From 7df1ae4b0e804956c6f36fd5cc11ff9ace4dcdb1 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Mon, 22 Jul 2024 22:04:35 -0400 Subject: [PATCH 1/8] impl: instrument action --- Cargo.lock | 7 ++ boltconn/Cargo.toml | 1 + boltconn/src/config/error.rs | 8 ++ boltconn/src/dispatch/action.rs | 4 + boltconn/src/dispatch/dispatching.rs | 3 +- boltconn/src/dispatch/instrument.rs | 152 +++++++++++++++++++++++++++ boltconn/src/dispatch/mod.rs | 1 + boltconn/src/dispatch/rule.rs | 12 ++- boltconn/src/dispatch/ruleset.rs | 6 +- boltconn/src/dispatch/temporary.rs | 1 + boltconn/src/platform/process/mod.rs | 9 ++ 11 files changed, 198 insertions(+), 6 deletions(-) create mode 100644 boltconn/src/dispatch/instrument.rs diff --git a/Cargo.lock b/Cargo.lock index 3375f5f..00ec7c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -601,6 +601,7 @@ dependencies = [ "httparse", "hyper 1.2.0", "hyper-util", + "interpolator", "ioctl-sys", "ip_network", "ip_network_table", @@ -2004,6 +2005,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "interpolator" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71dd52191aae121e8611f1e8dc3e324dd0dd1dee1e6dd91d10ee07a3cfb4d9d8" + [[package]] name = "ioctl-sys" version = "0.8.0" diff --git a/boltconn/Cargo.toml b/boltconn/Cargo.toml index 40ee2ce..b50ab94 100644 --- a/boltconn/Cargo.toml +++ b/boltconn/Cargo.toml @@ -73,6 +73,7 @@ rusqlite = { version = "0.29.0", features = ["bundled"] } rustls-pemfile = "2.1.1" webpki-roots = "0.26.1" x25519-dalek = "2.0.0-pre.1" +interpolator = "0.5.0" # Proxies fast-socks5 = "0.9.1" boringtun = "0.6.0" diff --git a/boltconn/src/config/error.rs b/boltconn/src/config/error.rs index ff6cf11..feb566b 100644 --- a/boltconn/src/config/error.rs +++ b/boltconn/src/config/error.rs @@ -16,12 +16,20 @@ pub enum ConfigError { TaskJoin(#[from] tokio::task::JoinError), #[error("Script error: {0}")] Script(#[from] ScriptError), + #[error("Instrument error: {0}")] + Instrument(#[from] InstrumentConfigError), #[error("Interception error: {0}")] Intercept(#[from] InterceptConfigError), #[error("Internal error: {0}")] Internal(&'static str), } +#[derive(Error, Debug)] +pub enum InstrumentConfigError { + #[error("Bad template {0}: {1}")] + BadTemplate(String, String), +} + #[derive(Error, Debug)] pub enum InterceptConfigError { #[error("Unknown rule: {0}")] diff --git a/boltconn/src/dispatch/action.rs b/boltconn/src/dispatch/action.rs index 8b32f31..c7cfb60 100644 --- a/boltconn/src/dispatch/action.rs +++ b/boltconn/src/dispatch/action.rs @@ -1,3 +1,4 @@ +use crate::dispatch::instrument::Instrument; use crate::dispatch::rule::RuleImpl; use crate::dispatch::{ConnInfo, DispatchingSnippet, ProxyImpl}; use crate::network::dns::Dns; @@ -9,8 +10,10 @@ use std::sync::Arc; pub enum Action { LocalResolve(LocalResolve), SubDispatch(SubDispatch), + Instrument(Instrument), } +//---------------------------------------------------------------------- pub struct LocalResolve { dns: Arc, } @@ -31,6 +34,7 @@ impl LocalResolve { } } +//---------------------------------------------------------------------- pub struct SubDispatch { rule: RuleImpl, snippet: DispatchingSnippet, diff --git a/boltconn/src/dispatch/dispatching.rs b/boltconn/src/dispatch/dispatching.rs index 4bccaf3..d2be0f1 100644 --- a/boltconn/src/dispatch/dispatching.rs +++ b/boltconn/src/dispatch/dispatching.rs @@ -38,7 +38,7 @@ pub struct ConnInfo { } impl ConnInfo { - pub fn socketaddr(&self) -> Option<&SocketAddr> { + pub fn dst_addr(&self) -> Option<&SocketAddr> { if let NetworkAddr::Raw(s) = &self.dst { Some(s) } else { @@ -736,6 +736,7 @@ impl DispatchingSnippet { return r; } } + Action::Instrument(_) => unimplemented!("TODO: Instrument"), }, } } diff --git a/boltconn/src/dispatch/instrument.rs b/boltconn/src/dispatch/instrument.rs new file mode 100644 index 0000000..81a5ac3 --- /dev/null +++ b/boltconn/src/dispatch/instrument.rs @@ -0,0 +1,152 @@ +use crate::config::{ConfigError, InstrumentConfigError}; +use crate::dispatch::rule::RuleImpl; +use crate::dispatch::{ConnInfo, InboundInfo}; +use interpolator::Formattable; +use std::collections::HashMap; + +//---------------------------------------------------------------------- +pub struct Instrument { + rule: RuleImpl, + sub_id: u64, + fmt_obj: FormattingObject, +} + +struct FormattingObject { + usr_template: String, +} + +impl FormattingObject { + pub fn new(usr_template: String) -> Result { + let mock_info = ConnInfo { + src: std::net::SocketAddr::V4(std::net::SocketAddrV4::new( + std::net::Ipv4Addr::new(127, 0, 0, 1), + 8080, + )), + dst: crate::proxy::NetworkAddr::DomainName { + domain_name: "example.com".to_string(), + port: 443, + }, + local_ip: None, + inbound: InboundInfo::Tun, + resolved_dst: None, + connection_type: crate::platform::process::NetworkType::Tcp, + process_info: None, + }; + if let Err(e) = Self::format_inner(usr_template.as_str(), &mock_info) { + return Err(ConfigError::Instrument(InstrumentConfigError::BadTemplate( + usr_template.to_string(), + e.to_string(), + ))); + } + Ok(Self { usr_template }) + } + + pub fn format(&self, info: &ConnInfo) -> String { + Self::format_inner(self.usr_template.as_str(), info).expect("Infallible after check") + } + + fn format_inner(template: &str, info: &ConnInfo) -> Result { + let na_str = "N/A"; + + let local_ip = info + .local_ip + .map_or_else(|| na_str.to_string(), |ip| ip.to_string()); + let resolved_dst = info + .resolved_dst + .map_or_else(|| na_str.to_string(), |addr| addr.to_string()); + + // inbound info + let inbound_type = match &info.inbound { + InboundInfo::Tun => "tun", + InboundInfo::Http(_) => "http", + InboundInfo::Socks5(_) => "socks5", + }; + let inbound_port = match &info.inbound { + InboundInfo::Tun => None, + InboundInfo::Http(user) | InboundInfo::Socks5(user) => user.port, + } + .map_or_else(|| na_str.to_string(), |port| port.to_string()); + let inbound_username = match &info.inbound { + InboundInfo::Tun => None, + InboundInfo::Http(user) | InboundInfo::Socks5(user) => user.user.clone(), + } + .unwrap_or("N/A".to_string()); + + // process info + let process_name = info + .process_info + .as_ref() + .map_or_else(|| na_str.to_string(), |info| info.name.clone()); + let process_cmdline = info + .process_info + .as_ref() + .map_or_else(|| na_str.to_string(), |info| info.cmdline.clone()); + let process_path = info + .process_info + .as_ref() + .map_or_else(|| na_str.to_string(), |info| info.path.clone()); + let process_pid = info + .process_info + .as_ref() + .map_or_else(|| na_str.to_string(), |info| info.pid.to_string()); + let process_ppid = info + .process_info + .as_ref() + .map_or_else(|| na_str.to_string(), |info| info.ppid.to_string()); + let process_pname = info.process_info.as_ref().map_or_else( + || na_str.to_string(), + |info| { + info.parent_name + .clone() + .unwrap_or_else(|| na_str.to_string()) + }, + ); + + // Collect to hashmap; needed to be exported to end user, so consistency of key name is important here. + let mapping = [ + ("addr.src", Formattable::display(&info.src)), + ("addr.dst", Formattable::display(&info.dst)), + ("addr.resolved_dst", Formattable::display(&resolved_dst)), + ("ip.local", Formattable::display(&local_ip)), + ("inbound.type", Formattable::display(&inbound_type)), + ("inbound.port", Formattable::display(&inbound_port)), + ("inbound.user", Formattable::display(&inbound_username)), + ("conn.type", Formattable::display(&info.connection_type)), + ("process.name", Formattable::display(&process_name)), + ("process.cmdline", Formattable::display(&process_cmdline)), + ("process.path", Formattable::display(&process_path)), + ("process.pid", Formattable::display(&process_pid)), + ("process.ppid", Formattable::display(&process_ppid)), + ("process.parent_name", Formattable::display(&process_pname)), + ] + .into_iter() + .collect::>(); + interpolator::format(template, &mapping) + } +} + +#[test] +fn test_instrument_formatting() { + let template = "src: {addr.src}, dst: {addr.dst}, resolved_dst: {addr.resolved_dst}, \ + local_ip: {ip.local}, conn_type: {conn.type}, \ + inbound_type: {inbound.type}, inbound_port: {inbound.port}, inbound_user: {inbound.user}, \ + process_name: {process.name}, process_cmdline: {process.cmdline}, process_path: {process.path}, \ + process_pid: {process.pid}, process_ppid: {process.ppid}, process_parent_name: {process.parent_name}"; + let info = ConnInfo { + src: std::net::SocketAddr::V4(std::net::SocketAddrV4::new( + std::net::Ipv4Addr::new(192, 168, 0, 1), + 8080, + )), + dst: crate::proxy::NetworkAddr::DomainName { + domain_name: "example.com".to_string(), + port: 443, + }, + local_ip: None, + inbound: InboundInfo::Tun, + resolved_dst: None, + connection_type: crate::platform::process::NetworkType::Tcp, + process_info: None, + }; + let fmt_obj = FormattingObject::new(template.to_string()).unwrap(); + let _ = fmt_obj.format(&info); +} diff --git a/boltconn/src/dispatch/mod.rs b/boltconn/src/dispatch/mod.rs index 04780ed..573463a 100644 --- a/boltconn/src/dispatch/mod.rs +++ b/boltconn/src/dispatch/mod.rs @@ -1,6 +1,7 @@ mod action; mod dispatching; mod inbound; +mod instrument; mod proxy; mod rule; mod ruleset; diff --git a/boltconn/src/dispatch/rule.rs b/boltconn/src/dispatch/rule.rs index ebc83d1..7189691 100644 --- a/boltconn/src/dispatch/rule.rs +++ b/boltconn/src/dispatch/rule.rs @@ -70,6 +70,8 @@ pub enum RuleImpl { And(Vec), Or(Vec), Not(Box), + Always, + Never, } impl RuleImpl { @@ -104,12 +106,12 @@ impl RuleImpl { } RuleImpl::LocalIpCidr(net) => info.local_ip.as_ref().map_or(false, |s| net.contains(s)), RuleImpl::SrcIpCidr(net) => net.contains(&info.src.ip()), - RuleImpl::IpCidr(net) => info.socketaddr().is_some_and(|s| net.contains(&s.ip())), + RuleImpl::IpCidr(net) => info.dst_addr().is_some_and(|s| net.contains(&s.ip())), RuleImpl::GeoIP(mmdb, country) => info - .socketaddr() + .dst_addr() .is_some_and(|s| mmdb.search_country(s.ip()).is_some_and(|c| c == country)), RuleImpl::Asn(mmdb, asn) => info - .socketaddr() + .dst_addr() .is_some_and(|s| mmdb.search_asn(s.ip()).is_some_and(|a| a == *asn)), RuleImpl::SrcPort(port) => match port { PortRule::Tcp(p) => { @@ -168,6 +170,8 @@ impl RuleImpl { false })(), RuleImpl::Not(r) => !r.matches(info), + RuleImpl::Always => true, + RuleImpl::Never => false, } } } @@ -401,6 +405,8 @@ impl RuleBuilder<'_> { "RULE-SET" => rulesets .and_then(|table| table.get(&content)) .map(|rs| RuleImpl::RuleSet(rs.clone())), + "ALWAYS" => Some(RuleImpl::Always), + "NEVER" => Some(RuleImpl::Never), _ => None, } } diff --git a/boltconn/src/dispatch/ruleset.rs b/boltconn/src/dispatch/ruleset.rs index 9fd59d9..3eacfea 100644 --- a/boltconn/src/dispatch/ruleset.rs +++ b/boltconn/src/dispatch/ruleset.rs @@ -86,7 +86,7 @@ impl RuleSet { } } if let Some((mmdb, asn, countries)) = &self.mmdb { - if info.socketaddr().is_some_and(|s| { + if info.dst_addr().is_some_and(|s| { mmdb.search_asn(s.ip()).is_some_and(|a| asn.contains(&a)) || mmdb .search_country(s.ip()) @@ -243,7 +243,9 @@ impl RuleSetBuilder { | RuleImpl::And(..) | RuleImpl::Or(..) | RuleImpl::Not(_) - | RuleImpl::ProcCmdRegex(_) => return None, + | RuleImpl::ProcCmdRegex(_) + | RuleImpl::Always + | RuleImpl::Never => return None, } } Some(retval) diff --git a/boltconn/src/dispatch/temporary.rs b/boltconn/src/dispatch/temporary.rs index 25a41fd..b15ee18 100644 --- a/boltconn/src/dispatch/temporary.rs +++ b/boltconn/src/dispatch/temporary.rs @@ -40,6 +40,7 @@ impl TemporaryList { return Some(r); } } + Action::Instrument(_) => unimplemented!("TODO: Instrument"), }, } } diff --git a/boltconn/src/platform/process/mod.rs b/boltconn/src/platform/process/mod.rs index 2dea9d4..f03ec7a 100644 --- a/boltconn/src/platform/process/mod.rs +++ b/boltconn/src/platform/process/mod.rs @@ -16,6 +16,15 @@ pub enum NetworkType { Udp, } +impl std::fmt::Display for NetworkType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NetworkType::Tcp => write!(f, "tcp"), + NetworkType::Udp => write!(f, "udp"), + } + } +} + #[derive(Debug, Default, Clone)] pub struct ProcessInfo { pub pid: i32, From 3bf639029c4961bc3e64e700861c00130e3dd6c0 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Mon, 5 Aug 2024 23:57:45 -0400 Subject: [PATCH 2/8] impl: bus publisher & subscriber --- boltconn/src/dispatch/action.rs | 6 +- boltconn/src/dispatch/dispatching.rs | 2 +- boltconn/src/dispatch/mod.rs | 2 +- boltconn/src/dispatch/temporary.rs | 2 +- .../instrument.rs => instrument/action.rs} | 14 +++- boltconn/src/instrument/bus.rs | 81 +++++++++++++++++++ boltconn/src/instrument/mod.rs | 2 + boltconn/src/main.rs | 1 + 8 files changed, 102 insertions(+), 8 deletions(-) rename boltconn/src/{dispatch/instrument.rs => instrument/action.rs} (95%) create mode 100644 boltconn/src/instrument/bus.rs create mode 100644 boltconn/src/instrument/mod.rs diff --git a/boltconn/src/dispatch/action.rs b/boltconn/src/dispatch/action.rs index c7cfb60..0f91e93 100644 --- a/boltconn/src/dispatch/action.rs +++ b/boltconn/src/dispatch/action.rs @@ -1,6 +1,6 @@ -use crate::dispatch::instrument::Instrument; -use crate::dispatch::rule::RuleImpl; +use crate::dispatch::RuleImpl; use crate::dispatch::{ConnInfo, DispatchingSnippet, ProxyImpl}; +use crate::instrument::action::InstrumentAction; use crate::network::dns::Dns; use crate::proxy::NetworkAddr; use async_recursion::async_recursion; @@ -10,7 +10,7 @@ use std::sync::Arc; pub enum Action { LocalResolve(LocalResolve), SubDispatch(SubDispatch), - Instrument(Instrument), + Instrument(InstrumentAction), } //---------------------------------------------------------------------- diff --git a/boltconn/src/dispatch/dispatching.rs b/boltconn/src/dispatch/dispatching.rs index d2be0f1..05930d2 100644 --- a/boltconn/src/dispatch/dispatching.rs +++ b/boltconn/src/dispatch/dispatching.rs @@ -736,7 +736,7 @@ impl DispatchingSnippet { return r; } } - Action::Instrument(_) => unimplemented!("TODO: Instrument"), + Action::Instrument(r) => r.execute(info).await, }, } } diff --git a/boltconn/src/dispatch/mod.rs b/boltconn/src/dispatch/mod.rs index 573463a..752f1da 100644 --- a/boltconn/src/dispatch/mod.rs +++ b/boltconn/src/dispatch/mod.rs @@ -1,7 +1,6 @@ mod action; mod dispatching; mod inbound; -mod instrument; mod proxy; mod rule; mod ruleset; @@ -11,4 +10,5 @@ pub use dispatching::*; pub(crate) use inbound::*; pub use proxy::*; // expose this interface for performance +pub use rule::RuleImpl; pub use ruleset::*; diff --git a/boltconn/src/dispatch/temporary.rs b/boltconn/src/dispatch/temporary.rs index b15ee18..c9e7170 100644 --- a/boltconn/src/dispatch/temporary.rs +++ b/boltconn/src/dispatch/temporary.rs @@ -40,7 +40,7 @@ impl TemporaryList { return Some(r); } } - Action::Instrument(_) => unimplemented!("TODO: Instrument"), + Action::Instrument(r) => r.execute(info).await, }, } } diff --git a/boltconn/src/dispatch/instrument.rs b/boltconn/src/instrument/action.rs similarity index 95% rename from boltconn/src/dispatch/instrument.rs rename to boltconn/src/instrument/action.rs index 81a5ac3..fb36b29 100644 --- a/boltconn/src/dispatch/instrument.rs +++ b/boltconn/src/instrument/action.rs @@ -1,16 +1,26 @@ use crate::config::{ConfigError, InstrumentConfigError}; -use crate::dispatch::rule::RuleImpl; +use crate::dispatch::RuleImpl; use crate::dispatch::{ConnInfo, InboundInfo}; use interpolator::Formattable; use std::collections::HashMap; //---------------------------------------------------------------------- -pub struct Instrument { +pub struct InstrumentAction { rule: RuleImpl, sub_id: u64, fmt_obj: FormattingObject, } +impl InstrumentAction { + pub async fn execute(&self, info: &ConnInfo) { + if self.rule.matches(info) { + let _ = self.fmt_obj.format(info); + // TODO + todo!("InstrumentAction::execute"); + } + } +} + struct FormattingObject { usr_template: String, } diff --git a/boltconn/src/instrument/bus.rs b/boltconn/src/instrument/bus.rs new file mode 100644 index 0000000..e3f576d --- /dev/null +++ b/boltconn/src/instrument/bus.rs @@ -0,0 +1,81 @@ +use std::collections::HashMap; +use std::sync::Mutex; + +pub type SubId = u64; + +pub struct Bus { + // Only used for cloning + ingress_sender_handle: flume::Sender, + ingress_receiver: flume::Receiver, + egress_senders: Mutex>>, +} + +impl Bus { + pub async fn run(&self) { + while let Ok(msg) = self.ingress_receiver.recv_async().await { + if let Some(sender) = self.egress_senders.lock().unwrap().get(&msg.sub_id) { + let _ = sender.try_send(msg); + } + } + } + + pub fn create_publisher(&self, sub_id: SubId) -> BusPublisher { + let sender = self.ingress_sender_handle.clone(); + BusPublisher::new(sub_id, sender) + } + + /// Returns None if any of the sub_ids already exists + pub fn create_subscriber(&self, sub_ids: I) -> Option + where + I: Iterator + Clone, + { + let (sender, receiver) = flume::unbounded(); + let iter2 = sub_ids.clone(); + let mut egress_senders = self.egress_senders.lock().unwrap(); + for sub_id in iter2 { + if egress_senders.contains_key(&sub_id) { + return None; + } + } + for sub_id in sub_ids { + egress_senders.insert(sub_id, sender.clone()); + } + Some(BusSubscriber::new(receiver)) + } +} + +#[derive(Debug, Clone)] +pub struct BusMessage { + sub_id: SubId, + msg: String, +} + +pub struct BusPublisher { + sub_id: SubId, + sender: flume::Sender, +} + +impl BusPublisher { + pub fn new(sub_id: SubId, sender: flume::Sender) -> Self { + Self { sub_id, sender } + } + + pub fn publish(&self, msg: BusMessage) { + // drop on full channel + let _ = self.sender.try_send(msg); + } +} + +pub struct BusSubscriber { + receiver: flume::Receiver, +} + +impl BusSubscriber { + pub fn new(receiver: flume::Receiver) -> Self { + Self { receiver } + } + + pub async fn recv(&self) -> Option { + self.receiver.recv_async().await.ok() + } +} diff --git a/boltconn/src/instrument/mod.rs b/boltconn/src/instrument/mod.rs new file mode 100644 index 0000000..638e50d --- /dev/null +++ b/boltconn/src/instrument/mod.rs @@ -0,0 +1,2 @@ +pub mod action; +mod bus; diff --git a/boltconn/src/main.rs b/boltconn/src/main.rs index 928f03a..5e7a943 100644 --- a/boltconn/src/main.rs +++ b/boltconn/src/main.rs @@ -21,6 +21,7 @@ mod common; mod config; mod dispatch; mod external; +mod instrument; mod intercept; mod network; mod platform; From 089d6ca2bd414b54b0d59f2dd8654ccf3ace5b25 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Wed, 7 Aug 2024 19:00:10 -0400 Subject: [PATCH 3/8] impl: instrument action --- boltconn/src/app.rs | 29 +++++++++++++++++---- boltconn/src/config/rule.rs | 10 +++++++ boltconn/src/dispatch/dispatching.rs | 20 ++++++++++++-- boltconn/src/instrument/action.rs | 23 +++++++++++++--- boltconn/src/instrument/bus.rs | 20 ++++++++++++-- boltconn/src/instrument/mod.rs | 2 +- boltconn/src/intercept/intercept_manager.rs | 4 ++- 7 files changed, 94 insertions(+), 14 deletions(-) diff --git a/boltconn/src/app.rs b/boltconn/src/app.rs index 89e4688..5528ceb 100644 --- a/boltconn/src/app.rs +++ b/boltconn/src/app.rs @@ -7,6 +7,7 @@ use crate::external::{ Controller, DatabaseHandle, MmdbReader, SharedDispatching, StreamLoggerSend, UdsController, UnixListenerGuard, WebController, }; +use crate::instrument::bus::MessageBus; use crate::intercept::{InterceptModifier, InterceptionManager}; use crate::network::configure::TunConfigure; use crate::network::dns::{new_bootstrap_resolver, parse_dns_config, Dns, NameserverPolicies}; @@ -41,6 +42,7 @@ pub struct App { speedtest_url: Arc>, receiver: tokio::sync::mpsc::Receiver<()>, uds_socket: Arc, + msg_bus: Arc, } impl App { @@ -173,6 +175,9 @@ impl App { 9961, ); + // initialize instrumentation + let msg_bus = Arc::new(MessageBus::new()); + // dispatch let mut ruleset = HashMap::new(); for (name, schema) in &loaded_config.rule_schema { @@ -182,9 +187,15 @@ impl App { ruleset.insert(name.clone(), Arc::new(builder.build()?)); } let dispatching = Arc::new( - DispatchingBuilder::new(dns.clone(), mmdb.clone(), &loaded_config, &ruleset) - .and_then(|b| b.build(&loaded_config)) - .map_err(|e| anyhow!("Parse routing rules failed: {}", e))?, + DispatchingBuilder::new( + dns.clone(), + mmdb.clone(), + &loaded_config, + &ruleset, + msg_bus.clone(), + ) + .and_then(|b| b.build(&loaded_config)) + .map_err(|e| anyhow!("Parse routing rules failed: {}", e))?, ); let dispatcher = { // tls mitm @@ -196,6 +207,7 @@ impl App { dns.clone(), mmdb.clone(), &ruleset, + msg_bus.clone(), ) .map_err(|e| anyhow!("Load intercept rules failed: {}", e))?, ); @@ -317,6 +329,7 @@ impl App { speedtest_url, receiver: reload_receiver, uds_socket: uds_listener, + msg_bus, }) } @@ -383,8 +396,13 @@ impl App { ) .await?; let dispatching = { - let builder = - DispatchingBuilder::new(self.dns.clone(), mmdb.clone(), &loaded_config, &ruleset)?; + let builder = DispatchingBuilder::new( + self.dns.clone(), + mmdb.clone(), + &loaded_config, + &ruleset, + self.msg_bus.clone(), + )?; Arc::new(builder.build(&loaded_config)?) }; @@ -394,6 +412,7 @@ impl App { self.dns.clone(), mmdb.clone(), &ruleset, + self.msg_bus.clone(), ) .map_err(|e| anyhow!("Load intercept rules failed: {}", e))?, ); diff --git a/boltconn/src/config/rule.rs b/boltconn/src/config/rule.rs index 4a2f4db..9ec780c 100644 --- a/boltconn/src/config/rule.rs +++ b/boltconn/src/config/rule.rs @@ -6,12 +6,22 @@ pub struct SubDispatchConfig { pub subrules: Vec, } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct InstrumentConfig { + #[serde(alias = "sub-id")] + pub id: u64, + pub matches: String, + pub message: String, +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub enum RuleAction { #[serde(alias = ".LOCAL-RESOLVE")] LocalResolve, #[serde(alias = ".SUB-DISPATCH")] SubDispatch(SubDispatchConfig), + #[serde(alias = ".INSTRUMENT")] + Instrument(InstrumentConfig), } // Warning: order matters here; changing order may result in break diff --git a/boltconn/src/dispatch/dispatching.rs b/boltconn/src/dispatch/dispatching.rs index 05930d2..e89b587 100644 --- a/boltconn/src/dispatch/dispatching.rs +++ b/boltconn/src/dispatch/dispatching.rs @@ -11,6 +11,8 @@ use crate::dispatch::ruleset::RuleSet; use crate::dispatch::temporary::TemporaryList; use crate::dispatch::{GeneralProxy, InboundInfo, Proxy, ProxyGroup, RuleSetTable}; use crate::external::MmdbReader; +use crate::instrument::action::InstrumentAction; +use crate::instrument::bus::MessageBus; use crate::network::dns::Dns; use crate::platform::process::{NetworkType, ProcessInfo}; use crate::proxy::NetworkAddr; @@ -103,10 +105,11 @@ pub struct DispatchingBuilder { group_order: Vec, dns: Arc, mmdb: Option>, + msg_bus: Arc, } impl DispatchingBuilder { - pub fn empty(dns: Arc, mmdb: Option>) -> Self { + pub fn empty(dns: Arc, mmdb: Option>, msg_bus: Arc) -> Self { let mut builder = Self { proxies: Default::default(), groups: Default::default(), @@ -114,6 +117,7 @@ impl DispatchingBuilder { group_order: Default::default(), dns, mmdb, + msg_bus, }; builder.proxies.insert( "DIRECT".into(), @@ -135,8 +139,9 @@ impl DispatchingBuilder { mmdb: Option>, loaded_config: &LoadedConfig, ruleset: &RuleSetTable, + msg_bus: Arc, ) -> Result { - let mut builder = Self::empty(dns, mmdb); + let mut builder = Self::empty(dns, mmdb, msg_bus); // start init let LoadedConfig { config, @@ -239,6 +244,17 @@ impl DispatchingBuilder { ), ))) } + RuleAction::Instrument(ins) => { + let matches = rule_builder.parse_incomplete(ins.matches.as_str())?; + rule_builder.append(RuleOrAction::Action(Action::Instrument( + InstrumentAction::new( + matches, + ins.id, + ins.message.clone(), + self.msg_bus.create_publisher(ins.id), + )?, + ))) + } }, RuleConfigLine::Simple(r) => { if idx == rules.len() - 1 { diff --git a/boltconn/src/instrument/action.rs b/boltconn/src/instrument/action.rs index fb36b29..a0df521 100644 --- a/boltconn/src/instrument/action.rs +++ b/boltconn/src/instrument/action.rs @@ -1,6 +1,7 @@ use crate::config::{ConfigError, InstrumentConfigError}; use crate::dispatch::RuleImpl; use crate::dispatch::{ConnInfo, InboundInfo}; +use crate::instrument::bus::{BusMessage, BusPublisher}; use interpolator::Formattable; use std::collections::HashMap; @@ -9,14 +10,30 @@ pub struct InstrumentAction { rule: RuleImpl, sub_id: u64, fmt_obj: FormattingObject, + bus_publisher: BusPublisher, } impl InstrumentAction { + pub fn new( + rule: RuleImpl, + sub_id: u64, + fmt_template: String, + bus_publisher: BusPublisher, + ) -> Result { + let fmt_obj = FormattingObject::new(fmt_template)?; + Ok(Self { + rule, + sub_id, + fmt_obj, + bus_publisher, + }) + } + pub async fn execute(&self, info: &ConnInfo) { if self.rule.matches(info) { - let _ = self.fmt_obj.format(info); - // TODO - todo!("InstrumentAction::execute"); + let str = self.fmt_obj.format(info); + self.bus_publisher + .publish(BusMessage::new(self.sub_id, str)); } } } diff --git a/boltconn/src/instrument/bus.rs b/boltconn/src/instrument/bus.rs index e3f576d..247f6ba 100644 --- a/boltconn/src/instrument/bus.rs +++ b/boltconn/src/instrument/bus.rs @@ -3,16 +3,26 @@ use std::sync::Mutex; pub type SubId = u64; -pub struct Bus { +pub struct MessageBus { // Only used for cloning ingress_sender_handle: flume::Sender, ingress_receiver: flume::Receiver, egress_senders: Mutex>>, } -impl Bus { +impl MessageBus { + pub fn new() -> Self { + let (ingress_sender, ingress_receiver) = flume::bounded(4096); + Self { + ingress_sender_handle: ingress_sender, + ingress_receiver, + egress_senders: Mutex::new(HashMap::new()), + } + } + pub async fn run(&self) { while let Ok(msg) = self.ingress_receiver.recv_async().await { + tracing::debug!("[Message Bus {}] {:?} ", msg.sub_id, msg.msg); if let Some(sender) = self.egress_senders.lock().unwrap().get(&msg.sub_id) { let _ = sender.try_send(msg); } @@ -50,6 +60,12 @@ pub struct BusMessage { msg: String, } +impl BusMessage { + pub fn new(sub_id: SubId, msg: String) -> Self { + Self { sub_id, msg } + } +} + pub struct BusPublisher { sub_id: SubId, sender: flume::Sender, diff --git a/boltconn/src/instrument/mod.rs b/boltconn/src/instrument/mod.rs index 638e50d..67b13c5 100644 --- a/boltconn/src/instrument/mod.rs +++ b/boltconn/src/instrument/mod.rs @@ -1,2 +1,2 @@ pub mod action; -mod bus; +pub mod bus; diff --git a/boltconn/src/intercept/intercept_manager.rs b/boltconn/src/intercept/intercept_manager.rs index 8d3485d..9e8ce9b 100644 --- a/boltconn/src/intercept/intercept_manager.rs +++ b/boltconn/src/intercept/intercept_manager.rs @@ -1,6 +1,7 @@ use crate::config::{ActionConfig, ConfigError, InterceptConfigError, InterceptionConfig}; use crate::dispatch::{ConnInfo, Dispatching, DispatchingBuilder, ProxyImpl, RuleSetTable}; use crate::external::MmdbReader; +use crate::instrument::bus::MessageBus; use crate::intercept::{HeaderEngine, ScriptEngine, UrlEngine}; use crate::network::dns::Dns; use std::sync::Arc; @@ -107,13 +108,14 @@ impl InterceptionManager { dns: Arc, mmdb: Option>, rulesets: &RuleSetTable, + msg_bus: Arc, ) -> Result { let mut res = vec![]; for i in entries.iter() { if !i.enabled { continue; } - let filters = DispatchingBuilder::empty(dns.clone(), mmdb.clone()) + let filters = DispatchingBuilder::empty(dns.clone(), mmdb.clone(), msg_bus.clone()) .build_filter(i.filters.as_slice(), rulesets)?; let payload = InterceptionPayload::parse_actions(i.actions.as_slice())?; res.push(InterceptionEntry { From 2aeb6dfaa409bd4bdad267beaa1f736ebd50fe2b Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Thu, 8 Aug 2024 00:43:16 -0400 Subject: [PATCH 4/8] fix: bus future not running --- boltconn/src/app.rs | 3 +++ boltconn/src/instrument/action.rs | 13 ++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/boltconn/src/app.rs b/boltconn/src/app.rs index 5528ceb..3f15f6f 100644 --- a/boltconn/src/app.rs +++ b/boltconn/src/app.rs @@ -265,6 +265,9 @@ impl App { tokio::spawn(async move { tun_inbound_udp.run().await }); tokio::spawn(async move { tun.run(nat_addr).await }); + let msg_bus2 = msg_bus.clone(); + tokio::spawn(async move { msg_bus2.run().await }); + // start http/socks5 inbound for (sock_addr, http_auth, socks_auth) in parse_two_inbound_service(&config.inbound.http, &config.inbound.socks5) diff --git a/boltconn/src/instrument/action.rs b/boltconn/src/instrument/action.rs index a0df521..b734d08 100644 --- a/boltconn/src/instrument/action.rs +++ b/boltconn/src/instrument/action.rs @@ -75,6 +75,12 @@ impl FormattingObject { fn format_inner(template: &str, info: &ConnInfo) -> Result { let na_str = "N/A"; + let now = chrono::Local::now(); + let time_rfc3389 = now.to_rfc3339(); + let time_hms_ms = now.format("%H:%M:%S%.3f").to_string(); + let time_datetime = now.format("%Y-%m-%d %H:%M:%S").to_string(); + let time_datetime_ms = now.format("%Y-%m-%d %H:%M:%S%.3f").to_string(); + let local_ip = info .local_ip .map_or_else(|| na_str.to_string(), |ip| ip.to_string()); @@ -145,6 +151,10 @@ impl FormattingObject { ("process.pid", Formattable::display(&process_pid)), ("process.ppid", Formattable::display(&process_ppid)), ("process.parent_name", Formattable::display(&process_pname)), + ("time.rfc3389", Formattable::display(&time_rfc3389)), + ("time.hms_ms", Formattable::display(&time_hms_ms)), + ("time.datetime", Formattable::display(&time_datetime)), + ("time.datetime_ms", Formattable::display(&time_datetime_ms)), ] .into_iter() .collect::>(); @@ -158,7 +168,8 @@ fn test_instrument_formatting() { local_ip: {ip.local}, conn_type: {conn.type}, \ inbound_type: {inbound.type}, inbound_port: {inbound.port}, inbound_user: {inbound.user}, \ process_name: {process.name}, process_cmdline: {process.cmdline}, process_path: {process.path}, \ - process_pid: {process.pid}, process_ppid: {process.ppid}, process_parent_name: {process.parent_name}"; + process_pid: {process.pid}, process_ppid: {process.ppid}, process_parent_name: {process.parent_name}\ + time: [{time.rfc3389}, {time.hms_ms}, {time.datetime}, {time.datetime_ms}]"; let info = ConnInfo { src: std::net::SocketAddr::V4(std::net::SocketAddrV4::new( std::net::Ipv4Addr::new(192, 168, 0, 1), From 4695712c93af52879b0a36468674b5abedc0b7c8 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Thu, 8 Aug 2024 19:47:08 -0400 Subject: [PATCH 5/8] impl: WebSocket-based instrument subscriber --- boltapi/src/instrument.rs | 34 +++++++ boltapi/src/lib.rs | 1 + boltconn/src/external/instrument_server.rs | 103 +++++++++++++++++++++ boltconn/src/external/mod.rs | 2 + boltconn/src/external/web_common.rs | 87 +++++++++++++++++ boltconn/src/external/web_controller.rs | 94 ++----------------- boltconn/src/instrument/bus.rs | 4 +- boltconn/src/proxy/error.rs | 2 + 8 files changed, 238 insertions(+), 89 deletions(-) create mode 100644 boltapi/src/instrument.rs create mode 100644 boltconn/src/external/instrument_server.rs create mode 100644 boltconn/src/external/web_common.rs diff --git a/boltapi/src/instrument.rs b/boltapi/src/instrument.rs new file mode 100644 index 0000000..727e768 --- /dev/null +++ b/boltapi/src/instrument.rs @@ -0,0 +1,34 @@ +/// Wire format for instrumentation data. +#[derive(Debug, Clone)] +pub struct InstrumentData { + /// Unique identifier for the instrument. + pub id: u64, + /// Message to be sent. + pub message: String, +} + +impl InstrumentData { + pub fn encode_string(&self) -> String { + format!("{:x}:{}", self.id, self.message) + } + + pub fn decode_string(encoded: &str) -> Option { + let mut parts = encoded.splitn(2, ':'); + let id = u64::from_str_radix(parts.next()?, 16).ok()?; + let message = parts.next()?.to_string(); + Some(Self { id, message }) + } +} + +#[test] +fn test_instrument_data() { + let data = InstrumentData { + id: 0x1234, + message: "hello".to_string(), + }; + let encoded = data.encode_string(); + assert_eq!(encoded, "1234:hello"); + let decoded = InstrumentData::decode_string(&encoded).unwrap(); + assert_eq!(decoded.id, data.id); + assert_eq!(decoded.message, data.message); +} diff --git a/boltapi/src/lib.rs b/boltapi/src/lib.rs index 6c4d2d8..f0408ba 100644 --- a/boltapi/src/lib.rs +++ b/boltapi/src/lib.rs @@ -1,3 +1,4 @@ +pub mod instrument; pub mod multiplex; pub mod rpc; mod schema; diff --git a/boltconn/src/external/instrument_server.rs b/boltconn/src/external/instrument_server.rs new file mode 100644 index 0000000..8218568 --- /dev/null +++ b/boltconn/src/external/instrument_server.rs @@ -0,0 +1,103 @@ +use crate::common::as_io_err; +use crate::external::web_common::{get_cors_layer, parse_cors_allow, web_auth}; +use crate::instrument::bus::{BusSubscriber, MessageBus}; +use crate::proxy::error::{RuntimeError, SystemError}; +use axum::extract::ws::WebSocket; +use axum::extract::{ws, Path, State, WebSocketUpgrade}; +use axum::middleware::map_request; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::Router; +use boltapi::instrument::InstrumentData; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::TcpListener; + +#[derive(Clone)] +pub struct InstrumentServer { + secret: Option, + msg_bus: Arc, +} + +impl InstrumentServer { + pub async fn run( + self, + listen_addr: SocketAddr, + cors_allowed_list: &[String], + ) -> Result<(), RuntimeError> { + let secret = Arc::new(self.secret.clone()); + let cors_vec = parse_cors_allow(cors_allowed_list); + let auth_wrapper = move |r| web_auth(secret.clone(), r, cors_vec.clone()); + + let mut app = Router::new() + .route("/subscribe", get(Self::subscribe)) + .route_layer(map_request(auth_wrapper)) + .with_state(self); + if let Some(origin) = super::web_controller::parse_api_cors_origin(cors_allowed_list) { + app = app.layer(get_cors_layer(origin)); + } + + let listener = TcpListener::bind(&listen_addr) + .await + .map_err(SystemError::InstrumentServer)?; + axum::serve(listener, app.into_make_service()) + .await + .map_err(|e| SystemError::InstrumentServer(as_io_err(e)))?; + Ok(()) + } + + async fn subscribe( + State(server): State, + Path(params): Path>, + ws: WebSocketUpgrade, + ) -> impl IntoResponse { + if let Some(secret) = server.secret.as_ref() { + if params.get("secret") != Some(secret) { + return refusal_resp(http::StatusCode::UNAUTHORIZED); + } + } + // parse hex-encoded topics from url params + let ids = { + let mut arr = Vec::new(); + let Some(s) = params.get("id") else { + return refusal_resp(http::StatusCode::BAD_REQUEST); + }; + for id in s.split(',') { + if let Ok(val) = u64::from_str_radix(id, 16) { + arr.push(val); + } else { + return refusal_resp(http::StatusCode::BAD_REQUEST); + } + } + arr + }; + let Some(sub) = server.msg_bus.create_subscriber(ids.iter().copied()) else { + return refusal_resp(http::StatusCode::CONFLICT); + }; + ws.on_upgrade(move |socket| Self::subscribe_inner(socket, sub, ids)) + } + + async fn subscribe_inner(mut socket: WebSocket, sub: BusSubscriber, ids: Vec) { + while let Some(msg) = sub.recv().await { + let wire_msg = InstrumentData { + id: msg.sub_id, + message: msg.msg, + }; + if let Err(e) = socket + .send(ws::Message::Text(wire_msg.encode_string())) + .await + { + tracing::warn!("Subscriber for {:?} failed to send: {}", ids, e); + break; + } + } + } +} + +fn refusal_resp(code: http::StatusCode) -> http::Response { + http::Response::builder() + .status(code) + .body(axum::body::Body::empty()) + .unwrap() +} diff --git a/boltconn/src/external/mod.rs b/boltconn/src/external/mod.rs index 6aa809e..7503f06 100644 --- a/boltconn/src/external/mod.rs +++ b/boltconn/src/external/mod.rs @@ -1,8 +1,10 @@ mod controller; mod database; +mod instrument_server; mod logger; mod mmdb; mod uds_controller; +mod web_common; mod web_controller; pub use controller::*; diff --git a/boltconn/src/external/web_common.rs b/boltconn/src/external/web_common.rs new file mode 100644 index 0000000..0169385 --- /dev/null +++ b/boltconn/src/external/web_common.rs @@ -0,0 +1,87 @@ +use http::Method; +use std::collections::HashSet; +use std::sync::Arc; +use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer}; + +pub(super) async fn web_auth( + auth: Arc>, + request: http::Request, + cors_allow: CorsAllow, +) -> Result, http::StatusCode> { + // Validate websocket origin + // The `origin` header will be set automatically by browser + if request.headers().contains_key("Upgrade") + && request.headers().contains_key("origin") + && !cors_allow.validate( + request + .headers() + .get("origin") + .unwrap() + .to_str() + .map_err(|_| http::StatusCode::UNAUTHORIZED)?, + ) + { + return Err(http::StatusCode::UNAUTHORIZED); + } + + if let Some(auth) = auth.as_ref() { + let auth_header = request + .headers() + .get("api-key") + .and_then(|h| h.to_str().ok()); + match auth_header { + Some(header_val) if header_val == auth => Ok(request), + _ => Err(http::StatusCode::UNAUTHORIZED), + } + } else { + Ok(request) + } +} + +#[derive(Debug, Clone)] +pub(super) enum CorsAllow { + Any, + None, + Some(Arc>), +} + +impl CorsAllow { + pub fn validate(&self, source: &str) -> bool { + match self { + CorsAllow::Any => true, + CorsAllow::None => Self::is_local(source), + CorsAllow::Some(set) => set.contains(source) || Self::is_local(source), + } + } + + pub fn is_local(source: &str) -> bool { + source.starts_with("http://localhost") + || source.starts_with("http://127.0.0.1") + || source.starts_with("file://") + || source.starts_with("https://localhost") + || source.starts_with("https://127.0.0.1") + } +} + +pub(super) fn parse_cors_allow(cors_allowed_list: &[String]) -> CorsAllow { + if !cors_allowed_list.is_empty() { + let mut list = HashSet::new(); + for i in cors_allowed_list.iter() { + if i == "*" { + return CorsAllow::Any; + } else { + list.insert(i.clone()); + } + } + CorsAllow::Some(Arc::new(list)) + } else { + CorsAllow::None + } +} + +pub(super) fn get_cors_layer(origin: AllowOrigin) -> CorsLayer { + CorsLayer::new() + .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]) + .allow_origin(origin) + .allow_headers(AllowHeaders::any()) +} diff --git a/boltconn/src/external/web_controller.rs b/boltconn/src/external/web_controller.rs index 135a87d..9089174 100644 --- a/boltconn/src/external/web_controller.rs +++ b/boltconn/src/external/web_controller.rs @@ -1,5 +1,6 @@ use crate::common::as_io_err; use crate::dispatch::Dispatching; +use crate::external::web_common::{get_cors_layer, parse_cors_allow, web_auth}; use crate::external::{Controller, StreamLoggerRecv}; use crate::proxy::error::SystemError; use arc_swap::ArcSwap; @@ -10,14 +11,14 @@ use axum::response::IntoResponse; use axum::routing::{delete, get, post}; use axum::{Json, Router}; use boltapi::{GetInterceptRangeReq, SetGroupReqSchema, TrafficResp, TunStatusSchema}; -use http::{HeaderValue, Method}; +use http::HeaderValue; use serde_json::json; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; -use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer}; +use tower_http::cors::AllowOrigin; pub type SharedDispatching = Arc>; @@ -39,7 +40,7 @@ impl WebController { ) -> Result<(), SystemError> { let secret = Arc::new(self.secret.clone()); let cors_vec = parse_cors_allow(cors_allowed_list); - let wrapper = move |r| Self::auth(secret.clone(), r, cors_vec.clone()); + let wrapper = move |r| web_auth(secret.clone(), r, cors_vec.clone()); let mut app = Router::new() .route("/ws/traffic", get(Self::ws_get_traffic)) @@ -75,12 +76,7 @@ impl WebController { .route_layer(map_request(wrapper)) .with_state(self); if let Some(origin) = parse_api_cors_origin(cors_allowed_list) { - app = app.layer( - CorsLayer::new() - .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]) - .allow_origin(origin) - .allow_headers(AllowHeaders::any()), - ); + app = app.layer(get_cors_layer(origin)); } let listener = TcpListener::bind(&listen_addr) @@ -92,41 +88,6 @@ impl WebController { Ok(()) } - async fn auth( - auth: Arc>, - request: http::Request, - cors_allow: CorsAllow, - ) -> Result, http::StatusCode> { - // Validate websocket origin - // The `origin` header will be set automatically by browser - if request.headers().contains_key("Upgrade") - && request.headers().contains_key("origin") - && !cors_allow.validate( - request - .headers() - .get("origin") - .unwrap() - .to_str() - .map_err(|_| http::StatusCode::UNAUTHORIZED)?, - ) - { - return Err(http::StatusCode::UNAUTHORIZED); - } - - if let Some(auth) = auth.as_ref() { - let auth_header = request - .headers() - .get("api-key") - .and_then(|h| h.to_str().ok()); - match auth_header { - Some(header_val) if header_val == auth => Ok(request), - _ => Err(http::StatusCode::UNAUTHORIZED), - } - } else { - Ok(request) - } - } - async fn get_tun_configure(State(server): State) -> Json { Json(json!(server.controller.get_tun())) } @@ -356,7 +317,7 @@ impl WebController { } } -fn parse_api_cors_origin(cors_allowed_list: &[String]) -> Option { +pub(super) fn parse_api_cors_origin(cors_allowed_list: &[String]) -> Option { if !cors_allowed_list.is_empty() { let mut list = vec![]; for i in cors_allowed_list.iter() { @@ -371,44 +332,3 @@ fn parse_api_cors_origin(cors_allowed_list: &[String]) -> Option { None } } - -#[derive(Debug, Clone)] -enum CorsAllow { - Any, - None, - Some(Arc>), -} - -impl CorsAllow { - fn validate(&self, source: &str) -> bool { - match self { - CorsAllow::Any => true, - CorsAllow::None => Self::is_local(source), - CorsAllow::Some(set) => set.contains(source) || Self::is_local(source), - } - } - - fn is_local(source: &str) -> bool { - source.starts_with("http://localhost") - || source.starts_with("http://127.0.0.1") - || source.starts_with("file://") - || source.starts_with("https://localhost") - || source.starts_with("https://127.0.0.1") - } -} - -fn parse_cors_allow(cors_allowed_list: &[String]) -> CorsAllow { - if !cors_allowed_list.is_empty() { - let mut list = HashSet::new(); - for i in cors_allowed_list.iter() { - if i == "*" { - return CorsAllow::Any; - } else { - list.insert(i.clone()); - } - } - CorsAllow::Some(Arc::new(list)) - } else { - CorsAllow::None - } -} diff --git a/boltconn/src/instrument/bus.rs b/boltconn/src/instrument/bus.rs index 247f6ba..dbd40dc 100644 --- a/boltconn/src/instrument/bus.rs +++ b/boltconn/src/instrument/bus.rs @@ -56,8 +56,8 @@ impl MessageBus { #[derive(Debug, Clone)] pub struct BusMessage { - sub_id: SubId, - msg: String, + pub sub_id: SubId, + pub msg: String, } impl BusMessage { diff --git a/boltconn/src/proxy/error.rs b/boltconn/src/proxy/error.rs index 7ce24ad..03dc45f 100644 --- a/boltconn/src/proxy/error.rs +++ b/boltconn/src/proxy/error.rs @@ -18,6 +18,8 @@ pub enum RuntimeError { pub enum SystemError { #[error("Controller error: {0}")] Controller(std::io::Error), + #[error("Instrument server error: {0}")] + InstrumentServer(std::io::Error), } #[derive(Error, Debug)] From 6a6aaa216b3026292c4e90e7bfbd07b386634496 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Fri, 9 Aug 2024 02:35:31 -0400 Subject: [PATCH 6/8] refactor: simplify App::create() --- boltconn/src/app.rs | 307 +++++++++++++++++++--------------- boltconn/src/config/config.rs | 12 ++ 2 files changed, 183 insertions(+), 136 deletions(-) diff --git a/boltconn/src/app.rs b/boltconn/src/app.rs index 3f15f6f..c4c54f4 100644 --- a/boltconn/src/app.rs +++ b/boltconn/src/app.rs @@ -1,8 +1,9 @@ use crate::config::{ default_inbound_ip_addr, safe_join_path, LinkedState, LoadedConfig, PortOrSocketAddr, - RawInboundServiceConfig, SingleOrVec, + RawDnsConfig, RawInboundConfig, RawInboundServiceConfig, RawRootCfg, RawWebControllerConfig, + SingleOrVec, }; -use crate::dispatch::{DispatchingBuilder, RuleSetBuilder}; +use crate::dispatch::{DispatchingBuilder, RuleSet, RuleSetBuilder}; use crate::external::{ Controller, DatabaseHandle, MmdbReader, SharedDispatching, StreamLoggerSend, UdsController, UnixListenerGuard, WebController, @@ -20,6 +21,7 @@ use crate::proxy::{ use crate::{external, platform}; use anyhow::anyhow; use arc_swap::ArcSwap; +use bytes::Bytes; use ipnet::Ipv4Net; use rcgen::{Certificate, CertificateParams, KeyPair}; use std::collections::HashMap; @@ -65,28 +67,11 @@ impl App { .await .map_err(|e| anyhow!("Load config from {:?} failed: {}", &config_path, e))?; let config = &loaded_config.config; - let mmdb = match config.geoip_db.as_ref() { - None => None, - Some(p) => { - let path = safe_join_path(&config_path, p)?; - Some(Arc::new(MmdbReader::read_from_file(path)?)) - } - }; + let mmdb = load_mmdb(config.geoip_db.as_ref(), &config_path)?; - let fake_dns_server = "198.18.99.88".parse().unwrap(); - - let outbound_iface = if config.interface != "auto" { - tracing::info!("Use pre-configured interface: {}", config.interface); - config.interface.clone() - } else { - let (_, real_iface_name) = get_default_v4_route() - .map_err(|e| anyhow!("Failed to get default route: {}", e))?; - tracing::info!("Auto detected interface: {}", real_iface_name); - real_iface_name - }; + let outbound_iface = detect_interface(config)?; - let manager = Arc::new(SessionManager::new()); - let (stat_center, http_capturer) = { + let (ctx_manager, http_capturer) = { let conn_handle = if config.enable_dump { Some(open_database_handle(data_path.as_path())?) } else { @@ -103,39 +88,14 @@ impl App { }; // initialize resources - let dns = { - let bootstrap = - new_bootstrap_resolver(outbound_iface.as_str(), config.dns.bootstrap.as_slice()); - let group = match parse_dns_config(config.dns.nameserver.iter(), Some(&bootstrap)).await - { - Ok(g) => { - if g.is_empty() { - return Err(anyhow!("No DNS specified")); - } - g - } - Err(e) => return Err(anyhow!("Parse dns config failed: {e}")), - }; - let ns_policy = NameserverPolicies::new( - &config.dns.nameserver_policy, - Some(&bootstrap), - outbound_iface.as_str(), - ) - .await - .map_err(|e| anyhow!("Parse nameserver policy failed: {e}"))?; - Arc::new(Dns::with_config( - outbound_iface.as_str(), - config.dns.preference, - &config.dns.hosts, - ns_policy, - group, - )) - }; + let dns = initialize_dns(&config.dns, outbound_iface.as_str()).await?; + let manager = Arc::new(SessionManager::new()); // Create TUN let will_enable_tun = enable_tun.unwrap_or(config.inbound.enable_tun); let (tun_udp_tx, tun_udp_rx) = flume::bounded(4096); let (udp_tun_tx, udp_tun_rx) = flume::bounded(4096); + let fake_dns_server = "198.18.99.88".parse().unwrap(); let tun = { let mut tun = TunDevice::open( manager.clone(), @@ -179,13 +139,7 @@ impl App { let msg_bus = Arc::new(MessageBus::new()); // dispatch - let mut ruleset = HashMap::new(); - for (name, schema) in &loaded_config.rule_schema { - let Some(builder) = RuleSetBuilder::new(name.as_str(), schema) else { - return Err(anyhow!("Filter: failed to parse provider {}", name)); - }; - ruleset.insert(name.clone(), Arc::new(builder.build()?)); - } + let ruleset = load_rulesets(&loaded_config)?; let dispatching = Arc::new( DispatchingBuilder::new( dns.clone(), @@ -215,7 +169,7 @@ impl App { Arc::new(Dispatcher::new( outbound_iface.as_str(), dns.clone(), - stat_center.clone(), + ctx_manager.clone(), dispatching.clone(), cert, Box::new(move |result, proc_info| { @@ -232,7 +186,7 @@ impl App { let controller = Arc::new(Controller::new( manager.clone(), dns.clone(), - stat_center, + ctx_manager, Some(http_capturer.clone()), dispatcher.clone(), api_dispatching_handler.clone(), @@ -246,80 +200,28 @@ impl App { speedtest_url.clone(), )); - // start tun - let tun_inbound_tcp = Arc::new(TunTcpInbound::new( + // start tun & L7 inbound services + start_tun_services( nat_addr, manager.clone(), dispatcher.clone(), dns.clone(), - )); - let tun_inbound_udp = TunUdpInbound::new( + tun, tun_udp_rx, udp_tun_tx, - dispatcher.clone(), - manager.clone(), - dns.clone(), ); - manager.flush_with_interval(Duration::from_secs(30)); - tokio::spawn(async move { tun_inbound_tcp.run().await }); - tokio::spawn(async move { tun_inbound_udp.run().await }); - tokio::spawn(async move { tun.run(nat_addr).await }); + start_inbound_services(&config.inbound, dispatcher.clone()); + + // start controller service + start_controller_services( + config.web_controller.as_ref(), + controller, + uds_listener.clone(), + ); let msg_bus2 = msg_bus.clone(); tokio::spawn(async move { msg_bus2.run().await }); - // start http/socks5 inbound - for (sock_addr, http_auth, socks_auth) in - parse_two_inbound_service(&config.inbound.http, &config.inbound.socks5) - { - let dispatcher = dispatcher.clone(); - match (http_auth, socks_auth) { - (Some(http_auth), Some(socks_auth)) => { - tokio::spawn(async move { - MixedInbound::new(sock_addr, http_auth, socks_auth, dispatcher) - .await? - .run() - .await; - Ok::<(), io::Error>(()) - }); - } - (Some(auth), None) => { - tokio::spawn(async move { - HttpInbound::new(sock_addr, auth, dispatcher) - .await? - .run() - .await; - Ok::<(), io::Error>(()) - }); - } - (None, Some(auth)) => { - tokio::spawn(async move { - Socks5Inbound::new(sock_addr, auth, dispatcher) - .await? - .run() - .await; - Ok::<(), io::Error>(()) - }); - } - _ => unreachable!(), - } - } - - let uds_controller = UdsController::new(controller.clone()); - let uds_listener2 = uds_listener.clone(); - tokio::spawn(async move { uds_controller.run(uds_listener2).await }); - - // start web controller - if let Some(web_cfg) = &config.web_controller { - let api_addr = match web_cfg.api_addr { - PortOrSocketAddr::Port(p) => SocketAddr::new(default_inbound_ip_addr(), p), - PortOrSocketAddr::SocketAddr(s) => s, - }; - let api_server = WebController::new(web_cfg.api_key.clone(), controller); - let cors_domains = web_cfg.cors_allowed_list.clone(); - tokio::spawn(async move { api_server.run(api_addr, cors_domains.as_slice()).await }); - } - Ok(Self { config_path, data_path, @@ -374,20 +276,8 @@ impl App { // reload parsing let loaded_config = LoadedConfig::load_config(&self.config_path, &self.data_path).await?; let config = &loaded_config.config; - let mmdb = match config.geoip_db.as_ref() { - None => None, - Some(p) => { - let path = safe_join_path(&self.config_path, p)?; - Some(Arc::new(MmdbReader::read_from_file(path)?)) - } - }; - let mut ruleset = HashMap::new(); - for (name, schema) in &loaded_config.rule_schema { - let Some(builder) = RuleSetBuilder::new(name.as_str(), schema) else { - return Err(anyhow!("Filter: failed to parse provider {}", name)); - }; - ruleset.insert(name.clone(), Arc::new(builder.build()?)); - } + let mmdb = load_mmdb(config.geoip_db.as_ref(), &self.config_path)?; + let ruleset = load_rulesets(&loaded_config)?; let bootstrap = new_bootstrap_resolver(&self.outbound_iface, config.dns.bootstrap.as_slice()); @@ -441,6 +331,151 @@ impl App { } } +fn load_mmdb(db_path: Option<&String>, cfg_path: &Path) -> anyhow::Result>> { + Ok(match db_path { + None => None, + Some(p) => { + let path = safe_join_path(cfg_path, p)?; + Some(Arc::new(MmdbReader::read_from_file(path)?)) + } + }) +} + +fn detect_interface(config: &RawRootCfg) -> anyhow::Result { + Ok(if config.interface != "auto" { + tracing::info!("Use pre-configured interface: {}", config.interface); + config.interface.clone() + } else { + let (_, real_iface_name) = + get_default_v4_route().map_err(|e| anyhow!("Failed to get default route: {}", e))?; + tracing::info!("Auto detected interface: {}", real_iface_name); + real_iface_name + }) +} + +async fn initialize_dns(config: &RawDnsConfig, outbound_iface: &str) -> anyhow::Result> { + Ok({ + let bootstrap = new_bootstrap_resolver(outbound_iface, config.bootstrap.as_slice()); + let group = match parse_dns_config(config.nameserver.iter(), Some(&bootstrap)).await { + Ok(g) => { + if g.is_empty() { + return Err(anyhow!("No DNS specified")); + } + g + } + Err(e) => return Err(anyhow!("Parse dns config failed: {e}")), + }; + let ns_policy = + NameserverPolicies::new(&config.nameserver_policy, Some(&bootstrap), outbound_iface) + .await + .map_err(|e| anyhow!("Parse nameserver policy failed: {e}"))?; + Arc::new(Dns::with_config( + outbound_iface, + config.preference, + &config.hosts, + ns_policy, + group, + )) + }) +} + +fn start_tun_services( + nat_addr: SocketAddr, + manager: Arc, + dispatcher: Arc, + dns: Arc, + tun: TunDevice, + tun_udp_rx: flume::Receiver, + udp_tun_tx: flume::Sender, +) { + let tun_inbound_tcp = Arc::new(TunTcpInbound::new( + nat_addr, + manager.clone(), + dispatcher.clone(), + dns.clone(), + )); + let tun_inbound_udp = TunUdpInbound::new( + tun_udp_rx, + udp_tun_tx, + dispatcher.clone(), + manager.clone(), + dns.clone(), + ); + manager.flush_with_interval(Duration::from_secs(30)); + tokio::spawn(async move { tun_inbound_tcp.run().await }); + tokio::spawn(async move { tun_inbound_udp.run().await }); + tokio::spawn(async move { tun.run(nat_addr).await }); +} + +fn start_inbound_services(config: &RawInboundConfig, dispatcher: Arc) { + for (sock_addr, http_auth, socks_auth) in + parse_two_inbound_service(&config.http, &config.socks5) + { + let dispatcher = dispatcher.clone(); + match (http_auth, socks_auth) { + (Some(http_auth), Some(socks_auth)) => { + tokio::spawn(async move { + MixedInbound::new(sock_addr, http_auth, socks_auth, dispatcher) + .await? + .run() + .await; + Ok::<(), io::Error>(()) + }); + } + (Some(auth), None) => { + tokio::spawn(async move { + HttpInbound::new(sock_addr, auth, dispatcher) + .await? + .run() + .await; + Ok::<(), io::Error>(()) + }); + } + (None, Some(auth)) => { + tokio::spawn(async move { + Socks5Inbound::new(sock_addr, auth, dispatcher) + .await? + .run() + .await; + Ok::<(), io::Error>(()) + }); + } + _ => unreachable!(), + } + } +} + +fn start_controller_services( + config: Option<&RawWebControllerConfig>, + controller: Arc, + uds_listener: Arc, +) { + let uds_controller = UdsController::new(controller.clone()); + let uds_listener2 = uds_listener.clone(); + tokio::spawn(async move { uds_controller.run(uds_listener2).await }); + + if let Some(web_cfg) = config { + let api_addr = match web_cfg.api_addr { + PortOrSocketAddr::Port(p) => SocketAddr::new(default_inbound_ip_addr(), p), + PortOrSocketAddr::SocketAddr(s) => s, + }; + let api_server = WebController::new(web_cfg.api_key.clone(), controller); + let cors_domains = web_cfg.cors_allowed_list.clone(); + tokio::spawn(async move { api_server.run(api_addr, cors_domains.as_slice()).await }); + } +} + +fn load_rulesets(loaded_config: &LoadedConfig) -> anyhow::Result>> { + let mut ruleset = HashMap::new(); + for (name, schema) in &loaded_config.rule_schema { + let Some(builder) = RuleSetBuilder::new(name.as_str(), schema) else { + return Err(anyhow!("Filter: failed to parse provider {}", name)); + }; + ruleset.insert(name.clone(), Arc::new(builder.build()?)); + } + Ok(ruleset) +} + fn load_cert_and_key(cert_path: &Path) -> anyhow::Result { let cert_str = fs::read_to_string(cert_path.join("crt.pem"))?; let key_str = fs::read_to_string(cert_path.join("key.pem"))?; diff --git a/boltconn/src/config/config.rs b/boltconn/src/config/config.rs index d681876..b4099c6 100644 --- a/boltconn/src/config/config.rs +++ b/boltconn/src/config/config.rs @@ -17,6 +17,8 @@ pub struct RawRootCfg { pub inbound: RawInboundConfig, #[serde(alias = "web-controller")] pub web_controller: Option, + #[serde(alias = "instrument")] + pub instrument: Option, #[serde(default = "default_false")] pub enable_dump: bool, // From now on, all the configs should be reloaded properly @@ -89,6 +91,16 @@ pub struct RawWebControllerConfig { pub cors_allowed_list: Vec, } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RawInstrumentConfig { + #[serde(alias = "api-port", alias = "api-addr")] + pub api_addr: PortOrSocketAddr, + #[serde(alias = "api-key")] + pub api_key: Option, + #[serde(alias = "cors-allowed-list", default = "default_str_vec")] + pub cors_allowed_list: Vec, +} + #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(deny_unknown_fields, tag = "type")] pub enum RawProxyLocalCfg { From b5d05b05af4ddb55dfdc0233edea61d322021c69 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:38:13 -0400 Subject: [PATCH 7/8] proto: change instrument wire format --- boltapi/src/instrument.rs | 6 ++-- boltconn/src/app.rs | 40 +++++++++++----------- boltconn/src/config/mod.rs | 14 +++++++- boltconn/src/external/instrument_server.rs | 18 +++++++--- boltconn/src/external/mod.rs | 1 + 5 files changed, 51 insertions(+), 28 deletions(-) diff --git a/boltapi/src/instrument.rs b/boltapi/src/instrument.rs index 727e768..0e0dd90 100644 --- a/boltapi/src/instrument.rs +++ b/boltapi/src/instrument.rs @@ -9,12 +9,12 @@ pub struct InstrumentData { impl InstrumentData { pub fn encode_string(&self) -> String { - format!("{:x}:{}", self.id, self.message) + format!("{}:{}", self.id, self.message) } pub fn decode_string(encoded: &str) -> Option { let mut parts = encoded.splitn(2, ':'); - let id = u64::from_str_radix(parts.next()?, 16).ok()?; + let id = parts.next()?.parse::().ok()?; let message = parts.next()?.to_string(); Some(Self { id, message }) } @@ -23,7 +23,7 @@ impl InstrumentData { #[test] fn test_instrument_data() { let data = InstrumentData { - id: 0x1234, + id: 1234, message: "hello".to_string(), }; let encoded = data.encode_string(); diff --git a/boltconn/src/app.rs b/boltconn/src/app.rs index c4c54f4..41c77e7 100644 --- a/boltconn/src/app.rs +++ b/boltconn/src/app.rs @@ -1,12 +1,12 @@ use crate::config::{ - default_inbound_ip_addr, safe_join_path, LinkedState, LoadedConfig, PortOrSocketAddr, - RawDnsConfig, RawInboundConfig, RawInboundServiceConfig, RawRootCfg, RawWebControllerConfig, - SingleOrVec, + default_inbound_ip_addr, safe_join_path, LinkedState, LoadedConfig, RawDnsConfig, + RawInboundConfig, RawInboundServiceConfig, RawInstrumentConfig, RawRootCfg, + RawWebControllerConfig, SingleOrVec, }; use crate::dispatch::{DispatchingBuilder, RuleSet, RuleSetBuilder}; use crate::external::{ - Controller, DatabaseHandle, MmdbReader, SharedDispatching, StreamLoggerSend, UdsController, - UnixListenerGuard, WebController, + Controller, DatabaseHandle, InstrumentServer, MmdbReader, SharedDispatching, StreamLoggerSend, + UdsController, UnixListenerGuard, WebController, }; use crate::instrument::bus::MessageBus; use crate::intercept::{InterceptModifier, InterceptionManager}; @@ -90,6 +90,8 @@ impl App { // initialize resources let dns = initialize_dns(&config.dns, outbound_iface.as_str()).await?; let manager = Arc::new(SessionManager::new()); + // initialize instrumentation + let msg_bus = Arc::new(MessageBus::new()); // Create TUN let will_enable_tun = enable_tun.unwrap_or(config.inbound.enable_tun); @@ -135,9 +137,6 @@ impl App { 9961, ); - // initialize instrumentation - let msg_bus = Arc::new(MessageBus::new()); - // dispatch let ruleset = load_rulesets(&loaded_config)?; let dispatching = Arc::new( @@ -219,8 +218,7 @@ impl App { uds_listener.clone(), ); - let msg_bus2 = msg_bus.clone(); - tokio::spawn(async move { msg_bus2.run().await }); + start_instrument_services(msg_bus.clone(), config.instrument.as_ref()); Ok(Self { config_path, @@ -379,6 +377,16 @@ async fn initialize_dns(config: &RawDnsConfig, outbound_iface: &str) -> anyhow:: }) } +fn start_instrument_services(bus: Arc, config: Option<&RawInstrumentConfig>) { + if let Some(config) = config { + let web_server = InstrumentServer::new(config.api_key.clone(), bus.clone()); + let addr = config.api_addr.as_socket_addr(default_inbound_ip_addr); + let cors_allowed_list = config.cors_allowed_list.clone(); + tokio::spawn(async move { web_server.run(addr, cors_allowed_list.as_slice()).await }); + } + tokio::spawn(async move { bus.run().await }); +} + fn start_tun_services( nat_addr: SocketAddr, manager: Arc, @@ -455,10 +463,7 @@ fn start_controller_services( tokio::spawn(async move { uds_controller.run(uds_listener2).await }); if let Some(web_cfg) = config { - let api_addr = match web_cfg.api_addr { - PortOrSocketAddr::Port(p) => SocketAddr::new(default_inbound_ip_addr(), p), - PortOrSocketAddr::SocketAddr(s) => s, - }; + let api_addr = web_cfg.api_addr.as_socket_addr(default_inbound_ip_addr); let api_server = WebController::new(web_cfg.api_key.clone(), controller); let cors_domains = web_cfg.cors_allowed_list.clone(); tokio::spawn(async move { api_server.run(api_addr, cors_domains.as_slice()).await }); @@ -517,12 +522,7 @@ fn parse_two_inbound_service( .into_iter() .map(|c| match c { RawInboundServiceConfig::Simple(e) => ( - match e { - PortOrSocketAddr::Port(p) => { - SocketAddr::new(default_inbound_ip_addr(), p) - } - PortOrSocketAddr::SocketAddr(s) => s, - }, + e.as_socket_addr(default_inbound_ip_addr), HashMap::default(), ), RawInboundServiceConfig::Complex { host, port, auth } => { diff --git a/boltconn/src/config/mod.rs b/boltconn/src/config/mod.rs index ef147fe..92550bd 100644 --- a/boltconn/src/config/mod.rs +++ b/boltconn/src/config/mod.rs @@ -26,7 +26,7 @@ use serde::{Deserialize, Serialize}; pub use state::*; use std::collections::HashMap; use std::fmt::Debug; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::{fs, io}; @@ -194,6 +194,18 @@ pub enum PortOrSocketAddr { SocketAddr(SocketAddr), } +impl PortOrSocketAddr { + pub fn as_socket_addr(&self, default_fn: F) -> SocketAddr + where + F: FnOnce() -> IpAddr, + { + match self { + PortOrSocketAddr::Port(port) => SocketAddr::from((default_fn(), *port)), + PortOrSocketAddr::SocketAddr(addr) => *addr, + } + } +} + pub(in crate::config) fn default_true() -> bool { true } diff --git a/boltconn/src/external/instrument_server.rs b/boltconn/src/external/instrument_server.rs index 8218568..52eceee 100644 --- a/boltconn/src/external/instrument_server.rs +++ b/boltconn/src/external/instrument_server.rs @@ -21,6 +21,10 @@ pub struct InstrumentServer { } impl InstrumentServer { + pub fn new(secret: Option, msg_bus: Arc) -> Self { + Self { secret, msg_bus } + } + pub async fn run( self, listen_addr: SocketAddr, @@ -38,9 +42,15 @@ impl InstrumentServer { app = app.layer(get_cors_layer(origin)); } - let listener = TcpListener::bind(&listen_addr) - .await - .map_err(SystemError::InstrumentServer)?; + let listener = TcpListener::bind(&listen_addr).await.map_err(|e| { + tracing::error!( + "[Instrument] service failed to bind to {}: {}", + listen_addr, + e + ); + SystemError::InstrumentServer(e) + })?; + tracing::info!("[Instrument] Listening on {}", listen_addr); axum::serve(listener, app.into_make_service()) .await .map_err(|e| SystemError::InstrumentServer(as_io_err(e)))?; @@ -64,7 +74,7 @@ impl InstrumentServer { return refusal_resp(http::StatusCode::BAD_REQUEST); }; for id in s.split(',') { - if let Ok(val) = u64::from_str_radix(id, 16) { + if let Ok(val) = u64::from_str_radix(id, 10) { arr.push(val); } else { return refusal_resp(http::StatusCode::BAD_REQUEST); diff --git a/boltconn/src/external/mod.rs b/boltconn/src/external/mod.rs index 7503f06..43450bd 100644 --- a/boltconn/src/external/mod.rs +++ b/boltconn/src/external/mod.rs @@ -9,6 +9,7 @@ mod web_controller; pub use controller::*; pub use database::*; +pub use instrument_server::*; pub use logger::*; pub use mmdb::*; pub use uds_controller::*; From 04c6da1a7b94fec14464c83d01d9635713e56b6b Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:21:14 -0400 Subject: [PATCH 8/8] feat: instrument server --- boltconn/src/external/instrument_server.rs | 34 ++++++++++++++-------- boltconn/src/instrument/bus.rs | 13 +++++++-- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/boltconn/src/external/instrument_server.rs b/boltconn/src/external/instrument_server.rs index 52eceee..aa59994 100644 --- a/boltconn/src/external/instrument_server.rs +++ b/boltconn/src/external/instrument_server.rs @@ -3,7 +3,7 @@ use crate::external::web_common::{get_cors_layer, parse_cors_allow, web_auth}; use crate::instrument::bus::{BusSubscriber, MessageBus}; use crate::proxy::error::{RuntimeError, SystemError}; use axum::extract::ws::WebSocket; -use axum::extract::{ws, Path, State, WebSocketUpgrade}; +use axum::extract::{ws, Query, State, WebSocketUpgrade}; use axum::middleware::map_request; use axum::response::IntoResponse; use axum::routing::get; @@ -13,6 +13,7 @@ use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::TcpListener; +use tokio::select; #[derive(Clone)] pub struct InstrumentServer { @@ -59,7 +60,7 @@ impl InstrumentServer { async fn subscribe( State(server): State, - Path(params): Path>, + Query(params): Query>, ws: WebSocketUpgrade, ) -> impl IntoResponse { if let Some(secret) = server.secret.as_ref() { @@ -71,12 +72,14 @@ impl InstrumentServer { let ids = { let mut arr = Vec::new(); let Some(s) = params.get("id") else { + tracing::debug!("No id parameter in request"); return refusal_resp(http::StatusCode::BAD_REQUEST); }; for id in s.split(',') { - if let Ok(val) = u64::from_str_radix(id, 10) { + if let Ok(val) = id.parse::() { arr.push(val); } else { + tracing::debug!("Invalid id parameter in request: {} in {}", id, s); return refusal_resp(http::StatusCode::BAD_REQUEST); } } @@ -85,22 +88,29 @@ impl InstrumentServer { let Some(sub) = server.msg_bus.create_subscriber(ids.iter().copied()) else { return refusal_resp(http::StatusCode::CONFLICT); }; - ws.on_upgrade(move |socket| Self::subscribe_inner(socket, sub, ids)) + ws.on_upgrade(move |socket| Self::subscribe_inner(socket, sub)) } - async fn subscribe_inner(mut socket: WebSocket, sub: BusSubscriber, ids: Vec) { - while let Some(msg) = sub.recv().await { + async fn subscribe_inner(mut socket: WebSocket, sub: BusSubscriber) { + loop { + let msg = select! { + r = socket.recv() => { + // client disconnected + if r.is_none(){ + return; + } + // some messages, maybe error + continue; + } + Some(msg) = sub.recv() => msg, + }; let wire_msg = InstrumentData { id: msg.sub_id, message: msg.msg, }; - if let Err(e) = socket + let _ = socket .send(ws::Message::Text(wire_msg.encode_string())) - .await - { - tracing::warn!("Subscriber for {:?} failed to send: {}", ids, e); - break; - } + .await; } } } diff --git a/boltconn/src/instrument/bus.rs b/boltconn/src/instrument/bus.rs index dbd40dc..41d1e86 100644 --- a/boltconn/src/instrument/bus.rs +++ b/boltconn/src/instrument/bus.rs @@ -1,3 +1,4 @@ +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::Mutex; @@ -22,7 +23,6 @@ impl MessageBus { pub async fn run(&self) { while let Ok(msg) = self.ingress_receiver.recv_async().await { - tracing::debug!("[Message Bus {}] {:?} ", msg.sub_id, msg.msg); if let Some(sender) = self.egress_senders.lock().unwrap().get(&msg.sub_id) { let _ = sender.try_send(msg); } @@ -44,7 +44,16 @@ impl MessageBus { let mut egress_senders = self.egress_senders.lock().unwrap(); for sub_id in iter2 { if egress_senders.contains_key(&sub_id) { - return None; + match egress_senders.entry(sub_id) { + Entry::Occupied(e) => { + if e.get().is_disconnected() { + e.remove(); + } else { + return None; + } + } + Entry::Vacant(_) => unreachable!(), + } } } for sub_id in sub_ids {