diff options
Diffstat (limited to 'internal/authentication')
| -rw-r--r-- | internal/authentication/ldap_connection_factory.go | 78 | ||||
| -rw-r--r-- | internal/authentication/ldap_connection_factory_mock.go | 143 | ||||
| -rw-r--r-- | internal/authentication/ldap_user_provider.go | 65 | ||||
| -rw-r--r-- | internal/authentication/ldap_user_provider_test.go | 57 | 
4 files changed, 328 insertions, 15 deletions
diff --git a/internal/authentication/ldap_connection_factory.go b/internal/authentication/ldap_connection_factory.go new file mode 100644 index 000000000..b6c6b2f48 --- /dev/null +++ b/internal/authentication/ldap_connection_factory.go @@ -0,0 +1,78 @@ +package authentication + +import ( +	"crypto/tls" + +	"gopkg.in/ldap.v3" +) + +// ********************* CONNECTION ********************* + +// LDAPConnection interface representing a connection to the ldap. +type LDAPConnection interface { +	Bind(username, password string) error +	Close() + +	Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) +	Modify(modifyRequest *ldap.ModifyRequest) error +} + +// LDAPConnectionImpl the production implementation of an ldap connection +type LDAPConnectionImpl struct { +	conn *ldap.Conn +} + +// NewLDAPConnectionImpl create a new ldap connection +func NewLDAPConnectionImpl(conn *ldap.Conn) *LDAPConnectionImpl { +	return &LDAPConnectionImpl{conn} +} + +func (lc *LDAPConnectionImpl) Bind(username, password string) error { +	return lc.conn.Bind(username, password) +} + +func (lc *LDAPConnectionImpl) Close() { +	lc.conn.Close() +} + +func (lc *LDAPConnectionImpl) Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { +	return lc.conn.Search(searchRequest) +} + +func (lc *LDAPConnectionImpl) Modify(modifyRequest *ldap.ModifyRequest) error { +	return lc.conn.Modify(modifyRequest) +} + +// ********************* FACTORY *********************** + +// LDAPConnectionFactory an interface of factory of ldap connections +type LDAPConnectionFactory interface { +	DialTLS(network, addr string, config *tls.Config) (LDAPConnection, error) +	Dial(network, addr string) (LDAPConnection, error) +} + +// LDAPConnectionFactoryImpl the production implementation of an ldap connection factory. +type LDAPConnectionFactoryImpl struct{} + +// NewLDAPConnectionFactoryImpl create a concrete ldap connection factory +func NewLDAPConnectionFactoryImpl() *LDAPConnectionFactoryImpl { +	return &LDAPConnectionFactoryImpl{} +} + +// DialTLS contact ldap server over TLS. +func (lcf *LDAPConnectionFactoryImpl) DialTLS(network, addr string, config *tls.Config) (LDAPConnection, error) { +	conn, err := ldap.DialTLS(network, addr, config) +	if err != nil { +		return nil, err +	} +	return NewLDAPConnectionImpl(conn), nil +} + +// Dial contact ldap server over raw tcp. +func (lcf *LDAPConnectionFactoryImpl) Dial(network, addr string) (LDAPConnection, error) { +	conn, err := ldap.Dial(network, addr) +	if err != nil { +		return nil, err +	} +	return NewLDAPConnectionImpl(conn), nil +} diff --git a/internal/authentication/ldap_connection_factory_mock.go b/internal/authentication/ldap_connection_factory_mock.go new file mode 100644 index 000000000..f40983fb6 --- /dev/null +++ b/internal/authentication/ldap_connection_factory_mock.go @@ -0,0 +1,143 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: internal/authentication/ldap_connection_factory.go + +// Package authentication is a generated GoMock package. +package authentication + +import ( +	tls "crypto/tls" +	gomock "github.com/golang/mock/gomock" +	ldap_v3 "gopkg.in/ldap.v3" +	reflect "reflect" +) + +// MockLDAPConnection is a mock of LDAPConnection interface +type MockLDAPConnection struct { +	ctrl     *gomock.Controller +	recorder *MockLDAPConnectionMockRecorder +} + +// MockLDAPConnectionMockRecorder is the mock recorder for MockLDAPConnection +type MockLDAPConnectionMockRecorder struct { +	mock *MockLDAPConnection +} + +// NewMockLDAPConnection creates a new mock instance +func NewMockLDAPConnection(ctrl *gomock.Controller) *MockLDAPConnection { +	mock := &MockLDAPConnection{ctrl: ctrl} +	mock.recorder = &MockLDAPConnectionMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockLDAPConnection) EXPECT() *MockLDAPConnectionMockRecorder { +	return m.recorder +} + +// Bind mocks base method +func (m *MockLDAPConnection) Bind(username, password string) error { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Bind", username, password) +	ret0, _ := ret[0].(error) +	return ret0 +} + +// Bind indicates an expected call of Bind +func (mr *MockLDAPConnectionMockRecorder) Bind(username, password interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bind", reflect.TypeOf((*MockLDAPConnection)(nil).Bind), username, password) +} + +// Close mocks base method +func (m *MockLDAPConnection) Close() { +	m.ctrl.T.Helper() +	m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close +func (mr *MockLDAPConnectionMockRecorder) Close() *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockLDAPConnection)(nil).Close)) +} + +// Search mocks base method +func (m *MockLDAPConnection) Search(searchRequest *ldap_v3.SearchRequest) (*ldap_v3.SearchResult, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Search", searchRequest) +	ret0, _ := ret[0].(*ldap_v3.SearchResult) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// Search indicates an expected call of Search +func (mr *MockLDAPConnectionMockRecorder) Search(searchRequest interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Search", reflect.TypeOf((*MockLDAPConnection)(nil).Search), searchRequest) +} + +// Modify mocks base method +func (m *MockLDAPConnection) Modify(modifyRequest *ldap_v3.ModifyRequest) error { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Modify", modifyRequest) +	ret0, _ := ret[0].(error) +	return ret0 +} + +// Modify indicates an expected call of Modify +func (mr *MockLDAPConnectionMockRecorder) Modify(modifyRequest interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Modify", reflect.TypeOf((*MockLDAPConnection)(nil).Modify), modifyRequest) +} + +// MockLDAPConnectionFactory is a mock of LDAPConnectionFactory interface +type MockLDAPConnectionFactory struct { +	ctrl     *gomock.Controller +	recorder *MockLDAPConnectionFactoryMockRecorder +} + +// MockLDAPConnectionFactoryMockRecorder is the mock recorder for MockLDAPConnectionFactory +type MockLDAPConnectionFactoryMockRecorder struct { +	mock *MockLDAPConnectionFactory +} + +// NewMockLDAPConnectionFactory creates a new mock instance +func NewMockLDAPConnectionFactory(ctrl *gomock.Controller) *MockLDAPConnectionFactory { +	mock := &MockLDAPConnectionFactory{ctrl: ctrl} +	mock.recorder = &MockLDAPConnectionFactoryMockRecorder{mock} +	return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockLDAPConnectionFactory) EXPECT() *MockLDAPConnectionFactoryMockRecorder { +	return m.recorder +} + +// DialTLS mocks base method +func (m *MockLDAPConnectionFactory) DialTLS(network, addr string, config *tls.Config) (LDAPConnection, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "DialTLS", network, addr, config) +	ret0, _ := ret[0].(LDAPConnection) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// DialTLS indicates an expected call of DialTLS +func (mr *MockLDAPConnectionFactoryMockRecorder) DialTLS(network, addr, config interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLS", reflect.TypeOf((*MockLDAPConnectionFactory)(nil).DialTLS), network, addr, config) +} + +// Dial mocks base method +func (m *MockLDAPConnectionFactory) Dial(network, addr string) (LDAPConnection, error) { +	m.ctrl.T.Helper() +	ret := m.ctrl.Call(m, "Dial", network, addr) +	ret0, _ := ret[0].(LDAPConnection) +	ret1, _ := ret[1].(error) +	return ret0, ret1 +} + +// Dial indicates an expected call of Dial +func (mr *MockLDAPConnectionFactoryMockRecorder) Dial(network, addr interface{}) *gomock.Call { +	mr.mock.ctrl.T.Helper() +	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockLDAPConnectionFactory)(nil).Dial), network, addr) +} diff --git a/internal/authentication/ldap_user_provider.go b/internal/authentication/ldap_user_provider.go index d02d525c9..803ae0083 100644 --- a/internal/authentication/ldap_user_provider.go +++ b/internal/authentication/ldap_user_provider.go @@ -1,35 +1,70 @@  package authentication  import ( +	"crypto/tls"  	"fmt" +	"net/url"  	"strings"  	"github.com/clems4ever/authelia/internal/configuration/schema" +	"github.com/clems4ever/authelia/internal/logging"  	"gopkg.in/ldap.v3"  )  // LDAPUserProvider is a provider using a LDAP or AD as a user database.  type LDAPUserProvider struct {  	configuration schema.LDAPAuthenticationBackendConfiguration + +	connectionFactory LDAPConnectionFactory  } -func (p *LDAPUserProvider) connect(userDN string, password string) (*ldap.Conn, error) { -	conn, err := ldap.Dial("tcp", p.configuration.URL) -	if err != nil { -		return nil, err +// NewLDAPUserProvider creates a new instance of LDAPUserProvider. +func NewLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration) *LDAPUserProvider { +	return &LDAPUserProvider{ +		configuration:     configuration, +		connectionFactory: NewLDAPConnectionFactoryImpl(),  	} +} -	err = conn.Bind(userDN, password) +func NewLDAPUserProviderWithFactory(configuration schema.LDAPAuthenticationBackendConfiguration, +	connectionFactory LDAPConnectionFactory) *LDAPUserProvider { +	return &LDAPUserProvider{ +		configuration:     configuration, +		connectionFactory: connectionFactory, +	} +} + +func (p *LDAPUserProvider) connect(userDN string, password string) (LDAPConnection, error) { +	var newConnection LDAPConnection + +	url, err := url.Parse(p.configuration.URL)  	if err != nil { -		return nil, err +		return nil, fmt.Errorf("Unable to parse URL to LDAP: %s", url)  	} -	return conn, nil -} -// NewLDAPUserProvider creates a new instance of LDAPUserProvider. -func NewLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration) *LDAPUserProvider { -	return &LDAPUserProvider{configuration} +	if url.Scheme == "ldaps" { +		logging.Logger().Debug("LDAP client starts a TLS session") +		conn, err := p.connectionFactory.DialTLS("tcp", url.Host, &tls.Config{ +			InsecureSkipVerify: p.configuration.SkipVerify, +		}) +		if err != nil { +			return nil, err +		} +		newConnection = conn +	} else { +		logging.Logger().Debug("LDAP client starts a session over raw TCP") +		conn, err := p.connectionFactory.Dial("tcp", url.Host) +		if err != nil { +			return nil, err +		} +		newConnection = conn +	} + +	if err := newConnection.Bind(userDN, password); err != nil { +		return nil, err +	} +	return newConnection, nil  }  // CheckUserPassword checks if provided password matches for the given user. @@ -54,7 +89,7 @@ func (p *LDAPUserProvider) CheckUserPassword(username string, password string) (  	return true, nil  } -func (p *LDAPUserProvider) getUserAttribute(conn *ldap.Conn, username string, attribute string) ([]string, error) { +func (p *LDAPUserProvider) getUserAttribute(conn LDAPConnection, username string, attribute string) ([]string, error) {  	client, err := p.connect(p.configuration.User, p.configuration.Password)  	if err != nil {  		return nil, err @@ -86,7 +121,7 @@ func (p *LDAPUserProvider) getUserAttribute(conn *ldap.Conn, username string, at  	return sr.Entries[0].Attributes[0].Values, nil  } -func (p *LDAPUserProvider) getUserDN(conn *ldap.Conn, username string) (string, error) { +func (p *LDAPUserProvider) getUserDN(conn LDAPConnection, username string) (string, error) {  	values, err := p.getUserAttribute(conn, username, "dn")  	if err != nil { @@ -100,7 +135,7 @@ func (p *LDAPUserProvider) getUserDN(conn *ldap.Conn, username string) (string,  	return values[0], nil  } -func (p *LDAPUserProvider) getUserUID(conn *ldap.Conn, username string) (string, error) { +func (p *LDAPUserProvider) getUserUID(conn LDAPConnection, username string) (string, error) {  	values, err := p.getUserAttribute(conn, username, "uid")  	if err != nil { @@ -114,7 +149,7 @@ func (p *LDAPUserProvider) getUserUID(conn *ldap.Conn, username string) (string,  	return values[0], nil  } -func (p *LDAPUserProvider) createGroupsFilter(conn *ldap.Conn, username string) (string, error) { +func (p *LDAPUserProvider) createGroupsFilter(conn LDAPConnection, username string) (string, error) {  	if strings.Index(p.configuration.GroupsFilter, "{0}") >= 0 {  		return strings.Replace(p.configuration.GroupsFilter, "{0}", username, -1), nil  	} else if strings.Index(p.configuration.GroupsFilter, "{dn}") >= 0 { diff --git a/internal/authentication/ldap_user_provider_test.go b/internal/authentication/ldap_user_provider_test.go new file mode 100644 index 000000000..1e666170b --- /dev/null +++ b/internal/authentication/ldap_user_provider_test.go @@ -0,0 +1,57 @@ +package authentication + +import ( +	"testing" + +	"github.com/clems4ever/authelia/internal/configuration/schema" +	gomock "github.com/golang/mock/gomock" +	"github.com/stretchr/testify/require" +) + +func TestShouldCreateRawConnectionWhenSchemeIsLDAP(t *testing.T) { +	ctrl := gomock.NewController(t) +	defer ctrl.Finish() + +	mockFactory := NewMockLDAPConnectionFactory(ctrl) +	mockConn := NewMockLDAPConnection(ctrl) + +	ldap := NewLDAPUserProviderWithFactory(schema.LDAPAuthenticationBackendConfiguration{ +		URL: "ldap://127.0.0.1:389", +	}, mockFactory) + +	mockFactory.EXPECT(). +		Dial(gomock.Eq("tcp"), gomock.Eq("127.0.0.1:389")). +		Return(mockConn, nil) + +	mockConn.EXPECT(). +		Bind(gomock.Eq("cn=admin,dc=example,dc=com"), gomock.Eq("password")). +		Return(nil) + +	_, err := ldap.connect("cn=admin,dc=example,dc=com", "password") + +	require.NoError(t, err) +} + +func TestShouldCreateTLSConnectionWhenSchemeIsLDAPS(t *testing.T) { +	ctrl := gomock.NewController(t) +	defer ctrl.Finish() + +	mockFactory := NewMockLDAPConnectionFactory(ctrl) +	mockConn := NewMockLDAPConnection(ctrl) + +	ldap := NewLDAPUserProviderWithFactory(schema.LDAPAuthenticationBackendConfiguration{ +		URL: "ldaps://127.0.0.1:389", +	}, mockFactory) + +	mockFactory.EXPECT(). +		DialTLS(gomock.Eq("tcp"), gomock.Eq("127.0.0.1:389"), gomock.Any()). +		Return(mockConn, nil) + +	mockConn.EXPECT(). +		Bind(gomock.Eq("cn=admin,dc=example,dc=com"), gomock.Eq("password")). +		Return(nil) + +	_, err := ldap.connect("cn=admin,dc=example,dc=com", "password") + +	require.NoError(t, err) +}  | 
