Improved predicate splitting (#445)

* Multi-batch splitting

* Invoke split predicates by name, passing in full argument list

* Reorder batches to prevent failure of forward references where possible

* Rename APIs for clarity

* Simplify example

* Add more docs

* Review updates

* Remove duplicate code

* Comment topological sort algorithm
This commit is contained in:
Rob Knight 2026-01-28 06:54:21 +01:00 committed by GitHub
parent 9c9a2c454c
commit d1b7b4d37e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 2090 additions and 466 deletions

View file

@ -34,6 +34,40 @@ pub struct ChainLink {
pub public_args_out: Vec<String>,
}
/// Information about a single piece of a split predicate chain
#[derive(Debug, Clone)]
pub struct SplitChainPiece {
/// Name of this predicate piece (e.g., "foo_1")
pub name: String,
/// Number of "real" statements in this piece (excludes chain call)
pub real_statement_count: usize,
/// Whether this piece has a chain call to the next piece
pub has_chain_call: bool,
}
/// Metadata about a split predicate chain
#[derive(Debug, Clone)]
pub struct SplitChainInfo {
/// Original predicate name (e.g., "foo")
pub original_name: String,
/// Chain pieces in execution order (innermost continuation first: [foo_2, foo_1, foo])
pub chain_pieces: Vec<SplitChainPiece>,
/// Total number of "real" user statements (excludes chain calls)
pub real_statement_count: usize,
/// Maps original statement index → reordered index
/// e.g., if original stmt 0 became reordered stmt 3, then `reorder_map[0] = 3`
pub reorder_map: Vec<usize>,
}
/// Result of splitting a predicate
#[derive(Debug, Clone)]
pub struct SplitResult {
/// The predicates (continuations first, original last if split)
pub predicates: Vec<CustomPredicateDef>,
/// Split chain info, if splitting occurred (None for non-split)
pub chain_info: Option<SplitChainInfo>,
}
/// Wildcard usage information
#[derive(Debug, Clone)]
struct WildcardUsage {
@ -66,19 +100,25 @@ pub fn validate_predicate_is_splittable(
pub fn split_predicate_if_needed(
pred: CustomPredicateDef,
params: &Params,
) -> Result<Vec<CustomPredicateDef>, SplittingError> {
) -> Result<SplitResult, SplittingError> {
// Early validation
validate_predicate_is_splittable(&pred, params)?;
// If within limits, no splitting needed
if pred.statements.len() <= params.max_custom_predicate_arity {
return Ok(vec![pred]);
return Ok(SplitResult {
predicates: vec![pred],
chain_info: None,
});
}
// Need to split - execute the splitting algorithm
let chain = split_into_chain(pred, params)?;
let (predicates, chain_info) = split_into_chain(pred, params)?;
Ok(chain)
Ok(SplitResult {
predicates,
chain_info: Some(chain_info),
})
}
fn analyze_wildcards(statements: &[StatementTmpl]) -> HashMap<String, WildcardUsage> {
@ -121,18 +161,33 @@ fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> {
}
/// Order constraints optimally to minimize liveness at boundaries
/// Result of ordering statements optimally for splitting
struct OrderingResult {
/// Reordered statements
statements: Vec<StatementTmpl>,
/// Maps original statement index → reordered index
/// reorder_map[original_idx] = new_idx
reorder_map: Vec<usize>,
}
fn order_constraints_optimally(
statements: Vec<StatementTmpl>,
_usage: &HashMap<String, WildcardUsage>,
params: &Params,
) -> Vec<StatementTmpl> {
// If no splitting needed, preserve original order
if statements.len() <= params.max_custom_predicate_arity {
return statements;
) -> OrderingResult {
let n = statements.len();
// If no splitting needed, preserve original order (identity mapping)
if n <= params.max_custom_predicate_arity {
return OrderingResult {
statements,
reorder_map: (0..n).collect(),
};
}
let mut ordered = Vec::new();
let mut remaining: HashSet<usize> = (0..statements.len()).collect();
let mut reorder_map = vec![0; n];
let mut remaining: HashSet<usize> = (0..n).collect();
let mut active_wildcards: HashSet<String> = HashSet::new();
while !remaining.is_empty() {
@ -146,6 +201,9 @@ fn order_constraints_optimally(
remaining.remove(&best_idx);
let stmt = &statements[best_idx];
// Record the mapping: original index best_idx → new index ordered.len()
reorder_map[best_idx] = ordered.len();
ordered.push(stmt.clone());
// Update active wildcards
@ -160,7 +218,10 @@ fn order_constraints_optimally(
active_wildcards.retain(|w| needed_later.contains(w));
}
ordered
OrderingResult {
statements: ordered,
reorder_map,
}
}
/// Compute tie-breaker metrics for deterministic ordering when scores are equal
@ -360,16 +421,20 @@ fn generate_refactor_suggestion(
}
/// Split into chain using bucket-filling approach
/// Returns the split predicates and metadata about the split
fn split_into_chain(
pred: CustomPredicateDef,
params: &Params,
) -> Result<Vec<CustomPredicateDef>, SplittingError> {
) -> Result<(Vec<CustomPredicateDef>, SplitChainInfo), SplittingError> {
let original_name = pred.name.name.clone();
let conjunction = pred.conjunction_type;
let usage = analyze_wildcards(&pred.statements);
let real_statement_count = pred.statements.len();
let ordered_statements = order_constraints_optimally(pred.statements, &usage, params);
let ordering_result = order_constraints_optimally(pred.statements, &usage, params);
let ordered_statements = ordering_result.statements;
let reorder_map = ordering_result.reorder_map;
let original_public_args: Vec<String> = pred
.args
@ -479,12 +544,43 @@ fn split_into_chain(
}
}
let chain_predicates =
// Build SplitChainInfo from chain_links before generating predicates
// Pieces are in execution order: innermost continuation first, original last
let num_links = chain_links.len();
let mut chain_pieces = Vec::new();
for i in (0..num_links).rev() {
let link = &chain_links[i];
let is_last = i == num_links - 1;
let name = if i == 0 {
original_name.clone()
} else {
format!("{}_{}", original_name, i)
};
chain_pieces.push(SplitChainPiece {
name,
real_statement_count: link.statements.len(),
has_chain_call: !is_last,
});
}
let chain_info = SplitChainInfo {
original_name: original_name.clone(),
chain_pieces,
real_statement_count,
reorder_map,
};
let mut chain_predicates =
generate_chain_predicates(&original_name, chain_links, conjunction, params)?;
validate_chain(&chain_predicates, &original_name, params)?;
validate_chain(&chain_predicates, params)?;
Ok(chain_predicates)
// Reverse so continuations come before callers in declaration order.
// This ensures that when batched, continuations are in earlier batches
// and can be referenced by their callers.
chain_predicates.reverse();
Ok((chain_predicates, chain_info))
}
/// Phase 4: Generate synthetic predicates from chain links
@ -519,20 +615,19 @@ fn generate_chain_predicates(
span: None,
};
// Create arguments for chain call: all public args (incoming + promoted)
let mut chain_call_args = Vec::new();
for arg_name in &link.public_args_in {
chain_call_args.push(StatementTmplArg::Wildcard(Identifier {
name: arg_name.clone(),
span: None,
}));
}
for arg_name in &link.public_args_out {
chain_call_args.push(StatementTmplArg::Wildcard(Identifier {
name: arg_name.clone(),
span: None,
}));
}
// Create arguments for chain call: use next link's public_args_in
// which is the deduplicated union of current public_args_in and public_args_out
let next_link = &chain_links[i + 1];
let chain_call_args: Vec<StatementTmplArg> = next_link
.public_args_in
.iter()
.map(|arg_name| {
StatementTmplArg::Wildcard(Identifier {
name: arg_name.clone(),
span: None,
})
})
.collect();
let chain_call = StatementTmpl {
predicate: next_pred_name,
@ -587,19 +682,10 @@ fn generate_chain_predicates(
}
/// Phase 5: Validate the generated chain
fn validate_chain(
chain: &[CustomPredicateDef],
original_name: &str,
params: &Params,
) -> Result<(), SplittingError> {
if chain.len() > params.max_custom_batch_size {
return Err(SplittingError::TooManyPredicatesInChain {
predicate: original_name.to_string(),
count: chain.len(),
max_allowed: params.max_custom_batch_size,
});
}
///
/// Note: We no longer check chain length against max_custom_batch_size since
/// chains can now span multiple batches thanks to multi-batch support.
fn validate_chain(chain: &[CustomPredicateDef], params: &Params) -> Result<(), SplittingError> {
for pred in chain {
// Each predicate should have ≤ max_statements
assert!(pred.statements.len() <= params.max_custom_predicate_arity);
@ -681,8 +767,9 @@ mod tests {
let result = split_predicate_if_needed(pred, &params);
assert!(result.is_ok());
let chain = result.unwrap();
assert_eq!(chain.len(), 1); // No split needed
let split_result = result.unwrap();
assert_eq!(split_result.predicates.len(), 1); // No split needed
assert!(split_result.chain_info.is_none()); // No chain info for non-split
}
#[test]
@ -704,14 +791,29 @@ mod tests {
let result = split_predicate_if_needed(pred, &params);
assert!(result.is_ok());
let chain = result.unwrap();
let split_result = result.unwrap();
let chain = &split_result.predicates;
assert_eq!(chain.len(), 2); // Should split into 2 predicates
// First predicate: 4 statements + chain call = 5
assert_eq!(chain[0].statements.len(), 5);
// Chain is reversed: continuation comes first, original comes last
// Last predicate (index 1): original name, 4 statements + chain call = 5
assert_eq!(chain[1].statements.len(), 5);
assert_eq!(chain[1].name.name, "my_pred");
// Second predicate: 2 remaining statements
assert_eq!(chain[1].statements.len(), 2);
// First predicate (index 0): continuation, 2 remaining statements
assert_eq!(chain[0].statements.len(), 2);
assert_eq!(chain[0].name.name, "my_pred_1");
// Verify chain_info is present
let chain_info = split_result.chain_info.as_ref().unwrap();
assert_eq!(chain_info.original_name, "my_pred");
assert_eq!(chain_info.real_statement_count, 6);
assert_eq!(chain_info.chain_pieces.len(), 2);
// Pieces are in execution order: innermost first
assert_eq!(chain_info.chain_pieces[0].name, "my_pred_1");
assert!(!chain_info.chain_pieces[0].has_chain_call);
assert_eq!(chain_info.chain_pieces[1].name, "my_pred");
assert!(chain_info.chain_pieces[1].has_chain_call);
}
#[test]
@ -733,12 +835,15 @@ mod tests {
let result = split_predicate_if_needed(pred, &params);
assert!(result.is_ok());
let chain = result.unwrap();
let split_result = result.unwrap();
let chain = &split_result.predicates;
assert_eq!(chain.len(), 2); // Should split into 2 predicates
// First predicate should have wildcards that cross boundary promoted
// Check that chain call is present
let last_stmt = &chain[0].statements.last().unwrap();
// Chain is reversed: continuation first, original last
// Original predicate should have chain call as last statement
let original = &chain[1];
assert_eq!(original.name.name, "complex");
let last_stmt = original.statements.last().unwrap();
assert_eq!(last_stmt.predicate.name, "complex_1");
}
@ -766,15 +871,29 @@ mod tests {
let result = split_predicate_if_needed(pred, &params);
assert!(result.is_ok());
let chain = result.unwrap();
let split_result = result.unwrap();
let chain = &split_result.predicates;
assert_eq!(chain.len(), 3); // Should split into 3 predicates
// First: 4 + chain call = 5
assert_eq!(chain[0].statements.len(), 5);
// Second: 4 + chain call = 5
// Chain is reversed: [_2, _1, original]
// chain[0]: large_pred_2 (3 remaining statements)
assert_eq!(chain[0].statements.len(), 3);
assert_eq!(chain[0].name.name, "large_pred_2");
// chain[1]: large_pred_1 (4 + chain call = 5)
assert_eq!(chain[1].statements.len(), 5);
// Third: 3 remaining
assert_eq!(chain[2].statements.len(), 3);
assert_eq!(chain[1].name.name, "large_pred_1");
// chain[2]: large_pred (4 + chain call = 5)
assert_eq!(chain[2].statements.len(), 5);
assert_eq!(chain[2].name.name, "large_pred");
// Verify chain_info
let chain_info = split_result.chain_info.as_ref().unwrap();
assert_eq!(chain_info.real_statement_count, 11);
assert_eq!(chain_info.chain_pieces.len(), 3);
// Execution order: innermost first
assert_eq!(chain_info.chain_pieces[0].name, "large_pred_2");
assert_eq!(chain_info.chain_pieces[1].name, "large_pred_1");
assert_eq!(chain_info.chain_pieces[2].name, "large_pred");
}
#[test]
@ -801,7 +920,8 @@ mod tests {
let result = split_predicate_if_needed(pred, &params);
assert!(result.is_ok());
let chain = result.unwrap();
let split_result = result.unwrap();
let chain = &split_result.predicates;
// Should split into 2 predicates
// T is used in first segment and crosses to second, then used again in second
assert_eq!(chain.len(), 2);
@ -867,7 +987,8 @@ mod tests {
let result = split_predicate_if_needed(pred, &params);
assert!(result.is_ok());
let chain = result.unwrap();
let split_result = result.unwrap();
let chain = &split_result.predicates;
assert_eq!(chain.len(), 2, "Predicate should split into 2 links");
let second_pred = &chain[1];