diff options
Diffstat (limited to 'internal/storage/sql_provider.go')
| -rw-r--r-- | internal/storage/sql_provider.go | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index c458722be..c2c0e7621 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -121,6 +121,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession), sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession), + sqlInsertOAuth2DeviceCodeSession: fmt.Sprintf(queryFmtInsertOAuth2DeviceCodeSession, tableOAuth2DeviceCodeSession), + sqlSelectOAuth2DeviceCodeSession: fmt.Sprintf(queryFmtSelectOAuth2DeviceCodeSession, tableOAuth2DeviceCodeSession), + sqlUpdateOAuth2DeviceCodeSession: fmt.Sprintf(queryFmtUpdateOAuth2DeviceCodeSession, tableOAuth2DeviceCodeSession), + sqlDeactivateOAuth2DeviceCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2DeviceCodeSession), + sqlSelectOAuth2DeviceCodeSessionByUserCode: fmt.Sprintf(queryFmtSelectOAuth2DeviceCodeSessionByUserCode, tableOAuth2DeviceCodeSession), + sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession), sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession), sqlRevokeOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2OpenIDConnectSession), @@ -263,6 +269,13 @@ type SQLProvider struct { sqlDeactivateOAuth2AuthorizeCodeSession string sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID string + // Table: oauth2_device_code_session. + sqlInsertOAuth2DeviceCodeSession string + sqlSelectOAuth2DeviceCodeSession string + sqlUpdateOAuth2DeviceCodeSession string + sqlDeactivateOAuth2DeviceCodeSession string + sqlSelectOAuth2DeviceCodeSessionByUserCode string + // Table: oauth2_access_token_session. sqlInsertOAuth2AccessTokenSession string sqlSelectOAuth2AccessTokenSession string @@ -1234,6 +1247,74 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S return session, nil } +func (p *SQLProvider) SaveOAuth2DeviceCodeSession(ctx context.Context, session *model.OAuth2DeviceCodeSession) (err error) { + if session.Session, err = p.encrypt(session.Session); err != nil { + return fmt.Errorf("error encrypting oauth2 device code session data for session with signature '%s' for subject '%s' and request id '%s': %w", session.Subject.String, session.Signature, session.RequestID, err) + } + + _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2DeviceCodeSession, + session.ChallengeID, session.RequestID, session.ClientID, session.Signature, session.UserCodeSignature, + session.Status, session.Subject, session.RequestedAt, session.CheckedAt, + session.RequestedScopes, session.GrantedScopes, + session.RequestedAudience, session.GrantedAudience, + session.Active, session.Revoked, session.Form, session.Session) + + if err != nil { + return fmt.Errorf("error inserting oauth2 device code session with device code signature '%s' and user code signature '%s' for subject '%s' and request id '%s': %w", session.Signature, session.UserCodeSignature, session.Subject.String, session.RequestID, err) + } + + return nil +} + +func (p *SQLProvider) UpdateOAuth2DeviceCodeSession(ctx context.Context, signature string, status int, checked time.Time) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2DeviceCodeSession, + checked, status, signature) + + if err != nil { + return fmt.Errorf("error updating oauth2 device code session data with device code signature '%s': %w", signature, err) + } + + return nil +} + +func (p *SQLProvider) DeactivateOAuth2DeviceCodeSession(ctx context.Context, signature string) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlDeactivateOAuth2DeviceCodeSession, signature) + + if err != nil { + return fmt.Errorf("error deactivating oauth2 device code session with device code signature '%s': %w", signature, err) + } + + return nil +} + +func (p *SQLProvider) LoadOAuth2DeviceCodeSession(ctx context.Context, signature string) (session *model.OAuth2DeviceCodeSession, err error) { + session = &model.OAuth2DeviceCodeSession{} + + if err = p.db.GetContext(ctx, session, p.sqlSelectOAuth2DeviceCodeSession, signature); err != nil { + return nil, fmt.Errorf("error selecting oauth2 device code session with device code signature '%s': %w", signature, err) + } + + if session.Session, err = p.decrypt(session.Session); err != nil { + return nil, fmt.Errorf("error decrypting the oauth2 device code session data with device code signature '%s' for subject '%s' and request id '%s': %w", signature, session.Subject.String, session.RequestID, err) + } + + return session, nil +} + +func (p *SQLProvider) LoadOAuth2DeviceCodeSessionByUserCode(ctx context.Context, signature string) (session *model.OAuth2DeviceCodeSession, err error) { + session = &model.OAuth2DeviceCodeSession{} + + if err = p.db.GetContext(ctx, session, p.sqlSelectOAuth2DeviceCodeSessionByUserCode, signature); err != nil { + return nil, fmt.Errorf("error selecting oauth2 device code session with user code signature '%s': %w", signature, err) + } + + if session.Session, err = p.decrypt(session.Session); err != nil { + return nil, fmt.Errorf("error decrypting the oauth2 device code session data with user code signature '%s' for subject '%s' and request id '%s': %w", signature, session.Subject.String, session.RequestID, err) + } + + return session, nil +} + // SaveOAuth2PARContext save an OAuth2.0 PAR context to the storage provider. func (p *SQLProvider) SaveOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) { if par.Session, err = p.encrypt(par.Session); err != nil { |
