summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider_schema.go
diff options
context:
space:
mode:
authorJames Elliott <james-d-elliott@users.noreply.github.com>2022-11-25 23:44:55 +1100
committerGitHub <noreply@github.com>2022-11-25 23:44:55 +1100
commit3e4ac7821d51ac447bb39e7e1ea3c385dc3084d9 (patch)
tree69594576856eb8b587158d9245f70aff4fec429a /internal/storage/sql_provider_schema.go
parent3c291b5685212813f98f365c8d963e0f107860cb (diff)
refactor: remove pre1 migration path (#4356)
This removes pre1 migrations and improves a lot of tooling.
Diffstat (limited to 'internal/storage/sql_provider_schema.go')
-rw-r--r--internal/storage/sql_provider_schema.go255
1 files changed, 141 insertions, 114 deletions
diff --git a/internal/storage/sql_provider_schema.go b/internal/storage/sql_provider_schema.go
index e01aee1ad..8c015e963 100644
--- a/internal/storage/sql_provider_schema.go
+++ b/internal/storage/sql_provider_schema.go
@@ -81,15 +81,41 @@ func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error
return 0, nil
}
-func (p *SQLProvider) schemaLatestMigration(ctx context.Context) (migration *model.Migration, err error) {
- migration = &model.Migration{}
+// SchemaLatestVersion returns the latest version available for migration.
+func (p *SQLProvider) SchemaLatestVersion() (version int, err error) {
+ return latestMigrationVersion(p.name)
+}
- err = p.db.QueryRowxContext(ctx, p.sqlSelectLatestMigration).StructScan(migration)
+// SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version.
+func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) {
+ current, err := p.SchemaVersion(ctx)
if err != nil {
- return nil, err
+ return migrations, err
}
- return migration, nil
+ if version == 0 {
+ version = SchemaLatest
+ }
+
+ if current >= version {
+ return migrations, ErrNoAvailableMigrations
+ }
+
+ return loadMigrations(p.name, current, version)
+}
+
+// SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version.
+func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) {
+ current, err := p.SchemaVersion(ctx)
+ if err != nil {
+ return migrations, err
+ }
+
+ if current <= version {
+ return migrations, ErrNoAvailableMigrations
+ }
+
+ return loadMigrations(p.name, current, version)
}
// SchemaMigrationHistory returns migration history rows.
@@ -121,184 +147,185 @@ func (p *SQLProvider) SchemaMigrationHistory(ctx context.Context) (migrations []
// SchemaMigrate migrates from the current version to the provided version.
func (p *SQLProvider) SchemaMigrate(ctx context.Context, up bool, version int) (err error) {
- currentVersion, err := p.SchemaVersion(ctx)
- if err != nil {
- return err
- }
+ var (
+ tx *sqlx.Tx
+ conn SQLXConnection
+ )
+
+ if p.name != providerMySQL {
+ if tx, err = p.db.BeginTxx(ctx, nil); err != nil {
+ return fmt.Errorf("failed to begin transaction: %w", err)
+ }
- if err = schemaMigrateChecks(p.name, up, version, currentVersion); err != nil {
- return err
+ conn = tx
+ } else {
+ conn = p.db
}
- return p.schemaMigrate(ctx, currentVersion, version)
-}
-
-//nolint:gocyclo // TODO: Consider refactoring time permitting.
-func (p *SQLProvider) schemaMigrate(ctx context.Context, prior, target int) (err error) {
- migrations, err := loadMigrations(p.name, prior, target)
+ currentVersion, err := p.SchemaVersion(ctx)
if err != nil {
return err
}
- if len(migrations) == 0 && (prior != 1 || target != -1) {
- return ErrNoMigrationsFound
+ if currentVersion != 0 {
+ if err = p.schemaMigrateLock(ctx, conn); err != nil {
+ return err
+ }
}
- switch {
- case prior == -1:
- p.log.Infof(logFmtMigrationFromTo, "pre1", strconv.Itoa(migrations[len(migrations)-1].After()))
-
- err = p.schemaMigratePre1To1(ctx)
- if err != nil {
- if errRollback := p.schemaMigratePre1To1Rollback(ctx, true); errRollback != nil {
- return fmt.Errorf(errFmtFailedMigrationPre1, err)
- }
-
- return fmt.Errorf(errFmtFailedMigrationPre1, err)
+ if err = schemaMigrateChecks(p.name, up, version, currentVersion); err != nil {
+ if tx != nil {
+ _ = tx.Rollback()
}
- case target == -1:
- p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), "pre1")
- default:
- p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
+
+ return err
}
- for _, migration := range migrations {
- if prior == -1 && migration.Version == 1 {
- // Skip migration version 1 when upgrading from pre1 as it's applied as part of the pre1 upgrade.
- continue
+ if err = p.schemaMigrate(ctx, conn, currentVersion, version); err != nil {
+ if tx != nil && err == ErrNoMigrationsFound {
+ _ = tx.Rollback()
}
- err = p.schemaMigrateApply(ctx, migration)
- if err != nil {
- return p.schemaMigrateRollback(ctx, prior, migration.After(), err)
- }
+ return err
}
- switch {
- case prior == -1:
- p.log.Infof(logFmtMigrationComplete, "pre1", strconv.Itoa(migrations[len(migrations)-1].After()))
- case target == -1:
- err = p.schemaMigrate1ToPre1(ctx)
- if err != nil {
- if errRollback := p.schemaMigratePre1To1Rollback(ctx, false); errRollback != nil {
- return fmt.Errorf(errFmtFailedMigrationPre1, err)
+ if tx != nil {
+ if err = tx.Commit(); err != nil {
+ if rerr := tx.Rollback(); rerr != nil {
+ return fmt.Errorf("failed to commit the transaction with: commit error: %w, rollback error: %+v", err, rerr)
}
- return fmt.Errorf(errFmtFailedMigrationPre1, err)
+ return fmt.Errorf("failed to commit the transaction but it has been rolled back: commit error: %w", err)
}
-
- p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), "pre1")
- default:
- p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
}
return nil
}
-func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, prior, after int, migrateErr error) (err error) {
- migrations, err := loadMigrations(p.name, after, prior)
+func (p *SQLProvider) schemaMigrate(ctx context.Context, conn SQLXConnection, prior, target int) (err error) {
+ migrations, err := loadMigrations(p.name, prior, target)
if err != nil {
- return fmt.Errorf("error loading migrations from version %d to version %d for rollback: %+v. rollback caused by: %+v", prior, after, err, migrateErr)
+ return err
}
- for _, migration := range migrations {
- if prior == -1 && !migration.Up && migration.Version == 1 {
- continue
+ if len(migrations) == 0 {
+ return ErrNoMigrationsFound
+ }
+
+ p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
+
+ for i, migration := range migrations {
+ if migration.Up && prior == 0 && i == 1 {
+ if err = p.schemaMigrateLock(ctx, conn); err != nil {
+ return err
+ }
}
- err = p.schemaMigrateApply(ctx, migration)
- if err != nil {
- return fmt.Errorf("error applying migration version %d to version %d for rollback: %+v. rollback caused by: %+v", migration.Before(), migration.After(), err, migrateErr)
+ if err = p.schemaMigrateApply(ctx, conn, migration); err != nil {
+ return p.schemaMigrateRollback(ctx, conn, prior, migration.After(), err)
}
}
- if prior == -1 {
- if err = p.schemaMigrate1ToPre1(ctx); err != nil {
- return fmt.Errorf("error applying migration version 1 to version pre1 for rollback: %+v. rollback caused by: %+v", err, migrateErr)
- }
+ p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
+
+ return nil
+}
+
+func (p *SQLProvider) schemaMigrateLock(ctx context.Context, conn SQLXConnection) (err error) {
+ if p.name != providerPostgres {
+ return nil
+ }
+
+ if _, err = conn.ExecContext(ctx, fmt.Sprintf(queryFmtPostgreSQLLockTable, tableMigrations, "ACCESS EXCLUSIVE")); err != nil {
+ return fmt.Errorf("failed to lock tables: %w", err)
}
- return fmt.Errorf("migration rollback complete. rollback caused by: %+v", migrateErr)
+ return nil
}
-func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration model.SchemaMigration) (err error) {
- _, err = p.db.ExecContext(ctx, migration.Query)
- if err != nil {
+func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
+ if _, err = conn.ExecContext(ctx, migration.Query); err != nil {
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
}
- if migration.Version == 1 {
- // Skip the migration history insertion in a migration to v0.
- if !migration.Up {
- return nil
- }
-
+ if migration.Version == 1 && migration.Up {
// Add the schema encryption value if upgrading to v1.
- if err = p.setNewEncryptionCheckValue(ctx, &p.key, nil); err != nil {
+ if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil {
return err
}
}
- if migration.Version == 1 && !migration.Up {
- return nil
+ if err = p.schemaMigrateFinalize(ctx, conn, migration); err != nil {
+ return err
}
- return p.schemaMigrateFinalize(ctx, migration)
+ return nil
}
-func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) {
- return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After())
-}
+func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
+ if migration.Version == 1 && !migration.Up {
+ return nil
+ }
-func (p *SQLProvider) schemaMigrateFinalizeAdvanced(ctx context.Context, before, after int) (err error) {
- _, err = p.db.ExecContext(ctx, p.sqlInsertMigration, time.Now(), before, after, utils.Version())
- if err != nil {
- return err
+ if _, err = conn.ExecContext(ctx, p.sqlInsertMigration, time.Now(), migration.Before(), migration.After(), utils.Version()); err != nil {
+ return fmt.Errorf("failed inserting migration record: %w", err)
}
- p.log.Debugf("Storage schema migrated from version %d to %d", before, after)
+ p.log.Debugf("Storage schema migrated from version %d to %d", migration.Before(), migration.After())
return nil
}
-// SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version.
-func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) {
- current, err := p.SchemaVersion(ctx)
- if err != nil {
- return migrations, err
- }
-
- if version == 0 {
- version = SchemaLatest
+func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, conn SQLXConnection, prior, after int, merr error) (err error) {
+ switch tx := conn.(type) {
+ case *sqlx.Tx:
+ return p.schemaMigrateRollbackWithTx(ctx, tx, merr)
+ default:
+ return p.schemaMigrateRollbackWithoutTx(ctx, prior, after, merr)
}
+}
- if current >= version {
- return migrations, ErrNoAvailableMigrations
+func (p *SQLProvider) schemaMigrateRollbackWithTx(_ context.Context, tx *sqlx.Tx, merr error) (err error) {
+ if err = tx.Rollback(); err != nil {
+ return fmt.Errorf("error applying rollback %+v. rollback caused by: %w", err, merr)
}
- return loadMigrations(p.name, current, version)
+ return fmt.Errorf("migration rollback complete. rollback caused by: %w", merr)
}
-// SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version.
-func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) {
- current, err := p.SchemaVersion(ctx)
+func (p *SQLProvider) schemaMigrateRollbackWithoutTx(ctx context.Context, prior, after int, merr error) (err error) {
+ migrations, err := loadMigrations(p.name, after, prior)
if err != nil {
- return migrations, err
+ return fmt.Errorf("error loading migrations from version %d to version %d for rollback: %+v. rollback caused by: %w", prior, after, err, merr)
}
- if current <= version {
- return migrations, ErrNoAvailableMigrations
+ for _, migration := range migrations {
+ if err = p.schemaMigrateApply(ctx, p.db, migration); err != nil {
+ return fmt.Errorf("error applying migration version %d to version %d for rollback: %+v. rollback caused by: %w", migration.Before(), migration.After(), err, merr)
+ }
}
- return loadMigrations(p.name, current, version)
+ return fmt.Errorf("migration rollback complete. rollback caused by: %w", merr)
}
-// SchemaLatestVersion returns the latest version available for migration.
-func (p *SQLProvider) SchemaLatestVersion() (version int, err error) {
- return latestMigrationVersion(p.name)
+func (p *SQLProvider) schemaLatestMigration(ctx context.Context) (migration *model.Migration, err error) {
+ migration = &model.Migration{}
+
+ if err = p.db.QueryRowxContext(ctx, p.sqlSelectLatestMigration).StructScan(migration); err != nil {
+ return nil, err
+ }
+
+ return migration, nil
}
func schemaMigrateChecks(providerName string, up bool, targetVersion, currentVersion int) (err error) {
+ switch {
+ case currentVersion == -1:
+ return fmt.Errorf(errFmtMigrationPre1, "up from", errFmtMigrationPre1SuggestedVersion)
+ case targetVersion == -1:
+ return fmt.Errorf(errFmtMigrationPre1, "down to", fmt.Sprintf("you should downgrade to schema version 1 using the current authelia version then use the suggested authelia version to downgrade to pre1: %s", errFmtMigrationPre1SuggestedVersion))
+ }
+
if targetVersion == currentVersion {
return fmt.Errorf(ErrFmtMigrateAlreadyOnTargetVersion, targetVersion, currentVersion)
}
@@ -325,7 +352,7 @@ func schemaMigrateChecks(providerName string, up bool, targetVersion, currentVer
return fmt.Errorf(ErrFmtMigrateUpTargetGreaterThanLatest, targetVersion, latest)
}
} else {
- if targetVersion < -1 {
+ if targetVersion < 0 {
return fmt.Errorf(ErrFmtMigrateDownTargetLessThanMinimum, targetVersion)
}
@@ -345,7 +372,7 @@ func SchemaVersionToString(version int) (versionStr string) {
case -1:
return "pre1"
case 0:
- return "N/A"
+ return na
default:
return strconv.Itoa(version)
}