diff --git a/migrations/20240227222554_index_subscription_watcher_account.sql b/migrations/20240227222554_index_subscription_watcher_account.sql new file mode 100644 index 00000000..266370b5 --- /dev/null +++ b/migrations/20240227222554_index_subscription_watcher_account.sql @@ -0,0 +1 @@ +CREATE INDEX subscription_watcher_address ON subscription_watcher (get_address_lower(account)); diff --git a/src/error.rs b/src/error.rs index a2341bde..b289d126 100644 --- a/src/error.rs +++ b/src/error.rs @@ -120,7 +120,7 @@ pub enum NotifyServerError { AccountNotAuthorized, #[error("sqlx error: {0}")] - SqlxError(#[from] sqlx::error::Error), + Sqlx(#[from] sqlx::error::Error), #[error("sqlx migration error: {0}")] SqlxMigrationError(#[from] sqlx::migrate::MigrateError), diff --git a/src/model/helpers.rs b/src/model/helpers.rs index d7c5825e..5b1f3622 100644 --- a/src/model/helpers.rs +++ b/src/model/helpers.rs @@ -23,6 +23,9 @@ use { x25519_dalek::StaticSecret, }; +// Import not part of group above because it breaks formatting: https://github.com/rust-lang/rustfmt/issues/4746 +use crate::services::public_http_server::handlers::relay_webhook::handlers::notify_watch_subscriptions::SUBSCRIPTION_WATCHER_LIMIT; + #[derive(Debug, FromRow)] pub struct ProjectWithPublicKeys { pub authentication_public_key: String, @@ -685,6 +688,15 @@ pub async fn get_subscriptions_by_account_and_maybe_app( result } +#[derive(Debug, thiserror::Error)] +pub enum UpsertSubscriptionWatcherError { + #[error("Subscription watcher limit reached")] + LimitReached, + + #[error("SQL error: {0}")] + Sqlx(#[from] sqlx::error::Error), +} + #[instrument(skip(postgres, metrics))] pub async fn upsert_subscription_watcher( account: AccountId, @@ -694,33 +706,48 @@ pub async fn upsert_subscription_watcher( expiry: DateTime, postgres: &PgPool, metrics: Option<&Metrics>, -) -> Result<(), sqlx::error::Error> { +) -> Result<(), UpsertSubscriptionWatcherError> { + let query = " + INSERT INTO subscription_watcher ( + account, + project, + did_key, + sym_key, + expiry + ) + SELECT $1, $2, $3, $4, $5 WHERE ( + SELECT COUNT(*) + FROM subscription_watcher + WHERE get_address_lower(account)=get_address_lower($1) + AND project=$2 + ) < $6 + ON CONFLICT (did_key) DO UPDATE SET + updated_at=now(), + account=$1, + project=$2, + sym_key=$4, + expiry=$5 + RETURNING * + "; let start = Instant::now(); - let _ = sqlx::query::( - " - INSERT INTO subscription_watcher ( - account, - project, - did_key, - sym_key, - expiry - ) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (did_key) DO UPDATE SET - updated_at=now(), - account=$1, - project=$2, - sym_key=$4, - expiry=$5 - ", - ) - .bind(account.as_ref()) - .bind(project) - .bind(did_key) - .bind(sym_key) - .bind(expiry) - .execute(postgres) - .await?; + let mut txn = postgres.begin().await?; + // https://stackoverflow.com/a/48730873 + sqlx::query::("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE") // TODO serialization errors not handled + .execute(&mut *txn) + .await?; + let result = sqlx::query_as::(query) + .bind(account.as_ref()) + .bind(project) + .bind(did_key) + .bind(sym_key) + .bind(expiry) + .bind(SUBSCRIPTION_WATCHER_LIMIT) + .fetch_optional(&mut *txn) + .await?; + if result.is_none() { + return Err(UpsertSubscriptionWatcherError::LimitReached); + } + txn.commit().await?; if let Some(metrics) = metrics { metrics.postgres_query("upsert_subscription_watcher", start); } diff --git a/src/services/public_http_server/handlers/relay_webhook/error.rs b/src/services/public_http_server/handlers/relay_webhook/error.rs index 27527f08..d9520e6e 100644 --- a/src/services/public_http_server/handlers/relay_webhook/error.rs +++ b/src/services/public_http_server/handlers/relay_webhook/error.rs @@ -16,6 +16,9 @@ pub enum RelayMessageClientError { #[error("Received 4010 on wrong topic: {0}")] WrongNotifyWatchSubscriptionsTopic(Topic), + #[error("Subscription watcher limit reached")] + SubscriptionWatcherLimitReached, + #[error("Received 4008 on unrecognized topic: {0}")] WrongNotifyUpdateTopic(Topic), diff --git a/src/services/public_http_server/handlers/relay_webhook/handlers/notify_watch_subscriptions.rs b/src/services/public_http_server/handlers/relay_webhook/handlers/notify_watch_subscriptions.rs index f3b2f7f6..6b921f83 100644 --- a/src/services/public_http_server/handlers/relay_webhook/handlers/notify_watch_subscriptions.rs +++ b/src/services/public_http_server/handlers/relay_webhook/handlers/notify_watch_subscriptions.rs @@ -11,7 +11,7 @@ use { helpers::{ get_project_by_app_domain, get_subscription_watchers_for_account_by_app_or_all_app, get_subscriptions_by_account_and_maybe_app, upsert_subscription_watcher, - SubscriberWithProject, SubscriptionWatcherQuery, + SubscriberWithProject, SubscriptionWatcherQuery, UpsertSubscriptionWatcherError, }, types::AccountId, }, @@ -46,6 +46,8 @@ use { x25519_dalek::PublicKey, }; +pub const SUBSCRIPTION_WATCHER_LIMIT: i32 = 25; + #[instrument(name = "wc_notifyWatchSubscriptions", skip_all)] pub async fn handle(msg: RelayIncomingMessage, state: &AppState) -> Result<(), RelayMessageError> { if msg.topic != state.notify_keys.key_agreement_topic { @@ -164,7 +166,14 @@ pub async fn handle(msg: RelayIncomingMessage, state: &AppState) -> Result<(), R state.metrics.as_ref(), ) .await - .map_err(|e| RelayMessageServerError::NotifyServerError(e.into()))?; // TODO change to client error? + .map_err(|e| match e { + UpsertSubscriptionWatcherError::LimitReached => { + RelayMessageError::Client(RelayMessageClientError::SubscriptionWatcherLimitReached) + } + UpsertSubscriptionWatcherError::Sqlx(e) => RelayMessageError::Server( + RelayMessageServerError::NotifyServerError(NotifyServerError::Sqlx(e)), + ), + })?; { let now = Utc::now(); diff --git a/tests/integration.rs b/tests/integration.rs index c363b8b7..324ff9d3 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -8,6 +8,7 @@ use { async_trait::async_trait, chrono::{DateTime, Duration, TimeZone, Utc}, futures::future::BoxFuture, + futures_util::StreamExt, hyper::StatusCode, itertools::Itertools, notify_server::{ @@ -59,6 +60,7 @@ use { self, notify_rate_limit, subscriber_rate_limit, subscriber_rate_limit_key, NotifyBodyNotification, }, + relay_webhook::handlers::notify_watch_subscriptions::SUBSCRIPTION_WATCHER_LIMIT, subscribe_topic::{SubscribeTopicRequestBody, SubscribeTopicResponseBody}, }, RELAY_WEBHOOK_ENDPOINT, @@ -9058,3 +9060,293 @@ async fn batch_receive_called(notify_server: &NotifyServerContext) { println!("fetch response: {response:?}"); assert_eq!(response.messages.len(), 0); } + +#[test_context(NotifyServerContext)] +#[tokio::test] +async fn subscription_watcher_limit(notify_server: &NotifyServerContext) { + let (account_signing_key, account) = generate_account(); + + let keys_server = MockServer::start().await; + let keys_server_url = keys_server.uri().parse::().unwrap(); + let keys_server = Arc::new(keys_server); + + let project_id = ProjectId::generate(); + let app_domain = DidWeb::from_domain(format!("{project_id}.walletconnect.com")); + + let (_key_agreement, _authentication, _client_id) = + subscribe_topic(&project_id, app_domain.clone(), ¬ify_server.url).await; + + futures_util::stream::iter(0..SUBSCRIPTION_WATCHER_LIMIT) + .map(|_| { + let keys_server = keys_server.clone(); + let keys_server_url = keys_server_url.clone(); + let account_signing_key = account_signing_key.clone(); + let app_domain = app_domain.clone(); + let account = account.clone(); + async move { + let (identity_signing_key, identity_public_key) = generate_identity_key(); + let identity_key_details = IdentityKeyDetails { + keys_server_url, + signing_key: identity_signing_key, + client_id: identity_public_key.clone(), + }; + register_mocked_identity_key( + &keys_server, + identity_public_key.clone(), + sign_cacao( + &app_domain, + &account, + STATEMENT_THIS_DOMAIN.to_owned(), + identity_public_key.clone(), + identity_key_details.keys_server_url.to_string(), + &account_signing_key, + ) + .await, + ) + .await; + let vars = get_vars(); + let mut relay_client = RelayClient::new( + vars.relay_url.parse().unwrap(), + vars.project_id.into(), + notify_server.url.clone(), + ) + .await; + watch_subscriptions( + &mut relay_client, + notify_server.url.clone(), + &identity_key_details, + Some(app_domain), + &account, + ) + .await + } + }) + .buffer_unordered(10) + .collect::>() + .await; + + let (identity_signing_key, identity_public_key) = generate_identity_key(); + let identity_key_details = IdentityKeyDetails { + keys_server_url: keys_server_url.clone(), + signing_key: identity_signing_key, + client_id: identity_public_key.clone(), + }; + register_mocked_identity_key( + &keys_server, + identity_public_key.clone(), + sign_cacao( + &app_domain, + &account, + STATEMENT_THIS_DOMAIN.to_owned(), + identity_public_key.clone(), + identity_key_details.keys_server_url.to_string(), + &account_signing_key, + ) + .await, + ) + .await; + let vars = get_vars(); + let mut relay_client = RelayClient::new( + vars.relay_url.parse().unwrap(), + vars.project_id.into(), + notify_server.url.clone(), + ) + .await; + let result = tokio::time::timeout( + RELAY_MESSAGE_DELIVERY_TIMEOUT / 2, + watch_subscriptions( + &mut relay_client, + notify_server.url.clone(), + &identity_key_details, + Some(app_domain), + &account, + ), + ) + .await; + assert!(result.is_err()); + + // Separate limit for different app domains + let project_id = ProjectId::generate(); + let app_domain = DidWeb::from_domain(format!("{project_id}.walletconnect.com")); + + let (_key_agreement, _authentication, _client_id) = + subscribe_topic(&project_id, app_domain.clone(), ¬ify_server.url).await; + + futures_util::stream::iter(0..SUBSCRIPTION_WATCHER_LIMIT) + .map(|_| { + let keys_server = keys_server.clone(); + let keys_server_url = keys_server_url.clone(); + let account_signing_key = account_signing_key.clone(); + let app_domain = app_domain.clone(); + let account = account.clone(); + async move { + let (identity_signing_key, identity_public_key) = generate_identity_key(); + let identity_key_details = IdentityKeyDetails { + keys_server_url, + signing_key: identity_signing_key, + client_id: identity_public_key.clone(), + }; + register_mocked_identity_key( + &keys_server, + identity_public_key.clone(), + sign_cacao( + &app_domain, + &account, + STATEMENT_THIS_DOMAIN.to_owned(), + identity_public_key.clone(), + identity_key_details.keys_server_url.to_string(), + &account_signing_key, + ) + .await, + ) + .await; + let vars = get_vars(); + let mut relay_client = RelayClient::new( + vars.relay_url.parse().unwrap(), + vars.project_id.into(), + notify_server.url.clone(), + ) + .await; + watch_subscriptions( + &mut relay_client, + notify_server.url.clone(), + &identity_key_details, + Some(app_domain), + &account, + ) + .await + } + }) + .buffer_unordered(10) + .collect::>() + .await; + + let (identity_signing_key, identity_public_key) = generate_identity_key(); + let identity_key_details = IdentityKeyDetails { + keys_server_url: keys_server_url.clone(), + signing_key: identity_signing_key, + client_id: identity_public_key.clone(), + }; + register_mocked_identity_key( + &keys_server, + identity_public_key.clone(), + sign_cacao( + &app_domain, + &account, + STATEMENT_THIS_DOMAIN.to_owned(), + identity_public_key.clone(), + identity_key_details.keys_server_url.to_string(), + &account_signing_key, + ) + .await, + ) + .await; + let vars = get_vars(); + let mut relay_client = RelayClient::new( + vars.relay_url.parse().unwrap(), + vars.project_id.into(), + notify_server.url.clone(), + ) + .await; + let result = tokio::time::timeout( + RELAY_MESSAGE_DELIVERY_TIMEOUT / 2, + watch_subscriptions( + &mut relay_client, + notify_server.url.clone(), + &identity_key_details, + Some(app_domain), + &account, + ), + ) + .await; + assert!(result.is_err()); + + // Separate limit for no app domain + futures_util::stream::iter(0..SUBSCRIPTION_WATCHER_LIMIT) + .map(|_| { + let keys_server = keys_server.clone(); + let keys_server_url = keys_server_url.clone(); + let account_signing_key = account_signing_key.clone(); + let account = account.clone(); + async move { + let (identity_signing_key, identity_public_key) = generate_identity_key(); + let identity_key_details = IdentityKeyDetails { + keys_server_url, + signing_key: identity_signing_key, + client_id: identity_public_key.clone(), + }; + register_mocked_identity_key( + &keys_server, + identity_public_key.clone(), + sign_cacao( + &DidWeb::from_domain("com.example.appbundle".to_owned()), + &account, + STATEMENT_ALL_DOMAINS.to_owned(), + identity_public_key.clone(), + identity_key_details.keys_server_url.to_string(), + &account_signing_key, + ) + .await, + ) + .await; + let vars = get_vars(); + let mut relay_client = RelayClient::new( + vars.relay_url.parse().unwrap(), + vars.project_id.into(), + notify_server.url.clone(), + ) + .await; + watch_subscriptions( + &mut relay_client, + notify_server.url.clone(), + &identity_key_details, + None, + &account, + ) + .await + } + }) + .buffer_unordered(10) + .collect::>() + .await; + + let (identity_signing_key, identity_public_key) = generate_identity_key(); + let identity_key_details = IdentityKeyDetails { + keys_server_url: keys_server_url.clone(), + signing_key: identity_signing_key, + client_id: identity_public_key.clone(), + }; + register_mocked_identity_key( + &keys_server, + identity_public_key.clone(), + sign_cacao( + &DidWeb::from_domain("com.example.appbundle".to_owned()), + &account, + STATEMENT_THIS_DOMAIN.to_owned(), + identity_public_key.clone(), + identity_key_details.keys_server_url.to_string(), + &account_signing_key, + ) + .await, + ) + .await; + let vars = get_vars(); + let mut relay_client = RelayClient::new( + vars.relay_url.parse().unwrap(), + vars.project_id.into(), + notify_server.url.clone(), + ) + .await; + let result = tokio::time::timeout( + RELAY_MESSAGE_DELIVERY_TIMEOUT / 2, + watch_subscriptions( + &mut relay_client, + notify_server.url.clone(), + &identity_key_details, + None, + &account, + ), + ) + .await; + assert!(result.is_err()); +}