diff options
| author | James Elliott <james-d-elliott@users.noreply.github.com> | 2021-11-23 20:45:38 +1100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-11-23 20:45:38 +1100 |
| commit | 3695aa8140eb91fd54a4cd849e1340ad4c36d987 (patch) | |
| tree | e2cbb84db06b8058dc89ba9c616f016a223e6e67 /internal/storage/sql_provider.go | |
| parent | 884dc99083ba280d1a93103c4e16d4446ff7fdcc (diff) | |
feat(storage): primary key for all tables and general qol refactoring (#2431)
This is a massive overhaul to the SQL Storage for Authelia. It facilitates a whole heap of utility commands to help manage the database, primary keys, ensures all database requests use a context for cancellations, and paves the way for a few other PR's which improve the database.
Fixes #1337
Diffstat (limited to 'internal/storage/sql_provider.go')
| -rw-r--r-- | internal/storage/sql_provider.go | 367 |
1 files changed, 191 insertions, 176 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index 8bd27a7e8..f136e4c6e 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -1,173 +1,199 @@ package storage import ( + "context" "database/sql" - "encoding/base64" + "errors" "fmt" "time" + "github.com/jmoiron/sqlx" "github.com/sirupsen/logrus" + "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/models" - "github.com/authelia/authelia/v4/internal/utils" ) -// SQLProvider is a storage provider persisting data in a SQL database. -type SQLProvider struct { - db *sql.DB - log *logrus.Logger - name string - - sqlUpgradesCreateTableStatements map[SchemaVersion]map[string]string - sqlUpgradesCreateTableIndexesStatements map[SchemaVersion][]string +// NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's. +func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvider) { + db, err := sqlx.Open(driverName, dataSourceName) - sqlGetPreferencesByUsername string - sqlUpsertSecondFactorPreference string + provider = SQLProvider{ + name: name, + driverName: driverName, + db: db, + log: logging.Logger(), + errOpen: err, - sqlTestIdentityVerificationTokenExistence string - sqlInsertIdentityVerificationToken string - sqlDeleteIdentityVerificationToken string + sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs), + sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs), - sqlGetTOTPSecretByUsername string - sqlUpsertTOTPSecret string - sqlDeleteTOTPSecret string + sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification), + sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification), + sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification), - sqlGetU2FDeviceHandleByUsername string - sqlUpsertU2FDeviceHandle string + sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations), + sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations), + sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations), - sqlInsertAuthenticationLog string - sqlGetLatestAuthenticationLogs string + sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices), + sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices), - sqlGetExistingTables string - - sqlConfigSetValue string - sqlConfigGetValue string -} + sqlUpsertPreferred2FAMethod: fmt.Sprintf(queryFmtUpsertPreferred2FAMethod, tableUserPreferences), + sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences), + sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableU2FDevices, tableUserPreferences), -func (p *SQLProvider) initialize(db *sql.DB) error { - p.db = db - p.log = logging.Logger() + sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations), + sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations), + sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations), - return p.upgrade() -} - -func (p *SQLProvider) getSchemaBasicDetails() (version SchemaVersion, tables []string, err error) { - rows, err := p.db.Query(p.sqlGetExistingTables) - if err != nil { - return version, tables, err + sqlFmtRenameTable: queryFmtRenameTable, } - defer rows.Close() - - var table string + return provider +} - for rows.Next() { - err := rows.Scan(&table) - if err != nil { - return version, tables, err - } +// SQLProvider is a storage provider persisting data in a SQL database. +type SQLProvider struct { + db *sqlx.DB + log *logrus.Logger + name string + driverName string + errOpen error + + // Table: authentication_logs. + sqlInsertAuthenticationAttempt string + sqlSelectAuthenticationAttemptsByUsername string + + // Table: identity_verification_tokens. + sqlInsertIdentityVerification string + sqlDeleteIdentityVerification string + sqlSelectExistsIdentityVerification string + + // Table: totp_configurations. + sqlUpsertTOTPConfig string + sqlDeleteTOTPConfig string + sqlSelectTOTPConfig string + + // Table: u2f_devices. + sqlUpsertU2FDevice string + sqlSelectU2FDevice string + + // Table: user_preferences. + sqlUpsertPreferred2FAMethod string + sqlSelectPreferred2FAMethod string + sqlSelectUserInfo string + + // Table: migrations. + sqlInsertMigration string + sqlSelectMigrations string + sqlSelectLatestMigration string + + // Utility. + sqlSelectExistingTables string + sqlFmtRenameTable string +} - tables = append(tables, table) +// StartupCheck implements the provider startup check interface. +func (p *SQLProvider) StartupCheck() (err error) { + if p.errOpen != nil { + return p.errOpen } - if utils.IsStringInSlice(configTableName, tables) { - rows, err := p.db.Query(p.sqlConfigGetValue, "schema", "version") - if err != nil { - return version, tables, err + // TODO: Decide if this is needed, or if it should be configurable. + for i := 0; i < 19; i++ { + err = p.db.Ping() + if err == nil { + break } - for rows.Next() { - err := rows.Scan(&version) - if err != nil { - return version, tables, err - } - } + time.Sleep(time.Millisecond * 500) } - return version, tables, nil -} - -func (p *SQLProvider) upgrade() error { - p.log.Debug("Storage schema is being checked to verify it is up to date") - - version, tables, err := p.getSchemaBasicDetails() if err != nil { return err } - if version < storageSchemaCurrentVersion { - p.log.Debugf("Storage schema is v%d, latest is v%d", version, storageSchemaCurrentVersion) + p.log.Infof("Storage schema is being checked for updates") - tx, err := p.db.Begin() - if err != nil { - return err - } + ctx := context.Background() - switch version { - case 0: - err := p.upgradeSchemaToVersion001(tx, tables) - if err != nil { - return p.handleUpgradeFailure(tx, 1, err) - } - - fallthrough - default: - err := tx.Commit() - if err != nil { - return err - } - - p.log.Infof("Storage schema upgrade to v%d completed", storageSchemaCurrentVersion) - } - } else { - p.log.Debug("Storage schema is up to date") - } + err = p.SchemaMigrate(ctx, true, SchemaLatest) - return nil + switch err { + case ErrSchemaAlreadyUpToDate: + p.log.Infof("Storage schema is already up to date") + return nil + case nil: + return nil + default: + return err + } } -func (p *SQLProvider) handleUpgradeFailure(tx *sql.Tx, version SchemaVersion, err error) error { - rollbackErr := tx.Rollback() - formattedErr := fmt.Errorf("%s%d: %v", storageSchemaUpgradeErrorText, version, err) - - if rollbackErr != nil { - return fmt.Errorf("rollback error occurred: %v (inner error %v)", rollbackErr, formattedErr) - } +// SavePreferred2FAMethod save the preferred method for 2FA to the database. +func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method) - return formattedErr + return err } // LoadPreferred2FAMethod load the preferred method for 2FA from the database. -func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) { - var method string +func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) { + err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username) - rows, err := p.db.Query(p.sqlGetPreferencesByUsername, username) - if err != nil { + switch err { + case sql.ErrNoRows: + return "", nil + case nil: + return method, err + default: return "", err } - defer rows.Close() +} - if !rows.Next() { - return "", nil - } +// LoadUserInfo loads the models.UserInfo from the database. +func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error) { + err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username) - err = rows.Scan(&method) + switch { + case err == nil: + return info, nil + case errors.Is(err, sql.ErrNoRows): + _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0]) + if err != nil { + return models.UserInfo{}, err + } + + err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username) + if err != nil { + return models.UserInfo{}, err + } - return method, err + return info, nil + default: + return models.UserInfo{}, err + } } -// SavePreferred2FAMethod save the preferred method for 2FA to the database. -func (p *SQLProvider) SavePreferred2FAMethod(username string, method string) error { - _, err := p.db.Exec(p.sqlUpsertSecondFactorPreference, username, method) +// SaveIdentityVerification save an identity verification record to the database. +func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, verification.Token) + return err } -// FindIdentityVerificationToken look for an identity verification token in the database. -func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) { - var found bool +// RemoveIdentityVerification remove an identity verification record from the database. +func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, token string) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, token) + + return err +} - err := p.db.QueryRow(p.sqlTestIdentityVerificationTokenExistence, token).Scan(&found) +// FindIdentityVerification checks if an identity verification record is in the database and active. +func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) { + err = p.db.GetContext(ctx, &found, p.sqlSelectExistsIdentityVerification, jti) if err != nil { return false, err } @@ -175,105 +201,94 @@ func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) return found, nil } -// SaveIdentityVerificationToken save an identity verification token in the database. -func (p *SQLProvider) SaveIdentityVerificationToken(token string) error { - _, err := p.db.Exec(p.sqlInsertIdentityVerificationToken, token) - return err -} +// SaveTOTPConfiguration save a TOTP config of a given user in the database. +func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error) { + // TODO: Encrypt config.Secret here. + _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig, + config.Username, + config.Algorithm, + config.Digits, + config.Period, + config.Secret, + ) -// RemoveIdentityVerificationToken remove an identity verification token from the database. -func (p *SQLProvider) RemoveIdentityVerificationToken(token string) error { - _, err := p.db.Exec(p.sqlDeleteIdentityVerificationToken, token) return err } -// SaveTOTPSecret save a TOTP secret of a given user in the database. -func (p *SQLProvider) SaveTOTPSecret(username string, secret string) error { - _, err := p.db.Exec(p.sqlUpsertTOTPSecret, username, secret) +// DeleteTOTPConfiguration delete a TOTP secret from the database given a username. +func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username) + return err } -// LoadTOTPSecret load a TOTP secret given a username from the database. -func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) { - var secret string - if err := p.db.QueryRow(p.sqlGetTOTPSecretByUsername, username).Scan(&secret); err != nil { +// LoadTOTPConfiguration load a TOTP secret given a username from the database. +func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error) { + config = &models.TOTPConfiguration{} + + err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config) + if err != nil { if err == sql.ErrNoRows { - return "", ErrNoTOTPSecret + return nil, ErrNoTOTPSecret } - return "", err + return nil, err } - return secret, nil + // TODO: Decrypt config.Secret here. + return config, nil } -// DeleteTOTPSecret delete a TOTP secret from the database given a username. -func (p *SQLProvider) DeleteTOTPSecret(username string) error { - _, err := p.db.Exec(p.sqlDeleteTOTPSecret, username) - return err -} - -// SaveU2FDeviceHandle save a registered U2F device registration blob. -func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error { - _, err := p.db.Exec(p.sqlUpsertU2FDeviceHandle, - username, - base64.StdEncoding.EncodeToString(keyHandle), - base64.StdEncoding.EncodeToString(publicKey)) +// SaveU2FDevice saves a registered U2F device. +func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.KeyHandle, device.PublicKey) return err } -// LoadU2FDeviceHandle load a U2F device registration blob for a given username. -func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) { - var keyHandleBase64, publicKeyBase64 string - if err := p.db.QueryRow(p.sqlGetU2FDeviceHandleByUsername, username).Scan(&keyHandleBase64, &publicKeyBase64); err != nil { - if err == sql.ErrNoRows { - return nil, nil, ErrNoU2FDeviceHandle - } - - return nil, nil, err +// LoadU2FDevice loads a U2F device registration for a given username. +func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) { + device = &models.U2FDevice{ + Username: username, } - keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64) - + err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username) if err != nil { - return nil, nil, err - } - - publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64) + if err == sql.ErrNoRows { + return nil, ErrNoU2FDeviceHandle + } - if err != nil { - return nil, nil, err + return nil, err } - return keyHandle, publicKey, nil + return device, nil } // AppendAuthenticationLog append a mark to the authentication log. -func (p *SQLProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error { - _, err := p.db.Exec(p.sqlInsertAuthenticationLog, attempt.Username, attempt.Successful, attempt.Time.Unix()) +func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Username) return err } -// LoadLatestAuthenticationLogs retrieve the latest marks from the authentication log. -func (p *SQLProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) { - var t int64 - - rows, err := p.db.Query(p.sqlGetLatestAuthenticationLogs, fromDate.Unix(), username) - +// LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log. +func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []models.AuthenticationAttempt, err error) { + rows, err := p.db.QueryxContext(ctx, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page) if err != nil { return nil, err } - attempts := make([]models.AuthenticationAttempt, 0, 10) + defer func() { + if err := rows.Close(); err != nil { + p.log.Errorf(logFmtErrClosingConn, err) + } + }() + + attempts = make([]models.AuthenticationAttempt, 0, limit) for rows.Next() { - attempt := models.AuthenticationAttempt{ - Username: username, - } - err = rows.Scan(&attempt.Successful, &t) - attempt.Time = time.Unix(t, 0) + var attempt models.AuthenticationAttempt + err = rows.StructScan(&attempt) if err != nil { return nil, err } |
