From 641d8dabdd39c3f228a3d9db05bfcf21e3b9d6c7 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Wed, 4 Feb 2026 11:12:32 +0100 Subject: [PATCH] Merkle tree for custom predicate batches (#471) Resolve https://github.com/0xPARC/pod2/issues/466 Now batches are identified by the root of a merkle tree that contains all the predicates (using sequential indices as keys). This means that the format to identify a custom predicate reference is still a hash + index, but the calculation of the hash is different. The MainPod circuit now isn't limited by number of batches but instead number of custom predicates; and for each one we verify a merkle proof to verify the batch id. I've removed a bunch of tests from lang that were testing splitting into multiple batches because there's no longer any need for that. In a future PR we'll remove the code that handles batch splitting. Each custom predicate needs 148.2 gates (which is very close to my estimate of 142.7 in https://github.com/0xPARC/pod2/issues/466#issuecomment-3823531286 where I actually made a mistake and considered 5 predicates per batch instead of 4 in the previous Params). --- book/src/custom.md | 2 +- book/src/customexample.md | 2 +- src/backends/plonky2/circuits/common.rs | 164 ++++++------ src/backends/plonky2/circuits/mainpod.rs | 91 ++++--- src/backends/plonky2/mainpod/mod.rs | 59 +++-- .../plonky2/primitives/merkletree/mod.rs | 7 + src/examples/custom.rs | 8 +- src/frontend/custom.rs | 4 +- src/frontend/multi_pod/mod.rs | 215 ---------------- src/frontend/multi_pod/solver.rs | 8 - src/lang/frontend_ast_batch.rs | 234 ------------------ src/lang/frontend_ast_lower.rs | 62 ----- src/lang/frontend_ast_validate.rs | 2 +- src/lang/mod.rs | 42 ++-- src/lang/pretty_print.rs | 18 +- src/middleware/custom.rs | 154 ++++++++---- src/middleware/mod.rs | 20 +- 17 files changed, 331 insertions(+), 761 deletions(-) diff --git a/book/src/custom.md b/book/src/custom.md index d27f008..ccd9201 100644 --- a/book/src/custom.md +++ b/book/src/custom.md @@ -22,7 +22,7 @@ See [examples](./customexample.md) ## Hashing and predicate IDs -Each custom predicate is defined as part of a _group_ of predicates. The definitions of all statements in the group are laid out consecutively (see [examples](./customexample.md)) and hashed. For more details, see the pages on [hashing custom statements](./customhash.md) and [custom predicates](./custompred.md). +Each custom predicate is defined as part of a _group_ of predicates. The definitions of all statements in the group are merklelized (using sequential indices as keys) (see [examples](./customexample.md)) and the root of the merkle tree is used as the identifier. For more details, see the pages on [hashing custom statements](./customhash.md) and [custom predicates](./custompred.md). ## How to prove an application of an operation diff --git a/book/src/customexample.md b/book/src/customexample.md index fc1c8c9..1310b9b 100644 --- a/book/src/customexample.md +++ b/book/src/customexample.md @@ -37,7 +37,7 @@ SELF.1(?1, ?2, ?3, ?4, ?5, ?6) = OR( ``` and similarly for the other two definitions. -The above definition is serialized in-circuit and hashed with a zk-friendly hash to generate the "group hash", a unique cryptographic identifier for the group. +The above definition is serialized in-circuit and merkelized with a zk-friendly hash to generate the "group hash", a unique cryptographic identifier for the group. Then the individual statements in the group are identified as: ``` diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 5260c96..db8c32a 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -28,14 +28,17 @@ use crate::{ circuits::mainpod::CustomPredicateVerification, error::Result, mainpod::{Operation, OperationArg, OperationAux, Statement}, - primitives::merkletree::{MerkleClaimAndProofTarget, MerkleTreeStateTransitionProofTarget}, + primitives::merkletree::{ + verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget, + MerkleProof, MerkleTreeStateTransitionProofTarget, + }, }, middleware::{ - CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, - NativePredicate, OperationType, Params, Predicate, PredicateOrWildcard, - PredicateOrWildcardPrefix, PredicatePrefix, RawValue, StatementArg, StatementTmpl, - StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, - STATEMENT_ARG_F_LEN, VALUE_SIZE, + hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate, + OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix, + PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg, + StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, STATEMENT_ARG_F_LEN, + VALUE_SIZE, }, }; @@ -688,34 +691,65 @@ impl CustomPredicateTarget { } } -/// This type is used to build the custom predicate table, which exposes the custom predicates with -/// normalized statement templates indexed by batch_id and custom_predicate_index. +/// Custom predicate structure that can be verified to belong to a batch id at a particular index #[derive(Clone, Serialize, Deserialize)] -pub struct CustomPredicateBatchTarget { - pub predicates: Vec, +pub struct CustomPredicateInBatchTarget { + pub id: HashOutTarget, + pub index: Target, + /// Predicate that may use references to another predicate of the batch with BatchSelf + pub self_predicate: CustomPredicateTarget, + pub mtp: MerkleClaimAndProofTarget, } -impl CustomPredicateBatchTarget { - pub fn id(&self, builder: &mut CircuitBuilder) -> HashOutTarget { - let flattened: Vec<_> = self.predicates.iter().flat_map(|cp| cp.flatten()).collect(); - builder.hash_n_to_hash_no_pad::(flattened) - } +impl CustomPredicateInBatchTarget { + /// This constructor connects the merkle proof and claim targets with with the (index, + /// self_predicate) and id. + pub fn new_virtual(builder: &mut CircuitBuilder) -> CustomPredicateInBatchTarget { + let index = builder.add_virtual_target(); + let self_predicate = builder.add_virtual_custom_predicate(true); + // Existence Merkle Tree proof of (index, hash(self_predicate)) -> id + let mtp = + MerkleClaimAndProofTarget::new_virtual(Params::max_depth_custom_batch_mt(), builder); + let _true = builder._true(); + builder.connect(_true.target, mtp.enabled.target); + builder.connect(_true.target, mtp.existence.target); + let zero = builder.constant(F(0)); + let key = ValueTarget { + elements: [index, zero, zero, zero], + }; + builder.connect_values(key, mtp.key); + let id = mtp.root; + Self { + id, + index, + mtp, + self_predicate, + } + } + /// Hash the predicate, connect it to the merkle proof claim value and verify the merkle proof. + pub fn verify_circuit(&self, builder: &mut CircuitBuilder) { + let value = builder.hash_n_to_hash_no_pad::(self.self_predicate.flatten()); + builder.connect_array(value.elements, self.mtp.value.elements); + verify_merkle_proof_circuit(builder, &self.mtp); + } pub fn set_targets( &self, pw: &mut PartialWitness, - custom_predicate_batch: &CustomPredicateBatch, + predicate_ref: &CustomPredicateRef, + mtp: &MerkleProof, ) -> Result<()> { - let pad_predicate = CustomPredicate::empty(); - for (i, predicate) in custom_predicate_batch - .predicates() - .iter() - .chain(iter::repeat(&pad_predicate)) - .take(Params::max_custom_batch_size()) - .enumerate() - { - self.predicates[i].set_targets(pw, predicate)?; - } + pw.set_target_arr(&self.id.elements, &predicate_ref.batch.id().0)?; + pw.set_target(self.index, F::from_canonical_usize(predicate_ref.index))?; + let predicate = predicate_ref.predicate(); + self.self_predicate.set_targets(pw, predicate)?; + let mtp_claim = MerkleClaimAndProof { + root: predicate_ref.batch.id(), + key: Value::from(predicate_ref.index as i64).raw(), + value: RawValue::from(hash_fields(&predicate.to_fields())), + proof: mtp.clone(), + }; + self.mtp.set_targets(pw, true, &mtp_claim)?; Ok(()) } } @@ -812,11 +846,9 @@ pub struct CustomPredicateVerifyEntryTarget { impl CustomPredicateVerifyEntryTarget { pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { - let custom_predicate_table_len = - params.max_custom_predicate_batches * Params::max_custom_batch_size(); CustomPredicateVerifyEntryTarget { custom_predicate_table_index: IndexTarget::new_virtual( - custom_predicate_table_len, + params.max_custom_predicates, builder, ), custom_predicate: builder.add_virtual_custom_predicate_entry(), @@ -1245,8 +1277,6 @@ pub trait CircuitBuilderPod, const D: usize> { fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget; fn add_virtual_statement_tmpl(&mut self, with_pred: bool) -> StatementTmplTarget; fn add_virtual_custom_predicate(&mut self, with_pred: bool) -> CustomPredicateTarget; - fn add_virtual_custom_predicate_batch(&mut self, with_pred: bool) - -> CustomPredicateBatchTarget; fn add_virtual_custom_predicate_entry(&mut self) -> CustomPredicateEntryTarget; fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; fn select_statement_arg( @@ -1435,18 +1465,6 @@ impl CircuitBuilderPod for CircuitBuilder { } } - /// See `add_virtual_statement_tmpl` for the meaning of `with_pred`. - fn add_virtual_custom_predicate_batch( - &mut self, - with_pred: bool, - ) -> CustomPredicateBatchTarget { - CustomPredicateBatchTarget { - predicates: (0..Params::max_custom_batch_size()) - .map(|_| self.add_virtual_custom_predicate(with_pred)) - .collect(), - } - } - /// See `add_virtual_statement_tmpl` for the meaning of `with_pred`. fn add_virtual_custom_predicate_entry(&mut self) -> CustomPredicateEntryTarget { CustomPredicateEntryTarget { @@ -1869,6 +1887,8 @@ impl SimpleGenerator for LtMaskGenerator { #[cfg(test)] pub(crate) mod tests { + use std::sync::Arc; + use anyhow::anyhow; use itertools::Itertools; use plonky2::plonk::{ @@ -1878,8 +1898,10 @@ pub(crate) mod tests { use super::*; use crate::{ - backends::plonky2::basetypes::C, examples::custom::eth_dos_batch, frontend, - frontend::CustomPredicateBatchBuilder, middleware::CustomPredicateBatch, + backends::plonky2::basetypes::C, + examples::custom::eth_dos_batch, + frontend::{self, CustomPredicateBatchBuilder}, + middleware::CustomPredicateBatch, }; pub(crate) const I64_TEST_PAIRS: [(i64, i64); 36] = [ @@ -1952,50 +1974,54 @@ pub(crate) mod tests { Ok(()) } - fn helper_custom_predicate_batch_target_id( - custom_predicate_batch: &CustomPredicateBatch, + fn helper_custom_predicate_in_batch_target( + custom_predicate_batch: &Arc, ) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); + for index in 0..custom_predicate_batch.predicates().len() { + let cpr = custom_predicate_batch + .predicate_ref_by_index(index) + .unwrap(); - let custom_predicate_batch_target = builder.add_virtual_custom_predicate_batch(false); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); - // Calculate the id in constraints and compare it against the id calculated natively - let id_target = custom_predicate_batch_target.id(&mut builder); + let custom_pred_in_batch_target = + CustomPredicateInBatchTarget::new_virtual(&mut builder); + custom_pred_in_batch_target.verify_circuit(&mut builder); - let mut pw = PartialWitness::::new(); - custom_predicate_batch_target.set_targets(&mut pw, custom_predicate_batch)?; - let id = custom_predicate_batch.id(); - pw.set_target_arr(&id_target.elements, &id.0)?; + let mut pw = PartialWitness::::new(); + let (_, mtp) = custom_predicate_batch + .mt() + .prove(&Value::from(index as i64).raw()) + .unwrap(); + custom_pred_in_batch_target.set_targets(&mut pw, &cpr, &mtp)?; - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof.clone()).unwrap(); + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + } Ok(()) } #[test] - fn test_custom_predicate_batch_target_id() -> frontend::Result<()> { - let params = Params { - max_custom_predicate_wildcards: 12, - ..Default::default() - }; + fn test_custom_predicate_in_batch_target() -> frontend::Result<()> { + let params = Params::default(); // Empty case let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into()); _ = cpb_builder.predicate_and("empty", &[], &[], &[])?; let custom_predicate_batch = cpb_builder.finish(); - helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap(); + helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap(); // Some cases from the examples let custom_predicate_batch = eth_dos_batch(¶ms)?; - helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap(); + helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap(); let custom_predicate_batch = - CustomPredicateBatch::new(¶ms, "empty".to_string(), vec![CustomPredicate::empty()]); - helper_custom_predicate_batch_target_id(&custom_predicate_batch).unwrap(); + CustomPredicateBatch::new("empty".to_string(), vec![CustomPredicate::empty()]); + helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap(); Ok(()) } diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 845f445..ebe77b4 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -1,4 +1,4 @@ -use std::{array, iter, sync::Arc}; +use std::{array, iter}; use itertools::{izip, zip_eq, Itertools}; use num::{BigUint, One}; @@ -21,7 +21,7 @@ use crate::{ basetypes::{CircuitBuilder, VDSet}, circuits::{ common::{ - CircuitBuilderPod, CustomPredicateBatchTarget, CustomPredicateEntryTarget, + CircuitBuilderPod, CustomPredicateEntryTarget, CustomPredicateInBatchTarget, CustomPredicateTarget, CustomPredicateVerifyEntryTarget, CustomPredicateVerifyQueryTarget, Flattenable, MerkleClaimTarget, MerkleTreeStateTransitionClaimTarget, OperationTarget, OperationTypeTarget, @@ -44,7 +44,7 @@ use crate::{ }, merkletree::{ verify_merkle_proof_circuit, verify_merkle_state_transition_circuit, - MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleTreeOp, + MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProof, MerkleTreeOp, MerkleTreeStateTransitionProof, MerkleTreeStateTransitionProofTarget, }, signature::{verify_signature_circuit, SignatureVerifyTarget}, @@ -1573,37 +1573,34 @@ fn normalize_st_tmpl_circuit( fn build_custom_predicate_table_circuit( params: &Params, builder: &mut CircuitBuilder, - custom_predicate_batches: &[CustomPredicateBatchTarget], + custom_predicates: &[CustomPredicateInBatchTarget], ) -> Result> { let measure = measure_gates_begin!(builder, "BuildCustomPredTbl"); - let mut custom_predicate_table = - Vec::with_capacity(params.max_custom_predicate_batches * Params::max_custom_batch_size()); - for cpb in custom_predicate_batches { - let measure_cpb = measure_gates_begin!(builder, "CustomPredBatch"); - let id = cpb.id(builder); // constrain the id - for (index, cp) in cpb.predicates.iter().enumerate() { - let statements = cp - .statements - .iter() - .map(|st_with_pred_tmpl| { - normalize_st_tmpl_circuit(params, builder, st_with_pred_tmpl, id) - }) - .collect_vec(); - let cp = CustomPredicateTarget { - conjunction: cp.conjunction, + let mut custom_predicate_table = Vec::with_capacity(params.max_custom_predicates); + for cp in custom_predicates { + let measure_cp = measure_gates_begin!(builder, "CustomPred"); + cp.verify_circuit(builder); + let statements = cp + .self_predicate + .statements + .iter() + .map(|st_with_pred_tmpl| { + normalize_st_tmpl_circuit(params, builder, st_with_pred_tmpl, cp.id) + }) + .collect_vec(); + let entry = CustomPredicateEntryTarget { + id: cp.id, // output + index: cp.index, // input + predicate: CustomPredicateTarget { + conjunction: cp.self_predicate.conjunction, statements, - args_len: cp.args_len, - }; - let entry = CustomPredicateEntryTarget { - id, // output - index: builder.constant(F::from_canonical_usize(index)), // constant - predicate: cp.clone(), // input - }; + args_len: cp.self_predicate.args_len, + }, // input + }; - let in_query_hash = entry.hash(builder); - custom_predicate_table.push(in_query_hash); - } - measure_gates_end!(builder, measure_cpb); + let in_query_hash = entry.hash(builder); + custom_predicate_table.push(in_query_hash); + measure_gates_end!(builder, measure_cp); } measure_gates_end!(builder, measure); Ok(custom_predicate_table) @@ -1711,7 +1708,7 @@ fn verify_main_pod_circuit( // Table of custom predicate batches with batch_id calculation let custom_predicate_table = - build_custom_predicate_table_circuit(params, builder, &main_pod.custom_predicate_batches)?; + build_custom_predicate_table_circuit(params, builder, &main_pod.custom_predicates)?; let aux_table = build_operation_aux_table_circuit( params, @@ -1754,7 +1751,7 @@ pub struct MainPodVerifyTarget { public_key_of_sks: Vec, signed_bys: Vec, merkle_tree_state_transition_proofs: Vec, - custom_predicate_batches: Vec, + custom_predicates: Vec, custom_predicate_verifications: Vec, } @@ -1799,8 +1796,8 @@ impl MainPodVerifyTarget { ) }) .collect(), - custom_predicate_batches: (0..params.max_custom_predicate_batches) - .map(|_| builder.add_virtual_custom_predicate_batch(true)) + custom_predicates: (0..params.max_custom_predicates) + .map(|_| CustomPredicateInBatchTarget::new_virtual(builder)) .collect(), custom_predicate_verifications: (0..params.max_custom_predicate_verifications) .map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder)) @@ -1830,7 +1827,7 @@ pub struct MainPodVerifyInput { pub public_key_of_sks: Vec, pub signed_bys: Vec, pub merkle_tree_state_transition_proofs: Vec, - pub custom_predicate_batches: Vec>, + pub custom_predicates_with_mpt_proofs: Vec<(CustomPredicateRef, MerkleProof)>, pub custom_predicate_verifications: Vec, } @@ -1972,18 +1969,20 @@ impl InnerCircuit for MainPodVerifyTarget { self.merkle_tree_state_transition_proofs[i].set_targets(pw, false, &pad_mtp)?; } - assert!(input.custom_predicate_batches.len() <= self.params.max_custom_predicate_batches); - for (i, cpb) in input.custom_predicate_batches.iter().enumerate() { - self.custom_predicate_batches[i].set_targets(pw, cpb)?; + assert!(input.custom_predicates_with_mpt_proofs.len() <= self.params.max_custom_predicates); + for (i, (cp, mtp)) in input.custom_predicates_with_mpt_proofs.iter().enumerate() { + self.custom_predicates[i].set_targets(pw, cp, mtp)?; } // Padding - let pad_cpb = CustomPredicateBatch::new( - &self.params, - "empty".to_string(), - vec![CustomPredicate::empty()], - ); - for i in input.custom_predicate_batches.len()..self.params.max_custom_predicate_batches { - self.custom_predicate_batches[i].set_targets(pw, &pad_cpb)?; + let pad_cpb = + CustomPredicateBatch::new("empty".to_string(), vec![CustomPredicate::empty()]); + let pad_cp = pad_cpb.predicate_ref_by_index(0).expect("index 0 exists"); + let (_, pad_mtp) = pad_cpb + .mt() + .prove(&Value::from(0i64).raw()) + .expect("exists"); + for i in input.custom_predicates_with_mpt_proofs.len()..self.params.max_custom_predicates { + self.custom_predicates[i].set_targets(pw, &pad_cp, &pad_mtp)?; } assert!( @@ -2096,7 +2095,7 @@ mod tests { .merkle_tree_state_transition_proofs .len(), max_custom_predicate_verifications: 0, - max_custom_predicate_batches: 0, + max_custom_predicates: 0, ..Default::default() }; diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 83190e9..e6f5329 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1,7 +1,7 @@ pub mod operation; use crate::middleware::{wildcard_values_from_op_st, PodType}; pub mod statement; -use std::{iter, sync::Arc}; +use std::iter; use itertools::{zip_eq, Itertools}; use num_bigint::BigUint; @@ -37,9 +37,9 @@ use crate::{ serialize_proof, serialize_verifier_only, }, middleware::{ - self, value_from_op, CustomPredicateBatch, Error as MiddlewareError, Hash, MainPodInputs, + self, value_from_op, CustomPredicateRef, Error as MiddlewareError, Hash, MainPodInputs, MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg, - ToFields, VDSet, + ToFields, VDSet, Value, }, timed, }; @@ -68,27 +68,27 @@ pub fn calculate_statements_hash(statements: &[Statement]) -> middleware::Hash { Hash(PoseidonHash::hash_no_pad(&field_elems).elements) } -/// Extracts unique `CustomPredicateBatch`es from Custom ops. -pub(crate) fn extract_custom_predicate_batches( +/// Extracts unique `CustomPredicate`s from Custom ops. +pub(crate) fn extract_custom_predicates( params: &Params, operations: &[middleware::Operation], -) -> Result>> { - let custom_predicate_batches: Vec<_> = operations +) -> Result> { + let custom_predicates: Vec<_> = operations .iter() .flat_map(|op| match op { - middleware::Operation::Custom(cpr, _) => Some(cpr.batch.clone()), + middleware::Operation::Custom(cpr, _) => Some(cpr.clone()), _ => None, }) - .unique_by(|cpr| cpr.id()) + .unique() .collect(); - if custom_predicate_batches.len() > params.max_custom_predicate_batches { + if custom_predicates.len() > params.max_custom_predicates { return Err(Error::custom(format!( - "The number of required `CustomPredicateBatch`es ({}) exceeds the maximum number ({}).", - custom_predicate_batches.len(), - params.max_custom_predicate_batches + "The number of required `CustomPredicate`s ({}) exceeds the maximum number ({}).", + custom_predicates.len(), + params.max_custom_predicates ))); } - Ok(custom_predicate_batches) + Ok(custom_predicates) } /// Extracts all custom predicate operations with all the data required to verify them. @@ -97,7 +97,7 @@ pub(crate) fn extract_custom_predicate_verifications( aux_list: &mut [OperationAux], operations: &[middleware::Operation], statements: &[middleware::Statement], - custom_predicate_batches: &[Arc], + custom_predicates: &[CustomPredicateRef], ) -> Result> { let mut table = Vec::new(); for (i, (op, st)) in zip_eq(operations.iter(), statements.iter()).enumerate() { @@ -108,13 +108,11 @@ pub(crate) fn extract_custom_predicate_verifications( wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args) .expect("resolved wildcards"); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); - let batch_index = custom_predicate_batches + let custom_predicate_table_index = custom_predicates .iter() .enumerate() - .find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i)) + .find_map(|(i, table_cpr)| (table_cpr == cpr).then_some(i)) .expect("find the custom predicate from the extracted unique list"); - let custom_predicate_table_index = - batch_index * Params::max_custom_batch_size() + cpr.index; aux_list[i] = OperationAux::CustomPredVerifyIndex(table.len()); table.push(CustomPredicateVerification { custom_predicate_table_index, @@ -497,14 +495,25 @@ impl MainPodProver for Prover { let mut aux_list = vec![OperationAux::None; params.max_priv_statements()]; let merkle_proofs = extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; - let custom_predicate_batches = extract_custom_predicate_batches(params, inputs.operations)?; + let custom_predicates = extract_custom_predicates(params, inputs.operations)?; let custom_predicate_verifications = extract_custom_predicate_verifications( params, &mut aux_list, inputs.operations, inputs.statements, - &custom_predicate_batches, + &custom_predicates, )?; + let custom_predicates_with_mpt_proofs = custom_predicates + .into_iter() + .map(|cpr| { + let (_, mtp) = cpr + .batch + .mt() + .prove(&Value::from(cpr.index as i64).raw()) + .expect("index by construction exists"); + (cpr, mtp) + }) + .collect_vec(); let public_key_of_sks = extract_public_key_of(params, &mut aux_list, inputs.operations, inputs.statements)?; let signed_bys = @@ -572,7 +581,7 @@ impl MainPodProver for Prover { public_key_of_sks, signed_bys, merkle_tree_state_transition_proofs, - custom_predicate_batches, + custom_predicates_with_mpt_proofs, custom_predicate_verifications, }; @@ -840,7 +849,7 @@ pub mod tests { // Currently the circuit uses random access that only supports vectors of length 64. // With max_input_main_pods=3 we need random access to a vector of length 73. max_input_pods: 0, - max_custom_predicate_batches: 0, + max_custom_predicates: 0, max_custom_predicate_verifications: 0, ..Default::default() }; @@ -961,7 +970,7 @@ pub mod tests { max_merkle_proofs_containers: 0, max_public_key_of: 0, max_custom_predicate_verifications: 0, - max_custom_predicate_batches: 0, + max_custom_predicates: 0, ..Default::default() }; let mut vds = DEFAULT_VD_LIST.clone(); @@ -995,7 +1004,7 @@ pub mod tests { max_statements: 5, max_public_statements: 2, max_operation_args: 5, - max_custom_predicate_batches: 2, + max_custom_predicates: 2, max_custom_predicate_verifications: 2, max_custom_predicate_wildcards: 3, max_merkle_proofs_containers: 2, diff --git a/src/backends/plonky2/primitives/merkletree/mod.rs b/src/backends/plonky2/primitives/merkletree/mod.rs index 07f87dc..9b3609b 100644 --- a/src/backends/plonky2/primitives/merkletree/mod.rs +++ b/src/backends/plonky2/primitives/merkletree/mod.rs @@ -23,6 +23,13 @@ pub struct MerkleTree { root: Node, } +impl PartialEq for MerkleTree { + fn eq(&self, other: &Self) -> bool { + self.root() == other.root() + } +} +impl Eq for MerkleTree {} + impl MerkleTree { /// builds a new `MerkleTree` where the leaves contain the given key-values pub fn new(kvs: &HashMap) -> Self { diff --git a/src/examples/custom.rs b/src/examples/custom.rs index 8c4550c..b64b8a4 100644 --- a/src/examples/custom.rs +++ b/src/examples/custom.rs @@ -37,10 +37,10 @@ pub fn eth_dos_batch(params: &Params) -> Result> { .first_batch() .expect("Expected batch") .clone(); - println!("a.0. {}", batch.predicates[0]); - println!("a.1. {}", batch.predicates[1]); - println!("a.2. {}", batch.predicates[2]); - println!("a.3. {}", batch.predicates[3]); + println!("a.0. {}", batch.predicates()[0]); + println!("a.1. {}", batch.predicates()[1]); + println!("a.2. {}", batch.predicates()[2]); + println!("a.3. {}", batch.predicates()[3]); Ok(batch) } diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index fec897c..92fdc4f 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -244,7 +244,7 @@ impl CustomPredicateBatchBuilder { } pub fn finish(self) -> Arc { - CustomPredicateBatch::new(&self.params, self.name, self.predicates) + CustomPredicateBatch::new(self.name, self.predicates) } } @@ -291,8 +291,6 @@ mod tests { let eth_dos_batch_mw: middleware::CustomPredicateBatch = Arc::unwrap_or_clone(eth_dos_batch); - let fields = eth_dos_batch_mw.to_fields(); - println!("Batch b, serialized: {:?}", fields); Ok(()) } diff --git a/src/frontend/multi_pod/mod.rs b/src/frontend/multi_pod/mod.rs index 2abe85d..73fd8c8 100644 --- a/src/frontend/multi_pod/mod.rs +++ b/src/frontend/multi_pod/mod.rs @@ -531,8 +531,6 @@ fn build_external_statement_map(input_pods: &[MainPod]) -> HashMap Result<()> { - // Verifies that the solver respects max_custom_predicate_batches per POD (C7). - // - // Setup: - // - max_custom_predicate_batches = 2 (small limit) - // - 4 different batches, each with one simple predicate - // - 4 operations, one from each batch - // - // Expected: Solver creates exactly 2 PODs since 4 batches / 2 per POD = 2 PODs - let params = Params { - max_statements: 48, - max_public_statements: 8, - max_custom_predicate_batches: 2, // Small limit to force splitting - max_input_pods: 10, - max_input_pods_public_statements: 20, - ..Params::default() - }; - let vd_set = &*MOCK_VD_SET; - - // Create 4 separate batches using podlang parser - // Each batch has a simple predicate that checks a Contains statement - let parsed1 = - parse(r#"pred1(A) = AND(Contains(A, "x", 1))"#, ¶ms, &[]).expect("parse batch1"); - let batch1 = parsed1 - .first_batch() - .expect("parse batch1 should have a batch"); - - let parsed2 = - parse(r#"pred2(A) = AND(Contains(A, "x", 2))"#, ¶ms, &[]).expect("parse batch2"); - let batch2 = parsed2 - .first_batch() - .expect("parse batch2 should have a batch"); - - let parsed3 = - parse(r#"pred3(A) = AND(Contains(A, "x", 3))"#, ¶ms, &[]).expect("parse batch3"); - let batch3 = parsed3 - .first_batch() - .expect("parse batch3 should have a batch"); - - let parsed4 = - parse(r#"pred4(A) = AND(Contains(A, "x", 4))"#, ¶ms, &[]).expect("parse batch4"); - let batch4 = parsed4 - .first_batch() - .expect("parse batch4 should have a batch"); - - let mut builder = MultiPodBuilder::new(¶ms, vd_set); - - // Add operations using predicates from each batch - // Each custom predicate needs a Contains statement argument - let dict1 = dict!({"x" => 1}); - let contains1 = builder.priv_op(FrontendOp::dict_contains(dict1, "x", 1))?; - builder.priv_op(FrontendOp::custom( - batch1.predicate_ref_by_name("pred1").unwrap(), - [contains1], - ))?; - - let dict2 = dict!({"x" => 2}); - let contains2 = builder.priv_op(FrontendOp::dict_contains(dict2, "x", 2))?; - builder.priv_op(FrontendOp::custom( - batch2.predicate_ref_by_name("pred2").unwrap(), - [contains2], - ))?; - - let dict3 = dict!({"x" => 3}); - let contains3 = builder.priv_op(FrontendOp::dict_contains(dict3, "x", 3))?; - builder.priv_op(FrontendOp::custom( - batch3.predicate_ref_by_name("pred3").unwrap(), - [contains3], - ))?; - - let dict4 = dict!({"x" => 4}); - let contains4 = builder.priv_op(FrontendOp::dict_contains(dict4, "x", 4))?; - builder.pub_op(FrontendOp::custom( - batch4.predicate_ref_by_name("pred4").unwrap(), - [contains4], - ))?; - - let solved = builder.solve()?; - // 4 batches / 2 per POD = exactly 2 PODs - assert_eq!( - solved.solution().pod_count, - 2, - "Expected exactly 2 PODs for 4 batches with max_custom_predicate_batches=2, got {}", - solved.solution().pod_count - ); - let pod_count = solved.solution().pod_count; - - // Prove and verify - let prover = MockProver {}; - let result = solved.prove(&prover)?; - assert_eq!(result.pods.len(), pod_count); - - for (i, pod) in result.pods.iter().enumerate() { - pod.pod - .verify() - .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; - } - - Ok(()) - } - #[test] fn test_long_dependency_chain_spans_multiple_pods() -> Result<()> { // Verifies that a long dependency chain correctly cascades through multiple @@ -1717,115 +1613,4 @@ mod tests { Ok(()) } - - #[test] - fn test_dependency_chain_with_batch_limit() -> Result<()> { - // Verifies that dependency chains work correctly when combined with - // batch cardinality limits. - // - // Setup: Two predicates in DIFFERENT batches, where pred_b depends on pred_a. - // With max_custom_predicate_batches = 1, pred_a and pred_b must be in - // different PODs due to the batch limit. The dependency must still be - // satisfied via cross-POD copying. - - let params = Params { - max_statements: 10, - max_public_statements: 4, - max_input_pods: 4, - max_input_pods_public_statements: 20, - max_custom_predicate_batches: 1, // Only 1 batch per POD - max_custom_predicate_verifications: 10, - ..Params::default() - }; - let vd_set = &*MOCK_VD_SET; - - // Create two SEPARATE batches (parsed separately to get different batch IDs) - let parsed_a = - parse(r#"pred_a(X) = AND(Contains(X, "k", 1))"#, ¶ms, &[]).expect("parse batch_a"); - let batch_a = parsed_a - .first_batch() - .expect("parse batch_a should have a batch"); - - // batch_b's pred_b accepts pred_a statements - // Must use "use batch" syntax to reference external predicates - let batch_a_id = batch_a.id().encode_hex::(); - let batch_b_src = format!( - r#" - use batch pred_a from 0x{batch_a_id} - pred_b(X) = AND(pred_a(X)) - "# - ); - let parsed_b = - parse(&batch_b_src, ¶ms, std::slice::from_ref(batch_a)).expect("parse batch_b"); - let batch_b = parsed_b - .first_batch() - .expect("parse batch_b should have a batch"); - - let mut builder = MultiPodBuilder::new(¶ms, vd_set); - - // Statement 0: Contains (no batch) - let dict = dict!({"k" => 1}); - let contains = builder.priv_op(FrontendOp::dict_contains(dict, "k", 1))?; - - // Statement 1: pred_a (batch A) - let a_out = builder.priv_op(FrontendOp::custom( - batch_a.predicate_ref_by_name("pred_a").unwrap(), - [contains], - ))?; - - // Statement 2: pred_b (batch B) - depends on a_out - // With max_custom_predicate_batches = 1, this MUST be in a different POD - let _b_out = builder.pub_op(FrontendOp::custom( - batch_b.predicate_ref_by_name("pred_b").unwrap(), - [a_out], - ))?; - - let solved = builder.solve()?; - let solution = solved.solution(); - - // Expected: exactly 2 PODs due to batch limit - // - POD 0: contains(0), a_out(1) using batch_a; a_out public - // - POD 1 (output): b_out(2) using batch_b; b_out public - // - // Even though max_priv_statements=6 could fit all 3 statements, - // max_custom_predicate_batches=1 forces batch_a and batch_b into different PODs. - assert_eq!( - solution.pod_count, 2, - "Expected exactly 2 PODs due to batch limit (max_custom_predicate_batches=1)" - ); - - // POD 0: contains(0), a_out(1) - assert!( - solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1), - "POD 0 should contain statements 0 and 1, got {:?}", - solution.pod_statements[0] - ); - assert!( - solution.pod_public_statements[0].contains(&1), - "Statement 1 (a_out) should be public in POD 0" - ); - - // POD 1 (output): b_out(2) - assert!( - solution.pod_statements[1].contains(&2), - "POD 1 should contain statement 2 (b_out), got {:?}", - solution.pod_statements[1] - ); - assert!( - solution.pod_public_statements[1].contains(&2), - "Statement 2 (b_out) should be public in output POD" - ); - - // Prove and verify - let prover = MockProver {}; - let result = solved.prove(&prover)?; - - for (i, pod) in result.pods.iter().enumerate() { - pod.pod - .verify() - .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; - } - - Ok(()) - } } diff --git a/src/frontend/multi_pod/solver.rs b/src/frontend/multi_pod/solver.rs index 4441c27..2480782 100644 --- a/src/frontend/multi_pod/solver.rs +++ b/src/frontend/multi_pod/solver.rs @@ -461,14 +461,6 @@ fn try_solve_with_pods( } } - // Batch count per POD - for p in 0..target_pods { - let batch_sum: Expression = (0..all_batches.len()).map(|b| batch_used[b][p]).sum(); - model.add_constraint(constraint!( - batch_sum <= (input.params.max_custom_predicate_batches as f64) * pod_used[p] - )); - } - // Constraint 7b: Anchored key tracking // // anchored_key_used[ak][p] = 1 when auto-insertion of a Contains is needed for anchored key ak in POD p. diff --git a/src/lang/frontend_ast_batch.rs b/src/lang/frontend_ast_batch.rs index 61b167e..37450cd 100644 --- a/src/lang/frontend_ast_batch.rs +++ b/src/lang/frontend_ast_batch.rs @@ -768,37 +768,6 @@ mod tests { assert_eq!(batches.total_predicate_count(), 3); } - #[test] - fn test_predicates_span_multiple_batches() { - let input = r#" - pred1(A) = AND(Equal(A["x"], 1)) - pred2(B) = AND(Equal(B["y"], 2)) - pred3(C) = AND(Equal(C["z"], 3)) - pred4(D) = AND(Equal(D["w"], 4)) - pred5(E) = AND(Equal(E["v"], 5)) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); // max_custom_batch_size = 4 - - let result = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ); - assert!(result.is_ok()); - - let batches = result.unwrap(); - assert_eq!(batches.batch_count(), 2); - assert_eq!(batches.total_predicate_count(), 5); - - // First batch should have 4 predicates - assert_eq!(batches.batches()[0].predicates().len(), 4); - // Second batch should have 1 predicate - assert_eq!(batches.batches()[1].predicates().len(), 1); - } - #[test] fn test_intra_batch_forward_reference() { // pred2 calls pred1, but pred2 is declared first @@ -869,132 +838,6 @@ mod tests { )); // calls pred1 } - #[test] - fn test_cross_batch_reference() { - // 5 predicates where pred5 calls pred1 - // pred1-4 go in batch 0, pred5 in batch 1 - // pred5's call to pred1 should be a cross-batch reference - let input = r#" - pred1(A) = AND(Equal(A["x"], 1)) - pred2(B) = AND(Equal(B["y"], 2)) - pred3(C) = AND(Equal(C["z"], 3)) - pred4(D) = AND(Equal(D["w"], 4)) - pred5(E) = AND(pred1(E)) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); // max_custom_batch_size = 4 - - let result = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ); - assert!(result.is_ok()); - - let batches = result.unwrap(); - assert_eq!(batches.batch_count(), 2); - - // pred5 should reference pred1 via CustomPredicateRef - let pred5_batch = &batches.batches()[1]; - let pred5 = &pred5_batch.predicates()[0]; - let pred5_stmt = &pred5.statements[0]; - - // The predicate should be a Custom reference to batch 0 - match pred5_stmt.pred_or_wc() { - PredicateOrWildcard::Predicate(Predicate::Custom(ref_)) => { - // Should reference batch 0, index 0 (pred1) - assert_eq!(ref_.batch.id(), batches.batches()[0].id()); - } - _ => panic!("Expected Custom predicate reference"), - } - } - - #[test] - fn test_split_chain_spans_batches() { - // Create a predicate that will split into 2-3 predicates - // Then add more predicates to force the chain to span batches - let input = r#" - pred1(A) = AND(Equal(A["x"], 1)) - pred2(B) = AND(Equal(B["y"], 2)) - pred3(C) = AND(Equal(C["z"], 3)) - large_pred(D) = AND( - Equal(D["a"], 1) - Equal(D["b"], 2) - Equal(D["c"], 3) - Equal(D["d"], 4) - Equal(D["e"], 5) - Equal(D["f"], 6) - ) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - // Split the large predicate - let mut all_split_results = Vec::new(); - for pred in predicates { - let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); - all_split_results.push(result); - } - - // Count total predicates across all split results - let total_preds: usize = all_split_results.iter().map(|r| r.predicates.len()).sum(); - - // We should have: pred1, pred2, pred3, large_pred_1 (continuation), large_pred - // That's 5 predicates, which spans 2 batches - assert_eq!(total_preds, 5); - - let result = batch_predicates(all_split_results, ¶ms, "TestBatch", validated.symbols()); - assert!(result.is_ok()); - - let batches = result.unwrap(); - assert_eq!(batches.batch_count(), 2); - assert_eq!(batches.total_predicate_count(), 5); - - // Verify chain info was captured - let chain_info = batches.split_chain("large_pred"); - assert!(chain_info.is_some()); - let info = chain_info.unwrap(); - assert_eq!(info.original_name, "large_pred"); - assert_eq!(info.real_statement_count, 6); - } - - #[test] - fn test_forward_cross_batch_reference_avoided_by_planner() { - // 5 predicates where pred4 calls pred5 (forward declaration) - // With max_custom_batch_size = 4, naive packing would place pred5 in batch 1 - // The dependency-aware planner should instead pack pred5 before pred4 - // to avoid a forward cross-batch reference. - let input = r#" - pred1(A) = AND(Equal(A["x"], 1)) - pred2(B) = AND(Equal(B["y"], 2)) - pred3(C) = AND(Equal(C["z"], 3)) - pred4(D) = AND(pred5(D)) - pred5(E) = AND(Equal(E["v"], 5)) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); // max_custom_batch_size = 4 - - let batches = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ) - .expect("Planner should avoid forward cross-batch reference"); - - // Expect two batches and the reference to point within the same batch or earlier batch. - assert_eq!(batches.batch_count(), 2); - // pred5 should be in batch 0 and pred4 in batch 1 (given stable topo + packing) - let pred5_ref = batches.predicate_ref_by_name("pred5").unwrap(); - let pred4_ref = batches.predicate_ref_by_name("pred4").unwrap(); - assert_eq!(pred5_ref.batch.id(), batches.batches()[0].id()); - assert_eq!(pred4_ref.batch.id(), batches.batches()[1].id()); - } - #[test] fn test_empty_input() { let split_results: Vec = vec![]; @@ -1037,83 +880,6 @@ mod tests { assert!(batches.predicate_ref_by_name("nonexistent").is_none()); } - #[test] - fn test_mutual_recursion_exceeds_capacity_error() { - // Two predicates that call each other (SCC size = 5) with max batch size 4 - // Should error because an SCC cannot be split across batches - let input = r#" - pred1(A) = AND(pred2(A)) - pred2(B) = AND(pred3(B)) - pred3(B) = AND(pred4(B)) - pred4(B) = AND(pred5(B)) - pred5(B) = AND(pred1(B)) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); - - let result = batch_predicates( - preds_to_split_results(predicates), - ¶ms, - "TestBatch", - validated.symbols(), - ); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("exceeds batch capacity")); - } - - #[test] - fn test_split_chain_across_batches_placement() { - // Create a large predicate that splits into 2 pieces, plus enough predicates - // to force the chain to span batches; verify continuation is placed earlier batch - let input = r#" - p1(A) = AND(Equal(A["x"], 1)) - p2(B) = AND(Equal(B["y"], 2)) - p3(C) = AND(Equal(C["z"], 3)) - large_pred(D) = AND( - Equal(D["a"], 1) - Equal(D["b"], 2) - Equal(D["c"], 3) - Equal(D["d"], 4) - Equal(D["e"], 5) - Equal(D["f"], 6) - ) - "#; - - let (predicates, validated) = parse_and_validate(input); - let params = Params::default(); // max_custom_batch_size = 4 - - // Split and batch - let mut all_split_results = Vec::new(); - for pred in predicates { - let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); - all_split_results.push(result); - } - let batches = - batch_predicates(all_split_results, ¶ms, "TestBatch", validated.symbols()) - .expect("Batch failed"); - - assert_eq!(batches.batch_count(), 2); - - // Verify chain info - let chain_info = batches - .split_chain("large_pred") - .expect("Expected chain info"); - assert_eq!(chain_info.chain_pieces.len(), 2); - // Expect continuation piece name to be large_pred_1 (innermost first) - let cont_name = &chain_info.chain_pieces[0].name; - assert_eq!(cont_name, "large_pred_1"); - - // Expect continuation in batch 0 and main in batch 1 - let cont_ref = batches.predicate_ref_by_name("large_pred_1").unwrap(); - let main_ref = batches.predicate_ref_by_name("large_pred").unwrap(); - assert_eq!(cont_ref.batch.id(), batches.batches()[0].id()); - assert_eq!(main_ref.batch.id(), batches.batches()[1].id()); - } - /// Helper: create a unique Statement for testing /// Uses Equal with distinct literal values to create distinguishable statements fn test_statement(id: usize) -> Statement { diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index b1db536..d020e37 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -681,68 +681,6 @@ mod tests { assert!(result.is_ok()); } - #[test] - fn test_multi_batch_packing() { - // Create more predicates than fit in a single batch - // With max_custom_batch_size = 4, 5 predicates should span 2 batches - let input = r#" - pred1(A) = AND(Equal(A["a"], 1)) - pred2(B) = AND(Equal(B["b"], 2)) - pred3(C) = AND(Equal(C["c"], 3)) - pred4(D) = AND(Equal(D["d"], 4)) - pred5(E) = AND(Equal(E["e"], 5)) - "#; - - let params = Params::default(); // max_custom_batch_size = 4 - - let result = parse_validate_and_lower(input, ¶ms); - assert!(result.is_ok()); - - let lowered = result.unwrap(); - let batches = lowered.batches.as_ref().expect("Expected batches"); - - // Should have 2 batches - assert_eq!(batches.batch_count(), 2); - assert_eq!(batches.total_predicate_count(), 5); - - // First batch should have 4 predicates - assert_eq!(batches.batches()[0].predicates().len(), 4); - // Second batch should have 1 predicate - assert_eq!(batches.batches()[1].predicates().len(), 1); - } - - #[test] - fn test_split_chains_span_batches() { - // Create predicates that will split, plus additional predicates - // to force the split chains across batch boundaries - let input = r#" - pred1(A) = AND(Equal(A["a"], 1)) - pred2(B) = AND(Equal(B["b"], 2)) - pred3(C) = AND(Equal(C["c"], 3)) - large_pred(D) = AND( - Equal(D["a"], 1) - Equal(D["b"], 2) - Equal(D["c"], 3) - Equal(D["d"], 4) - Equal(D["e"], 5) - Equal(D["f"], 6) - ) - "#; - - let params = Params::default(); - - let result = parse_validate_and_lower(input, ¶ms); - assert!(result.is_ok()); - - let lowered = result.unwrap(); - let batches = lowered.batches.as_ref().expect("Expected batches"); - - // pred1, pred2, pred3 + large_pred split into 2 = 5 total predicates - // Should span 2 batches - assert_eq!(batches.total_predicate_count(), 5); - assert_eq!(batches.batch_count(), 2); - } - #[test] fn test_intro_predicate_in_custom_predicate() { use hex::ToHex; diff --git a/src/lang/frontend_ast_validate.rs b/src/lang/frontend_ast_validate.rs index 8939c8d..bd7393f 100644 --- a/src/lang/frontend_ast_validate.rs +++ b/src/lang/frontend_ast_validate.rs @@ -777,7 +777,7 @@ mod tests { ) .unwrap(); - let batch = CustomPredicateBatch::new(¶ms, "TestBatch".to_string(), vec![pred]); + let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]); let batch_id = batch.id().encode_hex::(); let input = format!( diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 3908918..52979dd 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -162,7 +162,7 @@ mod tests { let request_result = processed.request.templates(); assert_eq!(request_result.len(), 0); - assert_eq!(batch_result.predicates.len(), 1); + assert_eq!(batch_result.predicates().len(), 1); // Expected structure let expected_statements = vec![StatementTmpl { @@ -179,11 +179,8 @@ mod tests { 2, // args_len (PodA, PodB) names(&["PodA", "PodB"]), )?; - let expected_batch = CustomPredicateBatch::new( - ¶ms, - "PodlangBatch".to_string(), - vec![expected_predicate], - ); + let expected_batch = + CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]); assert_eq!(*batch_result, expected_batch); @@ -244,7 +241,7 @@ mod tests { let request_result = processed.request.templates(); assert_eq!(request_result.len(), 0); - assert_eq!(batch_result.predicates.len(), 1); + assert_eq!(batch_result.predicates().len(), 1); // Expected structure: Public args: A (index 0). Private args: Temp (index 1) let expected_statements = vec![ @@ -270,11 +267,8 @@ mod tests { 1, // args_len (A) names(&["A", "Temp"]), )?; - let expected_batch = CustomPredicateBatch::new( - ¶ms, - "PodlangBatch".to_string(), - vec![expected_predicate], - ); + let expected_batch = + CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]); assert_eq!(*batch_result, expected_batch); @@ -298,7 +292,7 @@ mod tests { let batch_result = first_batch(&processed); let request_templates = processed.request.templates(); - assert_eq!(batch_result.predicates.len(), 1); + assert_eq!(batch_result.predicates().len(), 1); assert!(!request_templates.is_empty()); // Expected Batch structure @@ -316,11 +310,8 @@ mod tests { 2, // args_len (X, Y) names(&["X", "Y"]), )?; - let expected_batch = CustomPredicateBatch::new( - ¶ms, - "PodlangBatch".to_string(), - vec![expected_predicate], - ); + let expected_batch = + CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]); assert_eq!(*batch_result, expected_batch); @@ -362,7 +353,7 @@ mod tests { let batch_result = first_batch(&processed); let request_templates = processed.request.templates(); - assert_eq!(batch_result.predicates.len(), 1); // some_pred is defined + assert_eq!(batch_result.predicates().len(), 1); // some_pred is defined assert!(!request_templates.is_empty()); // Expected Wildcard Indices in Request Scope: @@ -607,7 +598,7 @@ mod tests { "Expected no request templates" ); assert_eq!( - first_batch(&processed).predicates.len(), + first_batch(&processed).predicates().len(), 4, "Expected 4 custom predicates" ); @@ -727,7 +718,6 @@ mod tests { )?; let expected_batch = CustomPredicateBatch::new( - ¶ms, "PodlangBatch".to_string(), vec![ expected_friend_pred, @@ -766,7 +756,7 @@ mod tests { names(&["A", "B"]), )?; let available_batch = - CustomPredicateBatch::new(¶ms, "MyBatch".to_string(), vec![imported_predicate]); + CustomPredicateBatch::new("MyBatch".to_string(), vec![imported_predicate]); let available_batches = vec![available_batch.clone()]; // 2. Create the input string that uses the batch @@ -819,7 +809,7 @@ mod tests { let pred3 = CustomPredicate::and(¶ms, "p3".into(), vec![], 1, names(&["D"]))?; let available_batch = - CustomPredicateBatch::new(¶ms, "MyBatch".to_string(), vec![pred1, pred2, pred3]); + CustomPredicateBatch::new("MyBatch".to_string(), vec![pred1, pred2, pred3]); let available_batches = vec![available_batch.clone()]; // 2. Create the input string that uses the batch with skips @@ -883,7 +873,7 @@ mod tests { names(&["A", "B"]), )?; let available_batch = - CustomPredicateBatch::new(¶ms, "MyBatch".to_string(), vec![imported_predicate]); + CustomPredicateBatch::new("MyBatch".to_string(), vec![imported_predicate]); let available_batches = vec![available_batch.clone()]; // 2. Create the input string that defines a new predicate using the imported one @@ -908,13 +898,13 @@ mod tests { "No request should be defined" ); assert_eq!( - first_batch(&processed).predicates.len(), + first_batch(&processed).predicates().len(), 1, "Expected one custom predicate to be defined" ); // 4. Check the resulting predicate definition - let defined_pred = &first_batch(&processed).predicates[0]; + let defined_pred = &first_batch(&processed).predicates()[0]; assert_eq!(defined_pred.name, "wrapper_pred"); assert_eq!(defined_pred.statements.len(), 1); diff --git a/src/lang/pretty_print.rs b/src/lang/pretty_print.rs index b176681..1ffeb96 100644 --- a/src/lang/pretty_print.rs +++ b/src/lang/pretty_print.rs @@ -71,7 +71,7 @@ impl StatementTmpl { } Predicate::BatchSelf(index) => { if let Some(batch) = batch_context { - if let Some(predicate) = batch.predicates.get(*index) { + if let Some(predicate) = batch.predicates().get(*index) { write!(w, "{}", predicate.name)?; } else { write!(w, "batch_self_{}", index)?; @@ -108,7 +108,7 @@ impl PrettyPrint for StatementTmplArg { impl PrettyPrint for CustomPredicateBatch { fn fmt_podlang_with_indent(&self, w: &mut dyn Write, indent: usize) -> std::fmt::Result { - for (i, predicate) in self.predicates.iter().enumerate() { + for (i, predicate) in self.predicates().iter().enumerate() { if i > 0 { write!(w, "\n\n")?; } @@ -405,9 +405,11 @@ mod tests { // Step 4: Verify the ASTs are equivalent assert_eq!( - batch.predicates, reparsed_batch.predicates, + batch.predicates(), + reparsed_batch.predicates(), "Original AST should match reparsed AST.\nOriginal input:\n{}\nPretty-printed:\n{}\n", - input, pretty_printed + input, + pretty_printed ); } @@ -565,7 +567,7 @@ mod tests { let reparsed = parse(&pretty_printed, ¶ms, &[]).expect("Reparsing should succeed"); let reparsed_batch = reparsed.first_batch().expect("Expected batch"); - assert_eq!(batch.predicates, reparsed_batch.predicates); + assert_eq!(batch.predicates(), reparsed_batch.predicates()); } #[test] @@ -637,9 +639,11 @@ mod tests { let reparsed_batch = reparsed_result.first_batch().expect("Expected batch"); assert_eq!( - batch.predicates, reparsed_batch.predicates, + batch.predicates(), + reparsed_batch.predicates(), "Round-trip failed for string: {:?}\nPretty-printed: {}", - test_string, pretty_printed + test_string, + pretty_printed ); } } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 402ee44..13cc387 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -1,13 +1,16 @@ -use std::{fmt, iter, sync::Arc}; +use std::{collections::HashMap, fmt, iter, sync::Arc}; use itertools::Itertools; use plonky2::field::types::Field; use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; -use crate::middleware::{ - hash_fields, Error, Hash, Key, NativePredicate, Params, Predicate, Result, ToFields, Value, - BASE_PARAMS, EMPTY_HASH, F, VALUE_SIZE, +use crate::{ + backends::plonky2::primitives::merkletree::MerkleTree, + middleware::{ + hash_fields, Error, Hash, Key, NativePredicate, Params, Predicate, RawValue, Result, + ToFields, Value, BASE_PARAMS, F, VALUE_SIZE, + }, }; #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] @@ -420,83 +423,142 @@ impl fmt::Display for CustomPredicate { } } +#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)] +enum CustomPredicateBatchData { + Full { + #[serde(skip)] + #[schemars(skip)] + mt: MerkleTree, + predicates: Vec, + }, + Opaque { + id: Hash, + }, +} + +// TODO: Rename Batch for Module everywhere in the code base +impl CustomPredicateBatchData { + fn new_full(predicates: Vec) -> Self { + let kvs: HashMap = predicates + .iter() + .enumerate() + .map(|(index, pred)| { + let cp_hash = hash_fields(&pred.to_fields()); + (Value::from(index as i64).raw(), Value::from(cp_hash).raw()) + }) + .collect(); + let mt = MerkleTree::new(&kvs); + Self::Full { mt, predicates } + } + fn new_opaque(id: Hash) -> Self { + Self::Opaque { id } + } +} + +impl<'de> Deserialize<'de> for CustomPredicateBatchData { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + enum Aux { + Full { predicates: Vec }, + Opaque { id: Hash }, + } + let aux = Aux::deserialize(deserializer)?; + Ok(match aux { + Aux::Opaque { id } => Self::new_opaque(id), + Aux::Full { predicates } => Self::new_full(predicates), + }) + } +} + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] pub struct CustomPredicateBatch { - id: Hash, pub name: String, - pub(crate) predicates: Vec, + data: CustomPredicateBatchData, } impl std::hash::Hash for CustomPredicateBatch { fn hash(&self, state: &mut H) { - self.id.hash(state); - } -} - -impl ToFields for CustomPredicateBatch { - fn to_fields(&self) -> Vec { - // all the custom predicates in order - let pad_pred = CustomPredicate::empty(); - self.predicates - .iter() - .chain(iter::repeat(&pad_pred)) - .take(BASE_PARAMS.max_custom_batch_size) - .flat_map(|p| p.to_fields()) - .collect_vec() + self.id().hash(state); } } impl CustomPredicateBatch { - pub fn new(_params: &Params, name: String, predicates: Vec) -> Arc { - let mut cpb = Self { - id: EMPTY_HASH, + pub fn new(name: String, predicates: Vec) -> Arc { + Arc::new(Self { name, - predicates, - }; - let id = cpb.calculate_id(); - cpb.id = id; - Arc::new(cpb) + data: CustomPredicateBatchData::new_full(predicates), + }) } pub fn new_opaque(name: String, id: Hash) -> Arc { Arc::new(Self { - id, name, - predicates: vec![], + data: CustomPredicateBatchData::Opaque { id }, }) } - /// Cryptographic identifier for the batch. - fn calculate_id(&self) -> Hash { - // NOTE: This implementation just hashes the concatenation of all the custom predicates, - // but ideally we want to use the root of a merkle tree built from the custom predicates. - let input = self.to_fields(); - hash_fields(&input) - } - pub fn id(&self) -> Hash { - self.id + match &self.data { + CustomPredicateBatchData::Opaque { id } => *id, + CustomPredicateBatchData::Full { mt, .. } => mt.root(), + } } pub fn predicates(&self) -> &[CustomPredicate] { - &self.predicates + match &self.data { + // TODO: Return Option here instead of panic + CustomPredicateBatchData::Opaque { .. } => panic!("opaque batch"), + CustomPredicateBatchData::Full { predicates, .. } => predicates, + } + } + pub fn mt(&self) -> &MerkleTree { + match &self.data { + // TODO: Return Option here instead of panic + CustomPredicateBatchData::Opaque { .. } => panic!("opaque batch"), + CustomPredicateBatchData::Full { mt, .. } => mt, + } } pub fn predicate_ref_by_name( self: &Arc, name: &str, ) -> Option { - self.predicates + self.predicates() .iter() .enumerate() .find_map(|(i, cp)| (cp.name == name).then(|| CustomPredicateRef::new(self.clone(), i))) } + pub fn predicate_ref_by_index( + self: &Arc, + index: usize, + ) -> Option { + self.predicates() + .get(index) + .map(|_| CustomPredicateRef::new(self.clone(), index)) + } } -#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct CustomPredicateRef { pub batch: Arc, pub index: usize, } +impl std::hash::Hash for CustomPredicateRef { + fn hash(&self, state: &mut H) { + (self.batch.id(), self.index).hash(state); + } +} + +impl PartialEq for CustomPredicateRef { + fn eq(&self, other: &Self) -> bool { + self.batch.id() == other.batch.id() && self.index == other.index + } +} + +impl Eq for CustomPredicateRef {} + impl CustomPredicateRef { pub fn new(batch: Arc, index: usize) -> Self { Self { batch, index } @@ -505,7 +567,7 @@ impl CustomPredicateRef { self.predicate().args_len } pub fn predicate(&self) -> &CustomPredicate { - &self.batch.predicates[self.index] + &self.batch.predicates()[self.index] } } @@ -556,7 +618,6 @@ mod tests { p:product_of(S1, Constant, S2) */ let cust_pred_batch = CustomPredicateBatch::new( - ¶ms, "is_double".to_string(), vec![CustomPredicate::and( ¶ms, @@ -637,7 +698,7 @@ mod tests { )?; let eth_friend_batch = - CustomPredicateBatch::new(¶ms, "eth_friend".to_string(), vec![eth_friend]); + CustomPredicateBatch::new("eth_friend".to_string(), vec![eth_friend]); // 0 let eth_dos_base = CustomPredicate::and( @@ -714,7 +775,6 @@ mod tests { )?; let eth_dos_distance_batch = CustomPredicateBatch::new( - ¶ms, "ETHDoS_distance".to_string(), vec![eth_dos_base, eth_dos_ind, eth_dos], ); diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 2b6ca01..542f5b2 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -766,14 +766,14 @@ pub struct BaseParams { /// max number of statements that can be ANDed or ORed together /// in a custom predicate pub max_custom_predicate_arity: usize, - pub max_custom_batch_size: usize, + pub max_depth_custom_batch_mt: usize, } pub const BASE_PARAMS: BaseParams = BaseParams { num_public_statements_hash: 16, max_statement_args: 5, max_custom_predicate_arity: 5, - max_custom_batch_size: 4, + max_depth_custom_batch_mt: 16, // up to 65k (2^16) custom predicates in a batch }; /// Params: non dynamic parameters that define the circuit. @@ -785,8 +785,8 @@ pub struct Params { pub max_statements: usize, pub max_public_statements: usize, pub max_operation_args: usize, - // max number of custom predicates batches that a MainPod can use - pub max_custom_predicate_batches: usize, + // max number of different custom predicates that can be used in a MainPod + pub max_custom_predicates: usize, // max number of operations using custom predicates that can be verified in the MainPod pub max_custom_predicate_verifications: usize, pub max_custom_predicate_wildcards: usize, @@ -815,7 +815,7 @@ impl Default for Params { max_statements: 48, max_public_statements: 8, max_operation_args: 5, - max_custom_predicate_batches: 4, + max_custom_predicates: 8, max_custom_predicate_verifications: 8, max_custom_predicate_wildcards: 8, max_merkle_proofs_containers: 20, @@ -841,7 +841,7 @@ impl Params { BASE_PARAMS.max_custom_predicate_arity } pub const fn max_custom_batch_size() -> usize { - BASE_PARAMS.max_custom_batch_size + 2usize.pow(BASE_PARAMS.max_depth_custom_batch_mt as u32) } pub fn max_priv_statements(&self) -> usize { @@ -877,8 +877,8 @@ impl Params { BASE_PARAMS.max_custom_predicate_arity * Self::statement_tmpl_size() + 2 } - pub const fn custom_predicate_batch_size_field_elts() -> usize { - BASE_PARAMS.max_custom_batch_size * Self::custom_predicate_size() + pub const fn max_depth_custom_batch_mt() -> usize { + BASE_PARAMS.max_depth_custom_batch_mt } /// Total size of the statement table including None, input statements from signed pods and @@ -896,10 +896,6 @@ impl Params { println!(" Predicate: {}", Self::predicate_size()); println!(" Statement template: {}", Self::statement_tmpl_size()); println!(" Custom predicate: {}", Self::custom_predicate_size()); - println!( - " Custom predicate batch: {}", - Self::custom_predicate_batch_size_field_elts() - ); println!(); } }