diff options
Diffstat (limited to 'internal/authentication/ldap_client_factory.go')
| -rw-r--r-- | internal/authentication/ldap_client_factory.go | 312 |
1 files changed, 304 insertions, 8 deletions
diff --git a/internal/authentication/ldap_client_factory.go b/internal/authentication/ldap_client_factory.go index d360f05b3..9bb55502c 100644 --- a/internal/authentication/ldap_client_factory.go +++ b/internal/authentication/ldap_client_factory.go @@ -1,18 +1,314 @@ package authentication import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net" + "sync" + "time" + "github.com/go-ldap/ldap/v3" + + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/utils" ) -// ProductionLDAPClientFactory the production implementation of an ldap connection factory. -type ProductionLDAPClientFactory struct{} +// LDAPClientFactory an interface describing factories that produce LDAPConnection implementations. +type LDAPClientFactory interface { + Initialize() (err error) + GetClient(opts ...LDAPClientFactoryOption) (client ldap.Client, err error) + ReleaseClient(client ldap.Client) (err error) + Close() (err error) +} + +// NewStandardLDAPClientFactory create a concrete ldap connection factory. +func NewStandardLDAPClientFactory(config *schema.AuthenticationBackendLDAP, certs *x509.CertPool, dialer LDAPClientDialer) LDAPClientFactory { + if dialer == nil { + dialer = &LDAPClientDialerStandard{} + } + + tlsc := utils.NewTLSConfig(config.TLS, certs) + + opts := []ldap.DialOpt{ + ldap.DialWithDialer(&net.Dialer{Timeout: config.Timeout}), + ldap.DialWithTLSConfig(tlsc), + } + + return &StandardLDAPClientFactory{ + config: config, + tls: tlsc, + opts: opts, + dialer: dialer, + } +} + +// StandardLDAPClientFactory the production implementation of an ldap connection factory. +type StandardLDAPClientFactory struct { + config *schema.AuthenticationBackendLDAP + tls *tls.Config + opts []ldap.DialOpt + dialer LDAPClientDialer +} + +func (f *StandardLDAPClientFactory) Initialize() (err error) { + return nil +} + +func (f *StandardLDAPClientFactory) GetClient(opts ...LDAPClientFactoryOption) (client ldap.Client, err error) { + return getLDAPClient(f.config.Address.String(), f.config.User, f.config.Password, f.dialer, f.tls, f.config.StartTLS, f.opts, opts...) +} + +func (f *StandardLDAPClientFactory) ReleaseClient(client ldap.Client) (err error) { + if err = client.Close(); err != nil { + return fmt.Errorf("error occurred closing LDAP client: %w", err) + } + + return nil +} -// NewProductionLDAPClientFactory create a concrete ldap connection factory. -func NewProductionLDAPClientFactory() *ProductionLDAPClientFactory { - return &ProductionLDAPClientFactory{} +func (f *StandardLDAPClientFactory) Close() (err error) { + return nil } -// DialURL creates a client from an LDAP URL when successful. -func (f *ProductionLDAPClientFactory) DialURL(addr string, opts ...ldap.DialOpt) (client LDAPClient, err error) { - return ldap.DialURL(addr, opts...) +// NewPooledLDAPClientFactory is a decorator for a LDAPClientFactory that performs pooling. +func NewPooledLDAPClientFactory(config *schema.AuthenticationBackendLDAP, certs *x509.CertPool, dialer LDAPClientDialer) (factory LDAPClientFactory) { + if dialer == nil { + dialer = &LDAPClientDialerStandard{} + } + + tlsc := utils.NewTLSConfig(config.TLS, certs) + + opts := []ldap.DialOpt{ + ldap.DialWithDialer(&net.Dialer{Timeout: config.Timeout}), + ldap.DialWithTLSConfig(tlsc), + } + + if config.Pooling.Count <= 0 { + config.Pooling.Count = 3 + } + + if config.Pooling.Retries <= 0 { + config.Pooling.Retries = 3 + } + + if config.Pooling.Timeout <= 0 { + config.Pooling.Timeout = time.Second + } + + sleep := config.Pooling.Timeout / time.Duration(config.Pooling.Retries) + + return &PooledLDAPClientFactory{ + config: config, + tls: tlsc, + opts: opts, + dialer: dialer, + sleep: sleep, + } +} + +// PooledLDAPClientFactory is a LDAPClientFactory that takes another LDAPClientFactory and pools the +// factory generated connections using a channel for thread safety. +type PooledLDAPClientFactory struct { + config *schema.AuthenticationBackendLDAP + tls *tls.Config + opts []ldap.DialOpt + dialer LDAPClientDialer + + pool chan *LDAPClientPooled + mu sync.Mutex + + sleep time.Duration + + closing bool +} + +func (f *PooledLDAPClientFactory) Initialize() (err error) { + f.mu.Lock() + + defer f.mu.Unlock() + + if f.pool != nil { + return nil + } + + f.pool = make(chan *LDAPClientPooled, f.config.Pooling.Count) + + var ( + errs []error + client *LDAPClientPooled + ) + + for i := 0; i < f.config.Pooling.Count; i++ { + if client, err = f.new(); err != nil { + errs = append(errs, err) + + continue + } + + f.pool <- client + } + + if len(errs) == f.config.Pooling.Count { + return fmt.Errorf("errors occurred initializing the client pool: no connections could be established") + } + + return nil +} + +// GetClient opens new client using the pool. +func (f *PooledLDAPClientFactory) GetClient(opts ...LDAPClientFactoryOption) (conn ldap.Client, err error) { + if len(opts) != 0 { + return getLDAPClient(f.config.Address.String(), f.config.User, f.config.Password, f.dialer, f.tls, f.config.StartTLS, f.opts, opts...) + } + + return f.acquire(context.Background()) +} + +// The new function creates a pool based client. This function is not thread safe. +func (f *PooledLDAPClientFactory) new() (pooled *LDAPClientPooled, err error) { + var client ldap.Client + + if client, err = getLDAPClient(f.config.Address.String(), f.config.User, f.config.Password, f.dialer, f.tls, f.config.StartTLS, f.opts); err != nil { + return nil, fmt.Errorf("error occurred establishing new client for the pool: %w", err) + } + + return &LDAPClientPooled{Client: client}, nil +} + +// ReleaseClient returns a client using the pool or closes it. +func (f *PooledLDAPClientFactory) ReleaseClient(client ldap.Client) (err error) { + f.mu.Lock() + + defer f.mu.Unlock() + + if f.closing { + return client.Close() + } + + if pool, ok := client.(*LDAPClientPooled); !ok || cap(f.pool) == len(f.pool) { + // Prevent extra or non-pool connections from being returned into the pool. + return client.Close() + } else { + f.pool <- pool + } + + return nil +} + +func (f *PooledLDAPClientFactory) acquire(ctx context.Context) (client *LDAPClientPooled, err error) { + f.mu.Lock() + + defer f.mu.Unlock() + + if f.closing { + return nil, fmt.Errorf("error acquiring client: the pool is closed") + } + + if cap(f.pool) != f.config.Pooling.Count { + if err = f.Initialize(); err != nil { + return nil, err + } + } + + ctx, cancel := context.WithTimeout(ctx, f.config.Pooling.Timeout) + defer cancel() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case client = <-f.pool: + if client.IsClosing() || client.Client == nil { + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + if client, err = f.new(); err != nil { + time.Sleep(f.sleep) + + continue + } + + return client, nil + } + } + } + + return client, nil + } +} + +func (f *PooledLDAPClientFactory) Close() (err error) { + f.mu.Lock() + + defer f.mu.Unlock() + + f.closing = true + + close(f.pool) + + var errs []error + + for client := range f.pool { + if client.IsClosing() { + continue + } + + if err = client.Close(); err != nil { + errs = append(errs, err) + } + } + + if len(errs) > 0 { + return fmt.Errorf("errors occurred closing the client pool: %w", errors.Join(errs...)) + } + + return nil +} + +// LDAPClientPooled is a decorator for the ldap.Client which handles the pooling functionality. i.e. prevents the client +// from being closed and instead relinquishes the connection back to the pool. +type LDAPClientPooled struct { + ldap.Client +} + +func getLDAPClient(address, username, password string, dialer LDAPClientDialer, tls *tls.Config, startTLS bool, dialerOpts []ldap.DialOpt, opts ...LDAPClientFactoryOption) (client ldap.Client, err error) { + config := &LDAPClientFactoryOptions{ + Address: address, + Username: username, + Password: password, + } + + for _, opt := range opts { + opt(config) + } + + if client, err = dialer.DialURL(config.Address, dialerOpts...); err != nil { + return nil, fmt.Errorf("error occurred dialing address: %w", err) + } + + if tls != nil && startTLS { + if err = client.StartTLS(tls); err != nil { + _ = client.Close() + + return nil, fmt.Errorf("error occurred performing starttls: %w", err) + } + } + + if config.Password == "" { + err = client.UnauthenticatedBind(config.Username) + } else { + err = client.Bind(config.Username, config.Password) + } + + if err != nil { + _ = client.Close() + + return nil, fmt.Errorf("error occurred performing bind: %w", err) + } + + return client, nil } |
