Skip to content

Commit

Permalink
closures: support for any/all, recursive ops, protobuf encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
divarvel committed Jan 2, 2024
1 parent 91c852a commit 3634700
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 132 deletions.
89 changes: 65 additions & 24 deletions biscuit-auth/src/datalog/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ pub enum Op {
Value(Term),
Unary(Unary),
Binary(Binary),
Param(String),
Closure(Vec<String>, Vec<Op>),
Closure(Vec<u32>, Vec<Op>),
}

/// Unary operation code
Expand Down Expand Up @@ -82,6 +81,9 @@ pub enum Binary {
BitwiseOr,
BitwiseXor,
NotEqual,
LazyAnd,
LazyOr,
All,
Any,
}

Expand All @@ -90,36 +92,55 @@ impl Binary {
&self,
left: Term,
mut right: Vec<Op>,
params: &[String],
params: &[u32],
ops: &mut Vec<Op>,
values: &HashMap<u32, Term>,
symbols: &mut TemporarySymbolTable,
) -> Result<Term, error::Expression> {
match (self, left, params) {
(Binary::Or, Term::Bool(true), []) => Ok(Term::Bool(true)),
(Binary::Or, Term::Bool(false), []) => {
(Binary::LazyOr, Term::Bool(true), []) => Ok(Term::Bool(true)),
(Binary::LazyOr, Term::Bool(false), []) => {
ops.push(Op::Binary(Binary::Or));
right.reverse();
for op in right {
ops.push(op);
}
Ok(Term::Bool(false))
}
(Binary::And, Term::Bool(false), []) => Ok(Term::Bool(false)),
(Binary::And, Term::Bool(true), []) => {
(Binary::LazyAnd, Term::Bool(false), []) => Ok(Term::Bool(false)),
(Binary::LazyAnd, Term::Bool(true), []) => {
ops.push(Op::Binary(Binary::And));
for op in right {
ops.push(op);
}
Ok(Term::Bool(true))
}
(Binary::Any, Term::Set(set_values), [param_name]) => {
(Binary::All, Term::Set(set_values), [param]) => {
for value in set_values.iter() {
let ops = right
.clone()
.iter()
.map(|op| match op {
Op::Param(p) if p == param_name => Op::Value(value.clone()),
Op::Value(Term::Variable(v)) if v == param => Op::Value(value.clone()),
_ => op.clone(),
})
.collect::<Vec<_>>();
let e = Expression { ops };
match e.evaluate(values, symbols)? {
Term::Bool(true) => {}
Term::Bool(false) => return Ok(Term::Bool(false)),
_ => return Err(error::Expression::InvalidType),
};
}
Ok(Term::Bool(false))
}
(Binary::Any, Term::Set(set_values), [param]) => {
for value in set_values.iter() {
let ops = right
.clone()
.iter()
.map(|op| match op {
Op::Value(Term::Variable(v)) if v == param => Op::Value(value.clone()),
_ => op.clone(),
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -290,14 +311,17 @@ impl Binary {
Binary::BitwiseAnd => format!("{} & {}", left, right),
Binary::BitwiseOr => format!("{} | {}", left, right),
Binary::BitwiseXor => format!("{} ^ {}", left, right),
Binary::Any => todo!(),
Binary::LazyAnd => format!("{left} && {right}"),
Binary::LazyOr => format!("{left} || {right}"),
Binary::All => format!("{left}.all({right})"),
Binary::Any => format!("{left}.any({right})"),
}
}
}

#[derive(Clone, Debug)]
enum StackElem {
Closure(Vec<String>, Vec<Op>),
Closure(Vec<u32>, Vec<Op>),
Term(Term),
}

Expand All @@ -313,8 +337,8 @@ impl Expression {
ops.reverse();

while let Some(op) = ops.pop() {
println!("ops: {ops:?}");
println!("op: {:?}\t| stack: {:?}", op, stack);
// println!("ops: {ops:?}");
// println!("op: {:?}\t| stack: {:?}", op, stack);

match op {
Op::Value(Term::Variable(i)) => match values.get(&i) {
Expand Down Expand Up @@ -351,10 +375,9 @@ impl Expression {
return Err(error::Expression::InvalidStack);
}
},
Op::Closure(param, ops) => {
stack.push(StackElem::Closure(param, ops));
Op::Closure(params, ops) => {
stack.push(StackElem::Closure(params, ops));
}
Op::Param(_) => todo!(),
}
}
println!("stack: {stack:?}");
Expand Down Expand Up @@ -384,8 +407,24 @@ impl Expression {
(Some(right), Some(left)) => stack.push(binary.print(left, right, symbols)),
_ => return None,
},
Op::Closure(_, _) => stack.push("todo".to_owned()),
Op::Param(_) => {}
Op::Closure(params, ops) => {
let exp_body = Expression { ops: ops.clone() };
let body = match exp_body.print(symbols) {
Some(c) => c,
_ => return None,
};

if params.is_empty() {
stack.push(body);
} else {
let param_group = params
.iter()
.map(|s| symbols.print_term(&Term::Variable(*s)))
.collect::<Vec<_>>()
.join(", ");
stack.push(format!("{param_group} -> {body}"));
}
}
}
}

Expand Down Expand Up @@ -581,10 +620,10 @@ mod tests {
vec![
Op::Value(Term::Bool(true)),
Op::Closure(vec![], vec![Op::Value(Term::Bool(true))]),
Op::Binary(Binary::And),
Op::Binary(Binary::LazyAnd),
],
),
Op::Binary(Binary::Or),
Op::Binary(Binary::LazyOr),
];
let e2 = Expression { ops: ops2 };

Expand All @@ -594,17 +633,19 @@ mod tests {

#[test]
fn any() {
let symbols = SymbolTable::new();
let mut symbols = TemporarySymbolTable::new(&symbols);
let mut symbols = SymbolTable::new();
let p = symbols.insert("param") as u32;
let mut tmp_symbols = TemporarySymbolTable::new(&symbols);

let ops1 = vec![
Op::Value(Term::Set([Term::Bool(false), Term::Bool(true)].into())),
Op::Closure(vec!["0".to_owned()], vec![Op::Param("0".to_owned())]),
Op::Closure(vec![p], vec![Op::Value(Term::Variable(p))]),
Op::Binary(Binary::Any),
];
let e1 = Expression { ops: ops1 };
println!("{:?}", e1.print(&symbols));

let res1 = e1.evaluate(&HashMap::new(), &mut symbols).unwrap();
let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap();
assert_eq!(res1, Term::Bool(true));
}
}
Loading

0 comments on commit 3634700

Please sign in to comment.