diff --git a/crates/cli/src/sync.rs b/crates/cli/src/sync.rs index 124ceb53..b61cc693 100644 --- a/crates/cli/src/sync.rs +++ b/crates/cli/src/sync.rs @@ -235,14 +235,16 @@ pub async fn config_sync( } }; - let response_mode = match provider.response_mode { - mas_config::UpstreamOAuth2ResponseMode::Query => { - mas_data_model::UpstreamOAuthProviderResponseMode::Query - } - mas_config::UpstreamOAuth2ResponseMode::FormPost => { - mas_data_model::UpstreamOAuthProviderResponseMode::FormPost - } - }; + let response_mode = provider + .response_mode + .map(|response_mode| match response_mode { + mas_config::UpstreamOAuth2ResponseMode::Query => { + mas_data_model::UpstreamOAuthProviderResponseMode::Query + } + mas_config::UpstreamOAuth2ResponseMode::FormPost => { + mas_data_model::UpstreamOAuthProviderResponseMode::FormPost + } + }); if discovery_mode.is_disabled() { if provider.authorization_endpoint.is_none() { diff --git a/crates/config/src/sections/upstream_oauth2.rs b/crates/config/src/sections/upstream_oauth2.rs index e6e30ee7..6fc47f1e 100644 --- a/crates/config/src/sections/upstream_oauth2.rs +++ b/crates/config/src/sections/upstream_oauth2.rs @@ -114,12 +114,11 @@ impl ConfigurationSection for UpstreamOAuth2Config { } /// The response mode we ask the provider to use for the callback -#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, JsonSchema)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum ResponseMode { /// `query`: The provider will send the response as a query string in the /// URL search parameters - #[default] Query, /// `form_post`: The provider will send the response as a POST request with @@ -129,13 +128,6 @@ pub enum ResponseMode { FormPost, } -impl ResponseMode { - #[allow(clippy::trivially_copy_pass_by_ref)] - const fn is_default(&self) -> bool { - matches!(self, ResponseMode::Query) - } -} - /// Authentication methods used against the OAuth 2.0 provider #[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] @@ -561,8 +553,8 @@ pub struct Provider { pub jwks_uri: Option, /// The response mode we ask the provider to use for the callback - #[serde(default, skip_serializing_if = "ResponseMode::is_default")] - pub response_mode: ResponseMode, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_mode: Option, /// How claims should be imported from the `id_token` provided by the /// provider diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index 55bfb41b..414c25c5 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -236,7 +236,7 @@ pub struct UpstreamOAuthProvider { pub token_endpoint_signing_alg: Option, pub token_endpoint_auth_method: TokenAuthMethod, pub id_token_signed_response_alg: JsonWebSignatureAlg, - pub response_mode: ResponseMode, + pub response_mode: Option, pub created_at: DateTime, pub disabled_at: Option>, pub claims_imports: ClaimsImports, diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 67d4d9a7..c1dd3f34 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -83,12 +83,15 @@ pub(crate) async fn get( let redirect_uri = url_builder.upstream_oauth_callback(provider.id); - let data = AuthorizationRequestData::new( + let mut data = AuthorizationRequestData::new( provider.client_id.clone(), provider.scope.clone(), redirect_uri, - ) - .with_response_mode(provider.response_mode.into()); + ); + + if let Some(response_mode) = provider.response_mode { + data = data.with_response_mode(response_mode.into()); + } let data = if let Some(methods) = lazy_metadata.pkce_methods().await? { data.with_code_challenge_methods_supported(methods) diff --git a/crates/handlers/src/upstream_oauth2/cache.rs b/crates/handlers/src/upstream_oauth2/cache.rs index 27b6c509..cac97a41 100644 --- a/crates/handlers/src/upstream_oauth2/cache.rs +++ b/crates/handlers/src/upstream_oauth2/cache.rs @@ -417,8 +417,8 @@ mod tests { encrypted_client_secret: None, token_endpoint_signing_alg: None, token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, - response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, id_token_signed_response_alg: JsonWebSignatureAlg::Rs256, + response_mode: None, created_at: clock.now(), disabled_at: None, claims_imports: UpstreamOAuthProviderClaimsImports::default(), diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index 1d831a35..7a6e0f21 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -109,7 +109,7 @@ pub(crate) enum RouteError { MissingFormParams, #[error("Invalid response mode, expected '{expected}'")] - InvalidParamsMode { + InvalidResponseMode { expected: UpstreamOAuthProviderResponseMode, }, @@ -185,8 +185,7 @@ pub(crate) async fn handler( // the query parameters for GET requests. We need to then look at the method do // make sure it matches the expected `response_mode` match (provider.response_mode, method) { - (UpstreamOAuthProviderResponseMode::Query, Method::GET) => {} - (UpstreamOAuthProviderResponseMode::FormPost, Method::POST) => { + (Some(UpstreamOAuthProviderResponseMode::FormPost) | None, Method::POST) => { // We set the cookies with a `Same-Site` policy set to `Lax`, so because this is // usually a cross-site form POST, we need to render a form with the // same values, which posts back to the same URL. However, there are @@ -202,7 +201,8 @@ pub(crate) async fn handler( return Ok(Html(html).into_response()); } } - (expected, _) => return Err(RouteError::InvalidParamsMode { expected }), + (None, _) | (Some(UpstreamOAuthProviderResponseMode::Query), Method::GET) => {} + (Some(expected), _) => return Err(RouteError::InvalidResponseMode { expected }), } let (session_id, _post_auth_action) = sessions_cookie diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 6614e4aa..e48ff190 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -934,7 +934,7 @@ mod tests { jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, - response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, + response_mode: None, additional_authorization_parameters: Vec::new(), }, ) diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index f7a2b150..b0d6991c 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -416,7 +416,7 @@ mod test { jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, - response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, + response_mode: None, additional_authorization_parameters: Vec::new(), }, ) @@ -456,7 +456,7 @@ mod test { jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, - response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, + response_mode: None, additional_authorization_parameters: Vec::new(), }, ) diff --git a/crates/storage-pg/.sqlx/query-1d758df58ccfead4cb39ee8f88f60b382b7881e9c4ead31ff257ff5ff4414b6e.json b/crates/storage-pg/.sqlx/query-1d758df58ccfead4cb39ee8f88f60b382b7881e9c4ead31ff257ff5ff4414b6e.json index 31d16172..65b97215 100644 --- a/crates/storage-pg/.sqlx/query-1d758df58ccfead4cb39ee8f88f60b382b7881e9c4ead31ff257ff5ff4414b6e.json +++ b/crates/storage-pg/.sqlx/query-1d758df58ccfead4cb39ee8f88f60b382b7881e9c4ead31ff257ff5ff4414b6e.json @@ -146,7 +146,7 @@ true, false, false, - false, + true, true ] }, diff --git a/crates/storage-pg/.sqlx/query-27d6f228a9a608b5d03d30cb4074be94dc893df9107e982583aa954b5067dfd1.json b/crates/storage-pg/.sqlx/query-27d6f228a9a608b5d03d30cb4074be94dc893df9107e982583aa954b5067dfd1.json index a866644a..938cab2b 100644 --- a/crates/storage-pg/.sqlx/query-27d6f228a9a608b5d03d30cb4074be94dc893df9107e982583aa954b5067dfd1.json +++ b/crates/storage-pg/.sqlx/query-27d6f228a9a608b5d03d30cb4074be94dc893df9107e982583aa954b5067dfd1.json @@ -144,7 +144,7 @@ true, false, false, - false, + true, true ] }, diff --git a/crates/storage-pg/migrations/20241212154426_oauth2_response_mode_null.sql b/crates/storage-pg/migrations/20241212154426_oauth2_response_mode_null.sql new file mode 100644 index 00000000..c6a6b7b4 --- /dev/null +++ b/crates/storage-pg/migrations/20241212154426_oauth2_response_mode_null.sql @@ -0,0 +1,7 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Drop not null requirement on response mode, so we can ignore this query parameter. +ALTER TABLE "upstream_oauth_providers" ALTER COLUMN "response_mode" DROP NOT NULL; diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index 89b81635..342f8f44 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -74,7 +74,7 @@ mod tests { jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, - response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, + response_mode: None, additional_authorization_parameters: Vec::new(), }, ) @@ -319,7 +319,7 @@ mod tests { jwks_uri_override: None, discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto, - response_mode: mas_data_model::UpstreamOAuthProviderResponseMode::Query, + response_mode: None, additional_authorization_parameters: Vec::new(), }, ) diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index 2a57b861..c9e2b0ec 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -68,7 +68,7 @@ struct ProviderLookup { userinfo_endpoint_override: Option, discovery_mode: String, pkce_mode: String, - response_mode: String, + response_mode: Option, additional_parameters: Option>>, } @@ -177,12 +177,16 @@ impl TryFrom for UpstreamOAuthProvider { .source(e) })?; - let response_mode = value.response_mode.parse().map_err(|e| { - DatabaseInconsistencyError::on("upstream_oauth_providers") - .column("response_mode") - .row(id) - .source(e) - })?; + let response_mode = value + .response_mode + .map(|x| x.parse()) + .transpose() + .map_err(|e| { + DatabaseInconsistencyError::on("upstream_oauth_providers") + .column("response_mode") + .row(id) + .source(e) + })?; let additional_authorization_parameters = value .additional_parameters @@ -370,7 +374,7 @@ impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> { params.jwks_uri_override.as_ref().map(ToString::to_string), params.discovery_mode.as_str(), params.pkce_mode.as_str(), - params.response_mode.as_str(), + params.response_mode.as_ref().map(ToString::to_string), created_at, ) .traced() @@ -576,7 +580,7 @@ impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> { params.jwks_uri_override.as_ref().map(ToString::to_string), params.discovery_mode.as_str(), params.pkce_mode.as_str(), - params.response_mode.as_str(), + params.response_mode.as_ref().map(ToString::to_string), Json(¶ms.additional_authorization_parameters) as _, created_at, ) diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index a7d62862..3489f9b5 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -91,7 +91,7 @@ pub struct UpstreamOAuthProviderParams { pub pkce_mode: UpstreamOAuthProviderPkceMode, /// What response mode it should ask - pub response_mode: UpstreamOAuthProviderResponseMode, + pub response_mode: Option, /// Additional parameters to include in the authorization request pub additional_authorization_parameters: Vec<(String, String)>, diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 377c2bfd..70fc16d0 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -22,8 +22,8 @@ use mas_data_model::{ AuthorizationGrant, BrowserSession, Client, CompatSsoLogin, CompatSsoLoginState, DeviceCodeGrant, UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode, - UpstreamOAuthProviderResponseMode, UpstreamOAuthProviderTokenAuthMethod, User, UserAgent, - UserEmail, UserEmailVerification, UserRecoverySession, + UpstreamOAuthProviderTokenAuthMethod, User, UserAgent, UserEmail, UserEmailVerification, + UserRecoverySession, }; use mas_i18n::DataLocale; use mas_iana::jose::JsonWebSignatureAlg; @@ -1408,7 +1408,7 @@ impl TemplateContext for UpstreamRegister { userinfo_signed_response_alg: None, discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc, pkce_mode: UpstreamOAuthProviderPkceMode::Auto, - response_mode: UpstreamOAuthProviderResponseMode::Query, + response_mode: None, additional_authorization_parameters: Vec::new(), created_at: now, disabled_at: None,