diff options
| author | James Elliott <james-d-elliott@users.noreply.github.com> | 2022-11-25 23:44:55 +1100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-25 23:44:55 +1100 |
| commit | 3e4ac7821d51ac447bb39e7e1ea3c385dc3084d9 (patch) | |
| tree | 69594576856eb8b587158d9245f70aff4fec429a /internal/storage/sql_provider_schema.go | |
| parent | 3c291b5685212813f98f365c8d963e0f107860cb (diff) | |
refactor: remove pre1 migration path (#4356)
This removes pre1 migrations and improves a lot of tooling.
Diffstat (limited to 'internal/storage/sql_provider_schema.go')
| -rw-r--r-- | internal/storage/sql_provider_schema.go | 255 |
1 files changed, 141 insertions, 114 deletions
diff --git a/internal/storage/sql_provider_schema.go b/internal/storage/sql_provider_schema.go index e01aee1ad..8c015e963 100644 --- a/internal/storage/sql_provider_schema.go +++ b/internal/storage/sql_provider_schema.go @@ -81,15 +81,41 @@ func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error return 0, nil } -func (p *SQLProvider) schemaLatestMigration(ctx context.Context) (migration *model.Migration, err error) { - migration = &model.Migration{} +// SchemaLatestVersion returns the latest version available for migration. +func (p *SQLProvider) SchemaLatestVersion() (version int, err error) { + return latestMigrationVersion(p.name) +} - err = p.db.QueryRowxContext(ctx, p.sqlSelectLatestMigration).StructScan(migration) +// SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version. +func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) { + current, err := p.SchemaVersion(ctx) if err != nil { - return nil, err + return migrations, err } - return migration, nil + if version == 0 { + version = SchemaLatest + } + + if current >= version { + return migrations, ErrNoAvailableMigrations + } + + return loadMigrations(p.name, current, version) +} + +// SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version. +func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) { + current, err := p.SchemaVersion(ctx) + if err != nil { + return migrations, err + } + + if current <= version { + return migrations, ErrNoAvailableMigrations + } + + return loadMigrations(p.name, current, version) } // SchemaMigrationHistory returns migration history rows. @@ -121,184 +147,185 @@ func (p *SQLProvider) SchemaMigrationHistory(ctx context.Context) (migrations [] // SchemaMigrate migrates from the current version to the provided version. func (p *SQLProvider) SchemaMigrate(ctx context.Context, up bool, version int) (err error) { - currentVersion, err := p.SchemaVersion(ctx) - if err != nil { - return err - } + var ( + tx *sqlx.Tx + conn SQLXConnection + ) + + if p.name != providerMySQL { + if tx, err = p.db.BeginTxx(ctx, nil); err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } - if err = schemaMigrateChecks(p.name, up, version, currentVersion); err != nil { - return err + conn = tx + } else { + conn = p.db } - return p.schemaMigrate(ctx, currentVersion, version) -} - -//nolint:gocyclo // TODO: Consider refactoring time permitting. -func (p *SQLProvider) schemaMigrate(ctx context.Context, prior, target int) (err error) { - migrations, err := loadMigrations(p.name, prior, target) + currentVersion, err := p.SchemaVersion(ctx) if err != nil { return err } - if len(migrations) == 0 && (prior != 1 || target != -1) { - return ErrNoMigrationsFound + if currentVersion != 0 { + if err = p.schemaMigrateLock(ctx, conn); err != nil { + return err + } } - switch { - case prior == -1: - p.log.Infof(logFmtMigrationFromTo, "pre1", strconv.Itoa(migrations[len(migrations)-1].After())) - - err = p.schemaMigratePre1To1(ctx) - if err != nil { - if errRollback := p.schemaMigratePre1To1Rollback(ctx, true); errRollback != nil { - return fmt.Errorf(errFmtFailedMigrationPre1, err) - } - - return fmt.Errorf(errFmtFailedMigrationPre1, err) + if err = schemaMigrateChecks(p.name, up, version, currentVersion); err != nil { + if tx != nil { + _ = tx.Rollback() } - case target == -1: - p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), "pre1") - default: - p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After())) + + return err } - for _, migration := range migrations { - if prior == -1 && migration.Version == 1 { - // Skip migration version 1 when upgrading from pre1 as it's applied as part of the pre1 upgrade. - continue + if err = p.schemaMigrate(ctx, conn, currentVersion, version); err != nil { + if tx != nil && err == ErrNoMigrationsFound { + _ = tx.Rollback() } - err = p.schemaMigrateApply(ctx, migration) - if err != nil { - return p.schemaMigrateRollback(ctx, prior, migration.After(), err) - } + return err } - switch { - case prior == -1: - p.log.Infof(logFmtMigrationComplete, "pre1", strconv.Itoa(migrations[len(migrations)-1].After())) - case target == -1: - err = p.schemaMigrate1ToPre1(ctx) - if err != nil { - if errRollback := p.schemaMigratePre1To1Rollback(ctx, false); errRollback != nil { - return fmt.Errorf(errFmtFailedMigrationPre1, err) + if tx != nil { + if err = tx.Commit(); err != nil { + if rerr := tx.Rollback(); rerr != nil { + return fmt.Errorf("failed to commit the transaction with: commit error: %w, rollback error: %+v", err, rerr) } - return fmt.Errorf(errFmtFailedMigrationPre1, err) + return fmt.Errorf("failed to commit the transaction but it has been rolled back: commit error: %w", err) } - - p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), "pre1") - default: - p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After())) } return nil } -func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, prior, after int, migrateErr error) (err error) { - migrations, err := loadMigrations(p.name, after, prior) +func (p *SQLProvider) schemaMigrate(ctx context.Context, conn SQLXConnection, prior, target int) (err error) { + migrations, err := loadMigrations(p.name, prior, target) if err != nil { - return fmt.Errorf("error loading migrations from version %d to version %d for rollback: %+v. rollback caused by: %+v", prior, after, err, migrateErr) + return err } - for _, migration := range migrations { - if prior == -1 && !migration.Up && migration.Version == 1 { - continue + if len(migrations) == 0 { + return ErrNoMigrationsFound + } + + p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After())) + + for i, migration := range migrations { + if migration.Up && prior == 0 && i == 1 { + if err = p.schemaMigrateLock(ctx, conn); err != nil { + return err + } } - err = p.schemaMigrateApply(ctx, migration) - if err != nil { - return fmt.Errorf("error applying migration version %d to version %d for rollback: %+v. rollback caused by: %+v", migration.Before(), migration.After(), err, migrateErr) + if err = p.schemaMigrateApply(ctx, conn, migration); err != nil { + return p.schemaMigrateRollback(ctx, conn, prior, migration.After(), err) } } - if prior == -1 { - if err = p.schemaMigrate1ToPre1(ctx); err != nil { - return fmt.Errorf("error applying migration version 1 to version pre1 for rollback: %+v. rollback caused by: %+v", err, migrateErr) - } + p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After())) + + return nil +} + +func (p *SQLProvider) schemaMigrateLock(ctx context.Context, conn SQLXConnection) (err error) { + if p.name != providerPostgres { + return nil + } + + if _, err = conn.ExecContext(ctx, fmt.Sprintf(queryFmtPostgreSQLLockTable, tableMigrations, "ACCESS EXCLUSIVE")); err != nil { + return fmt.Errorf("failed to lock tables: %w", err) } - return fmt.Errorf("migration rollback complete. rollback caused by: %+v", migrateErr) + return nil } -func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration model.SchemaMigration) (err error) { - _, err = p.db.ExecContext(ctx, migration.Query) - if err != nil { +func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) { + if _, err = conn.ExecContext(ctx, migration.Query); err != nil { return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err) } - if migration.Version == 1 { - // Skip the migration history insertion in a migration to v0. - if !migration.Up { - return nil - } - + if migration.Version == 1 && migration.Up { // Add the schema encryption value if upgrading to v1. - if err = p.setNewEncryptionCheckValue(ctx, &p.key, nil); err != nil { + if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil { return err } } - if migration.Version == 1 && !migration.Up { - return nil + if err = p.schemaMigrateFinalize(ctx, conn, migration); err != nil { + return err } - return p.schemaMigrateFinalize(ctx, migration) + return nil } -func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) { - return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After()) -} +func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) { + if migration.Version == 1 && !migration.Up { + return nil + } -func (p *SQLProvider) schemaMigrateFinalizeAdvanced(ctx context.Context, before, after int) (err error) { - _, err = p.db.ExecContext(ctx, p.sqlInsertMigration, time.Now(), before, after, utils.Version()) - if err != nil { - return err + if _, err = conn.ExecContext(ctx, p.sqlInsertMigration, time.Now(), migration.Before(), migration.After(), utils.Version()); err != nil { + return fmt.Errorf("failed inserting migration record: %w", err) } - p.log.Debugf("Storage schema migrated from version %d to %d", before, after) + p.log.Debugf("Storage schema migrated from version %d to %d", migration.Before(), migration.After()) return nil } -// SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version. -func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) { - current, err := p.SchemaVersion(ctx) - if err != nil { - return migrations, err - } - - if version == 0 { - version = SchemaLatest +func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, conn SQLXConnection, prior, after int, merr error) (err error) { + switch tx := conn.(type) { + case *sqlx.Tx: + return p.schemaMigrateRollbackWithTx(ctx, tx, merr) + default: + return p.schemaMigrateRollbackWithoutTx(ctx, prior, after, merr) } +} - if current >= version { - return migrations, ErrNoAvailableMigrations +func (p *SQLProvider) schemaMigrateRollbackWithTx(_ context.Context, tx *sqlx.Tx, merr error) (err error) { + if err = tx.Rollback(); err != nil { + return fmt.Errorf("error applying rollback %+v. rollback caused by: %w", err, merr) } - return loadMigrations(p.name, current, version) + return fmt.Errorf("migration rollback complete. rollback caused by: %w", merr) } -// SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version. -func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) { - current, err := p.SchemaVersion(ctx) +func (p *SQLProvider) schemaMigrateRollbackWithoutTx(ctx context.Context, prior, after int, merr error) (err error) { + migrations, err := loadMigrations(p.name, after, prior) if err != nil { - return migrations, err + return fmt.Errorf("error loading migrations from version %d to version %d for rollback: %+v. rollback caused by: %w", prior, after, err, merr) } - if current <= version { - return migrations, ErrNoAvailableMigrations + for _, migration := range migrations { + if err = p.schemaMigrateApply(ctx, p.db, migration); err != nil { + return fmt.Errorf("error applying migration version %d to version %d for rollback: %+v. rollback caused by: %w", migration.Before(), migration.After(), err, merr) + } } - return loadMigrations(p.name, current, version) + return fmt.Errorf("migration rollback complete. rollback caused by: %w", merr) } -// SchemaLatestVersion returns the latest version available for migration. -func (p *SQLProvider) SchemaLatestVersion() (version int, err error) { - return latestMigrationVersion(p.name) +func (p *SQLProvider) schemaLatestMigration(ctx context.Context) (migration *model.Migration, err error) { + migration = &model.Migration{} + + if err = p.db.QueryRowxContext(ctx, p.sqlSelectLatestMigration).StructScan(migration); err != nil { + return nil, err + } + + return migration, nil } func schemaMigrateChecks(providerName string, up bool, targetVersion, currentVersion int) (err error) { + switch { + case currentVersion == -1: + return fmt.Errorf(errFmtMigrationPre1, "up from", errFmtMigrationPre1SuggestedVersion) + case targetVersion == -1: + return fmt.Errorf(errFmtMigrationPre1, "down to", fmt.Sprintf("you should downgrade to schema version 1 using the current authelia version then use the suggested authelia version to downgrade to pre1: %s", errFmtMigrationPre1SuggestedVersion)) + } + if targetVersion == currentVersion { return fmt.Errorf(ErrFmtMigrateAlreadyOnTargetVersion, targetVersion, currentVersion) } @@ -325,7 +352,7 @@ func schemaMigrateChecks(providerName string, up bool, targetVersion, currentVer return fmt.Errorf(ErrFmtMigrateUpTargetGreaterThanLatest, targetVersion, latest) } } else { - if targetVersion < -1 { + if targetVersion < 0 { return fmt.Errorf(ErrFmtMigrateDownTargetLessThanMinimum, targetVersion) } @@ -345,7 +372,7 @@ func SchemaVersionToString(version int) (versionStr string) { case -1: return "pre1" case 0: - return "N/A" + return na default: return strconv.Itoa(version) } |
