summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider_encryption.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/storage/sql_provider_encryption.go')
-rw-r--r--internal/storage/sql_provider_encryption.go349
1 files changed, 206 insertions, 143 deletions
diff --git a/internal/storage/sql_provider_encryption.go b/internal/storage/sql_provider_encryption.go
index 338bb27bf..29d334510 100644
--- a/internal/storage/sql_provider_encryption.go
+++ b/internal/storage/sql_provider_encryption.go
@@ -1,38 +1,65 @@
package storage
import (
+ "bytes"
"context"
"crypto/sha256"
+ "database/sql"
+ "errors"
"fmt"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
- "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/utils"
)
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided
// by this command to encrypt the values again and update them using a transaction.
-func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionKey string) (err error) {
+func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) (err error) {
+ skey := sha256.Sum256([]byte(key))
+
+ if bytes.Equal(skey[:], p.key[:]) {
+ return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same")
+ }
+
+ if _, err = p.SchemaEncryptionCheckKey(ctx, false); err != nil {
+ return fmt.Errorf("error changing the storage encryption key: %w", err)
+ }
+
tx, err := p.db.Beginx()
if err != nil {
return fmt.Errorf("error beginning transaction to change encryption key: %w", err)
}
- key := sha256.Sum256([]byte(encryptionKey))
+ encChangeFuncs := []EncryptionChangeKeyFunc{
+ schemaEncryptionChangeKeyTOTP,
+ schemaEncryptionChangeKeyWebauthn,
+ }
- if err = p.schemaEncryptionChangeKeyTOTP(ctx, tx, key); err != nil {
- return err
+ for i := 0; true; i++ {
+ typeOAuth2Session := OAuth2SessionType(i)
+
+ if typeOAuth2Session.Table() == "" {
+ break
+ }
+
+ encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session))
}
- if err = p.schemaEncryptionChangeKeyWebauthn(ctx, tx, key); err != nil {
- return err
+ for _, encChangeFunc := range encChangeFuncs {
+ if err = encChangeFunc(ctx, p, tx, skey); err != nil {
+ if rerr := tx.Rollback(); rerr != nil {
+ return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
+ }
+
+ return fmt.Errorf("rollback due to error: %w", err)
+ }
}
- if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
+ if err = p.setNewEncryptionCheckValue(ctx, tx, &skey); err != nil {
+ if rerr := tx.Rollback(); rerr != nil {
+ return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
@@ -41,222 +68,262 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
return tx.Commit()
}
-func (p *SQLProvider) schemaEncryptionChangeKeyTOTP(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
- var configs []model.TOTPConfiguration
+// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
+func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (result EncryptionValidationResult, err error) {
+ version, err := p.SchemaVersion(ctx)
+ if err != nil {
+ return result, err
+ }
- for page := 0; true; page++ {
- if configs, err = p.LoadTOTPConfigurations(ctx, 10, page); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ if version < 1 {
+ return result, ErrSchemaEncryptionVersionUnsupported
+ }
- return fmt.Errorf("rollback due to error: %w", err)
- }
+ result = EncryptionValidationResult{
+ Tables: map[string]EncryptionValidationTableResult{},
+ }
- for _, config := range configs {
- if config.Secret, err = utils.Encrypt(config.Secret, &key); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil {
+ result.InvalidCheckValue = true
+ }
- return fmt.Errorf("rollback due to error: %w", err)
- }
+ if verbose {
+ encCheckFuncs := []EncryptionCheckKeyFunc{
+ schemaEncryptionCheckKeyTOTP,
+ schemaEncryptionCheckKeyWebauthn,
+ }
- if err = p.updateTOTPConfigurationSecret(ctx, config); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ for i := 0; true; i++ {
+ typeOAuth2Session := OAuth2SessionType(i)
- return fmt.Errorf("rollback due to error: %w", err)
+ if typeOAuth2Session.Table() == "" {
+ break
}
+
+ encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session))
}
- if len(configs) != 10 {
- break
+ for _, encCheckFunc := range encCheckFuncs {
+ table, tableResult := encCheckFunc(ctx, p)
+
+ result.Tables[table] = tableResult
}
}
- return nil
+ return result, nil
}
-func (p *SQLProvider) schemaEncryptionChangeKeyWebauthn(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
- var devices []model.WebauthnDevice
+func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
+ var count int
- for page := 0; true; page++ {
- if devices, err = p.LoadWebauthnDevices(ctx, 10, page); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableTOTPConfigurations)); err != nil {
+ return err
+ }
- return fmt.Errorf("rollback due to error: %w", err)
+ if count == 0 {
+ return nil
+ }
+
+ configs := make([]encTOTPConfiguration, 0, count)
+
+ if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectTOTPConfigurationsEncryptedData, tableTOTPConfigurations)); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil
}
- for _, device := range devices {
- if device.PublicKey, err = utils.Encrypt(device.PublicKey, &key); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ return fmt.Errorf("error selecting TOTP configurations: %w", err)
+ }
- return fmt.Errorf("rollback due to error: %w", err)
- }
+ query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations))
- if err = p.updateWebauthnDevicePublicKey(ctx, device); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ for _, c := range configs {
+ if c.Secret, err = provider.decrypt(c.Secret); err != nil {
+ return fmt.Errorf("error decrypting TOTP configuration secret with id '%d': %w", c.ID, err)
+ }
- return fmt.Errorf("rollback due to error: %w", err)
- }
+ if c.Secret, err = utils.Encrypt(c.Secret, &key); err != nil {
+ return fmt.Errorf("error encrypting TOTP configuration secret with id '%d': %w", c.ID, err)
}
- if len(devices) != 10 {
- break
+ if _, err = tx.ExecContext(ctx, query, c.Secret, c.ID); err != nil {
+ return fmt.Errorf("error updating TOTP configuration secret with id '%d': %w", c.ID, err)
}
}
return nil
}
-// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
-func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (err error) {
- version, err := p.SchemaVersion(ctx)
- if err != nil {
+func schemaEncryptionChangeKeyWebauthn(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
+ var count int
+
+ if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableWebauthnDevices)); err != nil {
return err
}
- if version < 1 {
- return ErrSchemaEncryptionVersionUnsupported
+ if count == 0 {
+ return nil
}
- var errs []error
+ devices := make([]encWebauthnDevice, 0, count)
- if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil {
- errs = append(errs, ErrSchemaEncryptionInvalidKey)
- }
-
- if verbose {
- if err = p.schemaEncryptionCheckTOTP(ctx); err != nil {
- errs = append(errs, err)
+ if err = tx.SelectContext(ctx, &devices, fmt.Sprintf(queryFmtSelectWebauthnDevicesEncryptedData, tableWebauthnDevices)); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil
}
- if err = p.schemaEncryptionCheckWebauthn(ctx); err != nil {
- errs = append(errs, err)
- }
+ return fmt.Errorf("error selecting Webauthn devices: %w", err)
}
- if len(errs) != 0 {
- for i, e := range errs {
- if i == 0 {
- err = e
+ query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateWebauthnDevicePublicKey, tableWebauthnDevices))
- continue
- }
+ for _, d := range devices {
+ if d.PublicKey, err = provider.decrypt(d.PublicKey); err != nil {
+ return fmt.Errorf("error decrypting Webauthn device public key with id '%d': %w", d.ID, err)
+ }
- err = fmt.Errorf("%w, %v", err, e)
+ if d.PublicKey, err = utils.Encrypt(d.PublicKey, &key); err != nil {
+ return fmt.Errorf("error encrypting Webauthn device public key with id '%d': %w", d.ID, err)
}
- return err
+ if _, err = tx.ExecContext(ctx, query, d.PublicKey, d.ID); err != nil {
+ return fmt.Errorf("error updating Webauthn device public key with id '%d': %w", d.ID, err)
+ }
}
return nil
}
-func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) {
- var (
- config model.TOTPConfiguration
- row int
- invalid int
- total int
- )
+func schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) EncryptionChangeKeyFunc {
+ return func(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
+ var count int
- pageSize := 10
+ if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, typeOAuth2Session.Table())); err != nil {
+ return err
+ }
- var rows *sqlx.Rows
+ if count == 0 {
+ return nil
+ }
- for page := 0; true; page++ {
- if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil {
- _ = rows.Close()
+ sessions := make([]encOAuth2Session, 0, count)
- return fmt.Errorf("error selecting TOTP configurations: %w", err)
+ if err = tx.SelectContext(ctx, &sessions, fmt.Sprintf(queryFmtSelectOAuth2SessionEncryptedData, typeOAuth2Session.Table())); err != nil {
+ return fmt.Errorf("error selecting oauth2 %s sessions: %w", typeOAuth2Session.String(), err)
}
- row = 0
+ query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionSessionData, typeOAuth2Session.Table()))
- for rows.Next() {
- total++
- row++
+ for _, s := range sessions {
+ if s.Session, err = provider.decrypt(s.Session); err != nil {
+ return fmt.Errorf("error decrypting oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
+ }
- if err = rows.StructScan(&config); err != nil {
- _ = rows.Close()
- return fmt.Errorf("error scanning TOTP configuration to struct: %w", err)
+ if s.Session, err = utils.Encrypt(s.Session, &key); err != nil {
+ return fmt.Errorf("error encrypting oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
}
- if _, err = p.decrypt(config.Secret); err != nil {
- invalid++
+ if _, err = tx.ExecContext(ctx, query, s.Session, s.ID); err != nil {
+ return fmt.Errorf("error updating oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
}
}
- _ = rows.Close()
+ return nil
+ }
+}
- if row < pageSize {
- break
- }
+func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
+ var (
+ rows *sqlx.Rows
+ err error
+ )
+
+ if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectTOTPConfigurationsEncryptedData, tableTOTPConfigurations)); err != nil {
+ return tableTOTPConfigurations, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting TOTP configurations: %w", err)}
}
- if invalid != 0 {
- return fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total)
+ var config encTOTPConfiguration
+
+ for rows.Next() {
+ result.Total++
+
+ if err = rows.StructScan(&config); err != nil {
+ _ = rows.Close()
+
+ return tableTOTPConfigurations, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning TOTP configuration to struct: %w", err)}
+ }
+
+ if _, err = provider.decrypt(config.Secret); err != nil {
+ result.Invalid++
+ }
}
- return nil
+ _ = rows.Close()
+
+ return tableTOTPConfigurations, result
}
-func (p *SQLProvider) schemaEncryptionCheckWebauthn(ctx context.Context) (err error) {
+func schemaEncryptionCheckKeyWebauthn(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
- device model.WebauthnDevice
- row int
- invalid int
- total int
+ rows *sqlx.Rows
+ err error
)
- pageSize := 10
+ if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectWebauthnDevicesEncryptedData, tableWebauthnDevices)); err != nil {
+ return tableWebauthnDevices, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting Webauthn devices: %w", err)}
+ }
+
+ var device encWebauthnDevice
- var rows *sqlx.Rows
+ for rows.Next() {
+ result.Total++
- for page := 0; true; page++ {
- if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil {
+ if err = rows.StructScan(&device); err != nil {
_ = rows.Close()
- return fmt.Errorf("error selecting Webauthn devices: %w", err)
+ return tableWebauthnDevices, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning Webauthn device to struct: %w", err)}
+ }
+
+ if _, err = provider.decrypt(device.PublicKey); err != nil {
+ result.Invalid++
+ }
+ }
+
+ _ = rows.Close()
+
+ return tableWebauthnDevices, result
+}
+
+func schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) EncryptionCheckKeyFunc {
+ return func(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
+ var (
+ rows *sqlx.Rows
+ err error
+ )
+
+ if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOAuth2SessionEncryptedData, typeOAuth2Session.Table())); err != nil {
+ return typeOAuth2Session.Table(), EncryptionValidationTableResult{Error: fmt.Errorf("error selecting oauth2 %s sessions: %w", typeOAuth2Session.String(), err)}
}
- row = 0
+ var session encOAuth2Session
for rows.Next() {
- total++
- row++
+ result.Total++
- if err = rows.StructScan(&device); err != nil {
+ if err = rows.StructScan(&session); err != nil {
_ = rows.Close()
- return fmt.Errorf("error scanning Webauthn device to struct: %w", err)
+
+ return typeOAuth2Session.Table(), EncryptionValidationTableResult{Error: fmt.Errorf("error scanning oauth2 %s session to struct: %w", typeOAuth2Session.String(), err)}
}
- if _, err = p.decrypt(device.PublicKey); err != nil {
- invalid++
+ if _, err = provider.decrypt(session.Session); err != nil {
+ result.Invalid++
}
}
_ = rows.Close()
- if row < pageSize {
- break
- }
- }
-
- if invalid != 0 {
- return fmt.Errorf("%d of %d total Webauthn devices were invalid", invalid, total)
+ return typeOAuth2Session.Table(), result
}
-
- return nil
}
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
@@ -278,7 +345,7 @@ func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (valu
return p.decrypt(encryptedValue)
}
-func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, key *[32]byte, e sqlx.ExecerContext) (err error) {
+func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) {
valueClearText, err := uuid.NewRandom()
if err != nil {
return err
@@ -289,11 +356,7 @@ func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, key *[32]b
return err
}
- if e != nil {
- _, err = e.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
- } else {
- _, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
- }
+ _, err = conn.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
return err
}