diff options
Diffstat (limited to 'internal/authentication/ldap_user_provider.go')
| -rw-r--r-- | internal/authentication/ldap_user_provider.go | 168 |
1 files changed, 75 insertions, 93 deletions
diff --git a/internal/authentication/ldap_user_provider.go b/internal/authentication/ldap_user_provider.go index 66687e1a1..905bf07e8 100644 --- a/internal/authentication/ldap_user_provider.go +++ b/internal/authentication/ldap_user_provider.go @@ -1,10 +1,8 @@ package authentication import ( - "crypto/tls" "crypto/x509" "fmt" - "net" "net/url" "strconv" "strings" @@ -21,11 +19,9 @@ import ( // LDAPUserProvider is a UserProvider that connects to LDAP servers like ActiveDirectory, OpenLDAP, OpenDJ, FreeIPA, etc. type LDAPUserProvider struct { - config schema.AuthenticationBackendLDAP - tlsConfig *tls.Config - dialOpts []ldap.DialOpt - log *logrus.Logger - factory LDAPClientFactory + config *schema.AuthenticationBackendLDAP + log *logrus.Logger + factory LDAPClientFactory clock clock.Provider @@ -53,37 +49,27 @@ type LDAPUserProvider struct { groupsFilterReplacementsMemberOfRDN bool } -// NewLDAPUserProvider creates a new instance of LDAPUserProvider with the ProductionLDAPClientFactory. -func NewLDAPUserProvider(config schema.AuthenticationBackend, certPool *x509.CertPool) (provider *LDAPUserProvider) { - provider = NewLDAPUserProviderWithFactory(*config.LDAP, config.PasswordReset.Disable, certPool, NewProductionLDAPClientFactory()) - - return provider -} - -// NewLDAPUserProviderWithFactory creates a new instance of LDAPUserProvider with the specified LDAPClientFactory. -func NewLDAPUserProviderWithFactory(config schema.AuthenticationBackendLDAP, disableResetPassword bool, certPool *x509.CertPool, factory LDAPClientFactory) (provider *LDAPUserProvider) { - if config.TLS == nil { - config.TLS = schema.DefaultLDAPAuthenticationBackendConfigurationImplementationCustom.TLS +// NewLDAPUserProvider creates a new instance of LDAPUserProvider with the StandardLDAPClientFactory. +func NewLDAPUserProvider(config schema.AuthenticationBackend, certs *x509.CertPool) (provider *LDAPUserProvider) { + if config.LDAP.TLS == nil { + config.LDAP.TLS = schema.DefaultLDAPAuthenticationBackendConfigurationImplementationCustom.TLS } - tlsConfig := utils.NewTLSConfig(config.TLS, certPool) + var factory LDAPClientFactory - var dialOpts = []ldap.DialOpt{ - ldap.DialWithDialer(&net.Dialer{Timeout: config.Timeout}), - } - - if tlsConfig != nil { - dialOpts = append(dialOpts, ldap.DialWithTLSConfig(tlsConfig)) + if config.LDAP.Pooling.Enable { + factory = NewPooledLDAPClientFactory(config.LDAP, certs, nil) + } else { + factory = NewStandardLDAPClientFactory(config.LDAP, certs, nil) } - if factory == nil { - factory = NewProductionLDAPClientFactory() - } + return NewLDAPUserProviderWithFactory(config.LDAP, config.PasswordReset.Disable, factory) +} +// NewLDAPUserProviderWithFactory creates a new instance of LDAPUserProvider with the specified LDAPClientFactory. +func NewLDAPUserProviderWithFactory(config *schema.AuthenticationBackendLDAP, disableResetPassword bool, factory LDAPClientFactory) (provider *LDAPUserProvider) { provider = &LDAPUserProvider{ config: config, - tlsConfig: tlsConfig, - dialOpts: dialOpts, log: logging.Logger(), factory: factory, disableResetPassword: disableResetPassword, @@ -99,25 +85,33 @@ func NewLDAPUserProviderWithFactory(config schema.AuthenticationBackendLDAP, dis // CheckUserPassword checks if provided password matches for the given user. func (p *LDAPUserProvider) CheckUserPassword(username string, password string) (valid bool, err error) { var ( - client, clientUser LDAPClient - profile *ldapUserProfile + client, uclient ldap.Client + profile *ldapUserProfile ) - if client, err = p.connect(); err != nil { + if client, err = p.factory.GetClient(); err != nil { return false, err } - defer client.Close() + defer func() { + if err := p.factory.ReleaseClient(client); err != nil { + p.log.WithError(err).Warn("Error occurred releasing the LDAP client") + } + }() if profile, err = p.getUserProfile(client, username); err != nil { return false, err } - if clientUser, err = p.connectCustom(p.config.Address.String(), profile.DN, password, p.config.StartTLS, p.dialOpts...); err != nil { + if uclient, err = p.factory.GetClient(WithUsername(profile.DN), WithPassword(password)); err != nil { return false, fmt.Errorf("authentication failed. Cause: %w", err) } - defer clientUser.Close() + defer func() { + if err := p.factory.ReleaseClient(uclient); err != nil { + p.log.WithError(err).Warn("Error occurred releasing the LDAP client") + } + }() return true, nil } @@ -125,15 +119,19 @@ func (p *LDAPUserProvider) CheckUserPassword(username string, password string) ( // GetDetails retrieve the groups a user belongs to. func (p *LDAPUserProvider) GetDetails(username string) (details *UserDetails, err error) { var ( - client LDAPClient + client ldap.Client profile *ldapUserProfile ) - if client, err = p.connect(); err != nil { + if client, err = p.factory.GetClient(); err != nil { return nil, err } - defer client.Close() + defer func() { + if err := p.factory.ReleaseClient(client); err != nil { + p.log.WithError(err).Warn("Error occurred releasing the LDAP client") + } + }() if profile, err = p.getUserProfile(client, username); err != nil { return nil, err @@ -158,11 +156,11 @@ func (p *LDAPUserProvider) GetDetails(username string) (details *UserDetails, er // GetDetailsExtended retrieves the UserDetailsExtended values. func (p *LDAPUserProvider) GetDetailsExtended(username string) (details *UserDetailsExtended, err error) { var ( - client LDAPClient + client ldap.Client profile *ldapUserProfileExtended ) - if client, err = p.connect(); err != nil { + if client, err = p.factory.GetClient(); err != nil { return nil, err } @@ -243,15 +241,19 @@ func (p *LDAPUserProvider) GetDetailsExtended(username string) (details *UserDet // UpdatePassword update the password of the given user. func (p *LDAPUserProvider) UpdatePassword(username, password string) (err error) { var ( - client LDAPClient + client ldap.Client profile *ldapUserProfile ) - if client, err = p.connect(); err != nil { + if client, err = p.factory.GetClient(); err != nil { return fmt.Errorf("unable to update password. Cause: %w", err) } - defer client.Close() + defer func() { + if err := p.factory.ReleaseClient(client); err != nil { + p.log.WithError(err).Warn("Error occurred releasing the LDAP client") + } + }() if profile, err = p.getUserProfile(client, username); err != nil { return fmt.Errorf("unable to update password. Cause: %w", err) @@ -297,39 +299,7 @@ func (p *LDAPUserProvider) UpdatePassword(username, password string) (err error) return nil } -func (p *LDAPUserProvider) connect() (client LDAPClient, err error) { - return p.connectCustom(p.config.Address.String(), p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...) -} - -func (p *LDAPUserProvider) connectCustom(url, username, password string, startTLS bool, opts ...ldap.DialOpt) (client LDAPClient, err error) { - if client, err = p.factory.DialURL(url, opts...); err != nil { - return nil, fmt.Errorf("dial failed with error: %w", err) - } - - if startTLS { - if err = client.StartTLS(p.tlsConfig); err != nil { - client.Close() - - return nil, fmt.Errorf("starttls failed with error: %w", err) - } - } - - if password == "" { - err = client.UnauthenticatedBind(username) - } else { - err = client.Bind(username, password) - } - - if err != nil { - client.Close() - - return nil, fmt.Errorf("bind failed with error: %w", err) - } - - return client, nil -} - -func (p *LDAPUserProvider) search(client LDAPClient, request *ldap.SearchRequest) (result *ldap.SearchResult, err error) { +func (p *LDAPUserProvider) search(client ldap.Client, request *ldap.SearchRequest) (result *ldap.SearchResult, err error) { if result, err = client.Search(request); err != nil { if referral, ok := p.getReferral(err); ok { if result == nil { @@ -357,15 +327,19 @@ func (p *LDAPUserProvider) search(client LDAPClient, request *ldap.SearchRequest func (p *LDAPUserProvider) searchReferral(referral string, request *ldap.SearchRequest, searchResult *ldap.SearchResult) (err error) { var ( - client LDAPClient + client ldap.Client result *ldap.SearchResult ) - if client, err = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); err != nil { + if client, err = p.factory.GetClient(WithAddress(referral)); err != nil { return fmt.Errorf("error occurred connecting to referred LDAP server '%s': %w", referral, err) } - defer client.Close() + defer func() { + if err := p.factory.ReleaseClient(client); err != nil { + p.log.WithError(err).Warn("Error occurred releasing the LDAP client") + } + }() if result, err = client.Search(request); err != nil { return fmt.Errorf("error occurred performing search on referred LDAP server '%s': %w", referral, err) @@ -390,7 +364,7 @@ func (p *LDAPUserProvider) searchReferrals(request *ldap.SearchRequest, result * return nil } -func (p *LDAPUserProvider) getUserProfile(client LDAPClient, username string) (profile *ldapUserProfile, err error) { +func (p *LDAPUserProvider) getUserProfile(client ldap.Client, username string) (profile *ldapUserProfile, err error) { // Search for the given username. request := ldap.NewSearchRequest( p.usersBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, @@ -463,7 +437,7 @@ func (p *LDAPUserProvider) getUserProfileResultToProfile(username string, entry return &userProfile, nil } -func (p *LDAPUserProvider) getUserProfileExtended(client LDAPClient, username string) (profile *ldapUserProfileExtended, err error) { +func (p *LDAPUserProvider) getUserProfileExtended(client ldap.Client, username string) (profile *ldapUserProfileExtended, err error) { // Search for the given username. request := ldap.NewSearchRequest( p.usersBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, @@ -552,7 +526,7 @@ func (p *LDAPUserProvider) getUserProfileResultToProfileExtended(username string return &userProfile, nil } -func (p *LDAPUserProvider) getUserGroups(client LDAPClient, username string, profile *ldapUserProfile) (groups []string, err error) { +func (p *LDAPUserProvider) getUserGroups(client ldap.Client, username string, profile *ldapUserProfile) (groups []string, err error) { request := ldap.NewSearchRequest( p.groupsBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, p.resolveGroupsFilter(username, profile), p.groupsAttributes, nil, @@ -577,7 +551,7 @@ func (p *LDAPUserProvider) getUserGroups(client LDAPClient, username string, pro } } -func (p *LDAPUserProvider) getUserGroupsRequestFilter(client LDAPClient, username string, _ *ldapUserProfile, request *ldap.SearchRequest) (groups []string, err error) { +func (p *LDAPUserProvider) getUserGroupsRequestFilter(client ldap.Client, username string, _ *ldapUserProfile, request *ldap.SearchRequest) (groups []string, err error) { var result *ldap.SearchResult if result, err = p.search(client, request); err != nil { @@ -593,7 +567,7 @@ func (p *LDAPUserProvider) getUserGroupsRequestFilter(client LDAPClient, usernam return groups, nil } -func (p *LDAPUserProvider) getUserGroupsRequestMemberOf(client LDAPClient, username string, profile *ldapUserProfile, request *ldap.SearchRequest) (groups []string, err error) { +func (p *LDAPUserProvider) getUserGroupsRequestMemberOf(client ldap.Client, username string, profile *ldapUserProfile, request *ldap.SearchRequest) (groups []string, err error) { var result *ldap.SearchResult if result, err = p.search(client, request); err != nil { @@ -733,7 +707,7 @@ func (p *LDAPUserProvider) resolveGroupsFilter(input string, profile *ldapUserPr return filter } -func (p *LDAPUserProvider) modify(client LDAPClient, modifyRequest *ldap.ModifyRequest) (err error) { +func (p *LDAPUserProvider) modify(client ldap.Client, modifyRequest *ldap.ModifyRequest) (err error) { if err = client.Modify(modifyRequest); err != nil { var ( referral string @@ -747,15 +721,19 @@ func (p *LDAPUserProvider) modify(client LDAPClient, modifyRequest *ldap.ModifyR p.log.Debugf("Attempting Modify on referred URL %s", referral) var ( - clientRef LDAPClient + clientRef ldap.Client errRef error ) - if clientRef, errRef = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); errRef != nil { + if clientRef, errRef = p.factory.GetClient(WithAddress(referral)); errRef != nil { return fmt.Errorf("error occurred connecting to referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err) } - defer clientRef.Close() + defer func() { + if err := p.factory.ReleaseClient(clientRef); err != nil { + p.log.WithError(err).Warn("Error occurred releasing the LDAP client") + } + }() if errRef = clientRef.Modify(modifyRequest); errRef != nil { return fmt.Errorf("error occurred performing modify on referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err) @@ -767,7 +745,7 @@ func (p *LDAPUserProvider) modify(client LDAPClient, modifyRequest *ldap.ModifyR return nil } -func (p *LDAPUserProvider) pwdModify(client LDAPClient, pwdModifyRequest *ldap.PasswordModifyRequest) (err error) { +func (p *LDAPUserProvider) pwdModify(client ldap.Client, pwdModifyRequest *ldap.PasswordModifyRequest) (err error) { if _, err = client.PasswordModify(pwdModifyRequest); err != nil { var ( referral string @@ -781,15 +759,19 @@ func (p *LDAPUserProvider) pwdModify(client LDAPClient, pwdModifyRequest *ldap.P p.log.Debugf("Attempting PwdModify ExOp (1.3.6.1.4.1.4203.1.11.1) on referred URL %s", referral) var ( - clientRef LDAPClient + clientRef ldap.Client errRef error ) - if clientRef, errRef = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); errRef != nil { + if clientRef, errRef = p.factory.GetClient(WithAddress(referral)); errRef != nil { return fmt.Errorf("error occurred connecting to referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err) } - defer clientRef.Close() + defer func() { + if err := p.factory.ReleaseClient(clientRef); err != nil { + p.log.WithError(err).Warn("Error occurred releasing the LDAP client") + } + }() if _, errRef = clientRef.PasswordModify(pwdModifyRequest); errRef != nil { return fmt.Errorf("error occurred performing password modify on referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err) |
