summaryrefslogtreecommitdiff
path: root/internal/authentication/ldap_user_provider.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/authentication/ldap_user_provider.go')
-rw-r--r--internal/authentication/ldap_user_provider.go65
1 files changed, 50 insertions, 15 deletions
diff --git a/internal/authentication/ldap_user_provider.go b/internal/authentication/ldap_user_provider.go
index d02d525c9..803ae0083 100644
--- a/internal/authentication/ldap_user_provider.go
+++ b/internal/authentication/ldap_user_provider.go
@@ -1,35 +1,70 @@
package authentication
import (
+ "crypto/tls"
"fmt"
+ "net/url"
"strings"
"github.com/clems4ever/authelia/internal/configuration/schema"
+ "github.com/clems4ever/authelia/internal/logging"
"gopkg.in/ldap.v3"
)
// LDAPUserProvider is a provider using a LDAP or AD as a user database.
type LDAPUserProvider struct {
configuration schema.LDAPAuthenticationBackendConfiguration
+
+ connectionFactory LDAPConnectionFactory
}
-func (p *LDAPUserProvider) connect(userDN string, password string) (*ldap.Conn, error) {
- conn, err := ldap.Dial("tcp", p.configuration.URL)
- if err != nil {
- return nil, err
+// NewLDAPUserProvider creates a new instance of LDAPUserProvider.
+func NewLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration) *LDAPUserProvider {
+ return &LDAPUserProvider{
+ configuration: configuration,
+ connectionFactory: NewLDAPConnectionFactoryImpl(),
}
+}
- err = conn.Bind(userDN, password)
+func NewLDAPUserProviderWithFactory(configuration schema.LDAPAuthenticationBackendConfiguration,
+ connectionFactory LDAPConnectionFactory) *LDAPUserProvider {
+ return &LDAPUserProvider{
+ configuration: configuration,
+ connectionFactory: connectionFactory,
+ }
+}
+
+func (p *LDAPUserProvider) connect(userDN string, password string) (LDAPConnection, error) {
+ var newConnection LDAPConnection
+
+ url, err := url.Parse(p.configuration.URL)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("Unable to parse URL to LDAP: %s", url)
}
- return conn, nil
-}
-// NewLDAPUserProvider creates a new instance of LDAPUserProvider.
-func NewLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration) *LDAPUserProvider {
- return &LDAPUserProvider{configuration}
+ if url.Scheme == "ldaps" {
+ logging.Logger().Debug("LDAP client starts a TLS session")
+ conn, err := p.connectionFactory.DialTLS("tcp", url.Host, &tls.Config{
+ InsecureSkipVerify: p.configuration.SkipVerify,
+ })
+ if err != nil {
+ return nil, err
+ }
+ newConnection = conn
+ } else {
+ logging.Logger().Debug("LDAP client starts a session over raw TCP")
+ conn, err := p.connectionFactory.Dial("tcp", url.Host)
+ if err != nil {
+ return nil, err
+ }
+ newConnection = conn
+ }
+
+ if err := newConnection.Bind(userDN, password); err != nil {
+ return nil, err
+ }
+ return newConnection, nil
}
// CheckUserPassword checks if provided password matches for the given user.
@@ -54,7 +89,7 @@ func (p *LDAPUserProvider) CheckUserPassword(username string, password string) (
return true, nil
}
-func (p *LDAPUserProvider) getUserAttribute(conn *ldap.Conn, username string, attribute string) ([]string, error) {
+func (p *LDAPUserProvider) getUserAttribute(conn LDAPConnection, username string, attribute string) ([]string, error) {
client, err := p.connect(p.configuration.User, p.configuration.Password)
if err != nil {
return nil, err
@@ -86,7 +121,7 @@ func (p *LDAPUserProvider) getUserAttribute(conn *ldap.Conn, username string, at
return sr.Entries[0].Attributes[0].Values, nil
}
-func (p *LDAPUserProvider) getUserDN(conn *ldap.Conn, username string) (string, error) {
+func (p *LDAPUserProvider) getUserDN(conn LDAPConnection, username string) (string, error) {
values, err := p.getUserAttribute(conn, username, "dn")
if err != nil {
@@ -100,7 +135,7 @@ func (p *LDAPUserProvider) getUserDN(conn *ldap.Conn, username string) (string,
return values[0], nil
}
-func (p *LDAPUserProvider) getUserUID(conn *ldap.Conn, username string) (string, error) {
+func (p *LDAPUserProvider) getUserUID(conn LDAPConnection, username string) (string, error) {
values, err := p.getUserAttribute(conn, username, "uid")
if err != nil {
@@ -114,7 +149,7 @@ func (p *LDAPUserProvider) getUserUID(conn *ldap.Conn, username string) (string,
return values[0], nil
}
-func (p *LDAPUserProvider) createGroupsFilter(conn *ldap.Conn, username string) (string, error) {
+func (p *LDAPUserProvider) createGroupsFilter(conn LDAPConnection, username string) (string, error) {
if strings.Index(p.configuration.GroupsFilter, "{0}") >= 0 {
return strings.Replace(p.configuration.GroupsFilter, "{0}", username, -1), nil
} else if strings.Index(p.configuration.GroupsFilter, "{dn}") >= 0 {