diff options
| author | James Elliott <james-d-elliott@users.noreply.github.com> | 2023-03-06 14:58:50 +1100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-06 14:58:50 +1100 |
| commit | ff6be40f5e5497da1f312d2896e210201a24b048 (patch) | |
| tree | b38883ed1623f4878b791929d519596438b8b054 /internal/storage/sql_provider.go | |
| parent | 42671d3edb0d336794de1e164d147fb742364e11 (diff) | |
feat(oidc): pushed authorization requests (#4546)
This implements RFC9126 OAuth 2.0 Pushed Authorization Requests. See https://datatracker.ietf.org/doc/html/rfc9126 for the specification details.
Diffstat (limited to 'internal/storage/sql_provider.go')
| -rw-r--r-- | internal/storage/sql_provider.go | 230 |
1 files changed, 139 insertions, 91 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index a55a41cea..2a4cce037 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -70,6 +70,13 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlSelectUserOpaqueIdentifiers: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifiers, tableUserOpaqueIdentifier), sqlSelectUserOpaqueIdentifierBySignature: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifierBySignature, tableUserOpaqueIdentifier), + sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI), + sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI), + + sqlInsertOAuth2PARContext: fmt.Sprintf(queryFmtInsertOAuth2PARContext, tableOAuth2PARContext), + sqlSelectOAuth2PARContext: fmt.Sprintf(queryFmtSelectOAuth2PARContext, tableOAuth2PARContext), + sqlRevokeOAuth2PARContext: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PARContext), + sqlInsertOAuth2ConsentPreConfiguration: fmt.Sprintf(queryFmtInsertOAuth2ConsentPreConfiguration, tableOAuth2ConsentPreConfiguration), sqlSelectOAuth2ConsentPreConfigurations: fmt.Sprintf(queryFmtSelectOAuth2ConsentPreConfigurations, tableOAuth2ConsentPreConfiguration), @@ -79,13 +86,6 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlUpdateOAuth2ConsentSessionGranted: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionGranted, tableOAuth2ConsentSession), sqlSelectOAuth2ConsentSessionByChallengeID: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionByChallengeID, tableOAuth2ConsentSession), - sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession), - sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession), - sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession), - sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession), - sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession), - sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession), - sqlInsertOAuth2AccessTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AccessTokenSession), sqlSelectOAuth2AccessTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AccessTokenSession), sqlRevokeOAuth2AccessTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AccessTokenSession), @@ -93,19 +93,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlDeactivateOAuth2AccessTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AccessTokenSession), sqlDeactivateOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AccessTokenSession), - sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession), - sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession), - sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession), - sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession), - sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession), - sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession), - - sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession), - sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession), - sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession), - sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession), - sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession), - sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession), + sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession), + sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession), + sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession), + sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession), + sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession), + sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession), sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession), sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession), @@ -114,8 +107,19 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlDeactivateOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2OpenIDConnectSession), sqlDeactivateOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession), - sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI), - sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI), + sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession), + sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession), + sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession), + sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession), + sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession), + sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession), + + sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession), + sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession), + sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession), + sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession), + sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession), + sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession), sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations), sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations), @@ -224,13 +228,18 @@ type SQLProvider struct { sqlDeactivateOAuth2AccessTokenSession string sqlDeactivateOAuth2AccessTokenSessionByRequestID string - // Table: oauth2_refresh_token_session. - sqlInsertOAuth2RefreshTokenSession string - sqlSelectOAuth2RefreshTokenSession string - sqlRevokeOAuth2RefreshTokenSession string - sqlRevokeOAuth2RefreshTokenSessionByRequestID string - sqlDeactivateOAuth2RefreshTokenSession string - sqlDeactivateOAuth2RefreshTokenSessionByRequestID string + // Table: oauth2_openid_connect_session. + sqlInsertOAuth2OpenIDConnectSession string + sqlSelectOAuth2OpenIDConnectSession string + sqlRevokeOAuth2OpenIDConnectSession string + sqlRevokeOAuth2OpenIDConnectSessionByRequestID string + sqlDeactivateOAuth2OpenIDConnectSession string + sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string + + // Table: oauth2_par_context. + sqlInsertOAuth2PARContext string + sqlSelectOAuth2PARContext string + sqlRevokeOAuth2PARContext string // Table: oauth2_pkce_request_session. sqlInsertOAuth2PKCERequestSession string @@ -240,13 +249,13 @@ type SQLProvider struct { sqlDeactivateOAuth2PKCERequestSession string sqlDeactivateOAuth2PKCERequestSessionByRequestID string - // Table: oauth2_openid_connect_session. - sqlInsertOAuth2OpenIDConnectSession string - sqlSelectOAuth2OpenIDConnectSession string - sqlRevokeOAuth2OpenIDConnectSession string - sqlRevokeOAuth2OpenIDConnectSessionByRequestID string - sqlDeactivateOAuth2OpenIDConnectSession string - sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string + // Table: oauth2_refresh_token_session. + sqlInsertOAuth2RefreshTokenSession string + sqlSelectOAuth2RefreshTokenSession string + sqlRevokeOAuth2RefreshTokenSession string + sqlRevokeOAuth2RefreshTokenSessionByRequestID string + sqlDeactivateOAuth2RefreshTokenSession string + sqlDeactivateOAuth2RefreshTokenSessionByRequestID string sqlUpsertOAuth2BlacklistedJTI string sqlSelectOAuth2BlacklistedJTI string @@ -339,19 +348,19 @@ func (p *SQLProvider) Rollback(ctx context.Context) (err error) { } // SaveUserOpaqueIdentifier saves a new opaque user identifier to the database. -func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, opaqueID model.UserOpaqueIdentifier) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, opaqueID.Service, opaqueID.SectorID, opaqueID.Username, opaqueID.Identifier); err != nil { - return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", opaqueID.Username, opaqueID.Identifier.String(), err) +func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, subject.Service, subject.SectorID, subject.Username, subject.Identifier); err != nil { + return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", subject.Username, subject.Identifier.String(), err) } return nil } // LoadUserOpaqueIdentifier selects an opaque user identifier from the database. -func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (opaqueID *model.UserOpaqueIdentifier, err error) { - opaqueID = &model.UserOpaqueIdentifier{} +func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier uuid.UUID) (subject *model.UserOpaqueIdentifier, err error) { + subject = &model.UserOpaqueIdentifier{} - if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifier, opaqueUUID); err != nil { + if err = p.db.GetContext(ctx, subject, p.sqlSelectUserOpaqueIdentifier, identifier); err != nil { switch { case errors.Is(err, sql.ErrNoRows): return nil, nil @@ -360,11 +369,11 @@ func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID u } } - return opaqueID, nil + return subject, nil } // LoadUserOpaqueIdentifiers selects an opaque user identifiers from the database. -func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs []model.UserOpaqueIdentifier, err error) { +func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (identifiers []model.UserOpaqueIdentifier, err error) { var rows *sqlx.Rows if rows, err = p.db.QueryxContext(ctx, p.sqlSelectUserOpaqueIdentifiers); err != nil { @@ -380,17 +389,17 @@ func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs return nil, fmt.Errorf("error selecting user opaque identifiers: error scanning row: %w", err) } - opaqueIDs = append(opaqueIDs, *opaqueID) + identifiers = append(identifiers, *opaqueID) } - return opaqueIDs, nil + return identifiers, nil } -// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username. -func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) { - opaqueID = &model.UserOpaqueIdentifier{} +// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username. +func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error) { + subject = &model.UserOpaqueIdentifier{} - if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil { + if err = p.db.GetContext(ctx, subject, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil { switch { case errors.Is(err, sql.ErrNoRows): return nil, nil @@ -399,7 +408,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, s } } - return opaqueID, nil + return subject, nil } // SaveOAuth2ConsentSession inserts an OAuth2.0 consent session. @@ -496,22 +505,22 @@ func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2S var query string switch sessionType { - case OAuth2SessionTypeAuthorizeCode: - query = p.sqlInsertOAuth2AuthorizeCodeSession case OAuth2SessionTypeAccessToken: query = p.sqlInsertOAuth2AccessTokenSession - case OAuth2SessionTypeRefreshToken: - query = p.sqlInsertOAuth2RefreshTokenSession - case OAuth2SessionTypePKCEChallenge: - query = p.sqlInsertOAuth2PKCERequestSession + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlInsertOAuth2AuthorizeCodeSession case OAuth2SessionTypeOpenIDConnect: query = p.sqlInsertOAuth2OpenIDConnectSession + case OAuth2SessionTypePKCEChallenge: + query = p.sqlInsertOAuth2PKCERequestSession + case OAuth2SessionTypeRefreshToken: + query = p.sqlInsertOAuth2RefreshTokenSession default: return fmt.Errorf("error inserting oauth2 session for subject '%s' and request id '%s': unknown oauth2 session type '%s'", session.Subject, session.RequestID, sessionType) } if session.Session, err = p.encrypt(session.Session); err != nil { - return fmt.Errorf("error encrypting the oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err) + return fmt.Errorf("error encrypting oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err) } _, err = p.db.ExecContext(ctx, query, @@ -532,16 +541,16 @@ func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth var query string switch sessionType { - case OAuth2SessionTypeAuthorizeCode: - query = p.sqlRevokeOAuth2AuthorizeCodeSession case OAuth2SessionTypeAccessToken: query = p.sqlRevokeOAuth2AccessTokenSession - case OAuth2SessionTypeRefreshToken: - query = p.sqlRevokeOAuth2RefreshTokenSession - case OAuth2SessionTypePKCEChallenge: - query = p.sqlRevokeOAuth2PKCERequestSession + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlRevokeOAuth2AuthorizeCodeSession case OAuth2SessionTypeOpenIDConnect: query = p.sqlRevokeOAuth2OpenIDConnectSession + case OAuth2SessionTypePKCEChallenge: + query = p.sqlRevokeOAuth2PKCERequestSession + case OAuth2SessionTypeRefreshToken: + query = p.sqlRevokeOAuth2RefreshTokenSession default: return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String()) } @@ -558,16 +567,16 @@ func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessio var query string switch sessionType { - case OAuth2SessionTypeAuthorizeCode: - query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID case OAuth2SessionTypeAccessToken: query = p.sqlRevokeOAuth2AccessTokenSessionByRequestID - case OAuth2SessionTypeRefreshToken: - query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID - case OAuth2SessionTypePKCEChallenge: - query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID case OAuth2SessionTypeOpenIDConnect: query = p.sqlRevokeOAuth2OpenIDConnectSessionByRequestID + case OAuth2SessionTypePKCEChallenge: + query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID + case OAuth2SessionTypeRefreshToken: + query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID default: return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String()) } @@ -584,16 +593,16 @@ func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType O var query string switch sessionType { - case OAuth2SessionTypeAuthorizeCode: - query = p.sqlDeactivateOAuth2AuthorizeCodeSession case OAuth2SessionTypeAccessToken: query = p.sqlDeactivateOAuth2AccessTokenSession - case OAuth2SessionTypeRefreshToken: - query = p.sqlDeactivateOAuth2RefreshTokenSession - case OAuth2SessionTypePKCEChallenge: - query = p.sqlDeactivateOAuth2PKCERequestSession + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlDeactivateOAuth2AuthorizeCodeSession case OAuth2SessionTypeOpenIDConnect: query = p.sqlDeactivateOAuth2OpenIDConnectSession + case OAuth2SessionTypePKCEChallenge: + query = p.sqlDeactivateOAuth2PKCERequestSession + case OAuth2SessionTypeRefreshToken: + query = p.sqlDeactivateOAuth2RefreshTokenSession default: return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String()) } @@ -610,16 +619,16 @@ func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, se var query string switch sessionType { - case OAuth2SessionTypeAuthorizeCode: - query = p.sqlDeactivateOAuth2AuthorizeCodeSession case OAuth2SessionTypeAccessToken: query = p.sqlDeactivateOAuth2AccessTokenSessionByRequestID - case OAuth2SessionTypeRefreshToken: - query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID - case OAuth2SessionTypePKCEChallenge: - query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlDeactivateOAuth2AuthorizeCodeSession case OAuth2SessionTypeOpenIDConnect: query = p.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID + case OAuth2SessionTypePKCEChallenge: + query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID + case OAuth2SessionTypeRefreshToken: + query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID default: return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String()) } @@ -636,16 +645,16 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S var query string switch sessionType { - case OAuth2SessionTypeAuthorizeCode: - query = p.sqlSelectOAuth2AuthorizeCodeSession case OAuth2SessionTypeAccessToken: query = p.sqlSelectOAuth2AccessTokenSession - case OAuth2SessionTypeRefreshToken: - query = p.sqlSelectOAuth2RefreshTokenSession - case OAuth2SessionTypePKCEChallenge: - query = p.sqlSelectOAuth2PKCERequestSession + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlSelectOAuth2AuthorizeCodeSession case OAuth2SessionTypeOpenIDConnect: query = p.sqlSelectOAuth2OpenIDConnectSession + case OAuth2SessionTypePKCEChallenge: + query = p.sqlSelectOAuth2PKCERequestSession + case OAuth2SessionTypeRefreshToken: + query = p.sqlSelectOAuth2RefreshTokenSession default: return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType.String()) } @@ -663,6 +672,45 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S return session, nil } +// SaveOAuth2PARContext save a OAuth2PARContext to the database. +func (p *SQLProvider) SaveOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) { + if par.Session, err = p.encrypt(par.Session); err != nil { + return fmt.Errorf("error encrypting oauth2 pushed authorization request context data for with signature '%s' and request id '%s': %w", par.Signature, par.RequestID, err) + } + + if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2PARContext, + par.Signature, par.RequestID, par.ClientID, par.RequestedAt, par.Scopes, par.Audience, par.HandledResponseTypes, + par.ResponseMode, par.DefaultResponseMode, par.Revoked, par.Form, par.Session); err != nil { + return fmt.Errorf("error inserting oauth2 pushed authorization request context data for with signature '%s' and request id '%s': %w", par.Signature, par.RequestID, err) + } + + return nil +} + +// LoadOAuth2PARContext loads a OAuth2PARContext from the database. +func (p *SQLProvider) LoadOAuth2PARContext(ctx context.Context, signature string) (par *model.OAuth2PARContext, err error) { + par = &model.OAuth2PARContext{} + + if err = p.db.GetContext(ctx, par, p.sqlSelectOAuth2PARContext, signature); err != nil { + return nil, fmt.Errorf("error selecting oauth2 pushed authorization request context with signature '%s': %w", signature, err) + } + + if par.Session, err = p.decrypt(par.Session); err != nil { + return nil, fmt.Errorf("error decrypting oauth2 oauth2 pushed authorization request context data with signature '%s' and request id '%s': %w", signature, par.RequestID, err) + } + + return par, nil +} + +// RevokeOAuth2PARContext marks a OAuth2PARContext as revoked in the database. +func (p *SQLProvider) RevokeOAuth2PARContext(ctx context.Context, signature string) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlRevokeOAuth2PARContext, signature); err != nil { + return fmt.Errorf("error revoking oauth2 pushed authorization request context with signature '%s': %w", signature, err) + } + + return nil +} + // SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database. func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil { @@ -762,7 +810,7 @@ func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) // SaveTOTPConfiguration save a TOTP configuration of a given user in the database. func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) { if config.Secret, err = p.encrypt(config.Secret); err != nil { - return fmt.Errorf("error encrypting the TOTP configuration secret for user '%s': %w", config.Username, err) + return fmt.Errorf("error encrypting TOTP configuration secret for user '%s': %w", config.Username, err) } if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig, @@ -806,7 +854,7 @@ func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string } if config.Secret, err = p.decrypt(config.Secret); err != nil { - return nil, fmt.Errorf("error decrypting the TOTP secret for user '%s': %w", username, err) + return nil, fmt.Errorf("error decrypting TOTP secret for user '%s': %w", username, err) } return config, nil @@ -836,7 +884,7 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in // SaveWebauthnDevice saves a registered Webauthn device. func (p *SQLProvider) SaveWebauthnDevice(ctx context.Context, device model.WebauthnDevice) (err error) { if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil { - return fmt.Errorf("error encrypting the Webauthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err) + return fmt.Errorf("error encrypting Webauthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err) } if _, err = p.db.ExecContext(ctx, p.sqlUpsertWebauthnDevice, |
