From 02b668f8e8588599dd0478204faa0fe19f799827 Mon Sep 17 00:00:00 2001 From: Xavier Basty Date: Mon, 10 Jul 2023 08:42:16 +0200 Subject: [PATCH] fix: fix CORS headers (#58) --- .env.example | 3 --- src/config.rs | 11 ---------- src/lib.rs | 48 +++++++++++++++-------------------------- terraform/ecs/main.tf | 1 - tests/context/server.rs | 1 - tests/context/store.rs | 1 - 6 files changed, 17 insertions(+), 48 deletions(-) diff --git a/.env.example b/.env.example index 7c9826a..8f0c7ce 100644 --- a/.env.example +++ b/.env.example @@ -7,8 +7,5 @@ MONGO_ADDRESS=mongodb://admin:admin@localhost:27017/gilgamesh?authSource=admin # HTTP clients e.g. curl, insomnia, postman, etc VALIDATE_SIGNATURES=false -# CORS -CORS_ALLOWED_ORIGINS=* - # Telemetry TELEMETRY_PROMETHEUS_PORT=3001 diff --git a/src/config.rs b/src/config.rs index 779e0a9..dd9df3c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,7 +4,6 @@ const DEFAULT_PORT_NUMBER: u16 = 3001; const DEFAULT_LOG_LEVEL: &str = "WARN"; const DEFAULT_RELAY_URL: &str = "https://relay.walletconnect.com"; const DEFAULT_VALIDATE_SIGNATURES: bool = true; -const DEFAULT_CORS_ALLOWED_ORIGINS: &[&str] = &["*"]; /// The server configuration. #[derive(Deserialize, Debug, Clone, Eq, PartialEq)] @@ -26,9 +25,6 @@ pub struct Configuration { /// An internal flag to disable logging, cannot be defined by user. #[serde(default = "default_is_test", skip)] pub is_test: bool, - // CORS - #[serde(default = "default_cors_allowed_origins")] - pub cors_allowed_origins: Vec, pub otel_exporter_otlp_endpoint: Option, pub telemetry_prometheus_port: Option, @@ -65,13 +61,6 @@ fn default_is_test() -> bool { false } -fn default_cors_allowed_origins() -> Vec { - DEFAULT_CORS_ALLOWED_ORIGINS - .iter() - .map(|s| s.to_string()) - .collect::>() -} - /// Create a new configuration from the environment variables. pub fn get_config() -> error::Result { let config = envy::from_env::()?; diff --git a/src/lib.rs b/src/lib.rs index 181cb9c..fe5c147 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ use { state::{MessagesStorageArc, RegistrationStorageArc}, }, axum::{ - http::{HeaderValue, Method}, + http, routing::{get, post}, Router, }, @@ -16,7 +16,7 @@ use { tokio::{select, sync::broadcast}, tower::ServiceBuilder, tower_http::{ - cors::{AllowOrigin, CorsLayer}, + cors::{Any, CorsLayer}, trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer}, }, }; @@ -92,37 +92,22 @@ pub async fn bootstrap( let port = state.config.port; let private_port = state.config.telemetry_prometheus_port.unwrap_or(3001); - let allowed_origins = state.config.cors_allowed_origins.clone(); - let state_arc = Arc::new(state); - let global_middleware = ServiceBuilder::new() - .layer( - TraceLayer::new_for_http() - .make_span_with(DefaultMakeSpan::new().include_headers(true)) - .on_request(DefaultOnRequest::new().level(config.log_level())) - .on_response( - DefaultOnResponse::new() - .level(config.log_level()) - .include_headers(true), - ), - ) - .layer(if allowed_origins == vec!["*".to_string()] { - info!("CORS is disabled"); - CorsLayer::new() - .allow_methods([Method::GET, Method::POST, Method::DELETE]) - .allow_origin(AllowOrigin::any()) - } else { - info!("CORS is enabled for {:?}", allowed_origins); - CorsLayer::new() - .allow_methods([Method::GET, Method::POST, Method::DELETE]) - .allow_origin( - allowed_origins - .iter() - .map(|v| v.parse::().unwrap()) - .collect::>(), - ) - }); + let global_middleware = ServiceBuilder::new().layer( + TraceLayer::new_for_http() + .make_span_with(DefaultMakeSpan::new().include_headers(true)) + .on_request(DefaultOnRequest::new().level(config.log_level())) + .on_response( + DefaultOnResponse::new() + .level(config.log_level()) + .include_headers(true), + ), + ); + + let cors = CorsLayer::new() + .allow_origin(Any) + .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION]); let app = Router::new() .route("/health", get(handlers::health::handler)) @@ -131,6 +116,7 @@ pub async fn bootstrap( .route("/register", get(handlers::get_registration::handler)) .route("/register", post(handlers::register::handler)) .layer(global_middleware) + .layer(cors) .with_state(state_arc.clone()); let private_app = Router::new() diff --git a/terraform/ecs/main.tf b/terraform/ecs/main.tf index 5d789c4..d02e771 100644 --- a/terraform/ecs/main.tf +++ b/terraform/ecs/main.tf @@ -64,7 +64,6 @@ resource "aws_ecs_task_definition" "app_task_definition" { { name = "PUBLIC_URL", value = "http://localhost:8080" }, // TODO: Change this to the actual public URL { name = "LOG_LEVEL", value = var.log_level }, { name = "MONGO_ADDRESS", value = var.docdb-connection_url }, - { name = "CORS_ALLOWED_ORIGINS", value = "*" }, { name = "TELEMETRY_PROMETHEUS_PORT", value = "8081" } ], dependsOn = [ diff --git a/tests/context/server.rs b/tests/context/server.rs index 68f59f9..f7471e9 100644 --- a/tests/context/server.rs +++ b/tests/context/server.rs @@ -55,7 +55,6 @@ impl Gilgamesh { validate_signatures: false, mongo_address, is_test: true, - cors_allowed_origins: vec!["*".to_string()], otel_exporter_otlp_endpoint: None, telemetry_prometheus_port: Some(get_random_port()), }; diff --git a/tests/context/store.rs b/tests/context/store.rs index 7ef7c5f..27f22c3 100644 --- a/tests/context/store.rs +++ b/tests/context/store.rs @@ -23,7 +23,6 @@ impl PersistentStorage { validate_signatures: false, mongo_address, is_test: true, - cors_allowed_origins: vec!["*".to_string()], otel_exporter_otlp_endpoint: None, telemetry_prometheus_port: Some(get_random_port()), };