diff options
Diffstat (limited to 'internal/storage/sql_provider_backend_postgres.go')
| -rw-r--r-- | internal/storage/sql_provider_backend_postgres.go | 42 |
1 files changed, 34 insertions, 8 deletions
diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index a410e01fd..daa7eef02 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -11,6 +11,7 @@ import ( "strconv" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/stdlib" "github.com/authelia/authelia/v4/internal/configuration/schema" @@ -174,7 +175,12 @@ func dsnPostgreSQL(config *schema.StoragePostgreSQL, globalCACertPool *x509.Cert dsnConfig.TLSConfig = loadPostgreSQLTLSConfig(config, globalCACertPool) dsnConfig.ConnectTimeout = config.Timeout dsnConfig.RuntimeParams = map[string]string{ - "search_path": config.Schema, + "application_name": fmt.Sprintf("Authelia %s", utils.Version()), + "search_path": config.Schema, + } + + if len(config.Servers) != 0 { + dsnPostgreSQLFallbacks(config, globalCACertPool, dsnConfig) } return stdlib.RegisterConnConfig(dsnConfig) @@ -207,20 +213,40 @@ func dsnPostgreSQLHostPort(address *schema.AddressTCP) (host string, port uint16 return host, port } +func dsnPostgreSQLFallbacks(config *schema.StoragePostgreSQL, globalCACertPool *x509.CertPool, dsnConfig *pgx.ConnConfig) { + dsnConfig.Fallbacks = make([]*pgconn.FallbackConfig, len(config.Servers)) + + for i, server := range config.Servers { + fallback := &pgconn.FallbackConfig{ + TLSConfig: loadPostgreSQLModernTLSConfig(server.TLS, globalCACertPool), + } + + fallback.Host, fallback.Port = dsnPostgreSQLHostPort(server.Address) + + if fallback.Port == 0 && !server.Address.IsUnixDomainSocket() { + fallback.Port = 5432 + } + + dsnConfig.Fallbacks[i] = fallback + } +} + func loadPostgreSQLTLSConfig(config *schema.StoragePostgreSQL, globalCACertPool *x509.CertPool) (tlsConfig *tls.Config) { if config.TLS != nil { - return utils.NewTLSConfig(config.TLS, globalCACertPool) + return loadPostgreSQLModernTLSConfig(config.TLS, globalCACertPool) + } else if config.SSL != nil { //nolint:staticcheck + return loadPostgreSQLLegacyTLSConfig(config, globalCACertPool) } - return loadPostgreSQLLegacyTLSConfig(config, globalCACertPool) + return nil +} + +func loadPostgreSQLModernTLSConfig(config *schema.TLS, globalCACertPool *x509.CertPool) (tlsConfig *tls.Config) { + return utils.NewTLSConfig(config, globalCACertPool) } //nolint:staticcheck // Used for legacy purposes. func loadPostgreSQLLegacyTLSConfig(config *schema.StoragePostgreSQL, globalCACertPool *x509.CertPool) (tlsConfig *tls.Config) { - if config.SSL == nil { - return nil - } - var ( ca *x509.Certificate certs []tls.Certificate @@ -238,7 +264,7 @@ func loadPostgreSQLLegacyTLSConfig(config *schema.StoragePostgreSQL, globalCACer case nil: caCertPool = globalCACertPool default: - caCertPool = globalCACertPool.Clone() + caCertPool = globalCACertPool caCertPool.AddCert(ca) } |
