Merkle tree for custom predicate batches (#471)

Resolve https://github.com/0xPARC/pod2/issues/466

Now batches are identified by the root of a merkle tree that contains all the predicates (using sequential indices as keys).  This means that the format to identify a custom predicate reference is still a hash + index, but the calculation of the hash is different.
The MainPod circuit now isn't limited by number of batches but instead number of custom predicates; and for each one we verify a merkle proof to verify the batch id.

I've removed a bunch of tests from lang that were testing splitting into multiple batches because there's no longer any need for that.  In a future PR we'll remove the code that handles batch splitting.

Each custom predicate needs 148.2 gates (which is very close to my estimate of 142.7 in https://github.com/0xPARC/pod2/issues/466#issuecomment-3823531286 where I actually made a mistake and considered 5 predicates per batch instead of 4 in the previous Params).
This commit is contained in:
Eduard S. 2026-02-04 11:12:32 +01:00 committed by GitHub
parent a7a30176a7
commit 641d8dabdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 331 additions and 761 deletions

View file

@ -768,37 +768,6 @@ mod tests {
assert_eq!(batches.total_predicate_count(), 3);
}
#[test]
fn test_predicates_span_multiple_batches() {
let input = r#"
pred1(A) = AND(Equal(A["x"], 1))
pred2(B) = AND(Equal(B["y"], 2))
pred3(C) = AND(Equal(C["z"], 3))
pred4(D) = AND(Equal(D["w"], 4))
pred5(E) = AND(Equal(E["v"], 5))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
validated.symbols(),
);
assert!(result.is_ok());
let batches = result.unwrap();
assert_eq!(batches.batch_count(), 2);
assert_eq!(batches.total_predicate_count(), 5);
// First batch should have 4 predicates
assert_eq!(batches.batches()[0].predicates().len(), 4);
// Second batch should have 1 predicate
assert_eq!(batches.batches()[1].predicates().len(), 1);
}
#[test]
fn test_intra_batch_forward_reference() {
// pred2 calls pred1, but pred2 is declared first
@ -869,132 +838,6 @@ mod tests {
)); // calls pred1
}
#[test]
fn test_cross_batch_reference() {
// 5 predicates where pred5 calls pred1
// pred1-4 go in batch 0, pred5 in batch 1
// pred5's call to pred1 should be a cross-batch reference
let input = r#"
pred1(A) = AND(Equal(A["x"], 1))
pred2(B) = AND(Equal(B["y"], 2))
pred3(C) = AND(Equal(C["z"], 3))
pred4(D) = AND(Equal(D["w"], 4))
pred5(E) = AND(pred1(E))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
validated.symbols(),
);
assert!(result.is_ok());
let batches = result.unwrap();
assert_eq!(batches.batch_count(), 2);
// pred5 should reference pred1 via CustomPredicateRef
let pred5_batch = &batches.batches()[1];
let pred5 = &pred5_batch.predicates()[0];
let pred5_stmt = &pred5.statements[0];
// The predicate should be a Custom reference to batch 0
match pred5_stmt.pred_or_wc() {
PredicateOrWildcard::Predicate(Predicate::Custom(ref_)) => {
// Should reference batch 0, index 0 (pred1)
assert_eq!(ref_.batch.id(), batches.batches()[0].id());
}
_ => panic!("Expected Custom predicate reference"),
}
}
#[test]
fn test_split_chain_spans_batches() {
// Create a predicate that will split into 2-3 predicates
// Then add more predicates to force the chain to span batches
let input = r#"
pred1(A) = AND(Equal(A["x"], 1))
pred2(B) = AND(Equal(B["y"], 2))
pred3(C) = AND(Equal(C["z"], 3))
large_pred(D) = AND(
Equal(D["a"], 1)
Equal(D["b"], 2)
Equal(D["c"], 3)
Equal(D["d"], 4)
Equal(D["e"], 5)
Equal(D["f"], 6)
)
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
// Split the large predicate
let mut all_split_results = Vec::new();
for pred in predicates {
let result = split_predicate_if_needed(pred, &params).expect("Split failed");
all_split_results.push(result);
}
// Count total predicates across all split results
let total_preds: usize = all_split_results.iter().map(|r| r.predicates.len()).sum();
// We should have: pred1, pred2, pred3, large_pred_1 (continuation), large_pred
// That's 5 predicates, which spans 2 batches
assert_eq!(total_preds, 5);
let result = batch_predicates(all_split_results, &params, "TestBatch", validated.symbols());
assert!(result.is_ok());
let batches = result.unwrap();
assert_eq!(batches.batch_count(), 2);
assert_eq!(batches.total_predicate_count(), 5);
// Verify chain info was captured
let chain_info = batches.split_chain("large_pred");
assert!(chain_info.is_some());
let info = chain_info.unwrap();
assert_eq!(info.original_name, "large_pred");
assert_eq!(info.real_statement_count, 6);
}
#[test]
fn test_forward_cross_batch_reference_avoided_by_planner() {
// 5 predicates where pred4 calls pred5 (forward declaration)
// With max_custom_batch_size = 4, naive packing would place pred5 in batch 1
// The dependency-aware planner should instead pack pred5 before pred4
// to avoid a forward cross-batch reference.
let input = r#"
pred1(A) = AND(Equal(A["x"], 1))
pred2(B) = AND(Equal(B["y"], 2))
pred3(C) = AND(Equal(C["z"], 3))
pred4(D) = AND(pred5(D))
pred5(E) = AND(Equal(E["v"], 5))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
let batches = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
validated.symbols(),
)
.expect("Planner should avoid forward cross-batch reference");
// Expect two batches and the reference to point within the same batch or earlier batch.
assert_eq!(batches.batch_count(), 2);
// pred5 should be in batch 0 and pred4 in batch 1 (given stable topo + packing)
let pred5_ref = batches.predicate_ref_by_name("pred5").unwrap();
let pred4_ref = batches.predicate_ref_by_name("pred4").unwrap();
assert_eq!(pred5_ref.batch.id(), batches.batches()[0].id());
assert_eq!(pred4_ref.batch.id(), batches.batches()[1].id());
}
#[test]
fn test_empty_input() {
let split_results: Vec<SplitResult> = vec![];
@ -1037,83 +880,6 @@ mod tests {
assert!(batches.predicate_ref_by_name("nonexistent").is_none());
}
#[test]
fn test_mutual_recursion_exceeds_capacity_error() {
// Two predicates that call each other (SCC size = 5) with max batch size 4
// Should error because an SCC cannot be split across batches
let input = r#"
pred1(A) = AND(pred2(A))
pred2(B) = AND(pred3(B))
pred3(B) = AND(pred4(B))
pred4(B) = AND(pred5(B))
pred5(B) = AND(pred1(B))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
validated.symbols(),
);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("exceeds batch capacity"));
}
#[test]
fn test_split_chain_across_batches_placement() {
// Create a large predicate that splits into 2 pieces, plus enough predicates
// to force the chain to span batches; verify continuation is placed earlier batch
let input = r#"
p1(A) = AND(Equal(A["x"], 1))
p2(B) = AND(Equal(B["y"], 2))
p3(C) = AND(Equal(C["z"], 3))
large_pred(D) = AND(
Equal(D["a"], 1)
Equal(D["b"], 2)
Equal(D["c"], 3)
Equal(D["d"], 4)
Equal(D["e"], 5)
Equal(D["f"], 6)
)
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
// Split and batch
let mut all_split_results = Vec::new();
for pred in predicates {
let result = split_predicate_if_needed(pred, &params).expect("Split failed");
all_split_results.push(result);
}
let batches =
batch_predicates(all_split_results, &params, "TestBatch", validated.symbols())
.expect("Batch failed");
assert_eq!(batches.batch_count(), 2);
// Verify chain info
let chain_info = batches
.split_chain("large_pred")
.expect("Expected chain info");
assert_eq!(chain_info.chain_pieces.len(), 2);
// Expect continuation piece name to be large_pred_1 (innermost first)
let cont_name = &chain_info.chain_pieces[0].name;
assert_eq!(cont_name, "large_pred_1");
// Expect continuation in batch 0 and main in batch 1
let cont_ref = batches.predicate_ref_by_name("large_pred_1").unwrap();
let main_ref = batches.predicate_ref_by_name("large_pred").unwrap();
assert_eq!(cont_ref.batch.id(), batches.batches()[0].id());
assert_eq!(main_ref.batch.id(), batches.batches()[1].id());
}
/// Helper: create a unique Statement for testing
/// Uses Equal with distinct literal values to create distinguishable statements
fn test_statement(id: usize) -> Statement {