Skip to content

Commit

Permalink
Merge pull request #9 from antonilol/reuse_socket
Browse files Browse the repository at this point in the history
allow reusing socket after disconnect
  • Loading branch information
antonilol authored Jan 22, 2024
2 parents 5e24416 + ca964dc commit 94e2c67
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 127 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ tokio = { version = "1.35.0", features = ["time", "rt-multi-thread", "macros"] }
[[example]]
name = "subscribe_async_timeout"
required-features = ["async"]
doc-scrape-examples = true

[[example]]
name = "subscribe_async"
Expand All @@ -44,3 +45,4 @@ name = "subscribe_receiver"
[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"]
30 changes: 24 additions & 6 deletions examples/subscribe_async_timeout.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use bitcoincore_zmq::subscribe_async_wait_handshake;
use bitcoincore_zmq::{subscribe_async_wait_handshake, SocketEvent, SocketMessage};
use core::time::Duration;
use futures_util::StreamExt;
use tokio::time::timeout;
Expand Down Expand Up @@ -29,14 +29,32 @@ async fn main() {
}
};

// like in other examples, we have a stream we can get messages from
// but this one is different in that it will terminate on disconnection, and return an error just before that
// Like in other examples, we have a stream we can get messages from, this one also produces
// events, for a detailed description on the events, see
// https://linux.die.net/man/3/zmq_socket_monitor.
while let Some(msg) = stream.next().await {
match msg {
Ok(msg) => println!("Received message: {msg}"),
Ok(SocketMessage::Message(msg)) => println!("Received message: {msg}"),
Ok(SocketMessage::Event(event)) => {
println!("Received socket event: {event:?}");
match event.event {
SocketEvent::Disconnected { .. } => {
println!(
"disconnected from {}, ZMQ automatically tries to reconnect",
event.source_url
);
}
SocketEvent::HandshakeSucceeded => {
// We can say "reconnected" because subscribe_async_wait_handshake waits on
// the first connections of each endpoint before returning.
println!("reconnected to {}", event.source_url);
}
_ => {
// ignore other events
}
}
}
Err(err) => println!("Error receiving message: {err}"),
}
}

println!("stream terminated");
}
39 changes: 26 additions & 13 deletions integration_tests/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod util;
use bitcoincore_rpc::Client;
use bitcoincore_zmq::{
subscribe_async, subscribe_async_monitor, subscribe_async_wait_handshake,
subscribe_async_wait_handshake_timeout, subscribe_blocking, subscribe_receiver, Error, Message,
subscribe_async_wait_handshake_timeout, subscribe_blocking, subscribe_receiver, Message,
MonitorMessage, SocketEvent, SocketMessage,
};
use core::{assert_eq, ops::ControlFlow, time::Duration};
Expand Down Expand Up @@ -243,23 +243,36 @@ fn test_disconnect(rpc: &'static Client) {

let rpc_hash = generate(rpc, 1).expect("rpc call failed").0[0];

match stream.next().await {
Some(Ok(Message::HashBlock(zmq_hash, _seq))) if rpc_hash == zmq_hash => {}
other => panic!("unexpected response: {other:?}"),
loop {
match stream.next().await {
Some(Ok(SocketMessage::Message(Message::HashBlock(zmq_hash, _seq))))
if rpc_hash == zmq_hash =>
{
break;
}
Some(Ok(SocketMessage::Event(_))) => {
// ignore events
}
other => panic!("unexpected response: {other:?}"),
}
}

// send the signal to close the proxy
tx.send(()).unwrap();

match stream.next().await {
Some(Err(Error::Disconnected(endpoint)))
if endpoint == "tcp://127.0.0.1:29999" => {}
other => panic!("unexpected response: {other:?}"),
}

match stream.next().await {
None => {}
other => panic!("unexpected response: {other:?}"),
loop {
match stream.next().await {
Some(Ok(SocketMessage::Event(MonitorMessage {
event: SocketEvent::Disconnected { .. },
source_url,
}))) if source_url == "tcp://127.0.0.1:29999" => {
break;
}
Some(Ok(SocketMessage::Event(_))) => {
// ignore other events
}
other => panic!("unexpected response: {other:?}"),
}
}
});

Expand Down
5 changes: 1 addition & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ pub enum Error {
BitcoinDeserialization(consensus::encode::Error),
Zmq(zmq::Error),
MonitorMessage(MonitorMessageError),
Disconnected(String),
}

impl Error {
Expand Down Expand Up @@ -128,7 +127,6 @@ impl fmt::Display for Error {
}
Self::Zmq(e) => write!(f, "ZMQ Error: {e}"),
Self::MonitorMessage(err) => write!(f, "unable to parse monitor message: {err}"),
Self::Disconnected(url) => write!(f, "disconnected from {url}"),
}
}
}
Expand All @@ -146,8 +144,7 @@ impl std::error::Error for Error {
| Self::InvalidSequenceLength(_)
| Self::InvalidSequenceMessageLength(_)
| Self::InvalidSequenceMessageLabel(_)
| Self::Invalid256BitHashLength(_)
| Self::Disconnected(_) => return None,
| Self::Invalid256BitHashLength(_) => return None,
})
}
}
3 changes: 1 addition & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ pub use crate::{
pub use crate::subscribe::stream::{
subscribe_async, subscribe_async_monitor, subscribe_async_monitor_stream,
subscribe_async_stream::{self, MessageStream},
subscribe_async_wait_handshake, subscribe_async_wait_handshake_stream,
subscribe_async_wait_handshake_timeout, SocketMessage,
subscribe_async_wait_handshake, subscribe_async_wait_handshake_timeout, SocketMessage,
};

#[allow(deprecated)]
Expand Down
114 changes: 12 additions & 102 deletions src/subscribe/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::{
};

/// A [`Message`] or a [`MonitorMessage`].
#[derive(Debug, Clone)]
pub enum SocketMessage {
Message(Message),
Event(MonitorMessage),
Expand Down Expand Up @@ -269,111 +270,27 @@ pub fn subscribe_async_monitor(
))
}

pub mod subscribe_async_wait_handshake_stream {
use super::{subscribe_async_monitor_stream, SocketMessage};
use crate::{
error::{Error, Result},
message::Message,
monitor::{event::SocketEvent, MonitorMessage},
};
use async_zmq::Subscribe;
use core::{
pin::Pin,
task::{Context as AsyncContext, Poll},
};
use futures_util::{
stream::{FusedStream, StreamExt},
Stream,
};

/// Stream returned by [`subscribe_async_wait_handshake`][super::subscribe_async_wait_handshake].
pub struct MessageStream {
inner: Option<subscribe_async_monitor_stream::MessageStream>,
}

impl MessageStream {
pub fn new(inner: subscribe_async_monitor_stream::MessageStream) -> Self {
Self { inner: Some(inner) }
}

/// Returns a reference to the ZMQ socket used by this stream. To get the [`zmq::Socket`], use
/// [`as_raw_socket`] on the result. This is useful to set socket options or use other
/// functions provided by [`zmq`] or [`async_zmq`].
///
/// Returns [`None`] when the socket is not connected anymore.
///
/// [`as_raw_socket`]: Subscribe::as_raw_socket
pub fn as_zmq_socket(&self) -> Option<&Subscribe> {
self.inner
.as_ref()
.map(subscribe_async_monitor_stream::MessageStream::as_zmq_socket)
}
}

impl Stream for MessageStream {
type Item = Result<Message>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut AsyncContext<'_>,
) -> Poll<Option<Self::Item>> {
if let Some(inner) = &mut self.inner {
loop {
match inner.poll_next_unpin(cx) {
Poll::Ready(opt) => match opt.unwrap()? {
SocketMessage::Message(msg) => return Poll::Ready(Some(Ok(msg))),
SocketMessage::Event(MonitorMessage { event, source_url }) => {
match event {
SocketEvent::Disconnected { .. } => {
// drop to disconnect
self.inner = None;
return Poll::Ready(Some(Err(Error::Disconnected(
source_url,
))));
}
_ => {
// only here it loops
}
}
}
},
Poll::Pending => return Poll::Pending,
}
}
} else {
Poll::Ready(None)
}
}
}

impl FusedStream for MessageStream {
fn is_terminated(&self) -> bool {
self.inner.is_none()
}
}
}

// TODO have some way to extract connecting to which endpoints failed, now just a (unit) error is returned (by tokio::time::timeout)

/// Subscribes to multiple ZMQ endpoints and returns a stream that yields [`Message`]s after a
/// connection has been established. When the other end disconnects, an error
/// ([`SocketEvent::Disconnected`]) is returned by the stream and it terminates.
/// Subscribes to multiple ZMQ endpoints and returns a stream that yields [`Message`]s and events
/// (see [`MonitorMessage`]). This method will wait until a connection has been established to all
/// endpoints.
///
/// NOTE: This method will wait indefinitely until a connection has been established, but this is
/// See examples/subscribe_async_timeout.rs for a usage example.
///
/// **NOTE:** This method will wait indefinitely until a connection has been established, but this is
/// often undesirable. This method should therefore be used in combination with your async
/// runtime's timeout function. Currently, with the state of async Rust in December of 2023, it is
/// not yet possible do this without creating an extra thread per timeout or depending on specific
/// runtimes.
pub async fn subscribe_async_wait_handshake(
endpoints: &[&str],
) -> Result<subscribe_async_wait_handshake_stream::MessageStream> {
) -> Result<subscribe_async_monitor_stream::MessageStream> {
let mut stream = subscribe_async_monitor(endpoints)?;
let mut connecting = endpoints.len();

if connecting == 0 {
return Ok(subscribe_async_wait_handshake_stream::MessageStream::new(
stream,
));
return Ok(stream);
}

loop {
Expand All @@ -393,9 +310,7 @@ pub async fn subscribe_async_wait_handshake(
}
}
if connecting == 0 {
return Ok(subscribe_async_wait_handshake_stream::MessageStream::new(
stream,
));
return Ok(stream);
}
}
}
Expand All @@ -405,7 +320,7 @@ pub async fn subscribe_async_wait_handshake(
pub async fn subscribe_async_wait_handshake_timeout(
endpoints: &[&str],
timeout: Duration,
) -> core::result::Result<Result<subscribe_async_wait_handshake_stream::MessageStream>, Timeout> {
) -> core::result::Result<Result<subscribe_async_monitor_stream::MessageStream>, Timeout> {
let subscribe = subscribe_async_wait_handshake(endpoints);
let timeout = sleep(timeout);

Expand All @@ -417,14 +332,9 @@ pub async fn subscribe_async_wait_handshake_timeout(

/// Error returned by [`subscribe_async_wait_handshake_timeout`] when the connection times out.
/// Contains no information, but does have a Error, Debug and Display impl.
#[derive(Debug)]
pub struct Timeout(());

impl fmt::Debug for Timeout {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Timeout").finish()
}
}

impl fmt::Display for Timeout {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "connection timed out")
Expand Down

0 comments on commit 94e2c67

Please sign in to comment.