Merkle tree for custom predicate batches (#471)

Resolve https://github.com/0xPARC/pod2/issues/466

Now batches are identified by the root of a merkle tree that contains all the predicates (using sequential indices as keys).  This means that the format to identify a custom predicate reference is still a hash + index, but the calculation of the hash is different.
The MainPod circuit now isn't limited by number of batches but instead number of custom predicates; and for each one we verify a merkle proof to verify the batch id.

I've removed a bunch of tests from lang that were testing splitting into multiple batches because there's no longer any need for that.  In a future PR we'll remove the code that handles batch splitting.

Each custom predicate needs 148.2 gates (which is very close to my estimate of 142.7 in https://github.com/0xPARC/pod2/issues/466#issuecomment-3823531286 where I actually made a mistake and considered 5 predicates per batch instead of 4 in the previous Params).
This commit is contained in:
Eduard S. 2026-02-04 11:12:32 +01:00 committed by GitHub
parent a7a30176a7
commit 641d8dabdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 331 additions and 761 deletions

View file

@ -28,14 +28,17 @@ use crate::{
circuits::mainpod::CustomPredicateVerification,
error::Result,
mainpod::{Operation, OperationArg, OperationAux, Statement},
primitives::merkletree::{MerkleClaimAndProofTarget, MerkleTreeStateTransitionProofTarget},
primitives::merkletree::{
verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget,
MerkleProof, MerkleTreeStateTransitionProofTarget,
},
},
middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
NativePredicate, OperationType, Params, Predicate, PredicateOrWildcard,
PredicateOrWildcardPrefix, PredicatePrefix, RawValue, StatementArg, StatementTmpl,
StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE,
STATEMENT_ARG_F_LEN, VALUE_SIZE,
hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate,
OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix,
PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg,
StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, STATEMENT_ARG_F_LEN,
VALUE_SIZE,
},
};
@ -688,34 +691,65 @@ impl CustomPredicateTarget {
}
}
/// This type is used to build the custom predicate table, which exposes the custom predicates with
/// normalized statement templates indexed by batch_id and custom_predicate_index.
/// Custom predicate structure that can be verified to belong to a batch id at a particular index
#[derive(Clone, Serialize, Deserialize)]
pub struct CustomPredicateBatchTarget {
pub predicates: Vec<CustomPredicateTarget>,
pub struct CustomPredicateInBatchTarget {
pub id: HashOutTarget,
pub index: Target,
/// Predicate that may use references to another predicate of the batch with BatchSelf
pub self_predicate: CustomPredicateTarget,
pub mtp: MerkleClaimAndProofTarget,
}
impl CustomPredicateBatchTarget {
pub fn id(&self, builder: &mut CircuitBuilder) -> HashOutTarget {
let flattened: Vec<_> = self.predicates.iter().flat_map(|cp| cp.flatten()).collect();
builder.hash_n_to_hash_no_pad::<PoseidonHash>(flattened)
}
impl CustomPredicateInBatchTarget {
/// This constructor connects the merkle proof and claim targets with with the (index,
/// self_predicate) and id.
pub fn new_virtual(builder: &mut CircuitBuilder) -> CustomPredicateInBatchTarget {
let index = builder.add_virtual_target();
let self_predicate = builder.add_virtual_custom_predicate(true);
// Existence Merkle Tree proof of (index, hash(self_predicate)) -> id
let mtp =
MerkleClaimAndProofTarget::new_virtual(Params::max_depth_custom_batch_mt(), builder);
let _true = builder._true();
builder.connect(_true.target, mtp.enabled.target);
builder.connect(_true.target, mtp.existence.target);
let zero = builder.constant(F(0));
let key = ValueTarget {
elements: [index, zero, zero, zero],
};
builder.connect_values(key, mtp.key);
let id = mtp.root;
Self {
id,
index,
mtp,
self_predicate,
}
}
/// Hash the predicate, connect it to the merkle proof claim value and verify the merkle proof.
pub fn verify_circuit(&self, builder: &mut CircuitBuilder) {
let value = builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.self_predicate.flatten());
builder.connect_array(value.elements, self.mtp.value.elements);
verify_merkle_proof_circuit(builder, &self.mtp);
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
custom_predicate_batch: &CustomPredicateBatch,
predicate_ref: &CustomPredicateRef,
mtp: &MerkleProof,
) -> Result<()> {
let pad_predicate = CustomPredicate::empty();
for (i, predicate) in custom_predicate_batch
.predicates()
.iter()
.chain(iter::repeat(&pad_predicate))
.take(Params::max_custom_batch_size())
.enumerate()
{
self.predicates[i].set_targets(pw, predicate)?;
}
pw.set_target_arr(&self.id.elements, &predicate_ref.batch.id().0)?;
pw.set_target(self.index, F::from_canonical_usize(predicate_ref.index))?;
let predicate = predicate_ref.predicate();
self.self_predicate.set_targets(pw, predicate)?;
let mtp_claim = MerkleClaimAndProof {
root: predicate_ref.batch.id(),
key: Value::from(predicate_ref.index as i64).raw(),
value: RawValue::from(hash_fields(&predicate.to_fields())),
proof: mtp.clone(),
};
self.mtp.set_targets(pw, true, &mtp_claim)?;
Ok(())
}
}
@ -812,11 +846,9 @@ pub struct CustomPredicateVerifyEntryTarget {
impl CustomPredicateVerifyEntryTarget {
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
let custom_predicate_table_len =
params.max_custom_predicate_batches * Params::max_custom_batch_size();
CustomPredicateVerifyEntryTarget {
custom_predicate_table_index: IndexTarget::new_virtual(
custom_predicate_table_len,
params.max_custom_predicates,
builder,
),
custom_predicate: builder.add_virtual_custom_predicate_entry(),
@ -1245,8 +1277,6 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget;
fn add_virtual_statement_tmpl(&mut self, with_pred: bool) -> StatementTmplTarget;
fn add_virtual_custom_predicate(&mut self, with_pred: bool) -> CustomPredicateTarget;
fn add_virtual_custom_predicate_batch(&mut self, with_pred: bool)
-> CustomPredicateBatchTarget;
fn add_virtual_custom_predicate_entry(&mut self) -> CustomPredicateEntryTarget;
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
fn select_statement_arg(
@ -1435,18 +1465,6 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
}
}
/// See `add_virtual_statement_tmpl` for the meaning of `with_pred`.
fn add_virtual_custom_predicate_batch(
&mut self,
with_pred: bool,
) -> CustomPredicateBatchTarget {
CustomPredicateBatchTarget {
predicates: (0..Params::max_custom_batch_size())
.map(|_| self.add_virtual_custom_predicate(with_pred))
.collect(),
}
}
/// See `add_virtual_statement_tmpl` for the meaning of `with_pred`.
fn add_virtual_custom_predicate_entry(&mut self) -> CustomPredicateEntryTarget {
CustomPredicateEntryTarget {
@ -1869,6 +1887,8 @@ impl SimpleGenerator<F, D> for LtMaskGenerator {
#[cfg(test)]
pub(crate) mod tests {
use std::sync::Arc;
use anyhow::anyhow;
use itertools::Itertools;
use plonky2::plonk::{
@ -1878,8 +1898,10 @@ pub(crate) mod tests {
use super::*;
use crate::{
backends::plonky2::basetypes::C, examples::custom::eth_dos_batch, frontend,
frontend::CustomPredicateBatchBuilder, middleware::CustomPredicateBatch,
backends::plonky2::basetypes::C,
examples::custom::eth_dos_batch,
frontend::{self, CustomPredicateBatchBuilder},
middleware::CustomPredicateBatch,
};
pub(crate) const I64_TEST_PAIRS: [(i64, i64); 36] = [
@ -1952,50 +1974,54 @@ pub(crate) mod tests {
Ok(())
}
fn helper_custom_predicate_batch_target_id(
custom_predicate_batch: &CustomPredicateBatch,
fn helper_custom_predicate_in_batch_target(
custom_predicate_batch: &Arc<CustomPredicateBatch>,
) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
for index in 0..custom_predicate_batch.predicates().len() {
let cpr = custom_predicate_batch
.predicate_ref_by_index(index)
.unwrap();
let custom_predicate_batch_target = builder.add_virtual_custom_predicate_batch(false);
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
// Calculate the id in constraints and compare it against the id calculated natively
let id_target = custom_predicate_batch_target.id(&mut builder);
let custom_pred_in_batch_target =
CustomPredicateInBatchTarget::new_virtual(&mut builder);
custom_pred_in_batch_target.verify_circuit(&mut builder);
let mut pw = PartialWitness::<F>::new();
custom_predicate_batch_target.set_targets(&mut pw, custom_predicate_batch)?;
let id = custom_predicate_batch.id();
pw.set_target_arr(&id_target.elements, &id.0)?;
let mut pw = PartialWitness::<F>::new();
let (_, mtp) = custom_predicate_batch
.mt()
.prove(&Value::from(index as i64).raw())
.unwrap();
custom_pred_in_batch_target.set_targets(&mut pw, &cpr, &mtp)?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof.clone()).unwrap();
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof.clone()).unwrap();
}
Ok(())
}
#[test]
fn test_custom_predicate_batch_target_id() -> frontend::Result<()> {
let params = Params {
max_custom_predicate_wildcards: 12,
..Default::default()
};
fn test_custom_predicate_in_batch_target() -> frontend::Result<()> {
let params = Params::default();
// Empty case
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
_ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
let custom_predicate_batch = cpb_builder.finish();
helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap();
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
// Some cases from the examples
let custom_predicate_batch = eth_dos_batch(&params)?;
helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap();
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
let custom_predicate_batch =
CustomPredicateBatch::new(&params, "empty".to_string(), vec![CustomPredicate::empty()]);
helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap();
CustomPredicateBatch::new("empty".to_string(), vec![CustomPredicate::empty()]);
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
Ok(())
}