summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider_backend_postgres.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/storage/sql_provider_backend_postgres.go')
-rw-r--r--internal/storage/sql_provider_backend_postgres.go42
1 files changed, 34 insertions, 8 deletions
diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go
index a410e01fd..daa7eef02 100644
--- a/internal/storage/sql_provider_backend_postgres.go
+++ b/internal/storage/sql_provider_backend_postgres.go
@@ -11,6 +11,7 @@ import (
"strconv"
"github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/stdlib"
"github.com/authelia/authelia/v4/internal/configuration/schema"
@@ -174,7 +175,12 @@ func dsnPostgreSQL(config *schema.StoragePostgreSQL, globalCACertPool *x509.Cert
dsnConfig.TLSConfig = loadPostgreSQLTLSConfig(config, globalCACertPool)
dsnConfig.ConnectTimeout = config.Timeout
dsnConfig.RuntimeParams = map[string]string{
- "search_path": config.Schema,
+ "application_name": fmt.Sprintf("Authelia %s", utils.Version()),
+ "search_path": config.Schema,
+ }
+
+ if len(config.Servers) != 0 {
+ dsnPostgreSQLFallbacks(config, globalCACertPool, dsnConfig)
}
return stdlib.RegisterConnConfig(dsnConfig)
@@ -207,20 +213,40 @@ func dsnPostgreSQLHostPort(address *schema.AddressTCP) (host string, port uint16
return host, port
}
+func dsnPostgreSQLFallbacks(config *schema.StoragePostgreSQL, globalCACertPool *x509.CertPool, dsnConfig *pgx.ConnConfig) {
+ dsnConfig.Fallbacks = make([]*pgconn.FallbackConfig, len(config.Servers))
+
+ for i, server := range config.Servers {
+ fallback := &pgconn.FallbackConfig{
+ TLSConfig: loadPostgreSQLModernTLSConfig(server.TLS, globalCACertPool),
+ }
+
+ fallback.Host, fallback.Port = dsnPostgreSQLHostPort(server.Address)
+
+ if fallback.Port == 0 && !server.Address.IsUnixDomainSocket() {
+ fallback.Port = 5432
+ }
+
+ dsnConfig.Fallbacks[i] = fallback
+ }
+}
+
func loadPostgreSQLTLSConfig(config *schema.StoragePostgreSQL, globalCACertPool *x509.CertPool) (tlsConfig *tls.Config) {
if config.TLS != nil {
- return utils.NewTLSConfig(config.TLS, globalCACertPool)
+ return loadPostgreSQLModernTLSConfig(config.TLS, globalCACertPool)
+ } else if config.SSL != nil { //nolint:staticcheck
+ return loadPostgreSQLLegacyTLSConfig(config, globalCACertPool)
}
- return loadPostgreSQLLegacyTLSConfig(config, globalCACertPool)
+ return nil
+}
+
+func loadPostgreSQLModernTLSConfig(config *schema.TLS, globalCACertPool *x509.CertPool) (tlsConfig *tls.Config) {
+ return utils.NewTLSConfig(config, globalCACertPool)
}
//nolint:staticcheck // Used for legacy purposes.
func loadPostgreSQLLegacyTLSConfig(config *schema.StoragePostgreSQL, globalCACertPool *x509.CertPool) (tlsConfig *tls.Config) {
- if config.SSL == nil {
- return nil
- }
-
var (
ca *x509.Certificate
certs []tls.Certificate
@@ -238,7 +264,7 @@ func loadPostgreSQLLegacyTLSConfig(config *schema.StoragePostgreSQL, globalCACer
case nil:
caCertPool = globalCACertPool
default:
- caCertPool = globalCACertPool.Clone()
+ caCertPool = globalCACertPool
caCertPool.AddCert(ca)
}