summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/storage/sql_provider.go')
-rw-r--r--internal/storage/sql_provider.go974
1 files changed, 562 insertions, 412 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go
index a5f1d0727..78866451e 100644
--- a/internal/storage/sql_provider.go
+++ b/internal/storage/sql_provider.go
@@ -24,12 +24,16 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
provider = SQLProvider{
db: db,
- key: sha256.Sum256([]byte(config.Storage.EncryptionKey)),
name: name,
driverName: driverName,
config: config,
errOpen: err,
- log: logging.Logger(),
+
+ keys: SQLProviderKeys{
+ encryption: sha256.Sum256([]byte(config.Storage.EncryptionKey)),
+ },
+
+ log: logging.Logger(),
sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
@@ -38,6 +42,14 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification),
sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification),
+ sqlInsertOneTimeCode: fmt.Sprintf(queryFmtInsertOTC, tableOneTimeCode),
+ sqlConsumeOneTimeCode: fmt.Sprintf(queryFmtConsumeOTC, tableOneTimeCode),
+ sqlRevokeOneTimeCode: fmt.Sprintf(queryFmtRevokeOTC, tableOneTimeCode),
+ sqlSelectOneTimeCode: fmt.Sprintf(queryFmtSelectOTCBySignatureAndUsername, tableOneTimeCode),
+ sqlSelectOneTimeCodeBySignature: fmt.Sprintf(queryFmtSelectOTCBySignature, tableOneTimeCode),
+ sqlSelectOneTimeCodeByID: fmt.Sprintf(queryFmtSelectOTCByID, tableOneTimeCode),
+ sqlSelectOneTimeCodeByPublicID: fmt.Sprintf(queryFmtSelectOTCByPublicID, tableOneTimeCode),
+
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
@@ -140,14 +152,16 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
// SQLProvider is a storage provider persisting data in a SQL database.
type SQLProvider struct {
- db *sqlx.DB
- key [32]byte
+ db *sqlx.DB
+
name string
driverName string
schema string
config *schema.Configuration
errOpen error
+ keys SQLProviderKeys
+
log *logrus.Logger
// Table: authentication_logs.
@@ -159,6 +173,15 @@ type SQLProvider struct {
sqlConsumeIdentityVerification string
sqlSelectIdentityVerification string
+ // Table: one_time_code.
+ sqlInsertOneTimeCode string
+ sqlConsumeOneTimeCode string
+ sqlRevokeOneTimeCode string
+ sqlSelectOneTimeCode string
+ sqlSelectOneTimeCodeBySignature string
+ sqlSelectOneTimeCodeByID string
+ sqlSelectOneTimeCodeByPublicID string
+
// Table: totp_configurations.
sqlUpsertTOTPConfig string
sqlDeleteTOTPConfig string
@@ -276,9 +299,10 @@ type SQLProvider struct {
sqlFmtRenameTable string
}
-// Close the underlying database connection.
-func (p *SQLProvider) Close() (err error) {
- return p.db.Close()
+// SQLProviderKeys are the cryptography keys used by a SQLProvider.
+type SQLProviderKeys struct {
+ encryption [32]byte
+ signature []byte
}
// StartupCheck implements the provider startup check interface.
@@ -315,17 +339,22 @@ func (p *SQLProvider) StartupCheck() (err error) {
}
switch err = p.SchemaMigrate(ctx, true, SchemaLatest); err {
+ case nil:
+ break
case ErrSchemaAlreadyUpToDate:
p.log.Infof("Storage schema is already up to date")
- return nil
- case nil:
- return nil
default:
return fmt.Errorf("error during schema migrate: %w", err)
}
+
+ if p.keys.signature, err = p.getKeySigHMAC(ctx); err != nil {
+ return fmt.Errorf("failed to initialize the hmac signature key during startup: %w", err)
+ }
+
+ return nil
}
-// BeginTX begins a transaction.
+// BeginTX begins a transaction with the storage provider when applicable.
func (p *SQLProvider) BeginTX(ctx context.Context) (c context.Context, err error) {
var tx *sql.Tx
@@ -336,7 +365,7 @@ func (p *SQLProvider) BeginTX(ctx context.Context) (c context.Context, err error
return context.WithValue(ctx, ctxKeyTransaction, tx), nil
}
-// Commit performs a database commit.
+// Commit performs a storage provider commit when applicable.
func (p *SQLProvider) Commit(ctx context.Context) (err error) {
tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx)
@@ -347,7 +376,7 @@ func (p *SQLProvider) Commit(ctx context.Context) (err error) {
return tx.Commit()
}
-// Rollback performs a database rollback.
+// Rollback performs a storage provider rollback when applicable.
func (p *SQLProvider) Rollback(ctx context.Context) (err error) {
tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx)
@@ -358,7 +387,47 @@ func (p *SQLProvider) Rollback(ctx context.Context) (err error) {
return tx.Rollback()
}
-// SaveUserOpaqueIdentifier saves a new opaque user identifier to the database.
+// Close the underlying storage provider.
+func (p *SQLProvider) Close() (err error) {
+ return p.db.Close()
+}
+
+// SavePreferred2FAMethod save the preferred method for 2FA for a username to the storage provider.
+func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method); err != nil {
+ return fmt.Errorf("error upserting preferred two factor method for user '%s': %w", username, err)
+ }
+
+ return nil
+}
+
+// LoadPreferred2FAMethod load the preferred method for 2FA for a username from the storage provider.
+func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) {
+ err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username)
+
+ switch {
+ case err == nil:
+ return method, nil
+ case errors.Is(err, sql.ErrNoRows):
+ return "", sql.ErrNoRows
+ default:
+ return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err)
+ }
+}
+
+// LoadUserInfo loads the model.UserInfo from the storage provider.
+func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error) {
+ err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username)
+
+ switch {
+ case err == nil, errors.Is(err, sql.ErrNoRows):
+ return info, nil
+ default:
+ return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err)
+ }
+}
+
+// SaveUserOpaqueIdentifier saves a new opaque user identifier to the storage provider.
func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, subject.Service, subject.SectorID, subject.Username, subject.Identifier); err != nil {
return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", subject.Username, subject.Identifier.String(), err)
@@ -367,7 +436,7 @@ func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, subject mode
return nil
}
-// LoadUserOpaqueIdentifier selects an opaque user identifier from the database.
+// LoadUserOpaqueIdentifier selects an opaque user identifier from the storage provider.
func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier uuid.UUID) (subject *model.UserOpaqueIdentifier, err error) {
subject = &model.UserOpaqueIdentifier{}
@@ -383,7 +452,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier u
return subject, nil
}
-// LoadUserOpaqueIdentifiers selects an opaque user identifiers from the database.
+// LoadUserOpaqueIdentifiers selects an opaque user identifiers from the storage provider.
func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (identifiers []model.UserOpaqueIdentifier, err error) {
var rows *sqlx.Rows
@@ -406,7 +475,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (identifier
return identifiers, nil
}
-// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username.
+// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the storage provider given a service name, sector id, and username.
func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error) {
subject = &model.UserOpaqueIdentifier{}
@@ -422,57 +491,424 @@ func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, s
return subject, nil
}
-// SaveOAuth2ConsentSession inserts an OAuth2.0 consent session.
-func (p *SQLProvider) SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2ConsentSession,
- consent.ChallengeID, consent.ClientID, consent.Subject, consent.Authorized, consent.Granted,
- consent.RequestedAt, consent.RespondedAt, consent.Form,
- consent.RequestedScopes, consent.GrantedScopes, consent.RequestedAudience, consent.GrantedAudience, consent.PreConfiguration); err != nil {
- return fmt.Errorf("error inserting oauth2 consent session with challenge id '%s' for subject '%s': %w", consent.ChallengeID.String(), consent.Subject.UUID.String(), err)
+// SaveTOTPConfiguration save a TOTP configuration of a given user in the storage provider.
+func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
+ if config.Secret, err = p.encrypt(config.Secret); err != nil {
+ return fmt.Errorf("error encrypting TOTP configuration secret for user '%s': %w", config.Username, err)
+ }
+
+ if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
+ config.CreatedAt, config.LastUsedAt,
+ config.Username, config.Issuer,
+ config.Algorithm, config.Digits, config.Period, config.Secret); err != nil {
+ return fmt.Errorf("error upserting TOTP configuration for user '%s': %w", config.Username, err)
}
return nil
}
-// SaveOAuth2ConsentSessionSubject updates an OAuth2.0 consent session with the subject.
-func (p *SQLProvider) SaveOAuth2ConsentSessionSubject(ctx context.Context, consent model.OAuth2ConsentSession) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionSubject, consent.Subject, consent.ID); err != nil {
- return fmt.Errorf("error updating oauth2 consent session subject with id '%d' and challenge id '%s' for subject '%s': %w", consent.ID, consent.ChallengeID, consent.Subject.UUID, err)
+// UpdateTOTPConfigurationSignIn updates a registered TOTP configuration in the storage provider with the relevant sign in information.
+func (p *SQLProvider) UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigRecordSignIn, lastUsedAt, id); err != nil {
+ return fmt.Errorf("error updating TOTP configuration id %d: %w", id, err)
}
return nil
}
-// SaveOAuth2ConsentSessionResponse updates an OAuth2.0 consent session with the response.
-func (p *SQLProvider) SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, authorized bool) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionResponse, authorized, consent.GrantedScopes, consent.GrantedAudience, consent.PreConfiguration, consent.ID); err != nil {
- return fmt.Errorf("error updating oauth2 consent session (authorized '%t') with id '%d' and challenge id '%s' for subject '%s': %w", authorized, consent.ID, consent.ChallengeID, consent.Subject.UUID, err)
+// DeleteTOTPConfiguration delete a TOTP configuration from the storage provider given a username.
+func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username); err != nil {
+ return fmt.Errorf("error deleting TOTP configuration for user '%s': %w", username, err)
}
return nil
}
-// SaveOAuth2ConsentSessionGranted updates an OAuth2.0 consent recording that it has been granted by the authorization endpoint.
-func (p *SQLProvider) SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionGranted, id); err != nil {
- return fmt.Errorf("error updating oauth2 consent session (granted) with id '%d': %w", id, err)
+// LoadTOTPConfiguration load a TOTP configuration given a username from the storage provider.
+func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *model.TOTPConfiguration, err error) {
+ config = &model.TOTPConfiguration{}
+
+ if err = p.db.GetContext(ctx, config, p.sqlSelectTOTPConfig, username); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, ErrNoTOTPConfiguration
+ }
+
+ return nil, fmt.Errorf("error selecting TOTP configuration for user '%s': %w", username, err)
+ }
+
+ if config.Secret, err = p.decrypt(config.Secret); err != nil {
+ return nil, fmt.Errorf("error decrypting TOTP secret for user '%s': %w", username, err)
+ }
+
+ return config, nil
+}
+
+// LoadTOTPConfigurations load a set of TOTP configurations from the storage provider.
+func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []model.TOTPConfiguration, err error) {
+ configs = make([]model.TOTPConfiguration, 0, limit)
+
+ if err = p.db.SelectContext(ctx, &configs, p.sqlSelectTOTPConfigs, limit, limit*page); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting TOTP configurations: %w", err)
+ }
+
+ for i, c := range configs {
+ if configs[i].Secret, err = p.decrypt(c.Secret); err != nil {
+ return nil, fmt.Errorf("error decrypting TOTP configuration for user '%s': %w", c.Username, err)
+ }
+ }
+
+ return configs, nil
+}
+
+// SaveWebAuthnUser saves a registered WebAuthn user to the storage provider.
+func (p *SQLProvider) SaveWebAuthnUser(ctx context.Context, user model.WebAuthnUser) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnUser, user.RPID, user.Username, user.UserID); err != nil {
+ return fmt.Errorf("error inserting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err)
}
return nil
}
-// LoadOAuth2ConsentSessionByChallengeID returns an OAuth2.0 consent given the challenge ID.
-func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error) {
- consent = &model.OAuth2ConsentSession{}
+// LoadWebAuthnUser loads a registered WebAuthn user from the storage provider.
+func (p *SQLProvider) LoadWebAuthnUser(ctx context.Context, rpid, username string) (user *model.WebAuthnUser, err error) {
+ user = &model.WebAuthnUser{}
- if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil {
- return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err)
+ if err = p.db.GetContext(ctx, user, p.sqlSelectWebAuthnUser, rpid, username); err != nil {
+ switch {
+ case errors.Is(err, sql.ErrNoRows):
+ return nil, nil
+ default:
+ return nil, fmt.Errorf("error selecting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err)
+ }
}
- return consent, nil
+ return user, nil
+}
+
+// SaveWebAuthnCredential saves a registered WebAuthn credential to the storage provider.
+func (p *SQLProvider) SaveWebAuthnCredential(ctx context.Context, credential model.WebAuthnCredential) (err error) {
+ if credential.PublicKey, err = p.encrypt(credential.PublicKey); err != nil {
+ return fmt.Errorf("error encrypting WebAuthn credential public key for user '%s' kid '%x': %w", credential.Username, credential.KID, err)
+ }
+
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnCredential,
+ credential.CreatedAt, credential.LastUsedAt, credential.RPID, credential.Username, credential.Description,
+ credential.KID, credential.AAGUID, credential.AttestationType, credential.Attachment, credential.Transport,
+ credential.SignCount, credential.CloneWarning, credential.Discoverable, credential.Present, credential.Verified,
+ credential.BackupEligible, credential.BackupState, credential.PublicKey,
+ ); err != nil {
+ return fmt.Errorf("error inserting WebAuthn credential for user '%s' kid '%x': %w", credential.Username, credential.KID, err)
+ }
+
+ return nil
+}
+
+// UpdateWebAuthnCredentialDescription updates a registered WebAuthn credential in the storage provider changing the
+// description.
+func (p *SQLProvider) UpdateWebAuthnCredentialDescription(ctx context.Context, username string, credentialID int, description string) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialDescriptionByUsernameAndID, description, username, credentialID); err != nil {
+ return fmt.Errorf("error updating WebAuthn credential description to '%s' for credential id '%d': %w", description, credentialID, err)
+ }
+
+ return nil
+}
+
+// UpdateWebAuthnCredentialSignIn updates a registered WebAuthn credential in the storage provider changing the
+// information that should be changed in the event of a successful sign in.
+func (p *SQLProvider) UpdateWebAuthnCredentialSignIn(ctx context.Context, credential model.WebAuthnCredential) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialRecordSignIn,
+ credential.RPID, credential.LastUsedAt, credential.SignCount, credential.Discoverable, credential.Present, credential.Verified,
+ credential.BackupEligible, credential.BackupState, credential.CloneWarning, credential.ID,
+ ); err != nil {
+ return fmt.Errorf("error updating WebAuthn credentials authentication metadata for id '%x': %w", credential.ID, err)
+ }
+
+ return nil
+}
+
+// DeleteWebAuthnCredential deletes a registered WebAuthn credential from the storage provider.
+func (p *SQLProvider) DeleteWebAuthnCredential(ctx context.Context, kid string) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredential, kid); err != nil {
+ return fmt.Errorf("error deleting WebAuthn credential with kid '%s': %w", kid, err)
+ }
+
+ return nil
+}
+
+// DeleteWebAuthnCredentialByUsername deletes registered WebAuthn credential from the storage provider by username or
+// username and description.
+func (p *SQLProvider) DeleteWebAuthnCredentialByUsername(ctx context.Context, username, displayname string) (err error) {
+ if len(username) == 0 {
+ return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': username must not be empty", username, displayname)
+ }
+
+ if len(displayname) == 0 {
+ if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsername, username); err != nil {
+ return fmt.Errorf("error deleting WebAuthn credential for username '%s': %w", username, err)
+ }
+ } else {
+ if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsernameAndDisplayName, username, displayname); err != nil {
+ return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': %w", username, displayname, err)
+ }
+ }
+
+ return nil
+}
+
+// LoadWebAuthnCredentials loads WebAuthn credential registrations from the storage provider.
+func (p *SQLProvider) LoadWebAuthnCredentials(ctx context.Context, limit, page int) (credentials []model.WebAuthnCredential, err error) {
+ credentials = make([]model.WebAuthnCredential, 0, limit)
+
+ if err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentials, limit, limit*page); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting WebAuthn credentials: %w", err)
+ }
+
+ for i, credential := range credentials {
+ if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil {
+ return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err)
+ }
+ }
+
+ return credentials, nil
+}
+
+// LoadWebAuthnCredentialByID loads a WebAuthn credential registration from the storage provider for a given id.
+func (p *SQLProvider) LoadWebAuthnCredentialByID(ctx context.Context, id int) (credential *model.WebAuthnCredential, err error) {
+ credential = &model.WebAuthnCredential{}
+
+ if err = p.db.GetContext(ctx, credential, p.sqlSelectWebAuthnCredentialByID, id); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, sql.ErrNoRows
+ }
+
+ return nil, fmt.Errorf("error selecting WebAuthn credential with id '%d': %w", id, err)
+ }
+
+ return credential, nil
+}
+
+// LoadWebAuthnCredentialsByUsername loads all WebAuthn credential registrations from the storage provider for a
+// given username.
+func (p *SQLProvider) LoadWebAuthnCredentialsByUsername(ctx context.Context, rpid, username string) (credentials []model.WebAuthnCredential, err error) {
+ switch len(rpid) {
+ case 0:
+ err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByUsername, username)
+ default:
+ err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByRPIDByUsername, rpid, username)
+ }
+
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return credentials, ErrNoWebAuthnCredential
+ }
+
+ return nil, fmt.Errorf("error selecting WebAuthn credentials for user '%s': %w", username, err)
+ }
+
+ for i, credential := range credentials {
+ if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil {
+ return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err)
+ }
+ }
+
+ return credentials, nil
+}
+
+// SavePreferredDuoDevice saves a Duo device to the storage provider.
+func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device model.DuoDevice) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method); err != nil {
+ return fmt.Errorf("error upserting preferred duo device for user '%s': %w", device.Username, err)
+ }
+
+ return nil
+}
+
+// DeletePreferredDuoDevice deletes a Duo device from the storage provider for a given username.
+func (p *SQLProvider) DeletePreferredDuoDevice(ctx context.Context, username string) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlDeleteDuoDevice, username); err != nil {
+ return fmt.Errorf("error deleting preferred duo device for user '%s': %w", username, err)
+ }
+
+ return nil
+}
+
+// LoadPreferredDuoDevice loads a Duo device from the storage provider for a given username.
+func (p *SQLProvider) LoadPreferredDuoDevice(ctx context.Context, username string) (device *model.DuoDevice, err error) {
+ device = &model.DuoDevice{}
+
+ if err = p.db.QueryRowxContext(ctx, p.sqlSelectDuoDevice, username).StructScan(device); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, ErrNoDuoDevice
+ }
+
+ return nil, fmt.Errorf("error selecting preferred duo device for user '%s': %w", username, err)
+ }
+
+ return device, nil
+}
+
+// SaveIdentityVerification save an identity verification record to the storage provider.
+func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification,
+ verification.JTI, verification.IssuedAt, verification.IssuedIP, verification.ExpiresAt,
+ verification.Username, verification.Action); err != nil {
+ return fmt.Errorf("error inserting identity verification for user '%s' with uuid '%s': %w", verification.Username, verification.JTI, err)
+ }
+
+ return nil
}
-// SaveOAuth2ConsentPreConfiguration inserts an OAuth2.0 consent pre-configuration.
+// ConsumeIdentityVerification marks an identity verification record in the storage provider as consumed.
+func (p *SQLProvider) ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlConsumeIdentityVerification, ip, jti); err != nil {
+ return fmt.Errorf("error updating identity verification: %w", err)
+ }
+
+ return nil
+}
+
+// FindIdentityVerification checks if an identity verification record is in the storage provider and active.
+func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) {
+ verification := model.IdentityVerification{}
+ if err = p.db.GetContext(ctx, &verification, p.sqlSelectIdentityVerification, jti); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return false, nil
+ }
+
+ return false, fmt.Errorf("error selecting identity verification exists: %w", err)
+ }
+
+ switch {
+ case verification.Consumed.Valid:
+ return false, fmt.Errorf("the token has already been consumed")
+ case verification.ExpiresAt.Before(time.Now()):
+ return false, fmt.Errorf("the token expired %s ago", time.Since(verification.ExpiresAt))
+ default:
+ return true, nil
+ }
+}
+
+// SaveOneTimeCode saves a one-time code to the storage provider after generating the signature which is returned
+// along with any error.
+func (p *SQLProvider) SaveOneTimeCode(ctx context.Context, code model.OneTimeCode) (signature string, err error) {
+ code.Signature = p.hmacSignature([]byte(code.Username), []byte(code.Intent), code.Code)
+
+ if code.Code, err = p.encrypt(code.Code); err != nil {
+ return "", fmt.Errorf("error encrypting the one-time code value for user '%s' with signature '%s': %w", code.Username, code.Signature, err)
+ }
+
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertOneTimeCode,
+ code.PublicID, code.Signature, code.IssuedAt, code.IssuedIP, code.ExpiresAt,
+ code.Username, code.Intent, code.Code); err != nil {
+ return "", fmt.Errorf("error inserting one-time code for user '%s' with signature '%s': %w", code.Username, code.Signature, err)
+ }
+
+ return code.Signature, nil
+}
+
+// ConsumeOneTimeCode consumes a one-time code using the signature.
+func (p *SQLProvider) ConsumeOneTimeCode(ctx context.Context, code *model.OneTimeCode) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlConsumeOneTimeCode, code.ConsumedAt, code.ConsumedIP, code.Signature); err != nil {
+ return fmt.Errorf("error updating one-time code (consume): %w", err)
+ }
+
+ return nil
+}
+
+// RevokeOneTimeCode revokes a one-time code in the storage provider using the public ID.
+func (p *SQLProvider) RevokeOneTimeCode(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlRevokeOneTimeCode, ip, publicID); err != nil {
+ return fmt.Errorf("error updating one-time code (revoke): %w", err)
+ }
+
+ return nil
+}
+
+// LoadOneTimeCode loads a one-time code from the storage provider given a username, intent, and code.
+func (p *SQLProvider) LoadOneTimeCode(ctx context.Context, username, intent, raw string) (code *model.OneTimeCode, err error) {
+ code = &model.OneTimeCode{}
+
+ signature := p.hmacSignature([]byte(username), []byte(intent), []byte(raw))
+
+ if err = p.db.GetContext(ctx, code, p.sqlSelectOneTimeCode, signature, username); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting one-time code: %w", err)
+ }
+
+ if code.Code, err = p.decrypt(code.Code); err != nil {
+ return nil, fmt.Errorf("error decrypting the one-time code value for user '%s' with signature '%s': %w", code.Username, code.Signature, err)
+ }
+
+ return code, nil
+}
+
+// LoadOneTimeCodeBySignature loads a one-time code from the storage provider given the signature.
+// This method should NOT be used to validate a One-Time Code, LoadOneTimeCode should be used instead.
+func (p *SQLProvider) LoadOneTimeCodeBySignature(ctx context.Context, signature string) (code *model.OneTimeCode, err error) {
+ code = &model.OneTimeCode{}
+
+ if err = p.db.GetContext(ctx, code, p.sqlSelectOneTimeCodeBySignature, signature); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting one-time code: %w", err)
+ }
+
+ if code.Code, err = p.decrypt(code.Code); err != nil {
+ return nil, fmt.Errorf("error decrypting the one-time code value for user '%s' with signature '%s': %w", code.Username, code.Signature, err)
+ }
+
+ return code, nil
+}
+
+// LoadOneTimeCodeByID loads a one-time code from the storage provider given the id.
+// This does not decrypt the code. This method should NOT be used to validate a One-Time Code,
+// LoadOneTimeCode should be used instead.
+func (p *SQLProvider) LoadOneTimeCodeByID(ctx context.Context, id int) (code *model.OneTimeCode, err error) {
+ code = &model.OneTimeCode{}
+
+ if err = p.db.GetContext(ctx, code, p.sqlSelectOneTimeCodeByID, id); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting one-time code: %w", err)
+ }
+
+ return code, nil
+}
+
+// LoadOneTimeCodeByPublicID loads a one-time code from the storage provider given the public identifier.
+// This does not decrypt the code. This method SHOULD ONLY be used to find the One-Time Code for the
+// purpose of deletion.
+func (p *SQLProvider) LoadOneTimeCodeByPublicID(ctx context.Context, id uuid.UUID) (code *model.OneTimeCode, err error) {
+ code = &model.OneTimeCode{}
+
+ if err = p.db.GetContext(ctx, code, p.sqlSelectOneTimeCodeByPublicID, id); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting one-time code: %w", err)
+ }
+
+ return code, nil
+}
+
+// SaveOAuth2ConsentPreConfiguration inserts an OAuth2.0 consent pre-configuration in the storage provider.
func (p *SQLProvider) SaveOAuth2ConsentPreConfiguration(ctx context.Context, config model.OAuth2ConsentPreConfig) (insertedID int64, err error) {
switch p.name {
case providerPostgres:
@@ -496,7 +932,7 @@ func (p *SQLProvider) SaveOAuth2ConsentPreConfiguration(ctx context.Context, con
}
}
-// LoadOAuth2ConsentPreConfigurations returns an OAuth2.0 consents pre-configurations given the consent signature.
+// LoadOAuth2ConsentPreConfigurations returns an OAuth2.0 consents pre-configurations from the storage provider given the consent signature.
func (p *SQLProvider) LoadOAuth2ConsentPreConfigurations(ctx context.Context, clientID string, subject uuid.UUID) (rows *ConsentPreConfigRows, err error) {
var r *sqlx.Rows
@@ -511,7 +947,58 @@ func (p *SQLProvider) LoadOAuth2ConsentPreConfigurations(ctx context.Context, cl
return &ConsentPreConfigRows{rows: r}, nil
}
-// SaveOAuth2Session saves a OAuth2Session to the database.
+// SaveOAuth2ConsentSession inserts an OAuth2.0 consent session to the storage provider.
+func (p *SQLProvider) SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2ConsentSession,
+ consent.ChallengeID, consent.ClientID, consent.Subject, consent.Authorized, consent.Granted,
+ consent.RequestedAt, consent.RespondedAt, consent.Form,
+ consent.RequestedScopes, consent.GrantedScopes, consent.RequestedAudience, consent.GrantedAudience, consent.PreConfiguration); err != nil {
+ return fmt.Errorf("error inserting oauth2 consent session with challenge id '%s' for subject '%s': %w", consent.ChallengeID.String(), consent.Subject.UUID.String(), err)
+ }
+
+ return nil
+}
+
+// SaveOAuth2ConsentSessionSubject updates an OAuth2.0 consent session in the storage provider with the subject.
+func (p *SQLProvider) SaveOAuth2ConsentSessionSubject(ctx context.Context, consent model.OAuth2ConsentSession) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionSubject, consent.Subject, consent.ID); err != nil {
+ return fmt.Errorf("error updating oauth2 consent session subject with id '%d' and challenge id '%s' for subject '%s': %w", consent.ID, consent.ChallengeID, consent.Subject.UUID, err)
+ }
+
+ return nil
+}
+
+// SaveOAuth2ConsentSessionResponse updates an OAuth2.0 consent session in the storage provider with the response.
+func (p *SQLProvider) SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, authorized bool) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionResponse, authorized, consent.GrantedScopes, consent.GrantedAudience, consent.PreConfiguration, consent.ID); err != nil {
+ return fmt.Errorf("error updating oauth2 consent session (authorized '%t') with id '%d' and challenge id '%s' for subject '%s': %w", authorized, consent.ID, consent.ChallengeID, consent.Subject.UUID, err)
+ }
+
+ return nil
+}
+
+// SaveOAuth2ConsentSessionGranted updates an OAuth2.0 consent session in the storage provider recording that it
+// has been granted by the authorization endpoint.
+func (p *SQLProvider) SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionGranted, id); err != nil {
+ return fmt.Errorf("error updating oauth2 consent session (granted) with id '%d': %w", id, err)
+ }
+
+ return nil
+}
+
+// LoadOAuth2ConsentSessionByChallengeID returns an OAuth2.0 consent session in the storage provider given the challenge ID.
+func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error) {
+ consent = &model.OAuth2ConsentSession{}
+
+ if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil {
+ return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err)
+ }
+
+ return consent, nil
+}
+
+// SaveOAuth2Session saves an OAut2.0 session to the storage provider.
func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, session model.OAuth2Session) (err error) {
var query string
@@ -547,7 +1034,7 @@ func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2S
return nil
}
-// RevokeOAuth2Session marks a OAuth2Session as revoked in the database.
+// RevokeOAuth2Session marks an OAuth2.0 session as revoked in the storage provider.
func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) {
var query string
@@ -573,7 +1060,7 @@ func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth
return nil
}
-// RevokeOAuth2SessionByRequestID marks a OAuth2Session as revoked in the database.
+// RevokeOAuth2SessionByRequestID marks an OAuth2.0 session as revoked in the storage provider.
func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) {
var query string
@@ -599,7 +1086,7 @@ func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessio
return nil
}
-// DeactivateOAuth2Session marks a OAuth2Session as inactive in the database.
+// DeactivateOAuth2Session marks an OAuth2.0 session as inactive in the storage provider.
func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) {
var query string
@@ -625,7 +1112,7 @@ func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType O
return nil
}
-// DeactivateOAuth2SessionByRequestID marks a OAuth2Session as inactive in the database.
+// DeactivateOAuth2SessionByRequestID marks an OAuth2.0 session as inactive in the storage provider.
func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) {
var query string
@@ -651,7 +1138,7 @@ func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, se
return nil
}
-// LoadOAuth2Session saves a OAuth2Session from the database.
+// LoadOAuth2Session saves an OAuth2.0 session from the storage provider.
func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error) {
var query string
@@ -683,7 +1170,7 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S
return session, nil
}
-// SaveOAuth2PARContext save a OAuth2PARContext to the database.
+// SaveOAuth2PARContext save an OAuth2.0 PAR context to the storage provider.
func (p *SQLProvider) SaveOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) {
if par.Session, err = p.encrypt(par.Session); err != nil {
return fmt.Errorf("error encrypting oauth2 pushed authorization request context data for with signature '%s' and request id '%s': %w", par.Signature, par.RequestID, err)
@@ -698,26 +1185,7 @@ func (p *SQLProvider) SaveOAuth2PARContext(ctx context.Context, par model.OAuth2
return nil
}
-// UpdateOAuth2PARContext updates an existing OAuth2PARContext in the database.
-func (p *SQLProvider) UpdateOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) {
- if par.ID == 0 {
- return fmt.Errorf("error updating oauth2 pushed authorization request context data with signature '%s' and request id '%s': the id was a zero value", par.Signature, par.RequestID)
- }
-
- if par.Session, err = p.encrypt(par.Session); err != nil {
- return fmt.Errorf("error encrypting oauth2 pushed authorization request context data with id '%d' and signature '%s' and request id '%s': %w", par.ID, par.Signature, par.RequestID, err)
- }
-
- if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2PARContext,
- par.Signature, par.RequestID, par.ClientID, par.RequestedAt, par.Scopes, par.Audience, par.HandledResponseTypes,
- par.ResponseMode, par.DefaultResponseMode, par.Revoked, par.Form, par.Session, par.ID); err != nil {
- return fmt.Errorf("error updating oauth2 pushed authorization request context data with id '%d' and signature '%s' and request id '%s': %w", par.ID, par.Signature, par.RequestID, err)
- }
-
- return nil
-}
-
-// LoadOAuth2PARContext loads a OAuth2PARContext from the database.
+// LoadOAuth2PARContext loads an OAuth2.0 PAR context from the storage provider.
func (p *SQLProvider) LoadOAuth2PARContext(ctx context.Context, signature string) (par *model.OAuth2PARContext, err error) {
par = &model.OAuth2PARContext{}
@@ -732,7 +1200,7 @@ func (p *SQLProvider) LoadOAuth2PARContext(ctx context.Context, signature string
return par, nil
}
-// RevokeOAuth2PARContext marks a OAuth2PARContext as revoked in the database.
+// RevokeOAuth2PARContext marks an OAuth2.0 PAR context as revoked in the storage provider.
func (p *SQLProvider) RevokeOAuth2PARContext(ctx context.Context, signature string) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlRevokeOAuth2PARContext, signature); err != nil {
return fmt.Errorf("error revoking oauth2 pushed authorization request context with signature '%s': %w", signature, err)
@@ -741,364 +1209,46 @@ func (p *SQLProvider) RevokeOAuth2PARContext(ctx context.Context, signature stri
return nil
}
-// SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database.
-func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil {
- return fmt.Errorf("error inserting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err)
- }
-
- return nil
-}
-
-// LoadOAuth2BlacklistedJTI loads a OAuth2BlacklistedJTI from the database.
-func (p *SQLProvider) LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) {
- blacklistedJTI = &model.OAuth2BlacklistedJTI{}
-
- if err = p.db.GetContext(ctx, blacklistedJTI, p.sqlSelectOAuth2BlacklistedJTI, signature); err != nil {
- return nil, fmt.Errorf("error selecting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err)
- }
-
- return blacklistedJTI, nil
-}
-
-// SavePreferred2FAMethod save the preferred method for 2FA to the database.
-func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method); err != nil {
- return fmt.Errorf("error upserting preferred two factor method for user '%s': %w", username, err)
- }
-
- return nil
-}
-
-// LoadPreferred2FAMethod load the preferred method for 2FA from the database.
-func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) {
- err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username)
-
- switch {
- case err == nil:
- return method, nil
- case errors.Is(err, sql.ErrNoRows):
- return "", sql.ErrNoRows
- default:
- return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err)
- }
-}
-
-// LoadUserInfo loads the model.UserInfo from the database.
-func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error) {
- err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username)
-
- switch {
- case err == nil, errors.Is(err, sql.ErrNoRows):
- return info, nil
- default:
- return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err)
- }
-}
-
-// SaveIdentityVerification save an identity verification record to the database.
-func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification,
- verification.JTI, verification.IssuedAt, verification.IssuedIP, verification.ExpiresAt,
- verification.Username, verification.Action); err != nil {
- return fmt.Errorf("error inserting identity verification for user '%s' with uuid '%s': %w", verification.Username, verification.JTI, err)
- }
-
- return nil
-}
-
-// ConsumeIdentityVerification marks an identity verification record in the database as consumed.
-func (p *SQLProvider) ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlConsumeIdentityVerification, ip, jti); err != nil {
- return fmt.Errorf("error updating identity verification: %w", err)
- }
-
- return nil
-}
-
-// FindIdentityVerification checks if an identity verification record is in the database and active.
-func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) {
- verification := model.IdentityVerification{}
- if err = p.db.GetContext(ctx, &verification, p.sqlSelectIdentityVerification, jti); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return false, nil
- }
-
- return false, fmt.Errorf("error selecting identity verification exists: %w", err)
- }
-
- switch {
- case verification.Consumed.Valid:
- return false, fmt.Errorf("the token has already been consumed")
- case verification.ExpiresAt.Before(time.Now()):
- return false, fmt.Errorf("the token expired %s ago", time.Since(verification.ExpiresAt))
- default:
- return true, nil
- }
-}
-
-// SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
-func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
- if config.Secret, err = p.encrypt(config.Secret); err != nil {
- return fmt.Errorf("error encrypting TOTP configuration secret for user '%s': %w", config.Username, err)
- }
-
- if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
- config.CreatedAt, config.LastUsedAt,
- config.Username, config.Issuer,
- config.Algorithm, config.Digits, config.Period, config.Secret); err != nil {
- return fmt.Errorf("error upserting TOTP configuration for user '%s': %w", config.Username, err)
- }
-
- return nil
-}
-
-// UpdateTOTPConfigurationSignIn updates a registered TOTP configurations sign in information.
-func (p *SQLProvider) UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigRecordSignIn, lastUsedAt, id); err != nil {
- return fmt.Errorf("error updating TOTP configuration id %d: %w", id, err)
- }
-
- return nil
-}
-
-// DeleteTOTPConfiguration delete a TOTP configuration from the database given a username.
-func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username); err != nil {
- return fmt.Errorf("error deleting TOTP configuration for user '%s': %w", username, err)
- }
-
- return nil
-}
-
-// LoadTOTPConfiguration load a TOTP configuration given a username from the database.
-func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *model.TOTPConfiguration, err error) {
- config = &model.TOTPConfiguration{}
-
- if err = p.db.GetContext(ctx, config, p.sqlSelectTOTPConfig, username); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return nil, ErrNoTOTPConfiguration
- }
-
- return nil, fmt.Errorf("error selecting TOTP configuration for user '%s': %w", username, err)
- }
-
- if config.Secret, err = p.decrypt(config.Secret); err != nil {
- return nil, fmt.Errorf("error decrypting TOTP secret for user '%s': %w", username, err)
- }
-
- return config, nil
-}
-
-// LoadTOTPConfigurations load a set of TOTP configurations.
-func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []model.TOTPConfiguration, err error) {
- configs = make([]model.TOTPConfiguration, 0, limit)
-
- if err = p.db.SelectContext(ctx, &configs, p.sqlSelectTOTPConfigs, limit, limit*page); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return nil, nil
- }
-
- return nil, fmt.Errorf("error selecting TOTP configurations: %w", err)
- }
-
- for i, c := range configs {
- if configs[i].Secret, err = p.decrypt(c.Secret); err != nil {
- return nil, fmt.Errorf("error decrypting TOTP configuration for user '%s': %w", c.Username, err)
- }
- }
-
- return configs, nil
-}
-
-// SaveWebAuthnUser saves a registered WebAuthn user.
-func (p *SQLProvider) SaveWebAuthnUser(ctx context.Context, user model.WebAuthnUser) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnUser, user.RPID, user.Username, user.UserID); err != nil {
- return fmt.Errorf("error inserting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err)
- }
-
- return nil
-}
-
-// LoadWebAuthnUser loads a registered WebAuthn user.
-func (p *SQLProvider) LoadWebAuthnUser(ctx context.Context, rpid, username string) (user *model.WebAuthnUser, err error) {
- user = &model.WebAuthnUser{}
-
- if err = p.db.GetContext(ctx, user, p.sqlSelectWebAuthnUser, rpid, username); err != nil {
- switch {
- case errors.Is(err, sql.ErrNoRows):
- return nil, nil
- default:
- return nil, fmt.Errorf("error selecting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err)
- }
- }
-
- return user, nil
-}
-
-// SaveWebAuthnCredential saves a registered WebAuthn credential.
-func (p *SQLProvider) SaveWebAuthnCredential(ctx context.Context, credential model.WebAuthnCredential) (err error) {
- if credential.PublicKey, err = p.encrypt(credential.PublicKey); err != nil {
- return fmt.Errorf("error encrypting WebAuthn credential public key for user '%s' kid '%x': %w", credential.Username, credential.KID, err)
- }
-
- if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnCredential,
- credential.CreatedAt, credential.LastUsedAt, credential.RPID, credential.Username, credential.Description,
- credential.KID, credential.AAGUID, credential.AttestationType, credential.Attachment, credential.Transport,
- credential.SignCount, credential.CloneWarning, credential.Discoverable, credential.Present, credential.Verified,
- credential.BackupEligible, credential.BackupState, credential.PublicKey,
- ); err != nil {
- return fmt.Errorf("error inserting WebAuthn credential for user '%s' kid '%x': %w", credential.Username, credential.KID, err)
- }
-
- return nil
-}
-
-// UpdateWebAuthnCredentialDescription updates a registered WebAuthn credentials description.
-func (p *SQLProvider) UpdateWebAuthnCredentialDescription(ctx context.Context, username string, credentialID int, description string) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialDescriptionByUsernameAndID, description, username, credentialID); err != nil {
- return fmt.Errorf("error updating WebAuthn credential description to '%s' for credential id '%d': %w", description, credentialID, err)
- }
-
- return nil
-}
-
-// UpdateWebAuthnCredentialSignIn updates a registered WebAuthn credentials sign in information.
-func (p *SQLProvider) UpdateWebAuthnCredentialSignIn(ctx context.Context, credential model.WebAuthnCredential) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialRecordSignIn,
- credential.RPID, credential.LastUsedAt, credential.SignCount, credential.Discoverable, credential.Present, credential.Verified,
- credential.BackupEligible, credential.BackupState, credential.CloneWarning, credential.ID,
- ); err != nil {
- return fmt.Errorf("error updating WebAuthn credentials authentication metadata for id '%x': %w", credential.ID, err)
- }
-
- return nil
-}
-
-// DeleteWebAuthnCredential deletes a registered WebAuthn credential.
-func (p *SQLProvider) DeleteWebAuthnCredential(ctx context.Context, kid string) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredential, kid); err != nil {
- return fmt.Errorf("error deleting WebAuthn credential with kid '%s': %w", kid, err)
- }
-
- return nil
-}
-
-// DeleteWebAuthnCredentialByUsername deletes registered WebAuthn credential by username or username and description.
-func (p *SQLProvider) DeleteWebAuthnCredentialByUsername(ctx context.Context, username, displayname string) (err error) {
- if len(username) == 0 {
- return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': username must not be empty", username, displayname)
- }
-
- if len(displayname) == 0 {
- if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsername, username); err != nil {
- return fmt.Errorf("error deleting WebAuthn credential for username '%s': %w", username, err)
- }
- } else {
- if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsernameAndDisplayName, username, displayname); err != nil {
- return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': %w", username, displayname, err)
- }
- }
-
- return nil
-}
-
-// LoadWebAuthnCredentials loads WebAuthn credential registrations.
-func (p *SQLProvider) LoadWebAuthnCredentials(ctx context.Context, limit, page int) (credentials []model.WebAuthnCredential, err error) {
- credentials = make([]model.WebAuthnCredential, 0, limit)
-
- if err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentials, limit, limit*page); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return nil, nil
- }
-
- return nil, fmt.Errorf("error selecting WebAuthn credentials: %w", err)
- }
-
- for i, credential := range credentials {
- if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil {
- return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err)
- }
- }
-
- return credentials, nil
-}
-
-// LoadWebAuthnCredentialByID loads a WebAuthn credential registration for a given id.
-func (p *SQLProvider) LoadWebAuthnCredentialByID(ctx context.Context, id int) (credential *model.WebAuthnCredential, err error) {
- credential = &model.WebAuthnCredential{}
-
- if err = p.db.GetContext(ctx, credential, p.sqlSelectWebAuthnCredentialByID, id); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return nil, sql.ErrNoRows
- }
-
- return nil, fmt.Errorf("error selecting WebAuthn credential with id '%d': %w", id, err)
- }
-
- return credential, nil
-}
-
-// LoadWebAuthnCredentialsByUsername loads all WebAuthn credential registrations for a given username.
-func (p *SQLProvider) LoadWebAuthnCredentialsByUsername(ctx context.Context, rpid, username string) (credentials []model.WebAuthnCredential, err error) {
- switch len(rpid) {
- case 0:
- err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByUsername, username)
- default:
- err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByRPIDByUsername, rpid, username)
- }
-
- if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return credentials, ErrNoWebAuthnCredential
- }
-
- return nil, fmt.Errorf("error selecting WebAuthn credentials for user '%s': %w", username, err)
+// UpdateOAuth2PARContext updates an existing OAuth2.0 PAR context in the storage provider.
+func (p *SQLProvider) UpdateOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) {
+ if par.ID == 0 {
+ return fmt.Errorf("error updating oauth2 pushed authorization request context data with signature '%s' and request id '%s': the id was a zero value", par.Signature, par.RequestID)
}
- for i, credential := range credentials {
- if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil {
- return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err)
- }
+ if par.Session, err = p.encrypt(par.Session); err != nil {
+ return fmt.Errorf("error encrypting oauth2 pushed authorization request context data with id '%d' and signature '%s' and request id '%s': %w", par.ID, par.Signature, par.RequestID, err)
}
- return credentials, nil
-}
-
-// SavePreferredDuoDevice saves a Duo device.
-func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device model.DuoDevice) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method); err != nil {
- return fmt.Errorf("error upserting preferred duo device for user '%s': %w", device.Username, err)
+ if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2PARContext,
+ par.Signature, par.RequestID, par.ClientID, par.RequestedAt, par.Scopes, par.Audience, par.HandledResponseTypes,
+ par.ResponseMode, par.DefaultResponseMode, par.Revoked, par.Form, par.Session, par.ID); err != nil {
+ return fmt.Errorf("error updating oauth2 pushed authorization request context data with id '%d' and signature '%s' and request id '%s': %w", par.ID, par.Signature, par.RequestID, err)
}
return nil
}
-// DeletePreferredDuoDevice deletes a Duo device of a given user.
-func (p *SQLProvider) DeletePreferredDuoDevice(ctx context.Context, username string) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlDeleteDuoDevice, username); err != nil {
- return fmt.Errorf("error deleting preferred duo device for user '%s': %w", username, err)
+// SaveOAuth2BlacklistedJTI saves an OAuth2.0 blacklisted JTI to the storage provider.
+func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil {
+ return fmt.Errorf("error inserting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err)
}
return nil
}
-// LoadPreferredDuoDevice loads a Duo device of a given user.
-func (p *SQLProvider) LoadPreferredDuoDevice(ctx context.Context, username string) (device *model.DuoDevice, err error) {
- device = &model.DuoDevice{}
-
- if err = p.db.QueryRowxContext(ctx, p.sqlSelectDuoDevice, username).StructScan(device); err != nil {
- if err == sql.ErrNoRows {
- return nil, ErrNoDuoDevice
- }
+// LoadOAuth2BlacklistedJTI loads an OAuth2.0 blacklisted JTI from the storage provider.
+func (p *SQLProvider) LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) {
+ blacklistedJTI = &model.OAuth2BlacklistedJTI{}
- return nil, fmt.Errorf("error selecting preferred duo device for user '%s': %w", username, err)
+ if err = p.db.GetContext(ctx, blacklistedJTI, p.sqlSelectOAuth2BlacklistedJTI, signature); err != nil {
+ return nil, fmt.Errorf("error selecting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err)
}
- return device, nil
+ return blacklistedJTI, nil
}
-// AppendAuthenticationLog append a mark to the authentication log.
+// AppendAuthenticationLog saves an authentication attempt to the storage provider.
func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt model.AuthenticationAttempt) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt,
attempt.Time, attempt.Successful, attempt.Banned, attempt.Username,
@@ -1109,7 +1259,7 @@ func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt model
return nil
}
-// LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log.
+// LoadAuthenticationLogs loads authentication attempts from the storage provider (paginated).
func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []model.AuthenticationAttempt, err error) {
attempts = make([]model.AuthenticationAttempt, 0, limit)