From 0fca00cc93ef38e5d9bde432195c9714d9015377 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Mon, 19 Jan 2026 11:02:11 +0100 Subject: [PATCH] Use predicate hash in statements instead of the literal predicate Resolve #448 Previously a predicate was 6 elements. Now it grows to 8 elements; and the hash is 4 elements. Some parts of the circuit require only require equality checks with the predicate: that works with the predicate hash. Other parts require inspecting or working with particular elements in the predicate, those need the preimage of the predicate hash. Both `StatementTarget` and `StatementTmplTarget` have been updated to include the predicate hash and optionally the predicate. When the predicate is included, constraints are automatically generated for `pred_hash = hash(pred)`. We only include the predicate when needed. --- src/backends/plonky2/circuits/common.rs | 358 ++++++++++++++++------ src/backends/plonky2/circuits/mainpod.rs | 81 +++-- src/backends/plonky2/mainpod/statement.rs | 2 +- src/middleware/basetypes.rs | 6 + src/middleware/custom.rs | 1 + src/middleware/mod.rs | 6 +- src/middleware/statement.rs | 24 +- 7 files changed, 319 insertions(+), 159 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 9c7ccb8..5d8885b 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -9,7 +9,7 @@ use plonky2::{ types::{Field, PrimeField64}, }, hash::{ - hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS}, + hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS}, poseidon::PoseidonHash, }, iop::{ @@ -17,6 +17,7 @@ use plonky2::{ target::{BoolTarget, Target}, witness::{PartialWitness, PartitionWitness, Witness, WitnessWrite}, }, + plonk::config::Hasher, util::serialization::{Buffer, IoResult, Read, Write}, }; use serde::{Deserialize, Serialize}; @@ -136,10 +137,99 @@ impl StatementArgTarget { #[derive(Clone, Serialize, Deserialize)] pub struct StatementTarget { - pub predicate: PredicateTarget, + // If the pred is Some, then the `pred_hash` is constrained to be the `hash(pred)`. + pred: Option, + pred_hash: HashOutTarget, pub args: Vec, } +impl StatementTarget { + pub fn pred(&self) -> Option<&PredicateTarget> { + self.pred.as_ref() + } + pub fn pred_hash(&self) -> &HashOutTarget { + &self.pred_hash + } + + pub fn new(pred_hash: HashOutTarget, args: Vec) -> Self { + Self { + pred: None, + pred_hash, + args, + } + } + pub fn new_with_pred( + builder: &mut CircuitBuilder, + params: &Params, + predicate: impl Build, + args: &[StatementArgTarget], + ) -> Self { + let pred = predicate.build(builder, params); + let pred_hash = pred.hash(builder); + Self { + pred: Some(pred), + pred_hash, + args: args + .iter() + .cloned() + .chain(iter::repeat_with(|| StatementArgTarget::none(builder))) + .take(params.max_statement_args) + .collect(), + } + } + + pub fn new_native( + builder: &mut CircuitBuilder, + params: &Params, + native_predicate: impl Build, + args: &[StatementArgTarget], + ) -> Self { + let pred = PredicateTarget::new_native(builder, params, native_predicate); + Self::new_with_pred(builder, params, pred, args) + } + + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + st: &Statement, + ) -> Result<()> { + if let Some(pred) = &self.pred { + pred.set_targets(pw, params, &st.predicate())?; + } + pw.set_hash_target(self.pred_hash, HashOut::from(st.predicate().hash(params)))?; + for (i, arg) in st + .args() + .iter() + .chain(iter::repeat(&StatementArg::None)) + .take(params.max_statement_args) + .enumerate() + { + self.args[i].set_targets(pw, params, arg)?; + } + Ok(()) + } + + pub fn pred_is_blank_intro(&self, builder: &mut CircuitBuilder) -> BoolTarget { + let zero_hash = builder.constant_hash(HashOut { + elements: [F::ZERO, F::ZERO, F::ZERO, F::ZERO], + }); + let blank_intro = PredicateTarget::new_intro(builder, zero_hash).hash(builder); + builder.is_equal_flattenable(&self.pred_hash, &blank_intro) + } + + pub fn has_native_type( + &self, + builder: &mut CircuitBuilder, + params: &Params, + t: NativePredicate, + ) -> BoolTarget { + let expected_predicate_hash = + builder.constant_hash(HashOut::from(Predicate::Native(t).hash(params))); + builder.is_equal_flattenable(&self.pred_hash, &expected_predicate_hash) + } +} + pub trait Build { fn build(self, builder: &mut CircuitBuilder, params: &Params) -> T; } @@ -156,57 +246,6 @@ impl Build for T { } } -impl StatementTarget { - /// Build a new native StatementTarget. Pads the arguments. - pub fn new_native( - builder: &mut CircuitBuilder, - params: &Params, - native_predicate: impl Build, - args: &[StatementArgTarget], - ) -> Self { - // if native_predicate is const then NativePredicate -> NativePredicateTarget - // else just use as is - Self { - predicate: PredicateTarget::new_native(builder, params, native_predicate), - args: args - .iter() - .cloned() - .chain(iter::repeat_with(|| StatementArgTarget::none(builder))) - .take(params.max_statement_args) - .collect(), - } - } - - pub fn set_targets( - &self, - pw: &mut PartialWitness, - params: &Params, - st: &Statement, - ) -> Result<()> { - self.predicate.set_targets(pw, params, st.predicate())?; - for (i, arg) in st - .args() - .iter() - .chain(iter::repeat(&StatementArg::None)) - .take(params.max_statement_args) - .enumerate() - { - self.args[i].set_targets(pw, params, arg)?; - } - Ok(()) - } - - pub fn has_native_type( - &self, - builder: &mut CircuitBuilder, - params: &Params, - t: NativePredicate, - ) -> BoolTarget { - let expected_predicate = PredicateTarget::new_native(builder, params, t); - builder.is_equal_flattenable(&self.predicate, &expected_predicate) - } -} - #[derive(Clone, Serialize, Deserialize)] pub struct OperationTypeTarget { #[serde(with = "serde_arrays")] @@ -336,7 +375,7 @@ impl PredicateTarget { let id = native_predicate.build(builder, params).0; let zero = builder.zero(); Self { - elements: [prefix, id, zero, zero, zero, zero], + elements: [prefix, id, zero, zero, zero, zero, zero, zero], } } @@ -344,7 +383,7 @@ impl PredicateTarget { let prefix = builder.constant(F::from(PredicatePrefix::BatchSelf)); let zero = builder.zero(); Self { - elements: [prefix, index, zero, zero, zero, zero], + elements: [prefix, index, zero, zero, zero, zero, zero, zero], } } @@ -355,8 +394,9 @@ impl PredicateTarget { ) -> Self { let prefix = builder.constant(F::from(PredicatePrefix::Custom)); let id = batch_id.elements; + let zero = builder.zero(); Self { - elements: [prefix, id[0], id[1], id[2], id[3], index], + elements: [prefix, id[0], id[1], id[2], id[3], index, zero, zero], } } @@ -365,7 +405,7 @@ impl PredicateTarget { let vh = vd_hash.elements; let zero = builder.zero(); Self { - elements: [prefix, vh[0], vh[1], vh[2], vh[3], zero], + elements: [prefix, vh[0], vh[1], vh[2], vh[3], zero, zero, zero], } } @@ -378,10 +418,30 @@ impl PredicateTarget { &self, pw: &mut PartialWitness, params: &Params, - predicate: Predicate, + predicate: &Predicate, ) -> Result<()> { Ok(pw.set_target_arr(&self.elements, &predicate.to_fields(params))?) } + + pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget { + // Optimization: if all the predicate values are constants we skip the hash circuit and + // return a hash constant + let mut predicate_values = [F::ZERO; Params::predicate_size()]; + let mut predicate_constant = true; + for (i, target) in self.elements.iter().enumerate() { + if let Some(v) = builder.target_as_constant(*target) { + predicate_values[i] = v; + } else { + predicate_constant = false; + break; + } + } + if predicate_constant { + builder.constant_hash(PoseidonHash::hash_no_pad(&predicate_values)) + } else { + builder.hash_n_to_hash_no_pad::(self.elements.to_vec()) + } + } } /// Mirrors `middleware::KeyOrWildcard` @@ -466,18 +526,46 @@ impl StatementTmplArgTarget { #[derive(Clone, Serialize, Deserialize)] pub struct StatementTmplTarget { - pub pred: PredicateTarget, + pred: Option, + pred_hash: HashOutTarget, pub args: Vec, } impl StatementTmplTarget { + pub fn new(pred_hash: HashOutTarget, args: Vec) -> Self { + Self { + pred: None, + pred_hash, + args, + } + } pub fn set_targets( &self, pw: &mut PartialWitness, params: &Params, st_tmpl: &StatementTmpl, ) -> Result<()> { - Ok(pw.set_target_arr(&self.flatten(), &st_tmpl.to_fields(params))?) + if let Some(pred) = &self.pred { + pred.set_targets(pw, params, &st_tmpl.pred)?; + } + pw.set_hash_target(self.pred_hash, HashOut::from(st_tmpl.pred.hash(params)))?; + let arg_pad = StatementTmplArg::None; + for (i, arg) in st_tmpl + .args + .iter() + .chain(iter::repeat(&arg_pad)) + .take(params.max_statement_args) + .enumerate() + { + self.args[i].set_targets(pw, params, arg)?; + } + Ok(()) + } + pub fn pred(&self) -> Option<&PredicateTarget> { + self.pred.as_ref() + } + pub fn pred_hash(&self) -> &HashOutTarget { + &self.pred_hash } } @@ -494,9 +582,24 @@ impl CustomPredicateTarget { &self, pw: &mut PartialWitness, params: &Params, - custom_predicate: &CustomPredicate, + custom_pred: &CustomPredicate, ) -> Result<()> { - Ok(pw.set_target_arr(&self.flatten(), &custom_predicate.to_fields(params))?) + pw.set_target( + self.conjunction.target, + F::from_bool(custom_pred.conjunction), + )?; + let st_tmpl_pad = custom_pred.pad_statement_tmpl(); + for (i, st_tmpl) in custom_pred + .statements + .iter() + .chain(iter::repeat(&st_tmpl_pad)) + .take(params.max_custom_predicate_arity) + .enumerate() + { + self.statements[i].set_targets(pw, params, st_tmpl)?; + } + pw.set_target(self.args_len, F::from_canonical_usize(custom_pred.args_len))?; + Ok(()) } } @@ -507,7 +610,7 @@ pub struct CustomPredicateBatchTarget { impl CustomPredicateBatchTarget { pub fn id(&self, builder: &mut CircuitBuilder) -> HashOutTarget { - let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect(); + let flattened: Vec<_> = self.predicates.iter().flat_map(|cp| cp.flatten()).collect(); builder.hash_n_to_hash_no_pad::(flattened) } @@ -621,7 +724,7 @@ pub struct CustomPredicateVerifyEntryTarget { } impl CustomPredicateVerifyEntryTarget { - pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { + pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder, with_pred: bool) -> Self { let custom_predicate_table_len = params.max_custom_predicate_batches * params.max_custom_batch_size; CustomPredicateVerifyEntryTarget { @@ -629,12 +732,12 @@ impl CustomPredicateVerifyEntryTarget { custom_predicate_table_len, builder, ), - custom_predicate: builder.add_virtual_custom_predicate_entry(params), + custom_predicate: builder.add_virtual_custom_predicate_entry(params, with_pred), args: (0..params.max_custom_predicate_wildcards) .map(|_| builder.add_virtual_value()) .collect(), op_args: (0..params.max_operation_args) - .map(|_| builder.add_virtual_statement(params)) + .map(|_| builder.add_virtual_statement(params, false)) .collect(), } } @@ -897,7 +1000,7 @@ impl Flattenable for PredicateTarget { impl Flattenable for StatementTarget { fn flatten(&self) -> Vec { - self.predicate + self.pred_hash .flatten() .into_iter() .chain(self.args.iter().flat_map(|a| &a.elements).cloned()) @@ -906,20 +1009,22 @@ impl Flattenable for StatementTarget { fn from_flattened(params: &Params, v: &[Target]) -> Self { assert_eq!(v.len(), Self::size(params)); - let predicate = PredicateTarget::from_flattened(params, &v[..Params::predicate_size()]); + let predicate_hash = HashOutTarget::from_flattened(params, &v[..HASH_SIZE]); let args = (0..params.max_statement_args) .map(|i| StatementArgTarget { - elements: array::from_fn(|j| { - v[Params::predicate_size() + i * STATEMENT_ARG_F_LEN + j] - }), + elements: array::from_fn(|j| v[HASH_SIZE + i * STATEMENT_ARG_F_LEN + j]), }) .collect(); - Self { predicate, args } + Self { + pred: None, + pred_hash: predicate_hash, + args, + } } fn size(params: &Params) -> usize { - PredicateTarget::size(params) + params.max_statement_args * StatementArgTarget::size(params) + HASH_SIZE + params.max_statement_args * StatementArgTarget::size(params) } } @@ -957,7 +1062,7 @@ impl Flattenable for CustomPredicateTarget { impl Flattenable for StatementTmplTarget { fn flatten(&self) -> Vec { - self.pred + self.pred_hash .flatten() .into_iter() .chain(self.args.iter().flat_map(|sta| sta.flatten())) @@ -966,21 +1071,24 @@ impl Flattenable for StatementTmplTarget { fn from_flattened(params: &Params, v: &[Target]) -> Self { assert_eq!(v.len(), Self::size(params)); - let pred_end = Params::predicate_size(); - let pred = PredicateTarget::from_flattened(params, &v[..pred_end]); + let pred_hash_end = HASH_SIZE; + let pred_hash = HashOutTarget::from_flattened(params, &v[..pred_hash_end]); let sta_size = Params::statement_tmpl_arg_size(); let args = (0..params.max_statement_args) .map(|i| { - let sta_v = &v[pred_end + sta_size * i..pred_end + sta_size * (i + 1)]; + let sta_v = &v[pred_hash_end + sta_size * i..pred_hash_end + sta_size * (i + 1)]; StatementTmplArgTarget::from_flattened(params, sta_v) }) .collect(); - Self { pred, args } + Self { + pred: None, + pred_hash, + args, + } } fn size(params: &Params) -> usize { - PredicateTarget::size(params) - + params.max_statement_args * StatementTmplArgTarget::size(params) + HASH_SIZE + params.max_statement_args * StatementTmplArgTarget::size(params) } } @@ -1039,18 +1147,32 @@ pub trait CircuitBuilderPod, const D: usize> { fn connect_values(&mut self, x: ValueTarget, y: ValueTarget); fn connect_slice(&mut self, xs: &[Target], ys: &[Target]); fn add_virtual_value(&mut self) -> ValueTarget; - fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget; + fn add_virtual_statement(&mut self, params: &Params, with_pred: bool) -> StatementTarget; fn add_virtual_statement_arg(&mut self) -> StatementArgTarget; fn add_virtual_predicate(&mut self) -> PredicateTarget; fn add_virtual_operation_type(&mut self) -> OperationTypeTarget; fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget; fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget; - fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget; - fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget; - fn add_virtual_custom_predicate_batch(&mut self, params: &Params) - -> CustomPredicateBatchTarget; - fn add_virtual_custom_predicate_entry(&mut self, params: &Params) - -> CustomPredicateEntryTarget; + fn add_virtual_statement_tmpl( + &mut self, + params: &Params, + with_pred: bool, + ) -> StatementTmplTarget; + fn add_virtual_custom_predicate( + &mut self, + params: &Params, + with_pred: bool, + ) -> CustomPredicateTarget; + fn add_virtual_custom_predicate_batch( + &mut self, + params: &Params, + with_pred: bool, + ) -> CustomPredicateBatchTarget; + fn add_virtual_custom_predicate_entry( + &mut self, + params: &Params, + with_pred: bool, + ) -> CustomPredicateEntryTarget; fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; fn select_statement_arg( &mut self, @@ -1144,10 +1266,20 @@ impl CircuitBuilderPod for CircuitBuilder { } } - fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget { - let predicate = self.add_virtual_predicate(); + /// If `with_pred = true` a predicate is included and its hash constrained. + /// If `with_pred = false` only the predicate hash is included. + fn add_virtual_statement(&mut self, params: &Params, with_pred: bool) -> StatementTarget { + let (pred, pred_hash) = if with_pred { + let pred = self.add_virtual_predicate(); + let pred_hash = pred.hash(self); + (Some(pred), pred_hash) + } else { + let pred_hash = self.add_virtual_hash(); + (None, pred_hash) + }; StatementTarget { - predicate, + pred, + pred_hash, args: (0..params.max_statement_args) .map(|_| self.add_virtual_statement_arg()) .collect(), @@ -1188,19 +1320,38 @@ impl CircuitBuilderPod for CircuitBuilder { } } - fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget { - let args = (0..params.max_statement_args) - .map(|_| self.add_virtual_statement_tmpl_arg()) - .collect(); + /// If `with_pred = true` a predicate is included and its hash constrained. + /// If `with_pred = false` only the predicate hash is included. + fn add_virtual_statement_tmpl( + &mut self, + params: &Params, + with_pred: bool, + ) -> StatementTmplTarget { + let (pred, pred_hash) = if with_pred { + let pred = self.add_virtual_predicate(); + let pred_hash = pred.hash(self); + (Some(pred), pred_hash) + } else { + let pred_hash = self.add_virtual_hash(); + (None, pred_hash) + }; StatementTmplTarget { - pred: self.add_virtual_predicate(), - args, + pred, + pred_hash, + args: (0..params.max_statement_args) + .map(|_| self.add_virtual_statement_tmpl_arg()) + .collect(), } } - fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget { + /// See `add_virtual_statement_tmpl` for the meaning of `with_pred`. + fn add_virtual_custom_predicate( + &mut self, + params: &Params, + with_pred: bool, + ) -> CustomPredicateTarget { let statements = (0..params.max_custom_predicate_arity) - .map(|_| self.add_virtual_statement_tmpl(params)) + .map(|_| self.add_virtual_statement_tmpl(params, with_pred)) .collect(); CustomPredicateTarget { conjunction: self.add_virtual_bool_target_safe(), @@ -1209,25 +1360,29 @@ impl CircuitBuilderPod for CircuitBuilder { } } + /// See `add_virtual_statement_tmpl` for the meaning of `with_pred`. fn add_virtual_custom_predicate_batch( &mut self, params: &Params, + with_pred: bool, ) -> CustomPredicateBatchTarget { CustomPredicateBatchTarget { predicates: (0..params.max_custom_batch_size) - .map(|_| self.add_virtual_custom_predicate(params)) + .map(|_| self.add_virtual_custom_predicate(params, with_pred)) .collect(), } } + /// See `add_virtual_statement_tmpl` for the meaning of `with_pred`. fn add_virtual_custom_predicate_entry( &mut self, params: &Params, + with_pred: bool, ) -> CustomPredicateEntryTarget { CustomPredicateEntryTarget { id: self.add_virtual_hash(), index: self.add_virtual_target(), - predicate: self.add_virtual_custom_predicate(params), + predicate: self.add_virtual_custom_predicate(params, with_pred), } } @@ -1734,7 +1889,8 @@ pub(crate) mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let custom_predicate_batch_target = builder.add_virtual_custom_predicate_batch(params); + let custom_predicate_batch_target = + builder.add_virtual_custom_predicate_batch(params, false); // Calculate the id in constraints and compare it against the id calculated natively let id_target = custom_predicate_batch_target.id(&mut builder); diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 46c56aa..86d72d4 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -1396,10 +1396,7 @@ fn make_statement_from_template_circuit( }) .collect(); measure_gates_end!(builder, measure); - StatementTarget { - predicate: st_tmpl.pred.clone(), - args, - } + StatementTarget::new(*st_tmpl.pred_hash(), args) } /// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard @@ -1434,11 +1431,9 @@ fn make_custom_statement_circuit( let v = builder.select_flattenable(params, mask, arg, &arg_none); StatementArgTarget::wildcard_literal(builder, &v) }) - .collect(); - let statement = StatementTarget { - predicate: st_predicate, - args: st_args, - }; + .collect_vec(); + let statement_with_pred = + StatementTarget::new_with_pred(builder, params, st_predicate, &st_args); // Check the operation arguments // From each statement template we generate an expected statement using replacing the @@ -1470,7 +1465,7 @@ fn make_custom_statement_circuit( builder.assert_one(is_op_args_ok.target); measure_gates_end!(builder, measure); - Ok((statement, op_type)) + Ok((statement_with_pred, op_type)) } /// Replace the blank verifier_data_hash slots in intro predicates by `vd_hash` @@ -1480,19 +1475,13 @@ fn normalize_statement_circuit( statement: &StatementTarget, vd_hash: &HashOutTarget, ) -> StatementTarget { - let is_intro = statement.predicate.is_intro(builder); - let old_pred = statement.predicate.elements; - let old = HashOutTarget::try_from(&old_pred[1..1 + HASH_SIZE]).expect("len = 4"); - let new = builder - .select_flattenable(params, is_intro, vd_hash, &old) - .elements; + let is_blank_intro = statement.pred_is_blank_intro(builder); + let old_pred_hash = statement.pred_hash(); + let intro_pred_hash = PredicateTarget::new_intro(builder, *vd_hash).hash(builder); + let new_pred_hash = + builder.select_flattenable(params, is_blank_intro, &intro_pred_hash, old_pred_hash); - StatementTarget { - predicate: PredicateTarget { - elements: [old_pred[0], new[0], new[1], new[2], new[3], old_pred[5]], - }, - args: statement.args.clone(), - } + StatementTarget::new(new_pred_hash, statement.args.clone()) } /// `params.num_public_statements_hash` is the total number of statements that will be hashed. @@ -1538,15 +1527,13 @@ fn normalize_st_tmpl_circuit( st_tmpl: &StatementTmplTarget, id: HashOutTarget, ) -> StatementTmplTarget { + let pred = st_tmpl.pred().expect("StatementTmpl contains predicate"); let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf)); - let is_batch_self = builder.is_equal(st_tmpl.pred.elements[0], prefix_batch_self); - let pred_index = st_tmpl.pred.elements[1]; + let is_batch_self = builder.is_equal(pred.elements[0], prefix_batch_self); + let pred_index = pred.elements[1]; let custom_pred = PredicateTarget::new_custom(builder, id, pred_index); - let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, &st_tmpl.pred); - StatementTmplTarget { - pred, - args: st_tmpl.args.clone(), - } + let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, pred); + StatementTmplTarget::new(pred.hash(builder), st_tmpl.args.clone()) } /// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as @@ -1567,7 +1554,9 @@ fn build_custom_predicate_table_circuit( let statements = cp .statements .iter() - .map(|st_tmpl| normalize_st_tmpl_circuit(params, builder, st_tmpl, id)) + .map(|st_with_pred_tmpl| { + normalize_st_tmpl_circuit(params, builder, st_with_pred_tmpl, id) + }) .collect_vec(); let cp = CustomPredicateTarget { conjunction: cp.conjunction, @@ -1625,19 +1614,19 @@ fn verify_main_pod_circuit( // NOTE: We use an EmptyPod for padding input pod slots. The EmptyPod is an introduction // pod that declares a statement with no arguments. - let is_intro = input_pod_self_statements[0].predicate.is_intro(builder); + let is_blank_intro = input_pod_self_statements[0].pred_is_blank_intro(builder); // Introduction pods can only have Introduction or None statements - let mut intro_ok = is_intro; + let mut intro_ok = is_blank_intro; for self_st in &input_pod_self_statements[1..] { - let st_is_intro = self_st.predicate.is_intro(builder); + let st_is_intro = self_st.pred_is_blank_intro(builder); let st_is_none = self_st.has_native_type(builder, params, NativePredicate::None); let st_is_intro_or_none = builder.or(st_is_intro, st_is_none); intro_ok = builder.and(intro_ok, st_is_intro_or_none); } - builder.connect(is_intro.target, intro_ok.target); + builder.connect(is_blank_intro.target, intro_ok.target); - let is_main = builder.not(is_intro); + let is_main = builder.not(is_blank_intro); for self_st in input_pod_self_statements { let normalized_st = normalize_statement_circuit( params, @@ -1750,12 +1739,12 @@ impl MainPodVerifyTarget { input_pods_self_statements: (0..params.max_input_pods) .map(|_| { (0..params.max_input_pods_public_statements) - .map(|_| builder.add_virtual_statement(params)) + .map(|_| builder.add_virtual_statement(params, false)) .collect_vec() }) .collect(), input_statements: (0..params.max_statements) - .map(|_| builder.add_virtual_statement(params)) + .map(|_| builder.add_virtual_statement(params, false)) .collect(), operations: (0..params.max_statements) .map(|_| builder.add_virtual_operation(params)) @@ -1781,10 +1770,10 @@ impl MainPodVerifyTarget { }) .collect(), custom_predicate_batches: (0..params.max_custom_predicate_batches) - .map(|_| builder.add_virtual_custom_predicate_batch(params)) + .map(|_| builder.add_virtual_custom_predicate_batch(params, true)) .collect(), custom_predicate_verifications: (0..params.max_custom_predicate_verifications) - .map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder)) + .map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder, false)) .collect(), } } @@ -2084,10 +2073,10 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let st_target = builder.add_virtual_statement(¶ms); + let st_target = builder.add_virtual_statement(¶ms, false); let op_target = builder.add_virtual_operation(¶ms); let prev_statements_target: Vec<_> = (0..prev_statements.len()) - .map(|_| builder.add_virtual_statement(¶ms)) + .map(|_| builder.add_virtual_statement(¶ms, false)) .collect(); let merkle_proofs_target: Vec<_> = aux @@ -3098,7 +3087,7 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let st_tmpl_target = builder.add_virtual_statement_tmpl(params); + let st_tmpl_target = builder.add_virtual_statement_tmpl(params, false); let args_target: Vec<_> = (0..args.len()) .map(|_| builder.add_virtual_value()) .collect(); @@ -3109,7 +3098,7 @@ mod tests { &args_target, ); // TODO: Instead of connect, assign witness to result - let expected_st_target = builder.add_virtual_statement(params); + let expected_st_target = builder.add_virtual_statement(params, false); builder.connect_flattenable(&expected_st_target, &st_target); let mut pw = PartialWitness::::new(); @@ -3161,9 +3150,9 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::new(config); - let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params); + let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params, false); let op_args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_statement(params)) + .map(|_| builder.add_virtual_statement(params, false)) .collect(); let args_target: Vec<_> = (0..args.len()) .map(|_| builder.add_virtual_value()) @@ -3455,7 +3444,7 @@ mod tests { let mut builder = CircuitBuilder::new(config); let statements_target = (0..params.max_public_statements) - .map(|_| builder.add_virtual_statement(params)) + .map(|_| builder.add_virtual_statement(params, false)) .collect_vec(); let sts_hash_target = calculate_statements_hash_circuit(params, &mut builder, &statements_target); diff --git a/src/backends/plonky2/mainpod/statement.rs b/src/backends/plonky2/mainpod/statement.rs index 3fc4ed0..8a19b57 100644 --- a/src/backends/plonky2/mainpod/statement.rs +++ b/src/backends/plonky2/mainpod/statement.rs @@ -31,7 +31,7 @@ impl Statement { impl ToFields for Statement { fn to_fields(&self, params: &Params) -> Vec { - let mut fields = self.0.to_fields(params); + let mut fields = self.0.hash(params).to_fields(params); fields.extend( self.1 .iter() diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index d2d4eed..3bbb960 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -169,6 +169,12 @@ pub struct Hash( pub [F; HASH_SIZE], ); +impl From for HashOut { + fn from(hash: Hash) -> HashOut { + HashOut { elements: hash.0 } + } +} + impl ToHex for Hash { fn encode_hex>(&self) -> T { self.0 diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 9254760..86bb2f4 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -170,6 +170,7 @@ impl ToFields for StatementTmpl { let mut fields: Vec = self .pred + .hash(params) .to_fields(params) .into_iter() .chain(self.args.iter().flat_map(|sta| sta.to_fields(params))) diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 8ae3760..a0b55ae 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -809,7 +809,7 @@ impl Params { } pub const fn predicate_size() -> usize { - HASH_SIZE + 2 + 8 } pub const fn operation_type_size() -> usize { @@ -817,11 +817,11 @@ impl Params { } pub fn statement_size(&self) -> usize { - Self::predicate_size() + STATEMENT_ARG_F_LEN * self.max_statement_args + HASH_SIZE + STATEMENT_ARG_F_LEN * self.max_statement_args } pub const fn statement_tmpl_size(&self) -> usize { - Self::predicate_size() + self.max_statement_args * Self::statement_tmpl_arg_size() + HASH_SIZE + self.max_statement_args * Self::statement_tmpl_arg_size() } pub fn custom_predicate_size(&self) -> usize { diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 5889d5e..20f6381 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -10,7 +10,8 @@ use serde::{Deserialize, Serialize}; use strum_macros::FromRepr; use crate::middleware::{ - self, AnchoredKey, CustomPredicateRef, Error, Params, Result, ToFields, Value, F, VALUE_SIZE, + self, hash_fields, AnchoredKey, CustomPredicateRef, Error, Params, Result, ToFields, Value, F, + VALUE_SIZE, }; pub const STATEMENT_ARG_F_LEN: usize = 8; @@ -210,15 +211,15 @@ impl From for F { impl ToFields for Predicate { fn to_fields(&self, params: &Params) -> Vec { // serialize: - // NativePredicate(id) as (1, id, 0, 0, 0, 0) -- id: usize - // BatchSelf(i) as (2, i, 0, 0, 0, 0) -- i: usize + // NativePredicate(id) as (1, id, 0...) -- id: usize + // BatchSelf(i) as (2, i, 0...) -- i: usize // CustomPredicateRef(pb, i) as - // (3, [hash of pb], i) -- pb hashes to 4 field elements - // -- i: usize + // (3, [hash of pb], i, 0...) -- pb hashes to 4 field elements + // -- i: usize // IntroPredicateRef(vd_hash) as - // (4, [vd_hash], 0) + // (4, [vd_hash], 0...) - // in every case: pad to (hash_size + 2) field elements + // in every case: pad to `Params::predicate_size()` field elements let mut fields: Vec = match self { Self::Native(p) => iter::once(F::from(PredicatePrefix::Native)) .chain(p.to_fields(params)) @@ -243,6 +244,12 @@ impl ToFields for Predicate { } } +impl Predicate { + pub fn hash(&self, params: &Params) -> middleware::Hash { + hash_fields(&self.to_fields(params)) + } +} + impl fmt::Display for Predicate { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -497,7 +504,8 @@ impl Statement { impl ToFields for Statement { fn to_fields(&self, params: &Params) -> Vec { - let mut fields = self.predicate().to_fields(params); + let predicate_hash = hash_fields(&self.predicate().to_fields(params)); + let mut fields = predicate_hash.0.to_vec(); fields.extend(self.args().iter().flat_map(|arg| arg.to_fields(params))); fields.resize_with(params.statement_size(), || F::ZERO); fields