summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider.go
diff options
context:
space:
mode:
authorJames Elliott <james-d-elliott@users.noreply.github.com>2021-11-23 20:45:38 +1100
committerGitHub <noreply@github.com>2021-11-23 20:45:38 +1100
commit3695aa8140eb91fd54a4cd849e1340ad4c36d987 (patch)
treee2cbb84db06b8058dc89ba9c616f016a223e6e67 /internal/storage/sql_provider.go
parent884dc99083ba280d1a93103c4e16d4446ff7fdcc (diff)
feat(storage): primary key for all tables and general qol refactoring (#2431)
This is a massive overhaul to the SQL Storage for Authelia. It facilitates a whole heap of utility commands to help manage the database, primary keys, ensures all database requests use a context for cancellations, and paves the way for a few other PR's which improve the database. Fixes #1337
Diffstat (limited to 'internal/storage/sql_provider.go')
-rw-r--r--internal/storage/sql_provider.go367
1 files changed, 191 insertions, 176 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go
index 8bd27a7e8..f136e4c6e 100644
--- a/internal/storage/sql_provider.go
+++ b/internal/storage/sql_provider.go
@@ -1,173 +1,199 @@
package storage
import (
+ "context"
"database/sql"
- "encoding/base64"
+ "errors"
"fmt"
"time"
+ "github.com/jmoiron/sqlx"
"github.com/sirupsen/logrus"
+ "github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/models"
- "github.com/authelia/authelia/v4/internal/utils"
)
-// SQLProvider is a storage provider persisting data in a SQL database.
-type SQLProvider struct {
- db *sql.DB
- log *logrus.Logger
- name string
-
- sqlUpgradesCreateTableStatements map[SchemaVersion]map[string]string
- sqlUpgradesCreateTableIndexesStatements map[SchemaVersion][]string
+// NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's.
+func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvider) {
+ db, err := sqlx.Open(driverName, dataSourceName)
- sqlGetPreferencesByUsername string
- sqlUpsertSecondFactorPreference string
+ provider = SQLProvider{
+ name: name,
+ driverName: driverName,
+ db: db,
+ log: logging.Logger(),
+ errOpen: err,
- sqlTestIdentityVerificationTokenExistence string
- sqlInsertIdentityVerificationToken string
- sqlDeleteIdentityVerificationToken string
+ sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
+ sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
- sqlGetTOTPSecretByUsername string
- sqlUpsertTOTPSecret string
- sqlDeleteTOTPSecret string
+ sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification),
+ sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification),
+ sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification),
- sqlGetU2FDeviceHandleByUsername string
- sqlUpsertU2FDeviceHandle string
+ sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
+ sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
+ sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
- sqlInsertAuthenticationLog string
- sqlGetLatestAuthenticationLogs string
+ sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices),
+ sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices),
- sqlGetExistingTables string
-
- sqlConfigSetValue string
- sqlConfigGetValue string
-}
+ sqlUpsertPreferred2FAMethod: fmt.Sprintf(queryFmtUpsertPreferred2FAMethod, tableUserPreferences),
+ sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences),
+ sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableU2FDevices, tableUserPreferences),
-func (p *SQLProvider) initialize(db *sql.DB) error {
- p.db = db
- p.log = logging.Logger()
+ sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations),
+ sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations),
+ sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations),
- return p.upgrade()
-}
-
-func (p *SQLProvider) getSchemaBasicDetails() (version SchemaVersion, tables []string, err error) {
- rows, err := p.db.Query(p.sqlGetExistingTables)
- if err != nil {
- return version, tables, err
+ sqlFmtRenameTable: queryFmtRenameTable,
}
- defer rows.Close()
-
- var table string
+ return provider
+}
- for rows.Next() {
- err := rows.Scan(&table)
- if err != nil {
- return version, tables, err
- }
+// SQLProvider is a storage provider persisting data in a SQL database.
+type SQLProvider struct {
+ db *sqlx.DB
+ log *logrus.Logger
+ name string
+ driverName string
+ errOpen error
+
+ // Table: authentication_logs.
+ sqlInsertAuthenticationAttempt string
+ sqlSelectAuthenticationAttemptsByUsername string
+
+ // Table: identity_verification_tokens.
+ sqlInsertIdentityVerification string
+ sqlDeleteIdentityVerification string
+ sqlSelectExistsIdentityVerification string
+
+ // Table: totp_configurations.
+ sqlUpsertTOTPConfig string
+ sqlDeleteTOTPConfig string
+ sqlSelectTOTPConfig string
+
+ // Table: u2f_devices.
+ sqlUpsertU2FDevice string
+ sqlSelectU2FDevice string
+
+ // Table: user_preferences.
+ sqlUpsertPreferred2FAMethod string
+ sqlSelectPreferred2FAMethod string
+ sqlSelectUserInfo string
+
+ // Table: migrations.
+ sqlInsertMigration string
+ sqlSelectMigrations string
+ sqlSelectLatestMigration string
+
+ // Utility.
+ sqlSelectExistingTables string
+ sqlFmtRenameTable string
+}
- tables = append(tables, table)
+// StartupCheck implements the provider startup check interface.
+func (p *SQLProvider) StartupCheck() (err error) {
+ if p.errOpen != nil {
+ return p.errOpen
}
- if utils.IsStringInSlice(configTableName, tables) {
- rows, err := p.db.Query(p.sqlConfigGetValue, "schema", "version")
- if err != nil {
- return version, tables, err
+ // TODO: Decide if this is needed, or if it should be configurable.
+ for i := 0; i < 19; i++ {
+ err = p.db.Ping()
+ if err == nil {
+ break
}
- for rows.Next() {
- err := rows.Scan(&version)
- if err != nil {
- return version, tables, err
- }
- }
+ time.Sleep(time.Millisecond * 500)
}
- return version, tables, nil
-}
-
-func (p *SQLProvider) upgrade() error {
- p.log.Debug("Storage schema is being checked to verify it is up to date")
-
- version, tables, err := p.getSchemaBasicDetails()
if err != nil {
return err
}
- if version < storageSchemaCurrentVersion {
- p.log.Debugf("Storage schema is v%d, latest is v%d", version, storageSchemaCurrentVersion)
+ p.log.Infof("Storage schema is being checked for updates")
- tx, err := p.db.Begin()
- if err != nil {
- return err
- }
+ ctx := context.Background()
- switch version {
- case 0:
- err := p.upgradeSchemaToVersion001(tx, tables)
- if err != nil {
- return p.handleUpgradeFailure(tx, 1, err)
- }
-
- fallthrough
- default:
- err := tx.Commit()
- if err != nil {
- return err
- }
-
- p.log.Infof("Storage schema upgrade to v%d completed", storageSchemaCurrentVersion)
- }
- } else {
- p.log.Debug("Storage schema is up to date")
- }
+ err = p.SchemaMigrate(ctx, true, SchemaLatest)
- return nil
+ switch err {
+ case ErrSchemaAlreadyUpToDate:
+ p.log.Infof("Storage schema is already up to date")
+ return nil
+ case nil:
+ return nil
+ default:
+ return err
+ }
}
-func (p *SQLProvider) handleUpgradeFailure(tx *sql.Tx, version SchemaVersion, err error) error {
- rollbackErr := tx.Rollback()
- formattedErr := fmt.Errorf("%s%d: %v", storageSchemaUpgradeErrorText, version, err)
-
- if rollbackErr != nil {
- return fmt.Errorf("rollback error occurred: %v (inner error %v)", rollbackErr, formattedErr)
- }
+// SavePreferred2FAMethod save the preferred method for 2FA to the database.
+func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) {
+ _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method)
- return formattedErr
+ return err
}
// LoadPreferred2FAMethod load the preferred method for 2FA from the database.
-func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) {
- var method string
+func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) {
+ err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username)
- rows, err := p.db.Query(p.sqlGetPreferencesByUsername, username)
- if err != nil {
+ switch err {
+ case sql.ErrNoRows:
+ return "", nil
+ case nil:
+ return method, err
+ default:
return "", err
}
- defer rows.Close()
+}
- if !rows.Next() {
- return "", nil
- }
+// LoadUserInfo loads the models.UserInfo from the database.
+func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error) {
+ err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username)
- err = rows.Scan(&method)
+ switch {
+ case err == nil:
+ return info, nil
+ case errors.Is(err, sql.ErrNoRows):
+ _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0])
+ if err != nil {
+ return models.UserInfo{}, err
+ }
+
+ err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username)
+ if err != nil {
+ return models.UserInfo{}, err
+ }
- return method, err
+ return info, nil
+ default:
+ return models.UserInfo{}, err
+ }
}
-// SavePreferred2FAMethod save the preferred method for 2FA to the database.
-func (p *SQLProvider) SavePreferred2FAMethod(username string, method string) error {
- _, err := p.db.Exec(p.sqlUpsertSecondFactorPreference, username, method)
+// SaveIdentityVerification save an identity verification record to the database.
+func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) {
+ _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, verification.Token)
+
return err
}
-// FindIdentityVerificationToken look for an identity verification token in the database.
-func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) {
- var found bool
+// RemoveIdentityVerification remove an identity verification record from the database.
+func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, token string) (err error) {
+ _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, token)
+
+ return err
+}
- err := p.db.QueryRow(p.sqlTestIdentityVerificationTokenExistence, token).Scan(&found)
+// FindIdentityVerification checks if an identity verification record is in the database and active.
+func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) {
+ err = p.db.GetContext(ctx, &found, p.sqlSelectExistsIdentityVerification, jti)
if err != nil {
return false, err
}
@@ -175,105 +201,94 @@ func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error)
return found, nil
}
-// SaveIdentityVerificationToken save an identity verification token in the database.
-func (p *SQLProvider) SaveIdentityVerificationToken(token string) error {
- _, err := p.db.Exec(p.sqlInsertIdentityVerificationToken, token)
- return err
-}
+// SaveTOTPConfiguration save a TOTP config of a given user in the database.
+func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error) {
+ // TODO: Encrypt config.Secret here.
+ _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
+ config.Username,
+ config.Algorithm,
+ config.Digits,
+ config.Period,
+ config.Secret,
+ )
-// RemoveIdentityVerificationToken remove an identity verification token from the database.
-func (p *SQLProvider) RemoveIdentityVerificationToken(token string) error {
- _, err := p.db.Exec(p.sqlDeleteIdentityVerificationToken, token)
return err
}
-// SaveTOTPSecret save a TOTP secret of a given user in the database.
-func (p *SQLProvider) SaveTOTPSecret(username string, secret string) error {
- _, err := p.db.Exec(p.sqlUpsertTOTPSecret, username, secret)
+// DeleteTOTPConfiguration delete a TOTP secret from the database given a username.
+func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) {
+ _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username)
+
return err
}
-// LoadTOTPSecret load a TOTP secret given a username from the database.
-func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) {
- var secret string
- if err := p.db.QueryRow(p.sqlGetTOTPSecretByUsername, username).Scan(&secret); err != nil {
+// LoadTOTPConfiguration load a TOTP secret given a username from the database.
+func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error) {
+ config = &models.TOTPConfiguration{}
+
+ err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config)
+ if err != nil {
if err == sql.ErrNoRows {
- return "", ErrNoTOTPSecret
+ return nil, ErrNoTOTPSecret
}
- return "", err
+ return nil, err
}
- return secret, nil
+ // TODO: Decrypt config.Secret here.
+ return config, nil
}
-// DeleteTOTPSecret delete a TOTP secret from the database given a username.
-func (p *SQLProvider) DeleteTOTPSecret(username string) error {
- _, err := p.db.Exec(p.sqlDeleteTOTPSecret, username)
- return err
-}
-
-// SaveU2FDeviceHandle save a registered U2F device registration blob.
-func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error {
- _, err := p.db.Exec(p.sqlUpsertU2FDeviceHandle,
- username,
- base64.StdEncoding.EncodeToString(keyHandle),
- base64.StdEncoding.EncodeToString(publicKey))
+// SaveU2FDevice saves a registered U2F device.
+func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) {
+ _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.KeyHandle, device.PublicKey)
return err
}
-// LoadU2FDeviceHandle load a U2F device registration blob for a given username.
-func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) {
- var keyHandleBase64, publicKeyBase64 string
- if err := p.db.QueryRow(p.sqlGetU2FDeviceHandleByUsername, username).Scan(&keyHandleBase64, &publicKeyBase64); err != nil {
- if err == sql.ErrNoRows {
- return nil, nil, ErrNoU2FDeviceHandle
- }
-
- return nil, nil, err
+// LoadU2FDevice loads a U2F device registration for a given username.
+func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) {
+ device = &models.U2FDevice{
+ Username: username,
}
- keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64)
-
+ err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username)
if err != nil {
- return nil, nil, err
- }
-
- publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64)
+ if err == sql.ErrNoRows {
+ return nil, ErrNoU2FDeviceHandle
+ }
- if err != nil {
- return nil, nil, err
+ return nil, err
}
- return keyHandle, publicKey, nil
+ return device, nil
}
// AppendAuthenticationLog append a mark to the authentication log.
-func (p *SQLProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error {
- _, err := p.db.Exec(p.sqlInsertAuthenticationLog, attempt.Username, attempt.Successful, attempt.Time.Unix())
+func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error) {
+ _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Username)
return err
}
-// LoadLatestAuthenticationLogs retrieve the latest marks from the authentication log.
-func (p *SQLProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) {
- var t int64
-
- rows, err := p.db.Query(p.sqlGetLatestAuthenticationLogs, fromDate.Unix(), username)
-
+// LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log.
+func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []models.AuthenticationAttempt, err error) {
+ rows, err := p.db.QueryxContext(ctx, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page)
if err != nil {
return nil, err
}
- attempts := make([]models.AuthenticationAttempt, 0, 10)
+ defer func() {
+ if err := rows.Close(); err != nil {
+ p.log.Errorf(logFmtErrClosingConn, err)
+ }
+ }()
+
+ attempts = make([]models.AuthenticationAttempt, 0, limit)
for rows.Next() {
- attempt := models.AuthenticationAttempt{
- Username: username,
- }
- err = rows.Scan(&attempt.Successful, &t)
- attempt.Time = time.Unix(t, 0)
+ var attempt models.AuthenticationAttempt
+ err = rows.StructScan(&attempt)
if err != nil {
return nil, err
}