summaryrefslogtreecommitdiff
path: root/internal/storage
diff options
context:
space:
mode:
authorJames Elliott <james-d-elliott@users.noreply.github.com>2022-10-22 16:41:27 +1100
committerGitHub <noreply@github.com>2022-10-22 16:41:27 +1100
commit1ea29cb2c24b44d15dffed3964e41e56b32da02d (patch)
treef31fb4a2b54478c5f887e45564d6b8585d9878cc /internal/storage
parent1d821a0d3a15b853e50c44aa786b846213f0355c (diff)
feat(storage): unix socket support (#4231)
Support for unix sockets for MySQL and PostgreSQL.
Diffstat (limited to 'internal/storage')
-rw-r--r--internal/storage/const.go3
-rw-r--r--internal/storage/sql_provider_backend_mysql.go42
-rw-r--r--internal/storage/sql_provider_backend_postgres.go145
3 files changed, 147 insertions, 43 deletions
diff --git a/internal/storage/const.go b/internal/storage/const.go
index 825caae04..39645979d 100644
--- a/internal/storage/const.go
+++ b/internal/storage/const.go
@@ -41,7 +41,8 @@ const (
)
const (
- sqlNetworkTypeTCP = "tcp"
+ sqlNetworkTypeTCP = "tcp"
+ sqlNetworkTypeUnixSocket = "unix"
)
const (
diff --git a/internal/storage/sql_provider_backend_mysql.go b/internal/storage/sql_provider_backend_mysql.go
index a336aeac9..b0382ed7d 100644
--- a/internal/storage/sql_provider_backend_mysql.go
+++ b/internal/storage/sql_provider_backend_mysql.go
@@ -1,11 +1,12 @@
package storage
import (
+ "crypto/x509"
"fmt"
+ "path"
"time"
"github.com/go-sql-driver/mysql"
- _ "github.com/go-sql-driver/mysql" // Load the MySQL Driver used in the connection string.
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
@@ -16,9 +17,9 @@ type MySQLProvider struct {
}
// NewMySQLProvider a MySQL provider.
-func NewMySQLProvider(config *schema.Configuration) (provider *MySQLProvider) {
+func NewMySQLProvider(config *schema.Configuration, caCertPool *x509.CertPool) (provider *MySQLProvider) {
provider = &MySQLProvider{
- SQLProvider: NewSQLProvider(config, providerMySQL, providerMySQL, dataSourceNameMySQL(*config.Storage.MySQL)),
+ SQLProvider: NewSQLProvider(config, providerMySQL, providerMySQL, dsnMySQL(config.Storage.MySQL)),
}
// All providers have differing SELECT existing table statements.
@@ -30,32 +31,35 @@ func NewMySQLProvider(config *schema.Configuration) (provider *MySQLProvider) {
return provider
}
-func dataSourceNameMySQL(config schema.MySQLStorageConfiguration) (dataSourceName string) {
- dconfig := mysql.NewConfig()
+func dsnMySQL(config *schema.MySQLStorageConfiguration) (dataSourceName string) {
+ dsnConfig := mysql.NewConfig()
switch {
+ case path.IsAbs(config.Host):
+ dsnConfig.Net = sqlNetworkTypeUnixSocket
+ dsnConfig.Addr = config.Host
case config.Port == 0:
- dconfig.Net = sqlNetworkTypeTCP
- dconfig.Addr = fmt.Sprintf("%s:%d", config.Host, 3306)
+ dsnConfig.Net = sqlNetworkTypeTCP
+ dsnConfig.Addr = fmt.Sprintf("%s:%d", config.Host, 3306)
default:
- dconfig.Net = sqlNetworkTypeTCP
- dconfig.Addr = fmt.Sprintf("%s:%d", config.Host, config.Port)
+ dsnConfig.Net = sqlNetworkTypeTCP
+ dsnConfig.Addr = fmt.Sprintf("%s:%d", config.Host, config.Port)
}
switch config.Port {
case 0:
- dconfig.Addr = config.Host
+ dsnConfig.Addr = config.Host
default:
- dconfig.Addr = fmt.Sprintf("%s:%d", config.Host, config.Port)
+ dsnConfig.Addr = fmt.Sprintf("%s:%d", config.Host, config.Port)
}
- dconfig.DBName = config.Database
- dconfig.User = config.Username
- dconfig.Passwd = config.Password
- dconfig.Timeout = config.Timeout
- dconfig.MultiStatements = true
- dconfig.ParseTime = true
- dconfig.Loc = time.Local
+ dsnConfig.DBName = config.Database
+ dsnConfig.User = config.Username
+ dsnConfig.Passwd = config.Password
+ dsnConfig.Timeout = config.Timeout
+ dsnConfig.MultiStatements = true
+ dsnConfig.ParseTime = true
+ dsnConfig.Loc = time.Local
- return dconfig.FormatDSN()
+ return dsnConfig.FormatDSN()
}
diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go
index 8a3ee7779..80f9929aa 100644
--- a/internal/storage/sql_provider_backend_postgres.go
+++ b/internal/storage/sql_provider_backend_postgres.go
@@ -1,11 +1,15 @@
package storage
import (
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
"fmt"
- "strings"
- "time"
+ "os"
+ "path"
- _ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string.
+ "github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/stdlib"
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
@@ -16,9 +20,9 @@ type PostgreSQLProvider struct {
}
// NewPostgreSQLProvider a PostgreSQL provider.
-func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLProvider) {
+func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPool) (provider *PostgreSQLProvider) {
provider = &PostgreSQLProvider{
- SQLProvider: NewSQLProvider(config, providerPostgres, "pgx", dataSourceNamePostgreSQL(*config.Storage.PostgreSQL)),
+ SQLProvider: NewSQLProvider(config, providerPostgres, "pgx", dsnPostgreSQL(config.Storage.PostgreSQL, caCertPool)),
}
// All providers have differing SELECT existing table statements.
@@ -128,33 +132,128 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
return provider
}
-func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dataSourceName string) {
- args := []string{
- fmt.Sprintf("host=%s", config.Host),
- fmt.Sprintf("user='%s'", config.Username),
- fmt.Sprintf("password='%s'", config.Password),
- fmt.Sprintf("dbname=%s", config.Database),
- fmt.Sprintf("search_path=%s", config.Schema),
- fmt.Sprintf("sslmode=%s", config.SSL.Mode),
+func dsnPostgreSQL(config *schema.PostgreSQLStorageConfiguration, globalCACertPool *x509.CertPool) (dsn string) {
+ dsnConfig, _ := pgx.ParseConfig("")
+
+ ca, certs := loadPostgreSQLLegacyTLS(config)
+
+ switch config.SSL.Mode {
+ case "disable":
+ break
+ default:
+ var caCertPool *x509.CertPool
+
+ switch ca {
+ case nil:
+ caCertPool = globalCACertPool
+ default:
+ caCertPool = globalCACertPool.Clone()
+ caCertPool.AddCert(ca)
+ }
+
+ dsnConfig.TLSConfig = &tls.Config{
+ Certificates: certs,
+ RootCAs: caCertPool,
+ InsecureSkipVerify: true, //nolint:gosec
+ }
+
+ switch {
+ case config.SSL.Mode == "require" && config.SSL.RootCertificate != "" || config.SSL.Mode == "verify-ca":
+ dsnConfig.TLSConfig.VerifyPeerCertificate = newPostgreSQLVerifyCAFunc(dsnConfig.TLSConfig)
+ case config.SSL.Mode == "verify-full":
+ dsnConfig.TLSConfig.InsecureSkipVerify = false
+ dsnConfig.TLSConfig.ServerName = config.Host
+ }
}
- if config.Port > 0 {
- args = append(args, fmt.Sprintf("port=%d", config.Port))
+ dsnConfig.Host = config.Host
+ dsnConfig.Port = uint16(config.Port)
+ dsnConfig.Database = config.Database
+ dsnConfig.User = config.Username
+ dsnConfig.Password = config.Password
+ dsnConfig.ConnectTimeout = config.Timeout
+ dsnConfig.RuntimeParams = map[string]string{
+ "search_path": config.Schema,
}
- if config.SSL.RootCertificate != "" {
- args = append(args, fmt.Sprintf("sslrootcert=%s", config.SSL.RootCertificate))
+ if dsnConfig.Port == 0 && !path.IsAbs(dsnConfig.Host) {
+ dsnConfig.Port = 4321
}
- if config.SSL.Certificate != "" {
- args = append(args, fmt.Sprintf("sslcert=%s", config.SSL.Certificate))
+ return stdlib.RegisterConnConfig(dsnConfig)
+}
+
+func loadPostgreSQLLegacyTLS(config *schema.PostgreSQLStorageConfiguration) (ca *x509.Certificate, certs []tls.Certificate) {
+ var (
+ err error
+ )
+
+ if config.SSL.RootCertificate != "" {
+ var caPEMBlock []byte
+
+ if caPEMBlock, err = os.ReadFile(config.SSL.RootCertificate); err != nil {
+ return nil, nil
+ }
+
+ if ca, err = x509.ParseCertificate(caPEMBlock); err != nil {
+ return nil, nil
+ }
}
- if config.SSL.Key != "" {
- args = append(args, fmt.Sprintf("sslkey=%s", config.SSL.Key))
+ if config.SSL.Certificate != "" && config.SSL.Key != "" {
+ var (
+ keyPEMBlock []byte
+ certPEMBlock []byte
+ )
+
+ if keyPEMBlock, err = os.ReadFile(config.SSL.Key); err != nil {
+ return nil, nil
+ }
+
+ if certPEMBlock, err = os.ReadFile(config.SSL.Certificate); err != nil {
+ return nil, nil
+ }
+
+ var cert tls.Certificate
+
+ if cert, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock); err != nil {
+ return nil, nil
+ }
+
+ certs = []tls.Certificate{cert}
}
- args = append(args, fmt.Sprintf("connect_timeout=%d", int32(config.Timeout/time.Second)))
+ return ca, certs
+}
+
+func newPostgreSQLVerifyCAFunc(config *tls.Config) func(certificates [][]byte, _ [][]*x509.Certificate) (err error) {
+ return func(certificates [][]byte, _ [][]*x509.Certificate) (err error) {
+ certs := make([]*x509.Certificate, len(certificates))
+
+ var cert *x509.Certificate
+
+ for i, asn1Data := range certificates {
+ if cert, err = x509.ParseCertificate(asn1Data); err != nil {
+ return errors.New("failed to parse certificate from server: " + err.Error())
+ }
- return strings.Join(args, " ")
+ certs[i] = cert
+ }
+
+ // Leave DNSName empty to skip hostname verification.
+ opts := x509.VerifyOptions{
+ Roots: config.RootCAs,
+ Intermediates: x509.NewCertPool(),
+ }
+
+ // Skip the first cert because it's the leaf. All others
+ // are intermediates.
+ for _, cert = range certs[1:] {
+ opts.Intermediates.AddCert(cert)
+ }
+
+ _, err = certs[0].Verify(opts)
+
+ return err
+ }
}