diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 5d8885b..6d61ea1 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -32,9 +32,10 @@ use crate::{ }, middleware::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, - NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg, - StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, - HASH_SIZE, STATEMENT_ARG_F_LEN, VALUE_SIZE, + NativePredicate, OperationType, Params, Predicate, PredicateOrWildcard, + PredicateOrWildcardPrefix, PredicatePrefix, RawValue, StatementArg, StatementTmpl, + StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, + STATEMENT_ARG_F_LEN, VALUE_SIZE, }, }; @@ -46,6 +47,22 @@ pub struct ValueTarget { pub elements: [Target; VALUE_SIZE], } +impl From for HashOutTarget { + fn from(v: ValueTarget) -> HashOutTarget { + HashOutTarget { + elements: v.elements, + } + } +} + +impl From for ValueTarget { + fn from(h: HashOutTarget) -> ValueTarget { + ValueTarget { + elements: h.elements, + } + } +} + impl ValueTarget { pub fn zero(builder: &mut CircuitBuilder) -> Self { Self { @@ -524,18 +541,112 @@ impl StatementTmplArgTarget { } } +#[derive(Clone, Serialize, Deserialize)] +pub struct PredicateHashOrWildcardTarget { + /// layout: `prefix | [data]`, where data is predicate_hash or wildcard_index + pub elements: [Target; Params::pred_hash_or_wc_size()], +} + +impl PredicateHashOrWildcardTarget { + pub fn new(prefix: Target, data: ValueTarget) -> Self { + let v = data.elements; + Self { + elements: [prefix, v[0], v[1], v[2], v[3]], + } + } + pub fn new_pred_hash(builder: &mut CircuitBuilder, pred_hash: HashOutTarget) -> Self { + Self::new( + builder.constant(F::from(PredicateOrWildcardPrefix::Predicate)), + ValueTarget::from(pred_hash), + ) + } + pub fn is_pred(&self, builder: &mut CircuitBuilder) -> BoolTarget { + let prefix_pred = builder.constant(F::from(PredicateOrWildcardPrefix::Predicate)); + builder.is_equal(self.elements[0], prefix_pred) + } + pub fn data(&self) -> ValueTarget { + ValueTarget { + elements: self.elements[1..].try_into().expect("4 elements"), + } + } + pub fn pred_hash(&self) -> HashOutTarget { + HashOutTarget::from(self.data()) + } + pub fn wc_index(&self) -> Target { + self.elements[1] + } + pub fn set_targets_raw( + &self, + pw: &mut PartialWitness, + prefix: PredicateOrWildcardPrefix, + data: RawValue, + ) -> Result<()> { + pw.set_target(self.elements[0], F::from(prefix))?; + pw.set_target_arr(&self.elements[1..], &data.0)?; + Ok(()) + } + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + pred: &PredicateOrWildcard, + ) -> Result<()> { + match pred { + PredicateOrWildcard::Predicate(pred) => { + self.set_targets_raw( + pw, + PredicateOrWildcardPrefix::Predicate, + RawValue::from(pred.hash(params)), + )?; + } + PredicateOrWildcard::Wildcard(wc) => { + self.set_targets_raw( + pw, + PredicateOrWildcardPrefix::Wildcard, + RawValue([F::from_canonical_usize(wc.index), F::ZERO, F::ZERO, F::ZERO]), + )?; + } + } + Ok(()) + } +} + +impl Flattenable for PredicateHashOrWildcardTarget { + fn flatten(&self) -> Vec { + self.elements.to_vec() + } + fn from_flattened(_params: &Params, vs: &[Target]) -> Self { + Self { + elements: vs.try_into().expect("5 elements"), + } + } + fn size(_params: &Params) -> usize { + Params::pred_hash_or_wc_size() + } +} + #[derive(Clone, Serialize, Deserialize)] pub struct StatementTmplTarget { + /// The preimage of the predicate_hash. This predicate is needed only to build the custom + /// predicate table because it needs to normalize statement templates with predicates that + /// refer to self into content-addressed predicates (using the batch id and index). The + /// predicate type is inspected to do this normalization. After the table is built we only use + /// the predicate hash for equality checks. pred: Option, - pred_hash: HashOutTarget, + /// This is constrained to be `hash(pred)` through the type constructor when we have `pred` + /// and the template uses a predicate and not a wildcard. + pred_hash_or_wc: PredicateHashOrWildcardTarget, pub args: Vec, } impl StatementTmplTarget { - pub fn new(pred_hash: HashOutTarget, args: Vec) -> Self { + pub fn new( + pred_hash_or_wc: PredicateHashOrWildcardTarget, + args: Vec, + ) -> Self { Self { pred: None, - pred_hash, + pred_hash_or_wc, args, } } @@ -546,9 +657,22 @@ impl StatementTmplTarget { st_tmpl: &StatementTmpl, ) -> Result<()> { if let Some(pred) = &self.pred { - pred.set_targets(pw, params, &st_tmpl.pred)?; + match &st_tmpl.pred_or_wc { + PredicateOrWildcard::Predicate(p) => { + // We store a predicate (not a wildcard) and we have it available. In this + // case the hash will be calculated by constraints later on and we should not + // rely on the original data. + pred.set_targets(pw, params, p)? + } + PredicateOrWildcard::Wildcard(_wc) => { + // Fill in with a recognizable constant for better debugging; this value is + // not supposed to be used. + pw.set_target_arr(&pred.elements, &[F(0xdead); Params::predicate_size()])? + } + } } - pw.set_hash_target(self.pred_hash, HashOut::from(st_tmpl.pred.hash(params)))?; + self.pred_hash_or_wc + .set_targets(pw, params, &st_tmpl.pred_or_wc)?; let arg_pad = StatementTmplArg::None; for (i, arg) in st_tmpl .args @@ -564,8 +688,8 @@ impl StatementTmplTarget { pub fn pred(&self) -> Option<&PredicateTarget> { self.pred.as_ref() } - pub fn pred_hash(&self) -> &HashOutTarget { - &self.pred_hash + pub fn pred_hash_or_wc(&self) -> &PredicateHashOrWildcardTarget { + &self.pred_hash_or_wc } } @@ -603,6 +727,8 @@ impl CustomPredicateTarget { } } +/// This type is used to build the custom predicate table, which exposes the custom predicates with +/// normalized statement templates indexed by batch_id and custom_predicate_index. #[derive(Clone, Serialize, Deserialize)] pub struct CustomPredicateBatchTarget { pub predicates: Vec, @@ -660,15 +786,17 @@ impl CustomPredicateEntryTarget { .clone() .into_iter() .map(|st_tmpl| { - let pred = match st_tmpl.pred { - Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef { - batch: batch.clone(), - index: i, - }), - p => p, + let pred_or_wc = match st_tmpl.pred_or_wc { + PredicateOrWildcard::Predicate(Predicate::BatchSelf(i)) => { + PredicateOrWildcard::Predicate(Predicate::Custom(CustomPredicateRef { + batch: batch.clone(), + index: i, + })) + } + x => x.clone(), }; StatementTmpl { - pred, + pred_or_wc, args: st_tmpl.args, } }) @@ -724,7 +852,7 @@ pub struct CustomPredicateVerifyEntryTarget { } impl CustomPredicateVerifyEntryTarget { - pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder, with_pred: bool) -> Self { + pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { let custom_predicate_table_len = params.max_custom_predicate_batches * params.max_custom_batch_size; CustomPredicateVerifyEntryTarget { @@ -732,7 +860,7 @@ impl CustomPredicateVerifyEntryTarget { custom_predicate_table_len, builder, ), - custom_predicate: builder.add_virtual_custom_predicate_entry(params, with_pred), + custom_predicate: builder.add_virtual_custom_predicate_entry(params), args: (0..params.max_custom_predicate_wildcards) .map(|_| builder.add_virtual_value()) .collect(), @@ -1062,7 +1190,7 @@ impl Flattenable for CustomPredicateTarget { impl Flattenable for StatementTmplTarget { fn flatten(&self) -> Vec { - self.pred_hash + self.pred_hash_or_wc .flatten() .into_iter() .chain(self.args.iter().flat_map(|sta| sta.flatten())) @@ -1071,24 +1199,27 @@ impl Flattenable for StatementTmplTarget { fn from_flattened(params: &Params, v: &[Target]) -> Self { assert_eq!(v.len(), Self::size(params)); - let pred_hash_end = HASH_SIZE; - let pred_hash = HashOutTarget::from_flattened(params, &v[..pred_hash_end]); + let pred_hash_or_wc_end = Params::pred_hash_or_wc_size(); + let pred_hash_or_wc = + PredicateHashOrWildcardTarget::from_flattened(params, &v[..pred_hash_or_wc_end]); let sta_size = Params::statement_tmpl_arg_size(); let args = (0..params.max_statement_args) .map(|i| { - let sta_v = &v[pred_hash_end + sta_size * i..pred_hash_end + sta_size * (i + 1)]; + let sta_v = &v + [pred_hash_or_wc_end + sta_size * i..pred_hash_or_wc_end + sta_size * (i + 1)]; StatementTmplArgTarget::from_flattened(params, sta_v) }) .collect(); Self { pred: None, - pred_hash, + pred_hash_or_wc, args, } } fn size(params: &Params) -> usize { - HASH_SIZE + params.max_statement_args * StatementTmplArgTarget::size(params) + Params::pred_hash_or_wc_size() + + params.max_statement_args * StatementTmplArgTarget::size(params) } } @@ -1168,11 +1299,8 @@ pub trait CircuitBuilderPod, const D: usize> { params: &Params, with_pred: bool, ) -> CustomPredicateBatchTarget; - fn add_virtual_custom_predicate_entry( - &mut self, - params: &Params, - with_pred: bool, - ) -> CustomPredicateEntryTarget; + fn add_virtual_custom_predicate_entry(&mut self, params: &Params) + -> CustomPredicateEntryTarget; fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; fn select_statement_arg( &mut self, @@ -1320,24 +1448,32 @@ impl CircuitBuilderPod for CircuitBuilder { } } - /// If `with_pred = true` a predicate is included and its hash constrained. + /// If `with_pred = true` a predicate is included. /// If `with_pred = false` only the predicate hash is included. + /// The pred_hash is constrained to be hash(pred) conditionally on the template using a + /// predicate and not a wildcard. fn add_virtual_statement_tmpl( &mut self, params: &Params, with_pred: bool, ) -> StatementTmplTarget { - let (pred, pred_hash) = if with_pred { + let pred_hash_or_wc = + PredicateHashOrWildcardTarget::new(self.add_virtual_target(), self.add_virtual_value()); + let pred = if with_pred { let pred = self.add_virtual_predicate(); let pred_hash = pred.hash(self); - (Some(pred), pred_hash) + let is_pred = pred_hash_or_wc.is_pred(self); + let data = pred_hash_or_wc.data(); + for i in 0..VALUE_SIZE { + self.conditional_assert_eq(is_pred.target, data.elements[i], pred_hash.elements[i]); + } + Some(pred) } else { - let pred_hash = self.add_virtual_hash(); - (None, pred_hash) + None }; StatementTmplTarget { pred, - pred_hash, + pred_hash_or_wc, args: (0..params.max_statement_args) .map(|_| self.add_virtual_statement_tmpl_arg()) .collect(), @@ -1377,12 +1513,11 @@ impl CircuitBuilderPod for CircuitBuilder { fn add_virtual_custom_predicate_entry( &mut self, params: &Params, - with_pred: bool, ) -> CustomPredicateEntryTarget { CustomPredicateEntryTarget { id: self.add_virtual_hash(), index: self.add_virtual_target(), - predicate: self.add_virtual_custom_predicate(params, with_pred), + predicate: self.add_virtual_custom_predicate(params, false), } } diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 86d72d4..fe23403 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -25,8 +25,8 @@ use crate::{ CustomPredicateTarget, CustomPredicateVerifyEntryTarget, CustomPredicateVerifyQueryTarget, Flattenable, MerkleClaimTarget, MerkleTreeStateTransitionClaimTarget, OperationTarget, OperationTypeTarget, - PredicateTarget, StatementArgTarget, StatementTarget, StatementTmplArgTarget, - StatementTmplTarget, ValueTarget, + PredicateHashOrWildcardTarget, PredicateTarget, StatementArgTarget, + StatementTarget, StatementTmplArgTarget, StatementTmplTarget, ValueTarget, }, hash::{hash_from_state_circuit, precompute_hash_state}, mux_table::{MuxTableTarget, TableEntryTarget}, @@ -341,12 +341,7 @@ fn build_operation_aux_table_circuit( .chain(signed_by.pk.u.components) .collect(), ); - let entry: MsgPubKeyTarget = HashPairTarget( - HashOutTarget { - elements: signed_by.msg.elements, - }, - pk_hash, - ); + let entry: MsgPubKeyTarget = HashPairTarget(HashOutTarget::from(signed_by.msg), pk_hash); table.push(builder, OperationAuxTableTag::SignedBy as u32, &entry); measure_gates_end!(builder, measure); @@ -1381,6 +1376,26 @@ fn make_statement_arg_from_template_circuit( StatementArgTarget::new(first, second) } +fn make_predicate_from_template_circuit( + params: &Params, + builder: &mut CircuitBuilder, + pred_hash_or_wc: &PredicateHashOrWildcardTarget, + args: &[ValueTarget], +) -> HashOutTarget { + let zero = builder.zero(); + let is_pred = pred_hash_or_wc.is_pred(builder); + // If the index is not used, use a 0 instead to still pass the range constraints from + // vec_ref + let index = builder.select(is_pred, zero, pred_hash_or_wc.wc_index()); + let resolved_pred_hash = HashOutTarget::from(builder.vec_ref_small(params, args, index)); + builder.select_flattenable( + params, + is_pred, + &pred_hash_or_wc.pred_hash(), + &resolved_pred_hash, + ) +} + fn make_statement_from_template_circuit( params: &Params, builder: &mut CircuitBuilder, @@ -1388,7 +1403,7 @@ fn make_statement_from_template_circuit( args: &[ValueTarget], ) -> StatementTarget { let measure = measure_gates_begin!(builder, "StArgFromTmpl"); - let args = st_tmpl + let st_args = st_tmpl .args .iter() .map(|st_tmpl_arg| { @@ -1396,7 +1411,11 @@ fn make_statement_from_template_circuit( }) .collect(); measure_gates_end!(builder, measure); - StatementTarget::new(*st_tmpl.pred_hash(), args) + let measure = measure_gates_begin!(builder, "PredFromTmpl"); + let pred_hash = + make_predicate_from_template_circuit(params, builder, st_tmpl.pred_hash_or_wc(), args); + measure_gates_end!(builder, measure); + StatementTarget::new(pred_hash, st_args) } /// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard @@ -1527,13 +1546,29 @@ fn normalize_st_tmpl_circuit( st_tmpl: &StatementTmplTarget, id: HashOutTarget, ) -> StatementTmplTarget { - let pred = st_tmpl.pred().expect("StatementTmpl contains predicate"); + // If the custom predicate is self, we normalize it and then hash it. + let old_pred = st_tmpl.pred().expect("StatementTmpl contains predicate"); let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf)); - let is_batch_self = builder.is_equal(pred.elements[0], prefix_batch_self); - let pred_index = pred.elements[1]; - let custom_pred = PredicateTarget::new_custom(builder, id, pred_index); - let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, pred); - StatementTmplTarget::new(pred.hash(builder), st_tmpl.args.clone()) + let is_batch_self = builder.is_equal(old_pred.elements[0], prefix_batch_self); + + let pred_index = old_pred.elements[1]; + let normalized_custom_pred = PredicateTarget::new_custom(builder, id, pred_index); + let normalized_custom_pred_hash = normalized_custom_pred.hash(builder); + + // If the template is using a predicate and it is batch self we use the freshly computed + // normalized predicate hash, otherwise we keep the original data. + let old_data = st_tmpl.pred_hash_or_wc().data(); + let is_pred = st_tmpl.pred_hash_or_wc().is_pred(builder); + let is_pred_batch_self = builder.and(is_pred, is_batch_self); + let data = builder.select_flattenable( + params, + is_pred_batch_self, + &ValueTarget::from(normalized_custom_pred_hash), + &old_data, + ); + let pred_hash_or_wc = + PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data); + StatementTmplTarget::new(pred_hash_or_wc, st_tmpl.args.clone()) } /// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as @@ -1773,7 +1808,7 @@ impl MainPodVerifyTarget { .map(|_| builder.add_virtual_custom_predicate_batch(params, true)) .collect(), custom_predicate_verifications: (0..params.max_custom_predicate_verifications) - .map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder, false)) + .map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder)) .collect(), } } @@ -2012,8 +2047,8 @@ mod tests { dict, frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, middleware::{ - hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, RawValue, StatementArg, - StatementTmpl, StatementTmplArg, Wildcard, + hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, + RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard, }, }; @@ -3124,7 +3159,7 @@ mod tests { let dict = Hash([F(6), F(7), F(8), F(9)]); let st_tmpl = StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)), args: vec![ StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), StatementTmplArg::Literal(Value::from("value")), @@ -3137,6 +3172,21 @@ mod tests { ); helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; + let st_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)), + args: vec![ + StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), + StatementTmplArg::Literal(Value::from("value")), + ], + }; + let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(¶ms); + let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)]; + let expected_st = Statement::not_equal( + AnchoredKey::new(dict, Key::from("key")), + Value::from("value"), + ); + helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; + Ok(()) } @@ -3150,7 +3200,7 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params, false); + let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params); let op_args_target: Vec<_> = (0..args.len()) .map(|_| builder.add_virtual_statement(params, false)) .collect(); diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 9f1b38f..be40a90 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -7,7 +7,7 @@ use crate::{ frontend::{AnchoredKey, Error, Result, Statement, StatementArg}, middleware::{ self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params, - Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard, + Predicate, PredicateOrWildcard, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard, }, }; @@ -217,7 +217,8 @@ impl CustomPredicateBatchBuilder { }) .collect::>()?; Ok(StatementTmpl { - pred: stb.predicate.clone(), + // TODO: Support wildcard + pred_or_wc: PredicateOrWildcard::Predicate(stb.predicate.clone()), args, }) }) @@ -319,7 +320,10 @@ mod tests { // Check that the desugared predicate is the same as the one in the statement template assert_eq!( desugared_gt.predicate(), - *batch_clone.predicates()[0].statements[0].pred() + *batch_clone.predicates()[0].statements[0] + .pred_or_wc() + .as_pred() + .unwrap() ); // Check that our custom predicate matches the statement template @@ -366,7 +370,10 @@ mod tests { ); assert_eq!( set_contains.predicate(), - *batch_clone.predicates()[0].statements[0].pred() + *batch_clone.predicates()[0].statements[0] + .pred_or_wc() + .as_pred() + .unwrap() ); let set_contains_custom_pred = CustomPredicateRef::new(batch, 0); diff --git a/src/frontend/pod_request.rs b/src/frontend/pod_request.rs index 2de199c..c804610 100644 --- a/src/frontend/pod_request.rs +++ b/src/frontend/pod_request.rs @@ -3,7 +3,9 @@ use std::{collections::HashMap, fmt::Display}; use crate::{ frontend::{Error, Result}, lang::PrettyPrint, - middleware::{Pod, Statement, StatementArg, StatementTmpl, StatementTmplArg, Value}, + middleware::{ + Pod, PredicateOrWildcard, Statement, StatementArg, StatementTmpl, StatementTmplArg, Value, + }, }; /// Represents a request for a POD, in terms of a set of statement templates. @@ -76,7 +78,8 @@ impl PodRequest { statement: &Statement, current_bindings: &HashMap, ) -> Option> { - if template.pred != statement.predicate() { + // TODO: Support wildcard + if template.pred_or_wc != PredicateOrWildcard::Predicate(statement.predicate()) { return None; } diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index a8863d0..55d4591 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -18,8 +18,8 @@ use crate::{ }, middleware::{ self, containers, CustomPredicateBatch, IntroPredicateRef, NativePredicate, Params, - Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg, - Wildcard, + Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl, + StatementTmplArg as MWStatementTmplArg, Wildcard, }, }; @@ -201,7 +201,8 @@ impl<'a> Lowerer<'a> { } Ok(MWStatementTmpl { - pred: predicate, + // TODO: Support wildcard + pred_or_wc: PredicateOrWildcard::Predicate(predicate), args: mw_args, }) } @@ -596,7 +597,10 @@ mod tests { let stmt = &pred2.statements()[0]; // Should be BatchSelf(0) referring to pred1 - assert!(matches!(stmt.pred, Predicate::BatchSelf(0))); + assert!(matches!( + stmt.pred_or_wc, + PredicateOrWildcard::Predicate(Predicate::BatchSelf(0)) + )); } #[test] @@ -632,8 +636,8 @@ mod tests { // Should desugar to the Contains predicate assert!(matches!( - stmt.pred, - Predicate::Native(NativePredicate::Contains) + stmt.pred_or_wc, + PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Contains)) )); } diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 053062d..da21569 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -63,8 +63,8 @@ mod tests { backends::plonky2::primitives::ec::schnorr::SecretKey, middleware::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate, - Params, Predicate, RawValue, StatementTmpl, StatementTmplArg, Value, Wildcard, - EMPTY_HASH, + Params, Predicate, PredicateOrWildcard, RawValue, StatementTmpl, StatementTmplArg, + Value, Wildcard, EMPTY_HASH, }, }; @@ -89,6 +89,10 @@ mod tests { names.iter().map(|s| s.to_string()).collect() } + fn pred_lit(pred: Predicate) -> PredicateOrWildcard { + PredicateOrWildcard::Predicate(pred) + } + #[test] fn test_e2e_simple_predicate() -> Result<(), LangError> { let input = r#" @@ -109,7 +113,7 @@ mod tests { // Expected structure let expected_statements = vec![StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak(("PodA", 0), "the_key"), // PodA["the_key"] -> Wildcard(0), Key("the_key") sta_ak(("PodB", 1), "the_key"), // PodB["the_key"] -> Wildcard(1), Key("the_key") @@ -153,14 +157,14 @@ mod tests { // Expected structure let expected_templates = vec![ StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak(("ConstPod", 0), "my_val"), // ConstPod["my_val"] -> Wildcard(0), Key("my_val") sta_lit(RawValue::from(1)), ], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Lt), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)), args: vec![ sta_ak(("GovPod", 1), "dob"), // GovPod["dob"] -> Wildcard(1), Key("dob") sta_ak(("ConstPod", 0), "my_val"), // ConstPod["my_val"] -> Wildcard(0), Key("my_val") @@ -195,14 +199,14 @@ mod tests { // Expected structure: Public args: A (index 0). Private args: Temp (index 1) let expected_statements = vec![ StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak(("A", 0), "input_key"), // A["input_key"] -> Wildcard(0), Key("input_key") sta_ak(("Temp", 1), "const_key"), // Temp["const_key"] -> Wildcard(1), Key("const_key") ], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak(("Temp", 1), "const_key"), // Temp["const_key"] -> Wildcard(1), Key("const_key") sta_lit("some_value"), // Literal("some_value") @@ -251,7 +255,7 @@ mod tests { // Expected Batch structure let expected_pred_statements = vec![StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak(("X", 0), "val"), // X["val"] -> Wildcard(0), Key("val") sta_ak(("Y", 1), "val"), // Y["val"] -> Wildcard(1), Key("val") @@ -275,7 +279,10 @@ mod tests { // Expected Request structure // Pod1 -> Wildcard 0, Pod2 -> Wildcard 1 let expected_request_templates = vec![StatementTmpl { - pred: Predicate::Custom(CustomPredicateRef::new(expected_batch, 0)), + pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( + expected_batch, + 0, + ))), args: vec![ StatementTmplArg::Wildcard(wc("Pod1", 0)), StatementTmplArg::Wildcard(wc("Pod2", 1)), @@ -317,7 +324,7 @@ mod tests { // Expected structure let expected_templates = vec![ StatementTmpl { - pred: Predicate::Custom(CustomPredicateRef::new(batch_result, 0)), // Refers to some_pred + pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch_result, 0))), // Refers to some_pred args: vec![ StatementTmplArg::Wildcard(wc("Var1", 0)), // Var1 StatementTmplArg::Literal(Value::from(12345i64)), // 12345 @@ -325,7 +332,7 @@ mod tests { ], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ // AnotherPod["another_key"] -> Wildcard(1), Key("another_key") sta_ak(("AnotherPod", 1), "another_key"), @@ -362,15 +369,15 @@ mod tests { let expected_templates = vec![ StatementTmpl { - pred: Predicate::Native(NativePredicate::LtEq), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::LtEq)), args: vec![sta_ak(("B", 1), "bar"), sta_ak(("A", 0), "foo")], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Lt), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)), args: vec![sta_ak(("D", 3), "qux"), sta_ak(("C", 2), "baz")], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Contains), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Contains)), args: vec![ sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar"), @@ -378,11 +385,11 @@ mod tests { ], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::NotContains), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::NotContains)), args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Contains), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Contains)), args: vec![ sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar"), @@ -439,7 +446,7 @@ mod tests { let expected_templates = vec![ // 1. NotContains(sanctions["sanctionList"], gov["idNumber"]) StatementTmpl { - pred: Predicate::Native(NativePredicate::NotContains), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::NotContains)), args: vec![ sta_ak( (wc_sanctions.name.as_str(), wc_sanctions.index), @@ -450,7 +457,7 @@ mod tests { }, // 2. Lt(gov["dateOfBirth"], SELF_HOLDER_18Y["const_18y"]) StatementTmpl { - pred: Predicate::Native(NativePredicate::Lt), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)), args: vec![ sta_ak((wc_gov.name.as_str(), wc_gov.index), dob_key), sta_ak( @@ -461,7 +468,7 @@ mod tests { }, // 3. Equal(pay["startDate"], SELF_HOLDER_1Y["const_1y"]) StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak((wc_pay.name.as_str(), wc_pay.index), start_date_key), sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key), @@ -469,7 +476,7 @@ mod tests { }, // 4. Equal(gov["socialSecurityNumber"], pay["socialSecurityNumber"]) StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak((wc_gov.name.as_str(), wc_gov.index), ssn_key), sta_ak((wc_pay.name.as_str(), wc_pay.index), ssn_key), @@ -477,7 +484,7 @@ mod tests { }, // 5. Equal(SELF_HOLDER_18Y["const_18y"], 1169909388) StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak( (wc_self_18y.name.as_str(), wc_self_18y.index), @@ -488,7 +495,7 @@ mod tests { }, // 6. Equal(SELF_HOLDER_1Y["const_1y"], 1706367566) StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key), sta_lit(now_minus_1y_val.clone()), @@ -563,11 +570,11 @@ mod tests { // eth_friend (Index 0) let expected_friend_stmts = vec![ StatementTmpl { - pred: Predicate::Native(NativePredicate::SignedBy), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::SignedBy)), args: vec![sta_wc_lit("attestation_dict", 2), sta_wc_lit("src", 0)], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak(("attestation_dict", 2), "attestation"), sta_wc_lit("dst", 1), // Pub arg 1 @@ -586,11 +593,11 @@ mod tests { // eth_dos_distance_base (Index 1) let expected_base_stmts = vec![ StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![sta_wc_lit("src", 0), sta_wc_lit("dst", 1)], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![sta_wc_lit("distance", 2), sta_lit(0i64)], }, ]; @@ -608,7 +615,7 @@ mod tests { // Private args indices: 3-4 (shorter_distance, intermed) let expected_ind_stmts = vec![ StatementTmpl { - pred: Predicate::BatchSelf(3), // Calls eth_dos_distance (index 3) + pred_or_wc: pred_lit(Predicate::BatchSelf(3)), // Calls eth_dos_distance (index 3) args: vec![ // WildcardLiteral args sta_wc_lit("src", 0), @@ -617,7 +624,7 @@ mod tests { ], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::SumOf), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::SumOf)), args: vec![ sta_wc_lit("distance", 2), // public arg sta_wc_lit("shorter_distance", 3), // private arg @@ -625,7 +632,7 @@ mod tests { ], }, StatementTmpl { - pred: Predicate::BatchSelf(0), // Calls eth_friend (index 0) + pred_or_wc: pred_lit(Predicate::BatchSelf(0)), // Calls eth_friend (index 0) args: vec![ // WildcardLiteral args sta_wc_lit("intermed", 4), // private arg @@ -645,7 +652,7 @@ mod tests { // eth_dos_distance (Index 3) let expected_dist_stmts = vec![ StatementTmpl { - pred: Predicate::BatchSelf(1), // Calls eth_dos_distance_base (index 1) + pred_or_wc: pred_lit(Predicate::BatchSelf(1)), // Calls eth_dos_distance_base (index 1) args: vec![ // WildcardLiteral args sta_wc_lit("src", 0), @@ -654,7 +661,7 @@ mod tests { ], }, StatementTmpl { - pred: Predicate::BatchSelf(2), // Calls eth_dos_distance_ind (index 2) + pred_or_wc: pred_lit(Predicate::BatchSelf(2)), // Calls eth_dos_distance_ind (index 2) args: vec![ // WildcardLiteral args sta_wc_lit("src", 0), @@ -697,7 +704,7 @@ mod tests { // 1. Create a batch to be imported let imported_pred_stmts = vec![StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ sta_ak(("A", 0), "foo"), // A["foo"] sta_ak(("B", 1), "bar"), // B["bar"] @@ -739,7 +746,10 @@ mod tests { // 4. Check the resulting request template let expected_request_templates = vec![StatementTmpl { - pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 0)), + pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( + available_batch, + 0, + ))), args: vec![ StatementTmplArg::Wildcard(wc("Pod1", 0)), StatementTmplArg::Wildcard(wc("Pod2", 1)), @@ -788,11 +798,17 @@ mod tests { // 4. Check the resulting request templates let expected_templates = vec![ StatementTmpl { - pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)), + pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( + available_batch.clone(), + 0, + ))), args: vec![StatementTmplArg::Wildcard(wc("Pod1", 0))], }, StatementTmpl { - pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 2)), + pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( + available_batch, + 2, + ))), args: vec![StatementTmplArg::Wildcard(wc("Pod2", 1))], }, ]; @@ -808,7 +824,7 @@ mod tests { // 1. Create a batch with a predicate to be imported let imported_pred_stmts = vec![StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")], }]; let imported_predicate = CustomPredicate::and( @@ -855,7 +871,10 @@ mod tests { assert_eq!(defined_pred.statements.len(), 1); let expected_statement = StatementTmpl { - pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)), + pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new( + available_batch.clone(), + 0, + ))), args: vec![ StatementTmplArg::Wildcard(wc("X", 0)), StatementTmplArg::Wildcard(wc("Y", 1)), @@ -886,7 +905,9 @@ mod tests { let request_templates = processed.request.templates(); assert_eq!(request_templates.len(), 1); - if let Predicate::Intro(intro_ref) = &request_templates[0].pred { + if let PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) = + &request_templates[0].pred_or_wc + { assert_eq!(intro_ref.name, "empty"); assert_eq!(intro_ref.args_len, 0); assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH); @@ -944,27 +965,27 @@ mod tests { let expected_templates = vec![ StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![sta_ak(("A", 0), "pk"), sta_lit(Value::from(pk))], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![sta_ak(("B", 1), "raw"), sta_lit(Value::from(raw))], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![sta_ak(("C", 2), "string"), sta_lit(Value::from(string))], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![sta_ak(("D", 3), "int"), sta_lit(Value::from(int))], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![sta_ak(("E", 4), "bool"), sta_lit(Value::from(bool))], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![sta_ak(("F", 5), "sk"), sta_lit(Value::from(sk))], }, ]; diff --git a/src/lang/pretty_print.rs b/src/lang/pretty_print.rs index 5ba3470..7440825 100644 --- a/src/lang/pretty_print.rs +++ b/src/lang/pretty_print.rs @@ -5,7 +5,8 @@ use std::fmt::Write; use crate::{ frontend::PodRequest, middleware::{ - CustomPredicate, CustomPredicateBatch, Predicate, StatementTmpl, StatementTmplArg, Value, + CustomPredicate, CustomPredicateBatch, Predicate, PredicateOrWildcard, StatementTmpl, + StatementTmplArg, Value, }, }; @@ -57,26 +58,32 @@ impl StatementTmpl { w: &mut dyn Write, batch_context: Option<&CustomPredicateBatch>, ) -> std::fmt::Result { - match &self.pred { - Predicate::Native(native_pred) => { - write!(w, "{}", native_pred)?; - } - Predicate::Custom(custom_ref) => { - write!(w, "{}", custom_ref.predicate().name)?; - } - Predicate::Intro(intro_ref) => { - write!(w, "{}", intro_ref.name)?; - } - Predicate::BatchSelf(index) => { - if let Some(batch) = batch_context { - if let Some(predicate) = batch.predicates.get(*index) { - write!(w, "{}", predicate.name)?; + match &self.pred_or_wc { + PredicateOrWildcard::Predicate(pred) => match pred { + Predicate::Native(native_pred) => { + write!(w, "{}", native_pred)?; + } + Predicate::Custom(custom_ref) => { + write!(w, "{}", custom_ref.predicate().name)?; + } + Predicate::Intro(intro_ref) => { + write!(w, "{}", intro_ref.name)?; + } + Predicate::BatchSelf(index) => { + if let Some(batch) = batch_context { + if let Some(predicate) = batch.predicates.get(*index) { + write!(w, "{}", predicate.name)?; + } else { + write!(w, "batch_self_{}", index)?; + } } else { write!(w, "batch_self_{}", index)?; } - } else { - write!(w, "batch_self_{}", index)?; } + }, + PredicateOrWildcard::Wildcard(wc) => { + // TODO: Decide the syntax for a wildcard predicate + write!(w, "?{}", wc.name)?; } } @@ -223,13 +230,17 @@ mod tests { Wildcard::new(name.to_string(), index) } + fn pred_lit(pred: Predicate) -> PredicateOrWildcard { + PredicateOrWildcard::Predicate(pred) + } + #[test] fn test_simple_predicate_pretty_print() { let params = Params::default(); // Create a simple predicate: is_equal(PodA, PodB) = AND(Equal(PodA["key"], PodB["key"])) let statements = vec![StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ StatementTmplArg::AnchoredKey( create_test_wildcard("PodA", 0), @@ -265,7 +276,7 @@ mod tests { // Create: uses_private(A, private: Temp) = AND(Equal(A["input"], Temp["const"])) let statements = vec![StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ StatementTmplArg::AnchoredKey( create_test_wildcard("A", 0), @@ -301,7 +312,7 @@ mod tests { // Create: check_value(Pod) = AND(Equal(Pod["field"], 42)) let statements = vec![StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ StatementTmplArg::AnchoredKey( create_test_wildcard("Pod", 0), @@ -335,7 +346,7 @@ mod tests { // Create: either_or(A, B) = OR(Equal(A["x"], 1), Equal(B["y"], 2)) let statements = vec![ StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ StatementTmplArg::AnchoredKey( create_test_wildcard("A", 0), @@ -345,7 +356,7 @@ mod tests { ], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::Equal), + pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)), args: vec![ StatementTmplArg::AnchoredKey( create_test_wildcard("B", 1), diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 86bb2f4..bfb8ce4 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -1,5 +1,6 @@ use std::{fmt, iter, sync::Arc}; +use itertools::Itertools; use plonky2::field::types::Field; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -70,36 +71,28 @@ impl ToFields for StatementTmplArg { // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0) // In all three cases, we pad to 2 * hash_size + 1 = 9 field elements match self { - StatementTmplArg::None => { - let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::None)) - .chain(iter::repeat(F::ZERO)) - .take(Params::statement_tmpl_arg_size()) - .collect(); - fields - } - StatementTmplArg::Literal(v) => { - let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::Literal)) - .chain(v.raw().to_fields(params)) - .chain(iter::repeat(F::ZERO)) - .take(Params::statement_tmpl_arg_size()) - .collect(); - fields - } + StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None)) + .chain(iter::repeat(F::ZERO)) + .take(Params::statement_tmpl_arg_size()) + .collect_vec(), + StatementTmplArg::Literal(v) => iter::once(F::from(StatementTmplArgPrefix::Literal)) + .chain(v.raw().to_fields(params)) + .chain(iter::repeat(F::ZERO)) + .take(Params::statement_tmpl_arg_size()) + .collect_vec(), StatementTmplArg::AnchoredKey(wc1, kw2) => { - let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::AnchoredKey)) + iter::once(F::from(StatementTmplArgPrefix::AnchoredKey)) .chain(wc1.to_fields(params)) .chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1)) .chain(kw2.to_fields(params)) - .collect(); - fields + .collect_vec() } StatementTmplArg::Wildcard(wc) => { - let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral)) + iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral)) .chain(wc.to_fields(params)) .chain(iter::repeat(F::ZERO)) .take(Params::statement_tmpl_arg_size()) - .collect(); - fields + .collect_vec() } } } @@ -121,16 +114,79 @@ impl fmt::Display for StatementTmplArg { } } +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] +pub enum PredicateOrWildcard { + Predicate(Predicate), + Wildcard(Wildcard), +} + +impl PredicateOrWildcard { + pub fn as_pred(&self) -> Option<&Predicate> { + match self { + Self::Predicate(pred) => Some(pred), + _ => None, + } + } + pub fn as_wc(&self) -> Option<&Wildcard> { + match self { + Self::Wildcard(wc) => Some(wc), + _ => None, + } + } +} + +impl fmt::Display for PredicateOrWildcard { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Predicate(pred) => pred.fmt(f), + Self::Wildcard(wc) => { + write!(f, "?")?; + wc.fmt(f) + } + } + } +} + +#[derive(Clone, Copy)] +pub enum PredicateOrWildcardPrefix { + Predicate = 0, + Wildcard = 1, +} + +impl From for F { + fn from(prefix: PredicateOrWildcardPrefix) -> Self { + Self::from_canonical_usize(prefix as usize) + } +} + +impl ToFields for PredicateOrWildcard { + fn to_fields(&self, params: &Params) -> Vec { + // Encoding: + // Predicate(pred) => (0, [hash(pred) ]) + // Wildcard(wc) => (1, wc_index, 0...) + match self { + Self::Predicate(pred) => iter::once(F::from(PredicateOrWildcardPrefix::Predicate)) + .chain(pred.hash(params).to_fields(params)) + .collect_vec(), + Self::Wildcard(wc) => iter::once(F::from(PredicateOrWildcardPrefix::Wildcard)) + .chain(wc.to_fields(params)) + .chain(iter::repeat(F::ZERO)) + .take(Params::pred_hash_or_wc_size()) + .collect_vec(), + } + } +} + /// Statement Template for a Custom Predicate #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub struct StatementTmpl { - pub pred: Predicate, + pub pred_or_wc: PredicateOrWildcard, pub args: Vec, } impl StatementTmpl { - pub fn pred(&self) -> &Predicate { - &self.pred + pub fn pred_or_wc(&self) -> &PredicateOrWildcard { + &self.pred_or_wc } pub fn args(&self) -> &[StatementTmplArg] { &self.args @@ -139,7 +195,7 @@ impl StatementTmpl { impl fmt::Display for StatementTmpl { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.pred.fmt(f)?; + self.pred_or_wc.fmt(f)?; write!(f, "(")?; for (i, arg) in self.args.iter().enumerate() { if i != 0 { @@ -154,7 +210,7 @@ impl fmt::Display for StatementTmpl { impl ToFields for StatementTmpl { fn to_fields(&self, params: &Params) -> Vec { // serialize as: - // predicate (6 field elements) + // predicate (4 field elements) // then the StatementTmplArgs // TODO think if this check should go into the StatementTmpl creation, @@ -168,15 +224,13 @@ impl ToFields for StatementTmpl { ); } - let mut fields: Vec = self - .pred - .hash(params) + self.pred_or_wc .to_fields(params) .into_iter() .chain(self.args.iter().flat_map(|sta| sta.to_fields(params))) - .collect(); - fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0)); - fields + .chain(iter::repeat(F::ZERO)) + .take(params.statement_tmpl_size()) + .collect_vec() } } @@ -203,7 +257,9 @@ impl CustomPredicate { name: "empty".to_string(), conjunction: false, statements: vec![StatementTmpl { - pred: Predicate::Native(NativePredicate::None), + pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native( + NativePredicate::None, + )), args: vec![], }], args_len: 0, @@ -276,11 +332,11 @@ impl CustomPredicate { } pub fn pad_statement_tmpl(&self) -> StatementTmpl { StatementTmpl { - pred: Predicate::Native(if self.conjunction { + pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(if self.conjunction { NativePredicate::None } else { NativePredicate::False - }), + })), args: vec![], } } @@ -318,7 +374,7 @@ impl ToFields for CustomPredicate { } let pad_st = self.pad_statement_tmpl(); - let fields: Vec = iter::once(F::from_bool(self.conjunction)) + iter::once(F::from_bool(self.conjunction)) .chain(iter::once(F::from_canonical_usize(self.args_len))) .chain( self.statements @@ -327,8 +383,7 @@ impl ToFields for CustomPredicate { .take(params.max_custom_predicate_arity) .flat_map(|st| st.to_fields(params)), ) - .collect(); - fields + .collect_vec() } } @@ -350,7 +405,7 @@ impl fmt::Display for CustomPredicate { writeln!(f, ") = {}(", if self.conjunction { "AND" } else { "OR" })?; for st in &self.statements { write!(f, " ")?; - st.pred.fmt(f)?; + st.pred_or_wc.fmt(f)?; write!(f, "(")?; for (i, arg) in st.args.iter().enumerate() { if i != 0 { @@ -382,14 +437,12 @@ impl ToFields for CustomPredicateBatch { fn to_fields(&self, params: &Params) -> Vec { // all the custom predicates in order let pad_pred = CustomPredicate::empty(); - let fields: Vec = self - .predicates + self.predicates .iter() .chain(iter::repeat(&pad_pred)) .take(params.max_custom_batch_size) .flat_map(|p| p.to_fields(params)) - .collect(); - fields + .collect_vec() } } @@ -418,7 +471,6 @@ impl CustomPredicateBatch { // NOTE: This implementation just hashes the concatenation of all the custom predicates, // but ideally we want to use the root of a merkle tree built from the custom predicates. let input = self.to_fields(params); - hash_fields(&input) } @@ -470,7 +522,10 @@ mod tests { }; fn st(p: Predicate, args: Vec) -> StatementTmpl { - StatementTmpl { pred: p, args } + StatementTmpl { + pred_or_wc: PredicateOrWildcard::Predicate(p), + args, + } } fn key(name: &str) -> Key { diff --git a/src/middleware/error.rs b/src/middleware/error.rs index b1e5b02..e71544b 100644 --- a/src/middleware/error.rs +++ b/src/middleware/error.rs @@ -29,6 +29,8 @@ pub enum MiddlewareInnerError { MismatchedStatementTmplArg(StatementTmplArg, StatementArg), #[error("Expected a statement of type {0}, got {1}")] MismatchedStatementType(Predicate, Predicate), + #[error("Expected a statement with hash(predicate) {0}, got {1} ({2})")] + MismatchedStatementWildcardPredicate(Value, Value, Predicate), #[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")] MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate), #[error( @@ -111,6 +113,15 @@ impl Error { pub(crate) fn mismatched_statement_type(expected: Predicate, seen: Predicate) -> Self { new!(MismatchedStatementType(expected, seen)) } + pub(crate) fn mismatched_statement_wc_pred( + expected: Value, + seen: Value, + seen_pred: Predicate, + ) -> Self { + new!(MismatchedStatementWildcardPredicate( + expected, seen, seen_pred + )) + } pub(crate) fn mismatched_wildcard_value_and_statement_arg( wc_value: Value, st_arg: Value, diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index a0b55ae..db28f5a 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -820,8 +820,12 @@ impl Params { HASH_SIZE + STATEMENT_ARG_F_LEN * self.max_statement_args } + pub const fn pred_hash_or_wc_size() -> usize { + 1 + HASH_SIZE + } + pub const fn statement_tmpl_size(&self) -> usize { - HASH_SIZE + self.max_statement_args * Self::statement_tmpl_arg_size() + Self::pred_hash_or_wc_size() + self.max_statement_args * Self::statement_tmpl_arg_size() } pub fn custom_predicate_size(&self) -> usize { diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 2200513..6435a92 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -15,8 +15,9 @@ use crate::{ }, middleware::{ hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key, - MiddlewareInnerError, NativePredicate, Params, Predicate, Result, Statement, StatementArg, - StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F, + MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result, + Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, + ValueRef, Wildcard, F, }, }; @@ -550,6 +551,22 @@ impl Operation { } } +// Check that the value `v` at wildcard `wc` exists in the map or set it. +fn wc_check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option]) -> Result<()> { + if let Some(prev) = &wildcard_map[wc.index] { + if *prev != v { + return Err(Error::invalid_wildcard_assignment( + wc.clone(), + v, + prev.clone(), + )); + } + } else { + wildcard_map[wc.index] = Some(v); + } + Ok(()) +} + /// Check that a StatementArg follows a StatementTmplArg based on the currently mapped wildcards. /// Update the wildcard map with newly found wildcards. pub fn check_st_tmpl( @@ -558,22 +575,6 @@ pub fn check_st_tmpl( // Map from wildcards to values that we have seen so far. wildcard_map: &mut [Option], ) -> Result<()> { - // Check that the value `v` at wildcard `wc` exists in the map or set it. - fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option]) -> Result<()> { - if let Some(prev) = &wildcard_map[wc.index] { - if *prev != v { - return Err(Error::invalid_wildcard_assignment( - wc.clone(), - v, - prev.clone(), - )); - } - } else { - wildcard_map[wc.index] = Some(v); - } - Ok(()) - } - match (st_tmpl_arg, st_arg) { (StatementTmplArg::None, StatementArg::None) => Ok(()), (StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => Ok(()), @@ -581,7 +582,7 @@ pub fn check_st_tmpl( StatementTmplArg::AnchoredKey(root_wc, key_tmpl), StatementArg::Key(AnchoredKey { root, key }), ) => { - let root_ok = check_or_set(Value::from(*root), root_wc, wildcard_map); + let root_ok = wc_check_or_set(Value::from(*root), root_wc, wildcard_map); root_ok.and_then(|_| { (key_tmpl == key).then_some(()).ok_or( Error::mismatched_anchored_key_in_statement_tmpl_arg( @@ -594,7 +595,7 @@ pub fn check_st_tmpl( }) } (StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => { - check_or_set(v.clone(), wc, wildcard_map) + wc_check_or_set(v.clone(), wc, wildcard_map) } _ => Err(Error::mismatched_statement_tmpl_arg( st_tmpl_arg.clone(), @@ -604,12 +605,16 @@ pub fn check_st_tmpl( } pub fn fill_wildcard_values( + params: &Params, pred: &CustomPredicate, args: &[Statement], wildcard_map: &mut [Option], ) -> Result<()> { for (st_tmpl, st) in pred.statements.iter().zip(args) { let st_args = st.args(); + if let PredicateOrWildcard::Wildcard(wc) = &st_tmpl.pred_or_wc { + wc_check_or_set(Value::from(st.predicate().hash(params)), wc, wildcard_map)?; + } st_tmpl .args .iter() @@ -633,7 +638,7 @@ pub fn wildcard_values_from_op_st( .chain(core::iter::repeat(None)) .take(params.max_custom_predicate_wildcards) .collect_vec(); - fill_wildcard_values(pred, op_args, &mut wildcard_map)?; + fill_wildcard_values(params, pred, op_args, &mut wildcard_map)?; // NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because // they are beyond the number of used wildcards in this custom predicate, or they could be // private arguments that are unused in a particular disjunction. @@ -644,22 +649,38 @@ pub fn wildcard_values_from_op_st( } fn check_custom_pred_argument( + params: &Params, custom_pred_ref: &CustomPredicateRef, template: &StatementTmpl, statement: &Statement, + wc_values: &[Value], ) -> Result<()> { - let template_pred = match &template.pred { - &Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef { - batch: custom_pred_ref.batch.clone(), - index: i, - }), - p => p.clone(), - }; - if template_pred != statement.predicate() { - return Err(Error::mismatched_statement_type( - template_pred, - statement.predicate(), - )); + match &template.pred_or_wc { + PredicateOrWildcard::Predicate(pred) => { + let template_pred = match pred { + &Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef { + batch: custom_pred_ref.batch.clone(), + index: i, + }), + p => p.clone(), + }; + if template_pred != statement.predicate() { + return Err(Error::mismatched_statement_type( + template_pred, + statement.predicate(), + )); + } + } + PredicateOrWildcard::Wildcard(wc) => { + let pred_hash = Value::from(statement.predicate().hash(params)); + if wc_values[wc.index] != pred_hash { + return Err(Error::mismatched_statement_wc_pred( + wc_values[wc.index].clone(), + pred_hash, + statement.predicate(), + )); + } + } } let st_args_len = statement.args().len(); if template.args.len() != st_args_len { @@ -697,17 +718,42 @@ pub(crate) fn check_custom_pred( )); } + // Check that the resolved wildcards match the statement arguments. + let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) { + Ok(wc_values) => wc_values, + Err(Error::Inner { inner, backtrace }) => match *inner { + MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev) + if wc.index <= s_args.len() => + { + return Err(Error::mismatched_wildcard_value_and_statement_arg( + v, + prev, + wc.index, + pred.clone(), + )) + } + _ => return Err(Error::Inner { inner, backtrace }), + }, + _ => unreachable!(), + }; + let mut match_exists = false; for (st_tmpl, st) in pred.statements.iter().zip(args) { // For `or` predicates, only one statement needs to match the template. // The rest of the statements can be `None`. - if !pred.conjunction - && matches!(st, Statement::None) - && st_tmpl.pred != Predicate::Native(NativePredicate::None) - { + let expected_pred_is_none = match &st_tmpl.pred_or_wc { + PredicateOrWildcard::Predicate(st_tmpl_pred) => { + *st_tmpl_pred == Predicate::Native(NativePredicate::None) + } + PredicateOrWildcard::Wildcard(wc) => { + wc_values[wc.index] + == Value::from(Predicate::Native(NativePredicate::None).hash(params)) + } + }; + if !pred.conjunction && matches!(st, Statement::None) && !expected_pred_is_none { continue; } - check_custom_pred_argument(custom_pred_ref, st_tmpl, st)?; + check_custom_pred_argument(params, custom_pred_ref, st_tmpl, st, &wc_values)?; match_exists = true; } @@ -716,25 +762,7 @@ pub(crate) fn check_custom_pred( pred.clone(), )); } - - // Check that the resolved wildcards match the statement arguments. - match wildcard_values_from_op_st(params, pred, args, s_args) { - Ok(_) => Ok(()), - Err(Error::Inner { inner, backtrace }) => match *inner { - MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev) - if wc.index <= s_args.len() => - { - Err(Error::mismatched_wildcard_value_and_statement_arg( - v, - prev, - wc.index, - pred.clone(), - )) - } - _ => Err(Error::Inner { inner, backtrace }), - }, - _ => unreachable!(), - } + Ok(()) } impl ToFields for Operation {