summaryrefslogtreecommitdiff
path: root/internal/authentication
diff options
context:
space:
mode:
Diffstat (limited to 'internal/authentication')
-rw-r--r--internal/authentication/ldap_connection_factory.go78
-rw-r--r--internal/authentication/ldap_connection_factory_mock.go143
-rw-r--r--internal/authentication/ldap_user_provider.go65
-rw-r--r--internal/authentication/ldap_user_provider_test.go57
4 files changed, 328 insertions, 15 deletions
diff --git a/internal/authentication/ldap_connection_factory.go b/internal/authentication/ldap_connection_factory.go
new file mode 100644
index 000000000..b6c6b2f48
--- /dev/null
+++ b/internal/authentication/ldap_connection_factory.go
@@ -0,0 +1,78 @@
+package authentication
+
+import (
+ "crypto/tls"
+
+ "gopkg.in/ldap.v3"
+)
+
+// ********************* CONNECTION *********************
+
+// LDAPConnection interface representing a connection to the ldap.
+type LDAPConnection interface {
+ Bind(username, password string) error
+ Close()
+
+ Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
+ Modify(modifyRequest *ldap.ModifyRequest) error
+}
+
+// LDAPConnectionImpl the production implementation of an ldap connection
+type LDAPConnectionImpl struct {
+ conn *ldap.Conn
+}
+
+// NewLDAPConnectionImpl create a new ldap connection
+func NewLDAPConnectionImpl(conn *ldap.Conn) *LDAPConnectionImpl {
+ return &LDAPConnectionImpl{conn}
+}
+
+func (lc *LDAPConnectionImpl) Bind(username, password string) error {
+ return lc.conn.Bind(username, password)
+}
+
+func (lc *LDAPConnectionImpl) Close() {
+ lc.conn.Close()
+}
+
+func (lc *LDAPConnectionImpl) Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
+ return lc.conn.Search(searchRequest)
+}
+
+func (lc *LDAPConnectionImpl) Modify(modifyRequest *ldap.ModifyRequest) error {
+ return lc.conn.Modify(modifyRequest)
+}
+
+// ********************* FACTORY ***********************
+
+// LDAPConnectionFactory an interface of factory of ldap connections
+type LDAPConnectionFactory interface {
+ DialTLS(network, addr string, config *tls.Config) (LDAPConnection, error)
+ Dial(network, addr string) (LDAPConnection, error)
+}
+
+// LDAPConnectionFactoryImpl the production implementation of an ldap connection factory.
+type LDAPConnectionFactoryImpl struct{}
+
+// NewLDAPConnectionFactoryImpl create a concrete ldap connection factory
+func NewLDAPConnectionFactoryImpl() *LDAPConnectionFactoryImpl {
+ return &LDAPConnectionFactoryImpl{}
+}
+
+// DialTLS contact ldap server over TLS.
+func (lcf *LDAPConnectionFactoryImpl) DialTLS(network, addr string, config *tls.Config) (LDAPConnection, error) {
+ conn, err := ldap.DialTLS(network, addr, config)
+ if err != nil {
+ return nil, err
+ }
+ return NewLDAPConnectionImpl(conn), nil
+}
+
+// Dial contact ldap server over raw tcp.
+func (lcf *LDAPConnectionFactoryImpl) Dial(network, addr string) (LDAPConnection, error) {
+ conn, err := ldap.Dial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ return NewLDAPConnectionImpl(conn), nil
+}
diff --git a/internal/authentication/ldap_connection_factory_mock.go b/internal/authentication/ldap_connection_factory_mock.go
new file mode 100644
index 000000000..f40983fb6
--- /dev/null
+++ b/internal/authentication/ldap_connection_factory_mock.go
@@ -0,0 +1,143 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: internal/authentication/ldap_connection_factory.go
+
+// Package authentication is a generated GoMock package.
+package authentication
+
+import (
+ tls "crypto/tls"
+ gomock "github.com/golang/mock/gomock"
+ ldap_v3 "gopkg.in/ldap.v3"
+ reflect "reflect"
+)
+
+// MockLDAPConnection is a mock of LDAPConnection interface
+type MockLDAPConnection struct {
+ ctrl *gomock.Controller
+ recorder *MockLDAPConnectionMockRecorder
+}
+
+// MockLDAPConnectionMockRecorder is the mock recorder for MockLDAPConnection
+type MockLDAPConnectionMockRecorder struct {
+ mock *MockLDAPConnection
+}
+
+// NewMockLDAPConnection creates a new mock instance
+func NewMockLDAPConnection(ctrl *gomock.Controller) *MockLDAPConnection {
+ mock := &MockLDAPConnection{ctrl: ctrl}
+ mock.recorder = &MockLDAPConnectionMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockLDAPConnection) EXPECT() *MockLDAPConnectionMockRecorder {
+ return m.recorder
+}
+
+// Bind mocks base method
+func (m *MockLDAPConnection) Bind(username, password string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Bind", username, password)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Bind indicates an expected call of Bind
+func (mr *MockLDAPConnectionMockRecorder) Bind(username, password interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bind", reflect.TypeOf((*MockLDAPConnection)(nil).Bind), username, password)
+}
+
+// Close mocks base method
+func (m *MockLDAPConnection) Close() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "Close")
+}
+
+// Close indicates an expected call of Close
+func (mr *MockLDAPConnectionMockRecorder) Close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockLDAPConnection)(nil).Close))
+}
+
+// Search mocks base method
+func (m *MockLDAPConnection) Search(searchRequest *ldap_v3.SearchRequest) (*ldap_v3.SearchResult, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Search", searchRequest)
+ ret0, _ := ret[0].(*ldap_v3.SearchResult)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Search indicates an expected call of Search
+func (mr *MockLDAPConnectionMockRecorder) Search(searchRequest interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Search", reflect.TypeOf((*MockLDAPConnection)(nil).Search), searchRequest)
+}
+
+// Modify mocks base method
+func (m *MockLDAPConnection) Modify(modifyRequest *ldap_v3.ModifyRequest) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Modify", modifyRequest)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Modify indicates an expected call of Modify
+func (mr *MockLDAPConnectionMockRecorder) Modify(modifyRequest interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Modify", reflect.TypeOf((*MockLDAPConnection)(nil).Modify), modifyRequest)
+}
+
+// MockLDAPConnectionFactory is a mock of LDAPConnectionFactory interface
+type MockLDAPConnectionFactory struct {
+ ctrl *gomock.Controller
+ recorder *MockLDAPConnectionFactoryMockRecorder
+}
+
+// MockLDAPConnectionFactoryMockRecorder is the mock recorder for MockLDAPConnectionFactory
+type MockLDAPConnectionFactoryMockRecorder struct {
+ mock *MockLDAPConnectionFactory
+}
+
+// NewMockLDAPConnectionFactory creates a new mock instance
+func NewMockLDAPConnectionFactory(ctrl *gomock.Controller) *MockLDAPConnectionFactory {
+ mock := &MockLDAPConnectionFactory{ctrl: ctrl}
+ mock.recorder = &MockLDAPConnectionFactoryMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockLDAPConnectionFactory) EXPECT() *MockLDAPConnectionFactoryMockRecorder {
+ return m.recorder
+}
+
+// DialTLS mocks base method
+func (m *MockLDAPConnectionFactory) DialTLS(network, addr string, config *tls.Config) (LDAPConnection, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DialTLS", network, addr, config)
+ ret0, _ := ret[0].(LDAPConnection)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// DialTLS indicates an expected call of DialTLS
+func (mr *MockLDAPConnectionFactoryMockRecorder) DialTLS(network, addr, config interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLS", reflect.TypeOf((*MockLDAPConnectionFactory)(nil).DialTLS), network, addr, config)
+}
+
+// Dial mocks base method
+func (m *MockLDAPConnectionFactory) Dial(network, addr string) (LDAPConnection, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Dial", network, addr)
+ ret0, _ := ret[0].(LDAPConnection)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Dial indicates an expected call of Dial
+func (mr *MockLDAPConnectionFactoryMockRecorder) Dial(network, addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockLDAPConnectionFactory)(nil).Dial), network, addr)
+}
diff --git a/internal/authentication/ldap_user_provider.go b/internal/authentication/ldap_user_provider.go
index d02d525c9..803ae0083 100644
--- a/internal/authentication/ldap_user_provider.go
+++ b/internal/authentication/ldap_user_provider.go
@@ -1,35 +1,70 @@
package authentication
import (
+ "crypto/tls"
"fmt"
+ "net/url"
"strings"
"github.com/clems4ever/authelia/internal/configuration/schema"
+ "github.com/clems4ever/authelia/internal/logging"
"gopkg.in/ldap.v3"
)
// LDAPUserProvider is a provider using a LDAP or AD as a user database.
type LDAPUserProvider struct {
configuration schema.LDAPAuthenticationBackendConfiguration
+
+ connectionFactory LDAPConnectionFactory
}
-func (p *LDAPUserProvider) connect(userDN string, password string) (*ldap.Conn, error) {
- conn, err := ldap.Dial("tcp", p.configuration.URL)
- if err != nil {
- return nil, err
+// NewLDAPUserProvider creates a new instance of LDAPUserProvider.
+func NewLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration) *LDAPUserProvider {
+ return &LDAPUserProvider{
+ configuration: configuration,
+ connectionFactory: NewLDAPConnectionFactoryImpl(),
}
+}
- err = conn.Bind(userDN, password)
+func NewLDAPUserProviderWithFactory(configuration schema.LDAPAuthenticationBackendConfiguration,
+ connectionFactory LDAPConnectionFactory) *LDAPUserProvider {
+ return &LDAPUserProvider{
+ configuration: configuration,
+ connectionFactory: connectionFactory,
+ }
+}
+
+func (p *LDAPUserProvider) connect(userDN string, password string) (LDAPConnection, error) {
+ var newConnection LDAPConnection
+
+ url, err := url.Parse(p.configuration.URL)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("Unable to parse URL to LDAP: %s", url)
}
- return conn, nil
-}
-// NewLDAPUserProvider creates a new instance of LDAPUserProvider.
-func NewLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration) *LDAPUserProvider {
- return &LDAPUserProvider{configuration}
+ if url.Scheme == "ldaps" {
+ logging.Logger().Debug("LDAP client starts a TLS session")
+ conn, err := p.connectionFactory.DialTLS("tcp", url.Host, &tls.Config{
+ InsecureSkipVerify: p.configuration.SkipVerify,
+ })
+ if err != nil {
+ return nil, err
+ }
+ newConnection = conn
+ } else {
+ logging.Logger().Debug("LDAP client starts a session over raw TCP")
+ conn, err := p.connectionFactory.Dial("tcp", url.Host)
+ if err != nil {
+ return nil, err
+ }
+ newConnection = conn
+ }
+
+ if err := newConnection.Bind(userDN, password); err != nil {
+ return nil, err
+ }
+ return newConnection, nil
}
// CheckUserPassword checks if provided password matches for the given user.
@@ -54,7 +89,7 @@ func (p *LDAPUserProvider) CheckUserPassword(username string, password string) (
return true, nil
}
-func (p *LDAPUserProvider) getUserAttribute(conn *ldap.Conn, username string, attribute string) ([]string, error) {
+func (p *LDAPUserProvider) getUserAttribute(conn LDAPConnection, username string, attribute string) ([]string, error) {
client, err := p.connect(p.configuration.User, p.configuration.Password)
if err != nil {
return nil, err
@@ -86,7 +121,7 @@ func (p *LDAPUserProvider) getUserAttribute(conn *ldap.Conn, username string, at
return sr.Entries[0].Attributes[0].Values, nil
}
-func (p *LDAPUserProvider) getUserDN(conn *ldap.Conn, username string) (string, error) {
+func (p *LDAPUserProvider) getUserDN(conn LDAPConnection, username string) (string, error) {
values, err := p.getUserAttribute(conn, username, "dn")
if err != nil {
@@ -100,7 +135,7 @@ func (p *LDAPUserProvider) getUserDN(conn *ldap.Conn, username string) (string,
return values[0], nil
}
-func (p *LDAPUserProvider) getUserUID(conn *ldap.Conn, username string) (string, error) {
+func (p *LDAPUserProvider) getUserUID(conn LDAPConnection, username string) (string, error) {
values, err := p.getUserAttribute(conn, username, "uid")
if err != nil {
@@ -114,7 +149,7 @@ func (p *LDAPUserProvider) getUserUID(conn *ldap.Conn, username string) (string,
return values[0], nil
}
-func (p *LDAPUserProvider) createGroupsFilter(conn *ldap.Conn, username string) (string, error) {
+func (p *LDAPUserProvider) createGroupsFilter(conn LDAPConnection, username string) (string, error) {
if strings.Index(p.configuration.GroupsFilter, "{0}") >= 0 {
return strings.Replace(p.configuration.GroupsFilter, "{0}", username, -1), nil
} else if strings.Index(p.configuration.GroupsFilter, "{dn}") >= 0 {
diff --git a/internal/authentication/ldap_user_provider_test.go b/internal/authentication/ldap_user_provider_test.go
new file mode 100644
index 000000000..1e666170b
--- /dev/null
+++ b/internal/authentication/ldap_user_provider_test.go
@@ -0,0 +1,57 @@
+package authentication
+
+import (
+ "testing"
+
+ "github.com/clems4ever/authelia/internal/configuration/schema"
+ gomock "github.com/golang/mock/gomock"
+ "github.com/stretchr/testify/require"
+)
+
+func TestShouldCreateRawConnectionWhenSchemeIsLDAP(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockFactory := NewMockLDAPConnectionFactory(ctrl)
+ mockConn := NewMockLDAPConnection(ctrl)
+
+ ldap := NewLDAPUserProviderWithFactory(schema.LDAPAuthenticationBackendConfiguration{
+ URL: "ldap://127.0.0.1:389",
+ }, mockFactory)
+
+ mockFactory.EXPECT().
+ Dial(gomock.Eq("tcp"), gomock.Eq("127.0.0.1:389")).
+ Return(mockConn, nil)
+
+ mockConn.EXPECT().
+ Bind(gomock.Eq("cn=admin,dc=example,dc=com"), gomock.Eq("password")).
+ Return(nil)
+
+ _, err := ldap.connect("cn=admin,dc=example,dc=com", "password")
+
+ require.NoError(t, err)
+}
+
+func TestShouldCreateTLSConnectionWhenSchemeIsLDAPS(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockFactory := NewMockLDAPConnectionFactory(ctrl)
+ mockConn := NewMockLDAPConnection(ctrl)
+
+ ldap := NewLDAPUserProviderWithFactory(schema.LDAPAuthenticationBackendConfiguration{
+ URL: "ldaps://127.0.0.1:389",
+ }, mockFactory)
+
+ mockFactory.EXPECT().
+ DialTLS(gomock.Eq("tcp"), gomock.Eq("127.0.0.1:389"), gomock.Any()).
+ Return(mockConn, nil)
+
+ mockConn.EXPECT().
+ Bind(gomock.Eq("cn=admin,dc=example,dc=com"), gomock.Eq("password")).
+ Return(nil)
+
+ _, err := ldap.connect("cn=admin,dc=example,dc=com", "password")
+
+ require.NoError(t, err)
+}