diff options
| author | James Elliott <james-d-elliott@users.noreply.github.com> | 2023-08-20 13:00:00 +1000 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-20 13:00:00 +1000 | 
| commit | 321a3803f52b01324fcbf0e5b12ae014bf075c1e (patch) | |
| tree | 7e434d9ec3128cf83d59922a5eb493a7035e0c90 /internal/storage | |
| parent | e42bbca1efa3a596aaa7289a9a8c61e108d13a52 (diff) | |
fix(oidc): par consent state error (#5880)
This fixes a state error during a PAR session were if the session requires consent the flow fails.
Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
Diffstat (limited to 'internal/storage')
| -rw-r--r-- | internal/storage/provider.go | 1 | ||||
| -rw-r--r-- | internal/storage/sql_provider.go | 21 | ||||
| -rw-r--r-- | internal/storage/sql_provider_backend_postgres.go | 1 | ||||
| -rw-r--r-- | internal/storage/sql_provider_queries.go | 9 | 
4 files changed, 31 insertions, 1 deletions
diff --git a/internal/storage/provider.go b/internal/storage/provider.go index 651cdadfa..59e9b2219 100644 --- a/internal/storage/provider.go +++ b/internal/storage/provider.go @@ -68,6 +68,7 @@ type Provider interface {  	SaveOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error)  	LoadOAuth2PARContext(ctx context.Context, signature string) (par *model.OAuth2PARContext, err error)  	RevokeOAuth2PARContext(ctx context.Context, signature string) (err error) +	UpdateOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error)  	SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error)  	LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index 98964ae4f..7d8f1c14b 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -74,6 +74,7 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa  		sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),  		sqlInsertOAuth2PARContext: fmt.Sprintf(queryFmtInsertOAuth2PARContext, tableOAuth2PARContext), +		sqlUpdateOAuth2PARContext: fmt.Sprintf(queryFmtUpdateOAuth2PARContext, tableOAuth2PARContext),  		sqlSelectOAuth2PARContext: fmt.Sprintf(queryFmtSelectOAuth2PARContext, tableOAuth2PARContext),  		sqlRevokeOAuth2PARContext: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PARContext), @@ -238,6 +239,7 @@ type SQLProvider struct {  	// Table: oauth2_par_context.  	sqlInsertOAuth2PARContext string +	sqlUpdateOAuth2PARContext string  	sqlSelectOAuth2PARContext string  	sqlRevokeOAuth2PARContext string @@ -687,6 +689,25 @@ func (p *SQLProvider) SaveOAuth2PARContext(ctx context.Context, par model.OAuth2  	return nil  } +// UpdateOAuth2PARContext updates an existing OAuth2PARContext in the database. +func (p *SQLProvider) UpdateOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) { +	if par.ID == 0 { +		return fmt.Errorf("error updating oauth2 pushed authorization request context data with signature '%s' and request id '%s': the id was a zero value", par.Signature, par.RequestID) +	} + +	if par.Session, err = p.encrypt(par.Session); err != nil { +		return fmt.Errorf("error encrypting oauth2 pushed authorization request context data with id '%d' and signature '%s' and request id '%s': %w", par.ID, par.Signature, par.RequestID, err) +	} + +	if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2PARContext, +		par.Signature, par.RequestID, par.ClientID, par.RequestedAt, par.Scopes, par.Audience, par.HandledResponseTypes, +		par.ResponseMode, par.DefaultResponseMode, par.Revoked, par.Form, par.Session, par.ID); err != nil { +		return fmt.Errorf("error updating oauth2 pushed authorization request context data with id '%d' and signature '%s' and request id '%s': %w", par.ID, par.Signature, par.RequestID, err) +	} + +	return nil +} +  // LoadOAuth2PARContext loads a OAuth2PARContext from the database.  func (p *SQLProvider) LoadOAuth2PARContext(ctx context.Context, signature string) (par *model.OAuth2PARContext, err error) {  	par = &model.OAuth2PARContext{} diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index ef7054951..e97ef09e5 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -108,6 +108,7 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo  	provider.sqlSelectOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlSelectOAuth2OpenIDConnectSession)  	provider.sqlInsertOAuth2PARContext = provider.db.Rebind(provider.sqlInsertOAuth2PARContext) +	provider.sqlUpdateOAuth2PARContext = provider.db.Rebind(provider.sqlUpdateOAuth2PARContext)  	provider.sqlRevokeOAuth2PARContext = provider.db.Rebind(provider.sqlRevokeOAuth2PARContext)  	provider.sqlSelectOAuth2PARContext = provider.db.Rebind(provider.sqlSelectOAuth2PARContext) diff --git a/internal/storage/sql_provider_queries.go b/internal/storage/sql_provider_queries.go index b089e23ce..46a9635b6 100644 --- a/internal/storage/sql_provider_queries.go +++ b/internal/storage/sql_provider_queries.go @@ -319,7 +319,7 @@ const (  		handled_response_types, response_mode, response_mode_default, revoked,  		form_data, session_data  		FROM %s -		WHERE signature = ? AND revoked = FALSE;` +		WHERE signature = ?;`  	queryFmtInsertOAuth2PARContext = `  		INSERT INTO %s (signature, request_id, client_id, requested_at, scopes, audience, @@ -327,6 +327,13 @@ const (  		form_data, session_data)  		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);` +	queryFmtUpdateOAuth2PARContext = ` +	UPDATE %s +	SET signature = ?, request_id = ?, client_id = ?, requested_at = ?, scopes = ?, audience = ?, +	    handled_response_types = ?, response_mode = ?, response_mode_default = ?, revoked = ?, +	    form_data = ?, session_data = ? +	WHERE id = ?;` +  	queryFmtSelectOAuth2BlacklistedJTI = `  		SELECT id, signature, expires_at  		FROM %s  | 
