diff options
| author | James Elliott <james-d-elliott@users.noreply.github.com> | 2020-07-16 15:56:08 +1000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-07-16 15:56:08 +1000 |
| commit | ea1fae6491fbd2b625262d128ae35cd27015b12e (patch) | |
| tree | cc706df4becdc116230c9ec8ad51f1ae7f17ccda /internal/storage/sql_provider.go | |
| parent | 29c467098c4c2e1878ad0a4e02b8cd170b705294 (diff) | |
[MISC] Storage Schema Versioning Model (#1057)
* [MISC] Storage Schema Versioning Model
* fixup go.sum
* remove pq
* fix int to text issue
* fix incorrect SQL text
* use key_name vs key
* use transactions for all queries during upgrades
* fix missing parenthesis
* move upgrades to their own file
* add provider name for future usage in upgrades
* fix missing create config table values
* fix using the const instead of the provider SQL
* import logging once and reuse
* update docs
* remove db at suite teardown
* apply suggestions from code review
* fix mysql
* make errors more uniform
* style changes
* remove commented code sections
* remove commented code sections
* add schema version type
* add sql mock unit tests
* go mod tidy
* test blank row situations
Diffstat (limited to 'internal/storage/sql_provider.go')
| -rw-r--r-- | internal/storage/sql_provider.go | 129 |
1 files changed, 94 insertions, 35 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index f71bcff6d..891eec86e 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -6,19 +6,21 @@ import ( "fmt" "time" + "github.com/sirupsen/logrus" + + "github.com/authelia/authelia/internal/logging" "github.com/authelia/authelia/internal/models" + "github.com/authelia/authelia/internal/utils" ) // SQLProvider is a storage provider persisting data in a SQL database. type SQLProvider struct { - db *sql.DB + db *sql.DB + log *logrus.Logger + name string - sqlCreateUserPreferencesTable string - sqlCreateIdentityVerificationTokensTable string - sqlCreateTOTPSecretsTable string - sqlCreateU2FDeviceHandlesTable string - sqlCreateAuthenticationLogsTable string - sqlCreateAuthenticationLogsUserTimeIndex string + sqlUpgradesCreateTableStatements map[SchemaVersion]map[string]string + sqlUpgradesCreateTableIndexesStatements map[SchemaVersion][]string sqlGetPreferencesByUsername string sqlUpsertSecondFactorPreference string @@ -36,50 +38,107 @@ type SQLProvider struct { sqlInsertAuthenticationLog string sqlGetLatestAuthenticationLogs string + + sqlGetExistingTables string + + sqlConfigSetValue string + sqlConfigGetValue string } func (p *SQLProvider) initialize(db *sql.DB) error { p.db = db + p.log = logging.Logger() - _, err := db.Exec(p.sqlCreateUserPreferencesTable) - if err != nil { - return fmt.Errorf("Unable to create table %s: %v", preferencesTableName, err) - } + return p.upgrade() +} - _, err = db.Exec(p.sqlCreateIdentityVerificationTokensTable) +func (p *SQLProvider) getSchemaBasicDetails() (version SchemaVersion, tables []string, err error) { + rows, err := p.db.Query(p.sqlGetExistingTables) if err != nil { - return fmt.Errorf("Unable to create table %s: %v", identityVerificationTokensTableName, err) + return version, tables, err } - _, err = db.Exec(p.sqlCreateTOTPSecretsTable) - if err != nil { - return fmt.Errorf("Unable to create table %s: %v", totpSecretsTableName, err) + defer rows.Close() + + var table string + + for rows.Next() { + err := rows.Scan(&table) + if err != nil { + return version, tables, err + } + + tables = append(tables, table) } - // keyHandle and publicKey are stored in base64 format - _, err = db.Exec(p.sqlCreateU2FDeviceHandlesTable) - if err != nil { - return fmt.Errorf("Unable to create table %s: %v", u2fDeviceHandlesTableName, err) + if utils.IsStringInSlice(configTableName, tables) { + rows, err := p.db.Query(p.sqlConfigGetValue, "schema", "version") + if err != nil { + return version, tables, err + } + + for rows.Next() { + err := rows.Scan(&version) + if err != nil { + return version, tables, err + } + } } - _, err = db.Exec(p.sqlCreateAuthenticationLogsTable) + return version, tables, nil +} + +func (p *SQLProvider) upgrade() error { + p.log.Debug("Storage schema is being checked to verify it is up to date") + + version, tables, err := p.getSchemaBasicDetails() if err != nil { - return fmt.Errorf("Unable to create table %s: %v", authenticationLogsTableName, err) + return err } - // Create an index on (username, time) because this couple is highly used by the regulation module - // to check whether a user is banned. - if p.sqlCreateAuthenticationLogsUserTimeIndex != "" { - _, err = db.Exec(p.sqlCreateAuthenticationLogsUserTimeIndex) + if version < storageSchemaCurrentVersion { + p.log.Debugf("Storage schema is v%d, latest is v%d", version, storageSchemaCurrentVersion) + + tx, err := p.db.Begin() if err != nil { - return fmt.Errorf("Unable to create table %s: %v", authenticationLogsTableName, err) + return err } + + switch version { + case 0: + err := p.upgradeSchemaToVersion001(tx, tables) + if err != nil { + return p.handleUpgradeFailure(tx, 1, err) + } + + fallthrough + default: + err := tx.Commit() + if err != nil { + return err + } + + p.log.Infof("Storage schema upgrade to v%d completed", storageSchemaCurrentVersion) + } + } else { + p.log.Debug("Storage schema is up to date") } return nil } -// LoadPreferred2FAMethod load the preferred method for 2FA from sqlite db. +func (p *SQLProvider) handleUpgradeFailure(tx *sql.Tx, version SchemaVersion, err error) error { + rollbackErr := tx.Rollback() + formattedErr := fmt.Errorf("%s%d: %v", storageSchemaUpgradeErrorText, version, err) + + if rollbackErr != nil { + return fmt.Errorf("rollback error occurred: %v (inner error %v)", rollbackErr, formattedErr) + } + + return formattedErr +} + +// LoadPreferred2FAMethod load the preferred method for 2FA from the database. func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) { var method string @@ -98,13 +157,13 @@ func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) { return method, err } -// SavePreferred2FAMethod save the preferred method for 2FA in sqlite db. +// SavePreferred2FAMethod save the preferred method for 2FA to the database. func (p *SQLProvider) SavePreferred2FAMethod(username string, method string) error { _, err := p.db.Exec(p.sqlUpsertSecondFactorPreference, username, method) return err } -// FindIdentityVerificationToken look for an identity verification token in DB. +// FindIdentityVerificationToken look for an identity verification token in the database. func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) { var found bool @@ -116,25 +175,25 @@ func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) return found, nil } -// SaveIdentityVerificationToken save an identity verification token in DB. +// SaveIdentityVerificationToken save an identity verification token in the database. func (p *SQLProvider) SaveIdentityVerificationToken(token string) error { _, err := p.db.Exec(p.sqlInsertIdentityVerificationToken, token) return err } -// RemoveIdentityVerificationToken remove an identity verification token from the DB. +// RemoveIdentityVerificationToken remove an identity verification token from the database. func (p *SQLProvider) RemoveIdentityVerificationToken(token string) error { _, err := p.db.Exec(p.sqlDeleteIdentityVerificationToken, token) return err } -// SaveTOTPSecret save a TOTP secret of a given user. +// SaveTOTPSecret save a TOTP secret of a given user in the database. func (p *SQLProvider) SaveTOTPSecret(username string, secret string) error { _, err := p.db.Exec(p.sqlUpsertTOTPSecret, username, secret) return err } -// LoadTOTPSecret load a TOTP secret given a username. +// LoadTOTPSecret load a TOTP secret given a username from the database. func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) { var secret string if err := p.db.QueryRow(p.sqlGetTOTPSecretByUsername, username).Scan(&secret); err != nil { @@ -148,7 +207,7 @@ func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) { return secret, nil } -// DeleteTOTPSecret delete a TOTP secret given a username. +// DeleteTOTPSecret delete a TOTP secret from the database given a username. func (p *SQLProvider) DeleteTOTPSecret(username string) error { _, err := p.db.Exec(p.sqlDeleteTOTPSecret, username) return err |
