From 31565e447ba1e357828c99db5410de879bfd7669 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Mon, 30 Dec 2024 17:59:36 +1100 Subject: fix(configuration): allow unix socket ports (#8520) This allows unix sockets to include ports in the address URL. In addition allows for a absolute path for the PostgreSQL socket type. Both options are only used by PostgreSQL but other unix sockets will not expressly error if this is included. Fixes #8509 --- internal/storage/sql_provider_backend_postgres.go | 55 +++++++++++++++++------ 1 file changed, 42 insertions(+), 13 deletions(-) (limited to 'internal/storage/sql_provider_backend_postgres.go') diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index 41dd83db0..2886d387e 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -3,9 +3,12 @@ package storage import ( "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "os" + "path/filepath" + "strconv" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" @@ -153,8 +156,7 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo func dsnPostgreSQL(config *schema.StoragePostgreSQL, globalCACertPool *x509.CertPool) (dsn string) { dsnConfig, _ := pgx.ParseConfig("") - dsnConfig.Host = config.Address.SocketHostname() - dsnConfig.Port = config.Address.Port() + dsnConfig.Host, dsnConfig.Port = dsnPostgreSQLHostPort(config.Address) dsnConfig.Database = config.Database dsnConfig.User = config.Username dsnConfig.Password = config.Password @@ -164,11 +166,34 @@ func dsnPostgreSQL(config *schema.StoragePostgreSQL, globalCACertPool *x509.Cert "search_path": config.Schema, } - if dsnConfig.Port == 0 && config.Address.IsUnixDomainSocket() { - dsnConfig.Port = 5432 + return stdlib.RegisterConnConfig(dsnConfig) +} + +func dsnPostgreSQLHostPort(address *schema.AddressTCP) (host string, port uint16) { + if !address.IsUnixDomainSocket() { + return address.SocketHostname(), address.Port() } - return stdlib.RegisterConnConfig(dsnConfig) + host, port = address.SocketHostname(), address.Port() + + if port == 0 { + port = 5432 + } + + dir, base := filepath.Dir(host), filepath.Base(host) + + matches := rePostgreSQLUnixDomainSocket.FindStringSubmatch(base) + + if len(matches) != 2 { + return host, port + } + + if raw, err := strconv.ParseUint(matches[1], 10, 16); err == nil { + host = dir + port = uint16(raw) + } + + return host, port } func loadPostgreSQLTLSConfig(config *schema.StoragePostgreSQL, globalCACertPool *x509.CertPool) (tlsConfig *tls.Config) { @@ -231,34 +256,38 @@ func loadPostgreSQLLegacyTLSConfigFiles(config *schema.StoragePostgreSQL) (ca *x ) if config.SSL.RootCertificate != "" { - var caPEMBlock []byte + var ( + data []byte + block *pem.Block + ) - if caPEMBlock, err = os.ReadFile(config.SSL.RootCertificate); err != nil { + if data, err = os.ReadFile(config.SSL.RootCertificate); err != nil { return nil, nil } - if ca, err = x509.ParseCertificate(caPEMBlock); err != nil { + block, _ = pem.Decode(data) + + if ca, err = x509.ParseCertificate(block.Bytes); err != nil { return nil, nil } } if config.SSL.Certificate != "" && config.SSL.Key != "" { var ( - keyPEMBlock []byte - certPEMBlock []byte + dataKey, dataCert []byte ) - if keyPEMBlock, err = os.ReadFile(config.SSL.Key); err != nil { + if dataKey, err = os.ReadFile(config.SSL.Key); err != nil { return nil, nil } - if certPEMBlock, err = os.ReadFile(config.SSL.Certificate); err != nil { + if dataCert, err = os.ReadFile(config.SSL.Certificate); err != nil { return nil, nil } var cert tls.Certificate - if cert, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock); err != nil { + if cert, err = tls.X509KeyPair(dataCert, dataKey); err != nil { return nil, nil } -- cgit v1.2.3