Merkle tree for custom predicate batches (#471)
Resolve https://github.com/0xPARC/pod2/issues/466 Now batches are identified by the root of a merkle tree that contains all the predicates (using sequential indices as keys). This means that the format to identify a custom predicate reference is still a hash + index, but the calculation of the hash is different. The MainPod circuit now isn't limited by number of batches but instead number of custom predicates; and for each one we verify a merkle proof to verify the batch id. I've removed a bunch of tests from lang that were testing splitting into multiple batches because there's no longer any need for that. In a future PR we'll remove the code that handles batch splitting. Each custom predicate needs 148.2 gates (which is very close to my estimate of 142.7 in https://github.com/0xPARC/pod2/issues/466#issuecomment-3823531286 where I actually made a mistake and considered 5 predicates per batch instead of 4 in the previous Params).
This commit is contained in:
parent
a7a30176a7
commit
641d8dabdd
17 changed files with 331 additions and 761 deletions
|
|
@ -768,37 +768,6 @@ mod tests {
|
|||
assert_eq!(batches.total_predicate_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predicates_span_multiple_batches() {
|
||||
let input = r#"
|
||||
pred1(A) = AND(Equal(A["x"], 1))
|
||||
pred2(B) = AND(Equal(B["y"], 2))
|
||||
pred3(C) = AND(Equal(C["z"], 3))
|
||||
pred4(D) = AND(Equal(D["w"], 4))
|
||||
pred5(E) = AND(Equal(E["v"], 5))
|
||||
"#;
|
||||
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default(); // max_custom_batch_size = 4
|
||||
|
||||
let result = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
validated.symbols(),
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let batches = result.unwrap();
|
||||
assert_eq!(batches.batch_count(), 2);
|
||||
assert_eq!(batches.total_predicate_count(), 5);
|
||||
|
||||
// First batch should have 4 predicates
|
||||
assert_eq!(batches.batches()[0].predicates().len(), 4);
|
||||
// Second batch should have 1 predicate
|
||||
assert_eq!(batches.batches()[1].predicates().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_intra_batch_forward_reference() {
|
||||
// pred2 calls pred1, but pred2 is declared first
|
||||
|
|
@ -869,132 +838,6 @@ mod tests {
|
|||
)); // calls pred1
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_batch_reference() {
|
||||
// 5 predicates where pred5 calls pred1
|
||||
// pred1-4 go in batch 0, pred5 in batch 1
|
||||
// pred5's call to pred1 should be a cross-batch reference
|
||||
let input = r#"
|
||||
pred1(A) = AND(Equal(A["x"], 1))
|
||||
pred2(B) = AND(Equal(B["y"], 2))
|
||||
pred3(C) = AND(Equal(C["z"], 3))
|
||||
pred4(D) = AND(Equal(D["w"], 4))
|
||||
pred5(E) = AND(pred1(E))
|
||||
"#;
|
||||
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default(); // max_custom_batch_size = 4
|
||||
|
||||
let result = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
validated.symbols(),
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let batches = result.unwrap();
|
||||
assert_eq!(batches.batch_count(), 2);
|
||||
|
||||
// pred5 should reference pred1 via CustomPredicateRef
|
||||
let pred5_batch = &batches.batches()[1];
|
||||
let pred5 = &pred5_batch.predicates()[0];
|
||||
let pred5_stmt = &pred5.statements[0];
|
||||
|
||||
// The predicate should be a Custom reference to batch 0
|
||||
match pred5_stmt.pred_or_wc() {
|
||||
PredicateOrWildcard::Predicate(Predicate::Custom(ref_)) => {
|
||||
// Should reference batch 0, index 0 (pred1)
|
||||
assert_eq!(ref_.batch.id(), batches.batches()[0].id());
|
||||
}
|
||||
_ => panic!("Expected Custom predicate reference"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_chain_spans_batches() {
|
||||
// Create a predicate that will split into 2-3 predicates
|
||||
// Then add more predicates to force the chain to span batches
|
||||
let input = r#"
|
||||
pred1(A) = AND(Equal(A["x"], 1))
|
||||
pred2(B) = AND(Equal(B["y"], 2))
|
||||
pred3(C) = AND(Equal(C["z"], 3))
|
||||
large_pred(D) = AND(
|
||||
Equal(D["a"], 1)
|
||||
Equal(D["b"], 2)
|
||||
Equal(D["c"], 3)
|
||||
Equal(D["d"], 4)
|
||||
Equal(D["e"], 5)
|
||||
Equal(D["f"], 6)
|
||||
)
|
||||
"#;
|
||||
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
// Split the large predicate
|
||||
let mut all_split_results = Vec::new();
|
||||
for pred in predicates {
|
||||
let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed");
|
||||
all_split_results.push(result);
|
||||
}
|
||||
|
||||
// Count total predicates across all split results
|
||||
let total_preds: usize = all_split_results.iter().map(|r| r.predicates.len()).sum();
|
||||
|
||||
// We should have: pred1, pred2, pred3, large_pred_1 (continuation), large_pred
|
||||
// That's 5 predicates, which spans 2 batches
|
||||
assert_eq!(total_preds, 5);
|
||||
|
||||
let result = batch_predicates(all_split_results, ¶ms, "TestBatch", validated.symbols());
|
||||
assert!(result.is_ok());
|
||||
|
||||
let batches = result.unwrap();
|
||||
assert_eq!(batches.batch_count(), 2);
|
||||
assert_eq!(batches.total_predicate_count(), 5);
|
||||
|
||||
// Verify chain info was captured
|
||||
let chain_info = batches.split_chain("large_pred");
|
||||
assert!(chain_info.is_some());
|
||||
let info = chain_info.unwrap();
|
||||
assert_eq!(info.original_name, "large_pred");
|
||||
assert_eq!(info.real_statement_count, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_cross_batch_reference_avoided_by_planner() {
|
||||
// 5 predicates where pred4 calls pred5 (forward declaration)
|
||||
// With max_custom_batch_size = 4, naive packing would place pred5 in batch 1
|
||||
// The dependency-aware planner should instead pack pred5 before pred4
|
||||
// to avoid a forward cross-batch reference.
|
||||
let input = r#"
|
||||
pred1(A) = AND(Equal(A["x"], 1))
|
||||
pred2(B) = AND(Equal(B["y"], 2))
|
||||
pred3(C) = AND(Equal(C["z"], 3))
|
||||
pred4(D) = AND(pred5(D))
|
||||
pred5(E) = AND(Equal(E["v"], 5))
|
||||
"#;
|
||||
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default(); // max_custom_batch_size = 4
|
||||
|
||||
let batches = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
validated.symbols(),
|
||||
)
|
||||
.expect("Planner should avoid forward cross-batch reference");
|
||||
|
||||
// Expect two batches and the reference to point within the same batch or earlier batch.
|
||||
assert_eq!(batches.batch_count(), 2);
|
||||
// pred5 should be in batch 0 and pred4 in batch 1 (given stable topo + packing)
|
||||
let pred5_ref = batches.predicate_ref_by_name("pred5").unwrap();
|
||||
let pred4_ref = batches.predicate_ref_by_name("pred4").unwrap();
|
||||
assert_eq!(pred5_ref.batch.id(), batches.batches()[0].id());
|
||||
assert_eq!(pred4_ref.batch.id(), batches.batches()[1].id());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_input() {
|
||||
let split_results: Vec<SplitResult> = vec![];
|
||||
|
|
@ -1037,83 +880,6 @@ mod tests {
|
|||
assert!(batches.predicate_ref_by_name("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mutual_recursion_exceeds_capacity_error() {
|
||||
// Two predicates that call each other (SCC size = 5) with max batch size 4
|
||||
// Should error because an SCC cannot be split across batches
|
||||
let input = r#"
|
||||
pred1(A) = AND(pred2(A))
|
||||
pred2(B) = AND(pred3(B))
|
||||
pred3(B) = AND(pred4(B))
|
||||
pred4(B) = AND(pred5(B))
|
||||
pred5(B) = AND(pred1(B))
|
||||
"#;
|
||||
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default();
|
||||
|
||||
let result = batch_predicates(
|
||||
preds_to_split_results(predicates),
|
||||
¶ms,
|
||||
"TestBatch",
|
||||
validated.symbols(),
|
||||
);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("exceeds batch capacity"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_chain_across_batches_placement() {
|
||||
// Create a large predicate that splits into 2 pieces, plus enough predicates
|
||||
// to force the chain to span batches; verify continuation is placed earlier batch
|
||||
let input = r#"
|
||||
p1(A) = AND(Equal(A["x"], 1))
|
||||
p2(B) = AND(Equal(B["y"], 2))
|
||||
p3(C) = AND(Equal(C["z"], 3))
|
||||
large_pred(D) = AND(
|
||||
Equal(D["a"], 1)
|
||||
Equal(D["b"], 2)
|
||||
Equal(D["c"], 3)
|
||||
Equal(D["d"], 4)
|
||||
Equal(D["e"], 5)
|
||||
Equal(D["f"], 6)
|
||||
)
|
||||
"#;
|
||||
|
||||
let (predicates, validated) = parse_and_validate(input);
|
||||
let params = Params::default(); // max_custom_batch_size = 4
|
||||
|
||||
// Split and batch
|
||||
let mut all_split_results = Vec::new();
|
||||
for pred in predicates {
|
||||
let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed");
|
||||
all_split_results.push(result);
|
||||
}
|
||||
let batches =
|
||||
batch_predicates(all_split_results, ¶ms, "TestBatch", validated.symbols())
|
||||
.expect("Batch failed");
|
||||
|
||||
assert_eq!(batches.batch_count(), 2);
|
||||
|
||||
// Verify chain info
|
||||
let chain_info = batches
|
||||
.split_chain("large_pred")
|
||||
.expect("Expected chain info");
|
||||
assert_eq!(chain_info.chain_pieces.len(), 2);
|
||||
// Expect continuation piece name to be large_pred_1 (innermost first)
|
||||
let cont_name = &chain_info.chain_pieces[0].name;
|
||||
assert_eq!(cont_name, "large_pred_1");
|
||||
|
||||
// Expect continuation in batch 0 and main in batch 1
|
||||
let cont_ref = batches.predicate_ref_by_name("large_pred_1").unwrap();
|
||||
let main_ref = batches.predicate_ref_by_name("large_pred").unwrap();
|
||||
assert_eq!(cont_ref.batch.id(), batches.batches()[0].id());
|
||||
assert_eq!(main_ref.batch.id(), batches.batches()[1].id());
|
||||
}
|
||||
|
||||
/// Helper: create a unique Statement for testing
|
||||
/// Uses Equal with distinct literal values to create distinguishable statements
|
||||
fn test_statement(id: usize) -> Statement {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue