calculate MainPod id in a dynamic-friendly way (#241)

* calculate MainPod id in a dynamic-friendly way

The MainPod id is now calculated with front padding and a fixed size
independent of max_public_statements so that introduction gadgets can be
verified by a MainPod while paying only for the number of statements
they use.  This is because with front padding of none-statements we can
precompute the poseidon state corresponding to absorbing all the padding
statements and only pay constraints for the non-padding statements.

The id is calculated as follows:
`id = hash(serialize(reverse(statements || none-statements)))`

* fix test
This commit is contained in:
Eduard S. 2025-05-23 10:12:28 +02:00 committed by GitHub
parent 82481e88d7
commit d3fef8392e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 245 additions and 26 deletions

View file

@ -3,9 +3,16 @@ use std::{array, iter, sync::Arc};
use itertools::{zip_eq, Itertools}; use itertools::{zip_eq, Itertools};
use plonky2::{ use plonky2::{
field::types::Field, field::types::Field,
hash::{hash_types::HashOutTarget, poseidon::PoseidonHash}, hash::{
iop::{target::BoolTarget, witness::PartialWitness}, hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS},
plonk::circuit_builder::CircuitBuilder, hashing::PlonkyPermutation,
poseidon::{PoseidonHash, PoseidonPermutation},
},
iop::{
target::{BoolTarget, Target},
witness::PartialWitness,
},
plonk::{circuit_builder::CircuitBuilder, config::AlgebraicHasher},
}; };
use crate::{ use crate::{
@ -22,7 +29,7 @@ use crate::{
signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget}, signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget},
}, },
error::Result, error::Result,
mainpod, mainpod::{self, pad_statement},
primitives::merkletree::{ primitives::merkletree::{
MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProofGadget, MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProofGadget,
}, },
@ -894,6 +901,88 @@ impl CustomOperationVerifyGadget {
} }
} }
struct CalculateIdGadget {
params: Params,
}
impl CalculateIdGadget {
/// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder
/// elements that didn't fit into a chunk.
fn precompute_hash_state<F: RichField, P: PlonkyPermutation<F>>(inputs: &[F]) -> (P, &[F]) {
let (inputs, inputs_rem) = inputs.split_at((inputs.len() / P::RATE) * P::RATE);
let mut perm = P::new(core::iter::repeat(F::ZERO));
// Absorb all inputs up to the biggest multiple of RATE.
for input_chunk in inputs.chunks(P::RATE) {
perm.set_from_slice(input_chunk, 0);
perm.permute();
}
(perm, inputs_rem)
}
/// Hash `inputs` starting from a circuit-constant `perm` state.
fn hash_from_state<H: AlgebraicHasher<F>, P: PlonkyPermutation<F>>(
builder: &mut CircuitBuilder<F, D>,
perm: P,
inputs: &[Target],
) -> HashOutTarget {
let mut state =
H::AlgebraicPermutation::new(perm.as_ref().iter().map(|v| builder.constant(*v)));
// Absorb all input chunks.
for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) {
// Overwrite the first r elements with the inputs. This differs from a standard sponge,
// where we would xor or add in the inputs. This is a well-known variant, though,
// sometimes called "overwrite mode".
state.set_from_slice(input_chunk, 0);
state = builder.permute::<H>(state);
}
let num_outputs = NUM_HASH_OUT_ELTS;
// Squeeze until we have the desired number of outputs.
let mut outputs = Vec::with_capacity(num_outputs);
loop {
for &s in state.squeeze() {
outputs.push(s);
if outputs.len() == num_outputs {
return HashOutTarget::from_vec(outputs);
}
}
state = builder.permute::<H>(state);
}
}
fn eval(
&self,
builder: &mut CircuitBuilder<F, D>,
statements: &[StatementTarget],
) -> HashOutTarget {
let measure = measure_gates_begin!(builder, "CalculateId");
let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten());
let mut none_st = mainpod::Statement::from(Statement::None);
pad_statement(&self.params, &mut none_st);
let front_pad_elts = iter::repeat(&none_st)
.take(self.params.num_public_statements_id - self.params.max_public_statements)
.flat_map(|s| s.to_fields(&self.params))
.collect_vec();
let (perm, front_pad_elts_rem) =
Self::precompute_hash_state::<F, PoseidonPermutation<F>>(&front_pad_elts);
// Precompute the Poseidon state for the initial padding chunks
let inputs = front_pad_elts_rem
.iter()
.map(|v| builder.constant(*v))
.chain(statements_rev_flattened)
.collect_vec();
let id =
Self::hash_from_state::<PoseidonHash, PoseidonPermutation<F>>(builder, perm, &inputs);
measure_gates_end!(builder, measure);
id
}
}
struct MainPodVerifyGadget { struct MainPodVerifyGadget {
params: Params, params: Params,
} }
@ -1089,10 +1178,10 @@ impl MainPodVerifyGadget {
self.build_custom_predicate_verification_table(builder, &custom_predicate_table)?; self.build_custom_predicate_verification_table(builder, &custom_predicate_table)?;
// 2. Calculate the Pod Id from the public statements // 2. Calculate the Pod Id from the public statements
let measure_calc_id = measure_gates_begin!(builder, "MainPodId"); let id = CalculateIdGadget {
let pub_statements_flattened = pub_statements.iter().flat_map(|s| s.flatten()).collect(); params: self.params.clone(),
let id = builder.hash_n_to_hash_no_pad::<PoseidonHash>(pub_statements_flattened); }
measure_gates_end!(builder, measure_calc_id); .eval(builder, pub_statements);
// 4. Verify type // 4. Verify type
let type_statement = &pub_statements[0]; let type_statement = &pub_statements[0];
@ -1266,10 +1355,12 @@ impl MainPodVerifyCircuit {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::ops::Not; use std::{iter, ops::Not};
use plonky2::{ use plonky2::{
field::{goldilocks_field::GoldilocksField, types::Field}, field::{goldilocks_field::GoldilocksField, types::Field},
hash::hash_types::HashOut,
iop::witness::WitnessWrite,
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig},
}; };
@ -1278,7 +1369,7 @@ mod tests {
backends::plonky2::{ backends::plonky2::{
basetypes::C, basetypes::C,
circuits::common::tests::I64_TEST_PAIRS, circuits::common::tests::I64_TEST_PAIRS,
mainpod::{OperationArg, OperationAux}, mainpod::{calculate_id, OperationArg, OperationAux},
primitives::merkletree::{MerkleClaimAndProof, MerkleTree}, primitives::merkletree::{MerkleClaimAndProof, MerkleTree},
}, },
frontend::{self, key, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, frontend::{self, key, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
@ -2681,4 +2772,106 @@ mod tests {
Ok(()) Ok(())
} }
fn helper_calculate_id(params: &Params, statements: &[Statement]) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let gadget = CalculateIdGadget {
params: params.clone(),
};
let statements_target = (0..params.max_public_statements)
.map(|_| builder.add_virtual_statement(params))
.collect_vec();
let id_target = gadget.eval(&mut builder, &statements_target);
let mut pw = PartialWitness::<F>::new();
// Input
let statements = statements
.into_iter()
.map(|st| {
let mut st = mainpod::Statement::from(st.clone());
pad_statement(params, &mut st);
st
})
.collect_vec();
for (st_target, st) in statements_target.iter().zip(statements.iter()) {
st_target.set_targets(&mut pw, params, st)?;
}
// Expected Output
let expected_id = calculate_id(&statements, params);
pw.set_hash_target(
id_target,
HashOut {
elements: expected_id.0,
},
)?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw)?;
Ok(data.verify(proof.clone())?)
}
#[test]
fn test_calculate_id() -> frontend::Result<()> {
// Case with no public public statements
let params = Params {
max_public_statements: 0,
num_public_statements_id: 8,
..Default::default()
};
helper_calculate_id(&params, &[]).unwrap();
// Case with number of statements for the id equal to number of public statements
let params = Params {
max_public_statements: 2,
num_public_statements_id: 2,
..Default::default()
};
let statements = [
Statement::ValueOf(AnchoredKey::from((SELF, "foo")), Value::from(42)),
Statement::Equal(
AnchoredKey::from((SELF, "bar")),
AnchoredKey::from((SELF, "baz")),
),
]
.into_iter()
.chain(iter::repeat(Statement::None))
.take(params.max_public_statements)
.collect_vec();
helper_calculate_id(&params, &statements).unwrap();
// Case with more statements for the id than the number of public statements
let params = Params {
max_public_statements: 4,
num_public_statements_id: 6,
..Default::default()
};
let pod_id = PodId(hash_str("pod_id"));
let statements = [
Statement::ValueOf(AnchoredKey::from((SELF, "foo")), Value::from(42)),
Statement::Equal(
AnchoredKey::from((SELF, "bar")),
AnchoredKey::from((SELF, "baz")),
),
Statement::Lt(
AnchoredKey::from((pod_id, "one")),
AnchoredKey::from((pod_id, "two")),
),
]
.into_iter()
.chain(iter::repeat(Statement::None))
.take(params.max_public_statements)
.collect_vec();
helper_calculate_id(&params, &statements).unwrap();
Ok(())
}
} }

View file

@ -1,6 +1,6 @@
pub mod operation; pub mod operation;
pub mod statement; pub mod statement;
use std::{any::Any, sync::Arc}; use std::{any::Any, iter, sync::Arc};
use base64::{prelude::BASE64_STANDARD, Engine}; use base64::{prelude::BASE64_STANDARD, Engine};
use itertools::Itertools; use itertools::Itertools;
@ -35,11 +35,30 @@ use crate::{
}, },
}; };
/// Hash a list of public statements to derive the PodId /// Hash a list of public statements to derive the PodId. To make circuits with different number
pub(crate) fn hash_statements(statements: &[Statement], _params: &Params) -> middleware::Hash { /// of `max_public_statements compatible we pad the statements up to `num_public_statements_id`.
let field_elems = statements /// As an optimization we front pad with none-statements so that circuits with a small
/// `max_public_statements` only pay for `max_public_statements` by starting the poseidon state
/// with a precomputed constant corresponding to the front-padding part:
/// `id = hash(serialize(reverse(statements || none-statements)))`
pub(crate) fn calculate_id(statements: &[Statement], params: &Params) -> middleware::Hash {
assert_eq!(params.max_public_statements, statements.len());
assert!(params.max_public_statements <= params.num_public_statements_id);
statements
.iter() .iter()
.flat_map(|statement| statement.clone().to_fields(_params)) .for_each(|st| assert_eq!(params.max_statement_args, st.1.len()));
let mut none_st: Statement = middleware::Statement::None.into();
pad_statement(params, &mut none_st);
let statements_back_padded = statements
.iter()
.chain(iter::repeat(&none_st))
.take(params.num_public_statements_id)
.collect_vec();
let field_elems = statements_back_padded
.iter()
.rev()
.flat_map(|statement| statement.to_fields(params))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Hash(PoseidonHash::hash_no_pad(&field_elems).elements) Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
} }
@ -421,7 +440,7 @@ impl Prover {
let public_statements = let public_statements =
statements[statements.len() - params.max_public_statements..].to_vec(); statements[statements.len() - params.max_public_statements..].to_vec();
// get the id out of the public statements // get the id out of the public statements
let id: PodId = PodId(hash_statements(&public_statements, params)); let id: PodId = PodId(calculate_id(&public_statements, params));
let input = MainPodVerifyInput { let input = MainPodVerifyInput {
signed_pods: signed_pods_input, signed_pods: signed_pods_input,
@ -505,7 +524,7 @@ fn get_common_data(params: &Params) -> Result<CommonCircuitData<F, D>, Error> {
impl MainPod { impl MainPod {
fn _verify(&self) -> Result<()> { fn _verify(&self) -> Result<()> {
// 2. get the id out of the public statements // 2. get the id out of the public statements
let id: PodId = PodId(hash_statements(&self.public_statements, &self.params)); let id: PodId = PodId(calculate_id(&self.public_statements, &self.params));
if id != self.id { if id != self.id {
return Err(Error::id_not_equal(self.id, id)); return Err(Error::id_not_equal(self.id, id));
} }
@ -700,6 +719,7 @@ pub mod tests {
max_statements: 5, max_statements: 5,
max_signed_pod_values: 2, max_signed_pod_values: 2,
max_public_statements: 2, max_public_statements: 2,
num_public_statements_id: 4,
max_statement_args: 2, max_statement_args: 2,
max_operation_args: 3, max_operation_args: 3,
max_custom_predicate_batches: 2, max_custom_predicate_batches: 2,

View file

@ -11,7 +11,7 @@ use crate::{
backends::plonky2::{ backends::plonky2::{
error::{Error, Result}, error::{Error, Result},
mainpod::{ mainpod::{
extract_merkle_proofs, hash_statements, layout_statements, normalize_statement, calculate_id, extract_merkle_proofs, layout_statements, normalize_statement,
process_private_statements_operations, process_public_statements_operations, Operation, process_private_statements_operations, process_public_statements_operations, Operation,
Statement, Statement,
}, },
@ -163,7 +163,7 @@ impl MockMainPod {
statements[statements.len() - params.max_public_statements..].to_vec(); statements[statements.len() - params.max_public_statements..].to_vec();
// get the id out of the public statements // get the id out of the public statements
let id: PodId = PodId(hash_statements(&public_statements, params)); let id: PodId = PodId(calculate_id(&public_statements, params));
Ok(Self { Ok(Self {
params: params.clone(), params: params.clone(),
@ -197,7 +197,7 @@ impl MockMainPod {
// get the input_statements from the self.statements // get the input_statements from the self.statements
let input_statements = &self.statements[input_statement_offset..]; let input_statements = &self.statements[input_statement_offset..];
// 2. get the id out of the public statements, and ensure it is equal to self.id // 2. get the id out of the public statements, and ensure it is equal to self.id
let ids_match = self.id == PodId(hash_statements(&self.public_statements, &self.params)); let ids_match = self.id == PodId(calculate_id(&self.public_statements, &self.params));
// find a ValueOf statement from the public statements with key=KEY_TYPE and check that the // find a ValueOf statement from the public statements with key=KEY_TYPE and check that the
// value is PodType::MockMainPod // value is PodType::MockMainPod
let has_type_statement = self.public_statements.iter().any(|s| { let has_type_statement = self.public_statements.iter().any(|s| {
@ -351,8 +351,7 @@ pub mod tests {
#[test] #[test]
fn test_mock_main_great_boy() -> frontend::Result<()> { fn test_mock_main_great_boy() -> frontend::Result<()> {
let params = middleware::Params::default(); let (params, great_boy_builder) = great_boy_pod_full_flow()?;
let great_boy_builder = great_boy_pod_full_flow()?;
let mut prover = MockProver {}; let mut prover = MockProver {};
let great_boy_pod = great_boy_builder.prove(&mut prover, &params)?; let great_boy_pod = great_boy_builder.prove(&mut prover, &params)?;

View file

@ -296,11 +296,12 @@ pub fn great_boy_pod_builder(
Ok(great_boy) Ok(great_boy)
} }
pub fn great_boy_pod_full_flow() -> Result<MainPodBuilder> { pub fn great_boy_pod_full_flow() -> Result<(Params, MainPodBuilder)> {
let params = Params { let params = Params {
max_input_signed_pods: 6, max_input_signed_pods: 6,
max_statements: 100, max_statements: 100,
max_public_statements: 50, max_public_statements: 50,
num_public_statements_id: 50,
..Default::default() ..Default::default()
}; };
@ -349,7 +350,7 @@ pub fn great_boy_pod_full_flow() -> Result<MainPodBuilder> {
good_boy_issuers.into_iter().map(Value::from).collect(), good_boy_issuers.into_iter().map(Value::from).collect(),
)?); )?);
great_boy_pod_builder( let builder = great_boy_pod_builder(
&params, &params,
[ [
&bob_good_boys[0], &bob_good_boys[0],
@ -360,7 +361,9 @@ pub fn great_boy_pod_full_flow() -> Result<MainPodBuilder> {
[&alice_friend_pods[0], &alice_friend_pods[1]], [&alice_friend_pods[0], &alice_friend_pods[1]],
&good_boy_issuers, &good_boy_issuers,
alice, alice,
) )?;
Ok((params, builder))
} }
// Tickets // Tickets

View file

@ -940,7 +940,7 @@ pub mod tests {
#[test] #[test]
fn test_front_great_boy() -> Result<()> { fn test_front_great_boy() -> Result<()> {
let great_boy = great_boy_pod_full_flow()?; let (_, great_boy) = great_boy_pod_full_flow()?;
println!("{}", great_boy); println!("{}", great_boy);
// TODO: prove great_boy with MockProver and print it // TODO: prove great_boy with MockProver and print it

View file

@ -582,6 +582,9 @@ pub struct Params {
pub max_statements: usize, pub max_statements: usize,
pub max_signed_pod_values: usize, pub max_signed_pod_values: usize,
pub max_public_statements: usize, pub max_public_statements: usize,
// Number of public statements to hash to calculate the id. Must be equal or greater than
// `max_public_statements`.
pub num_public_statements_id: usize,
pub max_statement_args: usize, pub max_statement_args: usize,
pub max_operation_args: usize, pub max_operation_args: usize,
// max number of custom predicates batches that a MainPod can use // max number of custom predicates batches that a MainPod can use
@ -607,6 +610,7 @@ impl Default for Params {
max_statements: 20, max_statements: 20,
max_signed_pod_values: 8, max_signed_pod_values: 8,
max_public_statements: 10, max_public_statements: 10,
num_public_statements_id: 16,
max_statement_args: 5, max_statement_args: 5,
max_operation_args: 5, max_operation_args: 5,
max_custom_predicate_batches: 2, max_custom_predicate_batches: 2,