Skip to content

Commit

Permalink
initial async support
Browse files Browse the repository at this point in the history
  • Loading branch information
antonilol committed Oct 21, 2023
1 parent 47b73f1 commit 4985791
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 12 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ repository = "https://github.com/antonilol/rust-bitcoincore-zmq"
keywords = ["bitcoin", "bitcoin-core", "zmq"]
categories = ["cryptography::cryptocurrencies", "network-programming"]

[features]
async = ["dep:async_zmq", "dep:futures-util"]

[dependencies]
async_zmq = { version = "0.4.0", optional = true }
bitcoin = "0.30.0"
futures-util = { version = "0.3.28", optional = true }
zmq = "0.10.0"
3 changes: 2 additions & 1 deletion integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ edition = "2021"
[dependencies]
bitcoin = "0.30.0"
bitcoincore-rpc = "0.17.0"
bitcoincore-zmq = { path = ".." }
bitcoincore-zmq = { path = "..", features = ["async"] }
futures = "0.3.28"
32 changes: 31 additions & 1 deletion integration_tests/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ mod endpoints;
mod util;

use bitcoincore_rpc::Client;
use bitcoincore_zmq::{subscribe_multi, subscribe_single_blocking, Message};
use bitcoincore_zmq::{subscribe_multi, subscribe_multi_async, subscribe_single_blocking, Message};
use core::{assert_eq, ops::ControlFlow};
use futures::StreamExt;
use std::{sync::mpsc, thread};
use util::{generate, recv_timeout_2, setup_rpc, sleep, RECV_TIMEOUT};

Expand All @@ -13,6 +14,7 @@ fn main() {
test_hashblock(&rpc);
test_hashtx(&rpc);
test_sub_blocking(&rpc);
test_hashblock_async(&rpc);
}

fn test_hashblock(rpc: &Client) {
Expand Down Expand Up @@ -82,3 +84,31 @@ fn test_sub_blocking(rpc: &Client) {

assert_eq!(rpc_hash, zmq_hash);
}

fn test_hashblock_async(rpc: &Client) {
let mut stream = subscribe_multi_async(&[endpoints::HASHBLOCK, endpoints::RAWBLOCK])
.expect("failed to subscribe to Bitcoin Core's ZMQ subscriber");

let rpc_hash = generate(rpc, 1).expect("rpc call failed").0[0];

let (tx, rx) = mpsc::channel();

thread::spawn(move || {
futures::executor::block_on(async {
while let Some(msg) = stream.next().await {
tx.send(msg).unwrap();
}
})
});

match recv_timeout_2(&rx) {
(Message::Block(block, _), Message::HashBlock(blockhash, _))
| (Message::HashBlock(blockhash, _), Message::Block(block, _)) => {
assert_eq!(rpc_hash, block.block_hash());
assert_eq!(rpc_hash, blockhash);
}
(msg1, msg2) => {
panic!("invalid messages received: ({msg1}, {msg2})");
}
}
}
24 changes: 24 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,30 @@ impl From<zmq::Error> for Error {
}
}

#[cfg(feature = "async")]
impl From<async_zmq::SocketError> for Error {
#[inline]
fn from(value: async_zmq::SocketError) -> Self {
Self::Zmq(value.into())
}
}

#[cfg(feature = "async")]
impl From<async_zmq::SubscribeError> for Error {
#[inline]
fn from(value: async_zmq::SubscribeError) -> Self {
Self::Zmq(value.into())
}
}

#[cfg(feature = "async")]
impl From<async_zmq::RecvError> for Error {
#[inline]
fn from(value: async_zmq::RecvError) -> Self {
Self::Zmq(value.into())
}
}

impl From<consensus::encode::Error> for Error {
#[inline]
fn from(value: consensus::encode::Error) -> Self {
Expand Down
7 changes: 7 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod error;
mod message;
mod sequence_message;
mod subscribe;
#[cfg(feature = "async")]
mod subscribe_async;

pub use crate::{
error::Error,
Expand All @@ -11,3 +13,8 @@ pub use crate::{
subscribe_multi, subscribe_multi_blocking, subscribe_single, subscribe_single_blocking,
},
};

#[cfg(feature = "async")]
pub use crate::subscribe_async::{
subscribe_async, subscribe_multi_async, MessageStream, MultiMessageStream,
};
54 changes: 44 additions & 10 deletions src/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
message::{Message, SEQUENCE_LEN, TOPIC_MAX_LEN},
Error, DATA_MAX_LEN,
};
use core::{convert::Infallible, ops::ControlFlow};
use core::{cmp::min, convert::Infallible, ops::ControlFlow, slice};
use std::{
sync::mpsc::{channel, Receiver},
thread,
Expand Down Expand Up @@ -109,46 +109,80 @@ fn new_socket_internal(context: &Context, endpoint: &str) -> Result<Socket> {
Ok(socket)
}

pub(crate) trait ReceiveFrom {
fn has_next(&self) -> Result<bool>;

fn receive_into(&mut self, buf: &mut [u8]) -> Result<usize>;
}

impl ReceiveFrom for &Socket {
fn has_next(&self) -> Result<bool> {
Ok(self.get_rcvmore()?)
}

fn receive_into(&mut self, buf: &mut [u8]) -> Result<usize> {
Ok(self.recv_into(buf, 0)?)
}
}

impl ReceiveFrom for slice::Iter<'_, zmq::Message> {
fn has_next(&self) -> Result<bool> {
Ok(!self.as_slice().is_empty())
}

fn receive_into(&mut self, buf: &mut [u8]) -> Result<usize> {
// TODO better way to handle None than unwrap
let bytes = &**self.next().unwrap();
let len = bytes.len();
let copy_len = min(len, buf.len());
buf[0..copy_len].copy_from_slice(&bytes[0..copy_len]);
Ok(len)
}
}

#[inline]
fn recv_internal(socket: &Socket, data: &mut [u8; DATA_MAX_LEN]) -> Result<Message> {
pub(crate) fn recv_internal<R: ReceiveFrom>(
mut socket: R,
data: &mut [u8; DATA_MAX_LEN],
) -> Result<Message> {
let mut topic = [0u8; TOPIC_MAX_LEN];
let mut sequence = [0u8; SEQUENCE_LEN];

let topic_len = socket.recv_into(&mut topic, 0)?;
let topic_len = socket.receive_into(&mut topic)?;
if topic_len > TOPIC_MAX_LEN {
return Err(Error::InvalidTopic(topic_len, topic));
}

if !socket.get_rcvmore()? {
if !socket.has_next()? {
return Err(Error::InvalidMutlipartLength(1));
}

let data_len = socket.recv_into(data, 0)?;
let data_len = socket.receive_into(data)?;
if data_len > DATA_MAX_LEN {
return Err(Error::InvalidDataLength(data_len));
}

if !socket.get_rcvmore()? {
if !socket.has_next()? {
return Err(Error::InvalidMutlipartLength(2));
}

let sequence_len = socket.recv_into(&mut sequence, 0)?;
let sequence_len = socket.receive_into(&mut sequence)?;
if sequence_len != SEQUENCE_LEN {
return Err(Error::InvalidSequenceLength(sequence_len));
}

if !socket.get_rcvmore()? {
if !socket.has_next()? {
return Message::from_parts(&topic[0..topic_len], &data[0..data_len], sequence);
}

let mut len = 3;

loop {
socket.recv_into(&mut [], 0)?;
socket.receive_into(&mut [])?;

len += 1;

if !socket.get_rcvmore()? {
if !socket.has_next()? {
return Err(Error::InvalidMutlipartLength(len));
}
}
Expand Down
114 changes: 114 additions & 0 deletions src/subscribe_async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use crate::{error::Result, message::Message, subscribe::recv_internal, DATA_MAX_LEN};
use async_zmq::{subscribe, Stream, StreamExt, Subscribe};
use core::{
pin::Pin,
task::{Context as AsyncContext, Poll},
};
use futures_util::stream::Fuse;
use zmq::Context as ZmqContext;

/// Stream that asynchronously produces [`Message`]s using a ZMQ subscriber.
pub struct MessageStream {
zmq_stream: Subscribe,
data_cache: Box<[u8; DATA_MAX_LEN]>,
}

impl MessageStream {
fn new(zmq_stream: Subscribe) -> Self {
Self {
zmq_stream,
data_cache: vec![0; DATA_MAX_LEN].into_boxed_slice().try_into().unwrap(),
}
}
}

impl Stream for MessageStream {
type Item = Result<Message>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut AsyncContext<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().zmq_stream.poll_next_unpin(cx).map(|opt| {
opt.map(|res| match res {
Ok(mp) => recv_internal(mp.iter(), &mut self.data_cache),
Err(err) => Err(err.into()),
})
})
}
}

/// Stream that asynchronously produces [`Message`]s using multiple ZMQ subscriber.
pub struct MultiMessageStream {
streams: Vec<Fuse<MessageStream>>,
next: usize,
}

impl MultiMessageStream {
fn new(buf_capacity: usize) -> Self {
Self {
streams: Vec::with_capacity(buf_capacity),
next: 0,
}
}

fn push(&mut self, stream: Subscribe) {
// Not sure if fuse is needed, but has to prevent use of closed streams.
self.streams.push(MessageStream::new(stream).fuse());
}
}

impl Stream for MultiMessageStream {
type Item = Result<Message>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut AsyncContext<'_>) -> Poll<Option<Self::Item>> {
let mut any_pending = false;

let mut index_iter = (self.next..self.streams.len()).chain(0..self.next);
while let Some(i) = index_iter.next() {
match self.as_mut().streams[i].poll_next_unpin(cx) {
msg @ Poll::Ready(Some(_)) => {
if let Some(next) = index_iter.next() {
self.next = next;
}
return msg;
}
Poll::Ready(None) => {
// continue
}
Poll::Pending => {
any_pending = true;
}
}
}

if any_pending {
Poll::Pending
} else {
Poll::Ready(None)
}
}
}

pub fn subscribe_multi_async(endpoints: &[&str]) -> Result<MultiMessageStream> {
let context = ZmqContext::new();
let mut res = MultiMessageStream::new(endpoints.len());

for endpoint in endpoints {
let socket = new_socket_internal(&context, endpoint)?;
res.push(socket);
}

Ok(res)
}

pub fn subscribe_async(endpoint: &str) -> Result<MessageStream> {
Ok(MessageStream::new(new_socket_internal(
&ZmqContext::new(),
endpoint,
)?))
}

fn new_socket_internal(context: &ZmqContext, endpoint: &str) -> Result<Subscribe> {
let socket = subscribe(endpoint)?.with_context(context).connect()?;
socket.set_subscribe("")?;

Ok(socket)
}

0 comments on commit 4985791

Please sign in to comment.