summaryrefslogtreecommitdiff
path: root/internal/storage
diff options
context:
space:
mode:
authorJames Elliott <james-d-elliott@users.noreply.github.com>2022-11-25 23:44:55 +1100
committerGitHub <noreply@github.com>2022-11-25 23:44:55 +1100
commit3e4ac7821d51ac447bb39e7e1ea3c385dc3084d9 (patch)
tree69594576856eb8b587158d9245f70aff4fec429a /internal/storage
parent3c291b5685212813f98f365c8d963e0f107860cb (diff)
refactor: remove pre1 migration path (#4356)
This removes pre1 migrations and improves a lot of tooling.
Diffstat (limited to 'internal/storage')
-rw-r--r--internal/storage/const.go48
-rw-r--r--internal/storage/errors.go8
-rw-r--r--internal/storage/migrations.go36
-rw-r--r--internal/storage/migrations/V0005.ConsentSubjectNULL.sqlite.down.sql4
-rw-r--r--internal/storage/migrations/V0005.ConsentSubjectNULL.sqlite.up.sql4
-rw-r--r--internal/storage/migrations/V0007.ConsistencyFixes.sqlite.down.sql4
-rw-r--r--internal/storage/migrations/V0007.ConsistencyFixes.sqlite.up.sql4
-rw-r--r--internal/storage/provider.go4
-rw-r--r--internal/storage/sql_provider.go48
-rw-r--r--internal/storage/sql_provider_backend_postgres.go4
-rw-r--r--internal/storage/sql_provider_encryption.go349
-rw-r--r--internal/storage/sql_provider_queries.go28
-rw-r--r--internal/storage/sql_provider_queries_special.go104
-rw-r--r--internal/storage/sql_provider_schema.go255
-rw-r--r--internal/storage/sql_provider_schema_pre1.go470
-rw-r--r--internal/storage/sql_provider_schema_test.go15
-rw-r--r--internal/storage/types.go95
17 files changed, 519 insertions, 961 deletions
diff --git a/internal/storage/const.go b/internal/storage/const.go
index d899c82f9..5ca7a644f 100644
--- a/internal/storage/const.go
+++ b/internal/storage/const.go
@@ -15,17 +15,16 @@ const (
tableOAuth2ConsentSession = "oauth2_consent_session"
tableOAuth2ConsentPreConfiguration = "oauth2_consent_preconfiguration"
- tableOAuth2AuthorizeCodeSession = "oauth2_authorization_code_session"
- tableOAuth2AccessTokenSession = "oauth2_access_token_session" //nolint:gosec // This is not a hardcoded credential.
- tableOAuth2RefreshTokenSession = "oauth2_refresh_token_session" //nolint:gosec // This is not a hardcoded credential.
- tableOAuth2PKCERequestSession = "oauth2_pkce_request_session"
- tableOAuth2OpenIDConnectSession = "oauth2_openid_connect_session"
- tableOAuth2BlacklistedJTI = "oauth2_blacklisted_jti"
+
+ tableOAuth2AuthorizeCodeSession = "oauth2_authorization_code_session"
+ tableOAuth2AccessTokenSession = "oauth2_access_token_session" //nolint:gosec // This is not a hardcoded credential.
+ tableOAuth2RefreshTokenSession = "oauth2_refresh_token_session" //nolint:gosec // This is not a hardcoded credential.
+ tableOAuth2PKCERequestSession = "oauth2_pkce_request_session"
+ tableOAuth2OpenIDConnectSession = "oauth2_openid_connect_session"
+ tableOAuth2BlacklistedJTI = "oauth2_blacklisted_jti"
tableMigrations = "migrations"
tableEncryption = "encryption"
-
- tablePrefixBackup = "_bkp_"
)
// OAuth2SessionType represents the potential OAuth 2.0 session types.
@@ -58,6 +57,24 @@ func (s OAuth2SessionType) String() string {
}
}
+// Table returns the table name for this session type.
+func (s OAuth2SessionType) Table() string {
+ switch s {
+ case OAuth2SessionTypeAuthorizeCode:
+ return tableOAuth2AuthorizeCodeSession
+ case OAuth2SessionTypeAccessToken:
+ return tableOAuth2AccessTokenSession
+ case OAuth2SessionTypeRefreshToken:
+ return tableOAuth2RefreshTokenSession
+ case OAuth2SessionTypePKCEChallenge:
+ return tableOAuth2PKCERequestSession
+ case OAuth2SessionTypeOpenIDConnect:
+ return tableOAuth2OpenIDConnectSession
+ default:
+ return ""
+ }
+}
+
const (
sqlNetworkTypeTCP = "tcp"
sqlNetworkTypeUnixSocket = "unix"
@@ -72,16 +89,6 @@ const (
tablePre1TOTPSecrets = "totp_secrets"
tablePre1IdentityVerificationTokens = "identity_verification_tokens"
tablePre1U2FDevices = "u2f_devices"
-
- tablePre1Config = "config"
-
- tableAlphaAuthenticationLogs = "AuthenticationLogs"
- tableAlphaIdentityVerificationTokens = "IdentityVerificationTokens"
- tableAlphaPreferences = "Preferences"
- tableAlphaPreferencesTableName = "PreferencesTableName"
- tableAlphaSecondFactorPreferences = "SecondFactorPreferences"
- tableAlphaTOTPSecrets = "TOTPSecrets"
- tableAlphaU2FDeviceHandles = "U2FDeviceHandles"
)
var tablesPre1 = []string{
@@ -114,3 +121,8 @@ const (
var (
reMigration = regexp.MustCompile(`^V(\d{4})\.([^.]+)\.(all|sqlite|postgres|mysql)\.(up|down)\.sql$`)
)
+
+const (
+ na = "N/A"
+ invalid = "invalid"
+)
diff --git a/internal/storage/errors.go b/internal/storage/errors.go
index 388b69130..f3098f643 100644
--- a/internal/storage/errors.go
+++ b/internal/storage/errors.go
@@ -35,7 +35,7 @@ var (
// ErrSchemaEncryptionInvalidKey is returned when the schema is checked if the encryption key is valid for
// the database but the key doesn't appear to be valid.
- ErrSchemaEncryptionInvalidKey = errors.New("the encryption key is not valid against the schema check value")
+ ErrSchemaEncryptionInvalidKey = errors.New("the configured encryption key does not appear to be valid for this database which may occur if the encryption key was changed in the configuration without using the cli to change it in the database")
)
// Error formats for the storage provider.
@@ -49,7 +49,6 @@ const (
const (
errFmtFailedMigration = "schema migration %d (%s) failed: %w"
- errFmtFailedMigrationPre1 = "schema migration pre1 failed: %w"
errFmtSchemaCurrentGreaterThanLatestKnown = "current schema version is greater than the latest known schema " +
"version, you must downgrade to schema version %d before you can use this version of Authelia"
)
@@ -59,3 +58,8 @@ const (
logFmtMigrationComplete = "Storage schema migration from %s to %s is complete"
logFmtErrClosingConn = "Error occurred closing SQL connection: %v"
)
+
+const (
+ errFmtMigrationPre1 = "schema migration %s pre1 is no longer supported: you must use an older version of authelia to perform this migration: %s"
+ errFmtMigrationPre1SuggestedVersion = "the suggested authelia version is 4.37.2"
+)
diff --git a/internal/storage/migrations.go b/internal/storage/migrations.go
index bb527c942..f634bccdd 100644
--- a/internal/storage/migrations.go
+++ b/internal/storage/migrations.go
@@ -46,42 +46,6 @@ func latestMigrationVersion(providerName string) (version int, err error) {
return version, nil
}
-func loadMigration(providerName string, version int, up bool) (migration *model.SchemaMigration, err error) {
- entries, err := migrationsFS.ReadDir("migrations")
- if err != nil {
- return nil, err
- }
-
- for _, entry := range entries {
- if entry.IsDir() {
- continue
- }
-
- m, err := scanMigration(entry.Name())
- if err != nil {
- return nil, err
- }
-
- migration = &m
-
- if up != migration.Up {
- continue
- }
-
- if migration.Provider != providerAll && migration.Provider != providerName {
- continue
- }
-
- if version != migration.Version {
- continue
- }
-
- return migration, nil
- }
-
- return nil, errors.New("migration not found")
-}
-
// loadMigrations scans the migrations fs and loads the appropriate migrations for a given providerName, prior and
// target versions. If the target version is -1 this indicates the latest version. If the target version is 0
// this indicates the database zero state.
diff --git a/internal/storage/migrations/V0005.ConsentSubjectNULL.sqlite.down.sql b/internal/storage/migrations/V0005.ConsentSubjectNULL.sqlite.down.sql
index c736d6b87..fac319005 100644
--- a/internal/storage/migrations/V0005.ConsentSubjectNULL.sqlite.down.sql
+++ b/internal/storage/migrations/V0005.ConsentSubjectNULL.sqlite.down.sql
@@ -1,7 +1,5 @@
PRAGMA foreign_keys=off;
-BEGIN TRANSACTION;
-
DELETE FROM oauth2_consent_session
WHERE subject IN(SELECT identifier FROM user_opaque_identifier WHERE username = '' AND service IN('openid', 'openid_connect'));
@@ -261,6 +259,4 @@ ORDER BY id;
DROP TABLE IF EXISTS _bkp_DOWN_V0005_oauth2_openid_connect_session;
-COMMIT;
-
PRAGMA foreign_keys=on;
diff --git a/internal/storage/migrations/V0005.ConsentSubjectNULL.sqlite.up.sql b/internal/storage/migrations/V0005.ConsentSubjectNULL.sqlite.up.sql
index c9347cb4d..8f35ee6ca 100644
--- a/internal/storage/migrations/V0005.ConsentSubjectNULL.sqlite.up.sql
+++ b/internal/storage/migrations/V0005.ConsentSubjectNULL.sqlite.up.sql
@@ -1,7 +1,5 @@
PRAGMA foreign_keys=off;
-BEGIN TRANSACTION;
-
DELETE FROM oauth2_consent_session
WHERE subject IN(SELECT identifier FROM user_opaque_identifier WHERE username = '' AND service IN('openid', 'openid_connect'));
@@ -255,6 +253,4 @@ ORDER BY id;
DROP TABLE IF EXISTS _bkp_UP_V0005_oauth2_openid_connect_session;
-COMMIT;
-
PRAGMA foreign_keys=on;
diff --git a/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.down.sql b/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.down.sql
index b5a31858d..76f7d68e8 100644
--- a/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.down.sql
+++ b/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.down.sql
@@ -1,7 +1,5 @@
PRAGMA foreign_keys=off;
-BEGIN TRANSACTION;
-
ALTER TABLE webauthn_devices
RENAME TO _bkp_DOWN_V0007_webauthn_devices;
@@ -612,6 +610,4 @@ ORDER BY id;
DROP TABLE IF EXISTS _bkp_DOWN_V0007_oauth2_openid_connect_session;
-COMMIT;
-
PRAGMA foreign_keys=on;
diff --git a/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.up.sql b/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.up.sql
index 80847f593..1af55ed6b 100644
--- a/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.up.sql
+++ b/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.up.sql
@@ -1,7 +1,5 @@
PRAGMA foreign_keys=off;
-BEGIN TRANSACTION;
-
DROP TABLE IF EXISTS _bkp_UP_V0002_totp_configurations;
DROP TABLE IF EXISTS _bkp_UP_V0002_u2f_devices;
DROP TABLE IF EXISTS totp_secrets;
@@ -662,6 +660,4 @@ ORDER BY id;
DROP TABLE IF EXISTS _bkp_UP_V0007_oauth2_openid_connect_session;
-COMMIT;
-
PRAGMA foreign_keys=on;
diff --git a/internal/storage/provider.go b/internal/storage/provider.go
index ecfe104b0..d3c9f3b5a 100644
--- a/internal/storage/provider.go
+++ b/internal/storage/provider.go
@@ -77,8 +77,8 @@ type Provider interface {
SchemaMigrationsUp(ctx context.Context, version int) (migrations []model.SchemaMigration, err error)
SchemaMigrationsDown(ctx context.Context, version int) (migrations []model.SchemaMigration, err error)
- SchemaEncryptionChangeKey(ctx context.Context, encryptionKey string) (err error)
- SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (err error)
+ SchemaEncryptionChangeKey(ctx context.Context, key string) (err error)
+ SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (result EncryptionValidationResult, err error)
Close() (err error)
}
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go
index fd6943e1a..a55a41cea 100644
--- a/internal/storage/sql_provider.go
+++ b/internal/storage/sql_provider.go
@@ -43,8 +43,6 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
sqlSelectTOTPConfigs: fmt.Sprintf(queryFmtSelectTOTPConfigurations, tableTOTPConfigurations),
- sqlUpdateTOTPConfigSecret: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations),
- sqlUpdateTOTPConfigSecretByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecretByUsername, tableTOTPConfigurations),
sqlUpdateTOTPConfigRecordSignIn: fmt.Sprintf(queryFmtUpdateTOTPConfigRecordSignIn, tableTOTPConfigurations),
sqlUpdateTOTPConfigRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigRecordSignInByUsername, tableTOTPConfigurations),
@@ -52,8 +50,6 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlSelectWebauthnDevices: fmt.Sprintf(queryFmtSelectWebauthnDevices, tableWebauthnDevices),
sqlSelectWebauthnDevicesByUsername: fmt.Sprintf(queryFmtSelectWebauthnDevicesByUsername, tableWebauthnDevices),
- sqlUpdateWebauthnDevicePublicKey: fmt.Sprintf(queryFmtUpdateWebauthnDevicePublicKey, tableWebauthnDevices),
- sqlUpdateWebauthnDevicePublicKeyByUsername: fmt.Sprintf(queryFmtUpdateUpdateWebauthnDevicePublicKeyByUsername, tableWebauthnDevices),
sqlUpdateWebauthnDeviceRecordSignIn: fmt.Sprintf(queryFmtUpdateWebauthnDeviceRecordSignIn, tableWebauthnDevices),
sqlUpdateWebauthnDeviceRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateWebauthnDeviceRecordSignInByUsername, tableWebauthnDevices),
@@ -161,8 +157,6 @@ type SQLProvider struct {
sqlSelectTOTPConfig string
sqlSelectTOTPConfigs string
- sqlUpdateTOTPConfigSecret string
- sqlUpdateTOTPConfigSecretByUsername string
sqlUpdateTOTPConfigRecordSignIn string
sqlUpdateTOTPConfigRecordSignInByUsername string
@@ -171,8 +165,6 @@ type SQLProvider struct {
sqlSelectWebauthnDevices string
sqlSelectWebauthnDevicesByUsername string
- sqlUpdateWebauthnDevicePublicKey string
- sqlUpdateWebauthnDevicePublicKeyByUsername string
sqlUpdateWebauthnDeviceRecordSignIn string
sqlUpdateWebauthnDeviceRecordSignInByUsername string
@@ -292,13 +284,17 @@ func (p *SQLProvider) StartupCheck() (err error) {
ctx := context.Background()
- if err = p.SchemaEncryptionCheckKey(ctx, false); err != nil && !errors.Is(err, ErrSchemaEncryptionVersionUnsupported) {
+ var result EncryptionValidationResult
+
+ if result, err = p.SchemaEncryptionCheckKey(ctx, false); err != nil && !errors.Is(err, ErrSchemaEncryptionVersionUnsupported) {
return err
}
- err = p.SchemaMigrate(ctx, true, SchemaLatest)
+ if !result.Success() {
+ return ErrSchemaEncryptionInvalidKey
+ }
- switch err {
+ switch err = p.SchemaMigrate(ctx, true, SchemaLatest); err {
case ErrSchemaAlreadyUpToDate:
p.log.Infof("Storage schema is already up to date")
return nil
@@ -837,21 +833,6 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in
return configs, nil
}
-func (p *SQLProvider) updateTOTPConfigurationSecret(ctx context.Context, config model.TOTPConfiguration) (err error) {
- switch config.ID {
- case 0:
- _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecretByUsername, config.Secret, config.Username)
- default:
- _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecret, config.Secret, config.ID)
- }
-
- if err != nil {
- return fmt.Errorf("error updating TOTP configuration secret for user '%s': %w", config.Username, err)
- }
-
- return nil
-}
-
// SaveWebauthnDevice saves a registered Webauthn device.
func (p *SQLProvider) SaveWebauthnDevice(ctx context.Context, device model.WebauthnDevice) (err error) {
if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil {
@@ -947,21 +928,6 @@ func (p *SQLProvider) LoadWebauthnDevicesByUsername(ctx context.Context, usernam
return devices, nil
}
-func (p *SQLProvider) updateWebauthnDevicePublicKey(ctx context.Context, device model.WebauthnDevice) (err error) {
- switch device.ID {
- case 0:
- _, err = p.db.ExecContext(ctx, p.sqlUpdateWebauthnDevicePublicKeyByUsername, device.PublicKey, device.Username, device.KID)
- default:
- _, err = p.db.ExecContext(ctx, p.sqlUpdateWebauthnDevicePublicKey, device.PublicKey, device.ID)
- }
-
- if err != nil {
- return fmt.Errorf("error updating Webauthn public key for user '%s' kid '%x': %w", device.Username, device.KID, err)
- }
-
- return nil
-}
-
// SavePreferredDuoDevice saves a Duo device.
func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device model.DuoDevice) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method); err != nil {
diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go
index e8834db44..9e0c127a1 100644
--- a/internal/storage/sql_provider_backend_postgres.go
+++ b/internal/storage/sql_provider_backend_postgres.go
@@ -58,13 +58,9 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername)
provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig)
provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs)
- provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret)
- provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername)
provider.sqlSelectWebauthnDevices = provider.db.Rebind(provider.sqlSelectWebauthnDevices)
provider.sqlSelectWebauthnDevicesByUsername = provider.db.Rebind(provider.sqlSelectWebauthnDevicesByUsername)
- provider.sqlUpdateWebauthnDevicePublicKey = provider.db.Rebind(provider.sqlUpdateWebauthnDevicePublicKey)
- provider.sqlUpdateWebauthnDevicePublicKeyByUsername = provider.db.Rebind(provider.sqlUpdateWebauthnDevicePublicKeyByUsername)
provider.sqlUpdateWebauthnDeviceRecordSignIn = provider.db.Rebind(provider.sqlUpdateWebauthnDeviceRecordSignIn)
provider.sqlUpdateWebauthnDeviceRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateWebauthnDeviceRecordSignInByUsername)
provider.sqlDeleteWebauthnDevice = provider.db.Rebind(provider.sqlDeleteWebauthnDevice)
diff --git a/internal/storage/sql_provider_encryption.go b/internal/storage/sql_provider_encryption.go
index 338bb27bf..29d334510 100644
--- a/internal/storage/sql_provider_encryption.go
+++ b/internal/storage/sql_provider_encryption.go
@@ -1,38 +1,65 @@
package storage
import (
+ "bytes"
"context"
"crypto/sha256"
+ "database/sql"
+ "errors"
"fmt"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
- "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/utils"
)
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided
// by this command to encrypt the values again and update them using a transaction.
-func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionKey string) (err error) {
+func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) (err error) {
+ skey := sha256.Sum256([]byte(key))
+
+ if bytes.Equal(skey[:], p.key[:]) {
+ return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same")
+ }
+
+ if _, err = p.SchemaEncryptionCheckKey(ctx, false); err != nil {
+ return fmt.Errorf("error changing the storage encryption key: %w", err)
+ }
+
tx, err := p.db.Beginx()
if err != nil {
return fmt.Errorf("error beginning transaction to change encryption key: %w", err)
}
- key := sha256.Sum256([]byte(encryptionKey))
+ encChangeFuncs := []EncryptionChangeKeyFunc{
+ schemaEncryptionChangeKeyTOTP,
+ schemaEncryptionChangeKeyWebauthn,
+ }
- if err = p.schemaEncryptionChangeKeyTOTP(ctx, tx, key); err != nil {
- return err
+ for i := 0; true; i++ {
+ typeOAuth2Session := OAuth2SessionType(i)
+
+ if typeOAuth2Session.Table() == "" {
+ break
+ }
+
+ encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session))
}
- if err = p.schemaEncryptionChangeKeyWebauthn(ctx, tx, key); err != nil {
- return err
+ for _, encChangeFunc := range encChangeFuncs {
+ if err = encChangeFunc(ctx, p, tx, skey); err != nil {
+ if rerr := tx.Rollback(); rerr != nil {
+ return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
+ }
+
+ return fmt.Errorf("rollback due to error: %w", err)
+ }
}
- if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
+ if err = p.setNewEncryptionCheckValue(ctx, tx, &skey); err != nil {
+ if rerr := tx.Rollback(); rerr != nil {
+ return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
@@ -41,222 +68,262 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
return tx.Commit()
}
-func (p *SQLProvider) schemaEncryptionChangeKeyTOTP(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
- var configs []model.TOTPConfiguration
+// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
+func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (result EncryptionValidationResult, err error) {
+ version, err := p.SchemaVersion(ctx)
+ if err != nil {
+ return result, err
+ }
- for page := 0; true; page++ {
- if configs, err = p.LoadTOTPConfigurations(ctx, 10, page); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ if version < 1 {
+ return result, ErrSchemaEncryptionVersionUnsupported
+ }
- return fmt.Errorf("rollback due to error: %w", err)
- }
+ result = EncryptionValidationResult{
+ Tables: map[string]EncryptionValidationTableResult{},
+ }
- for _, config := range configs {
- if config.Secret, err = utils.Encrypt(config.Secret, &key); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil {
+ result.InvalidCheckValue = true
+ }
- return fmt.Errorf("rollback due to error: %w", err)
- }
+ if verbose {
+ encCheckFuncs := []EncryptionCheckKeyFunc{
+ schemaEncryptionCheckKeyTOTP,
+ schemaEncryptionCheckKeyWebauthn,
+ }
- if err = p.updateTOTPConfigurationSecret(ctx, config); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ for i := 0; true; i++ {
+ typeOAuth2Session := OAuth2SessionType(i)
- return fmt.Errorf("rollback due to error: %w", err)
+ if typeOAuth2Session.Table() == "" {
+ break
}
+
+ encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session))
}
- if len(configs) != 10 {
- break
+ for _, encCheckFunc := range encCheckFuncs {
+ table, tableResult := encCheckFunc(ctx, p)
+
+ result.Tables[table] = tableResult
}
}
- return nil
+ return result, nil
}
-func (p *SQLProvider) schemaEncryptionChangeKeyWebauthn(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
- var devices []model.WebauthnDevice
+func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
+ var count int
- for page := 0; true; page++ {
- if devices, err = p.LoadWebauthnDevices(ctx, 10, page); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableTOTPConfigurations)); err != nil {
+ return err
+ }
- return fmt.Errorf("rollback due to error: %w", err)
+ if count == 0 {
+ return nil
+ }
+
+ configs := make([]encTOTPConfiguration, 0, count)
+
+ if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectTOTPConfigurationsEncryptedData, tableTOTPConfigurations)); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil
}
- for _, device := range devices {
- if device.PublicKey, err = utils.Encrypt(device.PublicKey, &key); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ return fmt.Errorf("error selecting TOTP configurations: %w", err)
+ }
- return fmt.Errorf("rollback due to error: %w", err)
- }
+ query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations))
- if err = p.updateWebauthnDevicePublicKey(ctx, device); err != nil {
- if rollbackErr := tx.Rollback(); rollbackErr != nil {
- return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
- }
+ for _, c := range configs {
+ if c.Secret, err = provider.decrypt(c.Secret); err != nil {
+ return fmt.Errorf("error decrypting TOTP configuration secret with id '%d': %w", c.ID, err)
+ }
- return fmt.Errorf("rollback due to error: %w", err)
- }
+ if c.Secret, err = utils.Encrypt(c.Secret, &key); err != nil {
+ return fmt.Errorf("error encrypting TOTP configuration secret with id '%d': %w", c.ID, err)
}
- if len(devices) != 10 {
- break
+ if _, err = tx.ExecContext(ctx, query, c.Secret, c.ID); err != nil {
+ return fmt.Errorf("error updating TOTP configuration secret with id '%d': %w", c.ID, err)
}
}
return nil
}
-// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
-func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (err error) {
- version, err := p.SchemaVersion(ctx)
- if err != nil {
+func schemaEncryptionChangeKeyWebauthn(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
+ var count int
+
+ if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableWebauthnDevices)); err != nil {
return err
}
- if version < 1 {
- return ErrSchemaEncryptionVersionUnsupported
+ if count == 0 {
+ return nil
}
- var errs []error
+ devices := make([]encWebauthnDevice, 0, count)
- if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil {
- errs = append(errs, ErrSchemaEncryptionInvalidKey)
- }
-
- if verbose {
- if err = p.schemaEncryptionCheckTOTP(ctx); err != nil {
- errs = append(errs, err)
+ if err = tx.SelectContext(ctx, &devices, fmt.Sprintf(queryFmtSelectWebauthnDevicesEncryptedData, tableWebauthnDevices)); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil
}
- if err = p.schemaEncryptionCheckWebauthn(ctx); err != nil {
- errs = append(errs, err)
- }
+ return fmt.Errorf("error selecting Webauthn devices: %w", err)
}
- if len(errs) != 0 {
- for i, e := range errs {
- if i == 0 {
- err = e
+ query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateWebauthnDevicePublicKey, tableWebauthnDevices))
- continue
- }
+ for _, d := range devices {
+ if d.PublicKey, err = provider.decrypt(d.PublicKey); err != nil {
+ return fmt.Errorf("error decrypting Webauthn device public key with id '%d': %w", d.ID, err)
+ }
- err = fmt.Errorf("%w, %v", err, e)
+ if d.PublicKey, err = utils.Encrypt(d.PublicKey, &key); err != nil {
+ return fmt.Errorf("error encrypting Webauthn device public key with id '%d': %w", d.ID, err)
}
- return err
+ if _, err = tx.ExecContext(ctx, query, d.PublicKey, d.ID); err != nil {
+ return fmt.Errorf("error updating Webauthn device public key with id '%d': %w", d.ID, err)
+ }
}
return nil
}
-func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) {
- var (
- config model.TOTPConfiguration
- row int
- invalid int
- total int
- )
+func schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) EncryptionChangeKeyFunc {
+ return func(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
+ var count int
- pageSize := 10
+ if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, typeOAuth2Session.Table())); err != nil {
+ return err
+ }
- var rows *sqlx.Rows
+ if count == 0 {
+ return nil
+ }
- for page := 0; true; page++ {
- if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil {
- _ = rows.Close()
+ sessions := make([]encOAuth2Session, 0, count)
- return fmt.Errorf("error selecting TOTP configurations: %w", err)
+ if err = tx.SelectContext(ctx, &sessions, fmt.Sprintf(queryFmtSelectOAuth2SessionEncryptedData, typeOAuth2Session.Table())); err != nil {
+ return fmt.Errorf("error selecting oauth2 %s sessions: %w", typeOAuth2Session.String(), err)
}
- row = 0
+ query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionSessionData, typeOAuth2Session.Table()))
- for rows.Next() {
- total++
- row++
+ for _, s := range sessions {
+ if s.Session, err = provider.decrypt(s.Session); err != nil {
+ return fmt.Errorf("error decrypting oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
+ }
- if err = rows.StructScan(&config); err != nil {
- _ = rows.Close()
- return fmt.Errorf("error scanning TOTP configuration to struct: %w", err)
+ if s.Session, err = utils.Encrypt(s.Session, &key); err != nil {
+ return fmt.Errorf("error encrypting oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
}
- if _, err = p.decrypt(config.Secret); err != nil {
- invalid++
+ if _, err = tx.ExecContext(ctx, query, s.Session, s.ID); err != nil {
+ return fmt.Errorf("error updating oauth2 %s session data with id '%d': %w", typeOAuth2Session.String(), s.ID, err)
}
}
- _ = rows.Close()
+ return nil
+ }
+}
- if row < pageSize {
- break
- }
+func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
+ var (
+ rows *sqlx.Rows
+ err error
+ )
+
+ if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectTOTPConfigurationsEncryptedData, tableTOTPConfigurations)); err != nil {
+ return tableTOTPConfigurations, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting TOTP configurations: %w", err)}
}
- if invalid != 0 {
- return fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total)
+ var config encTOTPConfiguration
+
+ for rows.Next() {
+ result.Total++
+
+ if err = rows.StructScan(&config); err != nil {
+ _ = rows.Close()
+
+ return tableTOTPConfigurations, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning TOTP configuration to struct: %w", err)}
+ }
+
+ if _, err = provider.decrypt(config.Secret); err != nil {
+ result.Invalid++
+ }
}
- return nil
+ _ = rows.Close()
+
+ return tableTOTPConfigurations, result
}
-func (p *SQLProvider) schemaEncryptionCheckWebauthn(ctx context.Context) (err error) {
+func schemaEncryptionCheckKeyWebauthn(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
- device model.WebauthnDevice
- row int
- invalid int
- total int
+ rows *sqlx.Rows
+ err error
)
- pageSize := 10
+ if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectWebauthnDevicesEncryptedData, tableWebauthnDevices)); err != nil {
+ return tableWebauthnDevices, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting Webauthn devices: %w", err)}
+ }
+
+ var device encWebauthnDevice
- var rows *sqlx.Rows
+ for rows.Next() {
+ result.Total++
- for page := 0; true; page++ {
- if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil {
+ if err = rows.StructScan(&device); err != nil {
_ = rows.Close()
- return fmt.Errorf("error selecting Webauthn devices: %w", err)
+ return tableWebauthnDevices, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning Webauthn device to struct: %w", err)}
+ }
+
+ if _, err = provider.decrypt(device.PublicKey); err != nil {
+ result.Invalid++
+ }
+ }
+
+ _ = rows.Close()
+
+ return tableWebauthnDevices, result
+}
+
+func schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) EncryptionCheckKeyFunc {
+ return func(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
+ var (
+ rows *sqlx.Rows
+ err error
+ )
+
+ if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOAuth2SessionEncryptedData, typeOAuth2Session.Table())); err != nil {
+ return typeOAuth2Session.Table(), EncryptionValidationTableResult{Error: fmt.Errorf("error selecting oauth2 %s sessions: %w", typeOAuth2Session.String(), err)}
}
- row = 0
+ var session encOAuth2Session
for rows.Next() {
- total++
- row++
+ result.Total++
- if err = rows.StructScan(&device); err != nil {
+ if err = rows.StructScan(&session); err != nil {
_ = rows.Close()
- return fmt.Errorf("error scanning Webauthn device to struct: %w", err)
+
+ return typeOAuth2Session.Table(), EncryptionValidationTableResult{Error: fmt.Errorf("error scanning oauth2 %s session to struct: %w", typeOAuth2Session.String(), err)}
}
- if _, err = p.decrypt(device.PublicKey); err != nil {
- invalid++
+ if _, err = provider.decrypt(session.Session); err != nil {
+ result.Invalid++
}
}
_ = rows.Close()
- if row < pageSize {
- break
- }
- }
-
- if invalid != 0 {
- return fmt.Errorf("%d of %d total Webauthn devices were invalid", invalid, total)
+ return typeOAuth2Session.Table(), result
}
-
- return nil
}
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
@@ -278,7 +345,7 @@ func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (valu
return p.decrypt(encryptedValue)
}
-func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, key *[32]byte, e sqlx.ExecerContext) (err error) {
+func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) {
valueClearText, err := uuid.NewRandom()
if err != nil {
return err
@@ -289,11 +356,7 @@ func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, key *[32]b
return err
}
- if e != nil {
- _, err = e.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
- } else {
- _, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
- }
+ _, err = conn.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
return err
}
diff --git a/internal/storage/sql_provider_queries.go b/internal/storage/sql_provider_queries.go
index 327ab546f..f062f8afb 100644
--- a/internal/storage/sql_provider_queries.go
+++ b/internal/storage/sql_provider_queries.go
@@ -83,18 +83,16 @@ const (
LIMIT ?
OFFSET ?;`
+ queryFmtSelectTOTPConfigurationsEncryptedData = `
+ SELECT id, secret
+ FROM %s;`
+
//nolint:gosec // These are not hardcoded credentials it's a query to obtain credentials.
queryFmtUpdateTOTPConfigurationSecret = `
UPDATE %s
SET secret = ?
WHERE id = ?;`
- //nolint:gosec // These are not hardcoded credentials it's a query to obtain credentials.
- queryFmtUpdateTOTPConfigurationSecretByUsername = `
- UPDATE %s
- SET secret = ?
- WHERE username = ?;`
-
queryFmtUpsertTOTPConfiguration = `
REPLACE INTO %s (created_at, last_used_at, username, issuer, algorithm, digits, period, secret)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`
@@ -127,6 +125,10 @@ const (
LIMIT ?
OFFSET ?;`
+ queryFmtSelectWebauthnDevicesEncryptedData = `
+ SELECT id, public_key
+ FROM %s;`
+
queryFmtSelectWebauthnDevicesByUsername = `
SELECT id, created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning
FROM %s
@@ -137,11 +139,6 @@ const (
SET public_key = ?
WHERE id = ?;`
- queryFmtUpdateUpdateWebauthnDevicePublicKeyByUsername = `
- UPDATE %s
- SET public_key = ?
- WHERE username = ? AND kid = ?;`
-
queryFmtUpdateWebauthnDeviceRecordSignIn = `
UPDATE %s
SET
@@ -265,6 +262,11 @@ const (
SET subject = ?
WHERE id = ?;`
+ queryFmtUpdateOAuth2ConsentSessionSessionData = `
+ UPDATE %s
+ SET session_data = ?
+ WHERE id = ?;`
+
queryFmtUpdateOAuth2ConsentSessionResponse = `
UPDATE %s
SET authorized = ?, responded_at = CURRENT_TIMESTAMP, granted_scopes = ?, granted_audience = ?, preconfiguration = ?
@@ -282,6 +284,10 @@ const (
FROM %s
WHERE signature = ? AND revoked = FALSE;`
+ queryFmtSelectOAuth2SessionEncryptedData = `
+ SELECT id, session_data
+ FROM %s;`
+
queryFmtInsertOAuth2Session = `
INSERT INTO %s (challenge_id, request_id, client_id, signature, subject, requested_at,
requested_scopes, granted_scopes, requested_audience, granted_audience,
diff --git a/internal/storage/sql_provider_queries_special.go b/internal/storage/sql_provider_queries_special.go
index 3023191ad..9b85e9150 100644
--- a/internal/storage/sql_provider_queries_special.go
+++ b/internal/storage/sql_provider_queries_special.go
@@ -1,8 +1,6 @@
package storage
const (
- queryFmtDropTableIfExists = `DROP TABLE IF EXISTS %s;`
-
queryFmtRenameTable = `
ALTER TABLE %s
RENAME TO %s;`
@@ -10,104 +8,10 @@ const (
queryFmtMySQLRenameTable = `
ALTER TABLE %s
RENAME %s;`
-)
-
-// Pre1 migration constants.
-const (
- queryFmtPre1To1SelectAuthenticationLogs = `
- SELECT username, successful, time
- FROM %s
- ORDER BY time ASC
- LIMIT 100 OFFSET ?;`
-
- queryFmtPre1To1InsertAuthenticationLogs = `
- INSERT INTO %s (username, successful, time, request_uri)
- VALUES (?, ?, ?, '');`
-
- queryFmtPre1InsertUserPreferencesFromSelect = `
- INSERT INTO %s (username, second_factor_method)
- SELECT username, second_factor_method
- FROM %s
- ORDER BY username ASC;`
-
- queryFmtPre1SelectTOTPConfigurations = `
- SELECT username, secret
- FROM %s
- ORDER BY username ASC;`
-
- queryFmtPre1To1InsertTOTPConfiguration = `
- INSERT INTO %s (username, issuer, period, secret)
- VALUES (?, ?, ?, ?);`
-
- queryFmt1ToPre1InsertTOTPConfiguration = `
- INSERT INTO %s (username, secret)
- VALUES (?, ?);`
-
- queryFmtPre1To1SelectU2FDevices = `
- SELECT username, keyHandle, publicKey
- FROM %s
- ORDER BY username ASC;`
-
- queryFmtPre1To1InsertU2FDevice = `
- INSERT INTO %s (username, key_handle, public_key)
- VALUES (?, ?, ?);`
-
- queryFmt1ToPre1InsertAuthenticationLogs = `
- INSERT INTO %s (username, successful, time)
- VALUES (?, ?, ?);`
-
- queryFmt1ToPre1SelectAuthenticationLogs = `
- SELECT username, successful, time
- FROM %s
- ORDER BY id ASC
- LIMIT 100 OFFSET ?;`
-
- queryFmt1ToPre1SelectU2FDevices = `
- SELECT username, key_handle, public_key
- FROM %s
- ORDER BY username ASC;`
- queryFmt1ToPre1InsertU2FDevice = `
- INSERT INTO %s (username, keyHandle, publicKey)
- VALUES (?, ?, ?);`
+ queryFmtPostgreSQLLockTable = `LOCK TABLE %s IN %s MODE;`
- queryCreatePre1 = `
- CREATE TABLE user_preferences (
- username VARCHAR(100),
- second_factor_method VARCHAR(11),
- PRIMARY KEY (username)
- );
-
- CREATE TABLE identity_verification_tokens (
- token VARCHAR(512)
- );
-
- CREATE TABLE totp_secrets (
- username VARCHAR(100),
- secret VARCHAR(64),
- PRIMARY KEY (username)
- );
-
- CREATE TABLE u2f_devices (
- username VARCHAR(100),
- keyHandle TEXT,
- publicKey TEXT,
- PRIMARY KEY (username)
- );
-
- CREATE TABLE authentication_logs (
- username VARCHAR(100),
- successful BOOL,
- time INTEGER
- );
-
- CREATE TABLE config (
- category VARCHAR(32) NOT NULL,
- key_name VARCHAR(32) NOT NULL,
- value TEXT,
- PRIMARY KEY (category, key_name)
- );
-
- INSERT INTO config (category, key_name, value)
- VALUES ('schema', 'version', '1');`
+ queryFmtSelectRowCount = `
+ SELECT COUNT(id)
+ FROM %s;`
)
diff --git a/internal/storage/sql_provider_schema.go b/internal/storage/sql_provider_schema.go
index e01aee1ad..8c015e963 100644
--- a/internal/storage/sql_provider_schema.go
+++ b/internal/storage/sql_provider_schema.go
@@ -81,15 +81,41 @@ func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error
return 0, nil
}
-func (p *SQLProvider) schemaLatestMigration(ctx context.Context) (migration *model.Migration, err error) {
- migration = &model.Migration{}
+// SchemaLatestVersion returns the latest version available for migration.
+func (p *SQLProvider) SchemaLatestVersion() (version int, err error) {
+ return latestMigrationVersion(p.name)
+}
- err = p.db.QueryRowxContext(ctx, p.sqlSelectLatestMigration).StructScan(migration)
+// SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version.
+func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) {
+ current, err := p.SchemaVersion(ctx)
if err != nil {
- return nil, err
+ return migrations, err
}
- return migration, nil
+ if version == 0 {
+ version = SchemaLatest
+ }
+
+ if current >= version {
+ return migrations, ErrNoAvailableMigrations
+ }
+
+ return loadMigrations(p.name, current, version)
+}
+
+// SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version.
+func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) {
+ current, err := p.SchemaVersion(ctx)
+ if err != nil {
+ return migrations, err
+ }
+
+ if current <= version {
+ return migrations, ErrNoAvailableMigrations
+ }
+
+ return loadMigrations(p.name, current, version)
}
// SchemaMigrationHistory returns migration history rows.
@@ -121,184 +147,185 @@ func (p *SQLProvider) SchemaMigrationHistory(ctx context.Context) (migrations []
// SchemaMigrate migrates from the current version to the provided version.
func (p *SQLProvider) SchemaMigrate(ctx context.Context, up bool, version int) (err error) {
- currentVersion, err := p.SchemaVersion(ctx)
- if err != nil {
- return err
- }
+ var (
+ tx *sqlx.Tx
+ conn SQLXConnection
+ )
+
+ if p.name != providerMySQL {
+ if tx, err = p.db.BeginTxx(ctx, nil); err != nil {
+ return fmt.Errorf("failed to begin transaction: %w", err)
+ }
- if err = schemaMigrateChecks(p.name, up, version, currentVersion); err != nil {
- return err
+ conn = tx
+ } else {
+ conn = p.db
}
- return p.schemaMigrate(ctx, currentVersion, version)
-}
-
-//nolint:gocyclo // TODO: Consider refactoring time permitting.
-func (p *SQLProvider) schemaMigrate(ctx context.Context, prior, target int) (err error) {
- migrations, err := loadMigrations(p.name, prior, target)
+ currentVersion, err := p.SchemaVersion(ctx)
if err != nil {
return err
}
- if len(migrations) == 0 && (prior != 1 || target != -1) {
- return ErrNoMigrationsFound
+ if currentVersion != 0 {
+ if err = p.schemaMigrateLock(ctx, conn); err != nil {
+ return err
+ }
}
- switch {
- case prior == -1:
- p.log.Infof(logFmtMigrationFromTo, "pre1", strconv.Itoa(migrations[len(migrations)-1].After()))
-
- err = p.schemaMigratePre1To1(ctx)
- if err != nil {
- if errRollback := p.schemaMigratePre1To1Rollback(ctx, true); errRollback != nil {
- return fmt.Errorf(errFmtFailedMigrationPre1, err)
- }
-
- return fmt.Errorf(errFmtFailedMigrationPre1, err)
+ if err = schemaMigrateChecks(p.name, up, version, currentVersion); err != nil {
+ if tx != nil {
+ _ = tx.Rollback()
}
- case target == -1:
- p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), "pre1")
- default:
- p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
+
+ return err
}
- for _, migration := range migrations {
- if prior == -1 && migration.Version == 1 {
- // Skip migration version 1 when upgrading from pre1 as it's applied as part of the pre1 upgrade.
- continue
+ if err = p.schemaMigrate(ctx, conn, currentVersion, version); err != nil {
+ if tx != nil && err == ErrNoMigrationsFound {
+ _ = tx.Rollback()
}
- err = p.schemaMigrateApply(ctx, migration)
- if err != nil {
- return p.schemaMigrateRollback(ctx, prior, migration.After(), err)
- }
+ return err
}
- switch {
- case prior == -1:
- p.log.Infof(logFmtMigrationComplete, "pre1", strconv.Itoa(migrations[len(migrations)-1].After()))
- case target == -1:
- err = p.schemaMigrate1ToPre1(ctx)
- if err != nil {
- if errRollback := p.schemaMigratePre1To1Rollback(ctx, false); errRollback != nil {
- return fmt.Errorf(errFmtFailedMigrationPre1, err)
+ if tx != nil {
+ if err = tx.Commit(); err != nil {
+ if rerr := tx.Rollback(); rerr != nil {
+ return fmt.Errorf("failed to commit the transaction with: commit error: %w, rollback error: %+v", err, rerr)
}
- return fmt.Errorf(errFmtFailedMigrationPre1, err)
+ return fmt.Errorf("failed to commit the transaction but it has been rolled back: commit error: %w", err)
}
-
- p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), "pre1")
- default:
- p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
}
return nil
}
-func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, prior, after int, migrateErr error) (err error) {
- migrations, err := loadMigrations(p.name, after, prior)
+func (p *SQLProvider) schemaMigrate(ctx context.Context, conn SQLXConnection, prior, target int) (err error) {
+ migrations, err := loadMigrations(p.name, prior, target)
if err != nil {
- return fmt.Errorf("error loading migrations from version %d to version %d for rollback: %+v. rollback caused by: %+v", prior, after, err, migrateErr)
+ return err
}
- for _, migration := range migrations {
- if prior == -1 && !migration.Up && migration.Version == 1 {
- continue
+ if len(migrations) == 0 {
+ return ErrNoMigrationsFound
+ }
+
+ p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
+
+ for i, migration := range migrations {
+ if migration.Up && prior == 0 && i == 1 {
+ if err = p.schemaMigrateLock(ctx, conn); err != nil {
+ return err
+ }
}
- err = p.schemaMigrateApply(ctx, migration)
- if err != nil {
- return fmt.Errorf("error applying migration version %d to version %d for rollback: %+v. rollback caused by: %+v", migration.Before(), migration.After(), err, migrateErr)
+ if err = p.schemaMigrateApply(ctx, conn, migration); err != nil {
+ return p.schemaMigrateRollback(ctx, conn, prior, migration.After(), err)
}
}
- if prior == -1 {
- if err = p.schemaMigrate1ToPre1(ctx); err != nil {
- return fmt.Errorf("error applying migration version 1 to version pre1 for rollback: %+v. rollback caused by: %+v", err, migrateErr)
- }
+ p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
+
+ return nil
+}
+
+func (p *SQLProvider) schemaMigrateLock(ctx context.Context, conn SQLXConnection) (err error) {
+ if p.name != providerPostgres {
+ return nil
+ }
+
+ if _, err = conn.ExecContext(ctx, fmt.Sprintf(queryFmtPostgreSQLLockTable, tableMigrations, "ACCESS EXCLUSIVE")); err != nil {
+ return fmt.Errorf("failed to lock tables: %w", err)
}
- return fmt.Errorf("migration rollback complete. rollback caused by: %+v", migrateErr)
+ return nil
}
-func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration model.SchemaMigration) (err error) {
- _, err = p.db.ExecContext(ctx, migration.Query)
- if err != nil {
+func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
+ if _, err = conn.ExecContext(ctx, migration.Query); err != nil {
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
}
- if migration.Version == 1 {
- // Skip the migration history insertion in a migration to v0.
- if !migration.Up {
- return nil
- }
-
+ if migration.Version == 1 && migration.Up {
// Add the schema encryption value if upgrading to v1.
- if err = p.setNewEncryptionCheckValue(ctx, &p.key, nil); err != nil {
+ if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil {
return err
}
}
- if migration.Version == 1 && !migration.Up {
- return nil
+ if err = p.schemaMigrateFinalize(ctx, conn, migration); err != nil {
+ return err
}
- return p.schemaMigrateFinalize(ctx, migration)
+ return nil
}
-func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) {
- return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After())
-}
+func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
+ if migration.Version == 1 && !migration.Up {
+ return nil
+ }
-func (p *SQLProvider) schemaMigrateFinalizeAdvanced(ctx context.Context, before, after int) (err error) {
- _, err = p.db.ExecContext(ctx, p.sqlInsertMigration, time.Now(), before, after, utils.Version())
- if err != nil {
- return err
+ if _, err = conn.ExecContext(ctx, p.sqlInsertMigration, time.Now(), migration.Before(), migration.After(), utils.Version()); err != nil {
+ return fmt.Errorf("failed inserting migration record: %w", err)
}
- p.log.Debugf("Storage schema migrated from version %d to %d", before, after)
+ p.log.Debugf("Storage schema migrated from version %d to %d", migration.Before(), migration.After())
return nil
}
-// SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version.
-func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) {
- current, err := p.SchemaVersion(ctx)
- if err != nil {
- return migrations, err
- }
-
- if version == 0 {
- version = SchemaLatest
+func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, conn SQLXConnection, prior, after int, merr error) (err error) {
+ switch tx := conn.(type) {
+ case *sqlx.Tx:
+ return p.schemaMigrateRollbackWithTx(ctx, tx, merr)
+ default:
+ return p.schemaMigrateRollbackWithoutTx(ctx, prior, after, merr)
}
+}
- if current >= version {
- return migrations, ErrNoAvailableMigrations
+func (p *SQLProvider) schemaMigrateRollbackWithTx(_ context.Context, tx *sqlx.Tx, merr error) (err error) {
+ if err = tx.Rollback(); err != nil {
+ return fmt.Errorf("error applying rollback %+v. rollback caused by: %w", err, merr)
}
- return loadMigrations(p.name, current, version)
+ return fmt.Errorf("migration rollback complete. rollback caused by: %w", merr)
}
-// SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version.
-func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) {
- current, err := p.SchemaVersion(ctx)
+func (p *SQLProvider) schemaMigrateRollbackWithoutTx(ctx context.Context, prior, after int, merr error) (err error) {
+ migrations, err := loadMigrations(p.name, after, prior)
if err != nil {
- return migrations, err
+ return fmt.Errorf("error loading migrations from version %d to version %d for rollback: %+v. rollback caused by: %w", prior, after, err, merr)
}
- if current <= version {
- return migrations, ErrNoAvailableMigrations
+ for _, migration := range migrations {
+ if err = p.schemaMigrateApply(ctx, p.db, migration); err != nil {
+ return fmt.Errorf("error applying migration version %d to version %d for rollback: %+v. rollback caused by: %w", migration.Before(), migration.After(), err, merr)
+ }
}
- return loadMigrations(p.name, current, version)
+ return fmt.Errorf("migration rollback complete. rollback caused by: %w", merr)
}
-// SchemaLatestVersion returns the latest version available for migration.
-func (p *SQLProvider) SchemaLatestVersion() (version int, err error) {
- return latestMigrationVersion(p.name)
+func (p *SQLProvider) schemaLatestMigration(ctx context.Context) (migration *model.Migration, err error) {
+ migration = &model.Migration{}
+
+ if err = p.db.QueryRowxContext(ctx, p.sqlSelectLatestMigration).StructScan(migration); err != nil {
+ return nil, err
+ }
+
+ return migration, nil
}
func schemaMigrateChecks(providerName string, up bool, targetVersion, currentVersion int) (err error) {
+ switch {
+ case currentVersion == -1:
+ return fmt.Errorf(errFmtMigrationPre1, "up from", errFmtMigrationPre1SuggestedVersion)
+ case targetVersion == -1:
+ return fmt.Errorf(errFmtMigrationPre1, "down to", fmt.Sprintf("you should downgrade to schema version 1 using the current authelia version then use the suggested authelia version to downgrade to pre1: %s", errFmtMigrationPre1SuggestedVersion))
+ }
+
if targetVersion == currentVersion {
return fmt.Errorf(ErrFmtMigrateAlreadyOnTargetVersion, targetVersion, currentVersion)
}
@@ -325,7 +352,7 @@ func schemaMigrateChecks(providerName string, up bool, targetVersion, currentVer
return fmt.Errorf(ErrFmtMigrateUpTargetGreaterThanLatest, targetVersion, latest)
}
} else {
- if targetVersion < -1 {
+ if targetVersion < 0 {
return fmt.Errorf(ErrFmtMigrateDownTargetLessThanMinimum, targetVersion)
}
@@ -345,7 +372,7 @@ func SchemaVersionToString(version int) (versionStr string) {
case -1:
return "pre1"
case 0:
- return "N/A"
+ return na
default:
return strconv.Itoa(version)
}
diff --git a/internal/storage/sql_provider_schema_pre1.go b/internal/storage/sql_provider_schema_pre1.go
deleted file mode 100644
index 04587bfb7..000000000
--- a/internal/storage/sql_provider_schema_pre1.go
+++ /dev/null
@@ -1,470 +0,0 @@
-package storage
-
-import (
- "context"
- "database/sql"
- "encoding/base64"
- "fmt"
- "strings"
- "time"
-
- "github.com/authelia/authelia/v4/internal/model"
- "github.com/authelia/authelia/v4/internal/utils"
-)
-
-// schemaMigratePre1To1 takes the v1 migration and migrates to this version.
-func (p *SQLProvider) schemaMigratePre1To1(ctx context.Context) (err error) {
- migration, err := loadMigration(p.name, 1, true)
- if err != nil {
- return err
- }
-
- // Get Tables list.
- tables, err := p.SchemaTables(ctx)
- if err != nil {
- return err
- }
-
- tablesRename := []string{
- tablePre1Config,
- tablePre1TOTPSecrets,
- tablePre1IdentityVerificationTokens,
- tablePre1U2FDevices,
- tableUserPreferences,
- tableAuthenticationLogs,
- tableAlphaPreferences,
- tableAlphaIdentityVerificationTokens,
- tableAlphaAuthenticationLogs,
- tableAlphaPreferencesTableName,
- tableAlphaSecondFactorPreferences,
- tableAlphaTOTPSecrets,
- tableAlphaU2FDeviceHandles,
- }
-
- if err = p.schemaMigratePre1Rename(ctx, tables, tablesRename); err != nil {
- return err
- }
-
- if _, err = p.db.ExecContext(ctx, migration.Query); err != nil {
- return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
- }
-
- if err = p.setNewEncryptionCheckValue(ctx, &p.key, nil); err != nil {
- return err
- }
-
- if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertUserPreferencesFromSelect),
- tableUserPreferences, tablePrefixBackup+tableUserPreferences)); err != nil {
- return err
- }
-
- if err = p.schemaMigratePre1To1AuthenticationLogs(ctx); err != nil {
- return err
- }
-
- if err = p.schemaMigratePre1To1U2F(ctx); err != nil {
- return err
- }
-
- if err = p.schemaMigratePre1To1TOTP(ctx); err != nil {
- return err
- }
-
- for _, table := range tablesRename {
- if _, err = p.db.Exec(fmt.Sprintf(p.db.Rebind(queryFmtDropTableIfExists), tablePrefixBackup+table)); err != nil {
- return err
- }
- }
-
- return p.schemaMigrateFinalizeAdvanced(ctx, -1, 1)
-}
-
-func (p *SQLProvider) schemaMigratePre1Rename(ctx context.Context, tables, tablesRename []string) (err error) {
- // Rename Tables and Indexes.
- for _, table := range tables {
- if !utils.IsStringInSlice(table, tablesRename) {
- continue
- }
-
- tableNew := tablePrefixBackup + table
-
- if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.sqlFmtRenameTable, table, tableNew)); err != nil {
- return err
- }
-
- if p.name == providerPostgres {
- if table == tablePre1U2FDevices || table == tableUserPreferences {
- if _, err = p.db.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME CONSTRAINT %s_pkey TO %s_pkey;`,
- tableNew, table, tableNew)); err != nil {
- continue
- }
- }
- }
- }
-
- return nil
-}
-
-func (p *SQLProvider) schemaMigratePre1To1Rollback(ctx context.Context, up bool) (err error) {
- if up {
- migration, err := loadMigration(p.name, 1, false)
- if err != nil {
- return err
- }
-
- if _, err = p.db.ExecContext(ctx, migration.Query); err != nil {
- return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
- }
- }
-
- tables, err := p.SchemaTables(ctx)
- if err != nil {
- return err
- }
-
- for _, table := range tables {
- if !strings.HasPrefix(table, tablePrefixBackup) {
- continue
- }
-
- tableNew := strings.Replace(table, tablePrefixBackup, "", 1)
- if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.sqlFmtRenameTable, table, tableNew)); err != nil {
- return err
- }
-
- if p.name == providerPostgres && (tableNew == tablePre1U2FDevices || tableNew == tableUserPreferences) {
- if _, err = p.db.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME CONSTRAINT %s_pkey TO %s_pkey;`,
- tableNew, table, tableNew)); err != nil {
- continue
- }
- }
- }
-
- return nil
-}
-
-func (p *SQLProvider) schemaMigratePre1To1AuthenticationLogs(ctx context.Context) (err error) {
- for page := 0; true; page++ {
- attempts, err := p.schemaMigratePre1To1AuthenticationLogsGetRows(ctx, page)
- if err != nil {
- if err == sql.ErrNoRows {
- break
- }
-
- return err
- }
-
- for _, attempt := range attempts {
- _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1InsertAuthenticationLogs), tableAuthenticationLogs), attempt.Username, attempt.Successful, attempt.Time)
- if err != nil {
- return err
- }
- }
-
- if len(attempts) != 100 {
- break
- }
- }
-
- return nil
-}
-
-func (p *SQLProvider) schemaMigratePre1To1AuthenticationLogsGetRows(ctx context.Context, page int) (attempts []model.AuthenticationAttempt, err error) {
- rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1SelectAuthenticationLogs), tablePrefixBackup+tableAuthenticationLogs), page*100)
- if err != nil {
- return nil, err
- }
-
- attempts = make([]model.AuthenticationAttempt, 0, 100)
-
- for rows.Next() {
- var (
- username string
- successful bool
- timestamp int64
- )
-
- err = rows.Scan(&username, &successful, &timestamp)
- if err != nil {
- return nil, err
- }
-
- attempts = append(attempts, model.AuthenticationAttempt{Username: username, Successful: successful, Time: time.Unix(timestamp, 0)})
- }
-
- return attempts, nil
-}
-
-func (p *SQLProvider) schemaMigratePre1To1TOTP(ctx context.Context) (err error) {
- rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1SelectTOTPConfigurations), tablePrefixBackup+tablePre1TOTPSecrets))
- if err != nil {
- return err
- }
-
- var totpConfigs []model.TOTPConfiguration
-
- defer func() {
- if err := rows.Close(); err != nil {
- p.log.Errorf(logFmtErrClosingConn, err)
- }
- }()
-
- for rows.Next() {
- var username, secret string
-
- err = rows.Scan(&username, &secret)
- if err != nil {
- return err
- }
-
- encryptedSecret, err := p.encrypt([]byte(secret))
- if err != nil {
- return err
- }
-
- totpConfigs = append(totpConfigs, model.TOTPConfiguration{Username: username, Secret: encryptedSecret})
- }
-
- for _, config := range totpConfigs {
- _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1InsertTOTPConfiguration), tableTOTPConfigurations), config.Username, p.config.TOTP.Issuer, p.config.TOTP.Period, config.Secret)
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (p *SQLProvider) schemaMigratePre1To1U2F(ctx context.Context) (err error) {
- rows, err := p.db.Queryx(fmt.Sprintf(p.db.Rebind(queryFmtPre1To1SelectU2FDevices), tablePrefixBackup+tablePre1U2FDevices))
- if err != nil {
- return err
- }
-
- defer func() {
- if err := rows.Close(); err != nil {
- p.log.Errorf(logFmtErrClosingConn, err)
- }
- }()
-
- var devices []model.U2FDevice
-
- for rows.Next() {
- var username, keyHandleBase64, publicKeyBase64 string
-
- err = rows.Scan(&username, &keyHandleBase64, &publicKeyBase64)
- if err != nil {
- return err
- }
-
- keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64)
- if err != nil {
- return err
- }
-
- publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64)
- if err != nil {
- return err
- }
-
- encryptedPublicKey, err := p.encrypt(publicKey)
- if err != nil {
- return err
- }
-
- devices = append(devices, model.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: encryptedPublicKey})
- }
-
- for _, device := range devices {
- _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1InsertU2FDevice), tablePre1U2FDevices), device.Username, device.KeyHandle, device.PublicKey)
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (p *SQLProvider) schemaMigrate1ToPre1(ctx context.Context) (err error) {
- tables, err := p.SchemaTables(ctx)
- if err != nil {
- return err
- }
-
- tablesRename := []string{
- tableMigrations,
- tableTOTPConfigurations,
- tableIdentityVerification,
- tablePre1U2FDevices,
- tableDuoDevices,
- tableUserPreferences,
- tableAuthenticationLogs,
- tableEncryption,
- }
-
- if err = p.schemaMigratePre1Rename(ctx, tables, tablesRename); err != nil {
- return err
- }
-
- if _, err := p.db.ExecContext(ctx, queryCreatePre1); err != nil {
- return err
- }
-
- if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertUserPreferencesFromSelect),
- tableUserPreferences, tablePrefixBackup+tableUserPreferences)); err != nil {
- return err
- }
-
- if err = p.schemaMigrate1ToPre1AuthenticationLogs(ctx); err != nil {
- return err
- }
-
- if err = p.schemaMigrate1ToPre1U2F(ctx); err != nil {
- return err
- }
-
- if err = p.schemaMigrate1ToPre1TOTP(ctx); err != nil {
- return err
- }
-
- queryFmtDropTableRebound := p.db.Rebind(queryFmtDropTableIfExists)
-
- for _, table := range tablesRename {
- if _, err = p.db.Exec(fmt.Sprintf(queryFmtDropTableRebound, tablePrefixBackup+table)); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (p *SQLProvider) schemaMigrate1ToPre1AuthenticationLogs(ctx context.Context) (err error) {
- for page := 0; true; page++ {
- attempts, err := p.schemaMigrate1ToPre1AuthenticationLogsGetRows(ctx, page)
- if err != nil {
- if err == sql.ErrNoRows {
- break
- }
-
- return err
- }
-
- for _, attempt := range attempts {
- _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1InsertAuthenticationLogs), tableAuthenticationLogs), attempt.Username, attempt.Successful, attempt.Time.Unix())
- if err != nil {
- return err
- }
- }
-
- if len(attempts) != 100 {
- break
- }
- }
-
- return nil
-}
-
-func (p *SQLProvider) schemaMigrate1ToPre1AuthenticationLogsGetRows(ctx context.Context, page int) (attempts []model.AuthenticationAttempt, err error) {
- rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1SelectAuthenticationLogs), tablePrefixBackup+tableAuthenticationLogs), page*100)
- if err != nil {
- return nil, err
- }
-
- attempts = make([]model.AuthenticationAttempt, 0, 100)
-
- var attempt model.AuthenticationAttempt
- for rows.Next() {
- err = rows.StructScan(&attempt)
- if err != nil {
- return nil, err
- }
-
- attempts = append(attempts, attempt)
- }
-
- return attempts, nil
-}
-
-func (p *SQLProvider) schemaMigrate1ToPre1TOTP(ctx context.Context) (err error) {
- rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1SelectTOTPConfigurations), tablePrefixBackup+tableTOTPConfigurations))
- if err != nil {
- return err
- }
-
- var totpConfigs []model.TOTPConfiguration
-
- defer func() {
- if err := rows.Close(); err != nil {
- p.log.Errorf(logFmtErrClosingConn, err)
- }
- }()
-
- for rows.Next() {
- var (
- username string
- secretCipherText []byte
- )
-
- err = rows.Scan(&username, &secretCipherText)
- if err != nil {
- return err
- }
-
- secretClearText, err := p.decrypt(secretCipherText)
- if err != nil {
- return err
- }
-
- totpConfigs = append(totpConfigs, model.TOTPConfiguration{Username: username, Secret: secretClearText})
- }
-
- for _, config := range totpConfigs {
- _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1InsertTOTPConfiguration), tablePre1TOTPSecrets), config.Username, config.Secret)
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (p *SQLProvider) schemaMigrate1ToPre1U2F(ctx context.Context) (err error) {
- rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1SelectU2FDevices), tablePrefixBackup+tablePre1U2FDevices))
- if err != nil {
- return err
- }
-
- defer func() {
- if err := rows.Close(); err != nil {
- p.log.Errorf(logFmtErrClosingConn, err)
- }
- }()
-
- var (
- devices []model.U2FDevice
- device model.U2FDevice
- )
-
- for rows.Next() {
- err = rows.StructScan(&device)
- if err != nil {
- return err
- }
-
- device.PublicKey, err = p.decrypt(device.PublicKey)
- if err != nil {
- return err
- }
-
- devices = append(devices, device)
- }
-
- for _, device := range devices {
- _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1InsertU2FDevice), tablePre1U2FDevices), device.Username, base64.StdEncoding.EncodeToString(device.KeyHandle), base64.StdEncoding.EncodeToString(device.PublicKey))
- if err != nil {
- return err
- }
- }
-
- return nil
-}
diff --git a/internal/storage/sql_provider_schema_test.go b/internal/storage/sql_provider_schema_test.go
index 16352e451..c4e9868be 100644
--- a/internal/storage/sql_provider_schema_test.go
+++ b/internal/storage/sql_provider_schema_test.go
@@ -29,7 +29,7 @@ func TestShouldReturnErrOnTargetSameAsCurrent(t *testing.T) {
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1))
}
-func TestShouldReturnErrOnUpMigrationTargetVersionLessTHanCurrent(t *testing.T) {
+func TestShouldReturnErrOnUpMigrationTargetVersionLessThanCurrent(t *testing.T) {
assert.EqualError(t,
schemaMigrateChecks(providerPostgres, true, 0, LatestVersion),
fmt.Sprintf(ErrFmtMigrateUpTargetLessThanCurrent, 0, LatestVersion))
@@ -80,7 +80,7 @@ func TestShouldReturnErrOnVersionDoesntExits(t *testing.T) {
fmt.Sprintf(ErrFmtMigrateUpTargetGreaterThanLatest, SchemaLatest-1, LatestVersion))
}
-func TestMigrationDownShouldReturnErrOnTargetLessThanPre1(t *testing.T) {
+func TestMigrationDownShouldReturnErrOnTargetLessThan1(t *testing.T) {
assert.EqualError(t,
schemaMigrateChecks(providerSQLite, false, -4, LatestVersion),
fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -4))
@@ -93,8 +93,15 @@ func TestMigrationDownShouldReturnErrOnTargetLessThanPre1(t *testing.T) {
schemaMigrateChecks(providerPostgres, false, -2, LatestVersion),
fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -2))
- assert.NoError(t,
- schemaMigrateChecks(providerPostgres, false, -1, LatestVersion))
+ assert.EqualError(t,
+ schemaMigrateChecks(providerPostgres, false, -1, LatestVersion),
+ "schema migration down to pre1 is no longer supported: you must use an older version of authelia to perform this migration: you should downgrade to schema version 1 using the current authelia version then use the suggested authelia version to downgrade to pre1: the suggested authelia version is 4.37.2")
+}
+
+func TestMigrationDownShouldReturnErrOnCurrentLessThan0(t *testing.T) {
+ assert.EqualError(t,
+ schemaMigrateChecks(providerPostgres, true, LatestVersion, -1),
+ "schema migration up from pre1 is no longer supported: you must use an older version of authelia to perform this migration: the suggested authelia version is 4.37.2")
}
func TestMigrationDownShouldReturnErrOnTargetVersionGreaterThanCurrent(t *testing.T) {
diff --git a/internal/storage/types.go b/internal/storage/types.go
new file mode 100644
index 000000000..327f52c4d
--- /dev/null
+++ b/internal/storage/types.go
@@ -0,0 +1,95 @@
+package storage
+
+import (
+ "context"
+
+ "github.com/jmoiron/sqlx"
+)
+
+// SQLXConnection is a *sqlx.DB or *sqlx.Tx.
+type SQLXConnection interface {
+ sqlx.Execer
+ sqlx.ExecerContext
+
+ sqlx.Preparer
+ sqlx.PreparerContext
+
+ sqlx.Queryer
+ sqlx.QueryerContext
+
+ sqlx.Ext
+ sqlx.ExtContext
+}
+
+// EncryptionChangeKeyFunc handles encryption key changes for a specific table or tables.
+type EncryptionChangeKeyFunc func(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error)
+
+// EncryptionCheckKeyFunc handles encryption key checking for a specific table or tables.
+type EncryptionCheckKeyFunc func(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult)
+
+type encOAuth2Session struct {
+ ID int `db:"id"`
+ Session []byte `db:"session_data"`
+}
+
+type encWebauthnDevice struct {
+ ID int `db:"id"`
+ PublicKey []byte `db:"public_key"`
+}
+
+type encTOTPConfiguration struct {
+ ID int `db:"id" json:"-"`
+ Secret []byte `db:"secret" json:"-"`
+}
+
+// EncryptionValidationResult contains information about the success of a schema encryption validation.
+type EncryptionValidationResult struct {
+ InvalidCheckValue bool
+ Tables map[string]EncryptionValidationTableResult
+}
+
+// Success returns true if no validation errors occurred.
+func (r EncryptionValidationResult) Success() bool {
+ if r.InvalidCheckValue {
+ return false
+ }
+
+ for _, table := range r.Tables {
+ if table.Invalid != 0 || table.Error != nil {
+ return false
+ }
+ }
+
+ return true
+}
+
+// Checked returns true the validation completed all phases even if there were errors.
+func (r EncryptionValidationResult) Checked() bool {
+ for _, table := range r.Tables {
+ if table.Error != nil {
+ return false
+ }
+ }
+
+ return true
+}
+
+// EncryptionValidationTableResult contains information about the success of a table schema encryption validation.
+type EncryptionValidationTableResult struct {
+ Error error
+ Total int
+ Invalid int
+}
+
+// ResultDescriptor returns a string representing the result.
+func (r EncryptionValidationTableResult) ResultDescriptor() string {
+ if r.Total == 0 {
+ return na
+ }
+
+ if r.Error != nil || r.Invalid != 0 {
+ return "FAILURE"
+ }
+
+ return "SUCCESS"
+}