diff options
| author | James Elliott <james-d-elliott@users.noreply.github.com> | 2021-11-25 12:56:58 +1100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-11-25 12:56:58 +1100 |
| commit | 347bd1be779266742aac4a1576fe58ed20b96d3b (patch) | |
| tree | 873f1cf8e4e6a36a29907189b25a0ef1e4cf3b9d /internal/storage/sql_provider.go | |
| parent | eb949603484c307827863729db22c1d38fc9e31f (diff) | |
feat(storage): encrypted secret values (#2588)
This adds an AES-GCM 256bit encryption layer for storage for sensitive items. This is only TOTP secrets for the time being but this may be expanded later. This will require a configuration change as per https://www.authelia.com/docs/configuration/migration.html#4330.
Closes #682
Diffstat (limited to 'internal/storage/sql_provider.go')
| -rw-r--r-- | internal/storage/sql_provider.go | 229 |
1 files changed, 162 insertions, 67 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index f136e4c6e..9ff5dad2b 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -2,6 +2,7 @@ package storage import ( "context" + "crypto/sha256" "database/sql" "errors" "fmt" @@ -16,15 +17,16 @@ import ( ) // NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's. -func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvider) { +func NewSQLProvider(name, driverName, dataSourceName, encryptionKey string) (provider SQLProvider) { db, err := sqlx.Open(driverName, dataSourceName) provider = SQLProvider{ + db: db, + key: sha256.Sum256([]byte(encryptionKey)), name: name, driverName: driverName, - db: db, - log: logging.Logger(), errOpen: err, + log: logging.Logger(), sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs), sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs), @@ -33,9 +35,13 @@ func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvid sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification), sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification), - sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations), - sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations), - sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations), + sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations), + sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations), + sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations), + sqlSelectTOTPConfigs: fmt.Sprintf(queryFmtSelectTOTPConfigurations, tableTOTPConfigurations), + + sqlUpdateTOTPConfigSecret: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations), + sqlUpdateTOTPConfigSecretByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecretByUsername, tableTOTPConfigurations), sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices), sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices), @@ -48,20 +54,29 @@ func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvid sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations), sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations), + sqlUpsertEncryptionValue: fmt.Sprintf(queryFmtUpsertEncryptionValue, tableEncryption), + sqlSelectEncryptionValue: fmt.Sprintf(queryFmtSelectEncryptionValue, tableEncryption), + sqlFmtRenameTable: queryFmtRenameTable, } + key := sha256.Sum256([]byte(encryptionKey)) + + provider.key = key + return provider } // SQLProvider is a storage provider persisting data in a SQL database. type SQLProvider struct { db *sqlx.DB - log *logrus.Logger + key [32]byte name string driverName string errOpen error + log *logrus.Logger + // Table: authentication_logs. sqlInsertAuthenticationAttempt string sqlSelectAuthenticationAttemptsByUsername string @@ -72,9 +87,13 @@ type SQLProvider struct { sqlSelectExistsIdentityVerification string // Table: totp_configurations. - sqlUpsertTOTPConfig string - sqlDeleteTOTPConfig string - sqlSelectTOTPConfig string + sqlUpsertTOTPConfig string + sqlDeleteTOTPConfig string + sqlSelectTOTPConfig string + sqlSelectTOTPConfigs string + + sqlUpdateTOTPConfigSecret string + sqlUpdateTOTPConfigSecretByUsername string // Table: u2f_devices. sqlUpsertU2FDevice string @@ -90,21 +109,29 @@ type SQLProvider struct { sqlSelectMigrations string sqlSelectLatestMigration string + // Table: encryption. + sqlUpsertEncryptionValue string + sqlSelectEncryptionValue string + // Utility. sqlSelectExistingTables string sqlFmtRenameTable string } +// Close the underlying database connection. +func (p *SQLProvider) Close() (err error) { + return p.db.Close() +} + // StartupCheck implements the provider startup check interface. func (p *SQLProvider) StartupCheck() (err error) { if p.errOpen != nil { - return p.errOpen + return fmt.Errorf("error opening database: %w", p.errOpen) } // TODO: Decide if this is needed, or if it should be configurable. for i := 0; i < 19; i++ { - err = p.db.Ping() - if err == nil { + if err = p.db.Ping(); err == nil { break } @@ -112,13 +139,17 @@ func (p *SQLProvider) StartupCheck() (err error) { } if err != nil { - return err + return fmt.Errorf("error pinging database: %w", err) } p.log.Infof("Storage schema is being checked for updates") ctx := context.Background() + if err = p.SchemaEncryptionCheckKey(ctx, false); err != nil && !errors.Is(err, ErrSchemaEncryptionVersionUnsupported) { + return err + } + err = p.SchemaMigrate(ctx, true, SchemaLatest) switch err { @@ -128,7 +159,7 @@ func (p *SQLProvider) StartupCheck() (err error) { case nil: return nil default: - return err + return fmt.Errorf("error during schema migrate: %w", err) } } @@ -143,13 +174,13 @@ func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username strin func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) { err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username) - switch err { - case sql.ErrNoRows: + switch { + case err == nil: + return method, nil + case errors.Is(err, sql.ErrNoRows): return "", nil - case nil: - return method, err default: - return "", err + return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err) } } @@ -161,89 +192,148 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m case err == nil: return info, nil case errors.Is(err, sql.ErrNoRows): - _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0]) - if err != nil { - return models.UserInfo{}, err + if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0]); err != nil { + return models.UserInfo{}, fmt.Errorf("error upserting preferred two factor method while selecting user info for user '%s': %w", username, err) } - err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username) - if err != nil { - return models.UserInfo{}, err + if err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username); err != nil { + return models.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err) } return info, nil default: - return models.UserInfo{}, err + return models.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err) } } // SaveIdentityVerification save an identity verification record to the database. func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) { - _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, verification.Token) + if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, verification.Token); err != nil { + return fmt.Errorf("error inserting identity verification: %w", err) + } - return err + return nil } // RemoveIdentityVerification remove an identity verification record from the database. func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, token string) (err error) { - _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, token) + if _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, token); err != nil { + return fmt.Errorf("error updating identity verification: %w", err) + } - return err + return nil } // FindIdentityVerification checks if an identity verification record is in the database and active. -func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) { - err = p.db.GetContext(ctx, &found, p.sqlSelectExistsIdentityVerification, jti) - if err != nil { - return false, err +func (p *SQLProvider) FindIdentityVerification(ctx context.Context, token string) (found bool, err error) { + if err = p.db.GetContext(ctx, &found, p.sqlSelectExistsIdentityVerification, token); err != nil { + return false, fmt.Errorf("error selecting identity verification exists: %w", err) } return found, nil } -// SaveTOTPConfiguration save a TOTP config of a given user in the database. +// SaveTOTPConfiguration save a TOTP configuration of a given user in the database. func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error) { - // TODO: Encrypt config.Secret here. - _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig, - config.Username, - config.Algorithm, - config.Digits, - config.Period, - config.Secret, - ) + if config.Secret, err = p.encrypt(config.Secret); err != nil { + return fmt.Errorf("error encrypting the TOTP configuration secret: %v", err) + } - return err + if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig, + config.Username, config.Algorithm, config.Digits, config.Period, config.Secret); err != nil { + return fmt.Errorf("error upserting TOTP configuration: %w", err) + } + + return nil } -// DeleteTOTPConfiguration delete a TOTP secret from the database given a username. +// DeleteTOTPConfiguration delete a TOTP configuration from the database given a username. func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) { - _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username) + if _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username); err != nil { + return fmt.Errorf("error deleting TOTP configuration: %w", err) + } - return err + return nil } -// LoadTOTPConfiguration load a TOTP secret given a username from the database. +// LoadTOTPConfiguration load a TOTP configuration given a username from the database. func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error) { config = &models.TOTPConfiguration{} - err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config) - if err != nil { - if err == sql.ErrNoRows { + if err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config); err != nil { + if errors.Is(err, sql.ErrNoRows) { return nil, ErrNoTOTPSecret } - return nil, err + return nil, fmt.Errorf("error selecting TOTP configuration: %w", err) + } + + if config.Secret, err = p.decrypt(config.Secret); err != nil { + return nil, fmt.Errorf("error decrypting the TOTP secret: %v", err) } - // TODO: Decrypt config.Secret here. return config, nil } +// LoadTOTPConfigurations load a set of TOTP configurations. +func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []models.TOTPConfiguration, err error) { + rows, err := p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, limit, limit*page) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return configs, nil + } + + return nil, fmt.Errorf("error selecting TOTP configurations: %w", err) + } + + defer func() { + if err := rows.Close(); err != nil { + p.log.Errorf(logFmtErrClosingConn, err) + } + }() + + configs = make([]models.TOTPConfiguration, 0, limit) + + var config models.TOTPConfiguration + + for rows.Next() { + if err = rows.StructScan(&config); err != nil { + return nil, fmt.Errorf("error scanning TOTP configuration to struct: %w", err) + } + + if config.Secret, err = p.decrypt(config.Secret); err != nil { + return nil, fmt.Errorf("error decrypting the TOTP secret: %v", err) + } + + configs = append(configs, config) + } + + return configs, nil +} + +// UpdateTOTPConfigurationSecret updates a TOTP configuration secret. +func (p *SQLProvider) UpdateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) { + switch config.ID { + case 0: + _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecretByUsername, config.Secret, config.Username) + default: + _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecret, config.Secret, config.ID) + } + + if err != nil { + return fmt.Errorf("error updating TOTP configuration secret: %w", err) + } + + return nil +} + // SaveU2FDevice saves a registered U2F device. func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) { - _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.KeyHandle, device.PublicKey) + if _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.KeyHandle, device.PublicKey); err != nil { + return fmt.Errorf("error upserting U2F device secret: %v", err) + } - return err + return nil } // LoadU2FDevice loads a U2F device registration for a given username. @@ -252,13 +342,12 @@ func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (devic Username: username, } - err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username) - if err != nil { - if err == sql.ErrNoRows { + if err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username); err != nil { + if errors.Is(err, sql.ErrNoRows) { return nil, ErrNoU2FDeviceHandle } - return nil, err + return nil, fmt.Errorf("error selecting U2F device: %w", err) } return device, nil @@ -266,15 +355,22 @@ func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (devic // AppendAuthenticationLog append a mark to the authentication log. func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error) { - _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Username) - return err + if _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Username); err != nil { + return fmt.Errorf("error inserting authentiation attempt: %w", err) + } + + return nil } // LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log. func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []models.AuthenticationAttempt, err error) { rows, err := p.db.QueryxContext(ctx, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page) if err != nil { - return nil, err + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNoAuthenticationLogs + } + + return nil, fmt.Errorf("error selecting authentication logs: %w", err) } defer func() { @@ -283,13 +379,12 @@ func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username strin } }() + var attempt models.AuthenticationAttempt + attempts = make([]models.AuthenticationAttempt, 0, limit) for rows.Next() { - var attempt models.AuthenticationAttempt - - err = rows.StructScan(&attempt) - if err != nil { + if err = rows.StructScan(&attempt); err != nil { return nil, err } |
