diff options
Diffstat (limited to 'internal/authentication/ldap_user_provider.go')
| -rw-r--r-- | internal/authentication/ldap_user_provider.go | 65 |
1 files changed, 50 insertions, 15 deletions
diff --git a/internal/authentication/ldap_user_provider.go b/internal/authentication/ldap_user_provider.go index d02d525c9..803ae0083 100644 --- a/internal/authentication/ldap_user_provider.go +++ b/internal/authentication/ldap_user_provider.go @@ -1,35 +1,70 @@ package authentication import ( + "crypto/tls" "fmt" + "net/url" "strings" "github.com/clems4ever/authelia/internal/configuration/schema" + "github.com/clems4ever/authelia/internal/logging" "gopkg.in/ldap.v3" ) // LDAPUserProvider is a provider using a LDAP or AD as a user database. type LDAPUserProvider struct { configuration schema.LDAPAuthenticationBackendConfiguration + + connectionFactory LDAPConnectionFactory } -func (p *LDAPUserProvider) connect(userDN string, password string) (*ldap.Conn, error) { - conn, err := ldap.Dial("tcp", p.configuration.URL) - if err != nil { - return nil, err +// NewLDAPUserProvider creates a new instance of LDAPUserProvider. +func NewLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration) *LDAPUserProvider { + return &LDAPUserProvider{ + configuration: configuration, + connectionFactory: NewLDAPConnectionFactoryImpl(), } +} - err = conn.Bind(userDN, password) +func NewLDAPUserProviderWithFactory(configuration schema.LDAPAuthenticationBackendConfiguration, + connectionFactory LDAPConnectionFactory) *LDAPUserProvider { + return &LDAPUserProvider{ + configuration: configuration, + connectionFactory: connectionFactory, + } +} + +func (p *LDAPUserProvider) connect(userDN string, password string) (LDAPConnection, error) { + var newConnection LDAPConnection + + url, err := url.Parse(p.configuration.URL) if err != nil { - return nil, err + return nil, fmt.Errorf("Unable to parse URL to LDAP: %s", url) } - return conn, nil -} -// NewLDAPUserProvider creates a new instance of LDAPUserProvider. -func NewLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration) *LDAPUserProvider { - return &LDAPUserProvider{configuration} + if url.Scheme == "ldaps" { + logging.Logger().Debug("LDAP client starts a TLS session") + conn, err := p.connectionFactory.DialTLS("tcp", url.Host, &tls.Config{ + InsecureSkipVerify: p.configuration.SkipVerify, + }) + if err != nil { + return nil, err + } + newConnection = conn + } else { + logging.Logger().Debug("LDAP client starts a session over raw TCP") + conn, err := p.connectionFactory.Dial("tcp", url.Host) + if err != nil { + return nil, err + } + newConnection = conn + } + + if err := newConnection.Bind(userDN, password); err != nil { + return nil, err + } + return newConnection, nil } // CheckUserPassword checks if provided password matches for the given user. @@ -54,7 +89,7 @@ func (p *LDAPUserProvider) CheckUserPassword(username string, password string) ( return true, nil } -func (p *LDAPUserProvider) getUserAttribute(conn *ldap.Conn, username string, attribute string) ([]string, error) { +func (p *LDAPUserProvider) getUserAttribute(conn LDAPConnection, username string, attribute string) ([]string, error) { client, err := p.connect(p.configuration.User, p.configuration.Password) if err != nil { return nil, err @@ -86,7 +121,7 @@ func (p *LDAPUserProvider) getUserAttribute(conn *ldap.Conn, username string, at return sr.Entries[0].Attributes[0].Values, nil } -func (p *LDAPUserProvider) getUserDN(conn *ldap.Conn, username string) (string, error) { +func (p *LDAPUserProvider) getUserDN(conn LDAPConnection, username string) (string, error) { values, err := p.getUserAttribute(conn, username, "dn") if err != nil { @@ -100,7 +135,7 @@ func (p *LDAPUserProvider) getUserDN(conn *ldap.Conn, username string) (string, return values[0], nil } -func (p *LDAPUserProvider) getUserUID(conn *ldap.Conn, username string) (string, error) { +func (p *LDAPUserProvider) getUserUID(conn LDAPConnection, username string) (string, error) { values, err := p.getUserAttribute(conn, username, "uid") if err != nil { @@ -114,7 +149,7 @@ func (p *LDAPUserProvider) getUserUID(conn *ldap.Conn, username string) (string, return values[0], nil } -func (p *LDAPUserProvider) createGroupsFilter(conn *ldap.Conn, username string) (string, error) { +func (p *LDAPUserProvider) createGroupsFilter(conn LDAPConnection, username string) (string, error) { if strings.Index(p.configuration.GroupsFilter, "{0}") >= 0 { return strings.Replace(p.configuration.GroupsFilter, "{0}", username, -1), nil } else if strings.Index(p.configuration.GroupsFilter, "{dn}") >= 0 { |
