summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/storage/sql_provider.go')
-rw-r--r--internal/storage/sql_provider.go223
1 files changed, 198 insertions, 25 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go
index c3659f7f0..d738cd80c 100644
--- a/internal/storage/sql_provider.go
+++ b/internal/storage/sql_provider.go
@@ -36,8 +36,23 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
log: logging.Logger(),
- sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
- sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
+ sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
+ sqlSelectAuthenticationLogsRegulationRecordsByUsername: fmt.Sprintf(queryFmtSelectAuthenticationLogsRegulationRecordsByUsername, tableAuthenticationLogs),
+ sqlSelectAuthenticationLogsRegulationRecordsByRemoteIP: fmt.Sprintf(queryFmtSelectAuthenticationLogsRegulationRecordsByRemoteIP, tableAuthenticationLogs),
+
+ sqlInsertBannedUser: fmt.Sprintf(queryFmtInsertBannedUser, tableBannedUser),
+ sqlSelectBannedUser: fmt.Sprintf(queryFmtSelectBannedUser, tableBannedUser),
+ sqlSelectBannedUserByID: fmt.Sprintf(queryFmtSelectBannedUserByID, tableBannedUser),
+ sqlSelectBannedUsers: fmt.Sprintf(queryFmtSelectBannedUsers, tableBannedUser),
+ sqlSelectBannedUserLastTime: fmt.Sprintf(queryFmtSelectBannedUserLastExpires, tableBannedUser),
+ sqlRevokeBannedUser: fmt.Sprintf(queryFmtRevokeBannedEntry, tableBannedUser),
+
+ sqlInsertBannedIP: fmt.Sprintf(queryFmtInsertBannedIP, tableBannedIP),
+ sqlSelectBannedIP: fmt.Sprintf(queryFmtSelectBannedIP, tableBannedIP),
+ sqlSelectBannedIPByID: fmt.Sprintf(queryFmtSelectBannedIPByID, tableBannedIP),
+ sqlSelectBannedIPs: fmt.Sprintf(queryFmtSelectBannedIPs, tableBannedIP),
+ sqlSelectBannedIPLastTime: fmt.Sprintf(queryFmtSelectBannedIPLastExpires, tableBannedIP),
+ sqlRevokeBannedIP: fmt.Sprintf(queryFmtRevokeBannedEntry, tableBannedIP),
sqlUpsertCachedData: fmt.Sprintf(queryFmtUpsertCachedData, tableCachedData),
sqlSelectCachedData: fmt.Sprintf(queryFmtSelectCachedData, tableCachedData),
@@ -181,8 +196,25 @@ type SQLProvider struct {
log *logrus.Logger
// Table: authentication_logs.
- sqlInsertAuthenticationAttempt string
- sqlSelectAuthenticationAttemptsByUsername string
+ sqlInsertAuthenticationAttempt string
+ sqlSelectAuthenticationLogsRegulationRecordsByUsername string
+ sqlSelectAuthenticationLogsRegulationRecordsByRemoteIP string
+
+ // Table: banned_user.
+ sqlInsertBannedUser string
+ sqlSelectBannedUser string
+ sqlSelectBannedUserByID string
+ sqlSelectBannedUsers string
+ sqlSelectBannedUserLastTime string
+ sqlRevokeBannedUser string
+
+ // Table: banned_ip.
+ sqlInsertBannedIP string
+ sqlSelectBannedIP string
+ sqlSelectBannedIPByID string
+ sqlSelectBannedIPs string
+ sqlSelectBannedIPLastTime string
+ sqlRevokeBannedIP string
// Table: cached_data.
sqlUpsertCachedData string
@@ -469,7 +501,7 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m
// SaveUserOpaqueIdentifier saves a new opaque user identifier to the storage provider.
func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, subject.Service, subject.SectorID, subject.Username, subject.Identifier); err != nil {
- return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", subject.Username, subject.Identifier.String(), err)
+ return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", subject.Username, subject.Identifier, err)
}
return nil
@@ -484,7 +516,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier u
case errors.Is(err, sql.ErrNoRows):
return nil, nil
default:
- return nil, fmt.Errorf("error selecting user opaque id with value '%s': %w", identifier.String(), err)
+ return nil, fmt.Errorf("error selecting user opaque id with value '%s': %w", identifier, err)
}
}
@@ -1104,7 +1136,7 @@ func (p *SQLProvider) LoadOAuth2ConsentPreConfigurations(ctx context.Context, cl
return &ConsentPreConfigRows{}, nil
}
- return &ConsentPreConfigRows{}, fmt.Errorf("error selecting oauth2 consent pre-configurations by signature with client id '%s' and subject '%s': %w", clientID, subject.String(), err)
+ return &ConsentPreConfigRows{}, fmt.Errorf("error selecting oauth2 consent pre-configurations by signature with client id '%s' and subject '%s': %w", clientID, subject, err)
}
return &ConsentPreConfigRows{rows: r}, nil
@@ -1155,7 +1187,7 @@ func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context,
consent = &model.OAuth2ConsentSession{}
if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil {
- return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err)
+ return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID, err)
}
return consent, nil
@@ -1213,11 +1245,11 @@ func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth
case OAuth2SessionTypeRefreshToken:
query = p.sqlRevokeOAuth2RefreshTokenSession
default:
- return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String())
+ return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType)
}
if _, err = p.db.ExecContext(ctx, query, signature); err != nil {
- return fmt.Errorf("error revoking oauth2 %s session with signature '%s': %w", sessionType.String(), signature, err)
+ return fmt.Errorf("error revoking oauth2 %s session with signature '%s': %w", sessionType, signature, err)
}
return nil
@@ -1239,11 +1271,11 @@ func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessio
case OAuth2SessionTypeRefreshToken:
query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID
default:
- return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String())
+ return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType)
}
if _, err = p.db.ExecContext(ctx, query, requestID); err != nil {
- return fmt.Errorf("error revoking oauth2 %s session with request id '%s': %w", sessionType.String(), requestID, err)
+ return fmt.Errorf("error revoking oauth2 %s session with request id '%s': %w", sessionType, requestID, err)
}
return nil
@@ -1265,11 +1297,11 @@ func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType O
case OAuth2SessionTypeRefreshToken:
query = p.sqlDeactivateOAuth2RefreshTokenSession
default:
- return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String())
+ return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType)
}
if _, err = p.db.ExecContext(ctx, query, signature); err != nil {
- return fmt.Errorf("error deactivating oauth2 %s session with signature '%s': %w", sessionType.String(), signature, err)
+ return fmt.Errorf("error deactivating oauth2 %s session with signature '%s': %w", sessionType, signature, err)
}
return nil
@@ -1291,7 +1323,7 @@ func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, se
case OAuth2SessionTypeRefreshToken:
query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID
default:
- return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String())
+ return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType)
}
if _, err = p.db.ExecContext(ctx, query, requestID); err != nil {
@@ -1317,17 +1349,17 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S
case OAuth2SessionTypeRefreshToken:
query = p.sqlSelectOAuth2RefreshTokenSession
default:
- return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType.String())
+ return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType)
}
session = &model.OAuth2Session{}
if err = p.db.GetContext(ctx, session, query, signature); err != nil {
- return nil, fmt.Errorf("error selecting oauth2 %s session with signature '%s': %w", sessionType.String(), signature, err)
+ return nil, fmt.Errorf("error selecting oauth2 %s session with signature '%s': %w", sessionType, signature, err)
}
if session.Session, err = p.decrypt(session.Session); err != nil {
- return nil, fmt.Errorf("error decrypting the oauth2 %s session data with signature '%s' for subject '%s' and request id '%s': %w", sessionType.String(), signature, session.Subject.String, session.RequestID, err)
+ return nil, fmt.Errorf("error decrypting the oauth2 %s session data with signature '%s' for subject '%s' and request id '%s': %w", sessionType, signature, session.Subject.String, session.RequestID, err)
}
return session, nil
@@ -1490,19 +1522,160 @@ func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt model
return nil
}
-// LoadAuthenticationLogs loads authentication attempts from the storage provider (paginated).
-func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []model.AuthenticationAttempt, err error) {
- attempts = make([]model.AuthenticationAttempt, 0, limit)
+func (p *SQLProvider) LoadRegulationRecordsByUser(ctx context.Context, username string, since time.Time, limit int) (records []model.RegulationRecord, err error) {
+ exp := banExpiresExpired{}
- if err = p.db.SelectContext(ctx, &attempts, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page); err != nil {
+ if err = p.db.GetContext(ctx, &exp, p.sqlSelectBannedUserLastTime, username); err != nil {
+ if !errors.Is(err, sql.ErrNoRows) {
+ return nil, fmt.Errorf("error selecting last banned user time for username '%s': %w", username, err)
+ }
+ }
+
+ if expiration := exp.Expiration(); expiration.After(since) {
+ since = expiration
+ }
+
+ records = make([]model.RegulationRecord, 0, limit)
+
+ if err = p.db.SelectContext(ctx, &records, p.sqlSelectAuthenticationLogsRegulationRecordsByUsername, since, username, false, limit); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting regulation records for username '%s': %w", username, err)
+ }
+
+ return records, nil
+}
+
+func (p *SQLProvider) SaveBannedUser(ctx context.Context, ban *model.BannedUser) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertBannedUser, ban.Expires, ban.Username, ban.Source, ban.Reason); err != nil {
+ return fmt.Errorf("error inserting banned user with username '%s' and source '%s' and reason '%s': %w", ban.Username, ban.Source, ban.Reason.String, err)
+ }
+
+ return nil
+}
+
+func (p *SQLProvider) LoadBannedUser(ctx context.Context, username string) (bans []model.BannedUser, err error) {
+ bans = []model.BannedUser{}
+
+ if err = p.db.SelectContext(ctx, &bans, p.sqlSelectBannedUser, username, time.Now().UTC()); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting banned user records for username '%s': %w", username, err)
+ }
+
+ return bans, nil
+}
+
+func (p *SQLProvider) LoadBannedUserByID(ctx context.Context, id int) (ban model.BannedUser, err error) {
+ if err = p.db.GetContext(ctx, &ban, p.sqlSelectBannedUserByID, id); err != nil {
+ return model.BannedUser{}, fmt.Errorf("error selecting banned user with id '%d': %w", id, err)
+ }
+
+ return ban, nil
+}
+
+func (p *SQLProvider) LoadBannedUsers(ctx context.Context, limit, page int) (bans []model.BannedUser, err error) {
+ bans = []model.BannedUser{}
+
+ if err = p.db.SelectContext(ctx, &bans, p.sqlSelectBannedUsers, false, time.Now().UTC(), limit, limit*page); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting banned user records: %w", err)
+ }
+
+ return bans, nil
+}
+
+func (p *SQLProvider) RevokeBannedUser(ctx context.Context, id int, expired time.Time) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlRevokeBannedUser, expired, id); err != nil {
+ return fmt.Errorf("error revoking banned user with id '%d': %w", id, err)
+ }
+
+ return nil
+}
+
+func (p *SQLProvider) LoadRegulationRecordsByIP(ctx context.Context, ip model.IP, since time.Time, limit int) (records []model.RegulationRecord, err error) {
+ exp := banExpiresExpired{}
+
+ if err = p.db.GetContext(ctx, &exp, p.sqlSelectBannedIPLastTime, ip); err != nil {
+ if !errors.Is(err, sql.ErrNoRows) {
+ return nil, fmt.Errorf("error selecting last banned time for ip '%s': %w", ip, err)
+ }
+ }
+
+ if expiration := exp.Expiration(); expiration.After(since) {
+ since = expiration
+ }
+
+ records = make([]model.RegulationRecord, 0, limit)
+
+ if err = p.db.SelectContext(ctx, &records, p.sqlSelectAuthenticationLogsRegulationRecordsByRemoteIP, since, ip, false, limit); err != nil {
if errors.Is(err, sql.ErrNoRows) {
- return nil, ErrNoAuthenticationLogs
+ return nil, nil
}
- return nil, fmt.Errorf("error selecting authentication logs for user '%s': %w", username, err)
+ return nil, fmt.Errorf("error selecting regulation records for ip '%s': %w", ip, err)
}
- return attempts, nil
+ return records, nil
+}
+
+func (p *SQLProvider) SaveBannedIP(ctx context.Context, ban *model.BannedIP) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertBannedIP, ban.Expires, ban.IP, ban.Source, ban.Reason); err != nil {
+ return fmt.Errorf("error inserting banned ip with ip '%s' and source '%s' and reason '%s': %w", ban.IP, ban.Source, ban.Reason.String, err)
+ }
+
+ return nil
+}
+
+func (p *SQLProvider) LoadBannedIP(ctx context.Context, ip model.IP) (bans []model.BannedIP, err error) {
+ bans = []model.BannedIP{}
+
+ if err = p.db.SelectContext(ctx, &bans, p.sqlSelectBannedIP, ip, false, time.Now().UTC()); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting banned ip records for ip '%s': %w", ip, err)
+ }
+
+ return bans, nil
+}
+
+func (p *SQLProvider) LoadBannedIPByID(ctx context.Context, id int) (ban model.BannedIP, err error) {
+ if err = p.db.GetContext(ctx, &ban, p.sqlSelectBannedIPByID, id); err != nil {
+ return model.BannedIP{}, fmt.Errorf("error selecting banned ip with id '%d': %w", id, err)
+ }
+
+ return ban, nil
+}
+
+func (p *SQLProvider) LoadBannedIPs(ctx context.Context, limit, page int) (bans []model.BannedIP, err error) {
+ bans = []model.BannedIP{}
+
+ if err = p.db.SelectContext(ctx, &bans, p.sqlSelectBannedIPs, time.Now().UTC(), limit, limit*page); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("error selecting banned ip records: %w", err)
+ }
+
+ return bans, nil
+}
+
+func (p *SQLProvider) RevokeBannedIP(ctx context.Context, id int, expired time.Time) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlRevokeBannedIP, expired, id); err != nil {
+ return fmt.Errorf("error revoking banned ip with id '%d': %w", id, err)
+ }
+
+ return nil
}
func (p *SQLProvider) SaveCachedData(ctx context.Context, data model.CachedData) (err error) {