Skip to content

Commit

Permalink
add custom equivalence functions
Browse files Browse the repository at this point in the history
  • Loading branch information
evanrittenhouse committed Jan 6, 2025
1 parent 912e02d commit 6d879e0
Showing 1 changed file with 189 additions and 17 deletions.
206 changes: 189 additions & 17 deletions h3i/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::cmp;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt::Debug;
use std::sync::Arc;

use multimap::MultiMap;
use quiche;
Expand Down Expand Up @@ -433,33 +434,102 @@ impl Serialize for SerializableQFrame<'_> {
}
}

/// A combination of stream ID and [`H3iFrame`] which is used to instruct h3i to
/// watch for specific frames.
#[derive(Debug, PartialEq, Eq, Serialize, Clone)]
type CustomEquivalenceHandler =
Box<dyn for<'f> Fn(&'f H3iFrame) -> bool + Send + Sync + 'static>;

#[derive(Clone)]
enum Comparator {
Frame(H3iFrame),
/// Specifies how to compare an incoming [`H3iFrame`] with this
/// [`ExpectedFrame`]. Typically, the validation attempts to fuzzy-match
/// the [`ExpectedFrame`] against the incoming [`H3iFrame`], but there
/// are times where other behavior is desired (for example, checking
/// deserialized JSON payloads in a headers frame, or ensuring a random
/// value matches a regex).
///
/// See [`ExpectedFrame::is_equivalent`] for more on how frames are
/// compared.
Fn(Arc<CustomEquivalenceHandler>),
}

impl Serialize for Comparator {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Self::Fn(_) => serializer.serialize_str("<comparator_fn>"),
Self::Frame(f) => {
let mut frame_ser = serializer.serialize_struct("frame", 1)?;
frame_ser.serialize_field("frame", f)?;
frame_ser.end()
},
}
}
}

/// Instructs h3i to watch for certain incoming [`H3iFrame`]s. The incoming
/// frames can either be supplied directly via [`ExpectedFrame::new`], or via a
/// verification callback passed to [`ExpectedFrame::new_with_comparator`].
#[derive(Serialize, Clone)]
pub struct ExpectedFrame {
stream_id: u64,
frame: H3iFrame,
comparator: Comparator,
}

impl ExpectedFrame {
/// Create a new [`ExpectedFrame`] which should watch for the provided
/// [`H3iFrame`].
///
/// # Note
///
/// For [QuicheH3] and [ResetStream] variants, equivalence is the same as
/// equality.
///
/// For Headers variants, this [`ExpectedFrame`] is equivalent to the
/// incoming [`H3iFrame`] if the [`H3iFrame`] contains all [`Header`]s
/// in _this_ frame. In other words, `this` can be considered equivalent
/// to `other` if `other` contains a superset of `this`'s [`Headers`].
///
/// This allows users for fuzzy-matching on header frames without needing to
/// supply every individual header on the frame.
///
/// [ResetStream]: H3iFrame::ResetStream
/// [QuicheH3]: H3iFrame::QuicheH3
pub fn new(stream_id: u64, frame: impl Into<H3iFrame>) -> Self {
Self {
stream_id,
frame: frame.into(),
comparator: Comparator::Frame(frame.into()),
}
}

/// Create a new [`ExpectedFrame`] which will match incoming [`H3iFrame`]s
/// according to the passed `comparator_fn`.
///
/// The `comparator_fn` will be called with every incoming [`H3iFrame`]. It
/// should return `true` if the incoming frame is expected, and `false`
/// if it is not.
pub fn new_with_comparator<F>(stream_id: u64, comparator_fn: F) -> Self
where
F: Fn(&H3iFrame) -> bool + Send + Sync + 'static,
{
Self {
stream_id,
comparator: Comparator::Fn(Arc::new(Box::new(comparator_fn))),
}
}

pub(crate) fn stream_id(&self) -> u64 {
self.stream_id
}

/// Check if this [`ExpectedFrame`] is equivalent to another [`H3iFrame`].
/// For QuicheH3/ResetStream variants, equivalence is the same as
/// equality. For Headers variants, this [`ExpectedFrame`] is equivalent
/// to another if the other frame contains all [`Header`]s in _this_
/// frame.
pub(crate) fn is_equivalent(&self, other: &H3iFrame) -> bool {
match &self.frame {
let frame = match &self.comparator {
Comparator::Fn(compare) => return compare(other),
Comparator::Frame(frame) => frame,
};

match frame {
H3iFrame::Headers(me) => {
let H3iFrame::Headers(other) = other else {
return false;
Expand All @@ -485,25 +555,127 @@ impl ExpectedFrame {
}
}

impl Debug for ExpectedFrame {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let repr = match &self.comparator {
Comparator::Frame(frame) => format!("{frame:?}"),
Comparator::Fn(_) => "closure".to_string(),
};

write!(
f,
"{}",
format!(
"ExpectedFrame {{ stream_id: {}, comparator: {repr} }}",
self.stream_id
)
)
}
}

impl PartialEq for ExpectedFrame {
fn eq(&self, other: &Self) -> bool {
match (&self.comparator, &other.comparator) {
(Comparator::Frame(this_frame), Comparator::Frame(other_frame)) =>
self.stream_id == other.stream_id && this_frame == other_frame,
_ => false,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use quiche::h3::frame::Frame;

#[test]
fn test_equivalence() {
fn test_header_equivalence() {
let this = ExpectedFrame::new(0, vec![
Header::new(b"hello", b"world"),
Header::new(b"go", b"jets"),
]);
let other = ExpectedFrame::new(0, vec![
let other: H3iFrame = vec![
Header::new(b"hello", b"world"),
Header::new(b"go", b"jets"),
Header::new(b"go", b"devils"),
]
.into();

assert!(this.is_equivalent(&other));
}

#[test]
fn test_header_non_equivalence() {
let this = ExpectedFrame::new(0, vec![
Header::new(b"hello", b"world"),
Header::new(b"go", b"jets"),
Header::new(b"go", b"devils"),
]);
let other: H3iFrame =
vec![Header::new(b"hello", b"world"), Header::new(b"go", b"jets")]
.into();

// `other` does not contain the `go: devils` header, so it's not
// equivalent to `this.
assert!(!this.is_equivalent(&other));
}

#[test]
fn test_rst_stream_equivalence() {
let mut rs = ResetStream {
stream_id: 0,
error_code: 57,
};

let this = ExpectedFrame::new(0, H3iFrame::ResetStream(rs.clone()));
let incoming = H3iFrame::ResetStream(rs.clone());
assert!(this.is_equivalent(&incoming));

rs.stream_id = 57;
let incoming = H3iFrame::ResetStream(rs);
assert!(!this.is_equivalent(&incoming));
}

#[test]
fn test_frame_equivalence() {
let mut d = Frame::Data {
payload: b"57".to_vec(),
};

let this = ExpectedFrame::new(0, H3iFrame::QuicheH3(d.clone()));
let incoming = H3iFrame::QuicheH3(d.clone());
assert!(this.is_equivalent(&incoming));

d = Frame::Data {
payload: b"go jets".to_vec(),
};
let incoming = H3iFrame::QuicheH3(d.clone());
assert!(!this.is_equivalent(&incoming));
}

assert!(this.is_equivalent(&other.frame));
// `this` does not contain the `go: devils` header, so `other` is not
// equivalent to `this`.
assert!(!other.is_equivalent(&this.frame));
#[test]
fn test_comparator() {
let this = ExpectedFrame::new_with_comparator(0, |frame| {
if let H3iFrame::Headers(..) = frame {
frame
.to_enriched_headers()
.unwrap()
.header_map()
.get(&b"cookie".to_vec())
.is_some_and(|v| {
std::str::from_utf8(v)
.map(|s| s.to_lowercase())
.unwrap()
.contains("cookie")
})
} else {
false
}
});

let incoming: H3iFrame =
vec![Header::new(b"cookie", b"SomeRandomCookie1234")].into();

assert!(this.is_equivalent(&incoming));
}
}

0 comments on commit 6d879e0

Please sign in to comment.