diff options
| author | James Elliott <james-d-elliott@users.noreply.github.com> | 2023-10-27 20:20:29 +1100 |
|---|---|---|
| committer | James Elliott <james-d-elliott@users.noreply.github.com> | 2024-03-04 20:28:24 +1100 |
| commit | c0dbdd97ab2ac580e3da07a0137dbc7a1b9c9b83 (patch) | |
| tree | 57daff9cacd6a06524a87e40d9ee5d1dbcb483d3 /internal/storage/sql_provider.go | |
| parent | 358b6679b545d5227a8d5bd2c9e0f95e59ebc4f7 (diff) | |
feat(web): multiple webauthn credential registration
This implements multiple WebAuthn Credential registrations by means of a generic user settings UI.
Closes #275, Closes #4366
Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
Co-authored-by: Clément Michaud <clement.michaud34@gmail.com>
Co-authored-by: Stephen Kent <smkent@smkent.net>
Co-authored-by: Amir Zarrinkafsh <nightah@me.com>
Diffstat (limited to 'internal/storage/sql_provider.go')
| -rw-r--r-- | internal/storage/sql_provider.go | 204 |
1 files changed, 136 insertions, 68 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index 86fb1dacf..a5f1d0727 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -46,16 +46,19 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlUpdateTOTPConfigRecordSignIn: fmt.Sprintf(queryFmtUpdateTOTPConfigRecordSignIn, tableTOTPConfigurations), sqlUpdateTOTPConfigRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigRecordSignInByUsername, tableTOTPConfigurations), - sqlUpsertWebAuthnDevice: fmt.Sprintf(queryFmtUpsertWebAuthnDevice, tableWebAuthnDevices), - sqlSelectWebAuthnDevices: fmt.Sprintf(queryFmtSelectWebAuthnDevices, tableWebAuthnDevices), - sqlSelectWebAuthnDevicesByUsername: fmt.Sprintf(queryFmtSelectWebAuthnDevicesByUsername, tableWebAuthnDevices), - - sqlUpdateWebAuthnDeviceRecordSignIn: fmt.Sprintf(queryFmtUpdateWebAuthnDeviceRecordSignIn, tableWebAuthnDevices), - sqlUpdateWebAuthnDeviceRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateWebAuthnDeviceRecordSignInByUsername, tableWebAuthnDevices), - - sqlDeleteWebAuthnDevice: fmt.Sprintf(queryFmtDeleteWebAuthnDevice, tableWebAuthnDevices), - sqlDeleteWebAuthnDeviceByUsername: fmt.Sprintf(queryFmtDeleteWebAuthnDeviceByUsername, tableWebAuthnDevices), - sqlDeleteWebAuthnDeviceByUsernameAndDescription: fmt.Sprintf(queryFmtDeleteWebAuthnDeviceByUsernameAndDescription, tableWebAuthnDevices), + sqlInsertWebAuthnUser: fmt.Sprintf(queryFmtInsertWebAuthnUser, tableWebAuthnUsers), + sqlSelectWebAuthnUser: fmt.Sprintf(queryFmtSelectWebAuthnUser, tableWebAuthnUsers), + + sqlInsertWebAuthnCredential: fmt.Sprintf(queryFmtInsertWebAuthnCredential, tableWebAuthnCredentials), + sqlSelectWebAuthnCredentials: fmt.Sprintf(queryFmtSelectWebAuthnCredentials, tableWebAuthnCredentials), + sqlSelectWebAuthnCredentialsByUsername: fmt.Sprintf(queryFmtSelectWebAuthnCredentialsByUsername, tableWebAuthnCredentials), + sqlSelectWebAuthnCredentialsByRPIDByUsername: fmt.Sprintf(queryFmtSelectWebAuthnCredentialsByRPIDByUsername, tableWebAuthnCredentials), + sqlSelectWebAuthnCredentialByID: fmt.Sprintf(queryFmtSelectWebAuthnCredentialByID, tableWebAuthnCredentials), + sqlUpdateWebAuthnCredentialDescriptionByUsernameAndID: fmt.Sprintf(queryFmtUpdateUpdateWebAuthnCredentialDescriptionByUsernameAndID, tableWebAuthnCredentials), + sqlUpdateWebAuthnCredentialRecordSignIn: fmt.Sprintf(queryFmtUpdateWebAuthnCredentialRecordSignIn, tableWebAuthnCredentials), + sqlDeleteWebAuthnCredential: fmt.Sprintf(queryFmtDeleteWebAuthnCredential, tableWebAuthnCredentials), + sqlDeleteWebAuthnCredentialByUsername: fmt.Sprintf(queryFmtDeleteWebAuthnCredentialByUsername, tableWebAuthnCredentials), + sqlDeleteWebAuthnCredentialByUsernameAndDisplayName: fmt.Sprintf(queryFmtDeleteWebAuthnCredentialByUsernameAndDescription, tableWebAuthnCredentials), sqlUpsertDuoDevice: fmt.Sprintf(queryFmtUpsertDuoDevice, tableDuoDevices), sqlDeleteDuoDevice: fmt.Sprintf(queryFmtDeleteDuoDevice, tableDuoDevices), @@ -63,7 +66,7 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlUpsertPreferred2FAMethod: fmt.Sprintf(queryFmtUpsertPreferred2FAMethod, tableUserPreferences), sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences), - sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableWebAuthnDevices, tableDuoDevices, tableUserPreferences), + sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableWebAuthnCredentials, tableDuoDevices, tableUserPreferences), sqlInsertUserOpaqueIdentifier: fmt.Sprintf(queryFmtInsertUserOpaqueIdentifier, tableUserOpaqueIdentifier), sqlSelectUserOpaqueIdentifier: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifier, tableUserOpaqueIdentifier), @@ -165,17 +168,23 @@ type SQLProvider struct { sqlUpdateTOTPConfigRecordSignIn string sqlUpdateTOTPConfigRecordSignInByUsername string - // Table: webauthn_devices. - sqlUpsertWebAuthnDevice string - sqlSelectWebAuthnDevices string - sqlSelectWebAuthnDevicesByUsername string + // Table: webauthn_users. + sqlInsertWebAuthnUser string + sqlSelectWebAuthnUser string + + // Table: webauthn_credentials. + sqlInsertWebAuthnCredential string + sqlSelectWebAuthnCredentials string + sqlSelectWebAuthnCredentialsByUsername string + sqlSelectWebAuthnCredentialsByRPIDByUsername string + sqlSelectWebAuthnCredentialByID string - sqlUpdateWebAuthnDeviceRecordSignIn string - sqlUpdateWebAuthnDeviceRecordSignInByUsername string + sqlUpdateWebAuthnCredentialDescriptionByUsernameAndID string + sqlUpdateWebAuthnCredentialRecordSignIn string - sqlDeleteWebAuthnDevice string - sqlDeleteWebAuthnDeviceByUsername string - sqlDeleteWebAuthnDeviceByUsernameAndDescription string + sqlDeleteWebAuthnCredential string + sqlDeleteWebAuthnCredentialByUsername string + sqlDeleteWebAuthnCredentialByUsernameAndDisplayName string // Table: duo_devices. sqlUpsertDuoDevice string @@ -367,7 +376,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier u case errors.Is(err, sql.ErrNoRows): return nil, nil default: - return nil, err + return nil, fmt.Errorf("error selecting user opaque id with value '%s': %w", identifier.String(), err) } } @@ -406,7 +415,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, s case errors.Is(err, sql.ErrNoRows): return nil, nil default: - return nil, err + return nil, fmt.Errorf("error selecting user opaque with service '%s' and sector '%s' for username '%s': %w", service, sectorID, username, err) } } @@ -844,7 +853,7 @@ func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TO return nil } -// UpdateTOTPConfigurationSignIn updates a registered WebAuthn devices sign in information. +// UpdateTOTPConfigurationSignIn updates a registered TOTP configurations sign in information. func (p *SQLProvider) UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigRecordSignIn, lastUsedAt, id); err != nil { return fmt.Errorf("error updating TOTP configuration id %d: %w", id, err) @@ -866,7 +875,7 @@ func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username stri func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *model.TOTPConfiguration, err error) { config = &model.TOTPConfiguration{} - if err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config); err != nil { + if err = p.db.GetContext(ctx, config, p.sqlSelectTOTPConfig, username); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNoTOTPConfiguration } @@ -902,99 +911,158 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in return configs, nil } -// SaveWebAuthnDevice saves a registered WebAuthn device. -func (p *SQLProvider) SaveWebAuthnDevice(ctx context.Context, device model.WebAuthnDevice) (err error) { - if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil { - return fmt.Errorf("error encrypting WebAuthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err) +// SaveWebAuthnUser saves a registered WebAuthn user. +func (p *SQLProvider) SaveWebAuthnUser(ctx context.Context, user model.WebAuthnUser) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnUser, user.RPID, user.Username, user.UserID); err != nil { + return fmt.Errorf("error inserting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err) + } + + return nil +} + +// LoadWebAuthnUser loads a registered WebAuthn user. +func (p *SQLProvider) LoadWebAuthnUser(ctx context.Context, rpid, username string) (user *model.WebAuthnUser, err error) { + user = &model.WebAuthnUser{} + + if err = p.db.GetContext(ctx, user, p.sqlSelectWebAuthnUser, rpid, username); err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, nil + default: + return nil, fmt.Errorf("error selecting WebAuthn user '%s' with relying party id '%s': %w", user.Username, user.RPID, err) + } + } + + return user, nil +} + +// SaveWebAuthnCredential saves a registered WebAuthn credential. +func (p *SQLProvider) SaveWebAuthnCredential(ctx context.Context, credential model.WebAuthnCredential) (err error) { + if credential.PublicKey, err = p.encrypt(credential.PublicKey); err != nil { + return fmt.Errorf("error encrypting WebAuthn credential public key for user '%s' kid '%x': %w", credential.Username, credential.KID, err) } - if _, err = p.db.ExecContext(ctx, p.sqlUpsertWebAuthnDevice, - device.CreatedAt, device.LastUsedAt, - device.RPID, device.Username, device.Description, - device.KID, device.PublicKey, - device.AttestationType, device.Transport, device.AAGUID, device.SignCount, device.CloneWarning, + if _, err = p.db.ExecContext(ctx, p.sqlInsertWebAuthnCredential, + credential.CreatedAt, credential.LastUsedAt, credential.RPID, credential.Username, credential.Description, + credential.KID, credential.AAGUID, credential.AttestationType, credential.Attachment, credential.Transport, + credential.SignCount, credential.CloneWarning, credential.Discoverable, credential.Present, credential.Verified, + credential.BackupEligible, credential.BackupState, credential.PublicKey, ); err != nil { - return fmt.Errorf("error upserting WebAuthn device for user '%s' kid '%x': %w", device.Username, device.KID, err) + return fmt.Errorf("error inserting WebAuthn credential for user '%s' kid '%x': %w", credential.Username, credential.KID, err) } return nil } -// UpdateWebAuthnDeviceSignIn updates a registered WebAuthn devices sign in information. -func (p *SQLProvider) UpdateWebAuthnDeviceSignIn(ctx context.Context, id int, rpid string, lastUsedAt sql.NullTime, signCount uint32, cloneWarning bool) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnDeviceRecordSignIn, rpid, lastUsedAt, signCount, cloneWarning, id); err != nil { - return fmt.Errorf("error updating WebAuthn signin metadata for id '%x': %w", id, err) +// UpdateWebAuthnCredentialDescription updates a registered WebAuthn credentials description. +func (p *SQLProvider) UpdateWebAuthnCredentialDescription(ctx context.Context, username string, credentialID int, description string) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialDescriptionByUsernameAndID, description, username, credentialID); err != nil { + return fmt.Errorf("error updating WebAuthn credential description to '%s' for credential id '%d': %w", description, credentialID, err) } return nil } -// DeleteWebAuthnDevice deletes a registered WebAuthn device. -func (p *SQLProvider) DeleteWebAuthnDevice(ctx context.Context, kid string) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnDevice, kid); err != nil { - return fmt.Errorf("error deleting WebAuthn device with kid '%s': %w", kid, err) +// UpdateWebAuthnCredentialSignIn updates a registered WebAuthn credentials sign in information. +func (p *SQLProvider) UpdateWebAuthnCredentialSignIn(ctx context.Context, credential model.WebAuthnCredential) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebAuthnCredentialRecordSignIn, + credential.RPID, credential.LastUsedAt, credential.SignCount, credential.Discoverable, credential.Present, credential.Verified, + credential.BackupEligible, credential.BackupState, credential.CloneWarning, credential.ID, + ); err != nil { + return fmt.Errorf("error updating WebAuthn credentials authentication metadata for id '%x': %w", credential.ID, err) } return nil } -// DeleteWebAuthnDeviceByUsername deletes registered WebAuthn devices by username or username and description. -func (p *SQLProvider) DeleteWebAuthnDeviceByUsername(ctx context.Context, username, description string) (err error) { +// DeleteWebAuthnCredential deletes a registered WebAuthn credential. +func (p *SQLProvider) DeleteWebAuthnCredential(ctx context.Context, kid string) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredential, kid); err != nil { + return fmt.Errorf("error deleting WebAuthn credential with kid '%s': %w", kid, err) + } + + return nil +} + +// DeleteWebAuthnCredentialByUsername deletes registered WebAuthn credential by username or username and description. +func (p *SQLProvider) DeleteWebAuthnCredentialByUsername(ctx context.Context, username, displayname string) (err error) { if len(username) == 0 { - return fmt.Errorf("error deleting WebAuthn device with username '%s' and description '%s': username must not be empty", username, description) + return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': username must not be empty", username, displayname) } - if len(description) == 0 { - if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnDeviceByUsername, username); err != nil { - return fmt.Errorf("error deleting WebAuthn devices for username '%s': %w", username, err) + if len(displayname) == 0 { + if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsername, username); err != nil { + return fmt.Errorf("error deleting WebAuthn credential for username '%s': %w", username, err) } } else { - if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnDeviceByUsernameAndDescription, username, description); err != nil { - return fmt.Errorf("error deleting WebAuthn device with username '%s' and description '%s': %w", username, description, err) + if _, err = p.db.ExecContext(ctx, p.sqlDeleteWebAuthnCredentialByUsernameAndDisplayName, username, displayname); err != nil { + return fmt.Errorf("error deleting WebAuthn credential with username '%s' and displayname '%s': %w", username, displayname, err) } } return nil } -// LoadWebAuthnDevices loads WebAuthn device registrations. -func (p *SQLProvider) LoadWebAuthnDevices(ctx context.Context, limit, page int) (devices []model.WebAuthnDevice, err error) { - devices = make([]model.WebAuthnDevice, 0, limit) +// LoadWebAuthnCredentials loads WebAuthn credential registrations. +func (p *SQLProvider) LoadWebAuthnCredentials(ctx context.Context, limit, page int) (credentials []model.WebAuthnCredential, err error) { + credentials = make([]model.WebAuthnCredential, 0, limit) - if err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebAuthnDevices, limit, limit*page); err != nil { + if err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentials, limit, limit*page); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } - return nil, fmt.Errorf("error selecting WebAuthn devices: %w", err) + return nil, fmt.Errorf("error selecting WebAuthn credentials: %w", err) + } + + for i, credential := range credentials { + if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil { + return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err) + } } - for i, device := range devices { - if devices[i].PublicKey, err = p.decrypt(device.PublicKey); err != nil { - return nil, fmt.Errorf("error decrypting WebAuthn public key for user '%s': %w", device.Username, err) + return credentials, nil +} + +// LoadWebAuthnCredentialByID loads a WebAuthn credential registration for a given id. +func (p *SQLProvider) LoadWebAuthnCredentialByID(ctx context.Context, id int) (credential *model.WebAuthnCredential, err error) { + credential = &model.WebAuthnCredential{} + + if err = p.db.GetContext(ctx, credential, p.sqlSelectWebAuthnCredentialByID, id); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, sql.ErrNoRows } + + return nil, fmt.Errorf("error selecting WebAuthn credential with id '%d': %w", id, err) } - return devices, nil + return credential, nil } -// LoadWebAuthnDevicesByUsername loads all WebAuthn devices registration for a given username. -func (p *SQLProvider) LoadWebAuthnDevicesByUsername(ctx context.Context, username string) (devices []model.WebAuthnDevice, err error) { - if err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebAuthnDevicesByUsername, username); err != nil { +// LoadWebAuthnCredentialsByUsername loads all WebAuthn credential registrations for a given username. +func (p *SQLProvider) LoadWebAuthnCredentialsByUsername(ctx context.Context, rpid, username string) (credentials []model.WebAuthnCredential, err error) { + switch len(rpid) { + case 0: + err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByUsername, username) + default: + err = p.db.SelectContext(ctx, &credentials, p.sqlSelectWebAuthnCredentialsByRPIDByUsername, rpid, username) + } + + if err != nil { if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNoWebAuthnDevice + return credentials, ErrNoWebAuthnCredential } - return nil, fmt.Errorf("error selecting WebAuthn devices for user '%s': %w", username, err) + return nil, fmt.Errorf("error selecting WebAuthn credentials for user '%s': %w", username, err) } - for i, device := range devices { - if devices[i].PublicKey, err = p.decrypt(device.PublicKey); err != nil { - return nil, fmt.Errorf("error decrypting WebAuthn public key for user '%s': %w", username, err) + for i, credential := range credentials { + if credentials[i].PublicKey, err = p.decrypt(credential.PublicKey); err != nil { + return nil, fmt.Errorf("error decrypting WebAuthn credential public key of credential with id '%d' for user '%s': %w", credential.ID, credential.Username, err) } } - return devices, nil + return credentials, nil } // SavePreferredDuoDevice saves a Duo device. |
