diff --git a/roles/tests-integration/lib/sniffer.rs b/roles/tests-integration/lib/sniffer.rs index 109934df6..f7b86612e 100644 --- a/roles/tests-integration/lib/sniffer.rs +++ b/roles/tests-integration/lib/sniffer.rs @@ -457,6 +457,38 @@ impl Sniffer { } } + pub async fn wait_for_message_type_with_remove( + &self, + message_direction: MessageDirection, + message_type: u8, + ) -> bool { + let now = std::time::Instant::now(); + loop { + let has_message_type = match message_direction { + MessageDirection::ToDownstream => self + .messages_from_upstream + .has_message_type_with_remove(message_type), + MessageDirection::ToUpstream => self + .messages_from_downstream + .has_message_type_with_remove(message_type), + }; + + // ready to unblock test runtime + if has_message_type { + return true; + } + + // 10 min timeout + // only for worst case, ideally should never be triggered + if now.elapsed().as_secs() > 10 * 60 { + panic!("Timeout waiting for message type"); + } + + // sleep to reduce async lock contention + sleep(Duration::from_secs(1)).await; + } + } + pub async fn includes_message_type( &self, message_direction: MessageDirection, @@ -672,6 +704,22 @@ impl MessagesAggregator { has_message } + fn has_message_type_with_remove(&self, message_type: u8) -> bool { + self.messages + .safe_lock(|messages| { + let mut cloned_messages = messages.clone(); + for (pos, (t, _)) in cloned_messages.iter().enumerate() { + if *t == message_type { + let drained = cloned_messages.drain(pos + 1..).collect(); + *messages = drained; + return true; + } + } + false + }) + .unwrap() + } + // The aggregator queues messages in FIFO order, so this function returns the oldest message in // the queue. // diff --git a/roles/tests-integration/tests/sniffer_integration.rs b/roles/tests-integration/tests/sniffer_integration.rs index 64eeaf649..480f19de7 100644 --- a/roles/tests-integration/tests/sniffer_integration.rs +++ b/roles/tests-integration/tests/sniffer_integration.rs @@ -1,4 +1,7 @@ -use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_ERROR; +use const_sv2::{ + MESSAGE_TYPE_SETUP_CONNECTION_ERROR, MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS, + MESSAGE_TYPE_SET_NEW_PREV_HASH, +}; use integration_tests_sv2::*; use roles_logic_sv2::{ common_messages_sv2::SetupConnectionError, @@ -10,7 +13,6 @@ use std::convert::TryInto; #[tokio::test] async fn test_sniffer_interrupter() { let (_tp, tp_addr) = start_template_provider(None).await; - use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS; let message = PoolMessages::Common(CommonMessages::SetupConnectionError(SetupConnectionError { flags: 0, @@ -33,3 +35,36 @@ async fn test_sniffer_interrupter() { assert_common_message!(&sniffer.next_message_from_downstream(), SetupConnection); assert_common_message!(&sniffer.next_message_from_upstream(), SetupConnectionError); } + +#[tokio::test] +async fn test_sniffer_wait_for_message_type_with_remove() { + let (_tp, tp_addr) = start_template_provider(None).await; + let (sniffer, sniffer_addr) = start_sniffer("".to_string(), tp_addr, false, None).await; + let _ = start_pool(Some(sniffer_addr)).await; + assert!( + sniffer + .wait_for_message_type_with_remove( + MessageDirection::ToDownstream, + MESSAGE_TYPE_SET_NEW_PREV_HASH, + ) + .await + ); + assert_eq!( + sniffer + .includes_message_type( + MessageDirection::ToDownstream, + MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS + ) + .await, + false + ); + assert_eq!( + sniffer + .includes_message_type( + MessageDirection::ToDownstream, + MESSAGE_TYPE_SET_NEW_PREV_HASH + ) + .await, + false + ); +}