diff --git a/Cargo.toml b/Cargo.toml index b772502..61319b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,10 +17,16 @@ async_zmq = { version = "0.4.0", optional = true } bitcoin = "0.30.0" futures-util = { version = "0.3.28", optional = true } zmq = "0.10.0" +zmq-sys = "0.12.0" -# dev dependencies can be used in examples +# dependencies used in examples [dev-dependencies] futures = "0.3.28" +tokio = { version = "1.35.0", features = ["time", "rt-multi-thread", "macros"] } + +[[example]] +name = "subscribe_async_timeout" +required-features = ["async"] [[example]] name = "subscribe_async" diff --git a/examples/subscribe_async_timeout.rs b/examples/subscribe_async_timeout.rs new file mode 100644 index 0000000..70ca1d8 --- /dev/null +++ b/examples/subscribe_async_timeout.rs @@ -0,0 +1,42 @@ +use bitcoincore_zmq::subscribe_async_wait_handshake; +use core::time::Duration; +use futures_util::StreamExt; +use tokio::time::timeout; + +#[tokio::main] +async fn main() { + // In this example I use match instead of unwrap to clearly show where errors are produced. + // `timeout` here returns an `impl Future>>`. The outer + // Result is created by tokio's timeout function, and wraps the inner Result created by the + // subscribe function. + let mut stream = match timeout( + Duration::from_millis(2000), + subscribe_async_wait_handshake(&["tcp://127.0.0.1:28332"]), + ) + .await + { + Ok(Ok(stream)) => { + // Ok(Ok(_)), ok from both functions. + stream + } + Ok(Err(err)) => { + // Ok(Err(_)), ok from `timeout` but an error from the subscribe function. + panic!("subscribe error: {err}"); + } + Err(_) => { + // Err(_), err from `timeout` means that it timed out. + panic!("subscribe_async_wait_handshake timed out"); + } + }; + + // like in other examples, we have a stream we can get messages from + // but this one is different in that it will terminate on disconnection, and return an error just before that + while let Some(msg) = stream.next().await { + match msg { + Ok(msg) => println!("Received message: {msg}"), + Err(err) => println!("Error receiving message: {err}"), + } + } + + println!("stream terminated"); +} diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml index c965ee3..87e5f49 100644 --- a/integration_tests/Cargo.toml +++ b/integration_tests/Cargo.toml @@ -8,3 +8,4 @@ bitcoin = "0.30.0" bitcoincore-rpc = "0.17.0" bitcoincore-zmq = { path = "..", features = ["async"] } futures = "0.3.28" +tokio = { version = "1.35.0", features = ["full"] } diff --git a/integration_tests/src/main.rs b/integration_tests/src/main.rs index 4c5e370..7f8ab41 100644 --- a/integration_tests/src/main.rs +++ b/integration_tests/src/main.rs @@ -2,18 +2,28 @@ mod endpoints; mod util; use bitcoincore_rpc::Client; -use bitcoincore_zmq::{subscribe_async, subscribe_blocking, subscribe_receiver, Message}; -use core::{assert_eq, ops::ControlFlow}; +use bitcoincore_zmq::{ + subscribe_async, subscribe_async_monitor, subscribe_async_wait_handshake, + subscribe_async_wait_handshake_timeout, subscribe_blocking, subscribe_receiver, Error, Message, + MonitorMessage, SocketEvent, SocketMessage, +}; +use core::{assert_eq, ops::ControlFlow, time::Duration}; use futures::{executor::block_on, StreamExt}; -use std::{sync::mpsc, thread}; -use util::{generate, recv_timeout_2, setup_rpc, sleep, RECV_TIMEOUT}; +use std::{net::SocketAddr, sync::mpsc, thread}; +use tokio::{ + io::AsyncWriteExt, + net::{TcpListener, TcpStream}, + runtime, + sync::mpsc::unbounded_channel, +}; +use util::{generate, recv_timeout_2, setup_rpc, sleep, static_ref_heap, RECV_TIMEOUT}; macro_rules! test { ($($function:ident,)*) => { - let rpc = setup_rpc(); + let rpc = static_ref_heap(setup_rpc()); $( println!(concat!("Running ", stringify!($function), "...")); - $function(&rpc); + $function(rpc); println!("ok"); )* }; @@ -25,6 +35,10 @@ fn main() { test_hashtx, test_sub_blocking, test_hashblock_async, + test_monitor, + test_subscribe_timeout_tokio, + test_subscribe_timeout_inefficient, + test_disconnect, } } @@ -126,3 +140,159 @@ fn test_hashblock_async(rpc: &Client) { h.join().unwrap(); } + +fn test_monitor(rpc: &Client) { + let mut stream = subscribe_async_monitor(&[endpoints::HASHBLOCK]) + .expect("failed to subscribe to Bitcoin Core's ZMQ publisher"); + + block_on(async { + while let Some(msg) = stream.next().await { + let msg = msg.unwrap(); + match msg { + SocketMessage::Message(_msg) => { + break; + } + SocketMessage::Event(MonitorMessage { event, .. }) => { + if event == SocketEvent::HandshakeSucceeded { + // there is a zmq publisher on the other side! + // generate a block to generate a message + generate(rpc, 1).expect("rpc call failed"); + } + } + } + } + }); +} + +fn test_subscribe_timeout_tokio(_rpc: &Client) { + const TIMEOUT: Duration = Duration::from_millis(500); + + runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + let _ = tokio::time::timeout( + TIMEOUT, + subscribe_async_wait_handshake(&[endpoints::HASHBLOCK]), + ) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout( + TIMEOUT, + subscribe_async_wait_handshake(&["tcp://localhost:18443"]), + ) + .await + .map(|_| ()) + .expect_err("an http server will not make a zmtp handshake"); + + tokio::time::timeout( + TIMEOUT, + subscribe_async_wait_handshake(&[endpoints::HASHBLOCK, "tcp://localhost:18443"]), + ) + .await + .map(|_| ()) + .expect_err("an http server will not make a zmtp handshake"); + }); +} + +fn test_subscribe_timeout_inefficient(_rpc: &Client) { + const TIMEOUT: Duration = Duration::from_millis(500); + + block_on(async { + let _ = subscribe_async_wait_handshake_timeout(&[endpoints::HASHBLOCK], TIMEOUT) + .await + .unwrap() + .unwrap(); + + subscribe_async_wait_handshake_timeout(&["tcp://localhost:18443"], TIMEOUT) + .await + .map(|_| ()) + .expect_err("an http server will not make a zmtp handshake"); + + subscribe_async_wait_handshake_timeout( + &[endpoints::HASHBLOCK, "tcp://localhost:18443"], + TIMEOUT, + ) + .await + .map(|_| ()) + .expect_err("an http server will not make a zmtp handshake"); + }); +} + +fn test_disconnect(rpc: &'static Client) { + runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + let (tx, mut rx) = unbounded_channel(); + + let h = tokio::spawn(async move { + let mut stream = tokio::time::timeout( + Duration::from_millis(2000), + subscribe_async_wait_handshake(&["tcp://127.0.0.1:29999"]), + ) + .await + .unwrap() + .unwrap(); + + tokio::time::sleep(Duration::from_millis(500)).await; + + let rpc_hash = generate(rpc, 1).expect("rpc call failed").0[0]; + + match stream.next().await { + Some(Ok(Message::HashBlock(zmq_hash, _seq))) if rpc_hash == zmq_hash => {} + other => panic!("unexpected response: {other:?}"), + } + + // send the signal to close the proxy + tx.send(()).unwrap(); + + match stream.next().await { + Some(Err(Error::Disconnected(endpoint))) + if endpoint == "tcp://127.0.0.1:29999" => {} + other => panic!("unexpected response: {other:?}"), + } + + match stream.next().await { + None => {} + other => panic!("unexpected response: {other:?}"), + } + }); + + // proxy endpoints::HASHBLOCK to 127.0.0.1:29999 to simulate a disconnect + // stopping bitcoin core is not a good idea as other tests may follow this one + // taken from https://github.com/tokio-rs/tokio/discussions/3173, it is not perfect but ok for this test + let ss = TcpListener::bind("127.0.0.1:29999".parse::().unwrap()) + .await + .unwrap(); + let (cs, _) = ss.accept().await.unwrap(); + // [6..] splits off "tcp://" + let g = TcpStream::connect(endpoints::HASHBLOCK[6..].parse::().unwrap()) + .await + .unwrap(); + let (mut gr, mut gw) = g.into_split(); + let (mut csr, mut csw) = cs.into_split(); + let h1 = tokio::spawn(async move { + let _ = tokio::io::copy(&mut gr, &mut csw).await; + let _ = csw.shutdown().await; + }); + let h2 = tokio::spawn(async move { + let _ = tokio::io::copy(&mut csr, &mut gw).await; + let _ = gw.shutdown().await; + }); + + // wait for the signal + rx.recv().await.unwrap(); + + // close the proxy + h1.abort(); + h2.abort(); + + // wait on other spawned tasks + h.await.unwrap(); + }); +} diff --git a/integration_tests/src/util.rs b/integration_tests/src/util.rs index 8693650..b19b817 100644 --- a/integration_tests/src/util.rs +++ b/integration_tests/src/util.rs @@ -13,6 +13,10 @@ pub fn setup_rpc() -> Client { .expect("unable to connect to Bitcoin Core regtest RPC") } +pub fn static_ref_heap(val: T) -> &'static T { + Box::leak(Box::new(val)) +} + fn get_cookie_path() -> String { env::var("BITCOIN_CORE_COOKIE_PATH").expect( "env var BITCOIN_CORE_COOKIE_PATH probably not set, \ diff --git a/src/error.rs b/src/error.rs index c350a25..4bd7907 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,7 @@ -use crate::message::{DATA_MAX_LEN, SEQUENCE_LEN, TOPIC_MAX_LEN}; +use crate::{ + message::{DATA_MAX_LEN, SEQUENCE_LEN, TOPIC_MAX_LEN}, + monitor::MonitorMessageError, +}; use bitcoin::consensus; use core::{cmp::min, fmt}; @@ -15,6 +18,8 @@ pub enum Error { Invalid256BitHashLength(usize), BitcoinDeserialization(consensus::encode::Error), Zmq(zmq::Error), + MonitorMessage(MonitorMessageError), + Disconnected(String), } impl Error { @@ -69,6 +74,13 @@ impl From for Error { } } +impl From for Error { + #[inline] + fn from(value: MonitorMessageError) -> Self { + Self::MonitorMessage(value) + } +} + impl fmt::Display for Error { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -115,6 +127,8 @@ impl fmt::Display for Error { write!(f, "bitcoin consensus deserialization error: {e}") } Self::Zmq(e) => write!(f, "ZMQ Error: {e}"), + Self::MonitorMessage(err) => write!(f, "unable to parse monitor message: {err}"), + Self::Disconnected(url) => write!(f, "disconnected from {url}"), } } } @@ -125,13 +139,15 @@ impl std::error::Error for Error { Some(match self { Self::BitcoinDeserialization(e) => e, Self::Zmq(e) => e, + Self::MonitorMessage(e) => e, Self::InvalidMutlipartLength(_) | Self::InvalidTopic(_, _) | Self::InvalidDataLength(_) | Self::InvalidSequenceLength(_) | Self::InvalidSequenceMessageLength(_) | Self::InvalidSequenceMessageLabel(_) - | Self::Invalid256BitHashLength(_) => return None, + | Self::Invalid256BitHashLength(_) + | Self::Disconnected(_) => return None, }) } } diff --git a/src/lib.rs b/src/lib.rs index 972a9d3..52bc903 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,18 +2,28 @@ mod error; mod message; +mod monitor; mod sequence_message; mod subscribe; pub use crate::{ error::Error, message::{Message, DATA_MAX_LEN, SEQUENCE_LEN, TOPIC_MAX_LEN}, + monitor::{ + event::{HandshakeFailure, SocketEvent}, + MonitorMessage, + }, sequence_message::SequenceMessage, subscribe::{blocking::subscribe_blocking, receiver::subscribe_receiver}, }; #[cfg(feature = "async")] -pub use crate::subscribe::stream::{subscribe_async, MessageStream}; +pub use crate::subscribe::stream::{ + subscribe_async, subscribe_async_monitor, subscribe_async_monitor_stream, + subscribe_async_stream::{self, MessageStream}, + subscribe_async_wait_handshake, subscribe_async_wait_handshake_stream, + subscribe_async_wait_handshake_timeout, SocketMessage, +}; #[allow(deprecated)] pub use crate::subscribe::{ diff --git a/src/monitor/event.rs b/src/monitor/event.rs new file mode 100644 index 0000000..2262244 --- /dev/null +++ b/src/monitor/event.rs @@ -0,0 +1,151 @@ +use super::MonitorMessageError; + +/// Convenience trait to be able to use `from_raw` and `to_raw` on any value that either defines it +/// or is a `u32`. It doesn't matter that others don't implement this trait, rustc is smart enough +/// to find that out. +trait U32Ext: Sized { + fn from_raw(value: u32) -> Option; + + fn to_raw(self) -> u32; +} + +impl U32Ext for u32 { + fn from_raw(value: u32) -> Option { + Some(value) + } + + fn to_raw(self) -> Self { + self + } +} + +macro_rules! type_or_u32 { + ($type:ty) => { + $type + }; + () => { + u32 + }; +} + +macro_rules! define_handshake_failure_enum { + ($($name:ident = $zmq_sys_name:ident,)*) => { + #[repr(u32)] + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum HandshakeFailure { + $( + $name = zmq_sys::$zmq_sys_name, + )* + } + + impl HandshakeFailure { + pub fn from_raw(data: u32) -> Option { + Some(match data { + $( + zmq_sys::$zmq_sys_name => Self::$name, + )* + _ => return None, + }) + } + + pub fn to_raw(self) -> u32 { + self as u32 + } + } + }; +} + +define_handshake_failure_enum! { + ZmtpUnspecified = ZMQ_PROTOCOL_ERROR_ZMTP_UNSPECIFIED, + ZmtpUnexpectedCommand = ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND, + ZmtpInvalidSequence = ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE, + ZmtpKeyExchange = ZMQ_PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE, + ZmtpMalformedCommandUnspecified = ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED, + ZmtpMalformedCommandMessage = ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE, + ZmtpMalformedCommandHello = ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO, + ZmtpMalformedCommandInitiate = ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE, + ZmtpMalformedCommandError = ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR, + ZmtpMalformedCommandReady = ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY, + ZmtpMalformedCommandWelcome = ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME, + ZmtpInvalidMetadata = ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_METADATA, + ZmtpCryptographic = ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC, + ZmtpMechanismMismatch = ZMQ_PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH, + ZapUnspecified = ZMQ_PROTOCOL_ERROR_ZAP_UNSPECIFIED, + ZapMalformedReply = ZMQ_PROTOCOL_ERROR_ZAP_MALFORMED_REPLY, + ZapBadRequestId = ZMQ_PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID, + ZapBadVersion = ZMQ_PROTOCOL_ERROR_ZAP_BAD_VERSION, + ZapInvalidStatusCode = ZMQ_PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE, + ZapInvalidMetadata = ZMQ_PROTOCOL_ERROR_ZAP_INVALID_METADATA, +} + +macro_rules! define_socket_event_enum { + (enum docs { $(#[$attr:meta])* } $($name:ident$(($value:ident $(: $type:ty)?))? = $zmq_sys_name:ident,)*) => { + $(#[$attr])* + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum SocketEvent { + $( + $name $({ $value: type_or_u32!($($type)?) })?, + )* + Unknown { event: u16, data: u32 }, + } + + impl SocketEvent { + pub fn from_raw(event: u16, data: u32) -> Option { + Some(match event as u32 { + $( + zmq_sys::$zmq_sys_name => Self::$name $({ $value: ::from_raw(data)? })?, + )* + _ => Self::Unknown { event, data }, + }) + } + + pub fn to_raw(self) -> (u16, Option) { + match self { + $( + Self::$name $({ $value })? => (zmq_sys::$zmq_sys_name as u16, ($(Some($value.to_raw()), )? None::,).0), + )* + Self::Unknown { event, data } => (event, Some(data)), + } + } + } + }; +} + +define_socket_event_enum! { + enum docs { + /// An event from one of the connected sockets. See the "SUPPORTED EVENTS" section in the + /// "zmq_socket_monitor" manual page (`man zmq_socket_monitor`) for the original + /// documentation. + } + + Connected(fd) = ZMQ_EVENT_CONNECTED, + ConnectDelayed = ZMQ_EVENT_CONNECT_DELAYED, + ConnectRetried(interval) = ZMQ_EVENT_CONNECT_RETRIED, + Listening(fd) = ZMQ_EVENT_LISTENING, + BindFailed(errno) = ZMQ_EVENT_BIND_FAILED, + Accepted(fd) = ZMQ_EVENT_ACCEPTED, + AcceptFailed(errno) = ZMQ_EVENT_ACCEPT_FAILED, + Closed(fd) = ZMQ_EVENT_CLOSED, + CloseFailed(errno) = ZMQ_EVENT_CLOSE_FAILED, + Disconnected(fd) = ZMQ_EVENT_DISCONNECTED, + MonitorStopped = ZMQ_EVENT_MONITOR_STOPPED, + HandshakeFailedNoDetail(fd) = ZMQ_EVENT_HANDSHAKE_FAILED_NO_DETAIL, + HandshakeSucceeded = ZMQ_EVENT_HANDSHAKE_SUCCEEDED, + HandshakeFailedProtocol(err: HandshakeFailure) = ZMQ_EVENT_HANDSHAKE_FAILED_PROTOCOL, + HandshakeFailedAuth(error_code) = ZMQ_EVENT_HANDSHAKE_FAILED_AUTH, +} + +impl SocketEvent { + pub fn parse_from(msg: &zmq::Message) -> Result { + let bytes = &**msg; + + let event: [u8; 6] = bytes + .try_into() + .map_err(|_| MonitorMessageError::InvalidEventFrameLength(bytes.len()))?; + let event_type = u16::from_ne_bytes(event[0..2].try_into().unwrap()); + let data = u32::from_ne_bytes(event[2..6].try_into().unwrap()); + + SocketEvent::from_raw(event_type, data) + .ok_or(MonitorMessageError::InvalidEventData(event_type, data)) + } +} diff --git a/src/monitor/mod.rs b/src/monitor/mod.rs new file mode 100644 index 0000000..8935ee7 --- /dev/null +++ b/src/monitor/mod.rs @@ -0,0 +1,51 @@ +pub mod event; + +use core::fmt; +use event::SocketEvent; + +/// A [`SocketEvent`] combined with its source (the url used when connecting). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MonitorMessage { + pub event: SocketEvent, + pub source_url: String, +} + +impl MonitorMessage { + pub fn parse_from(msg: &[zmq::Message]) -> Result { + let [event_message, url_message] = msg else { + return Err(MonitorMessageError::InvalidMutlipartLength(msg.len())); + }; + + Ok(MonitorMessage { + event: SocketEvent::parse_from(event_message)?, + source_url: String::from_utf8_lossy(url_message).into(), + }) + } +} + +#[derive(Debug)] +// currently all variants have the same prefix: `Invalid`, which is correct and intended +#[allow(clippy::enum_variant_names)] +pub enum MonitorMessageError { + InvalidMutlipartLength(usize), + InvalidEventFrameLength(usize), + InvalidEventData(u16, u32), +} + +impl fmt::Display for MonitorMessageError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidMutlipartLength(len) => { + write!(f, "invalid multipart message length: {len} (expected 2)") + } + Self::InvalidEventFrameLength(len) => { + write!(f, "invalid event frame length: {len} (expected 6)") + } + Self::InvalidEventData(event_type, event_data) => { + write!(f, "invalid event data {event_data} for event {event_type}") + } + } + } +} + +impl std::error::Error for MonitorMessageError {} diff --git a/src/subscribe/stream.rs b/src/subscribe/stream.rs index 3157df5..adff1e4 100644 --- a/src/subscribe/stream.rs +++ b/src/subscribe/stream.rs @@ -1,54 +1,31 @@ -use super::{new_socket_internal, recv_internal}; -use crate::{error::Result, message::Message, DATA_MAX_LEN}; -use async_zmq::{Stream, StreamExt, Subscribe}; +use super::new_socket_internal; +use crate::{ + error::Result, + message::Message, + monitor::{event::SocketEvent, MonitorMessage, MonitorMessageError}, +}; use core::{ - pin::Pin, + fmt, + future::Future, + mem, + pin::{pin, Pin}, slice, - task::{Context as AsyncContext, Poll}, + task::{Context as AsyncContext, Poll, Waker}, + time::Duration, +}; +use futures_util::{ + future::{select, Either}, + stream::{FusedStream, Stream, StreamExt}, +}; +use std::{ + sync::{Arc, Mutex}, + thread, }; -use futures_util::stream::FusedStream; - -/// Stream that asynchronously produces [`Message`]s using a ZMQ subscriber. -pub struct MessageStream { - zmq_stream: Subscribe, - data_cache: Box<[u8; DATA_MAX_LEN]>, -} - -impl MessageStream { - fn new(zmq_stream: Subscribe) -> Self { - Self { - zmq_stream, - data_cache: vec![0; DATA_MAX_LEN].into_boxed_slice().try_into().unwrap(), - } - } - - /// Returns a reference to the ZMQ socket used by this stream. To get the [`zmq::Socket`], use - /// [`as_raw_socket`] on the result. This is useful to set socket options or use other - /// functions provided by [`zmq`] or [`async_zmq`]. - /// - /// [`as_raw_socket`]: Subscribe::as_raw_socket - pub fn as_zmq_socket(&self) -> &Subscribe { - &self.zmq_stream - } -} - -impl Stream for MessageStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut AsyncContext<'_>) -> Poll> { - self.zmq_stream.poll_next_unpin(cx).map(|opt| { - opt.map(|res| match res { - Ok(mp) => recv_internal(mp.iter(), &mut self.data_cache), - Err(err) => Err(err.into()), - }) - }) - } -} -impl FusedStream for MessageStream { - fn is_terminated(&self) -> bool { - false - } +/// A [`Message`] or a [`MonitorMessage`]. +pub enum SocketMessage { + Message(Message), + Event(MonitorMessage), } /// Stream that asynchronously produces [`Message`]s using multiple ZMQ subscribers. The ZMQ @@ -57,21 +34,21 @@ impl FusedStream for MessageStream { since = "1.3.2", note = "This struct is only used by deprecated functions." )] -pub struct MultiMessageStream(pub MessageStream); +pub struct MultiMessageStream(pub subscribe_async_stream::MessageStream); #[allow(deprecated)] impl MultiMessageStream { - /// Returns a reference to the separate [`MessageStream`]s this [`MultiMessageStream`] is made + /// Returns a reference to the separate `MessageStream`s this [`MultiMessageStream`] is made /// of. This is useful to set socket options or use other functions provided by [`zmq`] or - /// [`async_zmq`]. (See [`MessageStream::as_zmq_socket`]) - pub fn as_streams(&self) -> &[MessageStream] { + /// [`async_zmq`]. (See `MessageStream::as_zmq_socket`) + pub fn as_streams(&self) -> &[subscribe_async_stream::MessageStream] { slice::from_ref(&self.0) } - /// Returns the separate [`MessageStream`]s this [`MultiMessageStream`] is made of. This is + /// Returns the separate `MessageStream`s this [`MultiMessageStream`] is made of. This is /// useful to set socket options or use other functions provided by [`zmq`] or [`async_zmq`]. - /// (See [`MessageStream::as_zmq_socket`]) - pub fn into_streams(self) -> Vec { + /// (See `MessageStream::as_zmq_socket`) + pub fn into_streams(self) -> Vec { vec![self.0] } } @@ -102,18 +79,397 @@ pub fn subscribe_multi_async(endpoints: &[&str]) -> Result { subscribe_async(endpoints).map(MultiMessageStream) } -/// Subscribes to a single ZMQ endpoint and returns a [`MessageStream`]. +/// Subscribes to a single ZMQ endpoint and returns a `MessageStream`. #[deprecated( since = "1.3.2", note = "Use subscribe_async. The name changed because there is no distinction made anymore between subscribing to 1 or more endpoints." )] -pub fn subscribe_single_async(endpoint: &str) -> Result { +pub fn subscribe_single_async(endpoint: &str) -> Result { subscribe_async(&[endpoint]) } -/// Subscribes to multiple ZMQ endpoints and returns a [`MessageStream`]. -pub fn subscribe_async(endpoints: &[&str]) -> Result { +pub mod subscribe_async_stream { + use crate::{ + error::Result, + message::{Message, DATA_MAX_LEN}, + subscribe::recv_internal, + }; + use async_zmq::Subscribe; + use core::{ + pin::Pin, + task::{Context as AsyncContext, Poll}, + }; + use futures_util::{ + stream::{FusedStream, StreamExt}, + Stream, + }; + + /// Stream returned by [`subscribe_async`][super::subscribe_async]. + pub struct MessageStream { + zmq_stream: Subscribe, + data_cache: Box<[u8; DATA_MAX_LEN]>, + } + + impl MessageStream { + pub(super) fn new(zmq_stream: Subscribe) -> Self { + Self { + zmq_stream, + data_cache: vec![0; DATA_MAX_LEN].into_boxed_slice().try_into().unwrap(), + } + } + + /// Returns a reference to the ZMQ socket used by this stream. To get the [`zmq::Socket`], use + /// [`as_raw_socket`] on the result. This is useful to set socket options or use other + /// functions provided by [`zmq`] or [`async_zmq`]. + /// + /// [`as_raw_socket`]: Subscribe::as_raw_socket + pub fn as_zmq_socket(&self) -> &Subscribe { + &self.zmq_stream + } + } + + impl Stream for MessageStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut AsyncContext<'_>, + ) -> Poll> { + self.zmq_stream.poll_next_unpin(cx).map(|opt| { + Some(match opt.unwrap() { + Ok(mp) => recv_internal(mp.iter(), &mut self.data_cache), + Err(err) => Err(err.into()), + }) + }) + } + } + + impl FusedStream for MessageStream { + fn is_terminated(&self) -> bool { + false + } + } +} + +/// Subscribes to multiple ZMQ endpoints and returns a stream that produces [`Message`]s. +pub fn subscribe_async(endpoints: &[&str]) -> Result { let (_context, socket) = new_socket_internal(endpoints)?; - Ok(MessageStream::new(socket.into())) + Ok(subscribe_async_stream::MessageStream::new(socket.into())) +} + +pub mod subscribe_async_monitor_stream { + use super::{subscribe_async_stream, SocketMessage}; + use crate::{error::Result, monitor::MonitorMessage}; + use async_zmq::Subscribe; + use core::{ + pin::Pin, + task::{Context as AsyncContext, Poll}, + }; + use futures_util::{ + stream::{FusedStream, StreamExt}, + Stream, + }; + use zmq::Socket; + + pub(super) enum Empty {} + + impl Iterator for Empty { + type Item = Empty; + + fn next(&mut self) -> Option { + None + } + } + + impl From for async_zmq::Message { + fn from(val: Empty) -> Self { + match val {} + } + } + + // The generic type params don't matter as this will only be used for receiving + // Better to use an empty type to not waste precious bytes + pub(super) type RecvOnlyPair = async_zmq::Pair; + + /// Stream returned by [`subscribe_async_monitor`][super::subscribe_async_monitor]. + pub struct MessageStream { + messages: subscribe_async_stream::MessageStream, + pub(super) monitor: RecvOnlyPair, + } + + impl MessageStream { + pub(super) fn new( + messages: subscribe_async_stream::MessageStream, + monitor: RecvOnlyPair, + ) -> Self { + Self { messages, monitor } + } + + /// Returns a reference to the ZMQ socket used by this stream. To get the [`zmq::Socket`], use + /// [`as_raw_socket`] on the result. This is useful to set socket options or use other + /// functions provided by [`zmq`] or [`async_zmq`]. + /// + /// [`as_raw_socket`]: Subscribe::as_raw_socket + pub fn as_zmq_socket(&self) -> &Subscribe { + self.messages.as_zmq_socket() + } + + /// Returns a reference to the ZMQ monitor socket used by this stream. This is useful to + /// set socket options or use other functions provided by [`zmq`]. + pub fn as_zmq_monitor_socket(&self) -> &Socket { + self.monitor.as_raw_socket() + } + } + + impl Stream for MessageStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut AsyncContext<'_>, + ) -> Poll> { + match self.monitor.poll_next_unpin(cx) { + Poll::Ready(msg) => { + return Poll::Ready(Some(Ok(SocketMessage::Event( + MonitorMessage::parse_from(&msg.unwrap()?)?, + )))); + } + Poll::Pending => {} + } + + self.messages + .poll_next_unpin(cx) + .map(|opt| Some(opt.unwrap().map(SocketMessage::Message))) + } + } + + impl FusedStream for MessageStream { + fn is_terminated(&self) -> bool { + false + } + } +} + +/// Subscribes to multiple ZMQ endpoints and returns a stream that yields [`Message`]s and events +/// (see [`MonitorMessage`]). +pub fn subscribe_async_monitor( + endpoints: &[&str], +) -> Result { + let (context, socket) = new_socket_internal(endpoints)?; + + socket.monitor("inproc://monitor", zmq::SocketEvent::ALL as i32)?; + + let monitor = context.socket(zmq::PAIR)?; + monitor.connect("inproc://monitor")?; + + Ok(subscribe_async_monitor_stream::MessageStream::new( + subscribe_async_stream::MessageStream::new(socket.into()), + monitor.into(), + )) +} + +pub mod subscribe_async_wait_handshake_stream { + use super::{subscribe_async_monitor_stream, SocketMessage}; + use crate::{ + error::{Error, Result}, + message::Message, + monitor::{event::SocketEvent, MonitorMessage}, + }; + use async_zmq::Subscribe; + use core::{ + pin::Pin, + task::{Context as AsyncContext, Poll}, + }; + use futures_util::{ + stream::{FusedStream, StreamExt}, + Stream, + }; + + /// Stream returned by [`subscribe_async_wait_handshake`][super::subscribe_async_wait_handshake]. + pub struct MessageStream { + inner: Option, + } + + impl MessageStream { + pub fn new(inner: subscribe_async_monitor_stream::MessageStream) -> Self { + Self { inner: Some(inner) } + } + + /// Returns a reference to the ZMQ socket used by this stream. To get the [`zmq::Socket`], use + /// [`as_raw_socket`] on the result. This is useful to set socket options or use other + /// functions provided by [`zmq`] or [`async_zmq`]. + /// + /// Returns [`None`] when the socket is not connected anymore. + /// + /// [`as_raw_socket`]: Subscribe::as_raw_socket + pub fn as_zmq_socket(&self) -> Option<&Subscribe> { + self.inner + .as_ref() + .map(subscribe_async_monitor_stream::MessageStream::as_zmq_socket) + } + } + + impl Stream for MessageStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut AsyncContext<'_>, + ) -> Poll> { + if let Some(inner) = &mut self.inner { + loop { + match inner.poll_next_unpin(cx) { + Poll::Ready(opt) => match opt.unwrap()? { + SocketMessage::Message(msg) => return Poll::Ready(Some(Ok(msg))), + SocketMessage::Event(MonitorMessage { event, source_url }) => { + match event { + SocketEvent::Disconnected { .. } => { + // drop to disconnect + self.inner = None; + return Poll::Ready(Some(Err(Error::Disconnected( + source_url, + )))); + } + _ => { + // only here it loops + } + } + } + }, + Poll::Pending => return Poll::Pending, + } + } + } else { + Poll::Ready(None) + } + } + } + + impl FusedStream for MessageStream { + fn is_terminated(&self) -> bool { + self.inner.is_none() + } + } +} + +// TODO have some way to extract connecting to which endpoints failed, now just a (unit) error is returned (by tokio::time::timeout) + +/// Subscribes to multiple ZMQ endpoints and returns a stream that yields [`Message`]s after a +/// connection has been established. When the other end disconnects, an error +/// ([`SocketEvent::Disconnected`]) is returned by the stream and it terminates. +/// +/// NOTE: This method will wait indefinitely until a connection has been established, but this is +/// often undesirable. This method should therefore be used in combination with your async +/// runtime's timeout function. Currently, with the state of async Rust in December of 2023, it is +/// not yet possible do this without creating an extra thread per timeout or depending on specific +/// runtimes. +pub async fn subscribe_async_wait_handshake( + endpoints: &[&str], +) -> Result { + let mut stream = subscribe_async_monitor(endpoints)?; + let mut connecting = endpoints.len(); + + if connecting == 0 { + return Ok(subscribe_async_wait_handshake_stream::MessageStream::new( + stream, + )); + } + + loop { + let msg: &[zmq::Message] = &stream.monitor.next().await.unwrap()?; + let [event_message, _] = msg else { + return Err(MonitorMessageError::InvalidMutlipartLength(msg.len()).into()); + }; + match SocketEvent::parse_from(event_message)? { + SocketEvent::HandshakeSucceeded => { + connecting -= 1; + } + SocketEvent::Disconnected { .. } => { + connecting += 1; + } + _ => { + continue; + } + } + if connecting == 0 { + return Ok(subscribe_async_wait_handshake_stream::MessageStream::new( + stream, + )); + } + } +} + +/// See [`subscribe_async_wait_handshake`]. This method implements the inefficient, but runtime +/// independent approach. +pub async fn subscribe_async_wait_handshake_timeout( + endpoints: &[&str], + timeout: Duration, +) -> core::result::Result, Timeout> { + let subscribe = subscribe_async_wait_handshake(endpoints); + let timeout = sleep(timeout); + + match select(pin!(subscribe), timeout).await { + Either::Left((res, _)) => Ok(res), + Either::Right(_) => Err(Timeout(())), + } +} + +/// Error returned by [`subscribe_async_wait_handshake_timeout`] when the connection times out. +/// Contains no information, but does have a Error, Debug and Display impl. +pub struct Timeout(()); + +impl fmt::Debug for Timeout { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Timeout").finish() + } +} + +impl fmt::Display for Timeout { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "connection timed out") + } +} + +impl std::error::Error for Timeout {} + +fn sleep(dur: Duration) -> Sleep { + let state = Arc::new(Mutex::new(SleepReadyState::Pending)); + { + let state = state.clone(); + thread::spawn(move || { + thread::sleep(dur); + let state = { + let mut g = state.lock().unwrap(); + mem::replace(&mut *g, SleepReadyState::Done) + }; + if let SleepReadyState::PendingPolled(waker) = state { + waker.wake(); + } + }); + } + + Sleep(state) +} + +enum SleepReadyState { + Pending, + PendingPolled(Waker), + Done, +} + +struct Sleep(Arc>); + +impl Future for Sleep { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut AsyncContext<'_>) -> Poll { + let mut g = self.0.lock().unwrap(); + if let SleepReadyState::Done = *g { + Poll::Ready(()) + } else { + *g = SleepReadyState::PendingPolled(cx.waker().clone()); + Poll::Pending + } + } }