summaryrefslogtreecommitdiff
path: root/internal/authentication/ldap_client_factory.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/authentication/ldap_client_factory.go')
-rw-r--r--internal/authentication/ldap_client_factory.go312
1 files changed, 304 insertions, 8 deletions
diff --git a/internal/authentication/ldap_client_factory.go b/internal/authentication/ldap_client_factory.go
index d360f05b3..9bb55502c 100644
--- a/internal/authentication/ldap_client_factory.go
+++ b/internal/authentication/ldap_client_factory.go
@@ -1,18 +1,314 @@
package authentication
import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
"github.com/go-ldap/ldap/v3"
+
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/utils"
)
-// ProductionLDAPClientFactory the production implementation of an ldap connection factory.
-type ProductionLDAPClientFactory struct{}
+// LDAPClientFactory an interface describing factories that produce LDAPConnection implementations.
+type LDAPClientFactory interface {
+ Initialize() (err error)
+ GetClient(opts ...LDAPClientFactoryOption) (client ldap.Client, err error)
+ ReleaseClient(client ldap.Client) (err error)
+ Close() (err error)
+}
+
+// NewStandardLDAPClientFactory create a concrete ldap connection factory.
+func NewStandardLDAPClientFactory(config *schema.AuthenticationBackendLDAP, certs *x509.CertPool, dialer LDAPClientDialer) LDAPClientFactory {
+ if dialer == nil {
+ dialer = &LDAPClientDialerStandard{}
+ }
+
+ tlsc := utils.NewTLSConfig(config.TLS, certs)
+
+ opts := []ldap.DialOpt{
+ ldap.DialWithDialer(&net.Dialer{Timeout: config.Timeout}),
+ ldap.DialWithTLSConfig(tlsc),
+ }
+
+ return &StandardLDAPClientFactory{
+ config: config,
+ tls: tlsc,
+ opts: opts,
+ dialer: dialer,
+ }
+}
+
+// StandardLDAPClientFactory the production implementation of an ldap connection factory.
+type StandardLDAPClientFactory struct {
+ config *schema.AuthenticationBackendLDAP
+ tls *tls.Config
+ opts []ldap.DialOpt
+ dialer LDAPClientDialer
+}
+
+func (f *StandardLDAPClientFactory) Initialize() (err error) {
+ return nil
+}
+
+func (f *StandardLDAPClientFactory) GetClient(opts ...LDAPClientFactoryOption) (client ldap.Client, err error) {
+ return getLDAPClient(f.config.Address.String(), f.config.User, f.config.Password, f.dialer, f.tls, f.config.StartTLS, f.opts, opts...)
+}
+
+func (f *StandardLDAPClientFactory) ReleaseClient(client ldap.Client) (err error) {
+ if err = client.Close(); err != nil {
+ return fmt.Errorf("error occurred closing LDAP client: %w", err)
+ }
+
+ return nil
+}
-// NewProductionLDAPClientFactory create a concrete ldap connection factory.
-func NewProductionLDAPClientFactory() *ProductionLDAPClientFactory {
- return &ProductionLDAPClientFactory{}
+func (f *StandardLDAPClientFactory) Close() (err error) {
+ return nil
}
-// DialURL creates a client from an LDAP URL when successful.
-func (f *ProductionLDAPClientFactory) DialURL(addr string, opts ...ldap.DialOpt) (client LDAPClient, err error) {
- return ldap.DialURL(addr, opts...)
+// NewPooledLDAPClientFactory is a decorator for a LDAPClientFactory that performs pooling.
+func NewPooledLDAPClientFactory(config *schema.AuthenticationBackendLDAP, certs *x509.CertPool, dialer LDAPClientDialer) (factory LDAPClientFactory) {
+ if dialer == nil {
+ dialer = &LDAPClientDialerStandard{}
+ }
+
+ tlsc := utils.NewTLSConfig(config.TLS, certs)
+
+ opts := []ldap.DialOpt{
+ ldap.DialWithDialer(&net.Dialer{Timeout: config.Timeout}),
+ ldap.DialWithTLSConfig(tlsc),
+ }
+
+ if config.Pooling.Count <= 0 {
+ config.Pooling.Count = 3
+ }
+
+ if config.Pooling.Retries <= 0 {
+ config.Pooling.Retries = 3
+ }
+
+ if config.Pooling.Timeout <= 0 {
+ config.Pooling.Timeout = time.Second
+ }
+
+ sleep := config.Pooling.Timeout / time.Duration(config.Pooling.Retries)
+
+ return &PooledLDAPClientFactory{
+ config: config,
+ tls: tlsc,
+ opts: opts,
+ dialer: dialer,
+ sleep: sleep,
+ }
+}
+
+// PooledLDAPClientFactory is a LDAPClientFactory that takes another LDAPClientFactory and pools the
+// factory generated connections using a channel for thread safety.
+type PooledLDAPClientFactory struct {
+ config *schema.AuthenticationBackendLDAP
+ tls *tls.Config
+ opts []ldap.DialOpt
+ dialer LDAPClientDialer
+
+ pool chan *LDAPClientPooled
+ mu sync.Mutex
+
+ sleep time.Duration
+
+ closing bool
+}
+
+func (f *PooledLDAPClientFactory) Initialize() (err error) {
+ f.mu.Lock()
+
+ defer f.mu.Unlock()
+
+ if f.pool != nil {
+ return nil
+ }
+
+ f.pool = make(chan *LDAPClientPooled, f.config.Pooling.Count)
+
+ var (
+ errs []error
+ client *LDAPClientPooled
+ )
+
+ for i := 0; i < f.config.Pooling.Count; i++ {
+ if client, err = f.new(); err != nil {
+ errs = append(errs, err)
+
+ continue
+ }
+
+ f.pool <- client
+ }
+
+ if len(errs) == f.config.Pooling.Count {
+ return fmt.Errorf("errors occurred initializing the client pool: no connections could be established")
+ }
+
+ return nil
+}
+
+// GetClient opens new client using the pool.
+func (f *PooledLDAPClientFactory) GetClient(opts ...LDAPClientFactoryOption) (conn ldap.Client, err error) {
+ if len(opts) != 0 {
+ return getLDAPClient(f.config.Address.String(), f.config.User, f.config.Password, f.dialer, f.tls, f.config.StartTLS, f.opts, opts...)
+ }
+
+ return f.acquire(context.Background())
+}
+
+// The new function creates a pool based client. This function is not thread safe.
+func (f *PooledLDAPClientFactory) new() (pooled *LDAPClientPooled, err error) {
+ var client ldap.Client
+
+ if client, err = getLDAPClient(f.config.Address.String(), f.config.User, f.config.Password, f.dialer, f.tls, f.config.StartTLS, f.opts); err != nil {
+ return nil, fmt.Errorf("error occurred establishing new client for the pool: %w", err)
+ }
+
+ return &LDAPClientPooled{Client: client}, nil
+}
+
+// ReleaseClient returns a client using the pool or closes it.
+func (f *PooledLDAPClientFactory) ReleaseClient(client ldap.Client) (err error) {
+ f.mu.Lock()
+
+ defer f.mu.Unlock()
+
+ if f.closing {
+ return client.Close()
+ }
+
+ if pool, ok := client.(*LDAPClientPooled); !ok || cap(f.pool) == len(f.pool) {
+ // Prevent extra or non-pool connections from being returned into the pool.
+ return client.Close()
+ } else {
+ f.pool <- pool
+ }
+
+ return nil
+}
+
+func (f *PooledLDAPClientFactory) acquire(ctx context.Context) (client *LDAPClientPooled, err error) {
+ f.mu.Lock()
+
+ defer f.mu.Unlock()
+
+ if f.closing {
+ return nil, fmt.Errorf("error acquiring client: the pool is closed")
+ }
+
+ if cap(f.pool) != f.config.Pooling.Count {
+ if err = f.Initialize(); err != nil {
+ return nil, err
+ }
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, f.config.Pooling.Timeout)
+ defer cancel()
+
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case client = <-f.pool:
+ if client.IsClosing() || client.Client == nil {
+ for {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ if client, err = f.new(); err != nil {
+ time.Sleep(f.sleep)
+
+ continue
+ }
+
+ return client, nil
+ }
+ }
+ }
+
+ return client, nil
+ }
+}
+
+func (f *PooledLDAPClientFactory) Close() (err error) {
+ f.mu.Lock()
+
+ defer f.mu.Unlock()
+
+ f.closing = true
+
+ close(f.pool)
+
+ var errs []error
+
+ for client := range f.pool {
+ if client.IsClosing() {
+ continue
+ }
+
+ if err = client.Close(); err != nil {
+ errs = append(errs, err)
+ }
+ }
+
+ if len(errs) > 0 {
+ return fmt.Errorf("errors occurred closing the client pool: %w", errors.Join(errs...))
+ }
+
+ return nil
+}
+
+// LDAPClientPooled is a decorator for the ldap.Client which handles the pooling functionality. i.e. prevents the client
+// from being closed and instead relinquishes the connection back to the pool.
+type LDAPClientPooled struct {
+ ldap.Client
+}
+
+func getLDAPClient(address, username, password string, dialer LDAPClientDialer, tls *tls.Config, startTLS bool, dialerOpts []ldap.DialOpt, opts ...LDAPClientFactoryOption) (client ldap.Client, err error) {
+ config := &LDAPClientFactoryOptions{
+ Address: address,
+ Username: username,
+ Password: password,
+ }
+
+ for _, opt := range opts {
+ opt(config)
+ }
+
+ if client, err = dialer.DialURL(config.Address, dialerOpts...); err != nil {
+ return nil, fmt.Errorf("error occurred dialing address: %w", err)
+ }
+
+ if tls != nil && startTLS {
+ if err = client.StartTLS(tls); err != nil {
+ _ = client.Close()
+
+ return nil, fmt.Errorf("error occurred performing starttls: %w", err)
+ }
+ }
+
+ if config.Password == "" {
+ err = client.UnauthenticatedBind(config.Username)
+ } else {
+ err = client.Bind(config.Username, config.Password)
+ }
+
+ if err != nil {
+ _ = client.Close()
+
+ return nil, fmt.Errorf("error occurred performing bind: %w", err)
+ }
+
+ return client, nil
}