From d3fef8392e110121e0b4284995d696cc873a64f9 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Fri, 23 May 2025 10:12:28 +0200 Subject: [PATCH] calculate MainPod id in a dynamic-friendly way (#241) * calculate MainPod id in a dynamic-friendly way The MainPod id is now calculated with front padding and a fixed size independent of max_public_statements so that introduction gadgets can be verified by a MainPod while paying only for the number of statements they use. This is because with front padding of none-statements we can precompute the poseidon state corresponding to absorbing all the padding statements and only pay constraints for the non-padding statements. The id is calculated as follows: `id = hash(serialize(reverse(statements || none-statements)))` * fix test --- src/backends/plonky2/circuits/mainpod.rs | 213 +++++++++++++++++++++-- src/backends/plonky2/mainpod/mod.rs | 34 +++- src/backends/plonky2/mock/mainpod.rs | 9 +- src/examples/mod.rs | 9 +- src/frontend/mod.rs | 2 +- src/middleware/mod.rs | 4 + 6 files changed, 245 insertions(+), 26 deletions(-) diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index ad2aba8..a8faeca 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -3,9 +3,16 @@ use std::{array, iter, sync::Arc}; use itertools::{zip_eq, Itertools}; use plonky2::{ field::types::Field, - hash::{hash_types::HashOutTarget, poseidon::PoseidonHash}, - iop::{target::BoolTarget, witness::PartialWitness}, - plonk::circuit_builder::CircuitBuilder, + hash::{ + hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS}, + hashing::PlonkyPermutation, + poseidon::{PoseidonHash, PoseidonPermutation}, + }, + iop::{ + target::{BoolTarget, Target}, + witness::PartialWitness, + }, + plonk::{circuit_builder::CircuitBuilder, config::AlgebraicHasher}, }; use crate::{ @@ -22,7 +29,7 @@ use crate::{ signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget}, }, error::Result, - mainpod, + mainpod::{self, pad_statement}, primitives::merkletree::{ MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProofGadget, }, @@ -894,6 +901,88 @@ impl CustomOperationVerifyGadget { } } +struct CalculateIdGadget { + params: Params, +} + +impl CalculateIdGadget { + /// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder + /// elements that didn't fit into a chunk. + fn precompute_hash_state>(inputs: &[F]) -> (P, &[F]) { + let (inputs, inputs_rem) = inputs.split_at((inputs.len() / P::RATE) * P::RATE); + let mut perm = P::new(core::iter::repeat(F::ZERO)); + + // Absorb all inputs up to the biggest multiple of RATE. + for input_chunk in inputs.chunks(P::RATE) { + perm.set_from_slice(input_chunk, 0); + perm.permute(); + } + + (perm, inputs_rem) + } + + /// Hash `inputs` starting from a circuit-constant `perm` state. + fn hash_from_state, P: PlonkyPermutation>( + builder: &mut CircuitBuilder, + perm: P, + inputs: &[Target], + ) -> HashOutTarget { + let mut state = + H::AlgebraicPermutation::new(perm.as_ref().iter().map(|v| builder.constant(*v))); + + // Absorb all input chunks. + for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) { + // Overwrite the first r elements with the inputs. This differs from a standard sponge, + // where we would xor or add in the inputs. This is a well-known variant, though, + // sometimes called "overwrite mode". + state.set_from_slice(input_chunk, 0); + state = builder.permute::(state); + } + + let num_outputs = NUM_HASH_OUT_ELTS; + // Squeeze until we have the desired number of outputs. + let mut outputs = Vec::with_capacity(num_outputs); + loop { + for &s in state.squeeze() { + outputs.push(s); + if outputs.len() == num_outputs { + return HashOutTarget::from_vec(outputs); + } + } + state = builder.permute::(state); + } + } + + fn eval( + &self, + builder: &mut CircuitBuilder, + statements: &[StatementTarget], + ) -> HashOutTarget { + let measure = measure_gates_begin!(builder, "CalculateId"); + let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten()); + let mut none_st = mainpod::Statement::from(Statement::None); + pad_statement(&self.params, &mut none_st); + let front_pad_elts = iter::repeat(&none_st) + .take(self.params.num_public_statements_id - self.params.max_public_statements) + .flat_map(|s| s.to_fields(&self.params)) + .collect_vec(); + let (perm, front_pad_elts_rem) = + Self::precompute_hash_state::>(&front_pad_elts); + + // Precompute the Poseidon state for the initial padding chunks + let inputs = front_pad_elts_rem + .iter() + .map(|v| builder.constant(*v)) + .chain(statements_rev_flattened) + .collect_vec(); + let id = + Self::hash_from_state::>(builder, perm, &inputs); + + measure_gates_end!(builder, measure); + id + } +} + struct MainPodVerifyGadget { params: Params, } @@ -1089,10 +1178,10 @@ impl MainPodVerifyGadget { self.build_custom_predicate_verification_table(builder, &custom_predicate_table)?; // 2. Calculate the Pod Id from the public statements - let measure_calc_id = measure_gates_begin!(builder, "MainPodId"); - let pub_statements_flattened = pub_statements.iter().flat_map(|s| s.flatten()).collect(); - let id = builder.hash_n_to_hash_no_pad::(pub_statements_flattened); - measure_gates_end!(builder, measure_calc_id); + let id = CalculateIdGadget { + params: self.params.clone(), + } + .eval(builder, pub_statements); // 4. Verify type let type_statement = &pub_statements[0]; @@ -1266,10 +1355,12 @@ impl MainPodVerifyCircuit { #[cfg(test)] mod tests { - use std::ops::Not; + use std::{iter, ops::Not}; use plonky2::{ field::{goldilocks_field::GoldilocksField, types::Field}, + hash::hash_types::HashOut, + iop::witness::WitnessWrite, plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, }; @@ -1278,7 +1369,7 @@ mod tests { backends::plonky2::{ basetypes::C, circuits::common::tests::I64_TEST_PAIRS, - mainpod::{OperationArg, OperationAux}, + mainpod::{calculate_id, OperationArg, OperationAux}, primitives::merkletree::{MerkleClaimAndProof, MerkleTree}, }, frontend::{self, key, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, @@ -2681,4 +2772,106 @@ mod tests { Ok(()) } + + fn helper_calculate_id(params: &Params, statements: &[Statement]) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let gadget = CalculateIdGadget { + params: params.clone(), + }; + + let statements_target = (0..params.max_public_statements) + .map(|_| builder.add_virtual_statement(params)) + .collect_vec(); + let id_target = gadget.eval(&mut builder, &statements_target); + + let mut pw = PartialWitness::::new(); + + // Input + let statements = statements + .into_iter() + .map(|st| { + let mut st = mainpod::Statement::from(st.clone()); + pad_statement(params, &mut st); + st + }) + .collect_vec(); + for (st_target, st) in statements_target.iter().zip(statements.iter()) { + st_target.set_targets(&mut pw, params, st)?; + } + // Expected Output + let expected_id = calculate_id(&statements, params); + pw.set_hash_target( + id_target, + HashOut { + elements: expected_id.0, + }, + )?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + Ok(data.verify(proof.clone())?) + } + + #[test] + fn test_calculate_id() -> frontend::Result<()> { + // Case with no public public statements + let params = Params { + max_public_statements: 0, + num_public_statements_id: 8, + ..Default::default() + }; + + helper_calculate_id(¶ms, &[]).unwrap(); + + // Case with number of statements for the id equal to number of public statements + let params = Params { + max_public_statements: 2, + num_public_statements_id: 2, + ..Default::default() + }; + + let statements = [ + Statement::ValueOf(AnchoredKey::from((SELF, "foo")), Value::from(42)), + Statement::Equal( + AnchoredKey::from((SELF, "bar")), + AnchoredKey::from((SELF, "baz")), + ), + ] + .into_iter() + .chain(iter::repeat(Statement::None)) + .take(params.max_public_statements) + .collect_vec(); + + helper_calculate_id(¶ms, &statements).unwrap(); + + // Case with more statements for the id than the number of public statements + let params = Params { + max_public_statements: 4, + num_public_statements_id: 6, + ..Default::default() + }; + + let pod_id = PodId(hash_str("pod_id")); + let statements = [ + Statement::ValueOf(AnchoredKey::from((SELF, "foo")), Value::from(42)), + Statement::Equal( + AnchoredKey::from((SELF, "bar")), + AnchoredKey::from((SELF, "baz")), + ), + Statement::Lt( + AnchoredKey::from((pod_id, "one")), + AnchoredKey::from((pod_id, "two")), + ), + ] + .into_iter() + .chain(iter::repeat(Statement::None)) + .take(params.max_public_statements) + .collect_vec(); + + helper_calculate_id(¶ms, &statements).unwrap(); + + Ok(()) + } } diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index fac5241..fa602b8 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1,6 +1,6 @@ pub mod operation; pub mod statement; -use std::{any::Any, sync::Arc}; +use std::{any::Any, iter, sync::Arc}; use base64::{prelude::BASE64_STANDARD, Engine}; use itertools::Itertools; @@ -35,11 +35,30 @@ use crate::{ }, }; -/// Hash a list of public statements to derive the PodId -pub(crate) fn hash_statements(statements: &[Statement], _params: &Params) -> middleware::Hash { - let field_elems = statements +/// Hash a list of public statements to derive the PodId. To make circuits with different number +/// of `max_public_statements compatible we pad the statements up to `num_public_statements_id`. +/// As an optimization we front pad with none-statements so that circuits with a small +/// `max_public_statements` only pay for `max_public_statements` by starting the poseidon state +/// with a precomputed constant corresponding to the front-padding part: +/// `id = hash(serialize(reverse(statements || none-statements)))` +pub(crate) fn calculate_id(statements: &[Statement], params: &Params) -> middleware::Hash { + assert_eq!(params.max_public_statements, statements.len()); + assert!(params.max_public_statements <= params.num_public_statements_id); + statements .iter() - .flat_map(|statement| statement.clone().to_fields(_params)) + .for_each(|st| assert_eq!(params.max_statement_args, st.1.len())); + + let mut none_st: Statement = middleware::Statement::None.into(); + pad_statement(params, &mut none_st); + let statements_back_padded = statements + .iter() + .chain(iter::repeat(&none_st)) + .take(params.num_public_statements_id) + .collect_vec(); + let field_elems = statements_back_padded + .iter() + .rev() + .flat_map(|statement| statement.to_fields(params)) .collect::>(); Hash(PoseidonHash::hash_no_pad(&field_elems).elements) } @@ -421,7 +440,7 @@ impl Prover { let public_statements = statements[statements.len() - params.max_public_statements..].to_vec(); // get the id out of the public statements - let id: PodId = PodId(hash_statements(&public_statements, params)); + let id: PodId = PodId(calculate_id(&public_statements, params)); let input = MainPodVerifyInput { signed_pods: signed_pods_input, @@ -505,7 +524,7 @@ fn get_common_data(params: &Params) -> Result, Error> { impl MainPod { fn _verify(&self) -> Result<()> { // 2. get the id out of the public statements - let id: PodId = PodId(hash_statements(&self.public_statements, &self.params)); + let id: PodId = PodId(calculate_id(&self.public_statements, &self.params)); if id != self.id { return Err(Error::id_not_equal(self.id, id)); } @@ -700,6 +719,7 @@ pub mod tests { max_statements: 5, max_signed_pod_values: 2, max_public_statements: 2, + num_public_statements_id: 4, max_statement_args: 2, max_operation_args: 3, max_custom_predicate_batches: 2, diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index 616e12d..00de19d 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -11,7 +11,7 @@ use crate::{ backends::plonky2::{ error::{Error, Result}, mainpod::{ - extract_merkle_proofs, hash_statements, layout_statements, normalize_statement, + calculate_id, extract_merkle_proofs, layout_statements, normalize_statement, process_private_statements_operations, process_public_statements_operations, Operation, Statement, }, @@ -163,7 +163,7 @@ impl MockMainPod { statements[statements.len() - params.max_public_statements..].to_vec(); // get the id out of the public statements - let id: PodId = PodId(hash_statements(&public_statements, params)); + let id: PodId = PodId(calculate_id(&public_statements, params)); Ok(Self { params: params.clone(), @@ -197,7 +197,7 @@ impl MockMainPod { // get the input_statements from the self.statements let input_statements = &self.statements[input_statement_offset..]; // 2. get the id out of the public statements, and ensure it is equal to self.id - let ids_match = self.id == PodId(hash_statements(&self.public_statements, &self.params)); + let ids_match = self.id == PodId(calculate_id(&self.public_statements, &self.params)); // find a ValueOf statement from the public statements with key=KEY_TYPE and check that the // value is PodType::MockMainPod let has_type_statement = self.public_statements.iter().any(|s| { @@ -351,8 +351,7 @@ pub mod tests { #[test] fn test_mock_main_great_boy() -> frontend::Result<()> { - let params = middleware::Params::default(); - let great_boy_builder = great_boy_pod_full_flow()?; + let (params, great_boy_builder) = great_boy_pod_full_flow()?; let mut prover = MockProver {}; let great_boy_pod = great_boy_builder.prove(&mut prover, ¶ms)?; diff --git a/src/examples/mod.rs b/src/examples/mod.rs index ae51078..9c9885a 100644 --- a/src/examples/mod.rs +++ b/src/examples/mod.rs @@ -296,11 +296,12 @@ pub fn great_boy_pod_builder( Ok(great_boy) } -pub fn great_boy_pod_full_flow() -> Result { +pub fn great_boy_pod_full_flow() -> Result<(Params, MainPodBuilder)> { let params = Params { max_input_signed_pods: 6, max_statements: 100, max_public_statements: 50, + num_public_statements_id: 50, ..Default::default() }; @@ -349,7 +350,7 @@ pub fn great_boy_pod_full_flow() -> Result { good_boy_issuers.into_iter().map(Value::from).collect(), )?); - great_boy_pod_builder( + let builder = great_boy_pod_builder( ¶ms, [ &bob_good_boys[0], @@ -360,7 +361,9 @@ pub fn great_boy_pod_full_flow() -> Result { [&alice_friend_pods[0], &alice_friend_pods[1]], &good_boy_issuers, alice, - ) + )?; + + Ok((params, builder)) } // Tickets diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index aaa1481..5ff0c98 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -940,7 +940,7 @@ pub mod tests { #[test] fn test_front_great_boy() -> Result<()> { - let great_boy = great_boy_pod_full_flow()?; + let (_, great_boy) = great_boy_pod_full_flow()?; println!("{}", great_boy); // TODO: prove great_boy with MockProver and print it diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 8c3313e..41d7e10 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -582,6 +582,9 @@ pub struct Params { pub max_statements: usize, pub max_signed_pod_values: usize, pub max_public_statements: usize, + // Number of public statements to hash to calculate the id. Must be equal or greater than + // `max_public_statements`. + pub num_public_statements_id: usize, pub max_statement_args: usize, pub max_operation_args: usize, // max number of custom predicates batches that a MainPod can use @@ -607,6 +610,7 @@ impl Default for Params { max_statements: 20, max_signed_pod_values: 8, max_public_statements: 10, + num_public_statements_id: 16, max_statement_args: 5, max_operation_args: 5, max_custom_predicate_batches: 2,