Improved predicate splitting (#445)
* Multi-batch splitting * Invoke split predicates by name, passing in full argument list * Reorder batches to prevent failure of forward references where possible * Rename APIs for clarity * Simplify example * Add more docs * Review updates * Remove duplicate code * Comment topological sort algorithm
This commit is contained in:
parent
9c9a2c454c
commit
d1b7b4d37e
12 changed files with 2090 additions and 466 deletions
|
|
@ -34,6 +34,40 @@ pub struct ChainLink {
|
|||
pub public_args_out: Vec<String>,
|
||||
}
|
||||
|
||||
/// Information about a single piece of a split predicate chain
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SplitChainPiece {
|
||||
/// Name of this predicate piece (e.g., "foo_1")
|
||||
pub name: String,
|
||||
/// Number of "real" statements in this piece (excludes chain call)
|
||||
pub real_statement_count: usize,
|
||||
/// Whether this piece has a chain call to the next piece
|
||||
pub has_chain_call: bool,
|
||||
}
|
||||
|
||||
/// Metadata about a split predicate chain
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SplitChainInfo {
|
||||
/// Original predicate name (e.g., "foo")
|
||||
pub original_name: String,
|
||||
/// Chain pieces in execution order (innermost continuation first: [foo_2, foo_1, foo])
|
||||
pub chain_pieces: Vec<SplitChainPiece>,
|
||||
/// Total number of "real" user statements (excludes chain calls)
|
||||
pub real_statement_count: usize,
|
||||
/// Maps original statement index → reordered index
|
||||
/// e.g., if original stmt 0 became reordered stmt 3, then `reorder_map[0] = 3`
|
||||
pub reorder_map: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Result of splitting a predicate
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SplitResult {
|
||||
/// The predicates (continuations first, original last if split)
|
||||
pub predicates: Vec<CustomPredicateDef>,
|
||||
/// Split chain info, if splitting occurred (None for non-split)
|
||||
pub chain_info: Option<SplitChainInfo>,
|
||||
}
|
||||
|
||||
/// Wildcard usage information
|
||||
#[derive(Debug, Clone)]
|
||||
struct WildcardUsage {
|
||||
|
|
@ -66,19 +100,25 @@ pub fn validate_predicate_is_splittable(
|
|||
pub fn split_predicate_if_needed(
|
||||
pred: CustomPredicateDef,
|
||||
params: &Params,
|
||||
) -> Result<Vec<CustomPredicateDef>, SplittingError> {
|
||||
) -> Result<SplitResult, SplittingError> {
|
||||
// Early validation
|
||||
validate_predicate_is_splittable(&pred, params)?;
|
||||
|
||||
// If within limits, no splitting needed
|
||||
if pred.statements.len() <= params.max_custom_predicate_arity {
|
||||
return Ok(vec![pred]);
|
||||
return Ok(SplitResult {
|
||||
predicates: vec![pred],
|
||||
chain_info: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Need to split - execute the splitting algorithm
|
||||
let chain = split_into_chain(pred, params)?;
|
||||
let (predicates, chain_info) = split_into_chain(pred, params)?;
|
||||
|
||||
Ok(chain)
|
||||
Ok(SplitResult {
|
||||
predicates,
|
||||
chain_info: Some(chain_info),
|
||||
})
|
||||
}
|
||||
|
||||
fn analyze_wildcards(statements: &[StatementTmpl]) -> HashMap<String, WildcardUsage> {
|
||||
|
|
@ -121,18 +161,33 @@ fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> {
|
|||
}
|
||||
|
||||
/// Order constraints optimally to minimize liveness at boundaries
|
||||
/// Result of ordering statements optimally for splitting
|
||||
struct OrderingResult {
|
||||
/// Reordered statements
|
||||
statements: Vec<StatementTmpl>,
|
||||
/// Maps original statement index → reordered index
|
||||
/// reorder_map[original_idx] = new_idx
|
||||
reorder_map: Vec<usize>,
|
||||
}
|
||||
|
||||
fn order_constraints_optimally(
|
||||
statements: Vec<StatementTmpl>,
|
||||
_usage: &HashMap<String, WildcardUsage>,
|
||||
params: &Params,
|
||||
) -> Vec<StatementTmpl> {
|
||||
// If no splitting needed, preserve original order
|
||||
if statements.len() <= params.max_custom_predicate_arity {
|
||||
return statements;
|
||||
) -> OrderingResult {
|
||||
let n = statements.len();
|
||||
|
||||
// If no splitting needed, preserve original order (identity mapping)
|
||||
if n <= params.max_custom_predicate_arity {
|
||||
return OrderingResult {
|
||||
statements,
|
||||
reorder_map: (0..n).collect(),
|
||||
};
|
||||
}
|
||||
|
||||
let mut ordered = Vec::new();
|
||||
let mut remaining: HashSet<usize> = (0..statements.len()).collect();
|
||||
let mut reorder_map = vec![0; n];
|
||||
let mut remaining: HashSet<usize> = (0..n).collect();
|
||||
let mut active_wildcards: HashSet<String> = HashSet::new();
|
||||
|
||||
while !remaining.is_empty() {
|
||||
|
|
@ -146,6 +201,9 @@ fn order_constraints_optimally(
|
|||
|
||||
remaining.remove(&best_idx);
|
||||
let stmt = &statements[best_idx];
|
||||
|
||||
// Record the mapping: original index best_idx → new index ordered.len()
|
||||
reorder_map[best_idx] = ordered.len();
|
||||
ordered.push(stmt.clone());
|
||||
|
||||
// Update active wildcards
|
||||
|
|
@ -160,7 +218,10 @@ fn order_constraints_optimally(
|
|||
active_wildcards.retain(|w| needed_later.contains(w));
|
||||
}
|
||||
|
||||
ordered
|
||||
OrderingResult {
|
||||
statements: ordered,
|
||||
reorder_map,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute tie-breaker metrics for deterministic ordering when scores are equal
|
||||
|
|
@ -360,16 +421,20 @@ fn generate_refactor_suggestion(
|
|||
}
|
||||
|
||||
/// Split into chain using bucket-filling approach
|
||||
/// Returns the split predicates and metadata about the split
|
||||
fn split_into_chain(
|
||||
pred: CustomPredicateDef,
|
||||
params: &Params,
|
||||
) -> Result<Vec<CustomPredicateDef>, SplittingError> {
|
||||
) -> Result<(Vec<CustomPredicateDef>, SplitChainInfo), SplittingError> {
|
||||
let original_name = pred.name.name.clone();
|
||||
let conjunction = pred.conjunction_type;
|
||||
|
||||
let usage = analyze_wildcards(&pred.statements);
|
||||
let real_statement_count = pred.statements.len();
|
||||
|
||||
let ordered_statements = order_constraints_optimally(pred.statements, &usage, params);
|
||||
let ordering_result = order_constraints_optimally(pred.statements, &usage, params);
|
||||
let ordered_statements = ordering_result.statements;
|
||||
let reorder_map = ordering_result.reorder_map;
|
||||
|
||||
let original_public_args: Vec<String> = pred
|
||||
.args
|
||||
|
|
@ -479,12 +544,43 @@ fn split_into_chain(
|
|||
}
|
||||
}
|
||||
|
||||
let chain_predicates =
|
||||
// Build SplitChainInfo from chain_links before generating predicates
|
||||
// Pieces are in execution order: innermost continuation first, original last
|
||||
let num_links = chain_links.len();
|
||||
let mut chain_pieces = Vec::new();
|
||||
for i in (0..num_links).rev() {
|
||||
let link = &chain_links[i];
|
||||
let is_last = i == num_links - 1;
|
||||
let name = if i == 0 {
|
||||
original_name.clone()
|
||||
} else {
|
||||
format!("{}_{}", original_name, i)
|
||||
};
|
||||
chain_pieces.push(SplitChainPiece {
|
||||
name,
|
||||
real_statement_count: link.statements.len(),
|
||||
has_chain_call: !is_last,
|
||||
});
|
||||
}
|
||||
|
||||
let chain_info = SplitChainInfo {
|
||||
original_name: original_name.clone(),
|
||||
chain_pieces,
|
||||
real_statement_count,
|
||||
reorder_map,
|
||||
};
|
||||
|
||||
let mut chain_predicates =
|
||||
generate_chain_predicates(&original_name, chain_links, conjunction, params)?;
|
||||
|
||||
validate_chain(&chain_predicates, &original_name, params)?;
|
||||
validate_chain(&chain_predicates, params)?;
|
||||
|
||||
Ok(chain_predicates)
|
||||
// Reverse so continuations come before callers in declaration order.
|
||||
// This ensures that when batched, continuations are in earlier batches
|
||||
// and can be referenced by their callers.
|
||||
chain_predicates.reverse();
|
||||
|
||||
Ok((chain_predicates, chain_info))
|
||||
}
|
||||
|
||||
/// Phase 4: Generate synthetic predicates from chain links
|
||||
|
|
@ -519,20 +615,19 @@ fn generate_chain_predicates(
|
|||
span: None,
|
||||
};
|
||||
|
||||
// Create arguments for chain call: all public args (incoming + promoted)
|
||||
let mut chain_call_args = Vec::new();
|
||||
for arg_name in &link.public_args_in {
|
||||
chain_call_args.push(StatementTmplArg::Wildcard(Identifier {
|
||||
name: arg_name.clone(),
|
||||
span: None,
|
||||
}));
|
||||
}
|
||||
for arg_name in &link.public_args_out {
|
||||
chain_call_args.push(StatementTmplArg::Wildcard(Identifier {
|
||||
name: arg_name.clone(),
|
||||
span: None,
|
||||
}));
|
||||
}
|
||||
// Create arguments for chain call: use next link's public_args_in
|
||||
// which is the deduplicated union of current public_args_in and public_args_out
|
||||
let next_link = &chain_links[i + 1];
|
||||
let chain_call_args: Vec<StatementTmplArg> = next_link
|
||||
.public_args_in
|
||||
.iter()
|
||||
.map(|arg_name| {
|
||||
StatementTmplArg::Wildcard(Identifier {
|
||||
name: arg_name.clone(),
|
||||
span: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let chain_call = StatementTmpl {
|
||||
predicate: next_pred_name,
|
||||
|
|
@ -587,19 +682,10 @@ fn generate_chain_predicates(
|
|||
}
|
||||
|
||||
/// Phase 5: Validate the generated chain
|
||||
fn validate_chain(
|
||||
chain: &[CustomPredicateDef],
|
||||
original_name: &str,
|
||||
params: &Params,
|
||||
) -> Result<(), SplittingError> {
|
||||
if chain.len() > params.max_custom_batch_size {
|
||||
return Err(SplittingError::TooManyPredicatesInChain {
|
||||
predicate: original_name.to_string(),
|
||||
count: chain.len(),
|
||||
max_allowed: params.max_custom_batch_size,
|
||||
});
|
||||
}
|
||||
|
||||
///
|
||||
/// Note: We no longer check chain length against max_custom_batch_size since
|
||||
/// chains can now span multiple batches thanks to multi-batch support.
|
||||
fn validate_chain(chain: &[CustomPredicateDef], params: &Params) -> Result<(), SplittingError> {
|
||||
for pred in chain {
|
||||
// Each predicate should have ≤ max_statements
|
||||
assert!(pred.statements.len() <= params.max_custom_predicate_arity);
|
||||
|
|
@ -681,8 +767,9 @@ mod tests {
|
|||
let result = split_predicate_if_needed(pred, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let chain = result.unwrap();
|
||||
assert_eq!(chain.len(), 1); // No split needed
|
||||
let split_result = result.unwrap();
|
||||
assert_eq!(split_result.predicates.len(), 1); // No split needed
|
||||
assert!(split_result.chain_info.is_none()); // No chain info for non-split
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -704,14 +791,29 @@ mod tests {
|
|||
let result = split_predicate_if_needed(pred, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let chain = result.unwrap();
|
||||
let split_result = result.unwrap();
|
||||
let chain = &split_result.predicates;
|
||||
assert_eq!(chain.len(), 2); // Should split into 2 predicates
|
||||
|
||||
// First predicate: 4 statements + chain call = 5
|
||||
assert_eq!(chain[0].statements.len(), 5);
|
||||
// Chain is reversed: continuation comes first, original comes last
|
||||
// Last predicate (index 1): original name, 4 statements + chain call = 5
|
||||
assert_eq!(chain[1].statements.len(), 5);
|
||||
assert_eq!(chain[1].name.name, "my_pred");
|
||||
|
||||
// Second predicate: 2 remaining statements
|
||||
assert_eq!(chain[1].statements.len(), 2);
|
||||
// First predicate (index 0): continuation, 2 remaining statements
|
||||
assert_eq!(chain[0].statements.len(), 2);
|
||||
assert_eq!(chain[0].name.name, "my_pred_1");
|
||||
|
||||
// Verify chain_info is present
|
||||
let chain_info = split_result.chain_info.as_ref().unwrap();
|
||||
assert_eq!(chain_info.original_name, "my_pred");
|
||||
assert_eq!(chain_info.real_statement_count, 6);
|
||||
assert_eq!(chain_info.chain_pieces.len(), 2);
|
||||
// Pieces are in execution order: innermost first
|
||||
assert_eq!(chain_info.chain_pieces[0].name, "my_pred_1");
|
||||
assert!(!chain_info.chain_pieces[0].has_chain_call);
|
||||
assert_eq!(chain_info.chain_pieces[1].name, "my_pred");
|
||||
assert!(chain_info.chain_pieces[1].has_chain_call);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -733,12 +835,15 @@ mod tests {
|
|||
let result = split_predicate_if_needed(pred, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let chain = result.unwrap();
|
||||
let split_result = result.unwrap();
|
||||
let chain = &split_result.predicates;
|
||||
assert_eq!(chain.len(), 2); // Should split into 2 predicates
|
||||
|
||||
// First predicate should have wildcards that cross boundary promoted
|
||||
// Check that chain call is present
|
||||
let last_stmt = &chain[0].statements.last().unwrap();
|
||||
// Chain is reversed: continuation first, original last
|
||||
// Original predicate should have chain call as last statement
|
||||
let original = &chain[1];
|
||||
assert_eq!(original.name.name, "complex");
|
||||
let last_stmt = original.statements.last().unwrap();
|
||||
assert_eq!(last_stmt.predicate.name, "complex_1");
|
||||
}
|
||||
|
||||
|
|
@ -766,15 +871,29 @@ mod tests {
|
|||
let result = split_predicate_if_needed(pred, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let chain = result.unwrap();
|
||||
let split_result = result.unwrap();
|
||||
let chain = &split_result.predicates;
|
||||
assert_eq!(chain.len(), 3); // Should split into 3 predicates
|
||||
|
||||
// First: 4 + chain call = 5
|
||||
assert_eq!(chain[0].statements.len(), 5);
|
||||
// Second: 4 + chain call = 5
|
||||
// Chain is reversed: [_2, _1, original]
|
||||
// chain[0]: large_pred_2 (3 remaining statements)
|
||||
assert_eq!(chain[0].statements.len(), 3);
|
||||
assert_eq!(chain[0].name.name, "large_pred_2");
|
||||
// chain[1]: large_pred_1 (4 + chain call = 5)
|
||||
assert_eq!(chain[1].statements.len(), 5);
|
||||
// Third: 3 remaining
|
||||
assert_eq!(chain[2].statements.len(), 3);
|
||||
assert_eq!(chain[1].name.name, "large_pred_1");
|
||||
// chain[2]: large_pred (4 + chain call = 5)
|
||||
assert_eq!(chain[2].statements.len(), 5);
|
||||
assert_eq!(chain[2].name.name, "large_pred");
|
||||
|
||||
// Verify chain_info
|
||||
let chain_info = split_result.chain_info.as_ref().unwrap();
|
||||
assert_eq!(chain_info.real_statement_count, 11);
|
||||
assert_eq!(chain_info.chain_pieces.len(), 3);
|
||||
// Execution order: innermost first
|
||||
assert_eq!(chain_info.chain_pieces[0].name, "large_pred_2");
|
||||
assert_eq!(chain_info.chain_pieces[1].name, "large_pred_1");
|
||||
assert_eq!(chain_info.chain_pieces[2].name, "large_pred");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -801,7 +920,8 @@ mod tests {
|
|||
let result = split_predicate_if_needed(pred, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let chain = result.unwrap();
|
||||
let split_result = result.unwrap();
|
||||
let chain = &split_result.predicates;
|
||||
// Should split into 2 predicates
|
||||
// T is used in first segment and crosses to second, then used again in second
|
||||
assert_eq!(chain.len(), 2);
|
||||
|
|
@ -867,7 +987,8 @@ mod tests {
|
|||
let result = split_predicate_if_needed(pred, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let chain = result.unwrap();
|
||||
let split_result = result.unwrap();
|
||||
let chain = &split_result.predicates;
|
||||
assert_eq!(chain.len(), 2, "Predicate should split into 2 links");
|
||||
|
||||
let second_pred = &chain[1];
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue