diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index 254aff4a7f..c84726cce5 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -501,6 +501,43 @@ impl Methods { } } +#[derive(Default, Debug, Clone)] +/// Wraps an arbitrary number of [`Methods`] instances, pointing to the current [`Methods`] instance +pub struct MethodsPicker { + inner: Arc>, + current: Methods, +} + +impl From for MethodsPicker { + fn from(m: Methods) -> Self { + Self { inner: Arc::new(vec![m.clone()]), current: m } + } +} + +impl From> for MethodsPicker { + fn from(v: Vec) -> Self { + let current = if v.is_empty() { Methods::default() } else { v[0].clone() }; + Self { inner: Arc::new(v), current } + } +} + +impl MethodsPicker { + /// Instruct the picker which [`Methods`] instance to use for the current request + pub fn pick(&mut self, f: F) + where + F: FnOnce(&[Methods]) -> &Methods, + { + let current = f(&self.inner); + self.current = current.clone(); + } + + /// Points to the currently picked [`Methods`] set. + /// Returns [`Methods::default()`] if the internal collection is empty. + pub fn current(&self) -> Methods { + self.current.clone() + } +} + impl Deref for RpcModule { type Target = Methods; diff --git a/examples/examples/method_router.rs b/examples/examples/method_router.rs new file mode 100644 index 0000000000..432b4fadb8 --- /dev/null +++ b/examples/examples/method_router.rs @@ -0,0 +1,112 @@ +//! This example sets a custom tower service middleware which picks a variant +//! of rpc methods depending on the uri path. +//! +//! It works with both `WebSocket` and `HTTP` which is done in the example. + +use jsonrpsee::rpc_params; +use std::net::SocketAddr; + +use jsonrpsee::core::client::ClientT; +use jsonrpsee::http_client::HttpClientBuilder; +use jsonrpsee::server::{logger::Logger, RpcModule, ServerBuilder, TowerService}; +use jsonrpsee::ws_client::WsClientBuilder; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let addr = run_server().await?; + + // HTTP. + { + let client = HttpClientBuilder::default().build(format!("http://{}/v1", addr))?; + let response: String = client.request("say_hello", rpc_params![]).await?; + println!("[main]: http response: {:?}", response); + } + { + let client = HttpClientBuilder::default().build(format!("http://{}/v2", addr))?; + let response: String = client.request("say_hello", rpc_params![]).await?; + println!("[main]: http response: {:?}", response); + } + { + let client = HttpClientBuilder::default().build(format!("http://{}", addr))?; + let response = client.request::("say_hello", rpc_params![]).await.expect_err("404"); + println!("[main]: http response: {:}", response); + } + + // WebSocket. + { + let client = WsClientBuilder::default().build(format!("ws://{}/v1", addr)).await?; + let response: String = client.request("say_hello", rpc_params![]).await?; + println!("[main]: ws response: {:?}", response); + } + { + let client = WsClientBuilder::default().build(format!("ws://{}/v2", addr)).await?; + let response: String = client.request("say_hello", rpc_params![]).await?; + println!("[main]: ws response: {:?}", response); + } + { + let error = WsClientBuilder::default().build(format!("ws://{}", addr)).await.expect_err("404"); + println!("[main]: ws response: {:}", error); + } + + Ok(()) +} + +/// Wraps the ultimate core service of the jsonrpsee server in order to access its RPC method picker. +struct MethodRouter(TowerService); + +impl tower::Service> for MethodRouter +where + L: Logger, +{ + type Response = as hyper::service::Service>>::Response; + type Error = as hyper::service::Service>>::Error; + type Future = as hyper::service::Service>>::Future; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, req: hyper::Request) -> Self::Future { + let idx = match req.uri().path() { + "/v1" => 0, + "/v2" => 1, + _ => return Box::pin(std::future::ready(Ok(jsonrpsee::server::http_response::not_found()))), + }; + + self.0.inner.methods.pick(|all_methods| &all_methods[idx]); + self.0.call(req) + } +} + +struct MethodRouterLayer; + +impl tower::Layer> for MethodRouterLayer { + type Service = MethodRouter; + + fn layer(&self, inner: TowerService) -> Self::Service { + MethodRouter(inner) + } +} + +async fn run_server() -> anyhow::Result { + let service_builder = tower::ServiceBuilder::new().layer(MethodRouterLayer); + + let server = + ServerBuilder::new().set_middleware(service_builder).build("127.0.0.1:0".parse::()?).await?; + + let addr = server.local_addr()?; + + let mut module_v1 = RpcModule::new(()); + module_v1.register_method("say_hello", |_, _| Ok("lo v1")).unwrap(); + let mut module_v2 = RpcModule::new(()); + module_v2.register_method("say_hello", |_, _| Ok("lo v2")).unwrap(); + + // Serve different apis on different paths + let handle = server.start_with_methods_variants([module_v1.into(), module_v2.into()])?; + + // In this example we don't care about doing shutdown so let's it run forever. + // You may use the `ServerHandle` to shut it down or manage it yourself. + tokio::spawn(handle.stopped()); + + Ok(addr) +} diff --git a/server/src/lib.rs b/server/src/lib.rs index a002131b86..b7fb3c68f6 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -47,3 +47,6 @@ pub use jsonrpsee_core::{id_providers::*, traits::IdProvider}; pub use jsonrpsee_types as types; pub use server::{Builder as ServerBuilder, Server}; pub use tracing; + +pub use server::{ServiceData, TowerService}; +pub use transport::http::response as http_response; diff --git a/server/src/server.rs b/server/src/server.rs index a89841d3da..76f8b13dd1 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -45,7 +45,7 @@ use jsonrpsee_core::id_providers::RandomIntegerIdProvider; use jsonrpsee_core::server::helpers::MethodResponse; use jsonrpsee_core::server::host_filtering::AllowHosts; use jsonrpsee_core::server::resource_limiting::Resources; -use jsonrpsee_core::server::rpc_module::Methods; +use jsonrpsee_core::server::rpc_module::{Methods, MethodsPicker}; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_core::{http_helpers, Error, TEN_MB_SIZE_BYTES}; @@ -106,7 +106,7 @@ where /// /// This will run on the tokio runtime until the server is stopped or the `ServerHandle` is dropped. pub fn start(mut self, methods: impl Into) -> Result { - let methods = methods.into().initialize_resources(&self.resources)?; + let methods = methods.into().initialize_resources(&self.resources)?.into(); let (stop_tx, stop_rx) = watch::channel(()); let stop_handle = StopHandle::new(stop_rx); @@ -119,7 +119,36 @@ where Ok(ServerHandle::new(stop_tx)) } - async fn start_inner(self, methods: Methods, stop_handle: StopHandle) { + /// Start responding to connections requests. + /// + /// By default utilizes the first item in `methods` unless instructed otherwise via middleware. + /// See `examples/method_router` for more details. + /// + /// Replies with [`Error::MethodNotFound`] if the `methods` collection is empty. + /// + /// This will run on the tokio runtime until the server is stopped or the `ServerHandle` is dropped. + pub fn start_with_methods_variants( + mut self, + methods: impl IntoIterator, + ) -> Result { + let methods = methods + .into_iter() + .map(|methods| methods.initialize_resources(&self.resources)) + .collect::, Error>>()? + .into(); + let (stop_tx, stop_rx) = watch::channel(()); + + let stop_handle = StopHandle::new(stop_rx); + + match self.cfg.tokio_runtime.take() { + Some(rt) => rt.spawn(self.start_inner(methods, stop_handle)), + None => tokio::spawn(self.start_inner(methods, stop_handle)), + }; + + Ok(ServerHandle::new(stop_tx)) + } + + async fn start_inner(self, methods: MethodsPicker, stop_handle: StopHandle) { let max_request_body_size = self.cfg.max_request_body_size; let max_response_body_size = self.cfg.max_response_body_size; let max_log_length = self.cfg.max_log_length; @@ -535,11 +564,12 @@ impl MethodResult { /// Data required by the server to handle requests. #[derive(Debug, Clone)] -pub(crate) struct ServiceData { +pub struct ServiceData { /// Remote server address. pub(crate) remote_addr: SocketAddr, /// Registered server methods. - pub(crate) methods: Methods, + /// FIXME making it public for the example only + pub methods: MethodsPicker, /// Access control. pub(crate) allow_hosts: AllowHosts, /// Tracker for currently used resources on the server. @@ -576,7 +606,8 @@ pub(crate) struct ServiceData { /// This is similar to [`hyper::service::service_fn`]. #[derive(Debug)] pub struct TowerService { - inner: ServiceData, + /// FIXME making it public for the example only + pub inner: ServiceData, } impl hyper::service::Service> for TowerService { @@ -647,7 +678,7 @@ impl hyper::service::Service> for TowerSe } else { // The request wasn't an upgrade request; let's treat it as a standard HTTP request: let data = http::HandleRequest { - methods: self.inner.methods.clone(), + methods: self.inner.methods.current(), resources: self.inner.resources.clone(), max_request_body_size: self.inner.max_request_body_size, max_response_body_size: self.inner.max_response_body_size, diff --git a/server/src/transport/http.rs b/server/src/transport/http.rs index 8193950013..a6451d81a7 100644 --- a/server/src/transport/http.rs +++ b/server/src/transport/http.rs @@ -337,7 +337,8 @@ pub(crate) async fn handle_request( res } -pub(crate) mod response { +/// FIXME making it public for the example only +pub mod response { use jsonrpsee_types::error::reject_too_big_request; use jsonrpsee_types::error::{ErrorCode, ErrorResponse}; use jsonrpsee_types::Id; @@ -420,4 +421,9 @@ pub(crate) mod response { TEXT, ) } + + /// 404 + pub fn not_found() -> hyper::Response { + from_template(hyper::StatusCode::NOT_FOUND, hyper::StatusCode::NOT_FOUND.to_string(), TEXT) + } } diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index e8a4deab13..067897838b 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -279,6 +279,8 @@ pub(crate) async fn background_task( .. } = svc; + let methods = methods.current(); + let (tx, rx) = mpsc::unbounded::(); let bounded_subscriptions = BoundedSubscriptions::new(max_subscriptions_per_connection); let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length);