Merkle tree for custom predicate batches (#471)

Resolve https://github.com/0xPARC/pod2/issues/466

Now batches are identified by the root of a merkle tree that contains all the predicates (using sequential indices as keys).  This means that the format to identify a custom predicate reference is still a hash + index, but the calculation of the hash is different.
The MainPod circuit now isn't limited by number of batches but instead number of custom predicates; and for each one we verify a merkle proof to verify the batch id.

I've removed a bunch of tests from lang that were testing splitting into multiple batches because there's no longer any need for that.  In a future PR we'll remove the code that handles batch splitting.

Each custom predicate needs 148.2 gates (which is very close to my estimate of 142.7 in https://github.com/0xPARC/pod2/issues/466#issuecomment-3823531286 where I actually made a mistake and considered 5 predicates per batch instead of 4 in the previous Params).
This commit is contained in:
Eduard S. 2026-02-04 11:12:32 +01:00 committed by GitHub
parent a7a30176a7
commit 641d8dabdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 331 additions and 761 deletions

View file

@ -1,7 +1,7 @@
pub mod operation;
use crate::middleware::{wildcard_values_from_op_st, PodType};
pub mod statement;
use std::{iter, sync::Arc};
use std::iter;
use itertools::{zip_eq, Itertools};
use num_bigint::BigUint;
@ -37,9 +37,9 @@ use crate::{
serialize_proof, serialize_verifier_only,
},
middleware::{
self, value_from_op, CustomPredicateBatch, Error as MiddlewareError, Hash, MainPodInputs,
self, value_from_op, CustomPredicateRef, Error as MiddlewareError, Hash, MainPodInputs,
MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg,
ToFields, VDSet,
ToFields, VDSet, Value,
},
timed,
};
@ -68,27 +68,27 @@ pub fn calculate_statements_hash(statements: &[Statement]) -> middleware::Hash {
Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
}
/// Extracts unique `CustomPredicateBatch`es from Custom ops.
pub(crate) fn extract_custom_predicate_batches(
/// Extracts unique `CustomPredicate`s from Custom ops.
pub(crate) fn extract_custom_predicates(
params: &Params,
operations: &[middleware::Operation],
) -> Result<Vec<Arc<CustomPredicateBatch>>> {
let custom_predicate_batches: Vec<_> = operations
) -> Result<Vec<CustomPredicateRef>> {
let custom_predicates: Vec<_> = operations
.iter()
.flat_map(|op| match op {
middleware::Operation::Custom(cpr, _) => Some(cpr.batch.clone()),
middleware::Operation::Custom(cpr, _) => Some(cpr.clone()),
_ => None,
})
.unique_by(|cpr| cpr.id())
.unique()
.collect();
if custom_predicate_batches.len() > params.max_custom_predicate_batches {
if custom_predicates.len() > params.max_custom_predicates {
return Err(Error::custom(format!(
"The number of required `CustomPredicateBatch`es ({}) exceeds the maximum number ({}).",
custom_predicate_batches.len(),
params.max_custom_predicate_batches
"The number of required `CustomPredicate`s ({}) exceeds the maximum number ({}).",
custom_predicates.len(),
params.max_custom_predicates
)));
}
Ok(custom_predicate_batches)
Ok(custom_predicates)
}
/// Extracts all custom predicate operations with all the data required to verify them.
@ -97,7 +97,7 @@ pub(crate) fn extract_custom_predicate_verifications(
aux_list: &mut [OperationAux],
operations: &[middleware::Operation],
statements: &[middleware::Statement],
custom_predicate_batches: &[Arc<CustomPredicateBatch>],
custom_predicates: &[CustomPredicateRef],
) -> Result<Vec<CustomPredicateVerification>> {
let mut table = Vec::new();
for (i, (op, st)) in zip_eq(operations.iter(), statements.iter()).enumerate() {
@ -108,13 +108,11 @@ pub(crate) fn extract_custom_predicate_verifications(
wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args)
.expect("resolved wildcards");
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
let batch_index = custom_predicate_batches
let custom_predicate_table_index = custom_predicates
.iter()
.enumerate()
.find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i))
.find_map(|(i, table_cpr)| (table_cpr == cpr).then_some(i))
.expect("find the custom predicate from the extracted unique list");
let custom_predicate_table_index =
batch_index * Params::max_custom_batch_size() + cpr.index;
aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len());
table.push(CustomPredicateVerification {
custom_predicate_table_index,
@ -497,14 +495,25 @@ impl MainPodProver for Prover {
let mut aux_list = vec![OperationAux::None; params.max_priv_statements()];
let merkle_proofs =
extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?;
let custom_predicate_batches = extract_custom_predicate_batches(params, inputs.operations)?;
let custom_predicates = extract_custom_predicates(params, inputs.operations)?;
let custom_predicate_verifications = extract_custom_predicate_verifications(
params,
&mut aux_list,
inputs.operations,
inputs.statements,
&custom_predicate_batches,
&custom_predicates,
)?;
let custom_predicates_with_mpt_proofs = custom_predicates
.into_iter()
.map(|cpr| {
let (_, mtp) = cpr
.batch
.mt()
.prove(&Value::from(cpr.index as i64).raw())
.expect("index by construction exists");
(cpr, mtp)
})
.collect_vec();
let public_key_of_sks =
extract_public_key_of(params, &mut aux_list, inputs.operations, inputs.statements)?;
let signed_bys =
@ -572,7 +581,7 @@ impl MainPodProver for Prover {
public_key_of_sks,
signed_bys,
merkle_tree_state_transition_proofs,
custom_predicate_batches,
custom_predicates_with_mpt_proofs,
custom_predicate_verifications,
};
@ -840,7 +849,7 @@ pub mod tests {
// Currently the circuit uses random access that only supports vectors of length 64.
// With max_input_main_pods=3 we need random access to a vector of length 73.
max_input_pods: 0,
max_custom_predicate_batches: 0,
max_custom_predicates: 0,
max_custom_predicate_verifications: 0,
..Default::default()
};
@ -961,7 +970,7 @@ pub mod tests {
max_merkle_proofs_containers: 0,
max_public_key_of: 0,
max_custom_predicate_verifications: 0,
max_custom_predicate_batches: 0,
max_custom_predicates: 0,
..Default::default()
};
let mut vds = DEFAULT_VD_LIST.clone();
@ -995,7 +1004,7 @@ pub mod tests {
max_statements: 5,
max_public_statements: 2,
max_operation_args: 5,
max_custom_predicate_batches: 2,
max_custom_predicates: 2,
max_custom_predicate_verifications: 2,
max_custom_predicate_wildcards: 3,
max_merkle_proofs_containers: 2,