Improved predicate splitting (#445)

* Multi-batch splitting

* Invoke split predicates by name, passing in full argument list

* Reorder batches to prevent failure of forward references where possible

* Rename APIs for clarity

* Simplify example

* Add more docs

* Review updates

* Remove duplicate code

* Comment topological sort algorithm
This commit is contained in:
Rob Knight 2026-01-28 06:54:21 +01:00 committed by GitHub
parent 9c9a2c454c
commit d1b7b4d37e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 2090 additions and 466 deletions

View file

@ -71,6 +71,12 @@ impl From<crate::lang::LangError> for Error {
}
}
impl From<crate::lang::MultiOperationError> for Error {
fn from(value: crate::lang::MultiOperationError) -> Self {
Error::custom(value.to_string())
}
}
impl Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)

View file

@ -1390,7 +1390,11 @@ pub mod tests {
Equal(b, 5)
)
"#;
let batch = parse(input, &params, &[]).unwrap().custom_batch;
let batch = parse(input, &params, &[])
.unwrap()
.first_batch()
.unwrap()
.clone();
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
// Try to build with wrong type in 1st arg
@ -1414,4 +1418,83 @@ pub mod tests {
Ok(())
}
#[test]
fn test_apply_predicate_e2e() -> Result<()> {
// End-to-end test of apply_predicate with MockProver
// Tests a split predicate being applied through the full pipeline
let params = Params::default();
let vd_set = &*MOCK_VD_SET;
// Create a predicate that will split (6 Equal statements)
// The predicate checks that values at different keys are equal to specific literals
let input = r#"
large_pred(A) = AND(
Equal(A["a"], 1)
Equal(A["b"], 2)
Equal(A["c"], 3)
Equal(A["d"], 4)
Equal(A["e"], 5)
Equal(A["f"], 6)
)
"#;
// Parse and batch the predicate (this handles splitting internally)
let parsed = parse(input, &params, &[])?;
let batches = &parsed.custom_batches;
// Verify it was split
assert!(batches.split_chain("large_pred").is_some());
let chain_info = batches.split_chain("large_pred").unwrap();
assert_eq!(chain_info.chain_pieces.len(), 2);
assert_eq!(chain_info.real_statement_count, 6);
// Create a signed dict with the required entries
let mut signed_builder = SignedDictBuilder::new(&params);
signed_builder.insert("a", 1);
signed_builder.insert("b", 2);
signed_builder.insert("c", 3);
signed_builder.insert("d", 4);
signed_builder.insert("e", 5);
signed_builder.insert("f", 6);
let signer = Signer(SecretKey(1u32.into()));
let signed_dict = signed_builder.sign(&signer)?;
// Build the main pod
let mut builder = MainPodBuilder::new(&params, vd_set);
builder.pub_op(Operation::dict_signed_by(&signed_dict))?;
// Create 6 Equal statements (one for each predicate constraint) in original order
// Each proves that signed_dict["x"] = n, matching the Equal(A["x"], n) template
let st_a = builder.priv_op(Operation::eq((&signed_dict, "a"), 1))?;
let st_b = builder.priv_op(Operation::eq((&signed_dict, "b"), 2))?;
let st_c = builder.priv_op(Operation::eq((&signed_dict, "c"), 3))?;
let st_d = builder.priv_op(Operation::eq((&signed_dict, "d"), 4))?;
let st_e = builder.priv_op(Operation::eq((&signed_dict, "e"), 5))?;
let st_f = builder.priv_op(Operation::eq((&signed_dict, "f"), 6))?;
// Pass statements in original declaration order
let statements = vec![st_a, st_b, st_c, st_d, st_e, st_f];
// Use apply_predicate (primary API) to automatically wire the split chain
let result = batches.apply_predicate(&mut builder, "large_pred", statements, true)?;
// The result should be a valid statement
let predicate = batches.predicate_ref_by_name("large_pred").unwrap();
match &result {
Statement::Custom(pred_ref, _) => {
assert_eq!(pred_ref, &predicate);
}
_ => panic!("Expected Statement::Custom, got {:?}", result),
}
// Prove with MockProver
let prover = MockProver {};
let pod = builder.prove(&prover)?;
// Verify the pod
pod.pod.verify()?;
Ok(())
}
}