diff options
Diffstat (limited to 'internal/storage/sql_provider.go')
| -rw-r--r-- | internal/storage/sql_provider.go | 42 |
1 files changed, 39 insertions, 3 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index dfa23e66b..9f448165f 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -46,9 +46,13 @@ func NewSQLProvider(name, driverName, dataSourceName, encryptionKey string) (pro sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices), sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices), + sqlUpsertDuoDevice: fmt.Sprintf(queryFmtUpsertDuoDevice, tableDuoDevices), + sqlDeleteDuoDevice: fmt.Sprintf(queryFmtDeleteDuoDevice, tableDuoDevices), + sqlSelectDuoDevice: fmt.Sprintf(queryFmtSelectDuoDevice, tableDuoDevices), + sqlUpsertPreferred2FAMethod: fmt.Sprintf(queryFmtUpsertPreferred2FAMethod, tableUserPreferences), sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences), - sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableU2FDevices, tableUserPreferences), + sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableU2FDevices, tableDuoDevices, tableUserPreferences), sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations), sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations), @@ -99,6 +103,11 @@ type SQLProvider struct { sqlUpsertU2FDevice string sqlSelectU2FDevice string + // Table: duo_devices + sqlUpsertDuoDevice string + sqlDeleteDuoDevice string + sqlSelectDuoDevice string + // Table: user_preferences. sqlUpsertPreferred2FAMethod string sqlSelectPreferred2FAMethod string @@ -186,7 +195,7 @@ func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username strin // LoadUserInfo loads the models.UserInfo from the database. func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error) { - err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username) + err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username) switch { case err == nil: @@ -196,7 +205,7 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m return models.UserInfo{}, fmt.Errorf("error upserting preferred two factor method while selecting user info for user '%s': %w", username, err) } - if err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username); err != nil { + if err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username); err != nil { return models.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err) } @@ -355,6 +364,33 @@ func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (devic return device, nil } +// SavePreferredDuoDevice saves a Duo device. +func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method) + return err +} + +// DeletePreferredDuoDevice deletes a Duo device of a given user. +func (p *SQLProvider) DeletePreferredDuoDevice(ctx context.Context, username string) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlDeleteDuoDevice, username) + return err +} + +// LoadPreferredDuoDevice loads a Duo device of a given user. +func (p *SQLProvider) LoadPreferredDuoDevice(ctx context.Context, username string) (device *models.DuoDevice, err error) { + device = &models.DuoDevice{} + + if err := p.db.QueryRowxContext(ctx, p.sqlSelectDuoDevice, username).StructScan(device); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNoDuoDevice + } + + return nil, err + } + + return device, nil +} + // AppendAuthenticationLog append a mark to the authentication log. func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, |
