Merkle tree for custom predicate batches (#471)

Resolve https://github.com/0xPARC/pod2/issues/466

Now batches are identified by the root of a merkle tree that contains all the predicates (using sequential indices as keys).  This means that the format to identify a custom predicate reference is still a hash + index, but the calculation of the hash is different.
The MainPod circuit now isn't limited by number of batches but instead number of custom predicates; and for each one we verify a merkle proof to verify the batch id.

I've removed a bunch of tests from lang that were testing splitting into multiple batches because there's no longer any need for that.  In a future PR we'll remove the code that handles batch splitting.

Each custom predicate needs 148.2 gates (which is very close to my estimate of 142.7 in https://github.com/0xPARC/pod2/issues/466#issuecomment-3823531286 where I actually made a mistake and considered 5 predicates per batch instead of 4 in the previous Params).
This commit is contained in:
Eduard S. 2026-02-04 11:12:32 +01:00 committed by GitHub
parent a7a30176a7
commit 641d8dabdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 331 additions and 761 deletions

View file

@ -1,13 +1,16 @@
use std::{fmt, iter, sync::Arc};
use std::{collections::HashMap, fmt, iter, sync::Arc};
use itertools::Itertools;
use plonky2::field::types::Field;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use crate::middleware::{
hash_fields, Error, Hash, Key, NativePredicate, Params, Predicate, Result, ToFields, Value,
BASE_PARAMS, EMPTY_HASH, F, VALUE_SIZE,
use crate::{
backends::plonky2::primitives::merkletree::MerkleTree,
middleware::{
hash_fields, Error, Hash, Key, NativePredicate, Params, Predicate, RawValue, Result,
ToFields, Value, BASE_PARAMS, F, VALUE_SIZE,
},
};
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
@ -420,83 +423,142 @@ impl fmt::Display for CustomPredicate {
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)]
enum CustomPredicateBatchData {
Full {
#[serde(skip)]
#[schemars(skip)]
mt: MerkleTree,
predicates: Vec<CustomPredicate>,
},
Opaque {
id: Hash,
},
}
// TODO: Rename Batch for Module everywhere in the code base
impl CustomPredicateBatchData {
fn new_full(predicates: Vec<CustomPredicate>) -> Self {
let kvs: HashMap<RawValue, RawValue> = predicates
.iter()
.enumerate()
.map(|(index, pred)| {
let cp_hash = hash_fields(&pred.to_fields());
(Value::from(index as i64).raw(), Value::from(cp_hash).raw())
})
.collect();
let mt = MerkleTree::new(&kvs);
Self::Full { mt, predicates }
}
fn new_opaque(id: Hash) -> Self {
Self::Opaque { id }
}
}
impl<'de> Deserialize<'de> for CustomPredicateBatchData {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
enum Aux {
Full { predicates: Vec<CustomPredicate> },
Opaque { id: Hash },
}
let aux = Aux::deserialize(deserializer)?;
Ok(match aux {
Aux::Opaque { id } => Self::new_opaque(id),
Aux::Full { predicates } => Self::new_full(predicates),
})
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub struct CustomPredicateBatch {
id: Hash,
pub name: String,
pub(crate) predicates: Vec<CustomPredicate>,
data: CustomPredicateBatchData,
}
impl std::hash::Hash for CustomPredicateBatch {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}
impl ToFields for CustomPredicateBatch {
fn to_fields(&self) -> Vec<F> {
// all the custom predicates in order
let pad_pred = CustomPredicate::empty();
self.predicates
.iter()
.chain(iter::repeat(&pad_pred))
.take(BASE_PARAMS.max_custom_batch_size)
.flat_map(|p| p.to_fields())
.collect_vec()
self.id().hash(state);
}
}
impl CustomPredicateBatch {
pub fn new(_params: &Params, name: String, predicates: Vec<CustomPredicate>) -> Arc<Self> {
let mut cpb = Self {
id: EMPTY_HASH,
pub fn new(name: String, predicates: Vec<CustomPredicate>) -> Arc<Self> {
Arc::new(Self {
name,
predicates,
};
let id = cpb.calculate_id();
cpb.id = id;
Arc::new(cpb)
data: CustomPredicateBatchData::new_full(predicates),
})
}
pub fn new_opaque(name: String, id: Hash) -> Arc<Self> {
Arc::new(Self {
id,
name,
predicates: vec![],
data: CustomPredicateBatchData::Opaque { id },
})
}
/// Cryptographic identifier for the batch.
fn calculate_id(&self) -> Hash {
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
// but ideally we want to use the root of a merkle tree built from the custom predicates.
let input = self.to_fields();
hash_fields(&input)
}
pub fn id(&self) -> Hash {
self.id
match &self.data {
CustomPredicateBatchData::Opaque { id } => *id,
CustomPredicateBatchData::Full { mt, .. } => mt.root(),
}
}
pub fn predicates(&self) -> &[CustomPredicate] {
&self.predicates
match &self.data {
// TODO: Return Option here instead of panic
CustomPredicateBatchData::Opaque { .. } => panic!("opaque batch"),
CustomPredicateBatchData::Full { predicates, .. } => predicates,
}
}
pub fn mt(&self) -> &MerkleTree {
match &self.data {
// TODO: Return Option here instead of panic
CustomPredicateBatchData::Opaque { .. } => panic!("opaque batch"),
CustomPredicateBatchData::Full { mt, .. } => mt,
}
}
pub fn predicate_ref_by_name(
self: &Arc<CustomPredicateBatch>,
name: &str,
) -> Option<CustomPredicateRef> {
self.predicates
self.predicates()
.iter()
.enumerate()
.find_map(|(i, cp)| (cp.name == name).then(|| CustomPredicateRef::new(self.clone(), i)))
}
pub fn predicate_ref_by_index(
self: &Arc<CustomPredicateBatch>,
index: usize,
) -> Option<CustomPredicateRef> {
self.predicates()
.get(index)
.map(|_| CustomPredicateRef::new(self.clone(), index))
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct CustomPredicateRef {
pub batch: Arc<CustomPredicateBatch>,
pub index: usize,
}
impl std::hash::Hash for CustomPredicateRef {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
(self.batch.id(), self.index).hash(state);
}
}
impl PartialEq for CustomPredicateRef {
fn eq(&self, other: &Self) -> bool {
self.batch.id() == other.batch.id() && self.index == other.index
}
}
impl Eq for CustomPredicateRef {}
impl CustomPredicateRef {
pub fn new(batch: Arc<CustomPredicateBatch>, index: usize) -> Self {
Self { batch, index }
@ -505,7 +567,7 @@ impl CustomPredicateRef {
self.predicate().args_len
}
pub fn predicate(&self) -> &CustomPredicate {
&self.batch.predicates[self.index]
&self.batch.predicates()[self.index]
}
}
@ -556,7 +618,6 @@ mod tests {
p:product_of(S1, Constant, S2)
*/
let cust_pred_batch = CustomPredicateBatch::new(
&params,
"is_double".to_string(),
vec![CustomPredicate::and(
&params,
@ -637,7 +698,7 @@ mod tests {
)?;
let eth_friend_batch =
CustomPredicateBatch::new(&params, "eth_friend".to_string(), vec![eth_friend]);
CustomPredicateBatch::new("eth_friend".to_string(), vec![eth_friend]);
// 0
let eth_dos_base = CustomPredicate::and(
@ -714,7 +775,6 @@ mod tests {
)?;
let eth_dos_distance_batch = CustomPredicateBatch::new(
&params,
"ETHDoS_distance".to_string(),
vec![eth_dos_base, eth_dos_ind, eth_dos],
);