diff options
Diffstat (limited to 'internal/storage/sql_provider_encryption.go')
| -rw-r--r-- | internal/storage/sql_provider_encryption.go | 349 |
1 files changed, 206 insertions, 143 deletions
diff --git a/internal/storage/sql_provider_encryption.go b/internal/storage/sql_provider_encryption.go index 338bb27bf..29d334510 100644 --- a/internal/storage/sql_provider_encryption.go +++ b/internal/storage/sql_provider_encryption.go @@ -1,38 +1,65 @@ package storage import ( + "bytes" "context" "crypto/sha256" + "database/sql" + "errors" "fmt" "github.com/google/uuid" "github.com/jmoiron/sqlx" - "github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/utils" ) // SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided // by this command to encrypt the values again and update them using a transaction. -func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionKey string) (err error) { +func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) (err error) { + skey := sha256.Sum256([]byte(key)) + + if bytes.Equal(skey[:], p.key[:]) { + return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same") + } + + if _, err = p.SchemaEncryptionCheckKey(ctx, false); err != nil { + return fmt.Errorf("error changing the storage encryption key: %w", err) + } + tx, err := p.db.Beginx() if err != nil { return fmt.Errorf("error beginning transaction to change encryption key: %w", err) } - key := sha256.Sum256([]byte(encryptionKey)) + encChangeFuncs := []EncryptionChangeKeyFunc{ + schemaEncryptionChangeKeyTOTP, + schemaEncryptionChangeKeyWebauthn, + } - if err = p.schemaEncryptionChangeKeyTOTP(ctx, tx, key); err != nil { - return err + for i := 0; true; i++ { + typeOAuth2Session := OAuth2SessionType(i) + + if typeOAuth2Session.Table() == "" { + break + } + + encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session)) } - if err = p.schemaEncryptionChangeKeyWebauthn(ctx, tx, key); err != nil { - return err + for _, encChangeFunc := range encChangeFuncs { + if err = encChangeFunc(ctx, p, tx, skey); err != nil { + if rerr := tx.Rollback(); rerr != nil { + return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err) + } + + return fmt.Errorf("rollback due to error: %w", err) + } } - if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) + if err = p.setNewEncryptionCheckValue(ctx, tx, &skey); err != nil { + if rerr := tx.Rollback(); rerr != nil { + return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err) } return fmt.Errorf("rollback due to error: %w", err) @@ -41,222 +68,262 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK return tx.Commit() } -func (p *SQLProvider) schemaEncryptionChangeKeyTOTP(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) { - var configs []model.TOTPConfiguration +// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database. +func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (result EncryptionValidationResult, err error) { + version, err := p.SchemaVersion(ctx) + if err != nil { + return result, err + } - for page := 0; true; page++ { - if configs, err = p.LoadTOTPConfigurations(ctx, 10, page); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) - } + if version < 1 { + return result, ErrSchemaEncryptionVersionUnsupported + } - return fmt.Errorf("rollback due to error: %w", err) - } + result = EncryptionValidationResult{ + Tables: map[string]EncryptionValidationTableResult{}, + } - for _, config := range configs { - if config.Secret, err = utils.Encrypt(config.Secret, &key); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) - } + if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil { + result.InvalidCheckValue = true + } - return fmt.Errorf("rollback due to error: %w", err) - } + if verbose { + encCheckFuncs := []EncryptionCheckKeyFunc{ + schemaEncryptionCheckKeyTOTP, + schemaEncryptionCheckKeyWebauthn, + } - if err = p.updateTOTPConfigurationSecret(ctx, config); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) - } + for i := 0; true; i++ { + typeOAuth2Session := OAuth2SessionType(i) - return fmt.Errorf("rollback due to error: %w", err) + if typeOAuth2Session.Table() == "" { + break } + + encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session)) } - if len(configs) != 10 { - break + for _, encCheckFunc := range encCheckFuncs { + table, tableResult := encCheckFunc(ctx, p) + + result.Tables[table] = tableResult } } - return nil + return result, nil } -func (p *SQLProvider) schemaEncryptionChangeKeyWebauthn(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) { - var devices []model.WebauthnDevice +func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) { + var count int - for page := 0; true; page++ { - if devices, err = p.LoadWebauthnDevices(ctx, 10, page); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) - } + if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableTOTPConfigurations)); err != nil { + return err + } - return fmt.Errorf("rollback due to error: %w", err) + if count == 0 { + return nil + } + + configs := make([]encTOTPConfiguration, 0, count) + + if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectTOTPConfigurationsEncryptedData, tableTOTPConfigurations)); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil } - for _, device := range devices { - if device.PublicKey, err = utils.Encrypt(device.PublicKey, &key); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) - } + return fmt.Errorf("error selecting TOTP configurations: %w", err) + } - return fmt.Errorf("rollback due to error: %w", err) - } + query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations)) - if err = p.updateWebauthnDevicePublicKey(ctx, device); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) - } + for _, c := range configs { + if c.Secret, err = provider.decrypt(c.Secret); err != nil { + return fmt.Errorf("error decrypting TOTP configuration secret with id '%d': %w", c.ID, err) + } - return fmt.Errorf("rollback due to error: %w", err) - } + if c.Secret, err = utils.Encrypt(c.Secret, &key); err != nil { + return fmt.Errorf("error encrypting TOTP configuration secret with id '%d': %w", c.ID, err) } - if len(devices) != 10 { - break + if _, err = tx.ExecContext(ctx, query, c.Secret, c.ID); err != nil { + return fmt.Errorf("error updating TOTP configuration secret with id '%d': %w", c.ID, err) } } return nil } -// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database. -func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (err error) { - version, err := p.SchemaVersion(ctx) - if err != nil { +func schemaEncryptionChangeKeyWebauthn(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) { + var count int + + if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableWebauthnDevices)); err != nil { return err } - if version < 1 { - return ErrSchemaEncryptionVersionUnsupported + if count == 0 { + return nil } - var errs []error + devices := make([]encWebauthnDevice, 0, count) - if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil { - errs = append(errs, ErrSchemaEncryptionInvalidKey) - } - - if verbose { - if err = p.schemaEncryptionCheckTOTP(ctx); err != nil { - errs = append(errs, err) + if err = tx.SelectContext(ctx, &devices, fmt.Sprintf(queryFmtSelectWebauthnDevicesEncryptedData, tableWebauthnDevices)); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil } - if err = p.schemaEncryptionCheckWebauthn(ctx); err != nil { - errs = append(errs, err) - } + return fmt.Errorf("error selecting Webauthn devices: %w", err) } - if len(errs) != 0 { - for i, e := range errs { - if i == 0 { - err = e + query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateWebauthnDevicePublicKey, tableWebauthnDevices)) - continue - } + for _, d := range devices { + if d.PublicKey, err = provider.decrypt(d.PublicKey); err != nil { + return fmt.Errorf("error decrypting Webauthn device public key with id '%d': %w", d.ID, err) + } - err = fmt.Errorf("%w, %v", err, e) + if d.PublicKey, err = utils.Encrypt(d.PublicKey, &key); err != nil { + return fmt.Errorf("error encrypting Webauthn device public key with id '%d': %w", d.ID, err) } - return err + if _, err = tx.ExecContext(ctx, query, d.PublicKey, d.ID); err != nil { + return fmt.Errorf("error updating Webauthn device public key with id '%d': %w", d.ID, err) + } } return nil } -func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) { - var ( - config model.TOTPConfiguration - row int - invalid int - total int - ) +func schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) EncryptionChangeKeyFunc { + return func(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) { + var count int - pageSize := 10 + if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, typeOAuth2Session.Table())); err != nil { + return err + } - var rows *sqlx.Rows + if count == 0 { + return nil + } - for page := 0; true; page++ { - if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil { - _ = rows.Close() + sessions := make([]encOAuth2Session, 0, count) - return fmt.Errorf("error selecting TOTP configurations: %w", err) + if err = tx.SelectContext(ctx, &sessions, fmt.Sprintf(queryFmtSelectOAuth2SessionEncryptedData, typeOAuth2Session.Table())); err != nil { + return fmt.Errorf("error selecting oauth2 %s sessions: %w", typeOAuth2Session.String(), err) } - row = 0 + query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionSessionData, typeOAuth2Session.Table())) - for rows.Next() { - total++ - row++ + for _, s := range sessions { + if s.Session, err = provider.decrypt(s.Session); err != nil { + return fmt.Errorf("error decrypting oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err) + } - if err = rows.StructScan(&config); err != nil { - _ = rows.Close() - return fmt.Errorf("error scanning TOTP configuration to struct: %w", err) + if s.Session, err = utils.Encrypt(s.Session, &key); err != nil { + return fmt.Errorf("error encrypting oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err) } - if _, err = p.decrypt(config.Secret); err != nil { - invalid++ + if _, err = tx.ExecContext(ctx, query, s.Session, s.ID); err != nil { + return fmt.Errorf("error updating oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err) } } - _ = rows.Close() + return nil + } +} - if row < pageSize { - break - } +func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) { + var ( + rows *sqlx.Rows + err error + ) + + if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectTOTPConfigurationsEncryptedData, tableTOTPConfigurations)); err != nil { + return tableTOTPConfigurations, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting TOTP configurations: %w", err)} } - if invalid != 0 { - return fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total) + var config encTOTPConfiguration + + for rows.Next() { + result.Total++ + + if err = rows.StructScan(&config); err != nil { + _ = rows.Close() + + return tableTOTPConfigurations, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning TOTP configuration to struct: %w", err)} + } + + if _, err = provider.decrypt(config.Secret); err != nil { + result.Invalid++ + } } - return nil + _ = rows.Close() + + return tableTOTPConfigurations, result } -func (p *SQLProvider) schemaEncryptionCheckWebauthn(ctx context.Context) (err error) { +func schemaEncryptionCheckKeyWebauthn(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) { var ( - device model.WebauthnDevice - row int - invalid int - total int + rows *sqlx.Rows + err error ) - pageSize := 10 + if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectWebauthnDevicesEncryptedData, tableWebauthnDevices)); err != nil { + return tableWebauthnDevices, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting Webauthn devices: %w", err)} + } + + var device encWebauthnDevice - var rows *sqlx.Rows + for rows.Next() { + result.Total++ - for page := 0; true; page++ { - if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil { + if err = rows.StructScan(&device); err != nil { _ = rows.Close() - return fmt.Errorf("error selecting Webauthn devices: %w", err) + return tableWebauthnDevices, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning Webauthn device to struct: %w", err)} + } + + if _, err = provider.decrypt(device.PublicKey); err != nil { + result.Invalid++ + } + } + + _ = rows.Close() + + return tableWebauthnDevices, result +} + +func schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) EncryptionCheckKeyFunc { + return func(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) { + var ( + rows *sqlx.Rows + err error + ) + + if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOAuth2SessionEncryptedData, typeOAuth2Session.Table())); err != nil { + return typeOAuth2Session.Table(), EncryptionValidationTableResult{Error: fmt.Errorf("error selecting oauth2 %s sessions: %w", typeOAuth2Session.String(), err)} } - row = 0 + var session encOAuth2Session for rows.Next() { - total++ - row++ + result.Total++ - if err = rows.StructScan(&device); err != nil { + if err = rows.StructScan(&session); err != nil { _ = rows.Close() - return fmt.Errorf("error scanning Webauthn device to struct: %w", err) + + return typeOAuth2Session.Table(), EncryptionValidationTableResult{Error: fmt.Errorf("error scanning oauth2 %s session to struct: %w", typeOAuth2Session.String(), err)} } - if _, err = p.decrypt(device.PublicKey); err != nil { - invalid++ + if _, err = provider.decrypt(session.Session); err != nil { + result.Invalid++ } } _ = rows.Close() - if row < pageSize { - break - } - } - - if invalid != 0 { - return fmt.Errorf("%d of %d total Webauthn devices were invalid", invalid, total) + return typeOAuth2Session.Table(), result } - - return nil } func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) { @@ -278,7 +345,7 @@ func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (valu return p.decrypt(encryptedValue) } -func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, key *[32]byte, e sqlx.ExecerContext) (err error) { +func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) { valueClearText, err := uuid.NewRandom() if err != nil { return err @@ -289,11 +356,7 @@ func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, key *[32]b return err } - if e != nil { - _, err = e.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value) - } else { - _, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value) - } + _, err = conn.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value) return err } |
