diff --git a/book/src/values.md b/book/src/values.md index ea0d9b5..c446f12 100644 --- a/book/src/values.md +++ b/book/src/values.md @@ -40,9 +40,9 @@ The array, set and dictionary types are similar types. While all of them use [a - **array**: the elements are placed at the value field of each leaf, and the key field is just the array index (integer) - `leaf.key=i` - `leaf.value=original_value` -- **set**: the value field of the leaf is unused, and the key contains the hash of the element - - `leaf.key=hash(original_value)` - - `leaf.value=0` +- **set**: both the key and the value are set to the hash of the value. + - `leaf.key=hash(original_value)` + - `leaf.value=hash(original_value)` In the three types, the merkletree under the hood allows to prove inclusion & non-inclusion of the particular entry of the {dictionary/array/set} element. diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 708fe77..81ee86e 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -24,6 +24,7 @@ pub fn key(s: &str) -> KeyOrWildcardStr { } /// Builder Argument for the StatementTmplBuilder +#[derive(Clone)] pub enum BuilderArg { Literal(Value), /// Key: (origin, key), where origin is a Wildcard and key can be both Key or Wildcard @@ -64,6 +65,7 @@ pub fn literal(v: impl Into) -> BuilderArg { BuilderArg::Literal(v.into()) } +#[derive(Clone)] pub struct StatementTmplBuilder { predicate: Predicate, args: Vec, @@ -81,6 +83,48 @@ impl StatementTmplBuilder { self.args.push(a.into()); self } + + /// Desugar the predicate to a simpler form + /// Should mirror the logic in `MainPodBuilder::lower_op` + fn desugar(self) -> StatementTmplBuilder { + match self.predicate { + Predicate::Native(NativePredicate::Gt) => { + let mut stb = StatementTmplBuilder { + predicate: Predicate::Native(NativePredicate::Lt), + args: self.args, + }; + stb.args.swap(0, 1); + stb + } + Predicate::Native(NativePredicate::GtEq) => { + let mut stb = StatementTmplBuilder { + predicate: Predicate::Native(NativePredicate::LtEq), + args: self.args, + }; + stb.args.swap(0, 1); + stb + } + Predicate::Native(NativePredicate::ArrayContains) + | Predicate::Native(NativePredicate::DictContains) => StatementTmplBuilder { + predicate: Predicate::Native(NativePredicate::Contains), + args: self.args, + }, + Predicate::Native(NativePredicate::DictNotContains) + | Predicate::Native(NativePredicate::SetNotContains) => StatementTmplBuilder { + predicate: Predicate::Native(NativePredicate::NotContains), + args: self.args, + }, + Predicate::Native(NativePredicate::SetContains) => { + let mut new_args = self.args.clone(); + new_args.push(self.args[1].clone()); + StatementTmplBuilder { + predicate: Predicate::Native(NativePredicate::Contains), + args: new_args, + } + } + _ => self, + } + } } pub struct CustomPredicateBatchBuilder { @@ -147,7 +191,8 @@ impl CustomPredicateBatchBuilder { let statements = sts .iter() .map(|sb| { - let args = sb + let stb = sb.clone().desugar(); + let args = stb .args .iter() .map(|a| match a { @@ -162,7 +207,7 @@ impl CustomPredicateBatchBuilder { }) .collect(); StatementTmpl { - pred: sb.predicate.clone(), + pred: stb.predicate.clone(), args, } }) @@ -204,11 +249,15 @@ fn resolve_wildcard(args: &[&str], priv_args: &[&str], s: &str) -> Wildcard { #[cfg(test)] mod tests { + use std::collections::HashSet; + use super::*; use crate::{ + backends::plonky2::mock::mainpod::MockProver, examples::custom::{eth_dos_batch, eth_friend_batch}, - middleware, - middleware::{CustomPredicateRef, Params, PodType}, + frontend::MainPodBuilder, + middleware::{self, containers::Set, CustomPredicateRef, Params, PodType}, + op, }; #[test] @@ -237,4 +286,97 @@ mod tests { Ok(()) } + + #[test] + fn test_desugared_gt_custom_pred() -> Result<()> { + let params = Params::default(); + let mut builder = CustomPredicateBatchBuilder::new("gt_custom_pred".into()); + + let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt) + .arg(("s1_origin", "s1_key")) + .arg(("s2_origin", "s2_key")); + + builder.predicate_and( + "gt_custom_pred", + ¶ms, + &["s1_origin", "s1_key", "s2_origin", "s2_key"], + &[], + &[gt_stb], + )?; + let batch = builder.finish(); + let batch_clone = batch.clone(); + let gt_custom_pred = CustomPredicateRef::new(batch, 0); + + let mut mp_builder = MainPodBuilder::new(¶ms); + + // 2 > 1 + let s1 = mp_builder.literal(true, Value::from(2))?; + let s2 = mp_builder.literal(true, Value::from(1))?; + + // Adding a gt operation will produce a desugared lt operation + let desugared_gt = mp_builder.pub_op(op!(gt, s1, s2))?; + assert_eq!( + desugared_gt.predicate(), + Predicate::Native(NativePredicate::Lt) + ); + // Check that the desugared predicate is the same as the one in the statement template + assert_eq!( + desugared_gt.predicate(), + *batch_clone.predicates[0].statements[0].pred() + ); + + // Check that our custom predicate matches the statement template + // against the desugared gt statement (actually a lt statement) + mp_builder.pub_op(op!(custom, gt_custom_pred, desugared_gt))?; + + // Check that the POD builds + let mut prover = MockProver {}; + let proof = mp_builder.prove(&mut prover, ¶ms)?; + + Ok(()) + } + + #[test] + fn test_desugared_set_contains_custom_pred() -> Result<()> { + let params = Params::default(); + let mut builder = CustomPredicateBatchBuilder::new("set_contains_custom_pred".into()); + + let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains) + .arg(("s1_origin", "s1_key")) + .arg(("s2_origin", "s2_key")); + + builder.predicate_and( + "set_contains_custom_pred", + ¶ms, + &["s1_origin", "s1_key", "s2_origin", "s2_key"], + &[], + &[set_contains_stb], + )?; + let batch = builder.finish(); + let batch_clone = batch.clone(); + + let mut mp_builder = MainPodBuilder::new(¶ms); + + let set_values: HashSet = [1, 2, 3].iter().map(|i| Value::from(*i)).collect(); + let s1 = mp_builder.literal(true, Value::from(Set::new(set_values)?))?; + let s2 = mp_builder.literal(true, Value::from(1))?; + + let set_contains = mp_builder.pub_op(op!(set_contains, s1, s2))?; + assert_eq!( + set_contains.predicate(), + Predicate::Native(NativePredicate::Contains) + ); + assert_eq!( + set_contains.predicate(), + *batch_clone.predicates[0].statements[0].pred() + ); + + let set_contains_custom_pred = CustomPredicateRef::new(batch, 0); + mp_builder.pub_op(op!(custom, set_contains_custom_pred, set_contains))?; + + let mut prover = MockProver {}; + let proof = mp_builder.prove(&mut prover, ¶ms)?; + + Ok(()) + } } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 7113b72..f906887 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -9,8 +9,7 @@ use serde::{Deserialize, Serialize}; use crate::middleware::{ self, check_st_tmpl, hash_str, hash_values, AnchoredKey, Hash, Key, MainPodInputs, NativeOperation, NativePredicate, OperationAux, OperationType, Params, PodId, PodProver, - PodSigner, Predicate, Statement, StatementArg, Value, WildcardValue, EMPTY_VALUE, KEY_TYPE, - SELF, + PodSigner, Predicate, Statement, StatementArg, Value, WildcardValue, KEY_TYPE, SELF, }; mod custom; @@ -251,8 +250,11 @@ impl MainPodBuilder { } Native(SetContainsFromEntries) => { let [set, value] = op.1.try_into().unwrap(); // TODO: Error handling - let empty = OperationArg::Literal(Value::from(EMPTY_VALUE)); - Operation(Native(ContainsFromEntries), vec![set, value, empty], op.2) + Operation( + Native(ContainsFromEntries), + vec![set, value.clone(), value], + op.2, + ) } Native(SetNotContainsFromEntries) => { let [set, value] = op.1.try_into().unwrap(); // TODO: Error handling diff --git a/src/middleware/containers.rs b/src/middleware/containers.rs index 1a630a5..3861f54 100644 --- a/src/middleware/containers.rs +++ b/src/middleware/containers.rs @@ -11,7 +11,7 @@ use super::serialization::{ordered_map, ordered_set}; use crate::backends::plonky2::primitives::merkletree::{MerkleProof, MerkleTree}; use crate::{ constants::MAX_DEPTH, - middleware::{hash_value, Error, Hash, Key, RawValue, Result, Value, EMPTY_VALUE}, + middleware::{hash_value, Error, Hash, Key, RawValue, Result, Value}, }; /// Dictionary: the user original keys and values are hashed to be used in the leaf. @@ -129,7 +129,7 @@ impl Set { .iter() .map(|e| { let h = hash_value(&e.raw()); - (RawValue::from(h), EMPTY_VALUE) + (RawValue::from(h), RawValue::from(h)) }) .collect(); Ok(Self { @@ -159,7 +159,7 @@ impl Set { root, proof, &RawValue::from(h), - &EMPTY_VALUE, + &RawValue::from(h), )?) } pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> {