diff options
Diffstat (limited to 'internal/storage/sql_provider_backend_postgres.go')
| -rw-r--r-- | internal/storage/sql_provider_backend_postgres.go | 145 |
1 files changed, 122 insertions, 23 deletions
diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index 8a3ee7779..80f9929aa 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -1,11 +1,15 @@ package storage import ( + "crypto/tls" + "crypto/x509" + "errors" "fmt" - "strings" - "time" + "os" + "path" - _ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string. + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" "github.com/authelia/authelia/v4/internal/configuration/schema" ) @@ -16,9 +20,9 @@ type PostgreSQLProvider struct { } // NewPostgreSQLProvider a PostgreSQL provider. -func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLProvider) { +func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPool) (provider *PostgreSQLProvider) { provider = &PostgreSQLProvider{ - SQLProvider: NewSQLProvider(config, providerPostgres, "pgx", dataSourceNamePostgreSQL(*config.Storage.PostgreSQL)), + SQLProvider: NewSQLProvider(config, providerPostgres, "pgx", dsnPostgreSQL(config.Storage.PostgreSQL, caCertPool)), } // All providers have differing SELECT existing table statements. @@ -128,33 +132,128 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr return provider } -func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dataSourceName string) { - args := []string{ - fmt.Sprintf("host=%s", config.Host), - fmt.Sprintf("user='%s'", config.Username), - fmt.Sprintf("password='%s'", config.Password), - fmt.Sprintf("dbname=%s", config.Database), - fmt.Sprintf("search_path=%s", config.Schema), - fmt.Sprintf("sslmode=%s", config.SSL.Mode), +func dsnPostgreSQL(config *schema.PostgreSQLStorageConfiguration, globalCACertPool *x509.CertPool) (dsn string) { + dsnConfig, _ := pgx.ParseConfig("") + + ca, certs := loadPostgreSQLLegacyTLS(config) + + switch config.SSL.Mode { + case "disable": + break + default: + var caCertPool *x509.CertPool + + switch ca { + case nil: + caCertPool = globalCACertPool + default: + caCertPool = globalCACertPool.Clone() + caCertPool.AddCert(ca) + } + + dsnConfig.TLSConfig = &tls.Config{ + Certificates: certs, + RootCAs: caCertPool, + InsecureSkipVerify: true, //nolint:gosec + } + + switch { + case config.SSL.Mode == "require" && config.SSL.RootCertificate != "" || config.SSL.Mode == "verify-ca": + dsnConfig.TLSConfig.VerifyPeerCertificate = newPostgreSQLVerifyCAFunc(dsnConfig.TLSConfig) + case config.SSL.Mode == "verify-full": + dsnConfig.TLSConfig.InsecureSkipVerify = false + dsnConfig.TLSConfig.ServerName = config.Host + } } - if config.Port > 0 { - args = append(args, fmt.Sprintf("port=%d", config.Port)) + dsnConfig.Host = config.Host + dsnConfig.Port = uint16(config.Port) + dsnConfig.Database = config.Database + dsnConfig.User = config.Username + dsnConfig.Password = config.Password + dsnConfig.ConnectTimeout = config.Timeout + dsnConfig.RuntimeParams = map[string]string{ + "search_path": config.Schema, } - if config.SSL.RootCertificate != "" { - args = append(args, fmt.Sprintf("sslrootcert=%s", config.SSL.RootCertificate)) + if dsnConfig.Port == 0 && !path.IsAbs(dsnConfig.Host) { + dsnConfig.Port = 4321 } - if config.SSL.Certificate != "" { - args = append(args, fmt.Sprintf("sslcert=%s", config.SSL.Certificate)) + return stdlib.RegisterConnConfig(dsnConfig) +} + +func loadPostgreSQLLegacyTLS(config *schema.PostgreSQLStorageConfiguration) (ca *x509.Certificate, certs []tls.Certificate) { + var ( + err error + ) + + if config.SSL.RootCertificate != "" { + var caPEMBlock []byte + + if caPEMBlock, err = os.ReadFile(config.SSL.RootCertificate); err != nil { + return nil, nil + } + + if ca, err = x509.ParseCertificate(caPEMBlock); err != nil { + return nil, nil + } } - if config.SSL.Key != "" { - args = append(args, fmt.Sprintf("sslkey=%s", config.SSL.Key)) + if config.SSL.Certificate != "" && config.SSL.Key != "" { + var ( + keyPEMBlock []byte + certPEMBlock []byte + ) + + if keyPEMBlock, err = os.ReadFile(config.SSL.Key); err != nil { + return nil, nil + } + + if certPEMBlock, err = os.ReadFile(config.SSL.Certificate); err != nil { + return nil, nil + } + + var cert tls.Certificate + + if cert, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock); err != nil { + return nil, nil + } + + certs = []tls.Certificate{cert} } - args = append(args, fmt.Sprintf("connect_timeout=%d", int32(config.Timeout/time.Second))) + return ca, certs +} + +func newPostgreSQLVerifyCAFunc(config *tls.Config) func(certificates [][]byte, _ [][]*x509.Certificate) (err error) { + return func(certificates [][]byte, _ [][]*x509.Certificate) (err error) { + certs := make([]*x509.Certificate, len(certificates)) + + var cert *x509.Certificate + + for i, asn1Data := range certificates { + if cert, err = x509.ParseCertificate(asn1Data); err != nil { + return errors.New("failed to parse certificate from server: " + err.Error()) + } - return strings.Join(args, " ") + certs[i] = cert + } + + // Leave DNSName empty to skip hostname verification. + opts := x509.VerifyOptions{ + Roots: config.RootCAs, + Intermediates: x509.NewCertPool(), + } + + // Skip the first cert because it's the leaf. All others + // are intermediates. + for _, cert = range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + _, err = certs[0].Verify(opts) + + return err + } } |
