From c4cda40b1ce9c35938f52ad367aef31299daa8e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Schwartz?= <40674593+3schwartz@users.noreply.github.com> Date: Fri, 15 Nov 2024 19:40:21 +0100 Subject: [PATCH 1/4] update claims instead of overwrite --- oauth2/token_hook.go | 18 +++++++-- oauth2/token_hook_test.go | 84 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 oauth2/token_hook_test.go diff --git a/oauth2/token_hook.go b/oauth2/token_hook.go index d32cadd7e4d..2d3b84ae77f 100644 --- a/oauth2/token_hook.go +++ b/oauth2/token_hook.go @@ -143,13 +143,25 @@ func executeHookAndUpdateSession(ctx context.Context, reg x.HTTPClientProvider, ) } - // Overwrite existing session data (extra claims). - session.Extra = respBody.Session.AccessToken + // Update existing session data (extra claims). + session.Extra = updateExtraClaims(session.Extra, respBody.Session.AccessToken) idTokenClaims := session.IDTokenClaims() - idTokenClaims.Extra = respBody.Session.IDToken + idTokenClaims.Extra = updateExtraClaims(idTokenClaims.Extra, respBody.Session.IDToken) return nil } +func updateExtraClaims(priorExtraClaims, webhookExtraClaims map[string]interface{}) map[string]interface{} { + updatedClaims := make(map[string]interface{}) + for key, value := range priorExtraClaims { + updatedClaims[key] = value + } + for key, value := range webhookExtraClaims { + updatedClaims[key] = value + } + + return updatedClaims +} + // TokenHook is an AccessRequestHook called for all grant types. func TokenHook(reg interface { config.Provider diff --git a/oauth2/token_hook_test.go b/oauth2/token_hook_test.go new file mode 100644 index 00000000000..2af6d6acc79 --- /dev/null +++ b/oauth2/token_hook_test.go @@ -0,0 +1,84 @@ +package oauth2 + +import ( + "reflect" + "testing" +) + +func TestUpdateExtraClaims(t *testing.T) { + tests := []struct { + name string + priorExtraClaims map[string]interface{} + webhookExtraClaims map[string]interface{} + expected map[string]interface{} + }{ + { + name: "Merge with no updates", + priorExtraClaims: map[string]interface{}{ + "claim1": "value1", + "claim2": "value2", + }, + webhookExtraClaims: map[string]interface{}{ + "claim3": "value3", + "claim4": "value4", + }, + expected: map[string]interface{}{ + "claim1": "value1", + "claim2": "value2", + "claim3": "value3", + "claim4": "value4", + }, + }, + { + name: "Merge with updates", + priorExtraClaims: map[string]interface{}{ + "claim1": "value1", + "claim2": "value2", + }, + webhookExtraClaims: map[string]interface{}{ + "claim2": "newValue2", // Overwrites prior claim2 + "claim3": "value3", + }, + expected: map[string]interface{}{ + "claim1": "value1", + "claim2": "newValue2", + "claim3": "value3", + }, + }, + { + name: "Empty webhook claims", + priorExtraClaims: map[string]interface{}{ + "claim1": "value1", + }, + webhookExtraClaims: map[string]interface{}{}, + expected: map[string]interface{}{ + "claim1": "value1", + }, + }, + { + name: "Empty prior claims", + priorExtraClaims: map[string]interface{}{}, + webhookExtraClaims: map[string]interface{}{ + "claim1": "value1", + }, + expected: map[string]interface{}{ + "claim1": "value1", + }, + }, + { + name: "Both maps empty", + priorExtraClaims: map[string]interface{}{}, + webhookExtraClaims: map[string]interface{}{}, + expected: map[string]interface{}{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := updateExtraClaims(tt.priorExtraClaims, tt.webhookExtraClaims) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("updateExtraClaims() = %v, want %v", result, tt.expected) + } + }) + } +} From fe14b383c3ae41cd9d34247fdea060b5d5de9bb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Schwartz?= <40674593+3schwartz@users.noreply.github.com> Date: Fri, 15 Nov 2024 19:53:41 +0100 Subject: [PATCH 2/4] avoid additional map allocation --- oauth2/token_hook.go | 14 ++++---------- oauth2/token_hook_test.go | 12 +++++++++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/oauth2/token_hook.go b/oauth2/token_hook.go index 2d3b84ae77f..28ce149bf3c 100644 --- a/oauth2/token_hook.go +++ b/oauth2/token_hook.go @@ -144,22 +144,16 @@ func executeHookAndUpdateSession(ctx context.Context, reg x.HTTPClientProvider, } // Update existing session data (extra claims). - session.Extra = updateExtraClaims(session.Extra, respBody.Session.AccessToken) + updateExtraClaims(session.Extra, respBody.Session.AccessToken) idTokenClaims := session.IDTokenClaims() - idTokenClaims.Extra = updateExtraClaims(idTokenClaims.Extra, respBody.Session.IDToken) + updateExtraClaims(idTokenClaims.Extra, respBody.Session.IDToken) return nil } -func updateExtraClaims(priorExtraClaims, webhookExtraClaims map[string]interface{}) map[string]interface{} { - updatedClaims := make(map[string]interface{}) - for key, value := range priorExtraClaims { - updatedClaims[key] = value - } +func updateExtraClaims(claimsToUpdate, webhookExtraClaims map[string]interface{}) { for key, value := range webhookExtraClaims { - updatedClaims[key] = value + claimsToUpdate[key] = value } - - return updatedClaims } // TokenHook is an AccessRequestHook called for all grant types. diff --git a/oauth2/token_hook_test.go b/oauth2/token_hook_test.go index 2af6d6acc79..a313badc0f2 100644 --- a/oauth2/token_hook_test.go +++ b/oauth2/token_hook_test.go @@ -1,3 +1,6 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + package oauth2 import ( @@ -75,9 +78,12 @@ func TestUpdateExtraClaims(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := updateExtraClaims(tt.priorExtraClaims, tt.webhookExtraClaims) - if !reflect.DeepEqual(result, tt.expected) { - t.Errorf("updateExtraClaims() = %v, want %v", result, tt.expected) + // Act + updateExtraClaims(tt.priorExtraClaims, tt.webhookExtraClaims) + + // Assert + if !reflect.DeepEqual(tt.priorExtraClaims, tt.expected) { + t.Errorf("claimsToUpdate = %v, want %v", tt.priorExtraClaims, tt.expected) } }) } From 070101c281dd4813acfc6e77d5835e680a4d9081 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Schwartz?= <40674593+3schwartz@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:01:29 +0100 Subject: [PATCH 3/4] fix snapshot tests --- ...-should_call_refresh_token_hook_if_configured-hook=new.json | 3 ++- ...-should_call_refresh_token_hook_if_configured-hook=new.json | 3 ++- ...-should_call_refresh_token_hook_if_configured-hook=new.json | 3 ++- ...-should_call_refresh_token_hook_if_configured-hook=new.json | 3 ++- ...-should_call_refresh_token_hook_if_configured-hook=new.json | 3 ++- ...-should_call_refresh_token_hook_if_configured-hook=new.json | 3 ++- 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f1..e1381a9044c 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f1..e1381a9044c 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f1..e1381a9044c 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f1..e1381a9044c 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f1..e1381a9044c 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f1..e1381a9044c 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { From 6c25348f78c5844fc7b5fb9dd1716f16c09ab073 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Schwartz?= <40674593+3schwartz@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:45:39 +0100 Subject: [PATCH 4/4] fix nil case --- oauth2/token_hook.go | 13 ++++++++++--- oauth2/token_hook_test.go | 25 +++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/oauth2/token_hook.go b/oauth2/token_hook.go index 28ce149bf3c..83c7491aadd 100644 --- a/oauth2/token_hook.go +++ b/oauth2/token_hook.go @@ -144,16 +144,23 @@ func executeHookAndUpdateSession(ctx context.Context, reg x.HTTPClientProvider, } // Update existing session data (extra claims). - updateExtraClaims(session.Extra, respBody.Session.AccessToken) + session.Extra = updateExtraClaims(session.Extra, respBody.Session.AccessToken) idTokenClaims := session.IDTokenClaims() - updateExtraClaims(idTokenClaims.Extra, respBody.Session.IDToken) + idTokenClaims.Extra = updateExtraClaims(idTokenClaims.Extra, respBody.Session.IDToken) return nil } -func updateExtraClaims(claimsToUpdate, webhookExtraClaims map[string]interface{}) { +func updateExtraClaims(claimsToUpdate, webhookExtraClaims map[string]interface{}) map[string]interface{} { + if webhookExtraClaims == nil { + return claimsToUpdate + } + if claimsToUpdate == nil { + claimsToUpdate = make(map[string]interface{}) + } for key, value := range webhookExtraClaims { claimsToUpdate[key] = value } + return claimsToUpdate } // TokenHook is an AccessRequestHook called for all grant types. diff --git a/oauth2/token_hook_test.go b/oauth2/token_hook_test.go index a313badc0f2..d6c3dd24d66 100644 --- a/oauth2/token_hook_test.go +++ b/oauth2/token_hook_test.go @@ -74,15 +74,36 @@ func TestUpdateExtraClaims(t *testing.T) { webhookExtraClaims: map[string]interface{}{}, expected: map[string]interface{}{}, }, + { + name: "Nil webhook claims", + priorExtraClaims: map[string]interface{}{"claim1": "value1"}, + webhookExtraClaims: nil, + expected: map[string]interface{}{"claim1": "value1"}, + }, + { + name: "Nil prior claims", + priorExtraClaims: nil, + webhookExtraClaims: map[string]interface{}{"claim1": "value1"}, + expected: map[string]interface{}{"claim1": "value1"}, + }, + { + name: "Both maps nil", + priorExtraClaims: nil, + webhookExtraClaims: nil, + expected: nil, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Act - updateExtraClaims(tt.priorExtraClaims, tt.webhookExtraClaims) + if tt.priorExtraClaims == nil { + tt.priorExtraClaims = nil // Explicitly ensure nil for this test case + } + actual := updateExtraClaims(tt.priorExtraClaims, tt.webhookExtraClaims) // Assert - if !reflect.DeepEqual(tt.priorExtraClaims, tt.expected) { + if !reflect.DeepEqual(actual, tt.expected) { t.Errorf("claimsToUpdate = %v, want %v", tt.priorExtraClaims, tt.expected) } })