diff options
Diffstat (limited to 'internal/storage/sql_provider.go')
| -rw-r--r-- | internal/storage/sql_provider.go | 974 |
1 files changed, 562 insertions, 412 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index a5f1d0727..78866451e 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -24,12 +24,16 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa provider = SQLProvider{ db: db, - key: sha256.Sum256([]byte(config.Storage.EncryptionKey)), name: name, driverName: driverName, config: config, errOpen: err, - log: logging.Logger(), + + keys: SQLProviderKeys{ + encryption: sha256.Sum256([]byte(config.Storage.EncryptionKey)), + }, + + log: logging.Logger(), sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs), sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs), @@ -38,6 +42,14 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification), sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification), + sqlInsertOneTimeCode: fmt.Sprintf(queryFmtInsertOTC, tableOneTimeCode), + sqlConsumeOneTimeCode: fmt.Sprintf(queryFmtConsumeOTC, tableOneTimeCode), + sqlRevokeOneTimeCode: fmt.Sprintf(queryFmtRevokeOTC, tableOneTimeCode), + sqlSelectOneTimeCode: fmt.Sprintf(queryFmtSelectOTCBySignatureAndUsername, tableOneTimeCode), + sqlSelectOneTimeCodeBySignature: fmt.Sprintf(queryFmtSelectOTCBySignature, tableOneTimeCode), + sqlSelectOneTimeCodeByID: fmt.Sprintf(queryFmtSelectOTCByID, tableOneTimeCode), + sqlSelectOneTimeCodeByPublicID: fmt.Sprintf(queryFmtSelectOTCByPublicID, tableOneTimeCode), + sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations), sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations), sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations), @@ -140,14 +152,16 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa // SQLProvider is a storage provider persisting data in a SQL database. type SQLProvider struct { - db *sqlx.DB - key [32]byte + db *sqlx.DB + name string driverName string schema string config *schema.Configuration errOpen error + keys SQLProviderKeys + log *logrus.Logger // Table: authentication_logs. @@ -159,6 +173,15 @@ type SQLProvider struct { sqlConsumeIdentityVerification string sqlSelectIdentityVerification string + // Table: one_time_code. + sqlInsertOneTimeCode string + sqlConsumeOneTimeCode string + sqlRevokeOneTimeCode string + sqlSelectOneTimeCode string + sqlSelectOneTimeCodeBySignature string + sqlSelectOneTimeCodeByID string + sqlSelectOneTimeCodeByPublicID string + // Table: totp_configurations. sqlUpsertTOTPConfig string sqlDeleteTOTPConfig string @@ -276,9 +299,10 @@ type SQLProvider struct { sqlFmtRenameTable string } -// Close the underlying database connection. -func (p *SQLProvider) Close() (err error) { - return p.db.Close() +// SQLProviderKeys are the cryptography keys used by a SQLProvider. +type SQLProviderKeys struct { + encryption [32]byte + signature []byte } // StartupCheck implements the provider startup check interface. @@ -315,17 +339,22 @@ func (p *SQLProvider) StartupCheck() (err error) { } switch err = p.SchemaMigrate(ctx, true, SchemaLatest); err { + case nil: + break case ErrSchemaAlreadyUpToDate: p.log.Infof("Storage schema is already up to date") - return nil - case nil: - return nil default: return fmt.Errorf("error during schema migrate: %w", err) } + + if p.keys.signature, err = p.getKeySigHMAC(ctx); err != nil { + return fmt.Errorf("failed to initialize the hmac signature key during startup: %w", err) + } + + return nil } -// BeginTX begins a transaction. +// BeginTX begins a transaction with the storage provider when applicable. func (p *SQLProvider) BeginTX(ctx context.Context) (c context.Context, err error) { var tx *sql.Tx @@ -336,7 +365,7 @@ func (p *SQLProvider) BeginTX(ctx context.Context) (c context.Context, err error return context.WithValue(ctx, ctxKeyTransaction, tx), nil } -// Commit performs a database commit. +// Commit performs a storage provider commit when applicable. func (p *SQLProvider) Commit(ctx context.Context) (err error) { tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx) @@ -347,7 +376,7 @@ func (p *SQLProvider) Commit(ctx context.Context) (err error) { return tx.Commit() } -// Rollback performs a database rollback. +// Rollback performs a storage provider rollback when applicable. func (p *SQLProvider) Rollback(ctx context.Context) (err error) { tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx) @@ -358,7 +387,47 @@ func (p *SQLProvider) Rollback(ctx context.Context) (err error) { return tx.Rollback() } -// SaveUserOpaqueIdentifier saves a new opaque user identifier to the database. +// Close the underlying storage provider. +func (p *SQLProvider) Close() (err error) { + return p.db.Close() +} + +// SavePreferred2FAMethod save the preferred method for 2FA for a username to the storage provider. +func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method); err != nil { + return fmt.Errorf("error upserting preferred two factor method for user '%s': %w", username, err) + } + + return nil +} + +// LoadPreferred2FAMethod load the preferred method for 2FA for a username from the storage provider. +func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) { + err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username) + + switch { + case err == nil: + return method, nil + case errors.Is(err, sql.ErrNoRows): + return "", sql.ErrNoRows + default: + return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err) + } +} + +// LoadUserInfo loads the model.UserInfo from the storage provider. +func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error) { + err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username) + + switch { + case err == nil, errors.Is(err, sql.ErrNoRows): + return info, nil + default: + return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err) + } +} + +// SaveUserOpaqueIdentifier saves a new opaque user identifier to the storage provider. func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, subject.Service, subject.SectorID, subject.Username, subject.Identifier); err != nil { return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", subject.Username, subject.Identifier.String(), err) @@ -367,7 +436,7 @@ func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, subject mode return nil } -// LoadUserOpaqueIdentifier selects an opaque user identifier from the database. +// LoadUserOpaqueIdentifier selects an opaque user identifier from the storage provider. func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier uuid.UUID) (subject *model.UserOpaqueIdentifier, err error) { subject = &model.UserOpaqueIdentifier{} @@ -383,7 +452,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier u return subject, nil } -// LoadUserOpaqueIdentifiers selects an opaque user identifiers from the database. +// LoadUserOpaqueIdentifiers selects an opaque user identifiers from the storage provider. func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (identifiers []model.UserOpaqueIdentifier, err error) { var rows *sqlx.Rows @@ -406,7 +475,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (identifier return identifiers, nil } -// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username. +// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the storage provider given a service name, sector id, and username. func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error) { subject = &model.UserOpaqueIdentifier{} @@ -422,57 +491,424 @@ func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, s return subject, nil } -// SaveOAuth2ConsentSession inserts an OAuth2.0 consent session. -func (p *SQLProvider) SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2ConsentSession, - consent.ChallengeID, consent.ClientID, consent.Subject, consent.Authorized, consent.Granted, - consent.RequestedAt, consent.RespondedAt, consent.Form, - consent.RequestedScopes, consent.GrantedScopes, consent.RequestedAudience, consent.GrantedAudience, consent.PreConfiguration); err != nil { - return fmt.Errorf("error inserting oauth2 consent session with challenge id '%s' for subject '%s': %w", consent.ChallengeID.String(), consent.Subject.UUID.String(), err) +// SaveTOTPConfiguration save a TOTP configuration of a given user in the storage provider. +func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) { + if config.Secret, err = p.encrypt(config.Secret); err != nil { + return fmt.Errorf("error encrypting TOTP configuration secret for user '%s': %w", config.Username, err) + } + + if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig, + config.CreatedAt, config.LastUsedAt, + config.Username, config.Issuer, + config.Algorithm, config.Digits, config.Period, config.Secret); err != nil { + return fmt.Errorf("error upserting TOTP configuration for user '%s': %w", config.Username, err) } return nil } -// SaveOAuth2ConsentSessionSubject updates an OAuth2.0 consent session with the subject. -func (p *SQLProvider) SaveOAuth2ConsentSessionSubject(ctx context.Context, consent model.OAuth2ConsentSession) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionSubject, consent.Subject, consent.ID); err != nil { - return fmt.Errorf("error updating oauth2 consent session subject with id '%d' and challenge id '%s' for subject '%s': %w", consent.ID, consent.ChallengeID, consent.Subject.UUID, err) +// UpdateTOTPConfigurationSignIn updates a registered TOTP configuration in the storage provider with the relevant sign in information. +func (p *SQLProvider) UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigRecordSignIn, lastUsedAt, id); err != nil { + return fmt.Errorf("error updating TOTP configuration id %d: %w", id, err) } return nil } -// SaveOAuth2ConsentSessionResponse updates an OAuth2.0 consent session with the response. -func (p *SQLProvider) SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, authorized bool) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionResponse, authorized, consent.GrantedScopes, consent.GrantedAudience, consent.PreConfiguration, consent.ID); err != nil { - return fmt.Errorf("error updating oauth2 consent session (authorized '%t') with id '%d' and challenge id '%s' for subject '%s': %w", authorized, consent.ID, consent.ChallengeID, consent.Subject.UUID, err) +// DeleteTOTPConfiguration delete a TOTP configuration from the storage provider given a username. +func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username); err != nil { + return fmt.Errorf("error deleting TOTP configuration for user '%s': %w", username, err) } return nil } -// SaveOAuth2ConsentSessionGranted updates an OAuth2.0 consent recording that it has been granted by the authorization endpoint. -func (p *SQLProvider) SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionGranted, id); err != nil { - return fmt.Errorf("error updating oauth2 consent session (granted) with id '%d': %w", id, err) +// LoadTOTPConfiguration load a TOTP configuration given a username from the storage provider. +func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *model.TOTPConfiguration, err error) { + config = &model.TOTPConfiguration{} + + if err = p.db.GetContext(ctx, config, p.sqlSelectTOTPConfig, username); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNoTOTPConfiguration + } + + return nil, fmt.Errorf("error selecting TOTP configuration for user '%s': %w", username, err) + } + + if config.Secret, err = p.decrypt(config.Secret); err != nil { + return nil, fmt.Errorf("error decrypting TOTP secret for user '%s': %w", username, err) + } + + return config, nil +} + +// LoadTOTPConfigurations load a set of TOTP configurations from the storage provider. +func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []model.TOTPConfiguration, err error) { + configs = make([]model.TOTPConfiguration, 0, limit) + + if err = p.db.SelectContext(ctx, &configs, p.sqlSelectTOTPConfigs, limit, limit*page); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting TOTP configurations: %w", err) + } + + for i, c := range configs { + if configs[i].Secret, err = p.decrypt(c.Secret); err != nil { + return nil, fmt.Errorf("error decrypting TOTP configuration for user '%s': %w", c.Username, err) + } + } + + return configs, nil +} + +// SaveWebAuthnUser saves a registered WebAuthn user to the storage provider. +func (p *SQLProvider) SaveWebAuthnUser(ctx context.Context, user model.WebAuthnUser) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnUser, user.RPID, user.Username, user.UserID); err != nil { + return fmt.Errorf("error inserting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err) } return nil } -// LoadOAuth2ConsentSessionByChallengeID returns an OAuth2.0 consent given the challenge ID. -func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error) { - consent = &model.OAuth2ConsentSession{} +// LoadWebAuthnUser loads a registered WebAuthn user from the storage provider. +func (p *SQLProvider) LoadWebAuthnUser(ctx context.Context, rpid, username string) (user *model.WebAuthnUser, err error) { + user = &model.WebAuthnUser{} - if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil { - return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err) + if err = p.db.GetContext(ctx, user, p.sqlSelectWebAuthnUser, rpid, username); err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, nil + default: + return nil, fmt.Errorf("error selecting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err) + } } - return consent, nil + return user, nil +} + +// SaveWebAuthnCredential saves a registered WebAuthn credential to the storage provider. +func (p *SQLProvider) SaveWebAuthnCredential(ctx context.Context, credential model.WebAuthnCredential) (err error) { + if credential.PublicKey, err = p.encrypt(credential.PublicKey); err != nil { + return fmt.Errorf("error encrypting WebAuthn credential public key for user '%s' kid '%x': %w", credential.Username, credential.KID, err) + } + + if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnCredential, + credential.CreatedAt, credential.LastUsedAt, credential.RPID, credential.Username, credential.Description, + credential.KID, credential.AAGUID, credential.AttestationType, credential.Attachment, credential.Transport, + credential.SignCount, credential.CloneWarning, credential.Discoverable, credential.Present, credential.Verified, + credential.BackupEligible, credential.BackupState, credential.PublicKey, + ); err != nil { + return fmt.Errorf("error inserting WebAuthn credential for user '%s' kid '%x': %w", credential.Username, credential.KID, err) + } + + return nil +} + +// UpdateWebAuthnCredentialDescription updates a registered WebAuthn credential in the storage provider changing the +// description. +func (p *SQLProvider) UpdateWebAuthnCredentialDescription(ctx context.Context, username string, credentialID int, description string) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialDescriptionByUsernameAndID, description, username, credentialID); err != nil { + return fmt.Errorf("error updating WebAuthn credential description to '%s' for credential id '%d': %w", description, credentialID, err) + } + + return nil +} + +// UpdateWebAuthnCredentialSignIn updates a registered WebAuthn credential in the storage provider changing the +// information that should be changed in the event of a successful sign in. +func (p *SQLProvider) UpdateWebAuthnCredentialSignIn(ctx context.Context, credential model.WebAuthnCredential) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialRecordSignIn, + credential.RPID, credential.LastUsedAt, credential.SignCount, credential.Discoverable, credential.Present, credential.Verified, + credential.BackupEligible, credential.BackupState, credential.CloneWarning, credential.ID, + ); err != nil { + return fmt.Errorf("error updating WebAuthn credentials authentication metadata for id '%x': %w", credential.ID, err) + } + + return nil +} + +// DeleteWebAuthnCredential deletes a registered WebAuthn credential from the storage provider. +func (p *SQLProvider) DeleteWebAuthnCredential(ctx context.Context, kid string) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredential, kid); err != nil { + return fmt.Errorf("error deleting WebAuthn credential with kid '%s': %w", kid, err) + } + + return nil +} + +// DeleteWebAuthnCredentialByUsername deletes registered WebAuthn credential from the storage provider by username or +// username and description. +func (p *SQLProvider) DeleteWebAuthnCredentialByUsername(ctx context.Context, username, displayname string) (err error) { + if len(username) == 0 { + return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': username must not be empty", username, displayname) + } + + if len(displayname) == 0 { + if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsername, username); err != nil { + return fmt.Errorf("error deleting WebAuthn credential for username '%s': %w", username, err) + } + } else { + if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsernameAndDisplayName, username, displayname); err != nil { + return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': %w", username, displayname, err) + } + } + + return nil +} + +// LoadWebAuthnCredentials loads WebAuthn credential registrations from the storage provider. +func (p *SQLProvider) LoadWebAuthnCredentials(ctx context.Context, limit, page int) (credentials []model.WebAuthnCredential, err error) { + credentials = make([]model.WebAuthnCredential, 0, limit) + + if err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentials, limit, limit*page); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting WebAuthn credentials: %w", err) + } + + for i, credential := range credentials { + if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil { + return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err) + } + } + + return credentials, nil +} + +// LoadWebAuthnCredentialByID loads a WebAuthn credential registration from the storage provider for a given id. +func (p *SQLProvider) LoadWebAuthnCredentialByID(ctx context.Context, id int) (credential *model.WebAuthnCredential, err error) { + credential = &model.WebAuthnCredential{} + + if err = p.db.GetContext(ctx, credential, p.sqlSelectWebAuthnCredentialByID, id); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, sql.ErrNoRows + } + + return nil, fmt.Errorf("error selecting WebAuthn credential with id '%d': %w", id, err) + } + + return credential, nil +} + +// LoadWebAuthnCredentialsByUsername loads all WebAuthn credential registrations from the storage provider for a +// given username. +func (p *SQLProvider) LoadWebAuthnCredentialsByUsername(ctx context.Context, rpid, username string) (credentials []model.WebAuthnCredential, err error) { + switch len(rpid) { + case 0: + err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByUsername, username) + default: + err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByRPIDByUsername, rpid, username) + } + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return credentials, ErrNoWebAuthnCredential + } + + return nil, fmt.Errorf("error selecting WebAuthn credentials for user '%s': %w", username, err) + } + + for i, credential := range credentials { + if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil { + return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err) + } + } + + return credentials, nil +} + +// SavePreferredDuoDevice saves a Duo device to the storage provider. +func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device model.DuoDevice) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method); err != nil { + return fmt.Errorf("error upserting preferred duo device for user '%s': %w", device.Username, err) + } + + return nil +} + +// DeletePreferredDuoDevice deletes a Duo device from the storage provider for a given username. +func (p *SQLProvider) DeletePreferredDuoDevice(ctx context.Context, username string) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlDeleteDuoDevice, username); err != nil { + return fmt.Errorf("error deleting preferred duo device for user '%s': %w", username, err) + } + + return nil +} + +// LoadPreferredDuoDevice loads a Duo device from the storage provider for a given username. +func (p *SQLProvider) LoadPreferredDuoDevice(ctx context.Context, username string) (device *model.DuoDevice, err error) { + device = &model.DuoDevice{} + + if err = p.db.QueryRowxContext(ctx, p.sqlSelectDuoDevice, username).StructScan(device); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNoDuoDevice + } + + return nil, fmt.Errorf("error selecting preferred duo device for user '%s': %w", username, err) + } + + return device, nil +} + +// SaveIdentityVerification save an identity verification record to the storage provider. +func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, + verification.JTI, verification.IssuedAt, verification.IssuedIP, verification.ExpiresAt, + verification.Username, verification.Action); err != nil { + return fmt.Errorf("error inserting identity verification for user '%s' with uuid '%s': %w", verification.Username, verification.JTI, err) + } + + return nil } -// SaveOAuth2ConsentPreConfiguration inserts an OAuth2.0 consent pre-configuration. +// ConsumeIdentityVerification marks an identity verification record in the storage provider as consumed. +func (p *SQLProvider) ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlConsumeIdentityVerification, ip, jti); err != nil { + return fmt.Errorf("error updating identity verification: %w", err) + } + + return nil +} + +// FindIdentityVerification checks if an identity verification record is in the storage provider and active. +func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) { + verification := model.IdentityVerification{} + if err = p.db.GetContext(ctx, &verification, p.sqlSelectIdentityVerification, jti); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + + return false, fmt.Errorf("error selecting identity verification exists: %w", err) + } + + switch { + case verification.Consumed.Valid: + return false, fmt.Errorf("the token has already been consumed") + case verification.ExpiresAt.Before(time.Now()): + return false, fmt.Errorf("the token expired %s ago", time.Since(verification.ExpiresAt)) + default: + return true, nil + } +} + +// SaveOneTimeCode saves a one-time code to the storage provider after generating the signature which is returned +// along with any error. +func (p *SQLProvider) SaveOneTimeCode(ctx context.Context, code model.OneTimeCode) (signature string, err error) { + code.Signature = p.hmacSignature([]byte(code.Username), []byte(code.Intent), code.Code) + + if code.Code, err = p.encrypt(code.Code); err != nil { + return "", fmt.Errorf("error encrypting the one-time code value for user '%s' with signature '%s': %w", code.Username, code.Signature, err) + } + + if _, err = p.db.ExecContext(ctx, p.sqlInsertOneTimeCode, + code.PublicID, code.Signature, code.IssuedAt, code.IssuedIP, code.ExpiresAt, + code.Username, code.Intent, code.Code); err != nil { + return "", fmt.Errorf("error inserting one-time code for user '%s' with signature '%s': %w", code.Username, code.Signature, err) + } + + return code.Signature, nil +} + +// ConsumeOneTimeCode consumes a one-time code using the signature. +func (p *SQLProvider) ConsumeOneTimeCode(ctx context.Context, code *model.OneTimeCode) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlConsumeOneTimeCode, code.ConsumedAt, code.ConsumedIP, code.Signature); err != nil { + return fmt.Errorf("error updating one-time code (consume): %w", err) + } + + return nil +} + +// RevokeOneTimeCode revokes a one-time code in the storage provider using the public ID. +func (p *SQLProvider) RevokeOneTimeCode(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlRevokeOneTimeCode, ip, publicID); err != nil { + return fmt.Errorf("error updating one-time code (revoke): %w", err) + } + + return nil +} + +// LoadOneTimeCode loads a one-time code from the storage provider given a username, intent, and code. +func (p *SQLProvider) LoadOneTimeCode(ctx context.Context, username, intent, raw string) (code *model.OneTimeCode, err error) { + code = &model.OneTimeCode{} + + signature := p.hmacSignature([]byte(username), []byte(intent), []byte(raw)) + + if err = p.db.GetContext(ctx, code, p.sqlSelectOneTimeCode, signature, username); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting one-time code: %w", err) + } + + if code.Code, err = p.decrypt(code.Code); err != nil { + return nil, fmt.Errorf("error decrypting the one-time code value for user '%s' with signature '%s': %w", code.Username, code.Signature, err) + } + + return code, nil +} + +// LoadOneTimeCodeBySignature loads a one-time code from the storage provider given the signature. +// This method should NOT be used to validate a One-Time Code, LoadOneTimeCode should be used instead. +func (p *SQLProvider) LoadOneTimeCodeBySignature(ctx context.Context, signature string) (code *model.OneTimeCode, err error) { + code = &model.OneTimeCode{} + + if err = p.db.GetContext(ctx, code, p.sqlSelectOneTimeCodeBySignature, signature); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting one-time code: %w", err) + } + + if code.Code, err = p.decrypt(code.Code); err != nil { + return nil, fmt.Errorf("error decrypting the one-time code value for user '%s' with signature '%s': %w", code.Username, code.Signature, err) + } + + return code, nil +} + +// LoadOneTimeCodeByID loads a one-time code from the storage provider given the id. +// This does not decrypt the code. This method should NOT be used to validate a One-Time Code, +// LoadOneTimeCode should be used instead. +func (p *SQLProvider) LoadOneTimeCodeByID(ctx context.Context, id int) (code *model.OneTimeCode, err error) { + code = &model.OneTimeCode{} + + if err = p.db.GetContext(ctx, code, p.sqlSelectOneTimeCodeByID, id); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting one-time code: %w", err) + } + + return code, nil +} + +// LoadOneTimeCodeByPublicID loads a one-time code from the storage provider given the public identifier. +// This does not decrypt the code. This method SHOULD ONLY be used to find the One-Time Code for the +// purpose of deletion. +func (p *SQLProvider) LoadOneTimeCodeByPublicID(ctx context.Context, id uuid.UUID) (code *model.OneTimeCode, err error) { + code = &model.OneTimeCode{} + + if err = p.db.GetContext(ctx, code, p.sqlSelectOneTimeCodeByPublicID, id); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting one-time code: %w", err) + } + + return code, nil +} + +// SaveOAuth2ConsentPreConfiguration inserts an OAuth2.0 consent pre-configuration in the storage provider. func (p *SQLProvider) SaveOAuth2ConsentPreConfiguration(ctx context.Context, config model.OAuth2ConsentPreConfig) (insertedID int64, err error) { switch p.name { case providerPostgres: @@ -496,7 +932,7 @@ func (p *SQLProvider) SaveOAuth2ConsentPreConfiguration(ctx context.Context, con } } -// LoadOAuth2ConsentPreConfigurations returns an OAuth2.0 consents pre-configurations given the consent signature. +// LoadOAuth2ConsentPreConfigurations returns an OAuth2.0 consents pre-configurations from the storage provider given the consent signature. func (p *SQLProvider) LoadOAuth2ConsentPreConfigurations(ctx context.Context, clientID string, subject uuid.UUID) (rows *ConsentPreConfigRows, err error) { var r *sqlx.Rows @@ -511,7 +947,58 @@ func (p *SQLProvider) LoadOAuth2ConsentPreConfigurations(ctx context.Context, cl return &ConsentPreConfigRows{rows: r}, nil } -// SaveOAuth2Session saves a OAuth2Session to the database. +// SaveOAuth2ConsentSession inserts an OAuth2.0 consent session to the storage provider. +func (p *SQLProvider) SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2ConsentSession, + consent.ChallengeID, consent.ClientID, consent.Subject, consent.Authorized, consent.Granted, + consent.RequestedAt, consent.RespondedAt, consent.Form, + consent.RequestedScopes, consent.GrantedScopes, consent.RequestedAudience, consent.GrantedAudience, consent.PreConfiguration); err != nil { + return fmt.Errorf("error inserting oauth2 consent session with challenge id '%s' for subject '%s': %w", consent.ChallengeID.String(), consent.Subject.UUID.String(), err) + } + + return nil +} + +// SaveOAuth2ConsentSessionSubject updates an OAuth2.0 consent session in the storage provider with the subject. +func (p *SQLProvider) SaveOAuth2ConsentSessionSubject(ctx context.Context, consent model.OAuth2ConsentSession) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionSubject, consent.Subject, consent.ID); err != nil { + return fmt.Errorf("error updating oauth2 consent session subject with id '%d' and challenge id '%s' for subject '%s': %w", consent.ID, consent.ChallengeID, consent.Subject.UUID, err) + } + + return nil +} + +// SaveOAuth2ConsentSessionResponse updates an OAuth2.0 consent session in the storage provider with the response. +func (p *SQLProvider) SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, authorized bool) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionResponse, authorized, consent.GrantedScopes, consent.GrantedAudience, consent.PreConfiguration, consent.ID); err != nil { + return fmt.Errorf("error updating oauth2 consent session (authorized '%t') with id '%d' and challenge id '%s' for subject '%s': %w", authorized, consent.ID, consent.ChallengeID, consent.Subject.UUID, err) + } + + return nil +} + +// SaveOAuth2ConsentSessionGranted updates an OAuth2.0 consent session in the storage provider recording that it +// has been granted by the authorization endpoint. +func (p *SQLProvider) SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionGranted, id); err != nil { + return fmt.Errorf("error updating oauth2 consent session (granted) with id '%d': %w", id, err) + } + + return nil +} + +// LoadOAuth2ConsentSessionByChallengeID returns an OAuth2.0 consent session in the storage provider given the challenge ID. +func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error) { + consent = &model.OAuth2ConsentSession{} + + if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil { + return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err) + } + + return consent, nil +} + +// SaveOAuth2Session saves an OAut2.0 session to the storage provider. func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, session model.OAuth2Session) (err error) { var query string @@ -547,7 +1034,7 @@ func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2S return nil } -// RevokeOAuth2Session marks a OAuth2Session as revoked in the database. +// RevokeOAuth2Session marks an OAuth2.0 session as revoked in the storage provider. func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) { var query string @@ -573,7 +1060,7 @@ func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth return nil } -// RevokeOAuth2SessionByRequestID marks a OAuth2Session as revoked in the database. +// RevokeOAuth2SessionByRequestID marks an OAuth2.0 session as revoked in the storage provider. func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) { var query string @@ -599,7 +1086,7 @@ func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessio return nil } -// DeactivateOAuth2Session marks a OAuth2Session as inactive in the database. +// DeactivateOAuth2Session marks an OAuth2.0 session as inactive in the storage provider. func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) { var query string @@ -625,7 +1112,7 @@ func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType O return nil } -// DeactivateOAuth2SessionByRequestID marks a OAuth2Session as inactive in the database. +// DeactivateOAuth2SessionByRequestID marks an OAuth2.0 session as inactive in the storage provider. func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) { var query string @@ -651,7 +1138,7 @@ func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, se return nil } -// LoadOAuth2Session saves a OAuth2Session from the database. +// LoadOAuth2Session saves an OAuth2.0 session from the storage provider. func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error) { var query string @@ -683,7 +1170,7 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S return session, nil } -// SaveOAuth2PARContext save a OAuth2PARContext to the database. +// SaveOAuth2PARContext save an OAuth2.0 PAR context to the storage provider. func (p *SQLProvider) SaveOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) { if par.Session, err = p.encrypt(par.Session); err != nil { return fmt.Errorf("error encrypting oauth2 pushed authorization request context data for with signature '%s' and request id '%s': %w", par.Signature, par.RequestID, err) @@ -698,26 +1185,7 @@ func (p *SQLProvider) SaveOAuth2PARContext(ctx context.Context, par model.OAuth2 return nil } -// UpdateOAuth2PARContext updates an existing OAuth2PARContext in the database. -func (p *SQLProvider) UpdateOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) { - if par.ID == 0 { - return fmt.Errorf("error updating oauth2 pushed authorization request context data with signature '%s' and request id '%s': the id was a zero value", par.Signature, par.RequestID) - } - - if par.Session, err = p.encrypt(par.Session); err != nil { - return fmt.Errorf("error encrypting oauth2 pushed authorization request context data with id '%d' and signature '%s' and request id '%s': %w", par.ID, par.Signature, par.RequestID, err) - } - - if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2PARContext, - par.Signature, par.RequestID, par.ClientID, par.RequestedAt, par.Scopes, par.Audience, par.HandledResponseTypes, - par.ResponseMode, par.DefaultResponseMode, par.Revoked, par.Form, par.Session, par.ID); err != nil { - return fmt.Errorf("error updating oauth2 pushed authorization request context data with id '%d' and signature '%s' and request id '%s': %w", par.ID, par.Signature, par.RequestID, err) - } - - return nil -} - -// LoadOAuth2PARContext loads a OAuth2PARContext from the database. +// LoadOAuth2PARContext loads an OAuth2.0 PAR context from the storage provider. func (p *SQLProvider) LoadOAuth2PARContext(ctx context.Context, signature string) (par *model.OAuth2PARContext, err error) { par = &model.OAuth2PARContext{} @@ -732,7 +1200,7 @@ func (p *SQLProvider) LoadOAuth2PARContext(ctx context.Context, signature string return par, nil } -// RevokeOAuth2PARContext marks a OAuth2PARContext as revoked in the database. +// RevokeOAuth2PARContext marks an OAuth2.0 PAR context as revoked in the storage provider. func (p *SQLProvider) RevokeOAuth2PARContext(ctx context.Context, signature string) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlRevokeOAuth2PARContext, signature); err != nil { return fmt.Errorf("error revoking oauth2 pushed authorization request context with signature '%s': %w", signature, err) @@ -741,364 +1209,46 @@ func (p *SQLProvider) RevokeOAuth2PARContext(ctx context.Context, signature stri return nil } -// SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database. -func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil { - return fmt.Errorf("error inserting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err) - } - - return nil -} - -// LoadOAuth2BlacklistedJTI loads a OAuth2BlacklistedJTI from the database. -func (p *SQLProvider) LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) { - blacklistedJTI = &model.OAuth2BlacklistedJTI{} - - if err = p.db.GetContext(ctx, blacklistedJTI, p.sqlSelectOAuth2BlacklistedJTI, signature); err != nil { - return nil, fmt.Errorf("error selecting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err) - } - - return blacklistedJTI, nil -} - -// SavePreferred2FAMethod save the preferred method for 2FA to the database. -func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method); err != nil { - return fmt.Errorf("error upserting preferred two factor method for user '%s': %w", username, err) - } - - return nil -} - -// LoadPreferred2FAMethod load the preferred method for 2FA from the database. -func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) { - err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username) - - switch { - case err == nil: - return method, nil - case errors.Is(err, sql.ErrNoRows): - return "", sql.ErrNoRows - default: - return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err) - } -} - -// LoadUserInfo loads the model.UserInfo from the database. -func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error) { - err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username) - - switch { - case err == nil, errors.Is(err, sql.ErrNoRows): - return info, nil - default: - return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err) - } -} - -// SaveIdentityVerification save an identity verification record to the database. -func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, - verification.JTI, verification.IssuedAt, verification.IssuedIP, verification.ExpiresAt, - verification.Username, verification.Action); err != nil { - return fmt.Errorf("error inserting identity verification for user '%s' with uuid '%s': %w", verification.Username, verification.JTI, err) - } - - return nil -} - -// ConsumeIdentityVerification marks an identity verification record in the database as consumed. -func (p *SQLProvider) ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlConsumeIdentityVerification, ip, jti); err != nil { - return fmt.Errorf("error updating identity verification: %w", err) - } - - return nil -} - -// FindIdentityVerification checks if an identity verification record is in the database and active. -func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) { - verification := model.IdentityVerification{} - if err = p.db.GetContext(ctx, &verification, p.sqlSelectIdentityVerification, jti); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return false, nil - } - - return false, fmt.Errorf("error selecting identity verification exists: %w", err) - } - - switch { - case verification.Consumed.Valid: - return false, fmt.Errorf("the token has already been consumed") - case verification.ExpiresAt.Before(time.Now()): - return false, fmt.Errorf("the token expired %s ago", time.Since(verification.ExpiresAt)) - default: - return true, nil - } -} - -// SaveTOTPConfiguration save a TOTP configuration of a given user in the database. -func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) { - if config.Secret, err = p.encrypt(config.Secret); err != nil { - return fmt.Errorf("error encrypting TOTP configuration secret for user '%s': %w", config.Username, err) - } - - if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig, - config.CreatedAt, config.LastUsedAt, - config.Username, config.Issuer, - config.Algorithm, config.Digits, config.Period, config.Secret); err != nil { - return fmt.Errorf("error upserting TOTP configuration for user '%s': %w", config.Username, err) - } - - return nil -} - -// UpdateTOTPConfigurationSignIn updates a registered TOTP configurations sign in information. -func (p *SQLProvider) UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigRecordSignIn, lastUsedAt, id); err != nil { - return fmt.Errorf("error updating TOTP configuration id %d: %w", id, err) - } - - return nil -} - -// DeleteTOTPConfiguration delete a TOTP configuration from the database given a username. -func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username); err != nil { - return fmt.Errorf("error deleting TOTP configuration for user '%s': %w", username, err) - } - - return nil -} - -// LoadTOTPConfiguration load a TOTP configuration given a username from the database. -func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *model.TOTPConfiguration, err error) { - config = &model.TOTPConfiguration{} - - if err = p.db.GetContext(ctx, config, p.sqlSelectTOTPConfig, username); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNoTOTPConfiguration - } - - return nil, fmt.Errorf("error selecting TOTP configuration for user '%s': %w", username, err) - } - - if config.Secret, err = p.decrypt(config.Secret); err != nil { - return nil, fmt.Errorf("error decrypting TOTP secret for user '%s': %w", username, err) - } - - return config, nil -} - -// LoadTOTPConfigurations load a set of TOTP configurations. -func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []model.TOTPConfiguration, err error) { - configs = make([]model.TOTPConfiguration, 0, limit) - - if err = p.db.SelectContext(ctx, &configs, p.sqlSelectTOTPConfigs, limit, limit*page); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - - return nil, fmt.Errorf("error selecting TOTP configurations: %w", err) - } - - for i, c := range configs { - if configs[i].Secret, err = p.decrypt(c.Secret); err != nil { - return nil, fmt.Errorf("error decrypting TOTP configuration for user '%s': %w", c.Username, err) - } - } - - return configs, nil -} - -// SaveWebAuthnUser saves a registered WebAuthn user. -func (p *SQLProvider) SaveWebAuthnUser(ctx context.Context, user model.WebAuthnUser) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnUser, user.RPID, user.Username, user.UserID); err != nil { - return fmt.Errorf("error inserting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err) - } - - return nil -} - -// LoadWebAuthnUser loads a registered WebAuthn user. -func (p *SQLProvider) LoadWebAuthnUser(ctx context.Context, rpid, username string) (user *model.WebAuthnUser, err error) { - user = &model.WebAuthnUser{} - - if err = p.db.GetContext(ctx, user, p.sqlSelectWebAuthnUser, rpid, username); err != nil { - switch { - case errors.Is(err, sql.ErrNoRows): - return nil, nil - default: - return nil, fmt.Errorf("error selecting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err) - } - } - - return user, nil -} - -// SaveWebAuthnCredential saves a registered WebAuthn credential. -func (p *SQLProvider) SaveWebAuthnCredential(ctx context.Context, credential model.WebAuthnCredential) (err error) { - if credential.PublicKey, err = p.encrypt(credential.PublicKey); err != nil { - return fmt.Errorf("error encrypting WebAuthn credential public key for user '%s' kid '%x': %w", credential.Username, credential.KID, err) - } - - if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnCredential, - credential.CreatedAt, credential.LastUsedAt, credential.RPID, credential.Username, credential.Description, - credential.KID, credential.AAGUID, credential.AttestationType, credential.Attachment, credential.Transport, - credential.SignCount, credential.CloneWarning, credential.Discoverable, credential.Present, credential.Verified, - credential.BackupEligible, credential.BackupState, credential.PublicKey, - ); err != nil { - return fmt.Errorf("error inserting WebAuthn credential for user '%s' kid '%x': %w", credential.Username, credential.KID, err) - } - - return nil -} - -// UpdateWebAuthnCredentialDescription updates a registered WebAuthn credentials description. -func (p *SQLProvider) UpdateWebAuthnCredentialDescription(ctx context.Context, username string, credentialID int, description string) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialDescriptionByUsernameAndID, description, username, credentialID); err != nil { - return fmt.Errorf("error updating WebAuthn credential description to '%s' for credential id '%d': %w", description, credentialID, err) - } - - return nil -} - -// UpdateWebAuthnCredentialSignIn updates a registered WebAuthn credentials sign in information. -func (p *SQLProvider) UpdateWebAuthnCredentialSignIn(ctx context.Context, credential model.WebAuthnCredential) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialRecordSignIn, - credential.RPID, credential.LastUsedAt, credential.SignCount, credential.Discoverable, credential.Present, credential.Verified, - credential.BackupEligible, credential.BackupState, credential.CloneWarning, credential.ID, - ); err != nil { - return fmt.Errorf("error updating WebAuthn credentials authentication metadata for id '%x': %w", credential.ID, err) - } - - return nil -} - -// DeleteWebAuthnCredential deletes a registered WebAuthn credential. -func (p *SQLProvider) DeleteWebAuthnCredential(ctx context.Context, kid string) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredential, kid); err != nil { - return fmt.Errorf("error deleting WebAuthn credential with kid '%s': %w", kid, err) - } - - return nil -} - -// DeleteWebAuthnCredentialByUsername deletes registered WebAuthn credential by username or username and description. -func (p *SQLProvider) DeleteWebAuthnCredentialByUsername(ctx context.Context, username, displayname string) (err error) { - if len(username) == 0 { - return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': username must not be empty", username, displayname) - } - - if len(displayname) == 0 { - if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsername, username); err != nil { - return fmt.Errorf("error deleting WebAuthn credential for username '%s': %w", username, err) - } - } else { - if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsernameAndDisplayName, username, displayname); err != nil { - return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': %w", username, displayname, err) - } - } - - return nil -} - -// LoadWebAuthnCredentials loads WebAuthn credential registrations. -func (p *SQLProvider) LoadWebAuthnCredentials(ctx context.Context, limit, page int) (credentials []model.WebAuthnCredential, err error) { - credentials = make([]model.WebAuthnCredential, 0, limit) - - if err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentials, limit, limit*page); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - - return nil, fmt.Errorf("error selecting WebAuthn credentials: %w", err) - } - - for i, credential := range credentials { - if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil { - return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err) - } - } - - return credentials, nil -} - -// LoadWebAuthnCredentialByID loads a WebAuthn credential registration for a given id. -func (p *SQLProvider) LoadWebAuthnCredentialByID(ctx context.Context, id int) (credential *model.WebAuthnCredential, err error) { - credential = &model.WebAuthnCredential{} - - if err = p.db.GetContext(ctx, credential, p.sqlSelectWebAuthnCredentialByID, id); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, sql.ErrNoRows - } - - return nil, fmt.Errorf("error selecting WebAuthn credential with id '%d': %w", id, err) - } - - return credential, nil -} - -// LoadWebAuthnCredentialsByUsername loads all WebAuthn credential registrations for a given username. -func (p *SQLProvider) LoadWebAuthnCredentialsByUsername(ctx context.Context, rpid, username string) (credentials []model.WebAuthnCredential, err error) { - switch len(rpid) { - case 0: - err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByUsername, username) - default: - err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByRPIDByUsername, rpid, username) - } - - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return credentials, ErrNoWebAuthnCredential - } - - return nil, fmt.Errorf("error selecting WebAuthn credentials for user '%s': %w", username, err) +// UpdateOAuth2PARContext updates an existing OAuth2.0 PAR context in the storage provider. +func (p *SQLProvider) UpdateOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) { + if par.ID == 0 { + return fmt.Errorf("error updating oauth2 pushed authorization request context data with signature '%s' and request id '%s': the id was a zero value", par.Signature, par.RequestID) } - for i, credential := range credentials { - if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil { - return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err) - } + if par.Session, err = p.encrypt(par.Session); err != nil { + return fmt.Errorf("error encrypting oauth2 pushed authorization request context data with id '%d' and signature '%s' and request id '%s': %w", par.ID, par.Signature, par.RequestID, err) } - return credentials, nil -} - -// SavePreferredDuoDevice saves a Duo device. -func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device model.DuoDevice) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method); err != nil { - return fmt.Errorf("error upserting preferred duo device for user '%s': %w", device.Username, err) + if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2PARContext, + par.Signature, par.RequestID, par.ClientID, par.RequestedAt, par.Scopes, par.Audience, par.HandledResponseTypes, + par.ResponseMode, par.DefaultResponseMode, par.Revoked, par.Form, par.Session, par.ID); err != nil { + return fmt.Errorf("error updating oauth2 pushed authorization request context data with id '%d' and signature '%s' and request id '%s': %w", par.ID, par.Signature, par.RequestID, err) } return nil } -// DeletePreferredDuoDevice deletes a Duo device of a given user. -func (p *SQLProvider) DeletePreferredDuoDevice(ctx context.Context, username string) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlDeleteDuoDevice, username); err != nil { - return fmt.Errorf("error deleting preferred duo device for user '%s': %w", username, err) +// SaveOAuth2BlacklistedJTI saves an OAuth2.0 blacklisted JTI to the storage provider. +func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil { + return fmt.Errorf("error inserting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err) } return nil } -// LoadPreferredDuoDevice loads a Duo device of a given user. -func (p *SQLProvider) LoadPreferredDuoDevice(ctx context.Context, username string) (device *model.DuoDevice, err error) { - device = &model.DuoDevice{} - - if err = p.db.QueryRowxContext(ctx, p.sqlSelectDuoDevice, username).StructScan(device); err != nil { - if err == sql.ErrNoRows { - return nil, ErrNoDuoDevice - } +// LoadOAuth2BlacklistedJTI loads an OAuth2.0 blacklisted JTI from the storage provider. +func (p *SQLProvider) LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) { + blacklistedJTI = &model.OAuth2BlacklistedJTI{} - return nil, fmt.Errorf("error selecting preferred duo device for user '%s': %w", username, err) + if err = p.db.GetContext(ctx, blacklistedJTI, p.sqlSelectOAuth2BlacklistedJTI, signature); err != nil { + return nil, fmt.Errorf("error selecting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err) } - return device, nil + return blacklistedJTI, nil } -// AppendAuthenticationLog append a mark to the authentication log. +// AppendAuthenticationLog saves an authentication attempt to the storage provider. func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt model.AuthenticationAttempt) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Banned, attempt.Username, @@ -1109,7 +1259,7 @@ func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt model return nil } -// LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log. +// LoadAuthenticationLogs loads authentication attempts from the storage provider (paginated). func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []model.AuthenticationAttempt, err error) { attempts = make([]model.AuthenticationAttempt, 0, limit) |
