From bcf402aab04eaf8a0756b41247875f200baae8c0 Mon Sep 17 00:00:00 2001 From: Bob Maertz <1771054+bobmaertz@users.noreply.github.com> Date: Wed, 22 Jan 2025 22:12:30 -0500 Subject: [PATCH] Adding context param to private helper funcs Signed-off-by: Bob Maertz <1771054+bobmaertz@users.noreply.github.com> --- storage/sql/crud.go | 77 +++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 0ce86fdc1d..a9ca38167d 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -160,7 +160,7 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { return c.ExecTx(func(tx *trans) error { - r, err := getAuthRequest(tx, id) + r, err := getAuthRequest(ctx, tx, id) if err != nil { return err } @@ -201,10 +201,10 @@ func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a } func (c *conn) GetAuthRequest(ctx context.Context, id string) (storage.AuthRequest, error) { - return getAuthRequest(c, id) + return getAuthRequest(ctx, c, id) } -func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { +func getAuthRequest(ctx context.Context, q querier, id string) (a storage.AuthRequest, err error) { err = q.QueryRow(` select id, client_id, response_types, scopes, redirect_uri, nonce, state, @@ -312,7 +312,7 @@ func (c *conn) CreateRefresh(ctx context.Context, r storage.RefreshToken) error func (c *conn) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { return c.ExecTx(func(tx *trans) error { - r, err := getRefresh(tx, id) + r, err := getRefresh(ctx, tx, id) if err != nil { return err } @@ -355,10 +355,10 @@ func (c *conn) UpdateRefreshToken(ctx context.Context, id string, updater func(o } func (c *conn) GetRefresh(ctx context.Context, id string) (storage.RefreshToken, error) { - return getRefresh(c, id) + return getRefresh(ctx, c, id) } -func getRefresh(q querier, id string) (storage.RefreshToken, error) { +func getRefresh(ctx context.Context, q querier, id string) (storage.RefreshToken, error) { return scanRefresh(q.QueryRow(` select id, client_id, scopes, nonce, @@ -423,7 +423,7 @@ func (c *conn) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (s firstUpdate := false // TODO(ericchiang): errors may cause a transaction be rolled back by the SQL // server. Test this, and consider adding a COUNT() command beforehand. - old, err := getKeys(tx) + old, err := getKeys(ctx, tx) if err != nil { if err != storage.ErrNotFound { return fmt.Errorf("get keys: %v", err) @@ -472,10 +472,10 @@ func (c *conn) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (s } func (c *conn) GetKeys(ctx context.Context) (keys storage.Keys, err error) { - return getKeys(c) + return getKeys(ctx, c) } -func getKeys(q querier) (keys storage.Keys, err error) { +func getKeys(ctx context.Context, q querier) (keys storage.Keys, err error) { err = q.QueryRow(` select verification_keys, signing_key, signing_key_pub, next_rotation @@ -496,7 +496,7 @@ func getKeys(q querier) (keys storage.Keys, err error) { func (c *conn) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error { return c.ExecTx(func(tx *trans) error { - cli, err := getClient(tx, id) + cli, err := getClient(ctx, tx, id) if err != nil { return err } @@ -543,7 +543,7 @@ func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error { return nil } -func getClient(q querier, id string) (storage.Client, error) { +func getClient(ctx context.Context, q querier, id string) (storage.Client, error) { return scanClient(q.QueryRow(` select id, secret, redirect_uris, trusted_peers, public, name, logo_url @@ -552,7 +552,7 @@ func getClient(q querier, id string) (storage.Client, error) { } func (c *conn) GetClient(ctx context.Context, id string) (storage.Client, error) { - return getClient(c, id) + return getClient(ctx, c, id) } func (c *conn) ListClients(ctx context.Context) ([]storage.Client, error) { @@ -617,7 +617,7 @@ func (c *conn) CreatePassword(ctx context.Context, p storage.Password) error { func (c *conn) UpdatePassword(ctx context.Context, email string, updater func(p storage.Password) (storage.Password, error)) error { return c.ExecTx(func(tx *trans) error { - p, err := getPassword(tx, email) + p, err := getPassword(ctx, tx, email) if err != nil { return err } @@ -642,10 +642,10 @@ func (c *conn) UpdatePassword(ctx context.Context, email string, updater func(p } func (c *conn) GetPassword(ctx context.Context, email string) (storage.Password, error) { - return getPassword(c, email) + return getPassword(ctx, c, email) } -func getPassword(q querier, email string) (p storage.Password, err error) { +func getPassword(ctx context.Context, q querier, email string) (p storage.Password, err error) { return scanPassword(q.QueryRow(` select email, hash, username, user_id @@ -713,7 +713,7 @@ func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessi func (c *conn) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { return c.ExecTx(func(tx *trans) error { - s, err := getOfflineSessions(tx, userID, connID) + s, err := getOfflineSessions(ctx, tx, userID, connID) if err != nil { return err } @@ -739,10 +739,10 @@ func (c *conn) UpdateOfflineSessions(ctx context.Context, userID string, connID } func (c *conn) GetOfflineSessions(ctx context.Context, userID string, connID string) (storage.OfflineSessions, error) { - return getOfflineSessions(c, userID, connID) + return getOfflineSessions(ctx, c, userID, connID) } -func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { +func getOfflineSessions(ctx context.Context, q querier, userID string, connID string) (storage.OfflineSessions, error) { return scanOfflineSessions(q.QueryRow(` select user_id, conn_id, refresh, connector_data @@ -786,7 +786,7 @@ func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) func (c *conn) UpdateConnector(ctx context.Context, id string, updater func(s storage.Connector) (storage.Connector, error)) error { return c.ExecTx(func(tx *trans) error { - connector, err := getConnector(tx, id) + connector, err := getConnector(ctx, tx, id) if err != nil { return err } @@ -814,10 +814,10 @@ func (c *conn) UpdateConnector(ctx context.Context, id string, updater func(s st } func (c *conn) GetConnector(ctx context.Context, id string) (storage.Connector, error) { - return getConnector(c, id) + return getConnector(ctx, c, id) } -func getConnector(q querier, id string) (storage.Connector, error) { +func getConnector(ctx context.Context, q querier, id string) (storage.Connector, error) { return scanConnector(q.QueryRow(` select id, type, name, resource_version, config @@ -864,14 +864,29 @@ func (c *conn) ListConnectors(ctx context.Context) ([]storage.Connector, error) return connectors, nil } -func (c *conn) DeleteAuthRequest(ctx context.Context, id string) error { return c.delete("auth_request", "id", id) } -func (c *conn) DeleteAuthCode(ctx context.Context, id string) error { return c.delete("auth_code", "id", id) } -func (c *conn) DeleteClient(ctx context.Context, id string) error { return c.delete("client", "id", id) } -func (c *conn) DeleteRefresh(ctx context.Context, id string) error { return c.delete("refresh_token", "id", id) } +func (c *conn) DeleteAuthRequest(ctx context.Context, id string) error { + return c.delete("auth_request", "id", id) +} + +func (c *conn) DeleteAuthCode(ctx context.Context, id string) error { + return c.delete("auth_code", "id", id) +} + +func (c *conn) DeleteClient(ctx context.Context, id string) error { + return c.delete("client", "id", id) +} + +func (c *conn) DeleteRefresh(ctx context.Context, id string) error { + return c.delete("refresh_token", "id", id) +} + func (c *conn) DeletePassword(ctx context.Context, email string) error { return c.delete("password", "email", strings.ToLower(email)) } -func (c *conn) DeleteConnector(ctx context.Context, id string) error { return c.delete("connector", "id", id) } + +func (c *conn) DeleteConnector(ctx context.Context, id string) error { + return c.delete("connector", "id", id) +} func (c *conn) DeleteOfflineSessions(ctx context.Context, userID string, connID string) error { result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID) @@ -949,10 +964,10 @@ func (c *conn) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) err } func (c *conn) GetDeviceRequest(ctx context.Context, userCode string) (storage.DeviceRequest, error) { - return getDeviceRequest(c, userCode) + return getDeviceRequest(ctx, c, userCode) } -func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) { +func getDeviceRequest(ctx context.Context, q querier, userCode string) (d storage.DeviceRequest, err error) { err = q.QueryRow(` select device_code, client_id, client_secret, scopes, expiry @@ -971,10 +986,10 @@ func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err } func (c *conn) GetDeviceToken(ctx context.Context, deviceCode string) (storage.DeviceToken, error) { - return getDeviceToken(c, deviceCode) + return getDeviceToken(ctx, c, deviceCode) } -func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) { +func getDeviceToken(ctx context.Context, q querier, deviceCode string) (a storage.DeviceToken, err error) { err = q.QueryRow(` select status, token, expiry, last_request, poll_interval, code_challenge, code_challenge_method @@ -994,7 +1009,7 @@ func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err er func (c *conn) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { return c.ExecTx(func(tx *trans) error { - r, err := getDeviceToken(tx, deviceCode) + r, err := getDeviceToken(ctx, tx, deviceCode) if err != nil { return err }