Merkle tree for custom predicate batches (#471)

Resolve https://github.com/0xPARC/pod2/issues/466

Now batches are identified by the root of a merkle tree that contains all the predicates (using sequential indices as keys).  This means that the format to identify a custom predicate reference is still a hash + index, but the calculation of the hash is different.
The MainPod circuit now isn't limited by number of batches but instead number of custom predicates; and for each one we verify a merkle proof to verify the batch id.

I've removed a bunch of tests from lang that were testing splitting into multiple batches because there's no longer any need for that.  In a future PR we'll remove the code that handles batch splitting.

Each custom predicate needs 148.2 gates (which is very close to my estimate of 142.7 in https://github.com/0xPARC/pod2/issues/466#issuecomment-3823531286 where I actually made a mistake and considered 5 predicates per batch instead of 4 in the previous Params).
This commit is contained in:
Eduard S. 2026-02-04 11:12:32 +01:00 committed by GitHub
parent a7a30176a7
commit 641d8dabdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 331 additions and 761 deletions

View file

@ -1,4 +1,4 @@
use std::{array, iter, sync::Arc};
use std::{array, iter};
use itertools::{izip, zip_eq, Itertools};
use num::{BigUint, One};
@ -21,7 +21,7 @@ use crate::{
basetypes::{CircuitBuilder, VDSet},
circuits::{
common::{
CircuitBuilderPod, CustomPredicateBatchTarget, CustomPredicateEntryTarget,
CircuitBuilderPod, CustomPredicateEntryTarget, CustomPredicateInBatchTarget,
CustomPredicateTarget, CustomPredicateVerifyEntryTarget,
CustomPredicateVerifyQueryTarget, Flattenable, MerkleClaimTarget,
MerkleTreeStateTransitionClaimTarget, OperationTarget, OperationTypeTarget,
@ -44,7 +44,7 @@ use crate::{
},
merkletree::{
verify_merkle_proof_circuit, verify_merkle_state_transition_circuit,
MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleTreeOp,
MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProof, MerkleTreeOp,
MerkleTreeStateTransitionProof, MerkleTreeStateTransitionProofTarget,
},
signature::{verify_signature_circuit, SignatureVerifyTarget},
@ -1573,37 +1573,34 @@ fn normalize_st_tmpl_circuit(
fn build_custom_predicate_table_circuit(
params: &Params,
builder: &mut CircuitBuilder,
custom_predicate_batches: &[CustomPredicateBatchTarget],
custom_predicates: &[CustomPredicateInBatchTarget],
) -> Result<Vec<HashOutTarget>> {
let measure = measure_gates_begin!(builder, "BuildCustomPredTbl");
let mut custom_predicate_table =
Vec::with_capacity(params.max_custom_predicate_batches * Params::max_custom_batch_size());
for cpb in custom_predicate_batches {
let measure_cpb = measure_gates_begin!(builder, "CustomPredBatch");
let id = cpb.id(builder); // constrain the id
for (index, cp) in cpb.predicates.iter().enumerate() {
let statements = cp
.statements
.iter()
.map(|st_with_pred_tmpl| {
normalize_st_tmpl_circuit(params, builder, st_with_pred_tmpl, id)
})
.collect_vec();
let cp = CustomPredicateTarget {
conjunction: cp.conjunction,
let mut custom_predicate_table = Vec::with_capacity(params.max_custom_predicates);
for cp in custom_predicates {
let measure_cp = measure_gates_begin!(builder, "CustomPred");
cp.verify_circuit(builder);
let statements = cp
.self_predicate
.statements
.iter()
.map(|st_with_pred_tmpl| {
normalize_st_tmpl_circuit(params, builder, st_with_pred_tmpl, cp.id)
})
.collect_vec();
let entry = CustomPredicateEntryTarget {
id: cp.id, // output
index: cp.index, // input
predicate: CustomPredicateTarget {
conjunction: cp.self_predicate.conjunction,
statements,
args_len: cp.args_len,
};
let entry = CustomPredicateEntryTarget {
id, // output
index: builder.constant(F::from_canonical_usize(index)), // constant
predicate: cp.clone(), // input
};
args_len: cp.self_predicate.args_len,
}, // input
};
let in_query_hash = entry.hash(builder);
custom_predicate_table.push(in_query_hash);
}
measure_gates_end!(builder, measure_cpb);
let in_query_hash = entry.hash(builder);
custom_predicate_table.push(in_query_hash);
measure_gates_end!(builder, measure_cp);
}
measure_gates_end!(builder, measure);
Ok(custom_predicate_table)
@ -1711,7 +1708,7 @@ fn verify_main_pod_circuit(
// Table of custom predicate batches with batch_id calculation
let custom_predicate_table =
build_custom_predicate_table_circuit(params, builder, &main_pod.custom_predicate_batches)?;
build_custom_predicate_table_circuit(params, builder, &main_pod.custom_predicates)?;
let aux_table = build_operation_aux_table_circuit(
params,
@ -1754,7 +1751,7 @@ pub struct MainPodVerifyTarget {
public_key_of_sks: Vec<BigUInt320Target>,
signed_bys: Vec<SignedByTarget>,
merkle_tree_state_transition_proofs: Vec<MerkleTreeStateTransitionProofTarget>,
custom_predicate_batches: Vec<CustomPredicateBatchTarget>,
custom_predicates: Vec<CustomPredicateInBatchTarget>,
custom_predicate_verifications: Vec<CustomPredicateVerifyEntryTarget>,
}
@ -1799,8 +1796,8 @@ impl MainPodVerifyTarget {
)
})
.collect(),
custom_predicate_batches: (0..params.max_custom_predicate_batches)
.map(|_| builder.add_virtual_custom_predicate_batch(true))
custom_predicates: (0..params.max_custom_predicates)
.map(|_| CustomPredicateInBatchTarget::new_virtual(builder))
.collect(),
custom_predicate_verifications: (0..params.max_custom_predicate_verifications)
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder))
@ -1830,7 +1827,7 @@ pub struct MainPodVerifyInput {
pub public_key_of_sks: Vec<SecretKey>,
pub signed_bys: Vec<SignedBy>,
pub merkle_tree_state_transition_proofs: Vec<MerkleTreeStateTransitionProof>,
pub custom_predicate_batches: Vec<Arc<CustomPredicateBatch>>,
pub custom_predicates_with_mpt_proofs: Vec<(CustomPredicateRef, MerkleProof)>,
pub custom_predicate_verifications: Vec<CustomPredicateVerification>,
}
@ -1972,18 +1969,20 @@ impl InnerCircuit for MainPodVerifyTarget {
self.merkle_tree_state_transition_proofs[i].set_targets(pw, false, &pad_mtp)?;
}
assert!(input.custom_predicate_batches.len() <= self.params.max_custom_predicate_batches);
for (i, cpb) in input.custom_predicate_batches.iter().enumerate() {
self.custom_predicate_batches[i].set_targets(pw, cpb)?;
assert!(input.custom_predicates_with_mpt_proofs.len() <= self.params.max_custom_predicates);
for (i, (cp, mtp)) in input.custom_predicates_with_mpt_proofs.iter().enumerate() {
self.custom_predicates[i].set_targets(pw, cp, mtp)?;
}
// Padding
let pad_cpb = CustomPredicateBatch::new(
&self.params,
"empty".to_string(),
vec![CustomPredicate::empty()],
);
for i in input.custom_predicate_batches.len()..self.params.max_custom_predicate_batches {
self.custom_predicate_batches[i].set_targets(pw, &pad_cpb)?;
let pad_cpb =
CustomPredicateBatch::new("empty".to_string(), vec![CustomPredicate::empty()]);
let pad_cp = pad_cpb.predicate_ref_by_index(0).expect("index 0 exists");
let (_, pad_mtp) = pad_cpb
.mt()
.prove(&Value::from(0i64).raw())
.expect("exists");
for i in input.custom_predicates_with_mpt_proofs.len()..self.params.max_custom_predicates {
self.custom_predicates[i].set_targets(pw, &pad_cp, &pad_mtp)?;
}
assert!(
@ -2096,7 +2095,7 @@ mod tests {
.merkle_tree_state_transition_proofs
.len(),
max_custom_predicate_verifications: 0,
max_custom_predicate_batches: 0,
max_custom_predicates: 0,
..Default::default()
};