summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider.go
diff options
context:
space:
mode:
authorJames Elliott <james-d-elliott@users.noreply.github.com>2023-10-27 20:20:29 +1100
committerJames Elliott <james-d-elliott@users.noreply.github.com>2024-03-04 20:28:24 +1100
commitc0dbdd97ab2ac580e3da07a0137dbc7a1b9c9b83 (patch)
tree57daff9cacd6a06524a87e40d9ee5d1dbcb483d3 /internal/storage/sql_provider.go
parent358b6679b545d5227a8d5bd2c9e0f95e59ebc4f7 (diff)
feat(web): multiple webauthn credential registration
This implements multiple WebAuthn Credential registrations by means of a generic user settings UI. Closes #275, Closes #4366 Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com> Co-authored-by: Clément Michaud <clement.michaud34@gmail.com> Co-authored-by: Stephen Kent <smkent@smkent.net> Co-authored-by: Amir Zarrinkafsh <nightah@me.com>
Diffstat (limited to 'internal/storage/sql_provider.go')
-rw-r--r--internal/storage/sql_provider.go204
1 files changed, 136 insertions, 68 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go
index 86fb1dacf..a5f1d0727 100644
--- a/internal/storage/sql_provider.go
+++ b/internal/storage/sql_provider.go
@@ -46,16 +46,19 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlUpdateTOTPConfigRecordSignIn: fmt.Sprintf(queryFmtUpdateTOTPConfigRecordSignIn, tableTOTPConfigurations),
sqlUpdateTOTPConfigRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigRecordSignInByUsername, tableTOTPConfigurations),
- sqlUpsertWebAuthnDevice: fmt.Sprintf(queryFmtUpsertWebAuthnDevice, tableWebAuthnDevices),
- sqlSelectWebAuthnDevices: fmt.Sprintf(queryFmtSelectWebAuthnDevices, tableWebAuthnDevices),
- sqlSelectWebAuthnDevicesByUsername: fmt.Sprintf(queryFmtSelectWebAuthnDevicesByUsername, tableWebAuthnDevices),
-
- sqlUpdateWebAuthnDeviceRecordSignIn: fmt.Sprintf(queryFmtUpdateWebAuthnDeviceRecordSignIn, tableWebAuthnDevices),
- sqlUpdateWebAuthnDeviceRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateWebAuthnDeviceRecordSignInByUsername, tableWebAuthnDevices),
-
- sqlDeleteWebAuthnDevice: fmt.Sprintf(queryFmtDeleteWebAuthnDevice, tableWebAuthnDevices),
- sqlDeleteWebAuthnDeviceByUsername: fmt.Sprintf(queryFmtDeleteWebAuthnDeviceByUsername, tableWebAuthnDevices),
- sqlDeleteWebAuthnDeviceByUsernameAndDescription: fmt.Sprintf(queryFmtDeleteWebAuthnDeviceByUsernameAndDescription, tableWebAuthnDevices),
+ sqlInsertWebAuthnUser: fmt.Sprintf(queryFmtInsertWebAuthnUser, tableWebAuthnUsers),
+ sqlSelectWebAuthnUser: fmt.Sprintf(queryFmtSelectWebAuthnUser, tableWebAuthnUsers),
+
+ sqlInsertWebAuthnCredential: fmt.Sprintf(queryFmtInsertWebAuthnCredential, tableWebAuthnCredentials),
+ sqlSelectWebAuthnCredentials: fmt.Sprintf(queryFmtSelectWebAuthnCredentials, tableWebAuthnCredentials),
+ sqlSelectWebAuthnCredentialsByUsername: fmt.Sprintf(queryFmtSelectWebAuthnCredentialsByUsername, tableWebAuthnCredentials),
+ sqlSelectWebAuthnCredentialsByRPIDByUsername: fmt.Sprintf(queryFmtSelectWebAuthnCredentialsByRPIDByUsername, tableWebAuthnCredentials),
+ sqlSelectWebAuthnCredentialByID: fmt.Sprintf(queryFmtSelectWebAuthnCredentialByID, tableWebAuthnCredentials),
+ sqlUpdateWebAuthnCredentialDescriptionByUsernameAndID: fmt.Sprintf(queryFmtUpdateUpdateWebAuthnCredentialDescriptionByUsernameAndID, tableWebAuthnCredentials),
+ sqlUpdateWebAuthnCredentialRecordSignIn: fmt.Sprintf(queryFmtUpdateWebAuthnCredentialRecordSignIn, tableWebAuthnCredentials),
+ sqlDeleteWebAuthnCredential: fmt.Sprintf(queryFmtDeleteWebAuthnCredential, tableWebAuthnCredentials),
+ sqlDeleteWebAuthnCredentialByUsername: fmt.Sprintf(queryFmtDeleteWebAuthnCredentialByUsername, tableWebAuthnCredentials),
+ sqlDeleteWebAuthnCredentialByUsernameAndDisplayName: fmt.Sprintf(queryFmtDeleteWebAuthnCredentialByUsernameAndDescription, tableWebAuthnCredentials),
sqlUpsertDuoDevice: fmt.Sprintf(queryFmtUpsertDuoDevice, tableDuoDevices),
sqlDeleteDuoDevice: fmt.Sprintf(queryFmtDeleteDuoDevice, tableDuoDevices),
@@ -63,7 +66,7 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlUpsertPreferred2FAMethod: fmt.Sprintf(queryFmtUpsertPreferred2FAMethod, tableUserPreferences),
sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences),
- sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableWebAuthnDevices, tableDuoDevices, tableUserPreferences),
+ sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableWebAuthnCredentials, tableDuoDevices, tableUserPreferences),
sqlInsertUserOpaqueIdentifier: fmt.Sprintf(queryFmtInsertUserOpaqueIdentifier, tableUserOpaqueIdentifier),
sqlSelectUserOpaqueIdentifier: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifier, tableUserOpaqueIdentifier),
@@ -165,17 +168,23 @@ type SQLProvider struct {
sqlUpdateTOTPConfigRecordSignIn string
sqlUpdateTOTPConfigRecordSignInByUsername string
- // Table: webauthn_devices.
- sqlUpsertWebAuthnDevice string
- sqlSelectWebAuthnDevices string
- sqlSelectWebAuthnDevicesByUsername string
+ // Table: webauthn_users.
+ sqlInsertWebAuthnUser string
+ sqlSelectWebAuthnUser string
+
+ // Table: webauthn_credentials.
+ sqlInsertWebAuthnCredential string
+ sqlSelectWebAuthnCredentials string
+ sqlSelectWebAuthnCredentialsByUsername string
+ sqlSelectWebAuthnCredentialsByRPIDByUsername string
+ sqlSelectWebAuthnCredentialByID string
- sqlUpdateWebAuthnDeviceRecordSignIn string
- sqlUpdateWebAuthnDeviceRecordSignInByUsername string
+ sqlUpdateWebAuthnCredentialDescriptionByUsernameAndID string
+ sqlUpdateWebAuthnCredentialRecordSignIn string
- sqlDeleteWebAuthnDevice string
- sqlDeleteWebAuthnDeviceByUsername string
- sqlDeleteWebAuthnDeviceByUsernameAndDescription string
+ sqlDeleteWebAuthnCredential string
+ sqlDeleteWebAuthnCredentialByUsername string
+ sqlDeleteWebAuthnCredentialByUsernameAndDisplayName string
// Table: duo_devices.
sqlUpsertDuoDevice string
@@ -367,7 +376,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier u
case errors.Is(err, sql.ErrNoRows):
return nil, nil
default:
- return nil, err
+ return nil, fmt.Errorf("error selecting user opaque id with value '%s': %w", identifier.String(), err)
}
}
@@ -406,7 +415,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, s
case errors.Is(err, sql.ErrNoRows):
return nil, nil
default:
- return nil, err
+ return nil, fmt.Errorf("error selecting user opaque with service '%s' and sector '%s' for username '%s': %w", service, sectorID, username, err)
}
}
@@ -844,7 +853,7 @@ func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TO
return nil
}
-// UpdateTOTPConfigurationSignIn updates a registered WebAuthn devices sign in information.
+// UpdateTOTPConfigurationSignIn updates a registered TOTP configurations sign in information.
func (p *SQLProvider) UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigRecordSignIn, lastUsedAt, id); err != nil {
return fmt.Errorf("error updating TOTP configuration id %d: %w", id, err)
@@ -866,7 +875,7 @@ func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username stri
func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *model.TOTPConfiguration, err error) {
config = &model.TOTPConfiguration{}
- if err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config); err != nil {
+ if err = p.db.GetContext(ctx, config, p.sqlSelectTOTPConfig, username); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNoTOTPConfiguration
}
@@ -902,99 +911,158 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in
return configs, nil
}
-// SaveWebAuthnDevice saves a registered WebAuthn device.
-func (p *SQLProvider) SaveWebAuthnDevice(ctx context.Context, device model.WebAuthnDevice) (err error) {
- if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil {
- return fmt.Errorf("error encrypting WebAuthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err)
+// SaveWebAuthnUser saves a registered WebAuthn user.
+func (p *SQLProvider) SaveWebAuthnUser(ctx context.Context, user model.WebAuthnUser) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnUser, user.RPID, user.Username, user.UserID); err != nil {
+ return fmt.Errorf("error inserting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err)
+ }
+
+ return nil
+}
+
+// LoadWebAuthnUser loads a registered WebAuthn user.
+func (p *SQLProvider) LoadWebAuthnUser(ctx context.Context, rpid, username string) (user *model.WebAuthnUser, err error) {
+ user = &model.WebAuthnUser{}
+
+ if err = p.db.GetContext(ctx, user, p.sqlSelectWebAuthnUser, rpid, username); err != nil {
+ switch {
+ case errors.Is(err, sql.ErrNoRows):
+ return nil, nil
+ default:
+ return nil, fmt.Errorf("error selecting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err)
+ }
+ }
+
+ return user, nil
+}
+
+// SaveWebAuthnCredential saves a registered WebAuthn credential.
+func (p *SQLProvider) SaveWebAuthnCredential(ctx context.Context, credential model.WebAuthnCredential) (err error) {
+ if credential.PublicKey, err = p.encrypt(credential.PublicKey); err != nil {
+ return fmt.Errorf("error encrypting WebAuthn credential public key for user '%s' kid '%x': %w", credential.Username, credential.KID, err)
}
- if _, err = p.db.ExecContext(ctx, p.sqlUpsertWebAuthnDevice,
- device.CreatedAt, device.LastUsedAt,
- device.RPID, device.Username, device.Description,
- device.KID, device.PublicKey,
- device.AttestationType, device.Transport, device.AAGUID, device.SignCount, device.CloneWarning,
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnCredential,
+ credential.CreatedAt, credential.LastUsedAt, credential.RPID, credential.Username, credential.Description,
+ credential.KID, credential.AAGUID, credential.AttestationType, credential.Attachment, credential.Transport,
+ credential.SignCount, credential.CloneWarning, credential.Discoverable, credential.Present, credential.Verified,
+ credential.BackupEligible, credential.BackupState, credential.PublicKey,
); err != nil {
- return fmt.Errorf("error upserting WebAuthn device for user '%s' kid '%x': %w", device.Username, device.KID, err)
+ return fmt.Errorf("error inserting WebAuthn credential for user '%s' kid '%x': %w", credential.Username, credential.KID, err)
}
return nil
}
-// UpdateWebAuthnDeviceSignIn updates a registered WebAuthn devices sign in information.
-func (p *SQLProvider) UpdateWebAuthnDeviceSignIn(ctx context.Context, id int, rpid string, lastUsedAt sql.NullTime, signCount uint32, cloneWarning bool) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnDeviceRecordSignIn, rpid, lastUsedAt, signCount, cloneWarning, id); err != nil {
- return fmt.Errorf("error updating WebAuthn signin metadata for id '%x': %w", id, err)
+// UpdateWebAuthnCredentialDescription updates a registered WebAuthn credentials description.
+func (p *SQLProvider) UpdateWebAuthnCredentialDescription(ctx context.Context, username string, credentialID int, description string) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialDescriptionByUsernameAndID, description, username, credentialID); err != nil {
+ return fmt.Errorf("error updating WebAuthn credential description to '%s' for credential id '%d': %w", description, credentialID, err)
}
return nil
}
-// DeleteWebAuthnDevice deletes a registered WebAuthn device.
-func (p *SQLProvider) DeleteWebAuthnDevice(ctx context.Context, kid string) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnDevice, kid); err != nil {
- return fmt.Errorf("error deleting WebAuthn device with kid '%s': %w", kid, err)
+// UpdateWebAuthnCredentialSignIn updates a registered WebAuthn credentials sign in information.
+func (p *SQLProvider) UpdateWebAuthnCredentialSignIn(ctx context.Context, credential model.WebAuthnCredential) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialRecordSignIn,
+ credential.RPID, credential.LastUsedAt, credential.SignCount, credential.Discoverable, credential.Present, credential.Verified,
+ credential.BackupEligible, credential.BackupState, credential.CloneWarning, credential.ID,
+ ); err != nil {
+ return fmt.Errorf("error updating WebAuthn credentials authentication metadata for id '%x': %w", credential.ID, err)
}
return nil
}
-// DeleteWebAuthnDeviceByUsername deletes registered WebAuthn devices by username or username and description.
-func (p *SQLProvider) DeleteWebAuthnDeviceByUsername(ctx context.Context, username, description string) (err error) {
+// DeleteWebAuthnCredential deletes a registered WebAuthn credential.
+func (p *SQLProvider) DeleteWebAuthnCredential(ctx context.Context, kid string) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredential, kid); err != nil {
+ return fmt.Errorf("error deleting WebAuthn credential with kid '%s': %w", kid, err)
+ }
+
+ return nil
+}
+
+// DeleteWebAuthnCredentialByUsername deletes registered WebAuthn credential by username or username and description.
+func (p *SQLProvider) DeleteWebAuthnCredentialByUsername(ctx context.Context, username, displayname string) (err error) {
if len(username) == 0 {
- return fmt.Errorf("error deleting WebAuthn device with username '%s' and description '%s': username must not be empty", username, description)
+ return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': username must not be empty", username, displayname)
}
- if len(description) == 0 {
- if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnDeviceByUsername, username); err != nil {
- return fmt.Errorf("error deleting WebAuthn devices for username '%s': %w", username, err)
+ if len(displayname) == 0 {
+ if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsername, username); err != nil {
+ return fmt.Errorf("error deleting WebAuthn credential for username '%s': %w", username, err)
}
} else {
- if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnDeviceByUsernameAndDescription, username, description); err != nil {
- return fmt.Errorf("error deleting WebAuthn device with username '%s' and description '%s': %w", username, description, err)
+ if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsernameAndDisplayName, username, displayname); err != nil {
+ return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': %w", username, displayname, err)
}
}
return nil
}
-// LoadWebAuthnDevices loads WebAuthn device registrations.
-func (p *SQLProvider) LoadWebAuthnDevices(ctx context.Context, limit, page int) (devices []model.WebAuthnDevice, err error) {
- devices = make([]model.WebAuthnDevice, 0, limit)
+// LoadWebAuthnCredentials loads WebAuthn credential registrations.
+func (p *SQLProvider) LoadWebAuthnCredentials(ctx context.Context, limit, page int) (credentials []model.WebAuthnCredential, err error) {
+ credentials = make([]model.WebAuthnCredential, 0, limit)
- if err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebAuthnDevices, limit, limit*page); err != nil {
+ if err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentials, limit, limit*page); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
- return nil, fmt.Errorf("error selecting WebAuthn devices: %w", err)
+ return nil, fmt.Errorf("error selecting WebAuthn credentials: %w", err)
+ }
+
+ for i, credential := range credentials {
+ if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil {
+ return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err)
+ }
}
- for i, device := range devices {
- if devices[i].PublicKey, err = p.decrypt(device.PublicKey); err != nil {
- return nil, fmt.Errorf("error decrypting WebAuthn public key for user '%s': %w", device.Username, err)
+ return credentials, nil
+}
+
+// LoadWebAuthnCredentialByID loads a WebAuthn credential registration for a given id.
+func (p *SQLProvider) LoadWebAuthnCredentialByID(ctx context.Context, id int) (credential *model.WebAuthnCredential, err error) {
+ credential = &model.WebAuthnCredential{}
+
+ if err = p.db.GetContext(ctx, credential, p.sqlSelectWebAuthnCredentialByID, id); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, sql.ErrNoRows
}
+
+ return nil, fmt.Errorf("error selecting WebAuthn credential with id '%d': %w", id, err)
}
- return devices, nil
+ return credential, nil
}
-// LoadWebAuthnDevicesByUsername loads all WebAuthn devices registration for a given username.
-func (p *SQLProvider) LoadWebAuthnDevicesByUsername(ctx context.Context, username string) (devices []model.WebAuthnDevice, err error) {
- if err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebAuthnDevicesByUsername, username); err != nil {
+// LoadWebAuthnCredentialsByUsername loads all WebAuthn credential registrations for a given username.
+func (p *SQLProvider) LoadWebAuthnCredentialsByUsername(ctx context.Context, rpid, username string) (credentials []model.WebAuthnCredential, err error) {
+ switch len(rpid) {
+ case 0:
+ err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByUsername, username)
+ default:
+ err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByRPIDByUsername, rpid, username)
+ }
+
+ if err != nil {
if errors.Is(err, sql.ErrNoRows) {
- return nil, ErrNoWebAuthnDevice
+ return credentials, ErrNoWebAuthnCredential
}
- return nil, fmt.Errorf("error selecting WebAuthn devices for user '%s': %w", username, err)
+ return nil, fmt.Errorf("error selecting WebAuthn credentials for user '%s': %w", username, err)
}
- for i, device := range devices {
- if devices[i].PublicKey, err = p.decrypt(device.PublicKey); err != nil {
- return nil, fmt.Errorf("error decrypting WebAuthn public key for user '%s': %w", username, err)
+ for i, credential := range credentials {
+ if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil {
+ return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err)
}
}
- return devices, nil
+ return credentials, nil
}
// SavePreferredDuoDevice saves a Duo device.