Desugar statement templates (#226)

* Desugar statement templates

* Support desugaring of SetContains statement templates

* Update the book
This commit is contained in:
Rob Knight 2025-05-09 05:48:18 -07:00 committed by GitHub
parent 726f95483d
commit b2cb563eb6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 158 additions and 14 deletions

View file

@ -24,6 +24,7 @@ pub fn key(s: &str) -> KeyOrWildcardStr {
}
/// Builder Argument for the StatementTmplBuilder
#[derive(Clone)]
pub enum BuilderArg {
Literal(Value),
/// Key: (origin, key), where origin is a Wildcard and key can be both Key or Wildcard
@ -64,6 +65,7 @@ pub fn literal(v: impl Into<Value>) -> BuilderArg {
BuilderArg::Literal(v.into())
}
#[derive(Clone)]
pub struct StatementTmplBuilder {
predicate: Predicate,
args: Vec<BuilderArg>,
@ -81,6 +83,48 @@ impl StatementTmplBuilder {
self.args.push(a.into());
self
}
/// Desugar the predicate to a simpler form
/// Should mirror the logic in `MainPodBuilder::lower_op`
fn desugar(self) -> StatementTmplBuilder {
match self.predicate {
Predicate::Native(NativePredicate::Gt) => {
let mut stb = StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::Lt),
args: self.args,
};
stb.args.swap(0, 1);
stb
}
Predicate::Native(NativePredicate::GtEq) => {
let mut stb = StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::LtEq),
args: self.args,
};
stb.args.swap(0, 1);
stb
}
Predicate::Native(NativePredicate::ArrayContains)
| Predicate::Native(NativePredicate::DictContains) => StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::Contains),
args: self.args,
},
Predicate::Native(NativePredicate::DictNotContains)
| Predicate::Native(NativePredicate::SetNotContains) => StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::NotContains),
args: self.args,
},
Predicate::Native(NativePredicate::SetContains) => {
let mut new_args = self.args.clone();
new_args.push(self.args[1].clone());
StatementTmplBuilder {
predicate: Predicate::Native(NativePredicate::Contains),
args: new_args,
}
}
_ => self,
}
}
}
pub struct CustomPredicateBatchBuilder {
@ -147,7 +191,8 @@ impl CustomPredicateBatchBuilder {
let statements = sts
.iter()
.map(|sb| {
let args = sb
let stb = sb.clone().desugar();
let args = stb
.args
.iter()
.map(|a| match a {
@ -162,7 +207,7 @@ impl CustomPredicateBatchBuilder {
})
.collect();
StatementTmpl {
pred: sb.predicate.clone(),
pred: stb.predicate.clone(),
args,
}
})
@ -204,11 +249,15 @@ fn resolve_wildcard(args: &[&str], priv_args: &[&str], s: &str) -> Wildcard {
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
use crate::{
backends::plonky2::mock::mainpod::MockProver,
examples::custom::{eth_dos_batch, eth_friend_batch},
middleware,
middleware::{CustomPredicateRef, Params, PodType},
frontend::MainPodBuilder,
middleware::{self, containers::Set, CustomPredicateRef, Params, PodType},
op,
};
#[test]
@ -237,4 +286,97 @@ mod tests {
Ok(())
}
#[test]
fn test_desugared_gt_custom_pred() -> Result<()> {
let params = Params::default();
let mut builder = CustomPredicateBatchBuilder::new("gt_custom_pred".into());
let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt)
.arg(("s1_origin", "s1_key"))
.arg(("s2_origin", "s2_key"));
builder.predicate_and(
"gt_custom_pred",
&params,
&["s1_origin", "s1_key", "s2_origin", "s2_key"],
&[],
&[gt_stb],
)?;
let batch = builder.finish();
let batch_clone = batch.clone();
let gt_custom_pred = CustomPredicateRef::new(batch, 0);
let mut mp_builder = MainPodBuilder::new(&params);
// 2 > 1
let s1 = mp_builder.literal(true, Value::from(2))?;
let s2 = mp_builder.literal(true, Value::from(1))?;
// Adding a gt operation will produce a desugared lt operation
let desugared_gt = mp_builder.pub_op(op!(gt, s1, s2))?;
assert_eq!(
desugared_gt.predicate(),
Predicate::Native(NativePredicate::Lt)
);
// Check that the desugared predicate is the same as the one in the statement template
assert_eq!(
desugared_gt.predicate(),
*batch_clone.predicates[0].statements[0].pred()
);
// Check that our custom predicate matches the statement template
// against the desugared gt statement (actually a lt statement)
mp_builder.pub_op(op!(custom, gt_custom_pred, desugared_gt))?;
// Check that the POD builds
let mut prover = MockProver {};
let proof = mp_builder.prove(&mut prover, &params)?;
Ok(())
}
#[test]
fn test_desugared_set_contains_custom_pred() -> Result<()> {
let params = Params::default();
let mut builder = CustomPredicateBatchBuilder::new("set_contains_custom_pred".into());
let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains)
.arg(("s1_origin", "s1_key"))
.arg(("s2_origin", "s2_key"));
builder.predicate_and(
"set_contains_custom_pred",
&params,
&["s1_origin", "s1_key", "s2_origin", "s2_key"],
&[],
&[set_contains_stb],
)?;
let batch = builder.finish();
let batch_clone = batch.clone();
let mut mp_builder = MainPodBuilder::new(&params);
let set_values: HashSet<Value> = [1, 2, 3].iter().map(|i| Value::from(*i)).collect();
let s1 = mp_builder.literal(true, Value::from(Set::new(set_values)?))?;
let s2 = mp_builder.literal(true, Value::from(1))?;
let set_contains = mp_builder.pub_op(op!(set_contains, s1, s2))?;
assert_eq!(
set_contains.predicate(),
Predicate::Native(NativePredicate::Contains)
);
assert_eq!(
set_contains.predicate(),
*batch_clone.predicates[0].statements[0].pred()
);
let set_contains_custom_pred = CustomPredicateRef::new(batch, 0);
mp_builder.pub_op(op!(custom, set_contains_custom_pred, set_contains))?;
let mut prover = MockProver {};
let proof = mp_builder.prove(&mut prover, &params)?;
Ok(())
}
}