Merkle tree for custom predicate batches (#471)

Resolve https://github.com/0xPARC/pod2/issues/466

Now batches are identified by the root of a merkle tree that contains all the predicates (using sequential indices as keys).  This means that the format to identify a custom predicate reference is still a hash + index, but the calculation of the hash is different.
The MainPod circuit now isn't limited by number of batches but instead number of custom predicates; and for each one we verify a merkle proof to verify the batch id.

I've removed a bunch of tests from lang that were testing splitting into multiple batches because there's no longer any need for that.  In a future PR we'll remove the code that handles batch splitting.

Each custom predicate needs 148.2 gates (which is very close to my estimate of 142.7 in https://github.com/0xPARC/pod2/issues/466#issuecomment-3823531286 where I actually made a mistake and considered 5 predicates per batch instead of 4 in the previous Params).
This commit is contained in:
Eduard S. 2026-02-04 11:12:32 +01:00 committed by GitHub
parent a7a30176a7
commit 641d8dabdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 331 additions and 761 deletions

View file

@ -768,37 +768,6 @@ mod tests {
assert_eq!(batches.total_predicate_count(), 3);
}
#[test]
fn test_predicates_span_multiple_batches() {
let input = r#"
pred1(A) = AND(Equal(A["x"], 1))
pred2(B) = AND(Equal(B["y"], 2))
pred3(C) = AND(Equal(C["z"], 3))
pred4(D) = AND(Equal(D["w"], 4))
pred5(E) = AND(Equal(E["v"], 5))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
validated.symbols(),
);
assert!(result.is_ok());
let batches = result.unwrap();
assert_eq!(batches.batch_count(), 2);
assert_eq!(batches.total_predicate_count(), 5);
// First batch should have 4 predicates
assert_eq!(batches.batches()[0].predicates().len(), 4);
// Second batch should have 1 predicate
assert_eq!(batches.batches()[1].predicates().len(), 1);
}
#[test]
fn test_intra_batch_forward_reference() {
// pred2 calls pred1, but pred2 is declared first
@ -869,132 +838,6 @@ mod tests {
)); // calls pred1
}
#[test]
fn test_cross_batch_reference() {
// 5 predicates where pred5 calls pred1
// pred1-4 go in batch 0, pred5 in batch 1
// pred5's call to pred1 should be a cross-batch reference
let input = r#"
pred1(A) = AND(Equal(A["x"], 1))
pred2(B) = AND(Equal(B["y"], 2))
pred3(C) = AND(Equal(C["z"], 3))
pred4(D) = AND(Equal(D["w"], 4))
pred5(E) = AND(pred1(E))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
validated.symbols(),
);
assert!(result.is_ok());
let batches = result.unwrap();
assert_eq!(batches.batch_count(), 2);
// pred5 should reference pred1 via CustomPredicateRef
let pred5_batch = &batches.batches()[1];
let pred5 = &pred5_batch.predicates()[0];
let pred5_stmt = &pred5.statements[0];
// The predicate should be a Custom reference to batch 0
match pred5_stmt.pred_or_wc() {
PredicateOrWildcard::Predicate(Predicate::Custom(ref_)) => {
// Should reference batch 0, index 0 (pred1)
assert_eq!(ref_.batch.id(), batches.batches()[0].id());
}
_ => panic!("Expected Custom predicate reference"),
}
}
#[test]
fn test_split_chain_spans_batches() {
// Create a predicate that will split into 2-3 predicates
// Then add more predicates to force the chain to span batches
let input = r#"
pred1(A) = AND(Equal(A["x"], 1))
pred2(B) = AND(Equal(B["y"], 2))
pred3(C) = AND(Equal(C["z"], 3))
large_pred(D) = AND(
Equal(D["a"], 1)
Equal(D["b"], 2)
Equal(D["c"], 3)
Equal(D["d"], 4)
Equal(D["e"], 5)
Equal(D["f"], 6)
)
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
// Split the large predicate
let mut all_split_results = Vec::new();
for pred in predicates {
let result = split_predicate_if_needed(pred, &params).expect("Split failed");
all_split_results.push(result);
}
// Count total predicates across all split results
let total_preds: usize = all_split_results.iter().map(|r| r.predicates.len()).sum();
// We should have: pred1, pred2, pred3, large_pred_1 (continuation), large_pred
// That's 5 predicates, which spans 2 batches
assert_eq!(total_preds, 5);
let result = batch_predicates(all_split_results, &params, "TestBatch", validated.symbols());
assert!(result.is_ok());
let batches = result.unwrap();
assert_eq!(batches.batch_count(), 2);
assert_eq!(batches.total_predicate_count(), 5);
// Verify chain info was captured
let chain_info = batches.split_chain("large_pred");
assert!(chain_info.is_some());
let info = chain_info.unwrap();
assert_eq!(info.original_name, "large_pred");
assert_eq!(info.real_statement_count, 6);
}
#[test]
fn test_forward_cross_batch_reference_avoided_by_planner() {
// 5 predicates where pred4 calls pred5 (forward declaration)
// With max_custom_batch_size = 4, naive packing would place pred5 in batch 1
// The dependency-aware planner should instead pack pred5 before pred4
// to avoid a forward cross-batch reference.
let input = r#"
pred1(A) = AND(Equal(A["x"], 1))
pred2(B) = AND(Equal(B["y"], 2))
pred3(C) = AND(Equal(C["z"], 3))
pred4(D) = AND(pred5(D))
pred5(E) = AND(Equal(E["v"], 5))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
let batches = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
validated.symbols(),
)
.expect("Planner should avoid forward cross-batch reference");
// Expect two batches and the reference to point within the same batch or earlier batch.
assert_eq!(batches.batch_count(), 2);
// pred5 should be in batch 0 and pred4 in batch 1 (given stable topo + packing)
let pred5_ref = batches.predicate_ref_by_name("pred5").unwrap();
let pred4_ref = batches.predicate_ref_by_name("pred4").unwrap();
assert_eq!(pred5_ref.batch.id(), batches.batches()[0].id());
assert_eq!(pred4_ref.batch.id(), batches.batches()[1].id());
}
#[test]
fn test_empty_input() {
let split_results: Vec<SplitResult> = vec![];
@ -1037,83 +880,6 @@ mod tests {
assert!(batches.predicate_ref_by_name("nonexistent").is_none());
}
#[test]
fn test_mutual_recursion_exceeds_capacity_error() {
// Two predicates that call each other (SCC size = 5) with max batch size 4
// Should error because an SCC cannot be split across batches
let input = r#"
pred1(A) = AND(pred2(A))
pred2(B) = AND(pred3(B))
pred3(B) = AND(pred4(B))
pred4(B) = AND(pred5(B))
pred5(B) = AND(pred1(B))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let result = batch_predicates(
preds_to_split_results(predicates),
&params,
"TestBatch",
validated.symbols(),
);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("exceeds batch capacity"));
}
#[test]
fn test_split_chain_across_batches_placement() {
// Create a large predicate that splits into 2 pieces, plus enough predicates
// to force the chain to span batches; verify continuation is placed earlier batch
let input = r#"
p1(A) = AND(Equal(A["x"], 1))
p2(B) = AND(Equal(B["y"], 2))
p3(C) = AND(Equal(C["z"], 3))
large_pred(D) = AND(
Equal(D["a"], 1)
Equal(D["b"], 2)
Equal(D["c"], 3)
Equal(D["d"], 4)
Equal(D["e"], 5)
Equal(D["f"], 6)
)
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default(); // max_custom_batch_size = 4
// Split and batch
let mut all_split_results = Vec::new();
for pred in predicates {
let result = split_predicate_if_needed(pred, &params).expect("Split failed");
all_split_results.push(result);
}
let batches =
batch_predicates(all_split_results, &params, "TestBatch", validated.symbols())
.expect("Batch failed");
assert_eq!(batches.batch_count(), 2);
// Verify chain info
let chain_info = batches
.split_chain("large_pred")
.expect("Expected chain info");
assert_eq!(chain_info.chain_pieces.len(), 2);
// Expect continuation piece name to be large_pred_1 (innermost first)
let cont_name = &chain_info.chain_pieces[0].name;
assert_eq!(cont_name, "large_pred_1");
// Expect continuation in batch 0 and main in batch 1
let cont_ref = batches.predicate_ref_by_name("large_pred_1").unwrap();
let main_ref = batches.predicate_ref_by_name("large_pred").unwrap();
assert_eq!(cont_ref.batch.id(), batches.batches()[0].id());
assert_eq!(main_ref.batch.id(), batches.batches()[1].id());
}
/// Helper: create a unique Statement for testing
/// Uses Equal with distinct literal values to create distinguishable statements
fn test_statement(id: usize) -> Statement {

View file

@ -681,68 +681,6 @@ mod tests {
assert!(result.is_ok());
}
#[test]
fn test_multi_batch_packing() {
// Create more predicates than fit in a single batch
// With max_custom_batch_size = 4, 5 predicates should span 2 batches
let input = r#"
pred1(A) = AND(Equal(A["a"], 1))
pred2(B) = AND(Equal(B["b"], 2))
pred3(C) = AND(Equal(C["c"], 3))
pred4(D) = AND(Equal(D["d"], 4))
pred5(E) = AND(Equal(E["e"], 5))
"#;
let params = Params::default(); // max_custom_batch_size = 4
let result = parse_validate_and_lower(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
let batches = lowered.batches.as_ref().expect("Expected batches");
// Should have 2 batches
assert_eq!(batches.batch_count(), 2);
assert_eq!(batches.total_predicate_count(), 5);
// First batch should have 4 predicates
assert_eq!(batches.batches()[0].predicates().len(), 4);
// Second batch should have 1 predicate
assert_eq!(batches.batches()[1].predicates().len(), 1);
}
#[test]
fn test_split_chains_span_batches() {
// Create predicates that will split, plus additional predicates
// to force the split chains across batch boundaries
let input = r#"
pred1(A) = AND(Equal(A["a"], 1))
pred2(B) = AND(Equal(B["b"], 2))
pred3(C) = AND(Equal(C["c"], 3))
large_pred(D) = AND(
Equal(D["a"], 1)
Equal(D["b"], 2)
Equal(D["c"], 3)
Equal(D["d"], 4)
Equal(D["e"], 5)
Equal(D["f"], 6)
)
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
let batches = lowered.batches.as_ref().expect("Expected batches");
// pred1, pred2, pred3 + large_pred split into 2 = 5 total predicates
// Should span 2 batches
assert_eq!(batches.total_predicate_count(), 5);
assert_eq!(batches.batch_count(), 2);
}
#[test]
fn test_intro_predicate_in_custom_predicate() {
use hex::ToHex;

View file

@ -777,7 +777,7 @@ mod tests {
)
.unwrap();
let batch = CustomPredicateBatch::new(&params, "TestBatch".to_string(), vec![pred]);
let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]);
let batch_id = batch.id().encode_hex::<String>();
let input = format!(

View file

@ -162,7 +162,7 @@ mod tests {
let request_result = processed.request.templates();
assert_eq!(request_result.len(), 0);
assert_eq!(batch_result.predicates.len(), 1);
assert_eq!(batch_result.predicates().len(), 1);
// Expected structure
let expected_statements = vec![StatementTmpl {
@ -179,11 +179,8 @@ mod tests {
2, // args_len (PodA, PodB)
names(&["PodA", "PodB"]),
)?;
let expected_batch = CustomPredicateBatch::new(
&params,
"PodlangBatch".to_string(),
vec![expected_predicate],
);
let expected_batch =
CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]);
assert_eq!(*batch_result, expected_batch);
@ -244,7 +241,7 @@ mod tests {
let request_result = processed.request.templates();
assert_eq!(request_result.len(), 0);
assert_eq!(batch_result.predicates.len(), 1);
assert_eq!(batch_result.predicates().len(), 1);
// Expected structure: Public args: A (index 0). Private args: Temp (index 1)
let expected_statements = vec![
@ -270,11 +267,8 @@ mod tests {
1, // args_len (A)
names(&["A", "Temp"]),
)?;
let expected_batch = CustomPredicateBatch::new(
&params,
"PodlangBatch".to_string(),
vec![expected_predicate],
);
let expected_batch =
CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]);
assert_eq!(*batch_result, expected_batch);
@ -298,7 +292,7 @@ mod tests {
let batch_result = first_batch(&processed);
let request_templates = processed.request.templates();
assert_eq!(batch_result.predicates.len(), 1);
assert_eq!(batch_result.predicates().len(), 1);
assert!(!request_templates.is_empty());
// Expected Batch structure
@ -316,11 +310,8 @@ mod tests {
2, // args_len (X, Y)
names(&["X", "Y"]),
)?;
let expected_batch = CustomPredicateBatch::new(
&params,
"PodlangBatch".to_string(),
vec![expected_predicate],
);
let expected_batch =
CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]);
assert_eq!(*batch_result, expected_batch);
@ -362,7 +353,7 @@ mod tests {
let batch_result = first_batch(&processed);
let request_templates = processed.request.templates();
assert_eq!(batch_result.predicates.len(), 1); // some_pred is defined
assert_eq!(batch_result.predicates().len(), 1); // some_pred is defined
assert!(!request_templates.is_empty());
// Expected Wildcard Indices in Request Scope:
@ -607,7 +598,7 @@ mod tests {
"Expected no request templates"
);
assert_eq!(
first_batch(&processed).predicates.len(),
first_batch(&processed).predicates().len(),
4,
"Expected 4 custom predicates"
);
@ -727,7 +718,6 @@ mod tests {
)?;
let expected_batch = CustomPredicateBatch::new(
&params,
"PodlangBatch".to_string(),
vec![
expected_friend_pred,
@ -766,7 +756,7 @@ mod tests {
names(&["A", "B"]),
)?;
let available_batch =
CustomPredicateBatch::new(&params, "MyBatch".to_string(), vec![imported_predicate]);
CustomPredicateBatch::new("MyBatch".to_string(), vec![imported_predicate]);
let available_batches = vec![available_batch.clone()];
// 2. Create the input string that uses the batch
@ -819,7 +809,7 @@ mod tests {
let pred3 = CustomPredicate::and(&params, "p3".into(), vec![], 1, names(&["D"]))?;
let available_batch =
CustomPredicateBatch::new(&params, "MyBatch".to_string(), vec![pred1, pred2, pred3]);
CustomPredicateBatch::new("MyBatch".to_string(), vec![pred1, pred2, pred3]);
let available_batches = vec![available_batch.clone()];
// 2. Create the input string that uses the batch with skips
@ -883,7 +873,7 @@ mod tests {
names(&["A", "B"]),
)?;
let available_batch =
CustomPredicateBatch::new(&params, "MyBatch".to_string(), vec![imported_predicate]);
CustomPredicateBatch::new("MyBatch".to_string(), vec![imported_predicate]);
let available_batches = vec![available_batch.clone()];
// 2. Create the input string that defines a new predicate using the imported one
@ -908,13 +898,13 @@ mod tests {
"No request should be defined"
);
assert_eq!(
first_batch(&processed).predicates.len(),
first_batch(&processed).predicates().len(),
1,
"Expected one custom predicate to be defined"
);
// 4. Check the resulting predicate definition
let defined_pred = &first_batch(&processed).predicates[0];
let defined_pred = &first_batch(&processed).predicates()[0];
assert_eq!(defined_pred.name, "wrapper_pred");
assert_eq!(defined_pred.statements.len(), 1);

View file

@ -71,7 +71,7 @@ impl StatementTmpl {
}
Predicate::BatchSelf(index) => {
if let Some(batch) = batch_context {
if let Some(predicate) = batch.predicates.get(*index) {
if let Some(predicate) = batch.predicates().get(*index) {
write!(w, "{}", predicate.name)?;
} else {
write!(w, "batch_self_{}", index)?;
@ -108,7 +108,7 @@ impl PrettyPrint for StatementTmplArg {
impl PrettyPrint for CustomPredicateBatch {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, indent: usize) -> std::fmt::Result {
for (i, predicate) in self.predicates.iter().enumerate() {
for (i, predicate) in self.predicates().iter().enumerate() {
if i > 0 {
write!(w, "\n\n")?;
}
@ -405,9 +405,11 @@ mod tests {
// Step 4: Verify the ASTs are equivalent
assert_eq!(
batch.predicates, reparsed_batch.predicates,
batch.predicates(),
reparsed_batch.predicates(),
"Original AST should match reparsed AST.\nOriginal input:\n{}\nPretty-printed:\n{}\n",
input, pretty_printed
input,
pretty_printed
);
}
@ -565,7 +567,7 @@ mod tests {
let reparsed = parse(&pretty_printed, &params, &[]).expect("Reparsing should succeed");
let reparsed_batch = reparsed.first_batch().expect("Expected batch");
assert_eq!(batch.predicates, reparsed_batch.predicates);
assert_eq!(batch.predicates(), reparsed_batch.predicates());
}
#[test]
@ -637,9 +639,11 @@ mod tests {
let reparsed_batch = reparsed_result.first_batch().expect("Expected batch");
assert_eq!(
batch.predicates, reparsed_batch.predicates,
batch.predicates(),
reparsed_batch.predicates(),
"Round-trip failed for string: {:?}\nPretty-printed: {}",
test_string, pretty_printed
test_string,
pretty_printed
);
}
}