From 63fa14f6fceb37630cc2efd123ca224db8f96a91 Mon Sep 17 00:00:00 2001 From: Paul Loyd Date: Sun, 11 Feb 2024 20:54:01 +0100 Subject: [PATCH] perf(core/mailbox): use an intrusive list of envelopes --- elfo-core/Cargo.toml | 1 + elfo-core/src/context.rs | 62 ++- elfo-core/src/envelope.rs | 373 +++++++++++++---- elfo-core/src/lib.rs | 6 +- elfo-core/src/macros.rs | 18 +- elfo-core/src/mailbox.rs | 132 ++++-- elfo-core/src/message.rs | 665 ++++++++++++++++++++++++------ elfo-core/src/request_table.rs | 2 +- elfo-core/src/signal.rs | 2 +- elfo-core/src/stream.rs | 6 +- elfo-core/src/time/delay.rs | 2 +- elfo-core/src/time/interval.rs | 2 +- elfo-macros-impl/src/message.rs | 71 +--- elfo-macros-impl/src/msg.rs | 18 +- elfo-network/src/codec/format.rs | 1 + elfo-network/src/codec/mod.rs | 12 +- elfo-network/src/discovery/mod.rs | 4 +- elfo-network/src/socket/mod.rs | 4 +- elfo-network/src/worker/mod.rs | 60 ++- elfo-test/src/proxy.rs | 2 +- elfo-test/src/utils.rs | 9 +- elfo-utils/src/time.rs | 6 +- elfo/Cargo.toml | 3 - elfo/tests/msg_macro.rs | 10 +- elfo/tests/protocol_evolution.rs | 2 +- 25 files changed, 1052 insertions(+), 421 deletions(-) diff --git a/elfo-core/Cargo.toml b/elfo-core/Cargo.toml index 879d80c4..82c863a1 100644 --- a/elfo-core/Cargo.toml +++ b/elfo-core/Cargo.toml @@ -26,6 +26,7 @@ elfo-utils = { version = "0.2.5", path = "../elfo-utils" } tokio = { version = "1.16", features = ["rt", "sync", "time", "signal", "macros"] } idr-ebr = "0.2" futures-intrusive = "0.5" +cordyceps = "0.3.2" parking_lot = "0.12" smallbox = "0.8.0" # TODO: avoid the `rc` feature here? diff --git a/elfo-core/src/context.rs b/elfo-core/src/context.rs index 2b07c82b..c9f1f8da 100644 --- a/elfo-core/src/context.rs +++ b/elfo-core/src/context.rs @@ -15,7 +15,7 @@ use crate::{ coop, demux::Demux, dumping::{Direction, Dump, Dumper, INTERNAL_CLASS}, - envelope::{AnyMessageBorrowed, AnyMessageOwned, Envelope, EnvelopeOwned, MessageKind}, + envelope::{Envelope, MessageKind}, errors::{RequestError, SendError, TryRecvError, TrySendError}, mailbox::RecvResult, message::{Message, Request}, @@ -179,14 +179,14 @@ impl Context { sender: self.actor_addr, }; - self.stats.on_sent_message(&message); + self.stats.on_sent_message(&message); // TODO: only if successful? trace!("> {:?}", message); if let Some(permit) = DUMPER.acquire_m(&message) { permit.record(Dump::message(&message, &kind, Direction::Out)); } - let envelope = Envelope::new(message, kind).upcast(); + let envelope = Envelope::new(message, kind); let addrs = self.demux.filter(&envelope); if addrs.is_empty() { @@ -214,15 +214,15 @@ impl Context { Ok(()) => success = true, Err(err) => { has_full |= err.is_full(); - forget_and_replace(&mut unused, Some(err.into_inner())); + replace_unused(&mut unused, Some(err.into_inner())); } }, - None => forget_and_replace(&mut unused, Some(envelope)), + None => replace_unused(&mut unused, Some(envelope)), }; } if success { - forget_and_replace(&mut unused, None); + replace_unused(&mut unused, None); Ok(()) } else if has_full { Err(TrySendError::Full(e2m(unused.unwrap()))) @@ -270,14 +270,14 @@ impl Context { } async fn do_send(&self, message: M, kind: MessageKind) -> Result<(), SendError> { - self.stats.on_sent_message(&message); + self.stats.on_sent_message(&message); // TODO: only if successful? trace!("> {:?}", message); if let Some(permit) = DUMPER.acquire_m(&message) { permit.record(Dump::message(&message, &kind, Direction::Out)); } - let envelope = Envelope::new(message, kind).upcast(); + let envelope = Envelope::new(message, kind); let addrs = self.demux.filter(&envelope); if addrs.is_empty() { @@ -305,7 +305,7 @@ impl Context { let guard = EbrGuard::new(); let entry = self.book.get(recipient, &guard); let object = ward!(entry, { - forget_and_replace(&mut unused, Some(envelope)); + replace_unused(&mut unused, Some(envelope)); continue; }); Object::send(object, Addr::NULL, envelope) @@ -314,14 +314,14 @@ impl Context { .err() .map(|err| err.0); - forget_and_replace(&mut unused, returned_envelope); + replace_unused(&mut unused, returned_envelope); if unused.is_none() { success = true; } } if success { - forget_and_replace(&mut unused, None); + replace_unused(&mut unused, None); Ok(()) } else { Err(SendError(e2m(unused.unwrap()))) @@ -362,7 +362,7 @@ impl Context { message: M, kind: MessageKind, ) -> Result<(), SendError> { - self.stats.on_sent_message(&message); + self.stats.on_sent_message(&message); // TODO: only if successful? trace!(to = %recipient, "> {:?}", message); if let Some(permit) = DUMPER.acquire_m(&message) { @@ -374,7 +374,7 @@ impl Context { let entry = self.book.get(recipient, &guard); let object = ward!(entry, return Err(SendError(message))); let envelope = Envelope::new(message, kind); - Object::send(object, recipient, envelope.upcast()) + Object::send(object, recipient, envelope) } .await .map_err(|err| SendError(e2m(err.0))) @@ -402,7 +402,7 @@ impl Context { recipient: Addr, message: M, ) -> Result<(), TrySendError> { - self.stats.on_sent_message(&message); + self.stats.on_sent_message(&message); // TODO: only if successful? let kind = MessageKind::Regular { sender: self.actor_addr, @@ -419,7 +419,7 @@ impl Context { let envelope = Envelope::new(message, kind); object - .try_send(recipient, envelope.upcast()) + .try_send(recipient, envelope) .map_err(|err| err.map(e2m)) } @@ -442,7 +442,7 @@ impl Context { let token = token.into_untyped(); let recipient = token.sender(); let message = R::Wrapper::from(message); - self.stats.on_sent_message(&message); + self.stats.on_sent_message(&message); // TODO: only if successful? let kind = MessageKind::Response { sender: self.addr(), @@ -454,7 +454,7 @@ impl Context { permit.record(Dump::message(&message, &kind, Direction::Out)); } - let envelope = Envelope::new(message, kind).upcast(); + let envelope = Envelope::new(message, kind); let guard = EbrGuard::new(); let object = ward!(self.book.get(recipient, &guard)); object.respond(token, Ok(envelope)); @@ -696,7 +696,7 @@ impl Context { let kind = MessageKind::Regular { sender: self.actor_addr, }; - let envelope = Envelope::new(message, kind).upcast(); + let envelope = Envelope::new(message, kind); self.respond(token, Ok(())); envelope } @@ -705,9 +705,9 @@ impl Context { let message = envelope.message(); trace!("< {:?}", message); - if let Some(permit) = DUMPER.acquire_m(message) { + if let Some(permit) = DUMPER.acquire_m(&*message) { let kind = envelope.message_kind(); - permit.record(Dump::message(message, kind, Direction::In)); + permit.record(Dump::message(&*message, kind, Direction::In)); } // We should change the status after dumping the original message @@ -826,11 +826,9 @@ impl Context { } } +#[cold] fn e2m(envelope: Envelope) -> M { - envelope - .unpack_regular() - .downcast() - .expect("invalid message") + envelope.unpack().expect("invalid message").0 } #[cold] @@ -862,10 +860,9 @@ fn addrs_with_envelope( }) } -fn forget_and_replace(dest: &mut Option, value: Option) { - if let Some(old_value) = dest.take() { - let (_, token) = old_value.unpack_request(); - token.forget(); +fn replace_unused(dest: &mut Option, value: Option) { + if let Some(old) = dest.take() { + old.drop_as_unused(); } *dest = value; } @@ -1019,14 +1016,13 @@ fn prepare_response( response: Result, ) -> Result { let envelope = response?; - let message = envelope.message().downcast2::(); + let (message, kind) = envelope.unpack::().expect("invalid response"); // TODO: increase a counter. trace!("< {:?}", message); - if let Some(permit) = DUMPER.acquire_m(message) { - let kind = envelope.message_kind(); - permit.record(Dump::message(message, kind, Direction::In)); + if let Some(permit) = DUMPER.acquire_m(&message) { + permit.record(Dump::message(&message, &kind, Direction::In)); } - Ok(envelope.unpack_regular().downcast2::().into()) + Ok(message.into()) } diff --git a/elfo-core/src/envelope.rs b/elfo-core/src/envelope.rs index aabcdbcf..77fd0a12 100644 --- a/elfo-core/src/envelope.rs +++ b/elfo-core/src/envelope.rs @@ -1,26 +1,46 @@ +use std::{alloc, fmt, mem, ptr, ptr::NonNull}; + use elfo_utils::time::Instant; use crate::{ - message::{AnyMessage, Message}, + mailbox, + message::{AnyMessageRef, Message, MessageRepr, MessageTypeId, Request}, request_table::{RequestId, ResponseToken}, tracing::TraceId, Addr, }; -// TODO: use granular messages instead of `SmallBox`. -#[derive(Debug)] -pub struct Envelope { +pub struct Envelope(NonNull); + +assert_impl_all!(Envelope: Send); +assert_not_impl_any!(Envelope: Sync); +assert_eq_size!(Envelope, usize); + +// TODO: describe following: +// + add comment to network about kanal +// link: 8 +// created_time: 8 +// trace_id: 8 -> 16 +// kind: 24 -> 32 +// offset: 4 +// --- +// 8+8+8+24+4^=56 +// ^ padding (can be removed) +// +// 48 -> 64 +pub(crate) struct EnvelopeHeader { + pub(crate) link: mailbox::Link, created_time: Instant, // Now used also as a sent time. trace_id: TraceId, kind: MessageKind, - message: M, + message_offset: u32, } -assert_impl_all!(Envelope: Send); -assert_eq_size!(Envelope, [u8; 256]); +unsafe impl Send for Envelope {} +// TODO: Sync?? +// TODO: Pin? // Reexported in `elfo::_priv`. -#[derive(Debug)] pub enum MessageKind { Regular { sender: Addr }, RequestAny(ResponseToken), @@ -28,88 +48,124 @@ pub enum MessageKind { Response { sender: Addr, request_id: RequestId }, } -impl Envelope { +impl Drop for Envelope { + fn drop(&mut self) { + let message = self.message(); + let message_layout = message._repr_layout(); + let (layout, message_offset) = envelope_repr_layout(message_layout); + debug_assert_eq!(message_offset, self.header().message_offset); + + unsafe { message.drop_in_place() }; + unsafe { ptr::drop_in_place(self.0.as_ptr()) } + unsafe { alloc::dealloc(self.0.as_ptr().cast(), layout) }; + } +} + +impl Envelope { // This is private API. Do not use it. #[doc(hidden)] #[inline] - pub fn new(message: M, kind: MessageKind) -> Self { + pub fn new(message: M, kind: MessageKind) -> Self { Self::with_trace_id(message, kind, crate::scope::trace_id()) } // This is private API. Do not use it. #[doc(hidden)] #[inline] - pub fn with_trace_id(message: M, kind: MessageKind, trace_id: TraceId) -> Self { - Self { + pub fn with_trace_id(message: M, kind: MessageKind, trace_id: TraceId) -> Self { + let message_layout = message._repr_layout(); + let (layout, message_offset) = envelope_repr_layout(message_layout); + + let header = EnvelopeHeader { + link: <_>::default(), created_time: Instant::now(), trace_id, kind, - message, - } + message_offset, + }; + + let ptr = unsafe { alloc::alloc(layout) }; + + let Some(ptr) = NonNull::new(ptr) else { + alloc::handle_alloc_error(layout); + }; + + unsafe { ptr::write(ptr.cast().as_ptr(), header) }; + + let this = Self(ptr.cast()); + + let message_ptr = this.message_repr_ptr(); + unsafe { message._write(message_ptr) }; + + this + } + + pub(crate) fn stub() -> Self { + Self::with_trace_id( + crate::messages::Ping, + MessageKind::Regular { sender: Addr::NULL }, + TraceId::try_from(1).unwrap(), + ) + } + + fn header(&self) -> &EnvelopeHeader { + unsafe { self.0.as_ref() } } #[inline] pub fn trace_id(&self) -> TraceId { - self.trace_id + self.header().trace_id } #[inline] - pub fn message(&self) -> &M { - &self.message + pub fn message(&self) -> AnyMessageRef<'_> { + let message_repr = self.message_repr_ptr(); + unsafe { AnyMessageRef::new(message_repr) } } /// Part of private API. Do not use it. #[doc(hidden)] pub fn message_kind(&self) -> &MessageKind { - &self.kind + &self.header().kind } pub(crate) fn created_time(&self) -> Instant { - self.created_time + self.header().created_time } #[inline] pub fn sender(&self) -> Addr { - match &self.kind { + match self.message_kind() { MessageKind::Regular { sender } => *sender, MessageKind::RequestAny(token) => token.sender(), MessageKind::RequestAll(token) => token.sender(), MessageKind::Response { sender, .. } => *sender, } } -} -impl Envelope { - // This is private API. Do not use it. - #[doc(hidden)] - pub fn upcast(self) -> Envelope { - Envelope { - created_time: self.created_time, - trace_id: self.trace_id, - kind: self.kind, - message: self.message.upcast(), - } - } -} - -impl Envelope { #[inline] - pub fn is(&self) -> bool { - self.message.is::() + pub fn type_id(&self) -> MessageTypeId { + self.message().type_id() } #[inline] - pub fn type_id(&self) -> std::any::TypeId { - self.message.type_id() + pub fn is(&self) -> bool { + self.message().is::() } #[doc(hidden)] - #[stability::unstable] pub fn duplicate(&self) -> Self { - Self { - created_time: self.created_time, - trace_id: self.trace_id, - kind: match &self.kind { + let header = self.header(); + let message = self.message(); + let message_layout = message._repr_layout(); + let (layout, message_offset) = envelope_repr_layout(message_layout); + debug_assert_eq!(message_offset, header.message_offset); + + let out_header = EnvelopeHeader { + link: <_>::default(), + created_time: header.created_time, + trace_id: header.trace_id, + kind: match &header.kind { MessageKind::Regular { sender } => MessageKind::Regular { sender: *sender }, MessageKind::RequestAny(token) => MessageKind::RequestAny(token.duplicate()), MessageKind::RequestAll(token) => MessageKind::RequestAll(token.duplicate()), @@ -118,90 +174,239 @@ impl Envelope { request_id: *request_id, }, }, - message: self.message.clone(), - } + message_offset, + }; + + let out_ptr = unsafe { alloc::alloc(layout) }; + + let Some(out_ptr) = NonNull::new(out_ptr) else { + alloc::handle_alloc_error(layout); + }; + + unsafe { ptr::write(out_ptr.cast().as_ptr(), out_header) }; + + let out = Self(out_ptr.cast()); + + let out_message_ptr = out.message_repr_ptr(); + unsafe { message.clone_into(out_message_ptr) }; + + out } - // TODO: remove the method? + // TODO: remove the method pub(crate) fn set_message(&mut self, message: M) { - self.message = message.upcast(); + assert!(self.is::() && M::_type_id() != crate::message::AnyMessage::_type_id()); + + // TODO: rewrite without `MessageRepr`? + let repr_ptr = self.message_repr_ptr().cast::>().as_ptr(); + unsafe { ptr::replace(repr_ptr, MessageRepr::new(message)) }; + } + + fn message_repr_ptr(&self) -> NonNull { + let message_offset = self.header().message_offset; + let ptr = unsafe { self.0.as_ptr().cast::().add(message_offset as usize) }; + unsafe { NonNull::new_unchecked(ptr.cast()) } + } + + #[doc(hidden)] + #[inline] + pub fn unpack(self) -> Option<(M, MessageKind)> { + self.is::().then(|| unsafe { self.unpack_unchecked() }) + } + + unsafe fn unpack_unchecked(self) -> (M, MessageKind) { + let message_layout = self.message()._repr_layout(); + let (layout, message_offset) = envelope_repr_layout(message_layout); + debug_assert_eq!(message_offset, self.header().message_offset); + + let message = M::_read(self.message_repr_ptr()); + let kind = ptr::read(&self.0.as_ref().kind); + + alloc::dealloc(self.0.as_ptr().cast(), layout); + mem::forget(self); + (message, kind) + } + + pub(crate) fn drop_as_unused(mut self) { + let header = unsafe { self.0.as_mut() }; + + if let MessageKind::RequestAny(token) | MessageKind::RequestAll(token) = &mut header.kind { + // TODO FIXME: invalid for ALL requests, need to decrement remainder. + token.forget(); + } + } + + pub(crate) fn into_header_ptr(self) -> NonNull { + let ptr = self.0; + mem::forget(self); + ptr + } + + pub(crate) unsafe fn from_header_ptr(ptr: NonNull) -> Self { + Self(ptr) + } +} + +fn envelope_repr_layout(message_layout: alloc::Layout) -> (alloc::Layout, u32) { + let (layout, message_offset) = alloc::Layout::new::() + .extend(message_layout) + .expect("impossible envelope layout"); + + let message_offset = + u32::try_from(message_offset).expect("message requires too large alignment"); + + (layout.pad_to_align(), message_offset) +} + +impl fmt::Debug for MessageKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MessageKind::Regular { sender: _ } => f.debug_struct("Regular").finish(), + MessageKind::RequestAny(token) => f + .debug_tuple("RequestAny") + .field(&token.request_id()) + .finish(), + MessageKind::RequestAll(token) => f + .debug_tuple("RequestAll") + .field(&token.request_id()) + .finish(), + MessageKind::Response { + sender: _, + request_id, + } => f.debug_tuple("Response").field(request_id).finish(), + } + } +} + +impl fmt::Debug for Envelope { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Envelope") + .field("trace_id", &self.trace_id()) + .field("sender", &self.sender()) + .field("kind", &self.message_kind()) + .field("message", &self.message()) + .finish() } } // Extra traits to support both owned and borrowed usages of `msg!(..)`. pub trait EnvelopeOwned { - fn unpack_regular(self) -> AnyMessage; - fn unpack_request(self) -> (AnyMessage, ResponseToken); + unsafe fn unpack_regular_unchecked(self) -> M; + unsafe fn unpack_request_unchecked(self) -> (R, ResponseToken); } pub trait EnvelopeBorrowed { - fn unpack_regular(&self) -> &AnyMessage; + unsafe fn unpack_regular_unchecked(&self) -> &M; } +// TODO: AnyMessage impl EnvelopeOwned for Envelope { #[inline] - fn unpack_regular(self) -> AnyMessage { + unsafe fn unpack_regular_unchecked(self) -> M { + let (message, kind) = self.unpack_unchecked(); + #[cfg(feature = "network")] - if let MessageKind::RequestAny(token) | MessageKind::RequestAll(token) = self.kind { + if let MessageKind::RequestAny(token) | MessageKind::RequestAll(token) = kind { // The sender thought this is a request, but for the current node it isn't. // Mark the token as received to return `RequestError::Ignored` to the sender. let _ = token.into_received::<()>(); } + // TODO: about assert in `msg!` #[cfg(not(feature = "network"))] debug_assert!(!matches!( - self.kind, + kind, MessageKind::RequestAny(_) | MessageKind::RequestAll(_) )); - self.message + message } #[inline] - fn unpack_request(self) -> (AnyMessage, ResponseToken) { - match self.kind { - MessageKind::RequestAny(token) | MessageKind::RequestAll(token) => { - (self.message, token) - } + unsafe fn unpack_request_unchecked(self) -> (R, ResponseToken) { + let (message, kind) = self.unpack_unchecked(); + + let token = match kind { + MessageKind::RequestAny(token) | MessageKind::RequestAll(token) => token, // A request sent by using `ctx.send()` ("fire and forget"). // Also it's useful for the protocol evolution between remote nodes. - _ => (self.message, ResponseToken::forgotten()), - } + _ => ResponseToken::forgotten(), + }; + + (message, token.into_received()) } } impl EnvelopeBorrowed for Envelope { #[inline] - fn unpack_regular(&self) -> &AnyMessage { - &self.message + unsafe fn unpack_regular_unchecked(&self) -> &M { + self.message().downcast_ref_unchecked() } } -pub trait AnyMessageOwned { - fn downcast2(self) -> M; -} +#[cfg(test)] +mod tests { + use std::sync::Arc; -pub trait AnyMessageBorrowed { - fn downcast2(&self) -> &M; -} + use elfo_utils::time; -impl AnyMessageOwned for AnyMessage { - #[inline] - fn downcast2(self) -> M { - match self.downcast::() { - Ok(message) => message, - Err(message) => panic!("unexpected message: {message:?}"), + use super::*; + use crate::message; + + #[message] + #[derive(PartialEq)] + struct Sample { + value: u128, + counter: Arc<()>, + } + + impl Sample { + fn new(value: u128) -> (Arc<()>, Self) { + let this = Self { + value, + counter: Arc::new(()), + }; + + (this.counter.clone(), this) } } -} -impl AnyMessageBorrowed for AnyMessage { - #[inline] - fn downcast2(&self) -> &M { - ward!( - self.downcast_ref::(), - panic!("unexpected message: {self:?}") - ) + #[test] + fn miri_duplicate() { + let (counter, message) = Sample::new(42); + + // Miri doesn't support asm, so mock the time. + let envelope = time::with_instant_mock(|_mock| { + Envelope::with_trace_id( + message, + MessageKind::Regular { sender: Addr::NULL }, + TraceId::try_from(1).unwrap(), + ) + }); + + assert_eq!(Arc::strong_count(&counter), 2); + let envelope2 = envelope.duplicate(); + assert_eq!(Arc::strong_count(&counter), 3); + assert!(envelope2.is::()); + let envelope3 = envelope2.duplicate(); + assert_eq!(Arc::strong_count(&counter), 4); + assert!(envelope3.is::()); + + drop(envelope2); + assert_eq!(Arc::strong_count(&counter), 3); + + drop(envelope3); + assert_eq!(Arc::strong_count(&counter), 2); + + let envelope4 = envelope.duplicate(); + assert_eq!(Arc::strong_count(&counter), 3); + assert!(envelope4.is::()); + + drop(envelope); + assert_eq!(Arc::strong_count(&counter), 2); + + drop(envelope4); + assert_eq!(Arc::strong_count(&counter), 1); } } diff --git a/elfo-core/src/lib.rs b/elfo-core/src/lib.rs index 54e51739..119deca9 100644 --- a/elfo-core/src/lib.rs +++ b/elfo-core/src/lib.rs @@ -18,7 +18,7 @@ pub use crate::{ envelope::Envelope, group::{ActorGroup, Blueprint, TerminationPolicy}, local::{Local, MoveOwnership}, - message::{Message, Request}, + message::{AnyMessage, AnyMessageRef, Message, Request}, request_table::ResponseToken, restarting::{RestartParams, RestartPolicy}, source::{SourceHandle, UnattachedSource}, @@ -87,9 +87,7 @@ pub mod _priv { pub use crate::addr::{GroupNo, NodeLaunchId, NodeNo}; pub use crate::{ address_book::AddressBook, - envelope::{ - AnyMessageBorrowed, AnyMessageOwned, EnvelopeBorrowed, EnvelopeOwned, MessageKind, - }, + envelope::{EnvelopeBorrowed, EnvelopeOwned, MessageKind}, init::do_start, message::*, object::{GroupVisitor, Object, OwnedObject}, diff --git a/elfo-core/src/macros.rs b/elfo-core/src/macros.rs index 83bbc001..80134562 100644 --- a/elfo-core/src/macros.rs +++ b/elfo-core/src/macros.rs @@ -1,18 +1,16 @@ #[macro_export] macro_rules! assert_msg { ($envelope:expr, $pat:pat) => {{ - use $crate::_priv::{AnyMessageBorrowed, EnvelopeBorrowed}; - let envelope = &$envelope; - let msg = envelope.unpack_regular().downcast2(); + let message = envelope.message(); + // TODO: use `msg!` to support multiple messages in a pattern. #[allow(unreachable_patterns)] - match &msg { - &$pat => {} + match &message.downcast_ref() { + Some($pat) => {} _ => panic!( - "\na message doesn't match a pattern\npattern: {}\nmessage: {:#?}\n", + "\na message doesn't match a pattern\npattern: {}\nmessage: {message:#?}\n", stringify!($pat), - msg, ), } }}; @@ -21,11 +19,11 @@ macro_rules! assert_msg { #[macro_export] macro_rules! assert_msg_eq { ($envelope:expr, $expected:expr) => {{ - use $crate::_priv::{AnyMessageBorrowed, EnvelopeBorrowed}; - let envelope = &$envelope; - let actual = envelope.unpack_regular().downcast2(); + let Some(actual) = envelope.message().downcast_ref() else { + panic!("unexpected message: {:#?}", envelope.message()); + }; let expected = &$expected; fn unify(_rhs: &T, _lhs: &T) {} diff --git a/elfo-core/src/mailbox.rs b/elfo-core/src/mailbox.rs index 2cfac888..bc5d0f84 100644 --- a/elfo-core/src/mailbox.rs +++ b/elfo-core/src/mailbox.rs @@ -1,56 +1,118 @@ -use futures_intrusive::{ - buffer::GrowingHeapBuf, - channel::{self, GenericChannel}, +use std::ptr::{self, NonNull}; + +use cordyceps::{ + mpsc_queue::{Links, MpscQueue}, + Linked, }; -use parking_lot::{Mutex, RawMutex}; +use parking_lot::Mutex; +use tokio::sync::{Notify, Semaphore, TryAcquireError}; + +use elfo_utils::CachePadded; use crate::{ - envelope::Envelope, + envelope::{Envelope, EnvelopeHeader}, errors::{SendError, TrySendError}, tracing::TraceId, }; -// TODO: make mailboxes bounded by time instead of size. +pub(crate) type Link = Links; + +unsafe impl Linked> for EnvelopeHeader { + type Handle = Envelope; + + // TODO: Pin? + + fn into_ptr(handle: Self::Handle) -> NonNull { + handle.into_header_ptr() + } + + /// Convert a raw pointer back into an owned `Handle`. + unsafe fn from_ptr(ptr: NonNull) -> Self::Handle { + unsafe { Self::Handle::from_header_ptr(ptr) } + } + + /// Access an element's `Links`. + unsafe fn links(ptr: NonNull) -> NonNull> { + // Using `ptr::addr_of_mut!` permits us to avoid creating a temporary + // reference without using layout-dependent casts. + let links = ptr::addr_of_mut!((*ptr.as_ptr()).link); + + // `NonNull::new_unchecked` is safe to use here, because the pointer that + // we offset was not null, implying that the pointer produced by offsetting + // it will also not be null. + NonNull::new_unchecked(links) + } +} + +// TODO: make configurable && limit by time. const LIMIT: usize = 100_000; pub(crate) struct Mailbox { - queue: GenericChannel>, + queue: MpscQueue, + tx_semaphore: Semaphore, + rx_notify: CachePadded, closed_trace_id: Mutex>, } impl Mailbox { pub(crate) fn new() -> Self { Self { - queue: GenericChannel::with_capacity(LIMIT), + queue: MpscQueue::new_with_stub(Envelope::stub()), + tx_semaphore: Semaphore::new(LIMIT), + rx_notify: CachePadded(Notify::new()), closed_trace_id: Mutex::new(None), } } pub(crate) async fn send(&self, envelope: Envelope) -> Result<(), SendError> { - let fut = self.queue.send(envelope); - fut.await.map_err(|err| SendError(err.0)) + let permit = match self.tx_semaphore.acquire().await { + Ok(permit) => permit, + Err(_) => return Err(SendError(envelope)), + }; + + permit.forget(); + self.queue.enqueue(envelope); + self.rx_notify.notify_one(); + Ok(()) } pub(crate) fn try_send(&self, envelope: Envelope) -> Result<(), TrySendError> { - self.queue.try_send(envelope).map_err(|err| match err { - channel::TrySendError::Full(envelope) => TrySendError::Full(envelope), - channel::TrySendError::Closed(envelope) => TrySendError::Closed(envelope), - }) + match self.tx_semaphore.try_acquire() { + Ok(permit) => { + permit.forget(); + self.queue.enqueue(envelope); + self.rx_notify.notify_one(); + Ok(()) + } + Err(TryAcquireError::NoPermits) => Err(TrySendError::Full(envelope)), + Err(TryAcquireError::Closed) => Err(TrySendError::Closed(envelope)), + } } pub(crate) async fn recv(&self) -> RecvResult { - let fut = self.queue.receive(); - match fut.await { - Some(envelope) => RecvResult::Data(envelope), - None => self.on_close(), + loop { + if let Some(envelope) = self.queue.dequeue() { + // TODO: try_dequeue? + self.tx_semaphore.add_permits(1); + return RecvResult::Data(envelope); + } + + if self.tx_semaphore.is_closed() { + return self.on_close(); + } + + self.rx_notify.notified().await; } } pub(crate) fn try_recv(&self) -> Option { - match self.queue.try_receive() { - Ok(envelope) => Some(RecvResult::Data(envelope)), - Err(channel::TryReceiveError::Empty) => None, - Err(channel::TryReceiveError::Closed) => Some(self.on_close()), + match self.queue.dequeue() { + Some(envelope) => { + self.tx_semaphore.add_permits(1); + Some(RecvResult::Data(envelope)) + } + None if self.tx_semaphore.is_closed() => Some(self.on_close()), + None => None, } } @@ -61,23 +123,33 @@ impl Mailbox { // possible when we try to `recv()` after the channel is closed, but // before the `closed_trace_id` is assigned. let mut closed_trace_id = self.closed_trace_id.lock(); - if self.queue.close().is_newly_closed() { - *closed_trace_id = Some(trace_id); - true - } else { - false + + if self.tx_semaphore.is_closed() { + return false; } + + *closed_trace_id = Some(trace_id); + + self.tx_semaphore.close(); + self.rx_notify.notify_one(); + true } #[cold] pub(crate) fn drop_all(&self) { - while self.queue.try_receive().is_ok() {} + while self.queue.dequeue().is_some() {} } #[cold] fn on_close(&self) -> RecvResult { - let trace_id = self.closed_trace_id.lock().expect("called before close()"); - RecvResult::Closed(trace_id) + // Some messages may be in the queue after the channel is closed. + match self.queue.dequeue() { + Some(envelope) => RecvResult::Data(envelope), + None => { + let trace_id = self.closed_trace_id.lock().expect("called before close()"); + RecvResult::Closed(trace_id) + } + } } } diff --git a/elfo-core/src/message.rs b/elfo-core/src/message.rs index 6d20edff..969a7743 100644 --- a/elfo-core/src/message.rs +++ b/elfo-core/src/message.rs @@ -1,21 +1,29 @@ -use std::{any::Any, fmt, ops::Deref}; +use std::{ + alloc, fmt, + marker::PhantomData, + mem::{self, ManuallyDrop}, + ops::Deref, + ptr::{self, NonNull}, +}; use fxhash::{FxHashMap, FxHashSet}; use linkme::distributed_slice; use metrics::Label; -use once_cell::sync::Lazy; +use once_cell::sync::Lazy; // TODO: replace with std? use serde::{ - de::{DeserializeSeed, SeqAccess, Visitor}, - ser::{SerializeStruct as _, SerializeTuple as _}, - Deserialize, Deserializer, Serialize, + de, + ser::{self, SerializeStruct as _, SerializeTuple as _}, + Deserialize, Serialize, }; -use smallbox::{smallbox, SmallBox}; - -use elfo_utils::unlikely; +use smallbox::smallbox; use crate::{dumping, scope::SerdeMode}; -pub trait Message: fmt::Debug + Clone + Any + Send + Serialize + for<'de> Deserialize<'de> { +// === Message === + +pub trait Message: + fmt::Debug + Clone + Send + Serialize + for<'de> Deserialize<'de> + 'static +{ #[inline(always)] fn name(&self) -> &'static str { self._vtable().name @@ -38,65 +46,170 @@ pub trait Message: fmt::Debug + Clone + Any + Send + Serialize + for<'de> Deseri self._vtable().dumping_allowed } + #[deprecated(note = "use `AnyMessage::new` instead")] #[doc(hidden)] #[inline(always)] fn upcast(self) -> AnyMessage { - self._touch(); - AnyMessage { - vtable: self._vtable(), - data: smallbox!(self), - } + self._into_any() } // Private API. + #[doc(hidden)] + fn _type_id() -> MessageTypeId; + #[doc(hidden)] fn _vtable(&self) -> &'static MessageVTable; - // Called while upcasting/downcasting to avoid - // [rust#47384](https://github.com/rust-lang/rust/issues/47384). + #[doc(hidden)] + #[inline(always)] + fn _can_get_from(type_id: MessageTypeId) -> bool { + Self::_type_id() == type_id + } + + // Called in `_read()` and `_write()` to avoid + // * [rust#47384](https://github.com/rust-lang/rust/issues/47384) + // * [rust#99721](https://github.com/rust-lang/rust/issues/99721) #[doc(hidden)] fn _touch(&self); + #[doc(hidden)] + #[inline(always)] + fn _into_any(self) -> AnyMessage { + AnyMessage::from_real(self) + } + #[doc(hidden)] #[inline(always)] fn _erase(&self) -> dumping::ErasedMessage { smallbox!(self.clone()) } + + #[doc(hidden)] + #[inline(always)] + fn _repr_layout(&self) -> alloc::Layout { + self._vtable().repr_layout + + // let layout = alloc::Layout::new::>(); + // TODO: debug_assert_eq!(self._vtable().repr_layout, layout); + // where to check it? + } + + #[doc(hidden)] + #[inline(always)] + unsafe fn _read(ptr: NonNull) -> Self { + let data_ref = unsafe { &ptr.cast::>().as_ref().data }; + let data = unsafe { ptr::read(data_ref) }; + data._touch(); + data + } + + #[doc(hidden)] + #[inline(always)] + unsafe fn _write(self, ptr: NonNull) { + self._touch(); + let repr = MessageRepr::new(self); + unsafe { ptr::write(ptr.cast::>().as_ptr(), repr) }; + } } +// === Request === + pub trait Request: Message { - type Response: fmt::Debug + Clone + Send + Serialize; + type Response: fmt::Debug + Clone + Send + Serialize; // TODO #[doc(hidden)] type Wrapper: Message + Into + From; } -// === AnyMessage === +// === MessageTypeId === -// Reexported in `elfo::_priv`. -pub struct AnyMessage { +#[derive(Clone, Copy, Debug)] +pub struct MessageTypeId(*const ()); + +unsafe impl Send for MessageTypeId {} +unsafe impl Sync for MessageTypeId {} + +impl MessageTypeId { + #[inline] + pub const fn new(vtable: &'static MessageVTable) -> Self { + Self(vtable as *const _ as *const ()) + } +} + +impl PartialEq for MessageTypeId { + #[inline] + fn eq(&self, other: &Self) -> bool { + ptr::eq(self.0, other.0) + } +} + +// === MessageRepr === + +#[derive(Clone)] +#[repr(C)] +pub struct MessageRepr { vtable: &'static MessageVTable, - data: SmallBox, + data: M, } +impl MessageRepr +where + M: Message, +{ + pub(crate) fn new(message: M) -> Self { + debug_assert_ne!(M::_type_id(), AnyMessage::_type_id()); + + Self { + vtable: message._vtable(), + data: message, + } + } +} + +// === AnyMessage === + +pub struct AnyMessage(NonNull); + +assert_not_impl_any!(AnyMessage: Sync); + +unsafe impl Send for AnyMessage {} + +// TODO: Pin/Unpin? +// TODO: miri strict sptr + impl AnyMessage { + pub fn new(message: M) -> Self { + message._into_any() + } + + fn from_real(message: M) -> Self { + let ptr = unsafe { alloc_repr(message._vtable()) }; + unsafe { message._write(ptr) }; + Self(ptr) + } + #[inline] - pub fn is(&self) -> bool { - self.data.is::() + pub fn type_id(&self) -> MessageTypeId { + MessageTypeId::new(self._vtable()) } #[inline] - pub fn type_id(&self) -> std::any::TypeId { - (*self.data).type_id() + pub fn is(&self) -> bool { + M::_can_get_from(self.type_id()) } #[inline] pub fn downcast_ref(&self) -> Option<&M> { - self.data.downcast_ref::().map(|message| { - message._touch(); - message - }) + self.is::() + .then(|| unsafe { self.downcast_ref_unchecked() }) + } + + pub(crate) unsafe fn downcast_ref_unchecked(&self) -> &M { + // TODO: support `AnyMessage` + debug_assert_ne!(M::_type_id(), Self::_type_id()); + + &self.0.cast::>().as_ref().data } #[inline] @@ -105,60 +218,143 @@ impl AnyMessage { return Err(self); } - let message = self - .data - .downcast::() - .expect("cannot downcast") - .into_inner(); + let data = unsafe { M::_read(self.0) }; + + unsafe { dealloc_repr(self.0) }; + + mem::forget(self); + + Ok(data) + } + + pub(crate) unsafe fn clone_into(&self, out_ptr: NonNull) { + let vtable = self._vtable(); + (vtable.clone)(self.0, out_ptr); + } + + pub(crate) unsafe fn drop_in_place(&self) { + let vtable = self._vtable(); + (vtable.drop)(self.0); + } +} + +unsafe fn alloc_repr(vtable: &'static MessageVTable) -> NonNull { + let ptr = alloc::alloc(vtable.repr_layout); + + let Some(ptr) = NonNull::new(ptr) else { + alloc::handle_alloc_error(vtable.repr_layout); + }; + + ptr.cast() +} + +unsafe fn dealloc_repr(ptr: NonNull) { + let ptr = ptr.as_ptr(); + let vtable = (*ptr).vtable; + + alloc::dealloc(ptr.cast(), vtable.repr_layout); +} + +impl Drop for AnyMessage { + fn drop(&mut self) { + unsafe { self.drop_in_place() }; + + unsafe { dealloc_repr(self.0) }; + } +} + +impl Clone for AnyMessage { + fn clone(&self) -> Self { + let out_ptr = unsafe { alloc_repr(self._vtable()) }; + + unsafe { self.clone_into(out_ptr) }; + + Self(out_ptr) + } +} - message._touch(); - Ok(message) +impl fmt::Debug for AnyMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + unsafe { (self._vtable().debug)(self.0, f) } } } impl Message for AnyMessage { #[inline(always)] - fn upcast(self) -> AnyMessage { - self + fn _type_id() -> MessageTypeId { + MessageTypeId::new(&VTABLE_STUB) } #[inline(always)] fn _vtable(&self) -> &'static MessageVTable { - self.vtable + unsafe { (*self.0.as_ptr()).vtable } + } + + #[inline(always)] + fn _can_get_from(_: MessageTypeId) -> bool { + true } #[inline(always)] fn _touch(&self) {} - #[doc(hidden)] + #[inline(always)] + fn _into_any(self) -> AnyMessage { + self + } + #[inline(always)] fn _erase(&self) -> dumping::ErasedMessage { - (self.vtable.erase)(self) + let vtable = self._vtable(); + unsafe { (vtable.erase)(self.0) } } -} -impl Clone for AnyMessage { - #[inline] - fn clone(&self) -> Self { - (self.vtable.clone)(self) + #[inline(always)] + unsafe fn _read(ptr: NonNull) -> Self { + let vtable = (*ptr.as_ptr()).vtable; + let this = unsafe { alloc_repr(vtable) }; + + unsafe { + ptr::copy_nonoverlapping( + ptr.cast::().as_ptr(), + this.cast::().as_ptr(), + vtable.repr_layout.size(), + ) + }; + + Self(this) } -} -impl fmt::Debug for AnyMessage { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - (self.vtable.debug)(self, f) + #[inline(always)] + unsafe fn _write(self, out_ptr: NonNull) { + unsafe { + ptr::copy_nonoverlapping( + self.0.cast::().as_ptr(), + out_ptr.cast::().as_ptr(), + self._vtable().repr_layout.size(), + ) + }; + + unsafe { dealloc_repr(self.0) }; + + mem::forget(self); } } // `Serialize` / `Deserialize` impls for `AnyMessage` are not used when sending -// it by itself, only when it's used in other messages. +// it by itself over network (e.g. using `ctx.send(msg)`) or dumping. +// However, it's used if +// * It's a part of another message (e.g. `struct Msg(AnyMessage)`). +// * It's serialized directly (e.g. `insta::assert_yaml_snapshot!(msg)`). impl Serialize for AnyMessage { fn serialize(&self, serializer: S) -> Result where - S: serde::ser::Serializer, + S: ser::Serializer, { - // TODO: avoid allocation here + // TODO: avoid allocation here (add `_erase_ref`) let erased_msg = self._erase(); + + // TODO: use compact form only for network? if crate::scope::serde_mode() == SerdeMode::Dumping { let mut fields = serializer.serialize_struct("AnyMessage", 3)?; fields.serialize_field("protocol", self.protocol())?; @@ -178,7 +374,7 @@ impl Serialize for AnyMessage { impl<'de> Deserialize<'de> for AnyMessage { fn deserialize(deserializer: D) -> Result where - D: serde::de::Deserializer<'de>, + D: de::Deserializer<'de>, { // We don't deserialize dumps, so we can assume it's a tuple. deserializer.deserialize_tuple(3, AnyMessageDeserializeVisitor) @@ -187,7 +383,7 @@ impl<'de> Deserialize<'de> for AnyMessage { struct AnyMessageDeserializeVisitor; -impl<'de> Visitor<'de> for AnyMessageDeserializeVisitor { +impl<'de> de::Visitor<'de> for AnyMessageDeserializeVisitor { type Value = AnyMessage; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -197,19 +393,16 @@ impl<'de> Visitor<'de> for AnyMessageDeserializeVisitor { #[inline] fn visit_seq(self, mut seq: A) -> Result where - A: SeqAccess<'de>, + A: de::SeqAccess<'de>, { - let protocol = serde::de::SeqAccess::next_element::<&str>(&mut seq)?.ok_or( - serde::de::Error::invalid_length(0usize, &"tuple of 3 elements"), - )?; + let protocol = de::SeqAccess::next_element::<&str>(&mut seq)? + .ok_or(de::Error::invalid_length(0usize, &"tuple of 3 elements"))?; - let name = serde::de::SeqAccess::next_element::<&str>(&mut seq)?.ok_or( - serde::de::Error::invalid_length(1usize, &"tuple of 3 elements"), - )?; + let name = de::SeqAccess::next_element::<&str>(&mut seq)? + .ok_or(de::Error::invalid_length(1usize, &"tuple of 3 elements"))?; - serde::de::SeqAccess::next_element_seed(&mut seq, MessageTag { protocol, name })?.ok_or( - serde::de::Error::invalid_length(2usize, &"tuple of 3 elements"), - ) + de::SeqAccess::next_element_seed(&mut seq, MessageTag { protocol, name })? + .ok_or(de::Error::invalid_length(2usize, &"tuple of 3 elements")) } } @@ -218,26 +411,32 @@ struct MessageTag<'a> { name: &'a str, } -impl<'de, 'tag> DeserializeSeed<'de> for MessageTag<'tag> { +impl<'de, 'tag> de::DeserializeSeed<'de> for MessageTag<'tag> { type Value = AnyMessage; fn deserialize(self, deserializer: D) -> Result where - D: Deserializer<'de>, + D: de::Deserializer<'de>, { - let deserialize_any = lookup_vtable(self.protocol, self.name) - .ok_or(serde::de::Error::custom( - "unknown protocol/name combination", - ))? - .deserialize_any; + let Self { protocol, name } = self; + + let vtable = lookup_vtable(protocol, name) + .ok_or_else(|| de::Error::custom(format_args!("unknown message: {protocol}/{name}")))?; + + let out_ptr = unsafe { alloc_repr(vtable) }; let mut deserializer = >::erase(deserializer); - deserialize_any(&mut deserializer).map_err(serde::de::Error::custom) + unsafe { (vtable.deserialize_any)(&mut deserializer, out_ptr) } + .map_err(de::Error::custom)?; + + Ok(AnyMessage(out_ptr)) } } cfg_network!({ - use rmp_serde as rmps; + use std::io; + + use rmp_serde::{decode, encode}; impl AnyMessage { #[doc(hidden)] @@ -246,10 +445,16 @@ cfg_network!({ buffer: &[u8], protocol: &str, name: &str, - ) -> Result, rmps::decode::Error> { - lookup_vtable(protocol, name) - .map(|vtable| (vtable.read_msgpack)(buffer)) - .transpose() + ) -> Result, decode::Error> { + let Some(vtable) = lookup_vtable(protocol, name) else { + return Ok(None); + }; + + let out_ptr = unsafe { alloc_repr(vtable) }; + + unsafe { (vtable.read_msgpack)(buffer, out_ptr) }?; + + Ok(Some(Self(out_ptr))) } #[doc(hidden)] @@ -258,36 +463,20 @@ cfg_network!({ &self, buffer: &mut Vec, limit: usize, - ) -> Result<(), rmps::encode::Error> { - (self.vtable.write_msgpack)(self, buffer, limit) + ) -> Result<(), encode::Error> { + let vtable = self._vtable(); + let out = LimitedWrite(buffer, limit); + unsafe { (vtable.write_msgpack)(self.0, out) } } } - // For monomorphization in the `#[message]` macro. - // Reexported in `elfo::_priv`. - #[inline] - pub fn read_msgpack(buffer: &[u8]) -> Result { - rmps::decode::from_slice(buffer) - } + // The compiler requires all arguments to be visible. + pub struct LimitedWrite(W, usize); - // For monomorphization in the `#[message]` macro. - // Reexported in `elfo::_priv`. - #[inline] - pub fn write_msgpack( - buffer: &mut Vec, - limit: usize, - message: &impl Message, - ) -> Result<(), rmps::encode::Error> { - let mut wr = LimitedWrite(buffer, limit); - rmps::encode::write_named(&mut wr, message) - } - - struct LimitedWrite(W, usize); - - impl std::io::Write for LimitedWrite { + impl io::Write for LimitedWrite { #[inline] - fn write(&mut self, buf: &[u8]) -> std::io::Result { - if unlikely(buf.len() > self.1) { + fn write(&mut self, buf: &[u8]) -> io::Result { + if buf.len() > self.1 { self.1 = 0; return Ok(0); } @@ -297,22 +486,75 @@ cfg_network!({ } #[inline] - fn flush(&mut self) -> std::io::Result<()> { + fn flush(&mut self) -> io::Result<()> { self.0.flush() } } }); +// === AnyMessageRef === + +// TODO: method to get AnyMessageRef from AnyMessage + +pub struct AnyMessageRef<'a> { + inner: ManuallyDrop, // never drop, borrows memory + marker: PhantomData<&'a AnyMessage>, +} + +impl<'a> AnyMessageRef<'a> { + pub(crate) unsafe fn new(ptr: NonNull) -> Self { + Self { + inner: ManuallyDrop::new(AnyMessage(ptr)), + marker: PhantomData, + } + } + + #[inline] + pub fn downcast_ref(&self) -> Option<&'a M> { + self.is::() + .then(|| unsafe { self.downcast_ref_unchecked() }) + } + + pub(crate) unsafe fn downcast_ref_unchecked(&self) -> &'a M { + &self.inner.0.cast::>().as_ref().data + } +} + +impl Deref for AnyMessageRef<'_> { + type Target = AnyMessage; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl fmt::Debug for AnyMessageRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl Serialize for AnyMessageRef<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: ser::Serializer, + { + (**self).serialize(serializer) + } +} + // === ProtocolExtractor === -// Reexported in `elfo::_priv`. // See https://github.com/GoldsteinE/gh-blog/blob/master/const_deref_specialization/src/lib.md +// Reexported in `elfo::_priv`. pub struct ProtocolExtractor; +// Reexported in `elfo::_priv`. pub trait ProtocolHolder { const PROTOCOL: Option<&'static str>; } +// Reexported in `elfo::_priv`. pub struct DefaultProtocolHolder; impl ProtocolHolder for DefaultProtocolHolder { @@ -339,31 +581,120 @@ impl DefaultProtocolHolder { // Reexported in `elfo::_priv`. /// Message Virtual Table. pub struct MessageVTable { - /// Just a message's name. + pub repr_layout: alloc::Layout, // of `MessageRepr` pub name: &'static str, - /// A protocol's name. - /// Usually, it's a crate name where the message is defined. pub protocol: &'static str, pub labels: &'static [Label], pub dumping_allowed: bool, // TODO: introduce `DumpingMode`. - pub clone: fn(&AnyMessage) -> AnyMessage, - pub debug: fn(&AnyMessage, &mut fmt::Formatter<'_>) -> fmt::Result, - pub erase: fn(&AnyMessage) -> dumping::ErasedMessage, - pub deserialize_any: - fn(&mut dyn erased_serde::Deserializer<'_>) -> Result, + // TODO: field ordering (better for cache) + // TODO: + // pub deserialize_any: fn(&mut dyn erased_serde::Deserializer<'_>) -> Result, + #[cfg(feature = "network")] + pub read_msgpack: unsafe fn(&[u8], NonNull) -> Result<(), decode::Error>, + #[cfg(feature = "network")] + #[allow(clippy::type_complexity)] + pub write_msgpack: + unsafe fn(NonNull, LimitedWrite<&mut Vec>) -> Result<(), encode::Error>, + pub debug: unsafe fn(NonNull, &mut fmt::Formatter<'_>) -> fmt::Result, + pub clone: unsafe fn(NonNull, NonNull), + pub erase: unsafe fn(NonNull) -> dumping::ErasedMessage, + pub deserialize_any: unsafe fn( + deserializer: &mut dyn erased_serde::Deserializer<'_>, + out_ptr: NonNull, + ) -> Result<(), erased_serde::Error>, + pub drop: unsafe fn(NonNull), +} + +// TODO: rename +static VTABLE_STUB: MessageVTable = MessageVTable { + repr_layout: alloc::Layout::new::<()>(), + name: "", + protocol: "", + labels: &[], + dumping_allowed: false, #[cfg(feature = "network")] - pub write_msgpack: fn(&AnyMessage, &mut Vec, usize) -> Result<(), rmps::encode::Error>, + read_msgpack: |_, _| unreachable!(), #[cfg(feature = "network")] - pub read_msgpack: fn(&[u8]) -> Result, + write_msgpack: |_, _| unreachable!(), + debug: |_, _| unreachable!(), + clone: |_, _| unreachable!(), + erase: |_| unreachable!(), + deserialize_any: |_, _| unreachable!(), + drop: |_| unreachable!(), +}; + +// For monomorphization in the `#[message]` macro. +// Reeexported in `elfo::_priv`. +pub mod vtablefns { + use super::*; + + pub unsafe fn drop(ptr: NonNull) { + ptr::drop_in_place(ptr.cast::>().as_ptr()); + } + + pub unsafe fn clone(ptr: NonNull, out_ptr: NonNull) { + ptr::write( + out_ptr.cast::>().as_ptr(), + ptr.cast::>().as_ref().clone(), + ); + } + + pub unsafe fn debug( + ptr: NonNull, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + let data = &ptr.cast::>().as_ref().data; + fmt::Debug::fmt(data, f) + } + + pub unsafe fn erase(ptr: NonNull) -> dumping::ErasedMessage { + let data = ptr.cast::>().as_ref().data.clone(); + smallbox!(data) + } + + pub unsafe fn deserialize_any( + deserializer: &mut dyn erased_serde::Deserializer<'_>, + out_ptr: NonNull, + ) -> Result<(), erased_serde::Error> { + let data = erased_serde::deserialize::(deserializer)?; + ptr::write( + out_ptr.cast::>().as_ptr(), + MessageRepr::new(data), + ); + Ok(()) + } + + cfg_network!({ + pub unsafe fn read_msgpack( + buffer: &[u8], + out_ptr: NonNull, + ) -> Result<(), decode::Error> { + let data = decode::from_slice(buffer)?; + ptr::write( + out_ptr.cast::>().as_ptr(), + MessageRepr::new(data), + ); + Ok(()) + } + + pub unsafe fn write_msgpack( + ptr: NonNull, + mut out: LimitedWrite<&mut Vec>, + ) -> Result<(), encode::Error> { + let data = &ptr.cast::>().as_ref().data; + encode::write_named(&mut out, data) + } + }); } // Reexported in `elfo::_priv`. #[distributed_slice] -pub static MESSAGE_LIST: [&'static MessageVTable] = [..]; +pub static MESSAGE_VTABLES_LIST: [&'static MessageVTable] = [..]; -static MESSAGES: Lazy> = +static MESSAGE_VTABLES_MAP: Lazy> = Lazy::new(|| { - MESSAGE_LIST + MESSAGE_VTABLES_LIST .iter() .map(|vtable| ((vtable.protocol, vtable.name), *vtable)) .collect() @@ -379,23 +710,19 @@ fn lookup_vtable(protocol: &str, name: &str) -> Option<&'static MessageVTable> { ) }; - MESSAGES.get(&(protocol, name)).copied() + MESSAGE_VTABLES_MAP.get(&(protocol, name)).copied() } pub(crate) fn check_uniqueness() -> Result<(), Vec<(String, String)>> { - if MESSAGES.len() == MESSAGE_LIST.len() { + if MESSAGE_VTABLES_MAP.len() == MESSAGE_VTABLES_LIST.len() { return Ok(()); } - fn vtable_eq(lhs: &'static MessageVTable, rhs: &'static MessageVTable) -> bool { - std::ptr::eq(lhs, rhs) - } - - Err(MESSAGE_LIST + Err(MESSAGE_VTABLES_LIST .iter() .filter(|vtable| { - let stored = MESSAGES.get(&(vtable.protocol, vtable.name)).unwrap(); - !vtable_eq(stored, vtable) + let stored = lookup_vtable(vtable.protocol, vtable.name).unwrap(); + MessageTypeId::new(stored) != MessageTypeId::new(vtable) }) .map(|vtable| (vtable.protocol.to_string(), vtable.name.to_string())) .collect::>() @@ -403,9 +730,85 @@ pub(crate) fn check_uniqueness() -> Result<(), Vec<(String, String)>> { .collect::>()) } +// TODO: add tests for `AnyMessageRef` + #[cfg(test)] mod tests { - use crate::{message, message::AnyMessage, scope::SerdeMode, Message}; + use std::sync::Arc; + + use super::*; + use crate::{message, scope::SerdeMode}; + + #[message] + #[derive(PartialEq)] + struct Unused; + + #[message] + #[derive(PartialEq)] + struct P0; + + #[message] + #[derive(PartialEq)] + struct P1(u8); + + #[message] + #[derive(PartialEq)] + struct P8(u64); + + #[message] + #[derive(PartialEq)] + struct P16(u128); + + fn check_ops(message: M) { + let message_box = AnyMessage::new(message.clone()); + + // Debug + assert_eq!(format!("{:?}", message_box), format!("{:?}", message)); + + // Clone + let message_box_2 = message_box.clone(); + assert_eq!(message_box_2.downcast::().unwrap(), message); + + // Downcast + assert!(message_box.is::()); + assert!(!message_box.is::()); + assert_eq!(message_box.downcast_ref::(), Some(&message)); + assert_eq!(message_box.downcast_ref::(), None); + + let message_box = message_box.downcast::().unwrap_err(); + assert_eq!(message_box.downcast::().unwrap(), message); + } + + #[test] + fn miri_ops() { + check_ops(P0); + check_ops(P1(42)); + check_ops(P8(424242)); + check_ops(P16(424242424242)); + } + + #[message] + struct WithImplicitDrop(Arc<()>); + + #[test] + fn miri_drop() { + let counter = Arc::new(()); + let message = WithImplicitDrop(counter.clone()); + + assert_eq!(Arc::strong_count(&counter), 2); + let message_box = AnyMessage::new(message); + assert_eq!(Arc::strong_count(&counter), 2); + let message_box_2 = message_box.clone(); + let message_box_3 = message_box.clone(); + assert_eq!(Arc::strong_count(&counter), 4); + + drop(message_box_2); + assert_eq!(Arc::strong_count(&counter), 3); + drop(message_box); + assert_eq!(Arc::strong_count(&counter), 2); + drop(message_box_3); + assert_eq!(Arc::strong_count(&counter), 1); + } #[message] #[derive(PartialEq)] @@ -428,7 +831,7 @@ mod tests { #[test] fn any_message_deserialize() { let msg = MyCoolMessage::example(); - let any_msg = msg.clone().upcast(); + let any_msg = AnyMessage::new(msg.clone()); let serialized = serde_json::to_string(&any_msg).unwrap(); let deserialized_any_msg: AnyMessage = serde_json::from_str(&serialized).unwrap(); @@ -439,7 +842,7 @@ mod tests { #[test] fn any_message_serialize() { - let any_msg = MyCoolMessage::example().upcast(); + let any_msg = AnyMessage::new(MyCoolMessage::example()); for mode in [SerdeMode::Normal, SerdeMode::Network] { let dump = crate::scope::with_serde_mode(mode, || serde_json::to_string(&any_msg).unwrap()); @@ -452,7 +855,7 @@ mod tests { #[test] fn any_message_dump() { - let any_msg = MyCoolMessage::example().upcast(); + let any_msg = AnyMessage::new(MyCoolMessage::example()); let dump = crate::scope::with_serde_mode(SerdeMode::Dumping, || { serde_json::to_string(&any_msg).unwrap() }); diff --git a/elfo-core/src/request_table.rs b/elfo-core/src/request_table.rs index afe7bad1..57444980 100644 --- a/elfo-core/src/request_table.rs +++ b/elfo-core/src/request_table.rs @@ -251,7 +251,7 @@ impl ResponseToken { #[doc(hidden)] #[inline] - pub fn forget(mut self) { + pub fn forget(&mut self) { self.data = None; } diff --git a/elfo-core/src/signal.rs b/elfo-core/src/signal.rs index d6f72c34..c5cc31c0 100644 --- a/elfo-core/src/signal.rs +++ b/elfo-core/src/signal.rs @@ -204,7 +204,7 @@ impl SourceStream for SignalSource { let message = this.message.clone(); let kind = MessageKind::Regular { sender: Addr::NULL }; let trace_id = TraceId::generate(); - let envelope = Envelope::with_trace_id(message, kind, trace_id).upcast(); + let envelope = Envelope::with_trace_id(message, kind, trace_id); Poll::Ready(Some(envelope)) } } diff --git a/elfo-core/src/stream.rs b/elfo-core/src/stream.rs index 9450b433..0d320c58 100644 --- a/elfo-core/src/stream.rs +++ b/elfo-core/src/stream.rs @@ -275,7 +275,8 @@ pub struct Emitter(mpsc::Sender); impl Emitter { /// Emits a message from the generated stream. pub async fn emit(&mut self, message: M) { - let _ = self.0.send(message.upcast()).await; + // TODO: create `Envelope` to avoid extra allocation. + let _ = self.0.send(AnyMessage::new(message)).await; } } @@ -293,7 +294,8 @@ impl StreamItem for M { /// This method is private. #[doc(hidden)] fn to_any_message(self) -> AnyMessage { - self.upcast() + // TODO: create `Envelope` to avoid extra allocation. + AnyMessage::new(self) } } diff --git a/elfo-core/src/time/delay.rs b/elfo-core/src/time/delay.rs index 7fd3371b..3f4c05b9 100644 --- a/elfo-core/src/time/delay.rs +++ b/elfo-core/src/time/delay.rs @@ -126,7 +126,7 @@ impl SourceStream for DelaySource { let message = this.message.take().unwrap(); let kind = MessageKind::Regular { sender: Addr::NULL }; let trace_id = this.trace_id.take().unwrap_or_else(TraceId::generate); - let envelope = Envelope::with_trace_id(message, kind, trace_id).upcast(); + let envelope = Envelope::with_trace_id(message, kind, trace_id); Poll::Ready(Some(envelope)) } diff --git a/elfo-core/src/time/interval.rs b/elfo-core/src/time/interval.rs index 1dd04293..f7dbfeee 100644 --- a/elfo-core/src/time/interval.rs +++ b/elfo-core/src/time/interval.rs @@ -248,7 +248,7 @@ impl SourceStream for IntervalSource { let message = this.message.clone(); let kind = MessageKind::Regular { sender: Addr::NULL }; let trace_id = TraceId::generate(); - let envelope = Envelope::with_trace_id(message, kind, trace_id).upcast(); + let envelope = Envelope::with_trace_id(message, kind, trace_id); Poll::Ready(Some(envelope)) } diff --git a/elfo-macros-impl/src/message.rs b/elfo-macros-impl/src/message.rs index c48184c4..57ce67ae 100644 --- a/elfo-macros-impl/src/message.rs +++ b/elfo-macros-impl/src/message.rs @@ -199,24 +199,11 @@ pub fn message_impl( let network_fns = cfg!(feature = "network").then(|| { quote! { - fn write_msgpack( - message: &#internal::AnyMessage, - buffer: &mut Vec, - limit: usize - ) -> ::std::result::Result<(), #internal::rmps::encode::Error> { - #internal::write_msgpack(buffer, limit, cast_ref(message)) - } - - fn read_msgpack(buffer: &[u8]) -> - ::std::result::Result<#internal::AnyMessage, #internal::rmps::decode::Error> - { - #internal::read_msgpack::<#name>(buffer).map(#crate_::Message::upcast) - } + read_msgpack: #internal::vtablefns::read_msgpack::<#name>, + write_msgpack: #internal::vtablefns::write_msgpack::<#name>, } }); - let network_fns_ref = cfg!(feature = "network").then(|| quote! { write_msgpack, read_msgpack }); - let protocol = if let Some(protocol) = &args.protocol { quote! { #protocol } } else { @@ -227,41 +214,23 @@ pub fn message_impl( quote! { impl #crate_::Message for #name { #[inline(always)] - fn _vtable(&self) -> &'static #internal::MessageVTable { - &VTABLE + fn _type_id() -> #internal::MessageTypeId { + #internal::MessageTypeId::new(VTABLE) } #[inline(always)] - fn _touch(&self) { - touch(); + fn _vtable(&self) -> &'static #internal::MessageVTable { + VTABLE } - } - - fn cast_ref(message: &#internal::AnyMessage) -> &#name { - message.downcast_ref::<#name>().expect("invalid vtable") - } - fn clone(message: &#internal::AnyMessage) -> #internal::AnyMessage { - #crate_::Message::upcast(Clone::clone(cast_ref(message))) + #[inline(never)] + fn _touch(&self) {} } - fn debug(message: &#internal::AnyMessage, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - ::std::fmt::Debug::fmt(cast_ref(message), f) - } - - fn erase(message: &#internal::AnyMessage) -> #crate_::dumping::ErasedMessage { - smallbox!(Clone::clone(cast_ref(message))) - } - - fn deserialize_any(deserializer: &mut dyn #internal::erased_serde::Deserializer<'_>) -> Result<#internal::AnyMessage, #internal::erased_serde::Error> { - #internal::erased_serde::deserialize::<#name>(deserializer).map(#crate_::Message::upcast) - } - - #network_fns - - #[linkme::distributed_slice(MESSAGE_LIST)] + #[linkme::distributed_slice(MESSAGE_VTABLES_LIST)] #[linkme(crate = linkme)] - static VTABLE: &'static #internal::MessageVTable = &#internal::MessageVTable { + static VTABLE: &#internal::MessageVTable = &#internal::MessageVTable { + repr_layout: ::std::alloc::Layout::new::<#internal::MessageRepr<#name>>(), name: #name_str, protocol: #protocol, labels: &[ @@ -269,17 +238,13 @@ pub fn message_impl( #internal::metrics::Label::from_static_parts("protocol", #protocol), ], dumping_allowed: #dumping_allowed, - clone, - debug, - erase, - deserialize_any, - #network_fns_ref + debug: #internal::vtablefns::debug::<#name>, + clone: #internal::vtablefns::clone::<#name>, + erase: #internal::vtablefns::erase::<#name>, + deserialize_any: #internal::vtablefns::deserialize_any::<#name>, + drop: #internal::vtablefns::drop::<#name>, + #network_fns }; - - // See [rust#47384](https://github.com/rust-lang/rust/issues/47384). - #[doc(hidden)] - #[inline(never)] - pub fn touch() {} } }); @@ -337,7 +302,7 @@ pub fn message_impl( const _: () = { // Keep this list as minimal as possible to avoid possible collisions with `#name`. // Especially avoid `PascalCase`. - use #internal::{MESSAGE_LIST, smallbox::smallbox, linkme}; + use #internal::{MESSAGE_VTABLES_LIST, linkme}; // TODO: remove #impl_message #impl_request diff --git a/elfo-macros-impl/src/msg.rs b/elfo-macros-impl/src/msg.rs index ce2f948b..af20d042 100644 --- a/elfo-macros-impl/src/msg.rs +++ b/elfo-macros-impl/src/msg.rs @@ -226,8 +226,8 @@ pub fn msg_impl(input: proc_macro::TokenStream, path_to_elfo: Path) -> proc_macr add_groups(&mut groups, arm); } + let type_id_ident = quote! { _elfo_type_id }; let envelope_ident = quote! { _elfo_envelope }; - let type_id_ident = quote! { _type_id_envelope }; // println!(">>> HERE {:#?}", groups); @@ -238,7 +238,7 @@ pub fn msg_impl(input: proc_macro::TokenStream, path_to_elfo: Path) -> proc_macr // - used the regular syntax while the request one is expected // - unexhaustive match (GroupKind::Regular(path), arms) => quote_spanned! { path.span()=> - else if #type_id_ident == std::any::TypeId::of::<#path>() { + else if #type_id_ident == <#path as #crate_::Message>::_type_id() { // Ensure it's not a request, or a request but only in a borrowed context. // We cannot use `static_assertions` here because it wraps the check into // a closure that forbids us to use generic `msg!`: (`msg!(match e { M => .. })`). @@ -255,18 +255,15 @@ pub fn msg_impl(input: proc_macro::TokenStream, path_to_elfo: Path) -> proc_macr match { // Support both owned and borrowed contexts, relying on the type inference. #[allow(unused_imports)] - use #internal::{ - EnvelopeOwned as _, EnvelopeBorrowed as _, - AnyMessageOwned as _, AnyMessageBorrowed as _, - }; - #envelope_ident.unpack_regular().downcast2::<#path>() + use #internal::{EnvelopeOwned as _, EnvelopeBorrowed as _}; + unsafe { #envelope_ident.unpack_regular_unchecked::<#path>() } } { #(#arms)* } } }, (GroupKind::Request(path), arms) => quote_spanned! { path.span()=> - else if #type_id_ident == std::any::TypeId::of::<#path>() { + else if #type_id_ident == <#path as #crate_::Message>::_type_id() { // Ensure it's a request. We cannot use `static_assertions` here // because it wraps the check into a closure that forbids us to // use generic `msg!`: (`msg!(match e { (R, token) => .. })`). @@ -279,9 +276,8 @@ pub fn msg_impl(input: proc_macro::TokenStream, path_to_elfo: Path) -> proc_macr match { // Only the owned context is supported. #[allow(unused_imports)] - use #internal::{EnvelopeOwned as _, AnyMessageOwned as _}; - let (message, token) = #envelope_ident.unpack_request(); - (message.downcast2::<#path>(), token.into_received::<#path>()) + use #internal::EnvelopeOwned as _; + unsafe { #envelope_ident.unpack_request_unchecked::<#path>() } } { #(#arms)* } diff --git a/elfo-network/src/codec/format.rs b/elfo-network/src/codec/format.rs index 444cc91e..a273c02a 100644 --- a/elfo-network/src/codec/format.rs +++ b/elfo-network/src/codec/format.rs @@ -123,6 +123,7 @@ impl<'de> serde::Deserialize<'de> for NetworkAddr { } } +// TODO: use `elfo::Envelope` to avoid extra allocation with `AnyMessage`. #[derive(Debug)] pub(crate) enum NetworkEnvelopePayload { Regular { diff --git a/elfo-network/src/codec/mod.rs b/elfo-network/src/codec/mod.rs index 588a9318..b8c69d50 100644 --- a/elfo-network/src/codec/mod.rs +++ b/elfo-network/src/codec/mod.rs @@ -21,12 +21,14 @@ mod tests { #[derive(PartialEq)] struct BigMessage(String); - fn make_envelope(message: AnyMessage, trace_index: u64) -> NetworkEnvelope { + fn make_envelope(message: impl Message, trace_index: u64) -> NetworkEnvelope { NetworkEnvelope { sender: NetworkAddr::NULL, recipient: NetworkAddr::NULL, trace_id: TraceId::try_from(trace_index).unwrap(), - payload: NetworkEnvelopePayload::Regular { message }, + payload: NetworkEnvelopePayload::Regular { + message: AnyMessage::new(message), + }, } } @@ -51,8 +53,8 @@ mod tests { let mut position = 0; for i in 1..5 { - let small_envelope = make_envelope(SmallMessage(i).upcast(), i); - let big_envelope = make_envelope(BigMessage("oops".repeat(100)).upcast(), i); + let small_envelope = make_envelope(SmallMessage(i), i); + let big_envelope = make_envelope(BigMessage("oops".repeat(100)), i); // Small message must fit into 100 bytes, but big message must not. const LIMIT: Option = Some(100); @@ -100,7 +102,7 @@ mod tests { fn test_decode_skip() { let mut bytes = Vec::new(); - let envelope = make_envelope(BigMessage("a".repeat(100)).upcast(), 1); + let envelope = make_envelope(BigMessage("a".repeat(100)), 1); // Encode two messages. encode(&envelope, &mut bytes, &mut Default::default(), None).unwrap(); diff --git a/elfo-network/src/discovery/mod.rs b/elfo-network/src/discovery/mod.rs index 16da04d3..4174df11 100644 --- a/elfo-network/src/discovery/mod.rs +++ b/elfo-network/src/discovery/mod.rs @@ -6,7 +6,7 @@ use tracing::{debug, error, info, warn}; use elfo_core::{ message, msg, scope, Envelope, Message, MoveOwnership, RestartPolicy, - _priv::{GroupNo, MessageKind}, + _priv::{AnyMessage, GroupNo, MessageKind}, messages::ConfigUpdated, stream::Stream, RestartParams, Topology, @@ -488,7 +488,7 @@ async fn send_regular(socket: &mut Socket, msg: M) -> Result<()> { recipient: NetworkAddr::NULL, // doesn't matter trace_id: scope::trace_id(), payload: NetworkEnvelopePayload::Regular { - message: msg.upcast(), + message: AnyMessage::new(msg), }, }; diff --git a/elfo-network/src/socket/mod.rs b/elfo-network/src/socket/mod.rs index 78c47fce..ed3e9f2b 100644 --- a/elfo-network/src/socket/mod.rs +++ b/elfo-network/src/socket/mod.rs @@ -276,7 +276,7 @@ mod tests { use futures::{future, stream::StreamExt}; use tracing::debug; - use elfo_core::{message, tracing::TraceId, Message}; + use elfo_core::{_priv::AnyMessage, message, tracing::TraceId}; use crate::codec::format::{NetworkAddr, NetworkEnvelopePayload}; @@ -361,7 +361,7 @@ mod tests { recipient: NetworkAddr::NULL, trace_id: TraceId::try_from(1).unwrap(), payload: NetworkEnvelopePayload::Regular { - message: TestSocketMessage("a".repeat(i * 10)).upcast(), + message: AnyMessage::new(TestSocketMessage("a".repeat(i * 10))), }, }; diff --git a/elfo-network/src/worker/mod.rs b/elfo-network/src/worker/mod.rs index f54911e3..16bea898 100644 --- a/elfo-network/src/worker/mod.rs +++ b/elfo-network/src/worker/mod.rs @@ -7,7 +7,7 @@ use tracing::{debug, error, info, trace, warn}; use elfo_core::{ message, Local, Message, - _priv::{EbrGuard, EnvelopeOwned, GroupVisitor, MessageKind, NodeNo, Object, OwnedObject}, + _priv::{AnyMessage, EbrGuard, GroupVisitor, MessageKind, NodeNo, Object, OwnedObject}, errors::{RequestError, SendError, TrySendError}, messages::{ConfigUpdated, Impossible}, msg, remote, scope, @@ -276,50 +276,41 @@ fn make_network_envelope( (Ok(envelope), None) => { let sender = envelope.sender(); let trace_id = envelope.trace_id(); + let (message, kind) = envelope.unpack::().expect("impossible"); - let (payload, token) = match envelope.message_kind() { - MessageKind::Regular { .. } => ( - NetworkEnvelopePayload::Regular { - message: envelope.unpack_regular(), + let (payload, token) = match kind { + MessageKind::Regular { .. } => (NetworkEnvelopePayload::Regular { message }, None), + MessageKind::RequestAny(token) => ( + NetworkEnvelopePayload::RequestAny { + request_id: token.request_id(), + message, }, - None, + Some(token), + ), + MessageKind::RequestAll(token) => ( + NetworkEnvelopePayload::RequestAll { + request_id: token.request_id(), + message, + }, + Some(token), ), - MessageKind::RequestAny(_) => { - let (message, token) = envelope.unpack_request(); - ( - NetworkEnvelopePayload::RequestAny { - request_id: token.request_id(), - message, - }, - Some(token), - ) - } - MessageKind::RequestAll(_) => { - let (message, token) = envelope.unpack_request(); - ( - NetworkEnvelopePayload::RequestAll { - request_id: token.request_id(), - message, - }, - Some(token), - ) - } MessageKind::Response { .. } => unreachable!(), }; (sender, trace_id, payload, token) } // Response - (Ok(envelope), Some(token)) => { + (Ok(envelope), Some(mut token)) => { let sender = envelope.sender(); let trace_id = envelope.trace_id(); + let (message, kind) = envelope.unpack::().expect("impossible"); - let payload = match envelope.message_kind() { + let payload = match kind { MessageKind::Response { request_id, .. } => { - debug_assert_eq!(*request_id, token.request_id()); + debug_assert_eq!(request_id, token.request_id()); NetworkEnvelopePayload::Response { - request_id: *request_id, - message: Ok(envelope.unpack_regular()), + request_id, + message: Ok(message), is_last: token.is_last(), } } @@ -332,7 +323,7 @@ fn make_network_envelope( (sender, trace_id, payload, None) } // Failed/Ignored Response - (Err(err), Some(token)) => { + (Err(err), Some(mut token)) => { let sender = Addr::NULL; let trace_id = token.trace_id(); @@ -721,10 +712,7 @@ impl SocketReader { } fn make_system_envelope(message: impl Message) -> Envelope { - Envelope::new( - message.upcast(), - MessageKind::Regular { sender: Addr::NULL }, - ) + Envelope::new(message, MessageKind::Regular { sender: Addr::NULL }) } // === Pusher === diff --git a/elfo-test/src/proxy.rs b/elfo-test/src/proxy.rs index 9f120dc5..1fd1c45f 100644 --- a/elfo-test/src/proxy.rs +++ b/elfo-test/src/proxy.rs @@ -254,7 +254,7 @@ fn testers(tx: shared::OneshotSender) -> Blueprint { (StealContext, token) => { ctx.pruned().respond(token, Local::from(ctx)); } - envelope => panic!("unexpected message: {envelope:?}"), + envelope => panic!("unexpected message: {:?}", envelope.message()), }); } diff --git a/elfo-test/src/utils.rs b/elfo-test/src/utils.rs index 257fc2e3..afdc8e5e 100644 --- a/elfo-test/src/utils.rs +++ b/elfo-test/src/utils.rs @@ -44,8 +44,7 @@ mod tests { #[test] fn extract_message_test() { create_scope().sync_within(|| { - let envelop = - Envelope::new(TestMessage, MessageKind::Regular { sender: Addr::NULL }).upcast(); + let envelop = Envelope::new(TestMessage, MessageKind::Regular { sender: Addr::NULL }); let resp = extract_message::(envelop); assert_eq!(resp, TestMessage); }); @@ -55,8 +54,7 @@ mod tests { #[should_panic(expected = "expected TestMessage, got TestRequest")] fn extract_message_panic_test() { create_scope().sync_within(|| { - let envelop = - Envelope::new(TestRequest, MessageKind::Regular { sender: Addr::NULL }).upcast(); + let envelop = Envelope::new(TestRequest, MessageKind::Regular { sender: Addr::NULL }); extract_message::(envelop); }); } @@ -64,8 +62,7 @@ mod tests { #[test] fn extract_request_test() { create_scope().sync_within(|| { - let envelop = - Envelope::new(TestRequest, MessageKind::Regular { sender: Addr::NULL }).upcast(); + let envelop = Envelope::new(TestRequest, MessageKind::Regular { sender: Addr::NULL }); let (resp, _token) = extract_request::(envelop); assert_eq!(resp, TestRequest); }); diff --git a/elfo-utils/src/time.rs b/elfo-utils/src/time.rs index 6d623761..7a4a5947 100644 --- a/elfo-utils/src/time.rs +++ b/elfo-utils/src/time.rs @@ -82,11 +82,13 @@ mod mock { } /// Mocks `Instant`, see [`InstantMock`]. - pub fn with_instant_mock(f: impl FnOnce(InstantMock)) { + pub fn with_instant_mock(f: impl FnOnce(InstantMock) -> R) -> R { let (clock, mock) = Clock::mock(); let mock = InstantMock(mock); CLOCK.with(|c| *c.borrow_mut() = Some(clock)); - f(mock); + let result = f(mock); + CLOCK.with(|c| *c.borrow_mut() = None); + result } /// Controllable time source for use in tests. diff --git a/elfo/Cargo.toml b/elfo/Cargo.toml index c56157ff..2436e191 100644 --- a/elfo/Cargo.toml +++ b/elfo/Cargo.toml @@ -45,12 +45,9 @@ futures = "0.3.12" derive_more = "0.99.11" tokio = { version = "1", features = ["full"] } tracing = "0.1.25" -metrics = "0.17" tracing-subscriber = "0.3" serde = { version = "1.0.120", features = ["derive"] } toml = "0.7" -humantime-serde = "1" -criterion = "0.5.1" static_assertions = "1.1.0" parking_lot = "0.12" libc = "0.2.97" diff --git a/elfo/tests/msg_macro.rs b/elfo/tests/msg_macro.rs index 5a8243b0..68a450a9 100644 --- a/elfo/tests/msg_macro.rs +++ b/elfo/tests/msg_macro.rs @@ -42,11 +42,19 @@ enum Type { fn sample() -> Blueprint { ActorGroup::new().exec(|mut ctx| async move { while let Some(envelope) = ctx.recv().await { - msg!(match &envelope { + // Borrowed usage. + let _flag = msg!(match &envelope { Unit | Tuple | Struct | ReqStruct => true, _ => false, }); + // Regression for "borrowed value does not live long enough". + let _a = msg!(match &envelope { + Struct { a } => Some(a), + _ => None, + }); + + // Owned usage. msg!(match envelope { // Unit. Unit => ctx.send(Type::Unit(0)).await.unwrap(), diff --git a/elfo/tests/protocol_evolution.rs b/elfo/tests/protocol_evolution.rs index 9f70472b..c235aff6 100644 --- a/elfo/tests/protocol_evolution.rs +++ b/elfo/tests/protocol_evolution.rs @@ -8,7 +8,7 @@ use elfo::{_priv::AnyMessage, prelude::*, Message}; fn parse(input: A, expected: B) -> Result<()> { let mut buf = Vec::new(); - let input = input.upcast(); + let input = AnyMessage::new(input); input.write_msgpack(&mut buf, 512)?; let actual = AnyMessage::read_msgpack(&buf, expected.protocol(), expected.name())? .ok_or_else(|| anyhow!("no such message"))?