From a7a30176a79d0b96880b0319789c559a5260f9b0 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Mon, 2 Feb 2026 16:23:32 +0100 Subject: [PATCH] Split Params into base and developer-defined (#458) I thought it would be nice to have a Predicate for the typed value so that the developer can work with predicates as values comfortably. Then I noticed that hashing a predicate required `Params` which would have been annoying for converting a `TypedValue::Predicate` to `RawValue` and this led to a small refactor over how `Params` work. We already had some fields in the `Params` struct that determine compatibility between encoded data. They can be seen as determining a kind of ABI compatibility. In general it's better if those parameters don't change so that different circuit configurations can still verify proofs from each other. So I decided to force those parameters to be constant in the code base and not allow the user of our library to change them. Many field element serialization/deserialization functions in our code depended on those parameters, and since now they are constant many functions get rid of the `Params` argument, which simplifies the code. This includes the serialization of a `Predicate` which was required to calculate its hash. --- src/backends/plonky2/circuits/common.rs | 222 +++++++--------------- src/backends/plonky2/circuits/mainpod.rs | 178 ++++++++--------- src/backends/plonky2/emptypod.rs | 33 ++-- src/backends/plonky2/mainpod/mod.rs | 49 ++--- src/backends/plonky2/mainpod/statement.rs | 10 +- src/backends/plonky2/mock/emptypod.rs | 4 +- src/backends/plonky2/mock/mainpod.rs | 2 +- src/examples/mod.rs | 10 +- src/frontend/custom.rs | 11 +- src/frontend/mod.rs | 2 +- src/lang/frontend_ast_batch.rs | 14 +- src/lang/frontend_ast_lower.rs | 8 +- src/lang/frontend_ast_split.rs | 48 ++--- src/lang/mod.rs | 3 - src/middleware/basetypes.rs | 6 +- src/middleware/custom.rs | 63 +++--- src/middleware/error.rs | 2 +- src/middleware/mod.rs | 121 +++++++----- src/middleware/operation.rs | 34 ++-- src/middleware/statement.rs | 24 +-- 20 files changed, 376 insertions(+), 468 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 6d61ea1..5260c96 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -101,13 +101,8 @@ pub struct StatementArgTarget { } impl StatementArgTarget { - pub fn set_targets( - &self, - pw: &mut PartialWitness, - params: &Params, - arg: &StatementArg, - ) -> Result<()> { - Ok(pw.set_target_arr(&self.elements, &arg.to_fields(params))?) + pub fn set_targets(&self, pw: &mut PartialWitness, arg: &StatementArg) -> Result<()> { + Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?) } pub fn new(first: ValueTarget, second: ValueTarget) -> Self { @@ -190,7 +185,7 @@ impl StatementTarget { .iter() .cloned() .chain(iter::repeat_with(|| StatementArgTarget::none(builder))) - .take(params.max_statement_args) + .take(Params::max_statement_args()) .collect(), } } @@ -205,24 +200,19 @@ impl StatementTarget { Self::new_with_pred(builder, params, pred, args) } - pub fn set_targets( - &self, - pw: &mut PartialWitness, - params: &Params, - st: &Statement, - ) -> Result<()> { + pub fn set_targets(&self, pw: &mut PartialWitness, st: &Statement) -> Result<()> { if let Some(pred) = &self.pred { - pred.set_targets(pw, params, &st.predicate())?; + pred.set_targets(pw, &st.predicate())?; } - pw.set_hash_target(self.pred_hash, HashOut::from(st.predicate().hash(params)))?; + pw.set_hash_target(self.pred_hash, HashOut::from(st.predicate().hash()))?; for (i, arg) in st .args() .iter() .chain(iter::repeat(&StatementArg::None)) - .take(params.max_statement_args) + .take(Params::max_statement_args()) .enumerate() { - self.args[i].set_targets(pw, params, arg)?; + self.args[i].set_targets(pw, arg)?; } Ok(()) } @@ -235,14 +225,9 @@ impl StatementTarget { builder.is_equal_flattenable(&self.pred_hash, &blank_intro) } - pub fn has_native_type( - &self, - builder: &mut CircuitBuilder, - params: &Params, - t: NativePredicate, - ) -> BoolTarget { + pub fn has_native_type(&self, builder: &mut CircuitBuilder, t: NativePredicate) -> BoolTarget { let expected_predicate_hash = - builder.constant_hash(HashOut::from(Predicate::Native(t).hash(params))); + builder.constant_hash(HashOut::from(Predicate::Native(t).hash())); builder.is_equal_flattenable(&self.pred_hash, &expected_predicate_hash) } } @@ -252,8 +237,8 @@ pub trait Build { } impl Build for NativePredicate { - fn build(self, builder: &mut CircuitBuilder, params: &Params) -> NativePredicateTarget { - NativePredicateTarget::constant(builder, params, self) + fn build(self, builder: &mut CircuitBuilder, _params: &Params) -> NativePredicateTarget { + NativePredicateTarget::constant(builder, self) } } @@ -301,13 +286,8 @@ impl OperationTypeTarget { builder.and(op_is_native, op_code_matches) } - pub fn set_targets( - &self, - pw: &mut PartialWitness, - params: &Params, - op_type: &OperationType, - ) -> Result<()> { - Ok(pw.set_target_arr(&self.elements, &op_type.to_fields(params))?) + pub fn set_targets(&self, pw: &mut PartialWitness, op_type: &OperationType) -> Result<()> { + Ok(pw.set_target_arr(&self.elements, &op_type.to_fields())?) } fn size(_params: &Params) -> usize { @@ -330,7 +310,7 @@ impl OperationTarget { params: &Params, op: &Operation, ) -> Result<()> { - self.op_type.set_targets(pw, params, &op.op_type())?; + self.op_type.set_targets(pw, &op.op_type())?; for (i, arg) in op .args() .iter() @@ -354,12 +334,8 @@ impl OperationTarget { pub struct NativePredicateTarget(Target); impl NativePredicateTarget { - pub fn constant( - builder: &mut CircuitBuilder, - params: &Params, - native_predicate: NativePredicate, - ) -> Self { - let id = native_predicate.to_fields(params); + pub fn constant(builder: &mut CircuitBuilder, native_predicate: NativePredicate) -> Self { + let id = native_predicate.to_fields(); assert_eq!(1, id.len()); Self(builder.constant(id[0])) } @@ -367,10 +343,9 @@ impl NativePredicateTarget { pub fn set_targets( &self, pw: &mut PartialWitness, - params: &Params, native_predicate: NativePredicate, ) -> Result<()> { - let id = native_predicate.to_fields(params); + let id = native_predicate.to_fields(); assert_eq!(1, id.len()); Ok(pw.set_target(self.0, id[0])?) } @@ -431,13 +406,8 @@ impl PredicateTarget { builder.is_equal(prefix, self.elements[0]) } - pub fn set_targets( - &self, - pw: &mut PartialWitness, - params: &Params, - predicate: &Predicate, - ) -> Result<()> { - Ok(pw.set_target_arr(&self.elements, &predicate.to_fields(params))?) + pub fn set_targets(&self, pw: &mut PartialWitness, predicate: &Predicate) -> Result<()> { + Ok(pw.set_target_arr(&self.elements, &predicate.to_fields())?) } pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget { @@ -534,10 +504,9 @@ impl StatementTmplArgTarget { pub fn set_targets( &self, pw: &mut PartialWitness, - params: &Params, st_tmpl_arg: &StatementTmplArg, ) -> Result<()> { - Ok(pw.set_target_arr(&self.elements, &st_tmpl_arg.to_fields(params))?) + Ok(pw.set_target_arr(&self.elements, &st_tmpl_arg.to_fields())?) } } @@ -588,7 +557,6 @@ impl PredicateHashOrWildcardTarget { pub fn set_targets( &self, pw: &mut PartialWitness, - params: &Params, pred: &PredicateOrWildcard, ) -> Result<()> { match pred { @@ -596,7 +564,7 @@ impl PredicateHashOrWildcardTarget { self.set_targets_raw( pw, PredicateOrWildcardPrefix::Predicate, - RawValue::from(pred.hash(params)), + RawValue::from(pred.hash()), )?; } PredicateOrWildcard::Wildcard(wc) => { @@ -650,19 +618,14 @@ impl StatementTmplTarget { args, } } - pub fn set_targets( - &self, - pw: &mut PartialWitness, - params: &Params, - st_tmpl: &StatementTmpl, - ) -> Result<()> { + pub fn set_targets(&self, pw: &mut PartialWitness, st_tmpl: &StatementTmpl) -> Result<()> { if let Some(pred) = &self.pred { match &st_tmpl.pred_or_wc { PredicateOrWildcard::Predicate(p) => { // We store a predicate (not a wildcard) and we have it available. In this // case the hash will be calculated by constraints later on and we should not // rely on the original data. - pred.set_targets(pw, params, p)? + pred.set_targets(pw, p)? } PredicateOrWildcard::Wildcard(_wc) => { // Fill in with a recognizable constant for better debugging; this value is @@ -671,17 +634,16 @@ impl StatementTmplTarget { } } } - self.pred_hash_or_wc - .set_targets(pw, params, &st_tmpl.pred_or_wc)?; + self.pred_hash_or_wc.set_targets(pw, &st_tmpl.pred_or_wc)?; let arg_pad = StatementTmplArg::None; for (i, arg) in st_tmpl .args .iter() .chain(iter::repeat(&arg_pad)) - .take(params.max_statement_args) + .take(Params::max_statement_args()) .enumerate() { - self.args[i].set_targets(pw, params, arg)?; + self.args[i].set_targets(pw, arg)?; } Ok(()) } @@ -705,7 +667,6 @@ impl CustomPredicateTarget { pub fn set_targets( &self, pw: &mut PartialWitness, - params: &Params, custom_pred: &CustomPredicate, ) -> Result<()> { pw.set_target( @@ -717,10 +678,10 @@ impl CustomPredicateTarget { .statements .iter() .chain(iter::repeat(&st_tmpl_pad)) - .take(params.max_custom_predicate_arity) + .take(Params::max_custom_predicate_arity()) .enumerate() { - self.statements[i].set_targets(pw, params, st_tmpl)?; + self.statements[i].set_targets(pw, st_tmpl)?; } pw.set_target(self.args_len, F::from_canonical_usize(custom_pred.args_len))?; Ok(()) @@ -743,7 +704,6 @@ impl CustomPredicateBatchTarget { pub fn set_targets( &self, pw: &mut PartialWitness, - params: &Params, custom_predicate_batch: &CustomPredicateBatch, ) -> Result<()> { let pad_predicate = CustomPredicate::empty(); @@ -751,10 +711,10 @@ impl CustomPredicateBatchTarget { .predicates() .iter() .chain(iter::repeat(&pad_predicate)) - .take(params.max_custom_batch_size) + .take(Params::max_custom_batch_size()) .enumerate() { - self.predicates[i].set_targets(pw, params, predicate)?; + self.predicates[i].set_targets(pw, predicate)?; } Ok(()) } @@ -772,7 +732,6 @@ impl CustomPredicateEntryTarget { pub fn set_targets( &self, pw: &mut PartialWitness, - params: &Params, predicate: &CustomPredicateRef, ) -> Result<()> { pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?; @@ -808,7 +767,7 @@ impl CustomPredicateEntryTarget { args_len: predicate.args_len, wildcard_names: predicate.wildcard_names.clone(), }; - self.predicate.set_targets(pw, params, &predicate)?; + self.predicate.set_targets(pw, &predicate)?; Ok(()) } } @@ -854,18 +813,18 @@ pub struct CustomPredicateVerifyEntryTarget { impl CustomPredicateVerifyEntryTarget { pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { let custom_predicate_table_len = - params.max_custom_predicate_batches * params.max_custom_batch_size; + params.max_custom_predicate_batches * Params::max_custom_batch_size(); CustomPredicateVerifyEntryTarget { custom_predicate_table_index: IndexTarget::new_virtual( custom_predicate_table_len, builder, ), - custom_predicate: builder.add_virtual_custom_predicate_entry(params), + custom_predicate: builder.add_virtual_custom_predicate_entry(), args: (0..params.max_custom_predicate_wildcards) .map(|_| builder.add_virtual_value()) .collect(), op_args: (0..params.max_operation_args) - .map(|_| builder.add_virtual_statement(params, false)) + .map(|_| builder.add_virtual_statement(false)) .collect(), } } @@ -879,7 +838,7 @@ impl CustomPredicateVerifyEntryTarget { .set_targets(pw, cpv.custom_predicate_table_index)?; // Replace statement templates of batch-self with (id,index) self.custom_predicate - .set_targets(pw, params, &cpv.custom_predicate)?; + .set_targets(pw, &cpv.custom_predicate)?; let pad_arg = Value::from(0); for (arg_target, arg) in self.args.iter().zip_eq( cpv.args @@ -896,7 +855,7 @@ impl CustomPredicateVerifyEntryTarget { .chain(iter::repeat(&pad_op_arg)) .take(params.max_operation_args), ) { - op_arg_target.set_targets(pw, params, op_arg)? + op_arg_target.set_targets(pw, op_arg)? } Ok(()) } @@ -1138,7 +1097,7 @@ impl Flattenable for StatementTarget { fn from_flattened(params: &Params, v: &[Target]) -> Self { assert_eq!(v.len(), Self::size(params)); let predicate_hash = HashOutTarget::from_flattened(params, &v[..HASH_SIZE]); - let args = (0..params.max_statement_args) + let args = (0..Params::max_statement_args()) .map(|i| StatementArgTarget { elements: array::from_fn(|j| v[HASH_SIZE + i * STATEMENT_ARG_F_LEN + j]), }) @@ -1152,7 +1111,7 @@ impl Flattenable for StatementTarget { } fn size(params: &Params) -> usize { - HASH_SIZE + params.max_statement_args * StatementArgTarget::size(params) + HASH_SIZE + Params::max_statement_args() * StatementArgTarget::size(params) } } @@ -1170,8 +1129,8 @@ impl Flattenable for CustomPredicateTarget { // this `BoolTarget` should actually safe. let conjunction = BoolTarget::new_unsafe(v[0]); let args_len = v[1]; - let st_tmpl_size = params.statement_tmpl_size(); - let statements = (0..params.max_custom_predicate_arity) + let st_tmpl_size = Params::statement_tmpl_size(); + let statements = (0..Params::max_custom_predicate_arity()) .map(|i| { let st_v = &v[2 + st_tmpl_size * i..2 + st_tmpl_size * (i + 1)]; StatementTmplTarget::from_flattened(params, st_v) @@ -1184,7 +1143,7 @@ impl Flattenable for CustomPredicateTarget { } } fn size(params: &Params) -> usize { - 2 + params.max_custom_predicate_arity * StatementTmplTarget::size(params) + 2 + Params::max_custom_predicate_arity() * StatementTmplTarget::size(params) } } @@ -1203,7 +1162,7 @@ impl Flattenable for StatementTmplTarget { let pred_hash_or_wc = PredicateHashOrWildcardTarget::from_flattened(params, &v[..pred_hash_or_wc_end]); let sta_size = Params::statement_tmpl_arg_size(); - let args = (0..params.max_statement_args) + let args = (0..Params::max_statement_args()) .map(|i| { let sta_v = &v [pred_hash_or_wc_end + sta_size * i..pred_hash_or_wc_end + sta_size * (i + 1)]; @@ -1219,7 +1178,7 @@ impl Flattenable for StatementTmplTarget { fn size(params: &Params) -> usize { Params::pred_hash_or_wc_size() - + params.max_statement_args * StatementTmplArgTarget::size(params) + + Params::max_statement_args() * StatementTmplArgTarget::size(params) } } @@ -1278,29 +1237,17 @@ pub trait CircuitBuilderPod, const D: usize> { fn connect_values(&mut self, x: ValueTarget, y: ValueTarget); fn connect_slice(&mut self, xs: &[Target], ys: &[Target]); fn add_virtual_value(&mut self) -> ValueTarget; - fn add_virtual_statement(&mut self, params: &Params, with_pred: bool) -> StatementTarget; + fn add_virtual_statement(&mut self, with_pred: bool) -> StatementTarget; fn add_virtual_statement_arg(&mut self) -> StatementArgTarget; fn add_virtual_predicate(&mut self) -> PredicateTarget; fn add_virtual_operation_type(&mut self) -> OperationTypeTarget; fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget; fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget; - fn add_virtual_statement_tmpl( - &mut self, - params: &Params, - with_pred: bool, - ) -> StatementTmplTarget; - fn add_virtual_custom_predicate( - &mut self, - params: &Params, - with_pred: bool, - ) -> CustomPredicateTarget; - fn add_virtual_custom_predicate_batch( - &mut self, - params: &Params, - with_pred: bool, - ) -> CustomPredicateBatchTarget; - fn add_virtual_custom_predicate_entry(&mut self, params: &Params) - -> CustomPredicateEntryTarget; + fn add_virtual_statement_tmpl(&mut self, with_pred: bool) -> StatementTmplTarget; + fn add_virtual_custom_predicate(&mut self, with_pred: bool) -> CustomPredicateTarget; + fn add_virtual_custom_predicate_batch(&mut self, with_pred: bool) + -> CustomPredicateBatchTarget; + fn add_virtual_custom_predicate_entry(&mut self) -> CustomPredicateEntryTarget; fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; fn select_statement_arg( &mut self, @@ -1396,7 +1343,7 @@ impl CircuitBuilderPod for CircuitBuilder { /// If `with_pred = true` a predicate is included and its hash constrained. /// If `with_pred = false` only the predicate hash is included. - fn add_virtual_statement(&mut self, params: &Params, with_pred: bool) -> StatementTarget { + fn add_virtual_statement(&mut self, with_pred: bool) -> StatementTarget { let (pred, pred_hash) = if with_pred { let pred = self.add_virtual_predicate(); let pred_hash = pred.hash(self); @@ -1408,7 +1355,7 @@ impl CircuitBuilderPod for CircuitBuilder { StatementTarget { pred, pred_hash, - args: (0..params.max_statement_args) + args: (0..Params::max_statement_args()) .map(|_| self.add_virtual_statement_arg()) .collect(), } @@ -1452,11 +1399,7 @@ impl CircuitBuilderPod for CircuitBuilder { /// If `with_pred = false` only the predicate hash is included. /// The pred_hash is constrained to be hash(pred) conditionally on the template using a /// predicate and not a wildcard. - fn add_virtual_statement_tmpl( - &mut self, - params: &Params, - with_pred: bool, - ) -> StatementTmplTarget { + fn add_virtual_statement_tmpl(&mut self, with_pred: bool) -> StatementTmplTarget { let pred_hash_or_wc = PredicateHashOrWildcardTarget::new(self.add_virtual_target(), self.add_virtual_value()); let pred = if with_pred { @@ -1474,20 +1417,16 @@ impl CircuitBuilderPod for CircuitBuilder { StatementTmplTarget { pred, pred_hash_or_wc, - args: (0..params.max_statement_args) + args: (0..Params::max_statement_args()) .map(|_| self.add_virtual_statement_tmpl_arg()) .collect(), } } /// See `add_virtual_statement_tmpl` for the meaning of `with_pred`. - fn add_virtual_custom_predicate( - &mut self, - params: &Params, - with_pred: bool, - ) -> CustomPredicateTarget { - let statements = (0..params.max_custom_predicate_arity) - .map(|_| self.add_virtual_statement_tmpl(params, with_pred)) + fn add_virtual_custom_predicate(&mut self, with_pred: bool) -> CustomPredicateTarget { + let statements = (0..Params::max_custom_predicate_arity()) + .map(|_| self.add_virtual_statement_tmpl(with_pred)) .collect(); CustomPredicateTarget { conjunction: self.add_virtual_bool_target_safe(), @@ -1499,25 +1438,21 @@ impl CircuitBuilderPod for CircuitBuilder { /// See `add_virtual_statement_tmpl` for the meaning of `with_pred`. fn add_virtual_custom_predicate_batch( &mut self, - params: &Params, with_pred: bool, ) -> CustomPredicateBatchTarget { CustomPredicateBatchTarget { - predicates: (0..params.max_custom_batch_size) - .map(|_| self.add_virtual_custom_predicate(params, with_pred)) + predicates: (0..Params::max_custom_batch_size()) + .map(|_| self.add_virtual_custom_predicate(with_pred)) .collect(), } } /// See `add_virtual_statement_tmpl` for the meaning of `with_pred`. - fn add_virtual_custom_predicate_entry( - &mut self, - params: &Params, - ) -> CustomPredicateEntryTarget { + fn add_virtual_custom_predicate_entry(&mut self) -> CustomPredicateEntryTarget { CustomPredicateEntryTarget { id: self.add_virtual_hash(), index: self.add_virtual_target(), - predicate: self.add_virtual_custom_predicate(params, false), + predicate: self.add_virtual_custom_predicate(false), } } @@ -1998,7 +1933,7 @@ pub(crate) mod tests { for (i, cp) in custom_predicate_batch.predicates().iter().enumerate() { let mut builder = CircuitBuilder::::new(config.clone()); - let flattened = cp.to_fields(¶ms); + let flattened = cp.to_fields(); let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec(); let cp_target = CustomPredicateTarget::from_flattened(¶ms, &flatteend_target); // Round trip of from_flattened to flattened @@ -2018,20 +1953,18 @@ pub(crate) mod tests { } fn helper_custom_predicate_batch_target_id( - params: &Params, custom_predicate_batch: &CustomPredicateBatch, ) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let custom_predicate_batch_target = - builder.add_virtual_custom_predicate_batch(params, false); + let custom_predicate_batch_target = builder.add_virtual_custom_predicate_batch(false); // Calculate the id in constraints and compare it against the id calculated natively let id_target = custom_predicate_batch_target.id(&mut builder); let mut pw = PartialWitness::::new(); - custom_predicate_batch_target.set_targets(&mut pw, params, custom_predicate_batch)?; + custom_predicate_batch_target.set_targets(&mut pw, custom_predicate_batch)?; let id = custom_predicate_batch.id(); pw.set_target_arr(&id_target.elements, &id.0)?; @@ -2046,7 +1979,6 @@ pub(crate) mod tests { #[test] fn test_custom_predicate_batch_target_id() -> frontend::Result<()> { let params = Params { - max_statement_args: 6, max_custom_predicate_wildcards: 12, ..Default::default() }; @@ -2055,15 +1987,15 @@ pub(crate) mod tests { let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into()); _ = cpb_builder.predicate_and("empty", &[], &[], &[])?; let custom_predicate_batch = cpb_builder.finish(); - helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap(); + helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap(); // Some cases from the examples let custom_predicate_batch = eth_dos_batch(¶ms)?; - helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap(); + helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap(); let custom_predicate_batch = CustomPredicateBatch::new(¶ms, "empty".to_string(), vec![CustomPredicate::empty()]); - helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap(); + helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap(); Ok(()) } @@ -2079,17 +2011,13 @@ pub(crate) mod tests { let sum_target = builder.i64_add(x_target, y_target); let data = builder.build::(); - let params = Params::default(); I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| { let mut pw = PartialWitness::::new(); let (sum, overflow) = x.overflowing_add(y); - pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields(¶ms))?; - pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields(¶ms))?; - pw.set_target_arr( - &sum_target.elements, - &RawValue::from(sum).to_fields(¶ms), - )?; + pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields())?; + pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields())?; + pw.set_target_arr(&sum_target.elements, &RawValue::from(sum).to_fields())?; let proof = data.prove(pw); @@ -2113,18 +2041,14 @@ pub(crate) mod tests { let prod_target = builder.i64_mul(x_target, y_target); let data = builder.build::(); - let params = Params::default(); I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| { println!("{}, {}", x, y); let mut pw = PartialWitness::::new(); let (prod, overflow) = x.overflowing_mul(y); - pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields(¶ms))?; - pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields(¶ms))?; - pw.set_target_arr( - &prod_target.elements, - &RawValue::from(prod).to_fields(¶ms), - )?; + pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields())?; + pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields())?; + pw.set_target_arr(&prod_target.elements, &RawValue::from(prod).to_fields())?; let proof = data.prove(pw); diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 2bb6ee5..845f445 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -104,13 +104,12 @@ impl StatementCache { .collect::>() }; assert!(params.max_operation_args >= MAX_VALUE_ARGS); - assert!(params.max_statement_args >= MAX_VALUE_ARGS); + assert!(Params::max_statement_args() >= MAX_VALUE_ARGS); let equations = array::from_fn(|i| { - let pred_is_none = op_args[i].has_native_type(builder, params, NativePredicate::None); + let pred_is_none = op_args[i].has_native_type(builder, NativePredicate::None); let arg_is_value = builder.statement_arg_is_value(&st.args[i]); let is_literal = builder.and(pred_is_none, arg_is_value); - let pred_is_contains = - op_args[i].has_native_type(builder, params, NativePredicate::Contains); + let pred_is_contains = op_args[i].has_native_type(builder, NativePredicate::Contains); let ref_is_value_arg: [_; 3] = array::from_fn(|j| builder.statement_arg_is_value(&op_args[i].args[j])); let ref_is_value = builder.and(ref_is_value_arg[0], ref_is_value_arg[1]); @@ -435,8 +434,8 @@ fn verify_operation_circuit( if !cache.op_args.is_empty() { op_checks.extend_from_slice(&[ verify_copy_circuit(builder, st, &op.op_type, &cache.op_args), - verify_eq_neq_from_entries_circuit(params, builder, st, &op.op_type, &cache), - verify_lt_lteq_from_entries_circuit(params, builder, st, &op.op_type, &cache), + verify_eq_neq_from_entries_circuit(builder, st, &op.op_type, &cache), + verify_lt_lteq_from_entries_circuit(builder, st, &op.op_type, &cache), verify_transitive_eq_circuit(params, builder, st, &op.op_type, &cache.op_args), verify_lt_to_neq_circuit(params, builder, st, &op.op_type, &cache.op_args), verify_hash_of_circuit(params, builder, st, &op.op_type, &cache), @@ -881,7 +880,6 @@ fn verify_custom_circuit( /// Carries out the checks necessary for EqualFromEntries and /// NotEqualFromEntries. fn verify_eq_neq_from_entries_circuit( - params: &Params, builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, @@ -890,12 +888,12 @@ fn verify_eq_neq_from_entries_circuit( let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries"); let eq_op_st_code_ok = { let op_code_ok = op_type.has_native(builder, NativeOperation::EqualFromEntries); - let st_code_ok = st.has_native_type(builder, params, NativePredicate::Equal); + let st_code_ok = st.has_native_type(builder, NativePredicate::Equal); builder.and(op_code_ok, st_code_ok) }; let neq_op_st_code_ok = { let op_code_ok = op_type.has_native(builder, NativeOperation::NotEqualFromEntries); - let st_code_ok = st.has_native_type(builder, params, NativePredicate::NotEqual); + let st_code_ok = st.has_native_type(builder, NativePredicate::NotEqual); builder.and(op_code_ok, st_code_ok) }; let op_st_code_ok = builder.or(eq_op_st_code_ok, neq_op_st_code_ok); @@ -911,7 +909,7 @@ fn verify_eq_neq_from_entries_circuit( let expected_st_args: Vec<_> = [arg1_expected, arg2_expected] .into_iter() .chain(std::iter::repeat_with(|| StatementArgTarget::none(builder))) - .take(params.max_statement_args) + .take(Params::max_statement_args()) .flat_map(|arg| arg.elements) .collect(); @@ -931,7 +929,6 @@ fn verify_eq_neq_from_entries_circuit( /// Carries out the checks necessary for LtFromEntries and /// LtEqFromEntries. fn verify_lt_lteq_from_entries_circuit( - params: &Params, builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, @@ -943,12 +940,12 @@ fn verify_lt_lteq_from_entries_circuit( let lt_op_st_code_ok = { let op_code_ok = op_type.has_native(builder, NativeOperation::LtFromEntries); - let st_code_ok = st.has_native_type(builder, params, NativePredicate::Lt); + let st_code_ok = st.has_native_type(builder, NativePredicate::Lt); builder.and(op_code_ok, st_code_ok) }; let lteq_op_st_code_ok = { let op_code_ok = op_type.has_native(builder, NativeOperation::LtEqFromEntries); - let st_code_ok = st.has_native_type(builder, params, NativePredicate::LtEq); + let st_code_ok = st.has_native_type(builder, NativePredicate::LtEq); builder.and(op_code_ok, st_code_ok) }; let op_st_code_ok = builder.or(lt_op_st_code_ok, lteq_op_st_code_ok); @@ -981,7 +978,7 @@ fn verify_lt_lteq_from_entries_circuit( let expected_st_args: Vec<_> = [arg1_expected, arg2_expected] .into_iter() .chain(std::iter::repeat_with(|| StatementArgTarget::none(builder))) - .take(params.max_statement_args) + .take(Params::max_statement_args()) .flat_map(|arg| arg.elements) .collect(); @@ -1233,8 +1230,8 @@ fn verify_transitive_eq_circuit( let measure = measure_gates_begin!(builder, "OpTransitiveEq"); let op_code_ok = op_type.has_native(builder, NativeOperation::TransitiveEqualFromStatements); - let arg1_type_ok = resolved_op_args[0].has_native_type(builder, params, NativePredicate::Equal); - let arg2_type_ok = resolved_op_args[1].has_native_type(builder, params, NativePredicate::Equal); + let arg1_type_ok = resolved_op_args[0].has_native_type(builder, NativePredicate::Equal); + let arg2_type_ok = resolved_op_args[1].has_native_type(builder, NativePredicate::Equal); let arg_types_ok = builder.all([arg1_type_ok, arg2_type_ok]); let arg1_lhs = &resolved_op_args[0].args[0]; @@ -1285,7 +1282,7 @@ fn verify_lt_to_neq_circuit( let measure = measure_gates_begin!(builder, "OpLtToNeq"); let op_code_ok = op_type.has_native(builder, NativeOperation::LtToNotEqual); - let arg_type_ok = resolved_op_args[0].has_native_type(builder, params, NativePredicate::Lt); + let arg_type_ok = resolved_op_args[0].has_native_type(builder, NativePredicate::Lt); let arg1_expected = resolved_op_args[0].args[0].clone(); let arg2_expected = resolved_op_args[0].args[1].clone(); @@ -1442,7 +1439,7 @@ fn make_custom_statement_circuit( let st_predicate = PredicateTarget::new_custom(builder, batch_id, index); let arg_none = ValueTarget::zero(builder); let lt_mask = builder.lt_mask( - params.max_statement_args, + Params::max_statement_args(), custom_predicate.predicate.args_len, ); let st_args = std::iter::zip(lt_mask, args) @@ -1466,7 +1463,7 @@ fn make_custom_statement_circuit( .collect(); // expected_sts.len() == params.max_custom_predicate_arity // op_args.len() == params.max_operation_args; - assert!(params.max_custom_predicate_arity <= params.max_operation_args); + assert!(Params::max_custom_predicate_arity() <= params.max_operation_args); let sts_eq: Vec<_> = expected_sts .iter() @@ -1508,19 +1505,18 @@ fn normalize_statement_circuit( /// statements reversed. The part of the hash from the front-padded none-statements is /// precomputed. pub fn calculate_statements_hash_circuit( - params: &Params, builder: &mut CircuitBuilder, // These statements will be padded to reach `num_statements` statements: &[StatementTarget], ) -> HashOutTarget { - assert!(statements.len() <= params.num_public_statements_hash); + assert!(statements.len() <= Params::num_public_statements_hash()); let measure = measure_gates_begin!(builder, "CalculateStsHash"); let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten()); let mut none_st = mainpod::Statement::from(Statement::None); - pad_statement(params, &mut none_st); + pad_statement(&mut none_st); let front_pad_elts = iter::repeat(&none_st) - .take(params.num_public_statements_hash - statements.len()) - .flat_map(|s| s.to_fields(params)) + .take(Params::num_public_statements_hash() - statements.len()) + .flat_map(|s| s.to_fields()) .collect_vec(); let (perm, front_pad_elts_rem) = precompute_hash_state::>(&front_pad_elts); @@ -1581,7 +1577,7 @@ fn build_custom_predicate_table_circuit( ) -> Result> { let measure = measure_gates_begin!(builder, "BuildCustomPredTbl"); let mut custom_predicate_table = - Vec::with_capacity(params.max_custom_predicate_batches * params.max_custom_batch_size); + Vec::with_capacity(params.max_custom_predicate_batches * Params::max_custom_batch_size()); for cpb in custom_predicate_batches { let measure_cpb = measure_gates_begin!(builder, "CustomPredBatch"); let id = cpb.id(builder); // constrain the id @@ -1655,7 +1651,7 @@ fn verify_main_pod_circuit( let mut intro_ok = is_blank_intro; for self_st in &input_pod_self_statements[1..] { let st_is_intro = self_st.pred_is_blank_intro(builder); - let st_is_none = self_st.has_native_type(builder, params, NativePredicate::None); + let st_is_none = self_st.has_native_type(builder, NativePredicate::None); let st_is_intro_or_none = builder.or(st_is_intro, st_is_none); intro_ok = builder.and(intro_ok, st_is_intro_or_none); } @@ -1671,8 +1667,7 @@ fn verify_main_pod_circuit( ); statements.push(normalized_st); } - let sts_hash = - calculate_statements_hash_circuit(params, builder, input_pod_self_statements); + let sts_hash = calculate_statements_hash_circuit(builder, input_pod_self_statements); builder.connect_hashes(expected_sts_hash, sts_hash); // @@ -1730,7 +1725,7 @@ fn verify_main_pod_circuit( )?; // 2. Calculate the Pod Id from the public statements - let sts_hash = calculate_statements_hash_circuit(params, builder, pub_statements); + let sts_hash = calculate_statements_hash_circuit(builder, pub_statements); // 5. Verify input statements for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() { @@ -1774,12 +1769,12 @@ impl MainPodVerifyTarget { input_pods_self_statements: (0..params.max_input_pods) .map(|_| { (0..params.max_input_pods_public_statements) - .map(|_| builder.add_virtual_statement(params, false)) + .map(|_| builder.add_virtual_statement(false)) .collect_vec() }) .collect(), input_statements: (0..params.max_statements) - .map(|_| builder.add_virtual_statement(params, false)) + .map(|_| builder.add_virtual_statement(false)) .collect(), operations: (0..params.max_statements) .map(|_| builder.add_virtual_operation(params)) @@ -1805,7 +1800,7 @@ impl MainPodVerifyTarget { }) .collect(), custom_predicate_batches: (0..params.max_custom_predicate_batches) - .map(|_| builder.add_virtual_custom_predicate_batch(params, true)) + .map(|_| builder.add_virtual_custom_predicate_batch(true)) .collect(), custom_predicate_verifications: (0..params.max_custom_predicate_verifications) .map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder)) @@ -1849,16 +1844,16 @@ fn set_targets_input_pods_self_statements( statements_target.len(), params.max_input_pods_public_statements ); - assert!(statements.len() <= params.num_public_statements_hash); + assert!(statements.len() <= Params::num_public_statements_hash()); for (i, statement) in statements.iter().enumerate() { - statements_target[i].set_targets(pw, params, &statement.clone().into())?; + statements_target[i].set_targets(pw, &statement.clone().into())?; } // Padding let mut none_st = mainpod::Statement::from(Statement::None); - pad_statement(params, &mut none_st); + pad_statement(&mut none_st); for statement_target in statements_target.iter().skip(statements.len()) { - statement_target.set_targets(pw, params, &none_st)?; + statement_target.set_targets(pw, &none_st)?; } Ok(()) } @@ -1903,7 +1898,7 @@ impl InnerCircuit for MainPodVerifyTarget { } // Padding if input_pods_len != self.params.max_input_pods { - let empty_pod = EmptyPod::new_boxed(&self.params, input.vds_set.clone()); + let empty_pod = EmptyPod::new_boxed(input.vds_set.clone()); let empty_pod_statements = empty_pod.pub_statements(); let empty_mt_proof = MerkleClaimAndProof { root: input.vds_set.root(), @@ -1924,7 +1919,7 @@ impl InnerCircuit for MainPodVerifyTarget { assert_eq!(input.statements.len(), self.params.max_statements); for (i, (st, op)) in zip_eq(&input.statements, &input.operations).enumerate() { - self.input_statements[i].set_targets(pw, &self.params, st)?; + self.input_statements[i].set_targets(pw, st)?; self.operations[i].set_targets(pw, &self.params, op)?; } @@ -1979,7 +1974,7 @@ impl InnerCircuit for MainPodVerifyTarget { assert!(input.custom_predicate_batches.len() <= self.params.max_custom_predicate_batches); for (i, cpb) in input.custom_predicate_batches.iter().enumerate() { - self.custom_predicate_batches[i].set_targets(pw, &self.params, cpb)?; + self.custom_predicate_batches[i].set_targets(pw, cpb)?; } // Padding let pad_cpb = CustomPredicateBatch::new( @@ -1988,7 +1983,7 @@ impl InnerCircuit for MainPodVerifyTarget { vec![CustomPredicate::empty()], ); for i in input.custom_predicate_batches.len()..self.params.max_custom_predicate_batches { - self.custom_predicate_batches[i].set_targets(pw, &self.params, &pad_cpb)?; + self.custom_predicate_batches[i].set_targets(pw, &pad_cpb)?; } assert!( @@ -2048,7 +2043,7 @@ mod tests { frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, middleware::{ hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, - RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard, + RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard, EMPTY_VALUE, }, }; @@ -2108,10 +2103,10 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let st_target = builder.add_virtual_statement(¶ms, false); + let st_target = builder.add_virtual_statement(false); let op_target = builder.add_virtual_operation(¶ms); let prev_statements_target: Vec<_> = (0..prev_statements.len()) - .map(|_| builder.add_virtual_statement(¶ms, false)) + .map(|_| builder.add_virtual_statement(false)) .collect(); let merkle_proofs_target: Vec<_> = aux @@ -2166,10 +2161,10 @@ mod tests { )?; let mut pw = PartialWitness::::new(); - st_target.set_targets(&mut pw, ¶ms, &st)?; + st_target.set_targets(&mut pw, &st)?; op_target.set_targets(&mut pw, ¶ms, &op)?; for (prev_st_target, prev_st) in prev_statements_target.iter().zip(prev_statements.iter()) { - prev_st_target.set_targets(&mut pw, ¶ms, prev_st)?; + prev_st_target.set_targets(&mut pw, prev_st)?; } for (signed_by_target, signed_by) in signed_by_targets.iter().zip(aux.signed_bys.iter()) { signed_by_target.set_targets(&mut pw, signed_by)? @@ -3065,11 +3060,11 @@ mod tests { let mut pw = PartialWitness::::new(); - st_tmpl_arg_target.set_targets(&mut pw, params, &st_tmpl_arg)?; + st_tmpl_arg_target.set_targets(&mut pw, &st_tmpl_arg)?; for (arg_target, arg) in args_target.iter().zip(args.iter()) { arg_target.set_targets(&mut pw, arg)?; } - expected_st_arg_target.set_targets(&mut pw, params, &expected_st_arg)?; + expected_st_arg_target.set_targets(&mut pw, &expected_st_arg)?; // generate & verify proof let data = builder.build::(); @@ -3122,7 +3117,7 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let st_tmpl_target = builder.add_virtual_statement_tmpl(params, false); + let st_tmpl_target = builder.add_virtual_statement_tmpl(false); let args_target: Vec<_> = (0..args.len()) .map(|_| builder.add_virtual_value()) .collect(); @@ -3133,16 +3128,16 @@ mod tests { &args_target, ); // TODO: Instead of connect, assign witness to result - let expected_st_target = builder.add_virtual_statement(params, false); + let expected_st_target = builder.add_virtual_statement(false); builder.connect_flattenable(&expected_st_target, &st_target); let mut pw = PartialWitness::::new(); - st_tmpl_target.set_targets(&mut pw, params, &st_tmpl)?; + st_tmpl_target.set_targets(&mut pw, &st_tmpl)?; for (arg_target, arg) in args_target.iter().zip(args.iter()) { arg_target.set_targets(&mut pw, arg)?; } - expected_st_target.set_targets(&mut pw, params, &expected_st.into())?; + expected_st_target.set_targets(&mut pw, &expected_st.into())?; // generate & verify proof let data = builder.build::(); @@ -3179,7 +3174,7 @@ mod tests { StatementTmplArg::Literal(Value::from("value")), ], }; - let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(¶ms); + let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(); let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)]; let expected_st = Statement::not_equal( AnchoredKey::new(dict, Key::from("key")), @@ -3193,16 +3188,24 @@ mod tests { fn helper_custom_operation_verify_gadget( params: &Params, custom_predicate: CustomPredicateRef, - op_args: Vec, - args: Vec, + mut op_args: Vec, + mut args: Vec, expected_st: Option, ) -> Result<()> { + // Pad + for _ in op_args.len()..params.max_operation_args { + op_args.push(Statement::None); + } + for _ in args.len()..params.max_custom_predicate_wildcards { + args.push(Value::from(EMPTY_VALUE)); + } + let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params); - let op_args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_statement(params, false)) + let custom_predicate_target = builder.add_virtual_custom_predicate_entry(); + let op_args_target: Vec<_> = (0..op_args.len()) + .map(|_| builder.add_virtual_statement(false)) .collect(); let args_target: Vec<_> = (0..args.len()) .map(|_| builder.add_virtual_value()) @@ -3218,20 +3221,20 @@ mod tests { let mut pw = PartialWitness::::new(); // Input - custom_predicate_target.set_targets(&mut pw, params, &custom_predicate)?; + custom_predicate_target.set_targets(&mut pw, &custom_predicate)?; for (op_arg_target, op_arg) in op_args_target.iter().zip(op_args.into_iter()) { - op_arg_target.set_targets(&mut pw, params, &op_arg.into())?; + op_arg_target.set_targets(&mut pw, &op_arg.into())?; } for (arg_target, arg) in args_target.iter().zip(args.iter()) { arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?; } // Expected Output if let Some(expected_st) = expected_st { - st_target.set_targets(&mut pw, params, &expected_st.into())?; + st_target.set_targets(&mut pw, &expected_st.into())?; } let expected_op_type = OperationType::Custom(custom_predicate); - op_type_target.set_targets(&mut pw, params, &expected_op_type)?; + op_type_target.set_targets(&mut pw, &expected_op_type)?; // generate & verify proof let data = builder.build::(); @@ -3242,15 +3245,7 @@ mod tests { // TODO: Add negative tests #[test] fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { - // We set the parameters to the exact sizes we have in the test so that we don't have to - // pad. - let params = Params { - max_custom_predicate_arity: 2, - max_custom_predicate_wildcards: 2, - max_operation_args: 2, - max_statement_args: 2, - ..Default::default() - }; + let params = Params::default(); use NativePredicate as NP; use StatementTmplBuilder as STB; @@ -3340,15 +3335,7 @@ mod tests { #[test] fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> { - // We set the parameters to the exact sizes we have in the test so that we don't have to - // pad. - let params = Params { - max_custom_predicate_arity: 2, - max_custom_predicate_wildcards: 2, - max_operation_args: 2, - max_statement_args: 2, - ..Default::default() - }; + let params = Params::default(); use NativePredicate as NP; use StatementTmplBuilder as STB; @@ -3500,10 +3487,9 @@ mod tests { let mut builder = CircuitBuilder::new(config); let statements_target = (0..params.max_public_statements) - .map(|_| builder.add_virtual_statement(params, false)) + .map(|_| builder.add_virtual_statement(false)) .collect_vec(); - let sts_hash_target = - calculate_statements_hash_circuit(params, &mut builder, &statements_target); + let sts_hash_target = calculate_statements_hash_circuit(&mut builder, &statements_target); let mut pw = PartialWitness::::new(); @@ -3512,15 +3498,15 @@ mod tests { .iter() .map(|st| { let mut st = mainpod::Statement::from(st.clone()); - pad_statement(params, &mut st); + pad_statement(&mut st); st }) .collect_vec(); for (st_target, st) in statements_target.iter().zip(statements.iter()) { - st_target.set_targets(&mut pw, params, st)?; + st_target.set_targets(&mut pw, st)?; } // Expected Output - let expected_sts_hash = calculate_statements_hash(&statements, params); + let expected_sts_hash = calculate_statements_hash(&statements); pw.set_hash_target( sts_hash_target, HashOut { @@ -3536,10 +3522,10 @@ mod tests { #[test] fn test_calculate_sts_hash() -> frontend::Result<()> { + assert_eq!(Params::num_public_statements_hash(), 16); // Case with no public public statements let params = Params { max_public_statements: 0, - num_public_statements_hash: 8, ..Default::default() }; @@ -3547,30 +3533,20 @@ mod tests { // Case with number of statements for the sts_hash equal to number of public statements let params = Params { - max_public_statements: 2, - num_public_statements_hash: 2, + max_public_statements: Params::num_public_statements_hash(), ..Default::default() }; let dict = Hash([F(1), F(2), F(3), F(4)]); - let statements = [ - Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(42)), - Statement::equal( - AnchoredKey::from((dict, "bar")), - AnchoredKey::from((dict, "baz")), - ), - ] - .into_iter() - .chain(iter::repeat(Statement::None)) - .take(params.max_public_statements) - .collect_vec(); + let statements = (0..Params::num_public_statements_hash()) + .map(|i| Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(i as i64))) + .collect_vec(); helper_calculate_statements_hash(¶ms, &statements).unwrap(); - // Case with more statements for the sts_hash than the number of public statements + // Case with more statements for the sts_hash than the number of public statements let params = Params { max_public_statements: 4, - num_public_statements_hash: 6, ..Default::default() }; diff --git a/src/backends/plonky2/emptypod.rs b/src/backends/plonky2/emptypod.rs index 5e0be7c..4ca92ff 100644 --- a/src/backends/plonky2/emptypod.rs +++ b/src/backends/plonky2/emptypod.rs @@ -67,11 +67,9 @@ fn verify_empty_pod_circuit( builder: &mut CircuitBuilder, empty_pod: &EmptyPodVerifyTarget, ) { - let empty_statement = StatementTarget::from_flattened( - params, - &builder.constants(&empty_statement().to_fields(params)), - ); - let sts_hash = calculate_statements_hash_circuit(params, builder, &[empty_statement]); + let empty_statement = + StatementTarget::from_flattened(params, &builder.constants(&empty_statement().to_fields())); + let sts_hash = calculate_statements_hash_circuit(builder, &[empty_statement]); builder.register_public_inputs(&sts_hash.elements); builder.register_public_inputs(&empty_pod.vds_root.elements); } @@ -126,7 +124,7 @@ fn build() -> Result<(EmptyPodVerifyTarget, CircuitData)> { } impl EmptyPod { - fn new(params: &Params, vd_set: VDSet) -> Result { + fn new(vd_set: VDSet) -> Result { let (empty_pod_verify_target, data) = &*cache_get_standard_empty_pod_circuit_data(); let mut pw = PartialWitness::::new(); @@ -139,7 +137,7 @@ impl EmptyPod { }; let common_hash = hash_common_data(&data.common).expect("hash ok"); Ok(EmptyPod { - params: params.clone(), + params: Params::default(), verifier_only: VerifierOnlyCircuitDataSerializer(data.verifier_only.clone()), common_hash, sts_hash, @@ -147,15 +145,10 @@ impl EmptyPod { proof: proof.proof, }) } - pub fn new_boxed(params: &Params, vd_set: VDSet) -> Box { - let default_params = Params::default(); - assert_eq!(default_params.id_params(), params.id_params()); - - let empty_pod = cache::get( - "empty_pod", - &(default_params, vd_set), - |(params, vd_set)| Self::new(params, vd_set.clone()).expect("prove EmptyPod"), - ) + pub fn new_boxed(vd_set: VDSet) -> Box { + let empty_pod = cache::get("empty_pod", &vd_set, |vd_set| { + Self::new(vd_set.clone()).expect("prove EmptyPod") + }) .expect("cache ok"); Box::new(empty_pod.clone()) } @@ -178,13 +171,13 @@ impl Pod for EmptyPod { .into_iter() .map(mainpod::Statement::from) .collect_vec(); - let sts_hash = calculate_statements_hash(&statements, &self.params); + let sts_hash = calculate_statements_hash(&statements); if sts_hash != self.sts_hash { return Err(Error::statements_hash_not_equal(self.sts_hash, sts_hash)); } let public_inputs = sts_hash - .to_fields(&self.params) + .to_fields() .iter() .chain(self.vd_set.root().0.iter()) .cloned() @@ -258,9 +251,7 @@ pub mod tests { #[test] fn test_empty_pod() { - let params = Params::default(); - - let empty_pod = EmptyPod::new_boxed(¶ms, VDSet::new(&[])); + let empty_pod = EmptyPod::new_boxed(VDSet::new(&[])); empty_pod.verify().unwrap(); } } diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 6c20a09..83190e9 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -50,21 +50,20 @@ use crate::{ /// circuits with a small `max_public_statements` only pay for `max_public_statements` by starting /// the poseidon state with a precomputed constant corresponding to the front-padding part: `id = /// hash(serialize(reverse(statements || none-statements)))` -pub fn calculate_statements_hash(statements: &[Statement], params: &Params) -> middleware::Hash { - assert!(statements.len() <= params.num_public_statements_hash); - assert!(params.max_public_statements <= params.num_public_statements_hash); +pub fn calculate_statements_hash(statements: &[Statement]) -> middleware::Hash { + assert!(statements.len() <= Params::num_public_statements_hash()); let mut none_st: Statement = middleware::Statement::None.into(); - pad_statement(params, &mut none_st); + pad_statement(&mut none_st); let statements_back_padded = statements .iter() .chain(iter::repeat(&none_st)) - .take(params.num_public_statements_hash) + .take(Params::num_public_statements_hash()) .collect_vec(); let field_elems = statements_back_padded .iter() .rev() - .flat_map(|statement| statement.to_fields(params)) + .flat_map(|statement| statement.to_fields()) .collect::>(); Hash(PoseidonHash::hash_no_pad(&field_elems).elements) } @@ -115,7 +114,7 @@ pub(crate) fn extract_custom_predicate_verifications( .find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i)) .expect("find the custom predicate from the extracted unique list"); let custom_predicate_table_index = - batch_index * params.max_custom_batch_size + cpr.index; + batch_index * Params::max_custom_batch_size() + cpr.index; aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len()); table.push(CustomPredicateVerification { custom_predicate_table_index, @@ -326,8 +325,8 @@ fn fill_pad(v: &mut Vec, pad_value: T, len: usize) { } } -pub fn pad_statement(params: &Params, s: &mut Statement) { - fill_pad(&mut s.1, StatementArg::None, params.max_statement_args) +pub fn pad_statement(s: &mut Statement) { + fill_pad(&mut s.1, StatementArg::None, Params::max_statement_args()) } fn pad_operation_args(params: &Params, args: &mut Vec) { @@ -353,7 +352,7 @@ pub(crate) fn layout_statements( // We mocking or we don't need padding so we skip creating an EmptyPod MockEmptyPod::new_boxed(params, inputs.vd_set.clone()) } else { - EmptyPod::new_boxed(params, inputs.vd_set.clone()) + EmptyPod::new_boxed(inputs.vd_set.clone()) }; let empty_pod = empty_pod_box.as_ref(); assert!(inputs.pods.len() <= params.max_input_pods); @@ -367,7 +366,7 @@ pub(crate) fn layout_statements( .unwrap_or(&middleware::Statement::None) .clone() .into(); - pad_statement(params, &mut st); + pad_statement(&mut st); statements.push(st); } } @@ -386,7 +385,7 @@ pub(crate) fn layout_statements( .unwrap_or(&middleware::Statement::None) .clone() .into(); - pad_statement(params, &mut st); + pad_statement(&mut st); statements.push(st); } @@ -399,7 +398,7 @@ pub(crate) fn layout_statements( .unwrap_or(&middleware::Statement::None) .clone() .into(); - pad_statement(params, &mut st); + pad_statement(&mut st); statements.push(st); } @@ -475,7 +474,7 @@ impl MainPodProver for Prover { // We don't need padding so we skip creating an EmptyPod MockEmptyPod::new_boxed(params, inputs.vd_set.clone()) } else { - EmptyPod::new_boxed(params, inputs.vd_set.clone()) + EmptyPod::new_boxed(inputs.vd_set.clone()) }; let inputs = MainPodInputs { pods: &inputs @@ -491,10 +490,7 @@ impl MainPodProver for Prover { let input_pods_pub_self_statements = inputs .pods .iter() - .map(|pod| { - assert_eq!(params.id_params(), pod.params().id_params()); - pod.pub_self_statements() - }) + .map(|pod| pod.pub_self_statements()) .collect_vec(); // Aux values for backend::Operation @@ -527,7 +523,7 @@ impl MainPodProver for Prover { let operations = process_public_statements_operations(params, &statements, operations)?; // get the id out of the public statements - let sts_hash = calculate_statements_hash(&public_statements, params); + let sts_hash = calculate_statements_hash(&public_statements); let common_hash: String = cache_get_rec_main_pod_common_hash(params).clone(); let proofs = inputs @@ -718,7 +714,7 @@ impl Pod for MainPod { ))); } // 2. get the id out of the public statements - let sts_hash = calculate_statements_hash(&self.public_statements, &self.params); + let sts_hash = calculate_statements_hash(&self.public_statements); if sts_hash != self.sts_hash { return Err(Error::statements_hash_not_equal(self.sts_hash, sts_hash)); } @@ -738,7 +734,7 @@ impl Pod for MainPod { let rec_main_pod_verifier_circuit_data = &*cache_get_rec_main_pod_verifier_circuit_data(&self.params); let public_inputs = sts_hash - .to_fields(&self.params) + .to_fields() .iter() .chain(self.vd_set.root().0.iter()) .cloned() @@ -998,14 +994,10 @@ pub mod tests { max_input_pods_public_statements: 2, max_statements: 5, max_public_statements: 2, - num_public_statements_hash: 4, - max_statement_args: 4, - max_operation_args: 4, + max_operation_args: 5, max_custom_predicate_batches: 2, max_custom_predicate_verifications: 2, - max_custom_predicate_arity: 2, max_custom_predicate_wildcards: 3, - max_custom_batch_size: 2, max_merkle_proofs_containers: 2, max_merkle_tree_state_transition_proofs_containers: 2, max_public_key_of: 2, @@ -1067,10 +1059,7 @@ pub mod tests { max_input_pods: 0, max_statements: 9, max_public_statements: 4, - max_statement_args: 4, - max_operation_args: 4, - max_custom_predicate_arity: 3, - max_custom_batch_size: 3, + max_operation_args: 5, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, max_merkle_proofs_containers: 3, diff --git a/src/backends/plonky2/mainpod/statement.rs b/src/backends/plonky2/mainpod/statement.rs index 8a19b57..27776a6 100644 --- a/src/backends/plonky2/mainpod/statement.rs +++ b/src/backends/plonky2/mainpod/statement.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::error::{Error, Result}, - middleware::{self, NativePredicate, Params, Predicate, StatementArg, ToFields, Value}, + middleware::{self, NativePredicate, Predicate, StatementArg, ToFields, Value, BASE_PARAMS}, }; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -30,14 +30,14 @@ impl Statement { } impl ToFields for Statement { - fn to_fields(&self, params: &Params) -> Vec { - let mut fields = self.0.hash(params).to_fields(params); + fn to_fields(&self) -> Vec { + let mut fields = self.0.hash().to_fields(); fields.extend( self.1 .iter() .chain(iter::repeat(&StatementArg::None)) - .take(params.max_statement_args) - .flat_map(|arg| arg.to_fields(params)), + .take(BASE_PARAMS.max_statement_args) + .flat_map(|arg| arg.to_fields()), ); fields } diff --git a/src/backends/plonky2/mock/emptypod.rs b/src/backends/plonky2/mock/emptypod.rs index fff0b5c..2ccf7cd 100644 --- a/src/backends/plonky2/mock/emptypod.rs +++ b/src/backends/plonky2/mock/emptypod.rs @@ -30,7 +30,7 @@ fn empty_statement() -> Statement { impl MockEmptyPod { pub fn new_boxed(params: &Params, vd_set: VDSet) -> Box { let statements = [mainpod::Statement::from(empty_statement())]; - let sts_hash = calculate_statements_hash(&statements, params); + let sts_hash = calculate_statements_hash(&statements); Box::new(Self { params: params.clone(), sts_hash, @@ -49,7 +49,7 @@ impl Pod for MockEmptyPod { .into_iter() .map(mainpod::Statement::from) .collect_vec(); - let sts_hash = calculate_statements_hash(&statements, &self.params); + let sts_hash = calculate_statements_hash(&statements); if sts_hash != self.sts_hash { return Err(Error::statements_hash_not_equal(self.sts_hash, sts_hash)); } diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index ddd3c39..dcb1355 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -167,7 +167,7 @@ impl MockMainPod { let operations = process_public_statements_operations(params, &statements, operations)?; // get the id out of the public statements - let sts_hash = calculate_statements_hash(&public_statements, params); + let sts_hash = calculate_statements_hash(&public_statements); let pad_pod = MockEmptyPod::new_boxed(params, inputs.vd_set.clone()); let input_pods: Vec> = inputs diff --git a/src/examples/mod.rs b/src/examples/mod.rs index 0801978..5a5775f 100644 --- a/src/examples/mod.rs +++ b/src/examples/mod.rs @@ -258,21 +258,21 @@ pub fn great_boy_pod_builder( let mut great_boy = MainPodBuilder::new(params, vd_set); for good_boy_signed_dict in good_boy_signed_dicts { - great_boy.pub_op(Operation::dict_signed_by(good_boy_signed_dict))?; + great_boy.priv_op(Operation::dict_signed_by(good_boy_signed_dict))?; } for friend_signed_dict in friend_signed_dicts { - great_boy.pub_op(Operation::dict_signed_by(friend_signed_dict))?; + great_boy.priv_op(Operation::dict_signed_by(friend_signed_dict))?; } for good_boy_idx in 0..2 { for issuer_idx in 0..2 { // Each good boy POD comes from a valid issuer - great_boy.pub_op(Operation::set_contains( + great_boy.priv_op(Operation::set_contains( good_boy_issuers, good_boy_signed_dicts[good_boy_idx * 2 + issuer_idx].public_key, ))?; // Each good boy has 2 good boy pods - great_boy.pub_op(Operation::eq( + great_boy.priv_op(Operation::eq( (good_boy_signed_dicts[good_boy_idx * 2 + issuer_idx], "user"), friend_signed_dicts[good_boy_idx].public_key, ))?; @@ -302,8 +302,6 @@ pub fn great_boy_pod_full_flow() -> Result { max_signed_by: 6, max_input_pods: 0, max_statements: 100, - max_public_statements: 50, - num_public_statements_hash: 50, ..Default::default() }; let vd_set = &*MOCK_VD_SET; diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index c0ee1ba..fec897c 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -171,19 +171,19 @@ impl CustomPredicateBatchBuilder { priv_args: &[&str], sts: &[StatementTmplBuilder], ) -> Result { - if self.predicates.len() >= self.params.max_custom_batch_size { + if self.predicates.len() >= Params::max_custom_batch_size() { return Err(Error::max_length( "self.predicates.len".to_string(), self.predicates.len(), - self.params.max_custom_batch_size, + Params::max_custom_batch_size(), )); } - if args.len() > self.params.max_statement_args { + if args.len() > Params::max_statement_args() { return Err(Error::max_length( "args.len".to_string(), args.len(), - self.params.max_statement_args, + Params::max_statement_args(), )); } if (args.len() + priv_args.len()) > self.params.max_custom_predicate_wildcards { @@ -278,7 +278,6 @@ mod tests { use StatementTmplBuilder as STB; let params = Params { - max_statement_args: 6, max_custom_predicate_wildcards: 12, ..Default::default() }; @@ -292,7 +291,7 @@ mod tests { let eth_dos_batch_mw: middleware::CustomPredicateBatch = Arc::unwrap_or_clone(eth_dos_batch); - let fields = eth_dos_batch_mw.to_fields(¶ms); + let fields = eth_dos_batch_mw.to_fields(); println!("Batch b, serialized: {:?}", fields); Ok(()) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index c8ea847..2758660 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -600,7 +600,7 @@ impl MainPodBuilder { } wildcard_map[index] = Some(value); } - fill_wildcard_values(&self.params, pred, &args, &mut wildcard_map)?; + fill_wildcard_values(pred, &args, &mut wildcard_map)?; let v_default = Value::from(0); let st_args: Vec<_> = wildcard_map .into_iter() diff --git a/src/lang/frontend_ast_batch.rs b/src/lang/frontend_ast_batch.rs index 6b5f375..61b167e 100644 --- a/src/lang/frontend_ast_batch.rs +++ b/src/lang/frontend_ast_batch.rs @@ -355,7 +355,7 @@ pub fn batch_predicates( } // Plan batch assignments in declaration order - let assignments = plan_batch_assignments(&predicates, params.max_custom_batch_size)?; + let assignments = plan_batch_assignments(&predicates, Params::max_custom_batch_size())?; // Build reference map: name -> (batch_idx, idx_in_batch) let reference_map: HashMap = assignments @@ -1039,18 +1039,18 @@ mod tests { #[test] fn test_mutual_recursion_exceeds_capacity_error() { - // Two predicates that call each other (SCC size = 2) with max batch size 1 + // Two predicates that call each other (SCC size = 5) with max batch size 4 // Should error because an SCC cannot be split across batches let input = r#" pred1(A) = AND(pred2(A)) - pred2(B) = AND(pred1(B)) + pred2(B) = AND(pred3(B)) + pred3(B) = AND(pred4(B)) + pred4(B) = AND(pred5(B)) + pred5(B) = AND(pred1(B)) "#; let (predicates, validated) = parse_and_validate(input); - let params = Params { - max_custom_batch_size: 1, // force SCC > capacity - ..Default::default() - }; + let params = Params::default(); let result = batch_predicates( preds_to_split_results(predicates), diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index 185681e..b1db536 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -326,10 +326,10 @@ impl<'a> Lowerer<'a> { batches: Option<&PredicateBatches>, ) -> Result { // Enforce argument count limit for request statements - if stmt.args.len() > self.params.max_statement_args { + if stmt.args.len() > Params::max_statement_args() { return Err(LoweringError::TooManyStatementArgs { count: stmt.args.len(), - max: self.params.max_statement_args, + max: Params::max_statement_args(), }); } @@ -446,6 +446,7 @@ impl<'a> Lowerer<'a> { let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?; split_results.push(result); } + Ok(split_results) } } @@ -676,7 +677,8 @@ mod tests { "#; let params = Params::default(); - parse_validate_and_lower(input, ¶ms).unwrap(); + let result = parse_validate_and_lower(input, ¶ms); + assert!(result.is_ok()); } #[test] diff --git a/src/lang/frontend_ast_split.rs b/src/lang/frontend_ast_split.rs index 303720e..cc37463 100644 --- a/src/lang/frontend_ast_split.rs +++ b/src/lang/frontend_ast_split.rs @@ -76,18 +76,15 @@ struct WildcardUsage { } /// Early validation: Check if predicate is fundamentally splittable -pub fn validate_predicate_is_splittable( - pred: &CustomPredicateDef, - params: &Params, -) -> Result<(), SplittingError> { +pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), SplittingError> { let public_args = pred.args.public_args.len(); // Check: public args must fit in operation arg limit - if public_args > params.max_statement_args { + if public_args > Params::max_statement_args() { return Err(SplittingError::TooManyPublicArgs { predicate: pred.name.name.clone(), count: public_args, - max_allowed: params.max_statement_args, + max_allowed: Params::max_statement_args(), message: "Public arguments exceed max operation args - cannot call this predicate" .to_string(), }); @@ -102,10 +99,10 @@ pub fn split_predicate_if_needed( params: &Params, ) -> Result { // Early validation - validate_predicate_is_splittable(&pred, params)?; + validate_predicate_is_splittable(&pred)?; // If within limits, no splitting needed - if pred.statements.len() <= params.max_custom_predicate_arity { + if pred.statements.len() <= Params::max_custom_predicate_arity() { return Ok(SplitResult { predicates: vec![pred], chain_info: None, @@ -173,12 +170,11 @@ struct OrderingResult { fn order_constraints_optimally( statements: Vec, _usage: &HashMap, - params: &Params, ) -> OrderingResult { let n = statements.len(); // If no splitting needed, preserve original order (identity mapping) - if n <= params.max_custom_predicate_arity { + if n <= Params::max_custom_predicate_arity() { return OrderingResult { statements, reorder_map: (0..n).collect(), @@ -191,13 +187,8 @@ fn order_constraints_optimally( let mut active_wildcards: HashSet = HashSet::new(); while !remaining.is_empty() { - let best_idx = find_best_next_statement( - &statements, - &remaining, - &active_wildcards, - ordered.len(), - params, - ); + let best_idx = + find_best_next_statement(&statements, &remaining, &active_wildcards, ordered.len()); remaining.remove(&best_idx); let stmt = &statements[best_idx]; @@ -268,10 +259,9 @@ fn find_best_next_statement( remaining: &HashSet, active_wildcards: &HashSet, ordered_count: usize, - params: &Params, ) -> usize { // Calculate distance to next split point - let bucket_size = params.max_custom_predicate_arity - 1; // Reserve slot for chain call + let bucket_size = Params::max_custom_predicate_arity() - 1; // Reserve slot for chain call let distance_to_split = bucket_size - (ordered_count % bucket_size); let approaching_split = distance_to_split <= 2; @@ -432,7 +422,7 @@ fn split_into_chain( let usage = analyze_wildcards(&pred.statements); let real_statement_count = pred.statements.len(); - let ordering_result = order_constraints_optimally(pred.statements, &usage, params); + let ordering_result = order_constraints_optimally(pred.statements, &usage); let ordered_statements = ordering_result.statements; let reorder_map = ordering_result.reorder_map; @@ -449,12 +439,12 @@ fn split_into_chain( while pos < ordered_statements.len() { let remaining = ordered_statements.len() - pos; - let is_last = remaining <= params.max_custom_predicate_arity; + let is_last = remaining <= Params::max_custom_predicate_arity(); let bucket_size = if is_last { remaining // Last predicate uses all remaining } else { - params.max_custom_predicate_arity - 1 // Reserve slot for chain call + Params::max_custom_predicate_arity() - 1 // Reserve slot for chain call }; let end = pos + bucket_size; @@ -475,7 +465,7 @@ fn split_into_chain( .cloned() .collect(); let total_public = incoming_public.len() + new_promotions.len(); - if total_public > params.max_statement_args { + if total_public > Params::max_statement_args() { let context = crate::lang::error::SplitContext { split_index: chain_links.len(), statement_range: (pos, end), @@ -490,7 +480,7 @@ fn split_into_chain( return Err(SplittingError::TooManyPublicArgsAtSplit { predicate: original_name.clone(), context: Box::new(context), - max_allowed: params.max_statement_args, + max_allowed: Params::max_statement_args(), suggestion: suggestion.map(Box::new), }); } @@ -688,10 +678,10 @@ fn generate_chain_predicates( fn validate_chain(chain: &[CustomPredicateDef], params: &Params) -> Result<(), SplittingError> { for pred in chain { // Each predicate should have ≤ max_statements - assert!(pred.statements.len() <= params.max_custom_predicate_arity); + assert!(pred.statements.len() <= Params::max_custom_predicate_arity()); // Public args should fit - assert!(pred.args.public_args.len() <= params.max_statement_args); + assert!(pred.args.public_args.len() <= Params::max_statement_args()); // Total args should fit let total = @@ -729,9 +719,8 @@ mod tests { "#; let pred = parse_predicate(input); - let params = Params::default(); - assert!(validate_predicate_is_splittable(&pred, ¶ms).is_ok()); + assert!(validate_predicate_is_splittable(&pred).is_ok()); } #[test] @@ -743,9 +732,8 @@ mod tests { "#; let pred = parse_predicate(input); - let params = Params::default(); // max_statement_args = 5 - let result = validate_predicate_is_splittable(&pred, ¶ms); + let result = validate_predicate_is_splittable(&pred); assert!(matches!( result, Err(SplittingError::TooManyPublicArgs { .. }) diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 432991c..3908918 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -572,10 +572,7 @@ mod tests { max_input_pods: 3, max_statements: 31, max_public_statements: 10, - max_statement_args: 6, max_operation_args: 5, - max_custom_predicate_arity: 5, - max_custom_batch_size: 5, max_custom_predicate_wildcards: 12, ..Default::default() }; diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index 3bbb960..e6af211 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -49,7 +49,7 @@ use super::serialization::*; pub use crate::backends::plonky2::basetypes::*; #[cfg(feature = "backend_plonky2")] pub use crate::backends::plonky2::{Error as BackendError, Result as BackendResult}; -use crate::middleware::{Params, ToFields, Value}; +use crate::middleware::{ToFields, Value}; pub const HASH_SIZE: usize = 4; pub const VALUE_SIZE: usize = 4; @@ -71,7 +71,7 @@ pub struct RawValue( ); impl ToFields for RawValue { - fn to_fields(&self, _params: &Params) -> Vec { + fn to_fields(&self) -> Vec { self.0.to_vec() } } @@ -220,7 +220,7 @@ impl From for Hash { } impl ToFields for Hash { - fn to_fields(&self, _params: &Params) -> Vec { + fn to_fields(&self) -> Vec { self.0.to_vec() } } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index bfb8ce4..402ee44 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::middleware::{ hash_fields, Error, Hash, Key, NativePredicate, Params, Predicate, Result, ToFields, Value, - EMPTY_HASH, F, VALUE_SIZE, + BASE_PARAMS, EMPTY_HASH, F, VALUE_SIZE, }; #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] @@ -33,7 +33,7 @@ impl fmt::Display for Wildcard { } impl ToFields for Wildcard { - fn to_fields(&self, _params: &Params) -> Vec { + fn to_fields(&self) -> Vec { vec![F::from_canonical_u64(self.index as u64)] } } @@ -63,7 +63,7 @@ impl From for F { } impl ToFields for StatementTmplArg { - fn to_fields(&self, params: &Params) -> Vec { + fn to_fields(&self) -> Vec { // Encoding: // None => (0, 0, 0, 0, 0, 0, 0, 0, 0) // Literal(v) => (1, [v ], 0, 0, 0, 0) @@ -76,20 +76,20 @@ impl ToFields for StatementTmplArg { .take(Params::statement_tmpl_arg_size()) .collect_vec(), StatementTmplArg::Literal(v) => iter::once(F::from(StatementTmplArgPrefix::Literal)) - .chain(v.raw().to_fields(params)) + .chain(v.raw().to_fields()) .chain(iter::repeat(F::ZERO)) .take(Params::statement_tmpl_arg_size()) .collect_vec(), StatementTmplArg::AnchoredKey(wc1, kw2) => { iter::once(F::from(StatementTmplArgPrefix::AnchoredKey)) - .chain(wc1.to_fields(params)) + .chain(wc1.to_fields()) .chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1)) - .chain(kw2.to_fields(params)) + .chain(kw2.to_fields()) .collect_vec() } StatementTmplArg::Wildcard(wc) => { iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral)) - .chain(wc.to_fields(params)) + .chain(wc.to_fields()) .chain(iter::repeat(F::ZERO)) .take(Params::statement_tmpl_arg_size()) .collect_vec() @@ -160,16 +160,16 @@ impl From for F { } impl ToFields for PredicateOrWildcard { - fn to_fields(&self, params: &Params) -> Vec { + fn to_fields(&self) -> Vec { // Encoding: // Predicate(pred) => (0, [hash(pred) ]) // Wildcard(wc) => (1, wc_index, 0...) match self { Self::Predicate(pred) => iter::once(F::from(PredicateOrWildcardPrefix::Predicate)) - .chain(pred.hash(params).to_fields(params)) + .chain(pred.hash().to_fields()) .collect_vec(), Self::Wildcard(wc) => iter::once(F::from(PredicateOrWildcardPrefix::Wildcard)) - .chain(wc.to_fields(params)) + .chain(wc.to_fields()) .chain(iter::repeat(F::ZERO)) .take(Params::pred_hash_or_wc_size()) .collect_vec(), @@ -208,7 +208,7 @@ impl fmt::Display for StatementTmpl { } impl ToFields for StatementTmpl { - fn to_fields(&self, params: &Params) -> Vec { + fn to_fields(&self) -> Vec { // serialize as: // predicate (4 field elements) // then the StatementTmplArgs @@ -216,20 +216,20 @@ impl ToFields for StatementTmpl { // TODO think if this check should go into the StatementTmpl creation, // instead of at the `to_fields` method, where we should assume that the // values are already valid - if self.args.len() > params.max_statement_args { + if self.args.len() > BASE_PARAMS.max_statement_args { panic!( "Statement template has too many arguments {} > {}", self.args.len(), - params.max_statement_args + BASE_PARAMS.max_statement_args ); } self.pred_or_wc - .to_fields(params) + .to_fields() .into_iter() - .chain(self.args.iter().flat_map(|sta| sta.to_fields(params))) + .chain(self.args.iter().flat_map(|sta| sta.to_fields())) .chain(iter::repeat(F::ZERO)) - .take(params.statement_tmpl_size()) + .take(Params::statement_tmpl_size()) .collect_vec() } } @@ -300,18 +300,18 @@ impl CustomPredicate { args_len: usize, wildcard_names: Vec, ) -> Result { - if statements.len() > params.max_custom_predicate_arity { + if statements.len() > Params::max_custom_predicate_arity() { return Err(Error::max_length( "statements.len".to_string(), statements.len(), - params.max_custom_predicate_arity, + Params::max_custom_predicate_arity(), )); } - if args_len > params.max_statement_args { + if args_len > Params::max_statement_args() { return Err(Error::max_length( "statement_args.len".to_string(), args_len, - params.max_statement_args, + Params::max_statement_args(), )); } if wildcard_names.len() > params.max_custom_predicate_wildcards { @@ -358,7 +358,7 @@ impl CustomPredicate { } impl ToFields for CustomPredicate { - fn to_fields(&self, params: &Params) -> Vec { + fn to_fields(&self) -> Vec { // serialize as: // conjunction (one field element) // args_len (one field element) @@ -369,7 +369,7 @@ impl ToFields for CustomPredicate { // NOTE: this method assumes that the self.params.len() is inside the // expected bound, as Self should be instantiated with the constructor // method `new` which performs the check. - if self.statements.len() > params.max_custom_predicate_arity { + if self.statements.len() > BASE_PARAMS.max_custom_predicate_arity { panic!("Custom predicate depends on too many statements"); } @@ -380,8 +380,8 @@ impl ToFields for CustomPredicate { self.statements .iter() .chain(iter::repeat(&pad_st)) - .take(params.max_custom_predicate_arity) - .flat_map(|st| st.to_fields(params)), + .take(BASE_PARAMS.max_custom_predicate_arity) + .flat_map(|st| st.to_fields()), ) .collect_vec() } @@ -434,26 +434,26 @@ impl std::hash::Hash for CustomPredicateBatch { } impl ToFields for CustomPredicateBatch { - fn to_fields(&self, params: &Params) -> Vec { + fn to_fields(&self) -> Vec { // all the custom predicates in order let pad_pred = CustomPredicate::empty(); self.predicates .iter() .chain(iter::repeat(&pad_pred)) - .take(params.max_custom_batch_size) - .flat_map(|p| p.to_fields(params)) + .take(BASE_PARAMS.max_custom_batch_size) + .flat_map(|p| p.to_fields()) .collect_vec() } } impl CustomPredicateBatch { - pub fn new(params: &Params, name: String, predicates: Vec) -> Arc { + pub fn new(_params: &Params, name: String, predicates: Vec) -> Arc { let mut cpb = Self { id: EMPTY_HASH, name, predicates, }; - let id = cpb.calculate_id(params); + let id = cpb.calculate_id(); cpb.id = id; Arc::new(cpb) } @@ -467,10 +467,10 @@ impl CustomPredicateBatch { } /// Cryptographic identifier for the batch. - fn calculate_id(&self, params: &Params) -> Hash { + fn calculate_id(&self) -> Hash { // NOTE: This implementation just hashes the concatenation of all the custom predicates, // but ideally we want to use the root of a merkle tree built from the custom predicates. - let input = self.to_fields(params); + let input = self.to_fields(); hash_fields(&input) } @@ -613,7 +613,6 @@ mod tests { fn ethdos_test() -> Result<()> { let params = Params { max_custom_predicate_wildcards: 12, - max_statement_args: 6, ..Default::default() }; diff --git a/src/middleware/error.rs b/src/middleware/error.rs index 23650ce..74605da 100644 --- a/src/middleware/error.rs +++ b/src/middleware/error.rs @@ -43,7 +43,7 @@ pub enum MiddlewareInnerError { MismatchedStatementTmplArg(StatementTmplArg, StatementArg), #[error("Expected a statement of type {0}, got {1}")] MismatchedStatementType(Predicate, Predicate), - #[error("Expected a statement with hash(predicate) {0}, got {1} ({2})")] + #[error("Expected a statement with predicate value {0}, got {1} ({2})")] MismatchedStatementWildcardPredicate(Value, Value, Predicate), #[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")] MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate), diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index db28f5a..2b6ca01 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -62,6 +62,8 @@ pub enum TypedValue { PublicKey(PublicKey), // Schnorr secret key variant (scalar) SecretKey(SecretKey), + // Predicate as a value + Predicate(Predicate), // UNTAGGED TYPES: #[serde(untagged)] Set(Set), @@ -117,6 +119,12 @@ impl From for TypedValue { } } +impl From for TypedValue { + fn from(p: Predicate) -> Self { + TypedValue::Predicate(p) + } +} + impl From for TypedValue { fn from(s: Set) -> Self { TypedValue::Set(s) @@ -194,6 +202,17 @@ impl TryFrom<&TypedValue> for SecretKey { } } +impl TryFrom<&TypedValue> for Predicate { + type Error = Error; + fn try_from(v: &TypedValue) -> std::result::Result { + if let TypedValue::Predicate(p) = v { + Ok(p.clone()) + } else { + Err(Error::custom("Value not a Predicate".to_string())) + } + } +} + impl fmt::Display for TypedValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -240,6 +259,7 @@ impl fmt::Display for TypedValue { } TypedValue::PublicKey(p) => write!(f, "PublicKey({})", p), TypedValue::SecretKey(p) => write!(f, "SecretKey({})", p), + TypedValue::Predicate(p) => write!(f, "Predicate({})", p), TypedValue::Raw(r) => { write!(f, "Raw(0x{})", r.encode_hex::()) } @@ -259,6 +279,7 @@ impl From<&TypedValue> for RawValue { TypedValue::Raw(v) => *v, TypedValue::PublicKey(p) => RawValue::from(hash_fields(&p.as_fields())), TypedValue::SecretKey(sk) => RawValue::from(hash_fields(&sk.to_limbs())), + TypedValue::Predicate(p) => RawValue::from(p.hash()), } } } @@ -601,8 +622,8 @@ where } impl ToFields for Key { - fn to_fields(&self, params: &Params) -> Vec { - self.hash.to_fields(params) + fn to_fields(&self) -> Vec { + self.hash.to_fields() } } @@ -728,6 +749,33 @@ impl fmt::Display for PodType { } } +/// These base parameters need to be the same among different circuits to be compatible in their +/// verification. For this reason we define an instance of these parameters via `BASE_PARAMS` to +/// be used and we don't let the user of the library choose them. +pub struct BaseParams { + // + // The following parameters define how a pod id is calculated. + // + /// Number of public statements to hash to calculate the public inputs. Must be equal or + /// greater than `max_public_statements`. + pub num_public_statements_hash: usize, + pub max_statement_args: usize, + // + // The following parameters define how a custom predicate batch id is calculated. + // + /// max number of statements that can be ANDed or ORed together + /// in a custom predicate + pub max_custom_predicate_arity: usize, + pub max_custom_batch_size: usize, +} + +pub const BASE_PARAMS: BaseParams = BaseParams { + num_public_statements_hash: 16, + max_statement_args: 5, + max_custom_predicate_arity: 5, + max_custom_batch_size: 4, +}; + /// Params: non dynamic parameters that define the circuit. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] #[serde(rename_all = "camelCase")] @@ -757,21 +805,6 @@ pub struct Params { pub max_public_key_of: usize, // maximum number of signature verifications used for SignedBy operation pub max_signed_by: usize, - // - // The following parameters define how a pod id is calculated. They need to be the same among - // different circuits to be compatible in their verification. - // - // Number of public statements to hash to calculate the public inputs. Must be equal or - // greater than `max_public_statements`. - pub num_public_statements_hash: usize, - pub max_statement_args: usize, - // - // The following parameters define how a custom predicate batch id is calculated. - // - // max number of statements that can be ANDed or ORed together - // in a custom predicate - pub max_custom_predicate_arity: usize, - pub max_custom_batch_size: usize, } impl Default for Params { @@ -781,14 +814,10 @@ impl Default for Params { max_input_pods_public_statements: 8, max_statements: 48, max_public_statements: 8, - num_public_statements_hash: 16, - max_statement_args: 5, max_operation_args: 5, max_custom_predicate_batches: 4, max_custom_predicate_verifications: 8, - max_custom_predicate_arity: 5, max_custom_predicate_wildcards: 8, - max_custom_batch_size: 4, max_merkle_proofs_containers: 20, max_merkle_tree_state_transition_proofs_containers: 6, max_depth_mt_containers: 32, @@ -800,6 +829,21 @@ impl Default for Params { } impl Params { + // Convenient methods to get base params + + pub const fn num_public_statements_hash() -> usize { + BASE_PARAMS.num_public_statements_hash + } + pub const fn max_statement_args() -> usize { + BASE_PARAMS.max_statement_args + } + pub const fn max_custom_predicate_arity() -> usize { + BASE_PARAMS.max_custom_predicate_arity + } + pub const fn max_custom_batch_size() -> usize { + BASE_PARAMS.max_custom_batch_size + } + pub fn max_priv_statements(&self) -> usize { self.max_statements - self.max_public_statements } @@ -816,24 +860,25 @@ impl Params { HASH_SIZE + 2 } - pub fn statement_size(&self) -> usize { - HASH_SIZE + STATEMENT_ARG_F_LEN * self.max_statement_args + pub const fn statement_size() -> usize { + HASH_SIZE + STATEMENT_ARG_F_LEN * BASE_PARAMS.max_statement_args } pub const fn pred_hash_or_wc_size() -> usize { 1 + HASH_SIZE } - pub const fn statement_tmpl_size(&self) -> usize { - Self::pred_hash_or_wc_size() + self.max_statement_args * Self::statement_tmpl_arg_size() + pub const fn statement_tmpl_size() -> usize { + Self::pred_hash_or_wc_size() + + BASE_PARAMS.max_statement_args * Self::statement_tmpl_arg_size() } - pub fn custom_predicate_size(&self) -> usize { - self.max_custom_predicate_arity * self.statement_tmpl_size() + 2 + pub const fn custom_predicate_size() -> usize { + BASE_PARAMS.max_custom_predicate_arity * Self::statement_tmpl_size() + 2 } - pub fn custom_predicate_batch_size_field_elts(&self) -> usize { - self.max_custom_batch_size * self.custom_predicate_size() + pub const fn custom_predicate_batch_size_field_elts() -> usize { + BASE_PARAMS.max_custom_batch_size * Self::custom_predicate_size() } /// Total size of the statement table including None, input statements from signed pods and @@ -842,16 +887,6 @@ impl Params { 1 + self.max_input_pods * self.max_input_pods_public_statements + self.max_statements } - /// Parameters that define how the id is calculated - pub fn id_params(&self) -> Vec { - vec![ - self.num_public_statements_hash, - self.max_statement_args, - self.max_custom_predicate_arity, - self.max_custom_batch_size, - ] - } - pub fn print_serialized_sizes(&self) { println!("Parameter sizes:"); println!( @@ -859,11 +894,11 @@ impl Params { Self::statement_tmpl_arg_size() ); println!(" Predicate: {}", Self::predicate_size()); - println!(" Statement template: {}", self.statement_tmpl_size()); - println!(" Custom predicate: {}", self.custom_predicate_size()); + println!(" Statement template: {}", Self::statement_tmpl_size()); + println!(" Custom predicate: {}", Self::custom_predicate_size()); println!( " Custom predicate batch: {}", - self.custom_predicate_batch_size_field_elts() + Self::custom_predicate_batch_size_field_elts() ); println!(); } @@ -994,5 +1029,5 @@ pub trait MainPodProver { pub trait ToFields { /// returns `Vec` representation of the type - fn to_fields(&self, params: &Params) -> Vec; + fn to_fields(&self) -> Vec; } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 7d06d9c..526ff51 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -54,10 +54,10 @@ impl ToFields for OperationType { /// Encoding: /// - Native(native_op) => `[1, [native_op], 0, 0, 0, 0]` /// - Custom(batch, index) => `[3, [batch.id], index]` - fn to_fields(&self, params: &Params) -> Vec { + fn to_fields(&self) -> Vec { let mut fields: Vec = match self { Self::Native(p) => iter::once(F::from_canonical_u64(1)) - .chain(p.to_fields(params)) + .chain(p.to_fields()) .collect(), Self::Custom(CustomPredicateRef { batch, index }) => { iter::once(F::from_canonical_u64(3)) @@ -118,7 +118,7 @@ impl NativeOperation { } impl ToFields for NativeOperation { - fn to_fields(&self, _params: &Params) -> Vec { + fn to_fields(&self) -> Vec { vec![F::from_canonical_u64(*self as u64)] } } @@ -605,14 +605,25 @@ pub fn check_st_tmpl( } pub fn fill_wildcard_values( - params: &Params, pred: &CustomPredicate, args: &[Statement], wildcard_map: &mut [Option], ) -> Result<()> { for (st_tmpl, st) in pred.statements.iter().zip(args) { if let PredicateOrWildcard::Wildcard(wc) = &st_tmpl.pred_or_wc { - wc_check_or_set(Value::from(st.predicate().hash(params)), wc, wildcard_map)?; + wc_check_or_set(Value::from(st.predicate().hash()), wc, wildcard_map)?; + } + let st_args = st.args(); + + for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { + if let Err(st_tmpl_check_error) = check_st_tmpl(st_tmpl_arg, st_arg, wildcard_map) { + return Err(Error::statements_dont_match( + st.clone(), + st_tmpl.clone(), + wildcard_map.to_vec(), + st_tmpl_check_error, + )); + } } let st_args = st.args(); @@ -642,7 +653,7 @@ pub fn wildcard_values_from_op_st( .chain(core::iter::repeat(None)) .take(params.max_custom_predicate_wildcards) .collect_vec(); - fill_wildcard_values(params, pred, op_args, &mut wildcard_map)?; + fill_wildcard_values(pred, op_args, &mut wildcard_map)?; // NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because // they are beyond the number of used wildcards in this custom predicate, or they could be // private arguments that are unused in a particular disjunction. @@ -653,7 +664,6 @@ pub fn wildcard_values_from_op_st( } fn check_custom_pred_argument( - params: &Params, custom_pred_ref: &CustomPredicateRef, template: &StatementTmpl, statement: &Statement, @@ -676,11 +686,11 @@ fn check_custom_pred_argument( } } PredicateOrWildcard::Wildcard(wc) => { - let pred_hash = Value::from(statement.predicate().hash(params)); - if wc_values[wc.index] != pred_hash { + let pred_value = Value::from(statement.predicate()); + if wc_values[wc.index] != pred_value { return Err(Error::mismatched_statement_wc_pred( wc_values[wc.index].clone(), - pred_hash, + pred_value, statement.predicate(), )); } @@ -748,7 +758,7 @@ pub(crate) fn check_custom_pred( if !pred.conjunction && matches!(st, Statement::None) { continue; } - check_custom_pred_argument(params, custom_pred_ref, st_tmpl, st, &wc_values)?; + check_custom_pred_argument(custom_pred_ref, st_tmpl, st, &wc_values)?; match_exists = true; } @@ -761,7 +771,7 @@ pub(crate) fn check_custom_pred( } impl ToFields for Operation { - fn to_fields(&self, _params: &Params) -> Vec { + fn to_fields(&self) -> Vec { todo!() } } diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 20f6381..4ed1a8d 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -128,7 +128,7 @@ impl Display for NativePredicate { } impl ToFields for NativePredicate { - fn to_fields(&self, _params: &Params) -> Vec { + fn to_fields(&self) -> Vec { vec![F::from_canonical_u64(*self as u64)] } } @@ -209,7 +209,7 @@ impl From for F { } impl ToFields for Predicate { - fn to_fields(&self, params: &Params) -> Vec { + fn to_fields(&self) -> Vec { // serialize: // NativePredicate(id) as (1, id, 0...) -- id: usize // BatchSelf(i) as (2, i, 0...) -- i: usize @@ -222,7 +222,7 @@ impl ToFields for Predicate { // in every case: pad to `Params::predicate_size()` field elements let mut fields: Vec = match self { Self::Native(p) => iter::once(F::from(PredicatePrefix::Native)) - .chain(p.to_fields(params)) + .chain(p.to_fields()) .collect(), Self::BatchSelf(i) => iter::once(F::from(PredicatePrefix::BatchSelf)) .chain(iter::once(F::from_canonical_usize(*i))) @@ -245,8 +245,8 @@ impl ToFields for Predicate { } impl Predicate { - pub fn hash(&self, params: &Params) -> middleware::Hash { - hash_fields(&self.to_fields(params)) + pub fn hash(&self) -> middleware::Hash { + hash_fields(&self.to_fields()) } } @@ -503,11 +503,11 @@ impl Statement { } impl ToFields for Statement { - fn to_fields(&self, params: &Params) -> Vec { - let predicate_hash = hash_fields(&self.predicate().to_fields(params)); + fn to_fields(&self) -> Vec { + let predicate_hash = hash_fields(&self.predicate().to_fields()); let mut fields = predicate_hash.0.to_vec(); - fields.extend(self.args().iter().flat_map(|arg| arg.to_fields(params))); - fields.resize_with(params.statement_size(), || F::ZERO); + fields.extend(self.args().iter().flat_map(|arg| arg.to_fields())); + fields.resize_with(Params::statement_size(), || F::ZERO); fields } } @@ -573,7 +573,7 @@ impl ToFields for StatementArg { /// - Literal(v) => `[[v], 0, 0, 0, 0]` /// - Key(root, key) => `[[root], [key]]` /// - WildcardLiteral(v) => `[[v], 0, 0, 0, 0]` - fn to_fields(&self, params: &Params) -> Vec { + fn to_fields(&self) -> Vec { // NOTE for @ax0: I removed the old comment because may `to_fields` implementations do // padding and we need fixed output length for the circuits. let f = match self { @@ -585,8 +585,8 @@ impl ToFields for StatementArg { .chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE)) .collect(), StatementArg::Key(ak) => { - let mut fields = ak.root.to_fields(params); - fields.extend(ak.key.to_fields(params)); + let mut fields = ak.root.to_fields(); + fields.extend(ak.key.to_fields()); fields } };