diff options
| author | James Elliott <james-d-elliott@users.noreply.github.com> | 2022-10-22 16:41:27 +1100 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-22 16:41:27 +1100 | 
| commit | 1ea29cb2c24b44d15dffed3964e41e56b32da02d (patch) | |
| tree | f31fb4a2b54478c5f887e45564d6b8585d9878cc /internal/storage | |
| parent | 1d821a0d3a15b853e50c44aa786b846213f0355c (diff) | |
feat(storage): unix socket support (#4231)
Support for unix sockets for MySQL and PostgreSQL.
Diffstat (limited to 'internal/storage')
| -rw-r--r-- | internal/storage/const.go | 3 | ||||
| -rw-r--r-- | internal/storage/sql_provider_backend_mysql.go | 42 | ||||
| -rw-r--r-- | internal/storage/sql_provider_backend_postgres.go | 145 | 
3 files changed, 147 insertions, 43 deletions
diff --git a/internal/storage/const.go b/internal/storage/const.go index 825caae04..39645979d 100644 --- a/internal/storage/const.go +++ b/internal/storage/const.go @@ -41,7 +41,8 @@ const (  )  const ( -	sqlNetworkTypeTCP = "tcp" +	sqlNetworkTypeTCP        = "tcp" +	sqlNetworkTypeUnixSocket = "unix"  )  const ( diff --git a/internal/storage/sql_provider_backend_mysql.go b/internal/storage/sql_provider_backend_mysql.go index a336aeac9..b0382ed7d 100644 --- a/internal/storage/sql_provider_backend_mysql.go +++ b/internal/storage/sql_provider_backend_mysql.go @@ -1,11 +1,12 @@  package storage  import ( +	"crypto/x509"  	"fmt" +	"path"  	"time"  	"github.com/go-sql-driver/mysql" -	_ "github.com/go-sql-driver/mysql" // Load the MySQL Driver used in the connection string.  	"github.com/authelia/authelia/v4/internal/configuration/schema"  ) @@ -16,9 +17,9 @@ type MySQLProvider struct {  }  // NewMySQLProvider a MySQL provider. -func NewMySQLProvider(config *schema.Configuration) (provider *MySQLProvider) { +func NewMySQLProvider(config *schema.Configuration, caCertPool *x509.CertPool) (provider *MySQLProvider) {  	provider = &MySQLProvider{ -		SQLProvider: NewSQLProvider(config, providerMySQL, providerMySQL, dataSourceNameMySQL(*config.Storage.MySQL)), +		SQLProvider: NewSQLProvider(config, providerMySQL, providerMySQL, dsnMySQL(config.Storage.MySQL)),  	}  	// All providers have differing SELECT existing table statements. @@ -30,32 +31,35 @@ func NewMySQLProvider(config *schema.Configuration) (provider *MySQLProvider) {  	return provider  } -func dataSourceNameMySQL(config schema.MySQLStorageConfiguration) (dataSourceName string) { -	dconfig := mysql.NewConfig() +func dsnMySQL(config *schema.MySQLStorageConfiguration) (dataSourceName string) { +	dsnConfig := mysql.NewConfig()  	switch { +	case path.IsAbs(config.Host): +		dsnConfig.Net = sqlNetworkTypeUnixSocket +		dsnConfig.Addr = config.Host  	case config.Port == 0: -		dconfig.Net = sqlNetworkTypeTCP -		dconfig.Addr = fmt.Sprintf("%s:%d", config.Host, 3306) +		dsnConfig.Net = sqlNetworkTypeTCP +		dsnConfig.Addr = fmt.Sprintf("%s:%d", config.Host, 3306)  	default: -		dconfig.Net = sqlNetworkTypeTCP -		dconfig.Addr = fmt.Sprintf("%s:%d", config.Host, config.Port) +		dsnConfig.Net = sqlNetworkTypeTCP +		dsnConfig.Addr = fmt.Sprintf("%s:%d", config.Host, config.Port)  	}  	switch config.Port {  	case 0: -		dconfig.Addr = config.Host +		dsnConfig.Addr = config.Host  	default: -		dconfig.Addr = fmt.Sprintf("%s:%d", config.Host, config.Port) +		dsnConfig.Addr = fmt.Sprintf("%s:%d", config.Host, config.Port)  	} -	dconfig.DBName = config.Database -	dconfig.User = config.Username -	dconfig.Passwd = config.Password -	dconfig.Timeout = config.Timeout -	dconfig.MultiStatements = true -	dconfig.ParseTime = true -	dconfig.Loc = time.Local +	dsnConfig.DBName = config.Database +	dsnConfig.User = config.Username +	dsnConfig.Passwd = config.Password +	dsnConfig.Timeout = config.Timeout +	dsnConfig.MultiStatements = true +	dsnConfig.ParseTime = true +	dsnConfig.Loc = time.Local -	return dconfig.FormatDSN() +	return dsnConfig.FormatDSN()  } diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index 8a3ee7779..80f9929aa 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -1,11 +1,15 @@  package storage  import ( +	"crypto/tls" +	"crypto/x509" +	"errors"  	"fmt" -	"strings" -	"time" +	"os" +	"path" -	_ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string. +	"github.com/jackc/pgx/v5" +	"github.com/jackc/pgx/v5/stdlib"  	"github.com/authelia/authelia/v4/internal/configuration/schema"  ) @@ -16,9 +20,9 @@ type PostgreSQLProvider struct {  }  // NewPostgreSQLProvider a PostgreSQL provider. -func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLProvider) { +func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPool) (provider *PostgreSQLProvider) {  	provider = &PostgreSQLProvider{ -		SQLProvider: NewSQLProvider(config, providerPostgres, "pgx", dataSourceNamePostgreSQL(*config.Storage.PostgreSQL)), +		SQLProvider: NewSQLProvider(config, providerPostgres, "pgx", dsnPostgreSQL(config.Storage.PostgreSQL, caCertPool)),  	}  	// All providers have differing SELECT existing table statements. @@ -128,33 +132,128 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr  	return provider  } -func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dataSourceName string) { -	args := []string{ -		fmt.Sprintf("host=%s", config.Host), -		fmt.Sprintf("user='%s'", config.Username), -		fmt.Sprintf("password='%s'", config.Password), -		fmt.Sprintf("dbname=%s", config.Database), -		fmt.Sprintf("search_path=%s", config.Schema), -		fmt.Sprintf("sslmode=%s", config.SSL.Mode), +func dsnPostgreSQL(config *schema.PostgreSQLStorageConfiguration, globalCACertPool *x509.CertPool) (dsn string) { +	dsnConfig, _ := pgx.ParseConfig("") + +	ca, certs := loadPostgreSQLLegacyTLS(config) + +	switch config.SSL.Mode { +	case "disable": +		break +	default: +		var caCertPool *x509.CertPool + +		switch ca { +		case nil: +			caCertPool = globalCACertPool +		default: +			caCertPool = globalCACertPool.Clone() +			caCertPool.AddCert(ca) +		} + +		dsnConfig.TLSConfig = &tls.Config{ +			Certificates:       certs, +			RootCAs:            caCertPool, +			InsecureSkipVerify: true, //nolint:gosec +		} + +		switch { +		case config.SSL.Mode == "require" && config.SSL.RootCertificate != "" || config.SSL.Mode == "verify-ca": +			dsnConfig.TLSConfig.VerifyPeerCertificate = newPostgreSQLVerifyCAFunc(dsnConfig.TLSConfig) +		case config.SSL.Mode == "verify-full": +			dsnConfig.TLSConfig.InsecureSkipVerify = false +			dsnConfig.TLSConfig.ServerName = config.Host +		}  	} -	if config.Port > 0 { -		args = append(args, fmt.Sprintf("port=%d", config.Port)) +	dsnConfig.Host = config.Host +	dsnConfig.Port = uint16(config.Port) +	dsnConfig.Database = config.Database +	dsnConfig.User = config.Username +	dsnConfig.Password = config.Password +	dsnConfig.ConnectTimeout = config.Timeout +	dsnConfig.RuntimeParams = map[string]string{ +		"search_path": config.Schema,  	} -	if config.SSL.RootCertificate != "" { -		args = append(args, fmt.Sprintf("sslrootcert=%s", config.SSL.RootCertificate)) +	if dsnConfig.Port == 0 && !path.IsAbs(dsnConfig.Host) { +		dsnConfig.Port = 4321  	} -	if config.SSL.Certificate != "" { -		args = append(args, fmt.Sprintf("sslcert=%s", config.SSL.Certificate)) +	return stdlib.RegisterConnConfig(dsnConfig) +} + +func loadPostgreSQLLegacyTLS(config *schema.PostgreSQLStorageConfiguration) (ca *x509.Certificate, certs []tls.Certificate) { +	var ( +		err error +	) + +	if config.SSL.RootCertificate != "" { +		var caPEMBlock []byte + +		if caPEMBlock, err = os.ReadFile(config.SSL.RootCertificate); err != nil { +			return nil, nil +		} + +		if ca, err = x509.ParseCertificate(caPEMBlock); err != nil { +			return nil, nil +		}  	} -	if config.SSL.Key != "" { -		args = append(args, fmt.Sprintf("sslkey=%s", config.SSL.Key)) +	if config.SSL.Certificate != "" && config.SSL.Key != "" { +		var ( +			keyPEMBlock  []byte +			certPEMBlock []byte +		) + +		if keyPEMBlock, err = os.ReadFile(config.SSL.Key); err != nil { +			return nil, nil +		} + +		if certPEMBlock, err = os.ReadFile(config.SSL.Certificate); err != nil { +			return nil, nil +		} + +		var cert tls.Certificate + +		if cert, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock); err != nil { +			return nil, nil +		} + +		certs = []tls.Certificate{cert}  	} -	args = append(args, fmt.Sprintf("connect_timeout=%d", int32(config.Timeout/time.Second))) +	return ca, certs +} + +func newPostgreSQLVerifyCAFunc(config *tls.Config) func(certificates [][]byte, _ [][]*x509.Certificate) (err error) { +	return func(certificates [][]byte, _ [][]*x509.Certificate) (err error) { +		certs := make([]*x509.Certificate, len(certificates)) + +		var cert *x509.Certificate + +		for i, asn1Data := range certificates { +			if cert, err = x509.ParseCertificate(asn1Data); err != nil { +				return errors.New("failed to parse certificate from server: " + err.Error()) +			} -	return strings.Join(args, " ") +			certs[i] = cert +		} + +		// Leave DNSName empty to skip hostname verification. +		opts := x509.VerifyOptions{ +			Roots:         config.RootCAs, +			Intermediates: x509.NewCertPool(), +		} + +		// Skip the first cert because it's the leaf. All others +		// are intermediates. +		for _, cert = range certs[1:] { +			opts.Intermediates.AddCert(cert) +		} + +		_, err = certs[0].Verify(opts) + +		return err +	}  }  | 
