diff options
Diffstat (limited to 'internal/storage/sql_provider.go')
| -rw-r--r-- | internal/storage/sql_provider.go | 223 |
1 files changed, 198 insertions, 25 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index c3659f7f0..d738cd80c 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -36,8 +36,23 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa log: logging.Logger(), - sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs), - sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs), + sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs), + sqlSelectAuthenticationLogsRegulationRecordsByUsername: fmt.Sprintf(queryFmtSelectAuthenticationLogsRegulationRecordsByUsername, tableAuthenticationLogs), + sqlSelectAuthenticationLogsRegulationRecordsByRemoteIP: fmt.Sprintf(queryFmtSelectAuthenticationLogsRegulationRecordsByRemoteIP, tableAuthenticationLogs), + + sqlInsertBannedUser: fmt.Sprintf(queryFmtInsertBannedUser, tableBannedUser), + sqlSelectBannedUser: fmt.Sprintf(queryFmtSelectBannedUser, tableBannedUser), + sqlSelectBannedUserByID: fmt.Sprintf(queryFmtSelectBannedUserByID, tableBannedUser), + sqlSelectBannedUsers: fmt.Sprintf(queryFmtSelectBannedUsers, tableBannedUser), + sqlSelectBannedUserLastTime: fmt.Sprintf(queryFmtSelectBannedUserLastExpires, tableBannedUser), + sqlRevokeBannedUser: fmt.Sprintf(queryFmtRevokeBannedEntry, tableBannedUser), + + sqlInsertBannedIP: fmt.Sprintf(queryFmtInsertBannedIP, tableBannedIP), + sqlSelectBannedIP: fmt.Sprintf(queryFmtSelectBannedIP, tableBannedIP), + sqlSelectBannedIPByID: fmt.Sprintf(queryFmtSelectBannedIPByID, tableBannedIP), + sqlSelectBannedIPs: fmt.Sprintf(queryFmtSelectBannedIPs, tableBannedIP), + sqlSelectBannedIPLastTime: fmt.Sprintf(queryFmtSelectBannedIPLastExpires, tableBannedIP), + sqlRevokeBannedIP: fmt.Sprintf(queryFmtRevokeBannedEntry, tableBannedIP), sqlUpsertCachedData: fmt.Sprintf(queryFmtUpsertCachedData, tableCachedData), sqlSelectCachedData: fmt.Sprintf(queryFmtSelectCachedData, tableCachedData), @@ -181,8 +196,25 @@ type SQLProvider struct { log *logrus.Logger // Table: authentication_logs. - sqlInsertAuthenticationAttempt string - sqlSelectAuthenticationAttemptsByUsername string + sqlInsertAuthenticationAttempt string + sqlSelectAuthenticationLogsRegulationRecordsByUsername string + sqlSelectAuthenticationLogsRegulationRecordsByRemoteIP string + + // Table: banned_user. + sqlInsertBannedUser string + sqlSelectBannedUser string + sqlSelectBannedUserByID string + sqlSelectBannedUsers string + sqlSelectBannedUserLastTime string + sqlRevokeBannedUser string + + // Table: banned_ip. + sqlInsertBannedIP string + sqlSelectBannedIP string + sqlSelectBannedIPByID string + sqlSelectBannedIPs string + sqlSelectBannedIPLastTime string + sqlRevokeBannedIP string // Table: cached_data. sqlUpsertCachedData string @@ -469,7 +501,7 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m // SaveUserOpaqueIdentifier saves a new opaque user identifier to the storage provider. func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, subject.Service, subject.SectorID, subject.Username, subject.Identifier); err != nil { - return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", subject.Username, subject.Identifier.String(), err) + return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", subject.Username, subject.Identifier, err) } return nil @@ -484,7 +516,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier u case errors.Is(err, sql.ErrNoRows): return nil, nil default: - return nil, fmt.Errorf("error selecting user opaque id with value '%s': %w", identifier.String(), err) + return nil, fmt.Errorf("error selecting user opaque id with value '%s': %w", identifier, err) } } @@ -1104,7 +1136,7 @@ func (p *SQLProvider) LoadOAuth2ConsentPreConfigurations(ctx context.Context, cl return &ConsentPreConfigRows{}, nil } - return &ConsentPreConfigRows{}, fmt.Errorf("error selecting oauth2 consent pre-configurations by signature with client id '%s' and subject '%s': %w", clientID, subject.String(), err) + return &ConsentPreConfigRows{}, fmt.Errorf("error selecting oauth2 consent pre-configurations by signature with client id '%s' and subject '%s': %w", clientID, subject, err) } return &ConsentPreConfigRows{rows: r}, nil @@ -1155,7 +1187,7 @@ func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, consent = &model.OAuth2ConsentSession{} if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil { - return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err) + return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID, err) } return consent, nil @@ -1213,11 +1245,11 @@ func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth case OAuth2SessionTypeRefreshToken: query = p.sqlRevokeOAuth2RefreshTokenSession default: - return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String()) + return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType) } if _, err = p.db.ExecContext(ctx, query, signature); err != nil { - return fmt.Errorf("error revoking oauth2 %s session with signature '%s': %w", sessionType.String(), signature, err) + return fmt.Errorf("error revoking oauth2 %s session with signature '%s': %w", sessionType, signature, err) } return nil @@ -1239,11 +1271,11 @@ func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessio case OAuth2SessionTypeRefreshToken: query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID default: - return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String()) + return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType) } if _, err = p.db.ExecContext(ctx, query, requestID); err != nil { - return fmt.Errorf("error revoking oauth2 %s session with request id '%s': %w", sessionType.String(), requestID, err) + return fmt.Errorf("error revoking oauth2 %s session with request id '%s': %w", sessionType, requestID, err) } return nil @@ -1265,11 +1297,11 @@ func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType O case OAuth2SessionTypeRefreshToken: query = p.sqlDeactivateOAuth2RefreshTokenSession default: - return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String()) + return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType) } if _, err = p.db.ExecContext(ctx, query, signature); err != nil { - return fmt.Errorf("error deactivating oauth2 %s session with signature '%s': %w", sessionType.String(), signature, err) + return fmt.Errorf("error deactivating oauth2 %s session with signature '%s': %w", sessionType, signature, err) } return nil @@ -1291,7 +1323,7 @@ func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, se case OAuth2SessionTypeRefreshToken: query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID default: - return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String()) + return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType) } if _, err = p.db.ExecContext(ctx, query, requestID); err != nil { @@ -1317,17 +1349,17 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S case OAuth2SessionTypeRefreshToken: query = p.sqlSelectOAuth2RefreshTokenSession default: - return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType.String()) + return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType) } session = &model.OAuth2Session{} if err = p.db.GetContext(ctx, session, query, signature); err != nil { - return nil, fmt.Errorf("error selecting oauth2 %s session with signature '%s': %w", sessionType.String(), signature, err) + return nil, fmt.Errorf("error selecting oauth2 %s session with signature '%s': %w", sessionType, signature, err) } if session.Session, err = p.decrypt(session.Session); err != nil { - return nil, fmt.Errorf("error decrypting the oauth2 %s session data with signature '%s' for subject '%s' and request id '%s': %w", sessionType.String(), signature, session.Subject.String, session.RequestID, err) + return nil, fmt.Errorf("error decrypting the oauth2 %s session data with signature '%s' for subject '%s' and request id '%s': %w", sessionType, signature, session.Subject.String, session.RequestID, err) } return session, nil @@ -1490,19 +1522,160 @@ func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt model return nil } -// LoadAuthenticationLogs loads authentication attempts from the storage provider (paginated). -func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []model.AuthenticationAttempt, err error) { - attempts = make([]model.AuthenticationAttempt, 0, limit) +func (p *SQLProvider) LoadRegulationRecordsByUser(ctx context.Context, username string, since time.Time, limit int) (records []model.RegulationRecord, err error) { + exp := banExpiresExpired{} - if err = p.db.SelectContext(ctx, &attempts, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page); err != nil { + if err = p.db.GetContext(ctx, &exp, p.sqlSelectBannedUserLastTime, username); err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("error selecting last banned user time for username '%s': %w", username, err) + } + } + + if expiration := exp.Expiration(); expiration.After(since) { + since = expiration + } + + records = make([]model.RegulationRecord, 0, limit) + + if err = p.db.SelectContext(ctx, &records, p.sqlSelectAuthenticationLogsRegulationRecordsByUsername, since, username, false, limit); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting regulation records for username '%s': %w", username, err) + } + + return records, nil +} + +func (p *SQLProvider) SaveBannedUser(ctx context.Context, ban *model.BannedUser) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlInsertBannedUser, ban.Expires, ban.Username, ban.Source, ban.Reason); err != nil { + return fmt.Errorf("error inserting banned user with username '%s' and source '%s' and reason '%s': %w", ban.Username, ban.Source, ban.Reason.String, err) + } + + return nil +} + +func (p *SQLProvider) LoadBannedUser(ctx context.Context, username string) (bans []model.BannedUser, err error) { + bans = []model.BannedUser{} + + if err = p.db.SelectContext(ctx, &bans, p.sqlSelectBannedUser, username, time.Now().UTC()); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting banned user records for username '%s': %w", username, err) + } + + return bans, nil +} + +func (p *SQLProvider) LoadBannedUserByID(ctx context.Context, id int) (ban model.BannedUser, err error) { + if err = p.db.GetContext(ctx, &ban, p.sqlSelectBannedUserByID, id); err != nil { + return model.BannedUser{}, fmt.Errorf("error selecting banned user with id '%d': %w", id, err) + } + + return ban, nil +} + +func (p *SQLProvider) LoadBannedUsers(ctx context.Context, limit, page int) (bans []model.BannedUser, err error) { + bans = []model.BannedUser{} + + if err = p.db.SelectContext(ctx, &bans, p.sqlSelectBannedUsers, false, time.Now().UTC(), limit, limit*page); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting banned user records: %w", err) + } + + return bans, nil +} + +func (p *SQLProvider) RevokeBannedUser(ctx context.Context, id int, expired time.Time) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlRevokeBannedUser, expired, id); err != nil { + return fmt.Errorf("error revoking banned user with id '%d': %w", id, err) + } + + return nil +} + +func (p *SQLProvider) LoadRegulationRecordsByIP(ctx context.Context, ip model.IP, since time.Time, limit int) (records []model.RegulationRecord, err error) { + exp := banExpiresExpired{} + + if err = p.db.GetContext(ctx, &exp, p.sqlSelectBannedIPLastTime, ip); err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("error selecting last banned time for ip '%s': %w", ip, err) + } + } + + if expiration := exp.Expiration(); expiration.After(since) { + since = expiration + } + + records = make([]model.RegulationRecord, 0, limit) + + if err = p.db.SelectContext(ctx, &records, p.sqlSelectAuthenticationLogsRegulationRecordsByRemoteIP, since, ip, false, limit); err != nil { if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNoAuthenticationLogs + return nil, nil } - return nil, fmt.Errorf("error selecting authentication logs for user '%s': %w", username, err) + return nil, fmt.Errorf("error selecting regulation records for ip '%s': %w", ip, err) } - return attempts, nil + return records, nil +} + +func (p *SQLProvider) SaveBannedIP(ctx context.Context, ban *model.BannedIP) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlInsertBannedIP, ban.Expires, ban.IP, ban.Source, ban.Reason); err != nil { + return fmt.Errorf("error inserting banned ip with ip '%s' and source '%s' and reason '%s': %w", ban.IP, ban.Source, ban.Reason.String, err) + } + + return nil +} + +func (p *SQLProvider) LoadBannedIP(ctx context.Context, ip model.IP) (bans []model.BannedIP, err error) { + bans = []model.BannedIP{} + + if err = p.db.SelectContext(ctx, &bans, p.sqlSelectBannedIP, ip, false, time.Now().UTC()); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting banned ip records for ip '%s': %w", ip, err) + } + + return bans, nil +} + +func (p *SQLProvider) LoadBannedIPByID(ctx context.Context, id int) (ban model.BannedIP, err error) { + if err = p.db.GetContext(ctx, &ban, p.sqlSelectBannedIPByID, id); err != nil { + return model.BannedIP{}, fmt.Errorf("error selecting banned ip with id '%d': %w", id, err) + } + + return ban, nil +} + +func (p *SQLProvider) LoadBannedIPs(ctx context.Context, limit, page int) (bans []model.BannedIP, err error) { + bans = []model.BannedIP{} + + if err = p.db.SelectContext(ctx, &bans, p.sqlSelectBannedIPs, time.Now().UTC(), limit, limit*page); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting banned ip records: %w", err) + } + + return bans, nil +} + +func (p *SQLProvider) RevokeBannedIP(ctx context.Context, id int, expired time.Time) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlRevokeBannedIP, expired, id); err != nil { + return fmt.Errorf("error revoking banned ip with id '%d': %w", id, err) + } + + return nil } func (p *SQLProvider) SaveCachedData(ctx context.Context, data model.CachedData) (err error) { |
