Skip to content

Commit

Permalink
resolve a bunch of todos
Browse files Browse the repository at this point in the history
  • Loading branch information
antonilol committed Dec 8, 2023
1 parent 1c8d288 commit 603ec14
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 51 deletions.
4 changes: 2 additions & 2 deletions integration_tests/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod util;
use bitcoincore_rpc::Client;
use bitcoincore_zmq::{
subscribe_async, subscribe_async_monitor, subscribe_async_wait_handshake, subscribe_blocking,
subscribe_receiver, EventMessage, Message, SocketEvent, SocketMessage,
subscribe_receiver, MonitorMessage, Message, SocketEvent, SocketMessage,
};
use core::{assert_eq, ops::ControlFlow, time::Duration};
use futures::{executor::block_on, StreamExt};
Expand Down Expand Up @@ -145,7 +145,7 @@ fn test_monitor(rpc: &Client) {
println!("received real message: {msg}");
break;
}
SocketMessage::Event(EventMessage { event, source_url }) => {
SocketMessage::Event(MonitorMessage { event, source_url }) => {
// TODO remove debug printlns before merging
println!("received monitor message: {event:?} from {source_url}");
if event == SocketEvent::HandshakeSucceeded {
Expand Down
20 changes: 18 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::message::{DATA_MAX_LEN, SEQUENCE_LEN, TOPIC_MAX_LEN};
use crate::{
message::{DATA_MAX_LEN, SEQUENCE_LEN, TOPIC_MAX_LEN},
monitor::MonitorMessageError,
};
use bitcoin::consensus;
use core::{cmp::min, fmt};

Expand All @@ -15,6 +18,8 @@ pub enum Error {
Invalid256BitHashLength(usize),
BitcoinDeserialization(consensus::encode::Error),
Zmq(zmq::Error),
MonitorMessage(MonitorMessageError),
Disconnected(String),
}

impl Error {
Expand Down Expand Up @@ -69,6 +74,13 @@ impl From<consensus::encode::Error> for Error {
}
}

impl From<MonitorMessageError> for Error {
#[inline]
fn from(value: MonitorMessageError) -> Self {
Self::MonitorMessage(value)
}
}

impl fmt::Display for Error {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -115,6 +127,8 @@ impl fmt::Display for Error {
write!(f, "bitcoin consensus deserialization error: {e}")
}
Self::Zmq(e) => write!(f, "ZMQ Error: {e}"),
Self::MonitorMessage(err) => write!(f, "unable to parse monitor message: {err}"),
Self::Disconnected(url) => write!(f, "disconnected from {url}"),
}
}
}
Expand All @@ -125,13 +139,15 @@ impl std::error::Error for Error {
Some(match self {
Self::BitcoinDeserialization(e) => e,
Self::Zmq(e) => e,
Self::MonitorMessage(e) => e,
Self::InvalidMutlipartLength(_)
| Self::InvalidTopic(_, _)
| Self::InvalidDataLength(_)
| Self::InvalidSequenceLength(_)
| Self::InvalidSequenceMessageLength(_)
| Self::InvalidSequenceMessageLabel(_)
| Self::Invalid256BitHashLength(_) => return None,
| Self::Invalid256BitHashLength(_)
| Self::Disconnected(_) => return None,
})
}
}
9 changes: 6 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
#![cfg_attr(docsrs, feature(doc_auto_cfg))]

mod error;
mod event;
mod message;
mod monitor;
mod sequence_message;
mod subscribe;

pub use crate::{
error::Error,
event::{EventMessage, HandshakeFailure, SocketEvent},
message::{Message, DATA_MAX_LEN, SEQUENCE_LEN, TOPIC_MAX_LEN},
monitor::{
event::{HandshakeFailure, SocketEvent},
MonitorMessage,
},
sequence_message::SequenceMessage,
subscribe::{blocking::subscribe_blocking, receiver::subscribe_receiver},
};

#[cfg(feature = "async")]
pub use crate::subscribe::stream::{
subscribe_async, subscribe_async_monitor, subscribe_async_wait_handshake,
subscribe_async_wait_handshake_timeout, FiniteMessageStream, MessageStream, SocketMessage,
subscribe_async_wait_handshake_timeout, CheckedMessageStream, MessageStream, SocketMessage,
SocketMessageStream,
};

Expand Down
28 changes: 10 additions & 18 deletions src/event.rs → src/monitor/event.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use super::MonitorMessageError;

/// Convenience trait to be able to use `from_raw` and `to_raw` on any value that either defines it
/// or is a `u32`. It doesn't matter that others don't implement this trait, rustc is smart enough
/// to find that out.
Expand Down Expand Up @@ -126,27 +128,17 @@ define_socket_event_enum! {
HandshakeFailedAuth(error_code) = ZMQ_EVENT_HANDSHAKE_FAILED_AUTH,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EventMessage {
pub event: SocketEvent,
pub source_url: String,
}
impl SocketEvent {
pub fn parse_from(msg: &zmq::Message) -> Result<Self, MonitorMessageError> {
let bytes = &**msg;

impl EventMessage {
pub fn parse_from(msg: Vec<zmq::Message>) -> Self {
// TODO properly handle errors (review uses of unwrap, expect, unreachable)
let [a, b] = &msg[..] else {
unreachable!("monitor message is always 2 frames")
};
let event: [u8; 6] = (**a)
let event: [u8; 6] = bytes
.try_into()
.expect("monitor message's first frame is always 6 bytes");
.map_err(|_| MonitorMessageError::InvalidEventFrameLength(bytes.len()))?;
let event_type = u16::from_ne_bytes(event[0..2].try_into().unwrap());
let data = u32::from_ne_bytes(event[2..6].try_into().unwrap());
let source_url: String = String::from_utf8_lossy(b).into();
EventMessage {
event: SocketEvent::from_raw(event_type, data).unwrap(),
source_url,
}

SocketEvent::from_raw(event_type, data)
.ok_or(MonitorMessageError::InvalidEventData(event_type, data))
}
}
51 changes: 51 additions & 0 deletions src/monitor/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use std::fmt::Display;

use event::SocketEvent;

pub mod event;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MonitorMessage {
pub event: SocketEvent,
pub source_url: String,
}

impl MonitorMessage {
pub fn parse_from(msg: &[zmq::Message]) -> Result<Self, MonitorMessageError> {
let [event_message, url_message] = msg else {
return Err(MonitorMessageError::InvalidMutlipartLength(msg.len()));
};

Ok(MonitorMessage {
event: SocketEvent::parse_from(event_message)?,
source_url: String::from_utf8_lossy(url_message).into(),
})
}
}

#[derive(Debug)]
// currently all variants have the same prefix: `Invalid`, which is correct and intended
#[allow(clippy::enum_variant_names)]
pub enum MonitorMessageError {
InvalidMutlipartLength(usize),
InvalidEventFrameLength(usize),
InvalidEventData(u16, u32),
}

impl Display for MonitorMessageError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidMutlipartLength(len) => {
write!(f, "invalid multipart message length: {len} (expected 2)")
}
Self::InvalidEventFrameLength(len) => {
write!(f, "invalid event frame length: {len} (expected 6)")
}
Self::InvalidEventData(event_type, event_data) => {
write!(f, "invalid event data {event_data} for event {event_type}")
}
}
}
}

impl std::error::Error for MonitorMessageError {}
74 changes: 48 additions & 26 deletions src/subscribe/stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use super::{new_socket_internal, recv_internal};
use crate::{error::Result, event::SocketEvent, message::Message, EventMessage, DATA_MAX_LEN};
use crate::{
error::Result,
message::Message,
monitor::{event::SocketEvent, MonitorMessage, MonitorMessageError},
Error, DATA_MAX_LEN,
};
use async_zmq::{Stream, StreamExt, Subscribe};
use core::{
future::Future,
Expand Down Expand Up @@ -61,23 +66,39 @@ impl FusedStream for MessageStream {
}
}

// TODO move, name
// TODO move?
pub enum SocketMessage {
Message(Message),
Event(EventMessage),
Event(MonitorMessage),
}

enum Empty {}

impl Iterator for Empty {
type Item = Empty;

fn next(&mut self) -> Option<Self::Item> {
None
}
}

impl Into<async_zmq::Message> for Empty {
fn into(self) -> async_zmq::Message {
match self {}
}
}

// The generic type params don't matter as this will only be used for receiving
type Pair = async_zmq::Pair<std::vec::IntoIter<&'static [u8]>, &'static [u8]>;
// Better to use an empty type to not waste precious bytes
type RecvOnlyPair = async_zmq::Pair<Empty, Empty>;

// TODO name?
pub struct SocketMessageStream {
messages: MessageStream,
monitor: Pair,
monitor: RecvOnlyPair,
}

impl SocketMessageStream {
fn new(messages: MessageStream, monitor: Pair) -> Self {
fn new(messages: MessageStream, monitor: RecvOnlyPair) -> Self {
Self { messages, monitor }
}
}
Expand All @@ -88,10 +109,9 @@ impl Stream for SocketMessageStream {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut AsyncContext<'_>) -> Poll<Option<Self::Item>> {
match self.monitor.poll_next_unpin(cx) {
Poll::Ready(msg) => {
// TODO properly handle errors (review uses of unwrap, expect, unreachable)
return Poll::Ready(Some(Ok(SocketMessage::Event(EventMessage::parse_from(
msg.unwrap()?,
)))));
return Poll::Ready(Some(Ok(SocketMessage::Event(MonitorMessage::parse_from(
&msg.unwrap()?,
)?))));
}
Poll::Pending => {}
}
Expand All @@ -108,18 +128,17 @@ impl FusedStream for SocketMessageStream {
}
}

// TODO name, disconnect on failure?
pub struct FiniteMessageStream {
pub struct CheckedMessageStream {
inner: Option<SocketMessageStream>,
}

impl FiniteMessageStream {
impl CheckedMessageStream {
pub fn new(inner: SocketMessageStream) -> Self {
Self { inner: Some(inner) }
}
}

impl Stream for FiniteMessageStream {
impl Stream for CheckedMessageStream {
type Item = Result<Message>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut AsyncContext<'_>) -> Poll<Option<Self::Item>> {
Expand All @@ -128,11 +147,11 @@ impl Stream for FiniteMessageStream {
match inner.poll_next_unpin(cx) {
Poll::Ready(opt) => match opt.unwrap()? {
SocketMessage::Message(msg) => return Poll::Ready(Some(Ok(msg))),
SocketMessage::Event(EventMessage { event, .. }) => {
SocketMessage::Event(MonitorMessage { event, source_url }) => {
if let SocketEvent::Disconnected { .. } = event {
// drop to disconnect
self.inner = None;
return Poll::Ready(None);
return Poll::Ready(Some(Err(Error::Disconnected(source_url))));
} else {
// only here it loops
}
Expand All @@ -147,7 +166,7 @@ impl Stream for FiniteMessageStream {
}
}

impl FusedStream for FiniteMessageStream {
impl FusedStream for CheckedMessageStream {
fn is_terminated(&self) -> bool {
self.inner.is_none()
}
Expand Down Expand Up @@ -271,17 +290,20 @@ pub fn subscribe_async_monitor(endpoints: &[&str]) -> Result<SocketMessageStream
/// this should be used with the timeout function of your async runtime, this function will wait
/// indefinitely. to runtime independently return after some timeout, a second thread is needed
/// which is inefficient
pub async fn subscribe_async_wait_handshake(endpoints: &[&str]) -> Result<FiniteMessageStream> {
pub async fn subscribe_async_wait_handshake(endpoints: &[&str]) -> Result<CheckedMessageStream> {
let mut stream = subscribe_async_monitor(endpoints)?;
let mut connecting = endpoints.len();

if connecting == 0 {
return Ok(FiniteMessageStream::new(stream));
return Ok(CheckedMessageStream::new(stream));
}

loop {
// TODO only decode first frame, the second frame (source address) is unused here but a String is allocated for it
match EventMessage::parse_from(stream.monitor.next().await.unwrap()?).event {
let msg: &[zmq::Message] = &stream.monitor.next().await.unwrap()?;
let [event_message, _] = msg else {
return Err(MonitorMessageError::InvalidMutlipartLength(msg.len()).into());
};
match SocketEvent::parse_from(event_message)? {
SocketEvent::HandshakeSucceeded => {
connecting -= 1;
}
Expand All @@ -293,7 +315,7 @@ pub async fn subscribe_async_wait_handshake(endpoints: &[&str]) -> Result<Finite
}
}
if connecting == 0 {
return Ok(FiniteMessageStream::new(stream));
return Ok(CheckedMessageStream::new(stream));
}
}
}
Expand All @@ -302,13 +324,13 @@ pub async fn subscribe_async_wait_handshake(endpoints: &[&str]) -> Result<Finite
pub async fn subscribe_async_wait_handshake_timeout(
endpoints: &[&str],
timeout: Duration,
) -> Option<Result<FiniteMessageStream>> {
) -> core::result::Result<Result<CheckedMessageStream>, ()> {
let subscribe = subscribe_async_wait_handshake(endpoints);
let timeout = sleep(timeout);

match select(pin!(subscribe), timeout).await {
Either::Left((res, _)) => Some(res),
Either::Right(_) => None,
Either::Left((res, _)) => Ok(res),
Either::Right(_) => Err(()),
}
}

Expand Down

0 comments on commit 603ec14

Please sign in to comment.