Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PlatformRef::connect_* instantaneously return #1279

Merged
merged 4 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 2 additions & 26 deletions full-node/src/network_service/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,8 @@ pub(super) async fn connection_task(
mut coordinator_to_connection: channel::Receiver<service::CoordinatorToConnection>,
connection_to_coordinator: channel::Sender<super::ToBackground>,
) {
// Finishing ongoing connection process.
let socket = match socket.await.map_err(|_| ()) {
Ok(s) => s,
Err(_err) => {
// TODO: log
connection_task.reset();
loop {
let (task_update, opaque_message) = connection_task.pull_message_to_coordinator();
let _ = connection_to_coordinator
.send(super::ToBackground::FromConnectionTask {
connection_id,
opaque_message,
connection_now_dead: true,
})
.await;
if let Some(task_update) = task_update {
connection_task = task_update;
} else {
return;
}
}
}
};

// The socket is wrapped around an object containing a read buffer and a write buffer and
// allowing easier usage.
// The socket future is wrapped around an object containing a read buffer and a write buffer
// and allowing easier usage.
let mut socket = pin::pin!(with_buffers::WithBuffers::new(socket));

// Future that sends a message to the coordinator. Only one message is sent to the coordinator
Expand Down
212 changes: 121 additions & 91 deletions lib/src/libp2p/with_buffers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ use std::io;
/// Holds an implementation of `AsyncRead` and `AsyncWrite`, alongside with a read buffer and a
/// write buffer.
#[pin_project::pin_project]
pub struct WithBuffers<TSocket, TNow> {
pub struct WithBuffers<TSocketFut, TSocket, TNow> {
/// Actual socket to read from/write to.
#[pin]
socket: TSocket,
socket: Socket<TSocketFut, TSocket>,
/// Error that has happened on the socket, if any.
error: Option<io::Error>,
/// Storage for data read from the socket. The first [`WithBuffers::read_buffer_valid`] bytes
Expand Down Expand Up @@ -73,18 +73,22 @@ pub struct WithBuffers<TSocket, TNow> {
read_write_wake_up_after: Option<TNow>,
}

impl<TSocket, TNow> WithBuffers<TSocket, TNow>
#[pin_project::pin_project(project = SocketProj)]
enum Socket<TSocketFut, TSocket> {
Pending(#[pin] TSocketFut),
Resolved(#[pin] TSocket),
}

impl<TSocketFut, TSocket, TNow> WithBuffers<TSocketFut, TSocket, TNow>
where
TNow: Clone + Ord,
{
/// Initializes a new [`WithBuffers`] with the given socket.
///
/// The socket must still be open in both directions.
pub fn new(socket: TSocket) -> Self {
/// Initializes a new [`WithBuffers`] with the given socket-yielding future.
pub fn new(socket: TSocketFut) -> Self {
let read_buffer_reasonable_capacity = 65536; // TODO: make configurable?

WithBuffers {
socket,
socket: Socket::Pending(socket),
error: None,
read_buffer: Vec::with_capacity(read_buffer_reasonable_capacity),
read_buffer_valid: 0,
Expand Down Expand Up @@ -123,6 +127,8 @@ where

this.read_buffer.truncate(*this.read_buffer_valid);

let is_resolved = matches!(*this.socket, Socket::Resolved(_));

let write_bytes_queued = this.write_buffers.iter().map(Vec::len).sum();

Ok(ReadWriteAccess {
Expand All @@ -135,7 +141,9 @@ where
read_bytes: 0,
write_bytes_queued,
write_buffers: mem::take(this.write_buffers),
write_bytes_queueable: if !*this.write_closed {
write_bytes_queueable: if !is_resolved {
Some(0)
} else if !*this.write_closed {
// Limit outgoing buffer size to 128kiB.
// TODO: make configurable?
Some((128 * 1024usize).saturating_sub(write_bytes_queued))
Expand All @@ -155,9 +163,10 @@ where
}
}

impl<TSocket, TNow> WithBuffers<TSocket, TNow>
impl<TSocketFut, TSocket, TNow> WithBuffers<TSocketFut, TSocket, TNow>
where
TSocket: AsyncRead + AsyncWrite,
TSocketFut: future::Future<Output = Result<TSocket, io::Error>>,
TNow: Clone + Ord,
{
/// Waits until [`WithBuffers::read_write_access`] should be called again.
Expand Down Expand Up @@ -213,100 +222,115 @@ where
}
}

if !*this.read_closed {
let read_result = AsyncRead::poll_read(
this.socket.as_mut(),
cx,
&mut this.read_buffer[*this.read_buffer_valid..],
);

match read_result {
match this.socket.as_mut().project() {
SocketProj::Pending(future) => match future::Future::poll(future, cx) {
Poll::Pending => {}
Poll::Ready(Ok(0)) => {
*this.read_closed = true;
pending = false;
}
Poll::Ready(Ok(n)) => {
*this.read_buffer_valid += n;
// TODO: consider waking up only if the expected bytes of the consumer are exceeded
Poll::Ready(Ok(socket)) => {
this.socket.set(Socket::Resolved(socket));
pending = false;
}
Poll::Ready(Err(err)) => {
*this.error = Some(err);
return Poll::Ready(());
}
};
}

loop {
if this.write_buffers.iter().any(|b| !b.is_empty()) {
let write_result = {
let buffers = this
.write_buffers
.iter()
.map(|buf| io::IoSlice::new(buf))
.collect::<Vec<_>>();
AsyncWrite::poll_write_vectored(this.socket.as_mut(), cx, &buffers)
};

match write_result {
Poll::Ready(Ok(0)) => {
// It is not legal for `poll_write` to return 0 bytes written.
unreachable!();
}
Poll::Ready(Ok(mut n)) => {
*this.flush_pending = true;
while n > 0 {
let first_buf = this.write_buffers.first_mut().unwrap();
if first_buf.len() <= n {
n -= first_buf.len();
this.write_buffers.remove(0);
} else {
// TODO: consider keeping the buffer as is but starting the next write at a later offset
first_buf.copy_within(n.., 0);
first_buf.truncate(first_buf.len() - n);
break;
}
},
SocketProj::Resolved(mut socket) => {
if !*this.read_closed {
let read_result = AsyncRead::poll_read(
socket.as_mut(),
cx,
&mut this.read_buffer[*this.read_buffer_valid..],
);

match read_result {
Poll::Pending => {}
Poll::Ready(Ok(0)) => {
*this.read_closed = true;
pending = false;
}
// Wake up if the write buffers switch from non-empty to empty.
if this.write_buffers.is_empty() {
Poll::Ready(Ok(n)) => {
*this.read_buffer_valid += n;
// TODO: consider waking up only if the expected bytes of the consumer are exceeded
pending = false;
}
}
Poll::Ready(Err(err)) => {
*this.error = Some(err);
return Poll::Ready(());
}
Poll::Pending => break,
};
} else if *this.flush_pending {
match AsyncWrite::poll_flush(this.socket.as_mut(), cx) {
Poll::Ready(Ok(())) => {
*this.flush_pending = false;
}
Poll::Ready(Err(err)) => {
*this.error = Some(err);
return Poll::Ready(());
}
Poll::Pending => break,
Poll::Ready(Err(err)) => {
*this.error = Some(err);
return Poll::Ready(());
}
};
}
} else if *this.close_pending {
match AsyncWrite::poll_close(this.socket.as_mut(), cx) {
Poll::Ready(Ok(())) => {
*this.close_pending = false;
pending = false;

loop {
if this.write_buffers.iter().any(|b| !b.is_empty()) {
let write_result = {
let buffers = this
.write_buffers
.iter()
.map(|buf| io::IoSlice::new(buf))
.collect::<Vec<_>>();
AsyncWrite::poll_write_vectored(socket.as_mut(), cx, &buffers)
};

match write_result {
Poll::Ready(Ok(0)) => {
// It is not legal for `poll_write` to return 0 bytes written.
unreachable!();
}
Poll::Ready(Ok(mut n)) => {
*this.flush_pending = true;
while n > 0 {
let first_buf = this.write_buffers.first_mut().unwrap();
if first_buf.len() <= n {
n -= first_buf.len();
this.write_buffers.remove(0);
} else {
// TODO: consider keeping the buffer as is but starting the next write at a later offset
first_buf.copy_within(n.., 0);
first_buf.truncate(first_buf.len() - n);
break;
}
}
// Wake up if the write buffers switch from non-empty to empty.
if this.write_buffers.is_empty() {
pending = false;
}
}
Poll::Ready(Err(err)) => {
*this.error = Some(err);
return Poll::Ready(());
}
Poll::Pending => break,
};
} else if *this.flush_pending {
match AsyncWrite::poll_flush(socket.as_mut(), cx) {
Poll::Ready(Ok(())) => {
*this.flush_pending = false;
}
Poll::Ready(Err(err)) => {
*this.error = Some(err);
return Poll::Ready(());
}
Poll::Pending => break,
}
} else if *this.close_pending {
match AsyncWrite::poll_close(socket.as_mut(), cx) {
Poll::Ready(Ok(())) => {
*this.close_pending = false;
pending = false;
break;
}
Poll::Ready(Err(err)) => {
*this.error = Some(err);
return Poll::Ready(());
}
Poll::Pending => break,
}
} else {
break;
}
Poll::Ready(Err(err)) => {
*this.error = Some(err);
return Poll::Ready(());
}
Poll::Pending => break,
}
} else {
break;
}
}
};

if !pending {
Poll::Ready(())
Expand All @@ -318,9 +342,15 @@ where
}
}

impl<TSocket: fmt::Debug, TNow> fmt::Debug for WithBuffers<TSocket, TNow> {
impl<TSocketFut, TSocket: fmt::Debug, TNow> fmt::Debug for WithBuffers<TSocketFut, TSocket, TNow> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("WithBuffers").field(&self.socket).finish()
let mut t = f.debug_tuple("WithBuffers");
if let Socket::Resolved(socket) = &self.socket {
t.field(socket);
} else {
t.field(&"<pending>");
}
t.finish()
}
}

Expand Down
22 changes: 12 additions & 10 deletions light-base/src/network_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1845,7 +1845,11 @@ async fn background_task<TPlat: PlatformRef>(mut task: BackgroundTask<TPlat>) {
let task_name = format!("connection-{}-{}", peer_id, multiaddr);

let connection_id = match address {
address_parse::AddressOrMultiStreamAddress::Address(_) => {
address_parse::AddressOrMultiStreamAddress::Address(address) => {
// As documented in the `PlatformRef` trait, `connect_stream` must
// return as soon as possible.
let connection = task.platform.connect_stream(address).await;

let (connection_id, connection_task) =
task.network.add_single_stream_connection(
task.platform.now(),
Expand All @@ -1860,7 +1864,8 @@ async fn background_task<TPlat: PlatformRef>(mut task: BackgroundTask<TPlat>) {
task.platform.spawn_task(
task_name.into(),
tasks::single_stream_connection_task::<TPlat>(
multiaddr,
connection,
multiaddr.to_string(),
task.platform.clone(),
connection_id,
connection_task,
Expand All @@ -1878,21 +1883,18 @@ async fn background_task<TPlat: PlatformRef>(mut task: BackgroundTask<TPlat>) {
remote_certificate_sha256,
},
) => {
// TODO: we unfortunately need to know the local TLS certificate in order to
// insert the connection, and this local TLS certificate can only be given
// to us by the platform implementation, leading to this `await` here which
// really shouldn't exist. For the moment it's fine because the only implementations
// of multistream connections returns very quickly, but in theory this `await`
// could block for a long time.
// We need to know the local TLS certificate in order to insert the
// connection, and as such we need to call `connect_multistream` here.
// As documented in the `PlatformRef` trait, `connect_multistream` must
// return as soon as possible.
let connection = task
.platform
.connect_multistream(platform::MultiStreamAddress::WebRtc {
ip,
port,
remote_certificate_sha256,
})
.await
.unwrap_or_else(|_| unreachable!()); // TODO: don't unwrap, again we know that the only implementation that exists never unwraps here, but in theory it's possible
.await;

// Convert the SHA256 hashes into multihashes.
let local_tls_certificate_multihash = [12u8, 32]
Expand Down
Loading
Loading