summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider.go
diff options
context:
space:
mode:
authorJames Elliott <james-d-elliott@users.noreply.github.com>2020-07-16 15:56:08 +1000
committerGitHub <noreply@github.com>2020-07-16 15:56:08 +1000
commitea1fae6491fbd2b625262d128ae35cd27015b12e (patch)
treecc706df4becdc116230c9ec8ad51f1ae7f17ccda /internal/storage/sql_provider.go
parent29c467098c4c2e1878ad0a4e02b8cd170b705294 (diff)
[MISC] Storage Schema Versioning Model (#1057)
* [MISC] Storage Schema Versioning Model * fixup go.sum * remove pq * fix int to text issue * fix incorrect SQL text * use key_name vs key * use transactions for all queries during upgrades * fix missing parenthesis * move upgrades to their own file * add provider name for future usage in upgrades * fix missing create config table values * fix using the const instead of the provider SQL * import logging once and reuse * update docs * remove db at suite teardown * apply suggestions from code review * fix mysql * make errors more uniform * style changes * remove commented code sections * remove commented code sections * add schema version type * add sql mock unit tests * go mod tidy * test blank row situations
Diffstat (limited to 'internal/storage/sql_provider.go')
-rw-r--r--internal/storage/sql_provider.go129
1 files changed, 94 insertions, 35 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go
index f71bcff6d..891eec86e 100644
--- a/internal/storage/sql_provider.go
+++ b/internal/storage/sql_provider.go
@@ -6,19 +6,21 @@ import (
"fmt"
"time"
+ "github.com/sirupsen/logrus"
+
+ "github.com/authelia/authelia/internal/logging"
"github.com/authelia/authelia/internal/models"
+ "github.com/authelia/authelia/internal/utils"
)
// SQLProvider is a storage provider persisting data in a SQL database.
type SQLProvider struct {
- db *sql.DB
+ db *sql.DB
+ log *logrus.Logger
+ name string
- sqlCreateUserPreferencesTable string
- sqlCreateIdentityVerificationTokensTable string
- sqlCreateTOTPSecretsTable string
- sqlCreateU2FDeviceHandlesTable string
- sqlCreateAuthenticationLogsTable string
- sqlCreateAuthenticationLogsUserTimeIndex string
+ sqlUpgradesCreateTableStatements map[SchemaVersion]map[string]string
+ sqlUpgradesCreateTableIndexesStatements map[SchemaVersion][]string
sqlGetPreferencesByUsername string
sqlUpsertSecondFactorPreference string
@@ -36,50 +38,107 @@ type SQLProvider struct {
sqlInsertAuthenticationLog string
sqlGetLatestAuthenticationLogs string
+
+ sqlGetExistingTables string
+
+ sqlConfigSetValue string
+ sqlConfigGetValue string
}
func (p *SQLProvider) initialize(db *sql.DB) error {
p.db = db
+ p.log = logging.Logger()
- _, err := db.Exec(p.sqlCreateUserPreferencesTable)
- if err != nil {
- return fmt.Errorf("Unable to create table %s: %v", preferencesTableName, err)
- }
+ return p.upgrade()
+}
- _, err = db.Exec(p.sqlCreateIdentityVerificationTokensTable)
+func (p *SQLProvider) getSchemaBasicDetails() (version SchemaVersion, tables []string, err error) {
+ rows, err := p.db.Query(p.sqlGetExistingTables)
if err != nil {
- return fmt.Errorf("Unable to create table %s: %v", identityVerificationTokensTableName, err)
+ return version, tables, err
}
- _, err = db.Exec(p.sqlCreateTOTPSecretsTable)
- if err != nil {
- return fmt.Errorf("Unable to create table %s: %v", totpSecretsTableName, err)
+ defer rows.Close()
+
+ var table string
+
+ for rows.Next() {
+ err := rows.Scan(&table)
+ if err != nil {
+ return version, tables, err
+ }
+
+ tables = append(tables, table)
}
- // keyHandle and publicKey are stored in base64 format
- _, err = db.Exec(p.sqlCreateU2FDeviceHandlesTable)
- if err != nil {
- return fmt.Errorf("Unable to create table %s: %v", u2fDeviceHandlesTableName, err)
+ if utils.IsStringInSlice(configTableName, tables) {
+ rows, err := p.db.Query(p.sqlConfigGetValue, "schema", "version")
+ if err != nil {
+ return version, tables, err
+ }
+
+ for rows.Next() {
+ err := rows.Scan(&version)
+ if err != nil {
+ return version, tables, err
+ }
+ }
}
- _, err = db.Exec(p.sqlCreateAuthenticationLogsTable)
+ return version, tables, nil
+}
+
+func (p *SQLProvider) upgrade() error {
+ p.log.Debug("Storage schema is being checked to verify it is up to date")
+
+ version, tables, err := p.getSchemaBasicDetails()
if err != nil {
- return fmt.Errorf("Unable to create table %s: %v", authenticationLogsTableName, err)
+ return err
}
- // Create an index on (username, time) because this couple is highly used by the regulation module
- // to check whether a user is banned.
- if p.sqlCreateAuthenticationLogsUserTimeIndex != "" {
- _, err = db.Exec(p.sqlCreateAuthenticationLogsUserTimeIndex)
+ if version < storageSchemaCurrentVersion {
+ p.log.Debugf("Storage schema is v%d, latest is v%d", version, storageSchemaCurrentVersion)
+
+ tx, err := p.db.Begin()
if err != nil {
- return fmt.Errorf("Unable to create table %s: %v", authenticationLogsTableName, err)
+ return err
}
+
+ switch version {
+ case 0:
+ err := p.upgradeSchemaToVersion001(tx, tables)
+ if err != nil {
+ return p.handleUpgradeFailure(tx, 1, err)
+ }
+
+ fallthrough
+ default:
+ err := tx.Commit()
+ if err != nil {
+ return err
+ }
+
+ p.log.Infof("Storage schema upgrade to v%d completed", storageSchemaCurrentVersion)
+ }
+ } else {
+ p.log.Debug("Storage schema is up to date")
}
return nil
}
-// LoadPreferred2FAMethod load the preferred method for 2FA from sqlite db.
+func (p *SQLProvider) handleUpgradeFailure(tx *sql.Tx, version SchemaVersion, err error) error {
+ rollbackErr := tx.Rollback()
+ formattedErr := fmt.Errorf("%s%d: %v", storageSchemaUpgradeErrorText, version, err)
+
+ if rollbackErr != nil {
+ return fmt.Errorf("rollback error occurred: %v (inner error %v)", rollbackErr, formattedErr)
+ }
+
+ return formattedErr
+}
+
+// LoadPreferred2FAMethod load the preferred method for 2FA from the database.
func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) {
var method string
@@ -98,13 +157,13 @@ func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) {
return method, err
}
-// SavePreferred2FAMethod save the preferred method for 2FA in sqlite db.
+// SavePreferred2FAMethod save the preferred method for 2FA to the database.
func (p *SQLProvider) SavePreferred2FAMethod(username string, method string) error {
_, err := p.db.Exec(p.sqlUpsertSecondFactorPreference, username, method)
return err
}
-// FindIdentityVerificationToken look for an identity verification token in DB.
+// FindIdentityVerificationToken look for an identity verification token in the database.
func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) {
var found bool
@@ -116,25 +175,25 @@ func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error)
return found, nil
}
-// SaveIdentityVerificationToken save an identity verification token in DB.
+// SaveIdentityVerificationToken save an identity verification token in the database.
func (p *SQLProvider) SaveIdentityVerificationToken(token string) error {
_, err := p.db.Exec(p.sqlInsertIdentityVerificationToken, token)
return err
}
-// RemoveIdentityVerificationToken remove an identity verification token from the DB.
+// RemoveIdentityVerificationToken remove an identity verification token from the database.
func (p *SQLProvider) RemoveIdentityVerificationToken(token string) error {
_, err := p.db.Exec(p.sqlDeleteIdentityVerificationToken, token)
return err
}
-// SaveTOTPSecret save a TOTP secret of a given user.
+// SaveTOTPSecret save a TOTP secret of a given user in the database.
func (p *SQLProvider) SaveTOTPSecret(username string, secret string) error {
_, err := p.db.Exec(p.sqlUpsertTOTPSecret, username, secret)
return err
}
-// LoadTOTPSecret load a TOTP secret given a username.
+// LoadTOTPSecret load a TOTP secret given a username from the database.
func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) {
var secret string
if err := p.db.QueryRow(p.sqlGetTOTPSecretByUsername, username).Scan(&secret); err != nil {
@@ -148,7 +207,7 @@ func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) {
return secret, nil
}
-// DeleteTOTPSecret delete a TOTP secret given a username.
+// DeleteTOTPSecret delete a TOTP secret from the database given a username.
func (p *SQLProvider) DeleteTOTPSecret(username string) error {
_, err := p.db.Exec(p.sqlDeleteTOTPSecret, username)
return err