summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider.go
diff options
context:
space:
mode:
authorJames Elliott <james-d-elliott@users.noreply.github.com>2021-11-25 12:56:58 +1100
committerGitHub <noreply@github.com>2021-11-25 12:56:58 +1100
commit347bd1be779266742aac4a1576fe58ed20b96d3b (patch)
tree873f1cf8e4e6a36a29907189b25a0ef1e4cf3b9d /internal/storage/sql_provider.go
parenteb949603484c307827863729db22c1d38fc9e31f (diff)
feat(storage): encrypted secret values (#2588)
This adds an AES-GCM 256bit encryption layer for storage for sensitive items. This is only TOTP secrets for the time being but this may be expanded later. This will require a configuration change as per https://www.authelia.com/docs/configuration/migration.html#4330. Closes #682
Diffstat (limited to 'internal/storage/sql_provider.go')
-rw-r--r--internal/storage/sql_provider.go229
1 files changed, 162 insertions, 67 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go
index f136e4c6e..9ff5dad2b 100644
--- a/internal/storage/sql_provider.go
+++ b/internal/storage/sql_provider.go
@@ -2,6 +2,7 @@ package storage
import (
"context"
+ "crypto/sha256"
"database/sql"
"errors"
"fmt"
@@ -16,15 +17,16 @@ import (
)
// NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's.
-func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvider) {
+func NewSQLProvider(name, driverName, dataSourceName, encryptionKey string) (provider SQLProvider) {
db, err := sqlx.Open(driverName, dataSourceName)
provider = SQLProvider{
+ db: db,
+ key: sha256.Sum256([]byte(encryptionKey)),
name: name,
driverName: driverName,
- db: db,
- log: logging.Logger(),
errOpen: err,
+ log: logging.Logger(),
sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
@@ -33,9 +35,13 @@ func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvid
sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification),
sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification),
- sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
- sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
- sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
+ sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
+ sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
+ sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
+ sqlSelectTOTPConfigs: fmt.Sprintf(queryFmtSelectTOTPConfigurations, tableTOTPConfigurations),
+
+ sqlUpdateTOTPConfigSecret: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations),
+ sqlUpdateTOTPConfigSecretByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecretByUsername, tableTOTPConfigurations),
sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices),
sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices),
@@ -48,20 +54,29 @@ func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvid
sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations),
sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations),
+ sqlUpsertEncryptionValue: fmt.Sprintf(queryFmtUpsertEncryptionValue, tableEncryption),
+ sqlSelectEncryptionValue: fmt.Sprintf(queryFmtSelectEncryptionValue, tableEncryption),
+
sqlFmtRenameTable: queryFmtRenameTable,
}
+ key := sha256.Sum256([]byte(encryptionKey))
+
+ provider.key = key
+
return provider
}
// SQLProvider is a storage provider persisting data in a SQL database.
type SQLProvider struct {
db *sqlx.DB
- log *logrus.Logger
+ key [32]byte
name string
driverName string
errOpen error
+ log *logrus.Logger
+
// Table: authentication_logs.
sqlInsertAuthenticationAttempt string
sqlSelectAuthenticationAttemptsByUsername string
@@ -72,9 +87,13 @@ type SQLProvider struct {
sqlSelectExistsIdentityVerification string
// Table: totp_configurations.
- sqlUpsertTOTPConfig string
- sqlDeleteTOTPConfig string
- sqlSelectTOTPConfig string
+ sqlUpsertTOTPConfig string
+ sqlDeleteTOTPConfig string
+ sqlSelectTOTPConfig string
+ sqlSelectTOTPConfigs string
+
+ sqlUpdateTOTPConfigSecret string
+ sqlUpdateTOTPConfigSecretByUsername string
// Table: u2f_devices.
sqlUpsertU2FDevice string
@@ -90,21 +109,29 @@ type SQLProvider struct {
sqlSelectMigrations string
sqlSelectLatestMigration string
+ // Table: encryption.
+ sqlUpsertEncryptionValue string
+ sqlSelectEncryptionValue string
+
// Utility.
sqlSelectExistingTables string
sqlFmtRenameTable string
}
+// Close the underlying database connection.
+func (p *SQLProvider) Close() (err error) {
+ return p.db.Close()
+}
+
// StartupCheck implements the provider startup check interface.
func (p *SQLProvider) StartupCheck() (err error) {
if p.errOpen != nil {
- return p.errOpen
+ return fmt.Errorf("error opening database: %w", p.errOpen)
}
// TODO: Decide if this is needed, or if it should be configurable.
for i := 0; i < 19; i++ {
- err = p.db.Ping()
- if err == nil {
+ if err = p.db.Ping(); err == nil {
break
}
@@ -112,13 +139,17 @@ func (p *SQLProvider) StartupCheck() (err error) {
}
if err != nil {
- return err
+ return fmt.Errorf("error pinging database: %w", err)
}
p.log.Infof("Storage schema is being checked for updates")
ctx := context.Background()
+ if err = p.SchemaEncryptionCheckKey(ctx, false); err != nil && !errors.Is(err, ErrSchemaEncryptionVersionUnsupported) {
+ return err
+ }
+
err = p.SchemaMigrate(ctx, true, SchemaLatest)
switch err {
@@ -128,7 +159,7 @@ func (p *SQLProvider) StartupCheck() (err error) {
case nil:
return nil
default:
- return err
+ return fmt.Errorf("error during schema migrate: %w", err)
}
}
@@ -143,13 +174,13 @@ func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username strin
func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) {
err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username)
- switch err {
- case sql.ErrNoRows:
+ switch {
+ case err == nil:
+ return method, nil
+ case errors.Is(err, sql.ErrNoRows):
return "", nil
- case nil:
- return method, err
default:
- return "", err
+ return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err)
}
}
@@ -161,89 +192,148 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m
case err == nil:
return info, nil
case errors.Is(err, sql.ErrNoRows):
- _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0])
- if err != nil {
- return models.UserInfo{}, err
+ if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0]); err != nil {
+ return models.UserInfo{}, fmt.Errorf("error upserting preferred two factor method while selecting user info for user '%s': %w", username, err)
}
- err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username)
- if err != nil {
- return models.UserInfo{}, err
+ if err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username); err != nil {
+ return models.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err)
}
return info, nil
default:
- return models.UserInfo{}, err
+ return models.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err)
}
}
// SaveIdentityVerification save an identity verification record to the database.
func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) {
- _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, verification.Token)
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, verification.Token); err != nil {
+ return fmt.Errorf("error inserting identity verification: %w", err)
+ }
- return err
+ return nil
}
// RemoveIdentityVerification remove an identity verification record from the database.
func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, token string) (err error) {
- _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, token)
+ if _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, token); err != nil {
+ return fmt.Errorf("error updating identity verification: %w", err)
+ }
- return err
+ return nil
}
// FindIdentityVerification checks if an identity verification record is in the database and active.
-func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) {
- err = p.db.GetContext(ctx, &found, p.sqlSelectExistsIdentityVerification, jti)
- if err != nil {
- return false, err
+func (p *SQLProvider) FindIdentityVerification(ctx context.Context, token string) (found bool, err error) {
+ if err = p.db.GetContext(ctx, &found, p.sqlSelectExistsIdentityVerification, token); err != nil {
+ return false, fmt.Errorf("error selecting identity verification exists: %w", err)
}
return found, nil
}
-// SaveTOTPConfiguration save a TOTP config of a given user in the database.
+// SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error) {
- // TODO: Encrypt config.Secret here.
- _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
- config.Username,
- config.Algorithm,
- config.Digits,
- config.Period,
- config.Secret,
- )
+ if config.Secret, err = p.encrypt(config.Secret); err != nil {
+ return fmt.Errorf("error encrypting the TOTP configuration secret: %v", err)
+ }
- return err
+ if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
+ config.Username, config.Algorithm, config.Digits, config.Period, config.Secret); err != nil {
+ return fmt.Errorf("error upserting TOTP configuration: %w", err)
+ }
+
+ return nil
}
-// DeleteTOTPConfiguration delete a TOTP secret from the database given a username.
+// DeleteTOTPConfiguration delete a TOTP configuration from the database given a username.
func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) {
- _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username)
+ if _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username); err != nil {
+ return fmt.Errorf("error deleting TOTP configuration: %w", err)
+ }
- return err
+ return nil
}
-// LoadTOTPConfiguration load a TOTP secret given a username from the database.
+// LoadTOTPConfiguration load a TOTP configuration given a username from the database.
func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error) {
config = &models.TOTPConfiguration{}
- err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config)
- if err != nil {
- if err == sql.ErrNoRows {
+ if err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNoTOTPSecret
}
- return nil, err
+ return nil, fmt.Errorf("error selecting TOTP configuration: %w", err)
+ }
+
+ if config.Secret, err = p.decrypt(config.Secret); err != nil {
+ return nil, fmt.Errorf("error decrypting the TOTP secret: %v", err)
}
- // TODO: Decrypt config.Secret here.
return config, nil
}
+// LoadTOTPConfigurations load a set of TOTP configurations.
+func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []models.TOTPConfiguration, err error) {
+ rows, err := p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, limit, limit*page)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return configs, nil
+ }
+
+ return nil, fmt.Errorf("error selecting TOTP configurations: %w", err)
+ }
+
+ defer func() {
+ if err := rows.Close(); err != nil {
+ p.log.Errorf(logFmtErrClosingConn, err)
+ }
+ }()
+
+ configs = make([]models.TOTPConfiguration, 0, limit)
+
+ var config models.TOTPConfiguration
+
+ for rows.Next() {
+ if err = rows.StructScan(&config); err != nil {
+ return nil, fmt.Errorf("error scanning TOTP configuration to struct: %w", err)
+ }
+
+ if config.Secret, err = p.decrypt(config.Secret); err != nil {
+ return nil, fmt.Errorf("error decrypting the TOTP secret: %v", err)
+ }
+
+ configs = append(configs, config)
+ }
+
+ return configs, nil
+}
+
+// UpdateTOTPConfigurationSecret updates a TOTP configuration secret.
+func (p *SQLProvider) UpdateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) {
+ switch config.ID {
+ case 0:
+ _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecretByUsername, config.Secret, config.Username)
+ default:
+ _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecret, config.Secret, config.ID)
+ }
+
+ if err != nil {
+ return fmt.Errorf("error updating TOTP configuration secret: %w", err)
+ }
+
+ return nil
+}
+
// SaveU2FDevice saves a registered U2F device.
func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) {
- _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.KeyHandle, device.PublicKey)
+ if _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.KeyHandle, device.PublicKey); err != nil {
+ return fmt.Errorf("error upserting U2F device secret: %v", err)
+ }
- return err
+ return nil
}
// LoadU2FDevice loads a U2F device registration for a given username.
@@ -252,13 +342,12 @@ func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (devic
Username: username,
}
- err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username)
- if err != nil {
- if err == sql.ErrNoRows {
+ if err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNoU2FDeviceHandle
}
- return nil, err
+ return nil, fmt.Errorf("error selecting U2F device: %w", err)
}
return device, nil
@@ -266,15 +355,22 @@ func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (devic
// AppendAuthenticationLog append a mark to the authentication log.
func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error) {
- _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Username)
- return err
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Username); err != nil {
+ return fmt.Errorf("error inserting authentiation attempt: %w", err)
+ }
+
+ return nil
}
// LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log.
func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []models.AuthenticationAttempt, err error) {
rows, err := p.db.QueryxContext(ctx, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page)
if err != nil {
- return nil, err
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, ErrNoAuthenticationLogs
+ }
+
+ return nil, fmt.Errorf("error selecting authentication logs: %w", err)
}
defer func() {
@@ -283,13 +379,12 @@ func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username strin
}
}()
+ var attempt models.AuthenticationAttempt
+
attempts = make([]models.AuthenticationAttempt, 0, limit)
for rows.Next() {
- var attempt models.AuthenticationAttempt
-
- err = rows.StructScan(&attempt)
- if err != nil {
+ if err = rows.StructScan(&attempt); err != nil {
return nil, err
}