summaryrefslogtreecommitdiff
path: root/internal/authentication/ldap_user_provider.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/authentication/ldap_user_provider.go')
-rw-r--r--internal/authentication/ldap_user_provider.go168
1 files changed, 75 insertions, 93 deletions
diff --git a/internal/authentication/ldap_user_provider.go b/internal/authentication/ldap_user_provider.go
index 66687e1a1..905bf07e8 100644
--- a/internal/authentication/ldap_user_provider.go
+++ b/internal/authentication/ldap_user_provider.go
@@ -1,10 +1,8 @@
package authentication
import (
- "crypto/tls"
"crypto/x509"
"fmt"
- "net"
"net/url"
"strconv"
"strings"
@@ -21,11 +19,9 @@ import (
// LDAPUserProvider is a UserProvider that connects to LDAP servers like ActiveDirectory, OpenLDAP, OpenDJ, FreeIPA, etc.
type LDAPUserProvider struct {
- config schema.AuthenticationBackendLDAP
- tlsConfig *tls.Config
- dialOpts []ldap.DialOpt
- log *logrus.Logger
- factory LDAPClientFactory
+ config *schema.AuthenticationBackendLDAP
+ log *logrus.Logger
+ factory LDAPClientFactory
clock clock.Provider
@@ -53,37 +49,27 @@ type LDAPUserProvider struct {
groupsFilterReplacementsMemberOfRDN bool
}
-// NewLDAPUserProvider creates a new instance of LDAPUserProvider with the ProductionLDAPClientFactory.
-func NewLDAPUserProvider(config schema.AuthenticationBackend, certPool *x509.CertPool) (provider *LDAPUserProvider) {
- provider = NewLDAPUserProviderWithFactory(*config.LDAP, config.PasswordReset.Disable, certPool, NewProductionLDAPClientFactory())
-
- return provider
-}
-
-// NewLDAPUserProviderWithFactory creates a new instance of LDAPUserProvider with the specified LDAPClientFactory.
-func NewLDAPUserProviderWithFactory(config schema.AuthenticationBackendLDAP, disableResetPassword bool, certPool *x509.CertPool, factory LDAPClientFactory) (provider *LDAPUserProvider) {
- if config.TLS == nil {
- config.TLS = schema.DefaultLDAPAuthenticationBackendConfigurationImplementationCustom.TLS
+// NewLDAPUserProvider creates a new instance of LDAPUserProvider with the StandardLDAPClientFactory.
+func NewLDAPUserProvider(config schema.AuthenticationBackend, certs *x509.CertPool) (provider *LDAPUserProvider) {
+ if config.LDAP.TLS == nil {
+ config.LDAP.TLS = schema.DefaultLDAPAuthenticationBackendConfigurationImplementationCustom.TLS
}
- tlsConfig := utils.NewTLSConfig(config.TLS, certPool)
+ var factory LDAPClientFactory
- var dialOpts = []ldap.DialOpt{
- ldap.DialWithDialer(&net.Dialer{Timeout: config.Timeout}),
- }
-
- if tlsConfig != nil {
- dialOpts = append(dialOpts, ldap.DialWithTLSConfig(tlsConfig))
+ if config.LDAP.Pooling.Enable {
+ factory = NewPooledLDAPClientFactory(config.LDAP, certs, nil)
+ } else {
+ factory = NewStandardLDAPClientFactory(config.LDAP, certs, nil)
}
- if factory == nil {
- factory = NewProductionLDAPClientFactory()
- }
+ return NewLDAPUserProviderWithFactory(config.LDAP, config.PasswordReset.Disable, factory)
+}
+// NewLDAPUserProviderWithFactory creates a new instance of LDAPUserProvider with the specified LDAPClientFactory.
+func NewLDAPUserProviderWithFactory(config *schema.AuthenticationBackendLDAP, disableResetPassword bool, factory LDAPClientFactory) (provider *LDAPUserProvider) {
provider = &LDAPUserProvider{
config: config,
- tlsConfig: tlsConfig,
- dialOpts: dialOpts,
log: logging.Logger(),
factory: factory,
disableResetPassword: disableResetPassword,
@@ -99,25 +85,33 @@ func NewLDAPUserProviderWithFactory(config schema.AuthenticationBackendLDAP, dis
// CheckUserPassword checks if provided password matches for the given user.
func (p *LDAPUserProvider) CheckUserPassword(username string, password string) (valid bool, err error) {
var (
- client, clientUser LDAPClient
- profile *ldapUserProfile
+ client, uclient ldap.Client
+ profile *ldapUserProfile
)
- if client, err = p.connect(); err != nil {
+ if client, err = p.factory.GetClient(); err != nil {
return false, err
}
- defer client.Close()
+ defer func() {
+ if err := p.factory.ReleaseClient(client); err != nil {
+ p.log.WithError(err).Warn("Error occurred releasing the LDAP client")
+ }
+ }()
if profile, err = p.getUserProfile(client, username); err != nil {
return false, err
}
- if clientUser, err = p.connectCustom(p.config.Address.String(), profile.DN, password, p.config.StartTLS, p.dialOpts...); err != nil {
+ if uclient, err = p.factory.GetClient(WithUsername(profile.DN), WithPassword(password)); err != nil {
return false, fmt.Errorf("authentication failed. Cause: %w", err)
}
- defer clientUser.Close()
+ defer func() {
+ if err := p.factory.ReleaseClient(uclient); err != nil {
+ p.log.WithError(err).Warn("Error occurred releasing the LDAP client")
+ }
+ }()
return true, nil
}
@@ -125,15 +119,19 @@ func (p *LDAPUserProvider) CheckUserPassword(username string, password string) (
// GetDetails retrieve the groups a user belongs to.
func (p *LDAPUserProvider) GetDetails(username string) (details *UserDetails, err error) {
var (
- client LDAPClient
+ client ldap.Client
profile *ldapUserProfile
)
- if client, err = p.connect(); err != nil {
+ if client, err = p.factory.GetClient(); err != nil {
return nil, err
}
- defer client.Close()
+ defer func() {
+ if err := p.factory.ReleaseClient(client); err != nil {
+ p.log.WithError(err).Warn("Error occurred releasing the LDAP client")
+ }
+ }()
if profile, err = p.getUserProfile(client, username); err != nil {
return nil, err
@@ -158,11 +156,11 @@ func (p *LDAPUserProvider) GetDetails(username string) (details *UserDetails, er
// GetDetailsExtended retrieves the UserDetailsExtended values.
func (p *LDAPUserProvider) GetDetailsExtended(username string) (details *UserDetailsExtended, err error) {
var (
- client LDAPClient
+ client ldap.Client
profile *ldapUserProfileExtended
)
- if client, err = p.connect(); err != nil {
+ if client, err = p.factory.GetClient(); err != nil {
return nil, err
}
@@ -243,15 +241,19 @@ func (p *LDAPUserProvider) GetDetailsExtended(username string) (details *UserDet
// UpdatePassword update the password of the given user.
func (p *LDAPUserProvider) UpdatePassword(username, password string) (err error) {
var (
- client LDAPClient
+ client ldap.Client
profile *ldapUserProfile
)
- if client, err = p.connect(); err != nil {
+ if client, err = p.factory.GetClient(); err != nil {
return fmt.Errorf("unable to update password. Cause: %w", err)
}
- defer client.Close()
+ defer func() {
+ if err := p.factory.ReleaseClient(client); err != nil {
+ p.log.WithError(err).Warn("Error occurred releasing the LDAP client")
+ }
+ }()
if profile, err = p.getUserProfile(client, username); err != nil {
return fmt.Errorf("unable to update password. Cause: %w", err)
@@ -297,39 +299,7 @@ func (p *LDAPUserProvider) UpdatePassword(username, password string) (err error)
return nil
}
-func (p *LDAPUserProvider) connect() (client LDAPClient, err error) {
- return p.connectCustom(p.config.Address.String(), p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...)
-}
-
-func (p *LDAPUserProvider) connectCustom(url, username, password string, startTLS bool, opts ...ldap.DialOpt) (client LDAPClient, err error) {
- if client, err = p.factory.DialURL(url, opts...); err != nil {
- return nil, fmt.Errorf("dial failed with error: %w", err)
- }
-
- if startTLS {
- if err = client.StartTLS(p.tlsConfig); err != nil {
- client.Close()
-
- return nil, fmt.Errorf("starttls failed with error: %w", err)
- }
- }
-
- if password == "" {
- err = client.UnauthenticatedBind(username)
- } else {
- err = client.Bind(username, password)
- }
-
- if err != nil {
- client.Close()
-
- return nil, fmt.Errorf("bind failed with error: %w", err)
- }
-
- return client, nil
-}
-
-func (p *LDAPUserProvider) search(client LDAPClient, request *ldap.SearchRequest) (result *ldap.SearchResult, err error) {
+func (p *LDAPUserProvider) search(client ldap.Client, request *ldap.SearchRequest) (result *ldap.SearchResult, err error) {
if result, err = client.Search(request); err != nil {
if referral, ok := p.getReferral(err); ok {
if result == nil {
@@ -357,15 +327,19 @@ func (p *LDAPUserProvider) search(client LDAPClient, request *ldap.SearchRequest
func (p *LDAPUserProvider) searchReferral(referral string, request *ldap.SearchRequest, searchResult *ldap.SearchResult) (err error) {
var (
- client LDAPClient
+ client ldap.Client
result *ldap.SearchResult
)
- if client, err = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); err != nil {
+ if client, err = p.factory.GetClient(WithAddress(referral)); err != nil {
return fmt.Errorf("error occurred connecting to referred LDAP server '%s': %w", referral, err)
}
- defer client.Close()
+ defer func() {
+ if err := p.factory.ReleaseClient(client); err != nil {
+ p.log.WithError(err).Warn("Error occurred releasing the LDAP client")
+ }
+ }()
if result, err = client.Search(request); err != nil {
return fmt.Errorf("error occurred performing search on referred LDAP server '%s': %w", referral, err)
@@ -390,7 +364,7 @@ func (p *LDAPUserProvider) searchReferrals(request *ldap.SearchRequest, result *
return nil
}
-func (p *LDAPUserProvider) getUserProfile(client LDAPClient, username string) (profile *ldapUserProfile, err error) {
+func (p *LDAPUserProvider) getUserProfile(client ldap.Client, username string) (profile *ldapUserProfile, err error) {
// Search for the given username.
request := ldap.NewSearchRequest(
p.usersBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
@@ -463,7 +437,7 @@ func (p *LDAPUserProvider) getUserProfileResultToProfile(username string, entry
return &userProfile, nil
}
-func (p *LDAPUserProvider) getUserProfileExtended(client LDAPClient, username string) (profile *ldapUserProfileExtended, err error) {
+func (p *LDAPUserProvider) getUserProfileExtended(client ldap.Client, username string) (profile *ldapUserProfileExtended, err error) {
// Search for the given username.
request := ldap.NewSearchRequest(
p.usersBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
@@ -552,7 +526,7 @@ func (p *LDAPUserProvider) getUserProfileResultToProfileExtended(username string
return &userProfile, nil
}
-func (p *LDAPUserProvider) getUserGroups(client LDAPClient, username string, profile *ldapUserProfile) (groups []string, err error) {
+func (p *LDAPUserProvider) getUserGroups(client ldap.Client, username string, profile *ldapUserProfile) (groups []string, err error) {
request := ldap.NewSearchRequest(
p.groupsBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
0, 0, false, p.resolveGroupsFilter(username, profile), p.groupsAttributes, nil,
@@ -577,7 +551,7 @@ func (p *LDAPUserProvider) getUserGroups(client LDAPClient, username string, pro
}
}
-func (p *LDAPUserProvider) getUserGroupsRequestFilter(client LDAPClient, username string, _ *ldapUserProfile, request *ldap.SearchRequest) (groups []string, err error) {
+func (p *LDAPUserProvider) getUserGroupsRequestFilter(client ldap.Client, username string, _ *ldapUserProfile, request *ldap.SearchRequest) (groups []string, err error) {
var result *ldap.SearchResult
if result, err = p.search(client, request); err != nil {
@@ -593,7 +567,7 @@ func (p *LDAPUserProvider) getUserGroupsRequestFilter(client LDAPClient, usernam
return groups, nil
}
-func (p *LDAPUserProvider) getUserGroupsRequestMemberOf(client LDAPClient, username string, profile *ldapUserProfile, request *ldap.SearchRequest) (groups []string, err error) {
+func (p *LDAPUserProvider) getUserGroupsRequestMemberOf(client ldap.Client, username string, profile *ldapUserProfile, request *ldap.SearchRequest) (groups []string, err error) {
var result *ldap.SearchResult
if result, err = p.search(client, request); err != nil {
@@ -733,7 +707,7 @@ func (p *LDAPUserProvider) resolveGroupsFilter(input string, profile *ldapUserPr
return filter
}
-func (p *LDAPUserProvider) modify(client LDAPClient, modifyRequest *ldap.ModifyRequest) (err error) {
+func (p *LDAPUserProvider) modify(client ldap.Client, modifyRequest *ldap.ModifyRequest) (err error) {
if err = client.Modify(modifyRequest); err != nil {
var (
referral string
@@ -747,15 +721,19 @@ func (p *LDAPUserProvider) modify(client LDAPClient, modifyRequest *ldap.ModifyR
p.log.Debugf("Attempting Modify on referred URL %s", referral)
var (
- clientRef LDAPClient
+ clientRef ldap.Client
errRef error
)
- if clientRef, errRef = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); errRef != nil {
+ if clientRef, errRef = p.factory.GetClient(WithAddress(referral)); errRef != nil {
return fmt.Errorf("error occurred connecting to referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err)
}
- defer clientRef.Close()
+ defer func() {
+ if err := p.factory.ReleaseClient(clientRef); err != nil {
+ p.log.WithError(err).Warn("Error occurred releasing the LDAP client")
+ }
+ }()
if errRef = clientRef.Modify(modifyRequest); errRef != nil {
return fmt.Errorf("error occurred performing modify on referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err)
@@ -767,7 +745,7 @@ func (p *LDAPUserProvider) modify(client LDAPClient, modifyRequest *ldap.ModifyR
return nil
}
-func (p *LDAPUserProvider) pwdModify(client LDAPClient, pwdModifyRequest *ldap.PasswordModifyRequest) (err error) {
+func (p *LDAPUserProvider) pwdModify(client ldap.Client, pwdModifyRequest *ldap.PasswordModifyRequest) (err error) {
if _, err = client.PasswordModify(pwdModifyRequest); err != nil {
var (
referral string
@@ -781,15 +759,19 @@ func (p *LDAPUserProvider) pwdModify(client LDAPClient, pwdModifyRequest *ldap.P
p.log.Debugf("Attempting PwdModify ExOp (1.3.6.1.4.1.4203.1.11.1) on referred URL %s", referral)
var (
- clientRef LDAPClient
+ clientRef ldap.Client
errRef error
)
- if clientRef, errRef = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); errRef != nil {
+ if clientRef, errRef = p.factory.GetClient(WithAddress(referral)); errRef != nil {
return fmt.Errorf("error occurred connecting to referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err)
}
- defer clientRef.Close()
+ defer func() {
+ if err := p.factory.ReleaseClient(clientRef); err != nil {
+ p.log.WithError(err).Warn("Error occurred releasing the LDAP client")
+ }
+ }()
if _, errRef = clientRef.PasswordModify(pwdModifyRequest); errRef != nil {
return fmt.Errorf("error occurred performing password modify on referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err)