Skip to content

Commit

Permalink
Adding context parameter to update and delete functions
Browse files Browse the repository at this point in the history
  • Loading branch information
bobmaertz committed Jan 23, 2025
1 parent 9884b74 commit 0082e23
Show file tree
Hide file tree
Showing 29 changed files with 226 additions and 226 deletions.
18 changes: 9 additions & 9 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (d dexAPI) UpdateClient(ctx context.Context, req *api.UpdateClientReq) (*ap
return nil, errors.New("update client: no client ID supplied")
}

err := d.s.UpdateClient(req.Id, func(old storage.Client) (storage.Client, error) {
err := d.s.UpdateClient(ctx, req.Id, func(old storage.Client) (storage.Client, error) {
if req.RedirectUris != nil {
old.RedirectURIs = req.RedirectUris
}
Expand All @@ -134,7 +134,7 @@ func (d dexAPI) UpdateClient(ctx context.Context, req *api.UpdateClientReq) (*ap
}

func (d dexAPI) DeleteClient(ctx context.Context, req *api.DeleteClientReq) (*api.DeleteClientResp, error) {
err := d.s.DeleteClient(req.Id)
err := d.s.DeleteClient(ctx, req.Id)
if err != nil {
if err == storage.ErrNotFound {
return &api.DeleteClientResp{NotFound: true}, nil
Expand Down Expand Up @@ -219,7 +219,7 @@ func (d dexAPI) UpdatePassword(ctx context.Context, req *api.UpdatePasswordReq)
return old, nil
}

if err := d.s.UpdatePassword(req.Email, updater); err != nil {
if err := d.s.UpdatePassword(ctx, req.Email, updater); err != nil {
if err == storage.ErrNotFound {
return &api.UpdatePasswordResp{NotFound: true}, nil
}
Expand All @@ -235,7 +235,7 @@ func (d dexAPI) DeletePassword(ctx context.Context, req *api.DeletePasswordReq)
return nil, errors.New("no email supplied")
}

err := d.s.DeletePassword(req.Email)
err := d.s.DeletePassword(ctx, req.Email)
if err != nil {
if err == storage.ErrNotFound {
return &api.DeletePasswordResp{NotFound: true}, nil
Expand Down Expand Up @@ -381,7 +381,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*
return old, nil
}

if err := d.s.UpdateOfflineSessions(id.UserId, id.ConnId, updater); err != nil {
if err := d.s.UpdateOfflineSessions(ctx, id.UserId, id.ConnId, updater); err != nil {
if err == storage.ErrNotFound {
return &api.RevokeRefreshResp{NotFound: true}, nil
}
Expand All @@ -397,7 +397,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*
//
// TODO(ericchiang): we don't have any good recourse if this call fails.
// Consider garbage collection of refresh tokens with no associated ref.
if err := d.s.DeleteRefresh(refreshID); err != nil {
if err := d.s.DeleteRefresh(ctx, refreshID); err != nil {
d.logger.Error("failed to delete refresh token", "err", err)
return nil, err
}
Expand Down Expand Up @@ -448,7 +448,7 @@ func (d dexAPI) CreateConnector(ctx context.Context, req *api.CreateConnectorReq
return &api.CreateConnectorResp{}, nil
}

func (d dexAPI) UpdateConnector(_ context.Context, req *api.UpdateConnectorReq) (*api.UpdateConnectorResp, error) {
func (d dexAPI) UpdateConnector(ctx context.Context, req *api.UpdateConnectorReq) (*api.UpdateConnectorResp, error) {
if !featureflags.APIConnectorsCRUD.Enabled() {
return nil, fmt.Errorf("%s feature flag is not enabled", featureflags.APIConnectorsCRUD.Name)
}
Expand Down Expand Up @@ -485,7 +485,7 @@ func (d dexAPI) UpdateConnector(_ context.Context, req *api.UpdateConnectorReq)
return old, nil
}

if err := d.s.UpdateConnector(req.Id, updater); err != nil {
if err := d.s.UpdateConnector(ctx, req.Id, updater); err != nil {
if err == storage.ErrNotFound {
return &api.UpdateConnectorResp{NotFound: true}, nil
}
Expand All @@ -505,7 +505,7 @@ func (d dexAPI) DeleteConnector(ctx context.Context, req *api.DeleteConnectorReq
return nil, errors.New("no id supplied")
}

err := d.s.DeleteConnector(req.Id)
err := d.s.DeleteConnector(ctx, req.Id)
if err != nil {
if err == storage.ErrNotFound {
return &api.DeleteConnectorResp{NotFound: true}, nil
Expand Down
4 changes: 2 additions & 2 deletions server/deviceflowhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
return old, nil
}
// Update device token last request time in storage
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
if err := s.storage.UpdateDeviceToken(ctx, deviceCode, updater); err != nil {
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return
Expand Down Expand Up @@ -374,7 +374,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}

// Update refresh token in the storage, store the token and mark as complete
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
if err := s.storage.UpdateDeviceToken(ctx, deviceReq.DeviceCode, updater); err != nil {
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusBadRequest, "")
return
Expand Down
20 changes: 10 additions & 10 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
a.ConnectorData = identity.ConnectorData
return a, nil
}
if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil {
if err := s.storage.UpdateAuthRequest(ctx, authReq.ID, updater); err != nil {
return "", false, fmt.Errorf("failed to update auth request: %v", err)
}

Expand Down Expand Up @@ -565,7 +565,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
}
case err == nil:
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if len(identity.ConnectorData) > 0 {
old.ConnectorData = identity.ConnectorData
}
Expand Down Expand Up @@ -657,7 +657,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
return
}

if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil {
if err := s.storage.DeleteAuthRequest(ctx, authReq.ID); err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "Failed to delete authorization request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
Expand Down Expand Up @@ -954,7 +954,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
return nil, err
}

if err := s.storage.DeleteAuthCode(authCode.ID); err != nil {
if err := s.storage.DeleteAuthCode(ctx, authCode.ID); err != nil {
s.logger.ErrorContext(ctx, "failed to delete auth code", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err
Expand Down Expand Up @@ -1020,7 +1020,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
defer func() {
if deleteToken {
// Delete newly created refresh token from storage.
if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
if err := s.storage.DeleteRefresh(ctx, refresh.ID); err != nil {
s.logger.ErrorContext(ctx, "failed to delete refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -1061,7 +1061,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound {
if err := s.storage.DeleteRefresh(ctx, oldTokenRef.ID); err != nil && err != storage.ErrNotFound {
s.logger.ErrorContext(ctx, "failed to delete refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
Expand All @@ -1070,7 +1070,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
}

// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil
}); err != nil {
Expand Down Expand Up @@ -1272,7 +1272,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
defer func() {
if deleteToken {
// Delete newly created refresh token from storage.
if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
if err := s.storage.DeleteRefresh(ctx, refresh.ID); err != nil {
s.logger.ErrorContext(r.Context(), "failed to delete refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -1314,7 +1314,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
if err := s.storage.DeleteRefresh(ctx, oldTokenRef.ID); err != nil {
if err == storage.ErrNotFound {
s.logger.Warn("database inconsistent, refresh token missing", "token_id", oldTokenRef.ID)
} else {
Expand All @@ -1327,7 +1327,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
}

// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
old.ConnectorData = identity.ConnectorData
return old, nil
Expand Down
2 changes: 1 addition & 1 deletion server/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ func TestValidRedirectURI(t *testing.T) {

func TestStorageKeySet(t *testing.T) {
s := memory.New(logger)
if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
if err := s.UpdateKeys(context.TODO(), func(keys storage.Keys) (storage.Keys, error) {
keys.SigningKey = &jose.JSONWebKey{
Key: testKey,
KeyID: "testkey",
Expand Down
4 changes: 2 additions & 2 deletions server/refreshhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func (s *Server) updateOfflineSession(ctx context.Context, refresh *storage.Refr

// Update LastUsed time stamp in refresh token reference object
// in offline session for the user.
err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater)
err := s.storage.UpdateOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater)
if err != nil {
s.logger.ErrorContext(ctx, "failed to update offline session", "err", err)
return newInternalServerError()
Expand Down Expand Up @@ -314,7 +314,7 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
}

// Update refresh token in the storage.
err := s.storage.UpdateRefreshToken(rCtx.storageToken.ID, refreshTokenUpdater)
err := s.storage.UpdateRefreshToken(ctx, rCtx.storageToken.ID, refreshTokenUpdater)
if err != nil {
s.logger.ErrorContext(ctx, "failed to update refresh token", "err", err)
return nil, ident, newInternalServerError()
Expand Down
2 changes: 1 addition & 1 deletion server/rotation.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (k keyRotator) rotate() error {
}

var nextRotation time.Time
err = k.Storage.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
err = k.Storage.UpdateKeys(context.Background(), func(keys storage.Keys) (storage.Keys, error) {
tNow := k.now()

// if you are running multiple instances of dex, another instance
Expand Down
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura
case <-ctx.Done():
return
case <-time.After(frequency):
if r, err := s.storage.GarbageCollect(now()); err != nil {
if r, err := s.storage.GarbageCollect(ctx, now()); err != nil {
s.logger.ErrorContext(ctx, "garbage collection failed", "err", err)
} else if !r.IsEmpty() {
s.logger.InfoContext(ctx, "garbage collection run, delete auth",
Expand Down
6 changes: 3 additions & 3 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ func (s storageWithKeysTrigger) GetKeys(ctx context.Context) (storage.Keys, erro
func TestKeyCacher(t *testing.T) {
tNow := time.Now()
now := func() time.Time { return tNow }

ctx := context.TODO()
s := memory.New(logger)

tests := []struct {
Expand All @@ -1390,7 +1390,7 @@ func TestKeyCacher(t *testing.T) {
},
{
before: func() {
s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
s.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) {
old.NextRotation = tNow.Add(time.Minute)
return old, nil
})
Expand All @@ -1410,7 +1410,7 @@ func TestKeyCacher(t *testing.T) {
{
before: func() {
tNow = tNow.Add(time.Hour)
s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
s.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) {
old.NextRotation = tNow.Add(time.Minute)
return old, nil
})
Expand Down
Loading

0 comments on commit 0082e23

Please sign in to comment.