summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider_backend_postgres.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/storage/sql_provider_backend_postgres.go')
-rw-r--r--internal/storage/sql_provider_backend_postgres.go145
1 files changed, 122 insertions, 23 deletions
diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go
index 8a3ee7779..80f9929aa 100644
--- a/internal/storage/sql_provider_backend_postgres.go
+++ b/internal/storage/sql_provider_backend_postgres.go
@@ -1,11 +1,15 @@
package storage
import (
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
"fmt"
- "strings"
- "time"
+ "os"
+ "path"
- _ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string.
+ "github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/stdlib"
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
@@ -16,9 +20,9 @@ type PostgreSQLProvider struct {
}
// NewPostgreSQLProvider a PostgreSQL provider.
-func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLProvider) {
+func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPool) (provider *PostgreSQLProvider) {
provider = &PostgreSQLProvider{
- SQLProvider: NewSQLProvider(config, providerPostgres, "pgx", dataSourceNamePostgreSQL(*config.Storage.PostgreSQL)),
+ SQLProvider: NewSQLProvider(config, providerPostgres, "pgx", dsnPostgreSQL(config.Storage.PostgreSQL, caCertPool)),
}
// All providers have differing SELECT existing table statements.
@@ -128,33 +132,128 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
return provider
}
-func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dataSourceName string) {
- args := []string{
- fmt.Sprintf("host=%s", config.Host),
- fmt.Sprintf("user='%s'", config.Username),
- fmt.Sprintf("password='%s'", config.Password),
- fmt.Sprintf("dbname=%s", config.Database),
- fmt.Sprintf("search_path=%s", config.Schema),
- fmt.Sprintf("sslmode=%s", config.SSL.Mode),
+func dsnPostgreSQL(config *schema.PostgreSQLStorageConfiguration, globalCACertPool *x509.CertPool) (dsn string) {
+ dsnConfig, _ := pgx.ParseConfig("")
+
+ ca, certs := loadPostgreSQLLegacyTLS(config)
+
+ switch config.SSL.Mode {
+ case "disable":
+ break
+ default:
+ var caCertPool *x509.CertPool
+
+ switch ca {
+ case nil:
+ caCertPool = globalCACertPool
+ default:
+ caCertPool = globalCACertPool.Clone()
+ caCertPool.AddCert(ca)
+ }
+
+ dsnConfig.TLSConfig = &tls.Config{
+ Certificates: certs,
+ RootCAs: caCertPool,
+ InsecureSkipVerify: true, //nolint:gosec
+ }
+
+ switch {
+ case config.SSL.Mode == "require" && config.SSL.RootCertificate != "" || config.SSL.Mode == "verify-ca":
+ dsnConfig.TLSConfig.VerifyPeerCertificate = newPostgreSQLVerifyCAFunc(dsnConfig.TLSConfig)
+ case config.SSL.Mode == "verify-full":
+ dsnConfig.TLSConfig.InsecureSkipVerify = false
+ dsnConfig.TLSConfig.ServerName = config.Host
+ }
}
- if config.Port > 0 {
- args = append(args, fmt.Sprintf("port=%d", config.Port))
+ dsnConfig.Host = config.Host
+ dsnConfig.Port = uint16(config.Port)
+ dsnConfig.Database = config.Database
+ dsnConfig.User = config.Username
+ dsnConfig.Password = config.Password
+ dsnConfig.ConnectTimeout = config.Timeout
+ dsnConfig.RuntimeParams = map[string]string{
+ "search_path": config.Schema,
}
- if config.SSL.RootCertificate != "" {
- args = append(args, fmt.Sprintf("sslrootcert=%s", config.SSL.RootCertificate))
+ if dsnConfig.Port == 0 && !path.IsAbs(dsnConfig.Host) {
+ dsnConfig.Port = 4321
}
- if config.SSL.Certificate != "" {
- args = append(args, fmt.Sprintf("sslcert=%s", config.SSL.Certificate))
+ return stdlib.RegisterConnConfig(dsnConfig)
+}
+
+func loadPostgreSQLLegacyTLS(config *schema.PostgreSQLStorageConfiguration) (ca *x509.Certificate, certs []tls.Certificate) {
+ var (
+ err error
+ )
+
+ if config.SSL.RootCertificate != "" {
+ var caPEMBlock []byte
+
+ if caPEMBlock, err = os.ReadFile(config.SSL.RootCertificate); err != nil {
+ return nil, nil
+ }
+
+ if ca, err = x509.ParseCertificate(caPEMBlock); err != nil {
+ return nil, nil
+ }
}
- if config.SSL.Key != "" {
- args = append(args, fmt.Sprintf("sslkey=%s", config.SSL.Key))
+ if config.SSL.Certificate != "" && config.SSL.Key != "" {
+ var (
+ keyPEMBlock []byte
+ certPEMBlock []byte
+ )
+
+ if keyPEMBlock, err = os.ReadFile(config.SSL.Key); err != nil {
+ return nil, nil
+ }
+
+ if certPEMBlock, err = os.ReadFile(config.SSL.Certificate); err != nil {
+ return nil, nil
+ }
+
+ var cert tls.Certificate
+
+ if cert, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock); err != nil {
+ return nil, nil
+ }
+
+ certs = []tls.Certificate{cert}
}
- args = append(args, fmt.Sprintf("connect_timeout=%d", int32(config.Timeout/time.Second)))
+ return ca, certs
+}
+
+func newPostgreSQLVerifyCAFunc(config *tls.Config) func(certificates [][]byte, _ [][]*x509.Certificate) (err error) {
+ return func(certificates [][]byte, _ [][]*x509.Certificate) (err error) {
+ certs := make([]*x509.Certificate, len(certificates))
+
+ var cert *x509.Certificate
+
+ for i, asn1Data := range certificates {
+ if cert, err = x509.ParseCertificate(asn1Data); err != nil {
+ return errors.New("failed to parse certificate from server: " + err.Error())
+ }
- return strings.Join(args, " ")
+ certs[i] = cert
+ }
+
+ // Leave DNSName empty to skip hostname verification.
+ opts := x509.VerifyOptions{
+ Roots: config.RootCAs,
+ Intermediates: x509.NewCertPool(),
+ }
+
+ // Skip the first cert because it's the leaf. All others
+ // are intermediates.
+ for _, cert = range certs[1:] {
+ opts.Intermediates.AddCert(cert)
+ }
+
+ _, err = certs[0].Verify(opts)
+
+ return err
+ }
}