diff --git a/h3i/src/frame.rs b/h3i/src/frame.rs index 04de876b6d..6ddfe77f02 100644 --- a/h3i/src/frame.rs +++ b/h3i/src/frame.rs @@ -30,6 +30,7 @@ use std::cmp; use std::convert::TryFrom; use std::error::Error; use std::fmt::Debug; +use std::sync::Arc; use multimap::MultiMap; use quiche; @@ -433,19 +434,88 @@ impl Serialize for SerializableQFrame<'_> { } } -/// A combination of stream ID and [`H3iFrame`] which is used to instruct h3i to -/// watch for specific frames. -#[derive(Debug, PartialEq, Eq, Serialize, Clone)] +type CustomEquivalenceHandler = + Box Fn(&'f H3iFrame) -> bool + Send + Sync + 'static>; + +#[derive(Clone)] +enum Comparator { + Frame(H3iFrame), + /// Specifies how to compare an incoming [`H3iFrame`] with this + /// [`ExpectedFrame`]. Typically, the validation attempts to fuzzy-match + /// the [`ExpectedFrame`] against the incoming [`H3iFrame`], but there + /// are times where other behavior is desired (for example, checking + /// deserialized JSON payloads in a headers frame, or ensuring a random + /// value matches a regex). + /// + /// See [`ExpectedFrame::is_equivalent`] for more on how frames are + /// compared. + Fn(Arc), +} + +impl Serialize for Comparator { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + Self::Fn(_) => serializer.serialize_str(""), + Self::Frame(f) => { + let mut frame_ser = serializer.serialize_struct("frame", 1)?; + frame_ser.serialize_field("frame", f)?; + frame_ser.end() + }, + } + } +} + +/// Instructs h3i to watch for certain incoming [`H3iFrame`]s. The incoming +/// frames can either be supplied directly via [`ExpectedFrame::new`], or via a +/// verification callback passed to [`ExpectedFrame::new_with_comparator`]. +#[derive(Serialize, Clone)] pub struct ExpectedFrame { stream_id: u64, - frame: H3iFrame, + comparator: Comparator, } impl ExpectedFrame { + /// Create a new [`ExpectedFrame`] which should watch for the provided + /// [`H3iFrame`]. + /// + /// # Note + /// + /// For [QuicheH3] and [ResetStream] variants, equivalence is the same as + /// equality. + /// + /// For Headers variants, this [`ExpectedFrame`] is equivalent to the + /// incoming [`H3iFrame`] if the [`H3iFrame`] contains all [`Header`]s + /// in _this_ frame. In other words, `this` can be considered equivalent + /// to `other` if `other` contains a superset of `this`'s [`Headers`]. + /// + /// This allows users for fuzzy-matching on header frames without needing to + /// supply every individual header on the frame. + /// + /// [ResetStream]: H3iFrame::ResetStream + /// [QuicheH3]: H3iFrame::QuicheH3 pub fn new(stream_id: u64, frame: impl Into) -> Self { Self { stream_id, - frame: frame.into(), + comparator: Comparator::Frame(frame.into()), + } + } + + /// Create a new [`ExpectedFrame`] which will match incoming [`H3iFrame`]s + /// according to the passed `comparator_fn`. + /// + /// The `comparator_fn` will be called with every incoming [`H3iFrame`]. It + /// should return `true` if the incoming frame is expected, and `false` + /// if it is not. + pub fn new_with_comparator(stream_id: u64, comparator_fn: F) -> Self + where + F: Fn(&H3iFrame) -> bool + Send + Sync + 'static, + { + Self { + stream_id, + comparator: Comparator::Fn(Arc::new(Box::new(comparator_fn))), } } @@ -453,13 +523,13 @@ impl ExpectedFrame { self.stream_id } - /// Check if this [`ExpectedFrame`] is equivalent to another [`H3iFrame`]. - /// For QuicheH3/ResetStream variants, equivalence is the same as - /// equality. For Headers variants, this [`ExpectedFrame`] is equivalent - /// to another if the other frame contains all [`Header`]s in _this_ - /// frame. pub(crate) fn is_equivalent(&self, other: &H3iFrame) -> bool { - match &self.frame { + let frame = match &self.comparator { + Comparator::Fn(compare) => return compare(other), + Comparator::Frame(frame) => frame, + }; + + match frame { H3iFrame::Headers(me) => { let H3iFrame::Headers(other) = other else { return false; @@ -485,25 +555,127 @@ impl ExpectedFrame { } } +impl Debug for ExpectedFrame { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let repr = match &self.comparator { + Comparator::Frame(frame) => format!("{frame:?}"), + Comparator::Fn(_) => "closure".to_string(), + }; + + write!( + f, + "{}", + format!( + "ExpectedFrame {{ stream_id: {}, comparator: {repr} }}", + self.stream_id + ) + ) + } +} + +impl PartialEq for ExpectedFrame { + fn eq(&self, other: &Self) -> bool { + match (&self.comparator, &other.comparator) { + (Comparator::Frame(this_frame), Comparator::Frame(other_frame)) => + self.stream_id == other.stream_id && this_frame == other_frame, + _ => false, + } + } +} + #[cfg(test)] mod tests { use super::*; + use quiche::h3::frame::Frame; #[test] - fn test_equivalence() { + fn test_header_equivalence() { let this = ExpectedFrame::new(0, vec![ Header::new(b"hello", b"world"), Header::new(b"go", b"jets"), ]); - let other = ExpectedFrame::new(0, vec![ + let other: H3iFrame = vec![ + Header::new(b"hello", b"world"), + Header::new(b"go", b"jets"), + Header::new(b"go", b"devils"), + ] + .into(); + + assert!(this.is_equivalent(&other)); + } + + #[test] + fn test_header_non_equivalence() { + let this = ExpectedFrame::new(0, vec![ Header::new(b"hello", b"world"), Header::new(b"go", b"jets"), Header::new(b"go", b"devils"), ]); + let other: H3iFrame = + vec![Header::new(b"hello", b"world"), Header::new(b"go", b"jets")] + .into(); + + // `other` does not contain the `go: devils` header, so it's not + // equivalent to `this. + assert!(!this.is_equivalent(&other)); + } + + #[test] + fn test_rst_stream_equivalence() { + let mut rs = ResetStream { + stream_id: 0, + error_code: 57, + }; + + let this = ExpectedFrame::new(0, H3iFrame::ResetStream(rs.clone())); + let incoming = H3iFrame::ResetStream(rs.clone()); + assert!(this.is_equivalent(&incoming)); + + rs.stream_id = 57; + let incoming = H3iFrame::ResetStream(rs); + assert!(!this.is_equivalent(&incoming)); + } + + #[test] + fn test_frame_equivalence() { + let mut d = Frame::Data { + payload: b"57".to_vec(), + }; + + let this = ExpectedFrame::new(0, H3iFrame::QuicheH3(d.clone())); + let incoming = H3iFrame::QuicheH3(d.clone()); + assert!(this.is_equivalent(&incoming)); + + d = Frame::Data { + payload: b"go jets".to_vec(), + }; + let incoming = H3iFrame::QuicheH3(d.clone()); + assert!(!this.is_equivalent(&incoming)); + } - assert!(this.is_equivalent(&other.frame)); - // `this` does not contain the `go: devils` header, so `other` is not - // equivalent to `this`. - assert!(!other.is_equivalent(&this.frame)); + #[test] + fn test_comparator() { + let this = ExpectedFrame::new_with_comparator(0, |frame| { + if let H3iFrame::Headers(..) = frame { + frame + .to_enriched_headers() + .unwrap() + .header_map() + .get(&b"cookie".to_vec()) + .is_some_and(|v| { + std::str::from_utf8(v) + .map(|s| s.to_lowercase()) + .unwrap() + .contains("cookie") + }) + } else { + false + } + }); + + let incoming: H3iFrame = + vec![Header::new(b"cookie", b"SomeRandomCookie1234")].into(); + + assert!(this.is_equivalent(&incoming)); } }