Skip to content

Commit

Permalink
Adding context param to private helper funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
bobmaertz committed Jan 23, 2025
1 parent 0082e23 commit 24b0f2b
Showing 1 changed file with 46 additions and 31 deletions.
77 changes: 46 additions & 31 deletions storage/sql/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err

func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getAuthRequest(tx, id)
r, err := getAuthRequest(ctx, tx, id)
if err != nil {
return err
}
Expand Down Expand Up @@ -201,10 +201,10 @@ func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a
}

func (c *conn) GetAuthRequest(ctx context.Context, id string) (storage.AuthRequest, error) {
return getAuthRequest(c, id)
return getAuthRequest(ctx, c, id)
}

func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
func getAuthRequest(ctx context.Context, q querier, id string) (a storage.AuthRequest, err error) {
err = q.QueryRow(`
select
id, client_id, response_types, scopes, redirect_uri, nonce, state,
Expand Down Expand Up @@ -312,7 +312,7 @@ func (c *conn) CreateRefresh(ctx context.Context, r storage.RefreshToken) error

func (c *conn) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getRefresh(tx, id)
r, err := getRefresh(ctx, tx, id)
if err != nil {
return err
}
Expand Down Expand Up @@ -355,10 +355,10 @@ func (c *conn) UpdateRefreshToken(ctx context.Context, id string, updater func(o
}

func (c *conn) GetRefresh(ctx context.Context, id string) (storage.RefreshToken, error) {
return getRefresh(c, id)
return getRefresh(ctx, c, id)
}

func getRefresh(q querier, id string) (storage.RefreshToken, error) {
func getRefresh(ctx context.Context, q querier, id string) (storage.RefreshToken, error) {
return scanRefresh(q.QueryRow(`
select
id, client_id, scopes, nonce,
Expand Down Expand Up @@ -423,7 +423,7 @@ func (c *conn) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (s
firstUpdate := false
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
// server. Test this, and consider adding a COUNT() command beforehand.
old, err := getKeys(tx)
old, err := getKeys(ctx, tx)
if err != nil {
if err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
Expand Down Expand Up @@ -472,10 +472,10 @@ func (c *conn) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (s
}

func (c *conn) GetKeys(ctx context.Context) (keys storage.Keys, err error) {
return getKeys(c)
return getKeys(ctx, c)
}

func getKeys(q querier) (keys storage.Keys, err error) {
func getKeys(ctx context.Context, q querier) (keys storage.Keys, err error) {
err = q.QueryRow(`
select
verification_keys, signing_key, signing_key_pub, next_rotation
Expand All @@ -496,7 +496,7 @@ func getKeys(q querier) (keys storage.Keys, err error) {

func (c *conn) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error {
return c.ExecTx(func(tx *trans) error {
cli, err := getClient(tx, id)
cli, err := getClient(ctx, tx, id)
if err != nil {
return err
}
Expand Down Expand Up @@ -543,7 +543,7 @@ func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error {
return nil
}

func getClient(q querier, id string) (storage.Client, error) {
func getClient(ctx context.Context, q querier, id string) (storage.Client, error) {
return scanClient(q.QueryRow(`
select
id, secret, redirect_uris, trusted_peers, public, name, logo_url
Expand All @@ -552,7 +552,7 @@ func getClient(q querier, id string) (storage.Client, error) {
}

func (c *conn) GetClient(ctx context.Context, id string) (storage.Client, error) {
return getClient(c, id)
return getClient(ctx, c, id)
}

func (c *conn) ListClients(ctx context.Context) ([]storage.Client, error) {
Expand Down Expand Up @@ -617,7 +617,7 @@ func (c *conn) CreatePassword(ctx context.Context, p storage.Password) error {

func (c *conn) UpdatePassword(ctx context.Context, email string, updater func(p storage.Password) (storage.Password, error)) error {
return c.ExecTx(func(tx *trans) error {
p, err := getPassword(tx, email)
p, err := getPassword(ctx, tx, email)
if err != nil {
return err
}
Expand All @@ -642,10 +642,10 @@ func (c *conn) UpdatePassword(ctx context.Context, email string, updater func(p
}

func (c *conn) GetPassword(ctx context.Context, email string) (storage.Password, error) {
return getPassword(c, email)
return getPassword(ctx, c, email)
}

func getPassword(q querier, email string) (p storage.Password, err error) {
func getPassword(ctx context.Context, q querier, email string) (p storage.Password, err error) {
return scanPassword(q.QueryRow(`
select
email, hash, username, user_id
Expand Down Expand Up @@ -713,7 +713,7 @@ func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessi

func (c *conn) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
return c.ExecTx(func(tx *trans) error {
s, err := getOfflineSessions(tx, userID, connID)
s, err := getOfflineSessions(ctx, tx, userID, connID)
if err != nil {
return err
}
Expand All @@ -739,10 +739,10 @@ func (c *conn) UpdateOfflineSessions(ctx context.Context, userID string, connID
}

func (c *conn) GetOfflineSessions(ctx context.Context, userID string, connID string) (storage.OfflineSessions, error) {
return getOfflineSessions(c, userID, connID)
return getOfflineSessions(ctx, c, userID, connID)
}

func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
func getOfflineSessions(ctx context.Context, q querier, userID string, connID string) (storage.OfflineSessions, error) {
return scanOfflineSessions(q.QueryRow(`
select
user_id, conn_id, refresh, connector_data
Expand Down Expand Up @@ -786,7 +786,7 @@ func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector)

func (c *conn) UpdateConnector(ctx context.Context, id string, updater func(s storage.Connector) (storage.Connector, error)) error {
return c.ExecTx(func(tx *trans) error {
connector, err := getConnector(tx, id)
connector, err := getConnector(ctx, tx, id)
if err != nil {
return err
}
Expand Down Expand Up @@ -814,10 +814,10 @@ func (c *conn) UpdateConnector(ctx context.Context, id string, updater func(s st
}

func (c *conn) GetConnector(ctx context.Context, id string) (storage.Connector, error) {
return getConnector(c, id)
return getConnector(ctx, c, id)
}

func getConnector(q querier, id string) (storage.Connector, error) {
func getConnector(ctx context.Context, q querier, id string) (storage.Connector, error) {
return scanConnector(q.QueryRow(`
select
id, type, name, resource_version, config
Expand Down Expand Up @@ -864,14 +864,29 @@ func (c *conn) ListConnectors(ctx context.Context) ([]storage.Connector, error)
return connectors, nil
}

func (c *conn) DeleteAuthRequest(ctx context.Context, id string) error { return c.delete("auth_request", "id", id) }
func (c *conn) DeleteAuthCode(ctx context.Context, id string) error { return c.delete("auth_code", "id", id) }
func (c *conn) DeleteClient(ctx context.Context, id string) error { return c.delete("client", "id", id) }
func (c *conn) DeleteRefresh(ctx context.Context, id string) error { return c.delete("refresh_token", "id", id) }
func (c *conn) DeleteAuthRequest(ctx context.Context, id string) error {
return c.delete("auth_request", "id", id)
}

func (c *conn) DeleteAuthCode(ctx context.Context, id string) error {
return c.delete("auth_code", "id", id)
}

func (c *conn) DeleteClient(ctx context.Context, id string) error {
return c.delete("client", "id", id)
}

func (c *conn) DeleteRefresh(ctx context.Context, id string) error {
return c.delete("refresh_token", "id", id)
}

func (c *conn) DeletePassword(ctx context.Context, email string) error {
return c.delete("password", "email", strings.ToLower(email))
}
func (c *conn) DeleteConnector(ctx context.Context, id string) error { return c.delete("connector", "id", id) }

func (c *conn) DeleteConnector(ctx context.Context, id string) error {
return c.delete("connector", "id", id)
}

func (c *conn) DeleteOfflineSessions(ctx context.Context, userID string, connID string) error {
result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID)
Expand Down Expand Up @@ -949,10 +964,10 @@ func (c *conn) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) err
}

func (c *conn) GetDeviceRequest(ctx context.Context, userCode string) (storage.DeviceRequest, error) {
return getDeviceRequest(c, userCode)
return getDeviceRequest(ctx, c, userCode)
}

func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
func getDeviceRequest(ctx context.Context, q querier, userCode string) (d storage.DeviceRequest, err error) {
err = q.QueryRow(`
select
device_code, client_id, client_secret, scopes, expiry
Expand All @@ -971,10 +986,10 @@ func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err
}

func (c *conn) GetDeviceToken(ctx context.Context, deviceCode string) (storage.DeviceToken, error) {
return getDeviceToken(c, deviceCode)
return getDeviceToken(ctx, c, deviceCode)
}

func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
func getDeviceToken(ctx context.Context, q querier, deviceCode string) (a storage.DeviceToken, err error) {
err = q.QueryRow(`
select
status, token, expiry, last_request, poll_interval, code_challenge, code_challenge_method
Expand All @@ -994,7 +1009,7 @@ func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err er

func (c *conn) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getDeviceToken(tx, deviceCode)
r, err := getDeviceToken(ctx, tx, deviceCode)
if err != nil {
return err
}
Expand Down

0 comments on commit 24b0f2b

Please sign in to comment.