summaryrefslogtreecommitdiff
path: root/internal/storage/sql_provider.go
diff options
context:
space:
mode:
authorJames Elliott <james-d-elliott@users.noreply.github.com>2023-03-06 14:58:50 +1100
committerGitHub <noreply@github.com>2023-03-06 14:58:50 +1100
commitff6be40f5e5497da1f312d2896e210201a24b048 (patch)
treeb38883ed1623f4878b791929d519596438b8b054 /internal/storage/sql_provider.go
parent42671d3edb0d336794de1e164d147fb742364e11 (diff)
feat(oidc): pushed authorization requests (#4546)
This implements RFC9126 OAuth 2.0 Pushed Authorization Requests. See https://datatracker.ietf.org/doc/html/rfc9126 for the specification details.
Diffstat (limited to 'internal/storage/sql_provider.go')
-rw-r--r--internal/storage/sql_provider.go230
1 files changed, 139 insertions, 91 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go
index a55a41cea..2a4cce037 100644
--- a/internal/storage/sql_provider.go
+++ b/internal/storage/sql_provider.go
@@ -70,6 +70,13 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlSelectUserOpaqueIdentifiers: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifiers, tableUserOpaqueIdentifier),
sqlSelectUserOpaqueIdentifierBySignature: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifierBySignature, tableUserOpaqueIdentifier),
+ sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
+ sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
+
+ sqlInsertOAuth2PARContext: fmt.Sprintf(queryFmtInsertOAuth2PARContext, tableOAuth2PARContext),
+ sqlSelectOAuth2PARContext: fmt.Sprintf(queryFmtSelectOAuth2PARContext, tableOAuth2PARContext),
+ sqlRevokeOAuth2PARContext: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PARContext),
+
sqlInsertOAuth2ConsentPreConfiguration: fmt.Sprintf(queryFmtInsertOAuth2ConsentPreConfiguration, tableOAuth2ConsentPreConfiguration),
sqlSelectOAuth2ConsentPreConfigurations: fmt.Sprintf(queryFmtSelectOAuth2ConsentPreConfigurations, tableOAuth2ConsentPreConfiguration),
@@ -79,13 +86,6 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlUpdateOAuth2ConsentSessionGranted: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionGranted, tableOAuth2ConsentSession),
sqlSelectOAuth2ConsentSessionByChallengeID: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionByChallengeID, tableOAuth2ConsentSession),
- sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession),
- sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession),
- sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession),
- sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
- sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession),
- sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
-
sqlInsertOAuth2AccessTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AccessTokenSession),
sqlSelectOAuth2AccessTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AccessTokenSession),
sqlRevokeOAuth2AccessTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AccessTokenSession),
@@ -93,19 +93,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlDeactivateOAuth2AccessTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AccessTokenSession),
sqlDeactivateOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AccessTokenSession),
- sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession),
- sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession),
- sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession),
- sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
- sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession),
- sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
-
- sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession),
- sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession),
- sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession),
- sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
- sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession),
- sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
+ sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession),
+ sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession),
+ sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession),
+ sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
+ sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession),
+ sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession),
sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession),
@@ -114,8 +107,19 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlDeactivateOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2OpenIDConnectSession),
sqlDeactivateOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession),
- sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
- sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
+ sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession),
+ sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession),
+ sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession),
+ sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
+ sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession),
+ sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
+
+ sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession),
+ sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession),
+ sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession),
+ sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
+ sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession),
+ sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations),
sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations),
@@ -224,13 +228,18 @@ type SQLProvider struct {
sqlDeactivateOAuth2AccessTokenSession string
sqlDeactivateOAuth2AccessTokenSessionByRequestID string
- // Table: oauth2_refresh_token_session.
- sqlInsertOAuth2RefreshTokenSession string
- sqlSelectOAuth2RefreshTokenSession string
- sqlRevokeOAuth2RefreshTokenSession string
- sqlRevokeOAuth2RefreshTokenSessionByRequestID string
- sqlDeactivateOAuth2RefreshTokenSession string
- sqlDeactivateOAuth2RefreshTokenSessionByRequestID string
+ // Table: oauth2_openid_connect_session.
+ sqlInsertOAuth2OpenIDConnectSession string
+ sqlSelectOAuth2OpenIDConnectSession string
+ sqlRevokeOAuth2OpenIDConnectSession string
+ sqlRevokeOAuth2OpenIDConnectSessionByRequestID string
+ sqlDeactivateOAuth2OpenIDConnectSession string
+ sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string
+
+ // Table: oauth2_par_context.
+ sqlInsertOAuth2PARContext string
+ sqlSelectOAuth2PARContext string
+ sqlRevokeOAuth2PARContext string
// Table: oauth2_pkce_request_session.
sqlInsertOAuth2PKCERequestSession string
@@ -240,13 +249,13 @@ type SQLProvider struct {
sqlDeactivateOAuth2PKCERequestSession string
sqlDeactivateOAuth2PKCERequestSessionByRequestID string
- // Table: oauth2_openid_connect_session.
- sqlInsertOAuth2OpenIDConnectSession string
- sqlSelectOAuth2OpenIDConnectSession string
- sqlRevokeOAuth2OpenIDConnectSession string
- sqlRevokeOAuth2OpenIDConnectSessionByRequestID string
- sqlDeactivateOAuth2OpenIDConnectSession string
- sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string
+ // Table: oauth2_refresh_token_session.
+ sqlInsertOAuth2RefreshTokenSession string
+ sqlSelectOAuth2RefreshTokenSession string
+ sqlRevokeOAuth2RefreshTokenSession string
+ sqlRevokeOAuth2RefreshTokenSessionByRequestID string
+ sqlDeactivateOAuth2RefreshTokenSession string
+ sqlDeactivateOAuth2RefreshTokenSessionByRequestID string
sqlUpsertOAuth2BlacklistedJTI string
sqlSelectOAuth2BlacklistedJTI string
@@ -339,19 +348,19 @@ func (p *SQLProvider) Rollback(ctx context.Context) (err error) {
}
// SaveUserOpaqueIdentifier saves a new opaque user identifier to the database.
-func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, opaqueID model.UserOpaqueIdentifier) (err error) {
- if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, opaqueID.Service, opaqueID.SectorID, opaqueID.Username, opaqueID.Identifier); err != nil {
- return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", opaqueID.Username, opaqueID.Identifier.String(), err)
+func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, subject.Service, subject.SectorID, subject.Username, subject.Identifier); err != nil {
+ return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", subject.Username, subject.Identifier.String(), err)
}
return nil
}
// LoadUserOpaqueIdentifier selects an opaque user identifier from the database.
-func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (opaqueID *model.UserOpaqueIdentifier, err error) {
- opaqueID = &model.UserOpaqueIdentifier{}
+func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier uuid.UUID) (subject *model.UserOpaqueIdentifier, err error) {
+ subject = &model.UserOpaqueIdentifier{}
- if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifier, opaqueUUID); err != nil {
+ if err = p.db.GetContext(ctx, subject, p.sqlSelectUserOpaqueIdentifier, identifier); err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, nil
@@ -360,11 +369,11 @@ func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID u
}
}
- return opaqueID, nil
+ return subject, nil
}
// LoadUserOpaqueIdentifiers selects an opaque user identifiers from the database.
-func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs []model.UserOpaqueIdentifier, err error) {
+func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (identifiers []model.UserOpaqueIdentifier, err error) {
var rows *sqlx.Rows
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectUserOpaqueIdentifiers); err != nil {
@@ -380,17 +389,17 @@ func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs
return nil, fmt.Errorf("error selecting user opaque identifiers: error scanning row: %w", err)
}
- opaqueIDs = append(opaqueIDs, *opaqueID)
+ identifiers = append(identifiers, *opaqueID)
}
- return opaqueIDs, nil
+ return identifiers, nil
}
-// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username.
-func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
- opaqueID = &model.UserOpaqueIdentifier{}
+// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username.
+func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error) {
+ subject = &model.UserOpaqueIdentifier{}
- if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil {
+ if err = p.db.GetContext(ctx, subject, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, nil
@@ -399,7 +408,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, s
}
}
- return opaqueID, nil
+ return subject, nil
}
// SaveOAuth2ConsentSession inserts an OAuth2.0 consent session.
@@ -496,22 +505,22 @@ func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2S
var query string
switch sessionType {
- case OAuth2SessionTypeAuthorizeCode:
- query = p.sqlInsertOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken:
query = p.sqlInsertOAuth2AccessTokenSession
- case OAuth2SessionTypeRefreshToken:
- query = p.sqlInsertOAuth2RefreshTokenSession
- case OAuth2SessionTypePKCEChallenge:
- query = p.sqlInsertOAuth2PKCERequestSession
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlInsertOAuth2AuthorizeCodeSession
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlInsertOAuth2OpenIDConnectSession
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlInsertOAuth2PKCERequestSession
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlInsertOAuth2RefreshTokenSession
default:
return fmt.Errorf("error inserting oauth2 session for subject '%s' and request id '%s': unknown oauth2 session type '%s'", session.Subject, session.RequestID, sessionType)
}
if session.Session, err = p.encrypt(session.Session); err != nil {
- return fmt.Errorf("error encrypting the oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
+ return fmt.Errorf("error encrypting oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
}
_, err = p.db.ExecContext(ctx, query,
@@ -532,16 +541,16 @@ func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth
var query string
switch sessionType {
- case OAuth2SessionTypeAuthorizeCode:
- query = p.sqlRevokeOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken:
query = p.sqlRevokeOAuth2AccessTokenSession
- case OAuth2SessionTypeRefreshToken:
- query = p.sqlRevokeOAuth2RefreshTokenSession
- case OAuth2SessionTypePKCEChallenge:
- query = p.sqlRevokeOAuth2PKCERequestSession
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlRevokeOAuth2AuthorizeCodeSession
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlRevokeOAuth2OpenIDConnectSession
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlRevokeOAuth2PKCERequestSession
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlRevokeOAuth2RefreshTokenSession
default:
return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String())
}
@@ -558,16 +567,16 @@ func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessio
var query string
switch sessionType {
- case OAuth2SessionTypeAuthorizeCode:
- query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID
case OAuth2SessionTypeAccessToken:
query = p.sqlRevokeOAuth2AccessTokenSessionByRequestID
- case OAuth2SessionTypeRefreshToken:
- query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID
- case OAuth2SessionTypePKCEChallenge:
- query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlRevokeOAuth2OpenIDConnectSessionByRequestID
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID
default:
return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String())
}
@@ -584,16 +593,16 @@ func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType O
var query string
switch sessionType {
- case OAuth2SessionTypeAuthorizeCode:
- query = p.sqlDeactivateOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken:
query = p.sqlDeactivateOAuth2AccessTokenSession
- case OAuth2SessionTypeRefreshToken:
- query = p.sqlDeactivateOAuth2RefreshTokenSession
- case OAuth2SessionTypePKCEChallenge:
- query = p.sqlDeactivateOAuth2PKCERequestSession
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlDeactivateOAuth2AuthorizeCodeSession
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlDeactivateOAuth2OpenIDConnectSession
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlDeactivateOAuth2PKCERequestSession
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlDeactivateOAuth2RefreshTokenSession
default:
return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String())
}
@@ -610,16 +619,16 @@ func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, se
var query string
switch sessionType {
- case OAuth2SessionTypeAuthorizeCode:
- query = p.sqlDeactivateOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken:
query = p.sqlDeactivateOAuth2AccessTokenSessionByRequestID
- case OAuth2SessionTypeRefreshToken:
- query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID
- case OAuth2SessionTypePKCEChallenge:
- query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlDeactivateOAuth2AuthorizeCodeSession
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID
default:
return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String())
}
@@ -636,16 +645,16 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S
var query string
switch sessionType {
- case OAuth2SessionTypeAuthorizeCode:
- query = p.sqlSelectOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken:
query = p.sqlSelectOAuth2AccessTokenSession
- case OAuth2SessionTypeRefreshToken:
- query = p.sqlSelectOAuth2RefreshTokenSession
- case OAuth2SessionTypePKCEChallenge:
- query = p.sqlSelectOAuth2PKCERequestSession
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlSelectOAuth2AuthorizeCodeSession
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlSelectOAuth2OpenIDConnectSession
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlSelectOAuth2PKCERequestSession
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlSelectOAuth2RefreshTokenSession
default:
return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType.String())
}
@@ -663,6 +672,45 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S
return session, nil
}
+// SaveOAuth2PARContext save a OAuth2PARContext to the database.
+func (p *SQLProvider) SaveOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) {
+ if par.Session, err = p.encrypt(par.Session); err != nil {
+ return fmt.Errorf("error encrypting oauth2 pushed authorization request context data for with signature '%s' and request id '%s': %w", par.Signature, par.RequestID, err)
+ }
+
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2PARContext,
+ par.Signature, par.RequestID, par.ClientID, par.RequestedAt, par.Scopes, par.Audience, par.HandledResponseTypes,
+ par.ResponseMode, par.DefaultResponseMode, par.Revoked, par.Form, par.Session); err != nil {
+ return fmt.Errorf("error inserting oauth2 pushed authorization request context data for with signature '%s' and request id '%s': %w", par.Signature, par.RequestID, err)
+ }
+
+ return nil
+}
+
+// LoadOAuth2PARContext loads a OAuth2PARContext from the database.
+func (p *SQLProvider) LoadOAuth2PARContext(ctx context.Context, signature string) (par *model.OAuth2PARContext, err error) {
+ par = &model.OAuth2PARContext{}
+
+ if err = p.db.GetContext(ctx, par, p.sqlSelectOAuth2PARContext, signature); err != nil {
+ return nil, fmt.Errorf("error selecting oauth2 pushed authorization request context with signature '%s': %w", signature, err)
+ }
+
+ if par.Session, err = p.decrypt(par.Session); err != nil {
+ return nil, fmt.Errorf("error decrypting oauth2 oauth2 pushed authorization request context data with signature '%s' and request id '%s': %w", signature, par.RequestID, err)
+ }
+
+ return par, nil
+}
+
+// RevokeOAuth2PARContext marks a OAuth2PARContext as revoked in the database.
+func (p *SQLProvider) RevokeOAuth2PARContext(ctx context.Context, signature string) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlRevokeOAuth2PARContext, signature); err != nil {
+ return fmt.Errorf("error revoking oauth2 pushed authorization request context with signature '%s': %w", signature, err)
+ }
+
+ return nil
+}
+
// SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database.
func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil {
@@ -762,7 +810,7 @@ func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string)
// SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
if config.Secret, err = p.encrypt(config.Secret); err != nil {
- return fmt.Errorf("error encrypting the TOTP configuration secret for user '%s': %w", config.Username, err)
+ return fmt.Errorf("error encrypting TOTP configuration secret for user '%s': %w", config.Username, err)
}
if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
@@ -806,7 +854,7 @@ func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string
}
if config.Secret, err = p.decrypt(config.Secret); err != nil {
- return nil, fmt.Errorf("error decrypting the TOTP secret for user '%s': %w", username, err)
+ return nil, fmt.Errorf("error decrypting TOTP secret for user '%s': %w", username, err)
}
return config, nil
@@ -836,7 +884,7 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in
// SaveWebauthnDevice saves a registered Webauthn device.
func (p *SQLProvider) SaveWebauthnDevice(ctx context.Context, device model.WebauthnDevice) (err error) {
if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil {
- return fmt.Errorf("error encrypting the Webauthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err)
+ return fmt.Errorf("error encrypting Webauthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err)
}
if _, err = p.db.ExecContext(ctx, p.sqlUpsertWebauthnDevice,