Splitter fixes (#483)

* Add failing tests

* Model statements depending on public args as cheaper than those depending on private args

* Tidying

* Fix unnecessary propagation of unused public args

* More tidying

* Tidy test comments

* Fix incorrect Delete arities
This commit is contained in:
Rob Knight 2026-02-20 17:13:42 +01:00 committed by GitHub
parent e950661090
commit f6c6ec43ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 280 additions and 140 deletions

View file

@ -287,7 +287,8 @@ fn format_public_args_at_split_error(
msg.push_str(&format!(
" Statements {}-{} in this segment\n",
context.statement_range.0, context.statement_range.1
context.statement_range.0,
context.statement_range.1 - 1
));
if !context.incoming_public.is_empty() {

View file

@ -15,10 +15,7 @@
//! We use a greedy algorithm to order the statements in a predicate to minimize
//! the number of live wildcards at split boundaries.
use std::{
cmp::Reverse,
collections::{HashMap, HashSet},
};
use std::{cmp::Reverse, collections::HashSet};
// SplittingError is now defined in error.rs
pub use crate::lang::error::SplittingError;
@ -33,8 +30,8 @@ pub struct ChainLink {
pub public_args_in: Vec<String>,
/// Private arguments used only in this link
pub private_args: Vec<String>,
/// Public arguments promoted to pass to next link (empty if last link)
pub public_args_out: Vec<String>,
/// Private wildcards promoted to public for the next link (empty if last link)
pub promoted_wildcards: Vec<String>,
}
/// Information about a single piece of a split predicate chain
@ -71,13 +68,6 @@ pub struct SplitResult {
pub chain_info: Option<SplitChainInfo>,
}
/// Wildcard usage information
#[derive(Debug, Clone)]
struct WildcardUsage {
/// Indices of statements using this wildcard
used_in_statements: HashSet<usize>,
}
/// Early validation: Check if predicate is fundamentally splittable
pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), SplittingError> {
let public_args = pred.args.public_args.len();
@ -121,26 +111,6 @@ pub fn split_predicate_if_needed(
})
}
fn analyze_wildcards(statements: &[StatementTmpl]) -> HashMap<String, WildcardUsage> {
let mut usage: HashMap<String, WildcardUsage> = HashMap::new();
for (idx, stmt) in statements.iter().enumerate() {
let wildcards = collect_wildcards_from_statement(stmt);
for wildcard in wildcards {
usage
.entry(wildcard.clone())
.or_insert_with(|| WildcardUsage {
used_in_statements: HashSet::new(),
})
.used_in_statements
.insert(idx);
}
}
usage
}
/// Collect all wildcard names from a statement
fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> {
let mut wildcards = HashSet::new();
@ -172,7 +142,7 @@ struct OrderingResult {
fn order_constraints_optimally(
statements: Vec<StatementTmpl>,
_usage: &HashMap<String, WildcardUsage>,
public_args: &HashSet<String>,
) -> OrderingResult {
let n = statements.len();
@ -190,8 +160,13 @@ fn order_constraints_optimally(
let mut active_wildcards: HashSet<String> = HashSet::new();
while !remaining.is_empty() {
let best_idx =
find_best_next_statement(&statements, &remaining, &active_wildcards, ordered.len());
let best_idx = find_best_next_statement(
&statements,
&remaining,
&active_wildcards,
ordered.len(),
public_args,
);
remaining.remove(&best_idx);
let stmt = &statements[best_idx];
@ -200,14 +175,20 @@ fn order_constraints_optimally(
reorder_map[best_idx] = ordered.len();
ordered.push(stmt.clone());
// Update active wildcards
// Only track private wildcards in the active set — public args are always
// available at every boundary so their liveness is irrelevant to split cost.
let stmt_wildcards = collect_wildcards_from_statement(stmt);
active_wildcards.extend(stmt_wildcards);
active_wildcards.extend(
stmt_wildcards
.into_iter()
.filter(|w| !public_args.contains(w)),
);
// Remove wildcards no longer needed by remaining statements
// Remove private wildcards no longer needed by remaining statements
let needed_later: HashSet<_> = remaining
.iter()
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
.filter(|w| !public_args.contains(w))
.collect();
active_wildcards.retain(|w| needed_later.contains(w));
}
@ -225,30 +206,35 @@ fn compute_tie_breakers(
active_wildcards: &HashSet<String>,
statements: &[StatementTmpl],
remaining: &HashSet<usize>,
needed_later: &HashSet<String>,
public_args: &HashSet<String>,
) -> (usize, usize, i32) {
let stmt_wildcards = collect_wildcards_from_statement(stmt);
// Metric 1: Simplicity - prefer statements with fewer wildcards
let simplicity = usize::MAX - stmt_wildcards.len();
// Metric 2: Public closure - prefer statements that close active wildcards
// (wildcards that won't be needed by any remaining statements)
let needed_later: HashSet<String> = remaining
.iter()
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
let all_wildcards = collect_wildcards_from_statement(stmt);
// Only consider private wildcards for tie-breaking metrics
let stmt_wildcards: HashSet<_> = all_wildcards
.into_iter()
.filter(|w| !public_args.contains(w))
.collect();
// Metric 1: Simplicity - prefer statements with fewer private wildcards
let simplicity = usize::MAX - stmt_wildcards.len();
// Metric 2: Closure - prefer statements that close active private wildcards
// (wildcards that won't be needed by any remaining statements)
let closes_count = stmt_wildcards
.intersection(active_wildcards)
.filter(|w| !needed_later.contains(*w))
.count();
// Metric 3: Fanout - prefer statements with lower future usage
// (number of remaining statements that use any wildcard from this statement)
// (number of remaining statements sharing private wildcards with this statement)
let fanout = remaining
.iter()
.filter(|&&i| {
let other_wildcards = collect_wildcards_from_statement(&statements[i]);
let other_wildcards: HashSet<_> = collect_wildcards_from_statement(&statements[i])
.into_iter()
.filter(|w| !public_args.contains(w))
.collect();
!stmt_wildcards.is_disjoint(&other_wildcards)
})
.count();
@ -262,16 +248,33 @@ fn statement_selection_key(
active_wildcards: &HashSet<String>,
remaining: &HashSet<usize>,
approaching_split: bool,
public_args: &HashSet<String>,
) -> (i32, (usize, usize, i32), Reverse<usize>) {
// Pre-compute needed_later once and share between primary score and tie-breakers.
// Exclude the candidate itself: we want to know what the *other* remaining statements
// need, so that wildcards used only by this candidate correctly appear as closeable.
let needed_later: HashSet<String> = remaining
.iter()
.filter(|&&i| i != idx)
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
.filter(|w| !public_args.contains(w))
.collect();
let primary_score = score_statement(
&statements[idx],
active_wildcards,
approaching_split,
public_args,
&needed_later,
);
let tie_breakers = compute_tie_breakers(
&statements[idx],
active_wildcards,
statements,
remaining,
approaching_split,
&needed_later,
public_args,
);
let tie_breakers =
compute_tie_breakers(&statements[idx], active_wildcards, statements, remaining);
// Final deterministic tie-breaker: prefer smaller original indices.
// This avoids hash-iteration-dependent selection when scores are equal.
@ -284,6 +287,7 @@ fn find_best_next_statement(
remaining: &HashSet<usize>,
active_wildcards: &HashSet<String>,
ordered_count: usize,
public_args: &HashSet<String>,
) -> usize {
// Calculate distance to next split point
let bucket_size = Params::max_custom_predicate_arity() - 1; // Reserve slot for chain call
@ -299,51 +303,66 @@ fn find_best_next_statement(
active_wildcards,
remaining,
approaching_split,
public_args,
)
})
.copied()
.unwrap()
}
/// Score a statement based on how well it minimizes liveness
/// Score a statement based on how well it minimizes private-wildcard liveness at boundaries.
/// `needed_later` is the set of private wildcards used by any remaining statement.
fn score_statement(
stmt: &StatementTmpl,
active_wildcards: &HashSet<String>,
statements: &[StatementTmpl],
remaining: &HashSet<usize>,
approaching_split: bool,
public_args: &HashSet<String>,
needed_later: &HashSet<String>,
) -> i32 {
let stmt_wildcards = collect_wildcards_from_statement(stmt);
let all_wildcards = collect_wildcards_from_statement(stmt);
// How many active wildcards does this reuse?
let reuse_count = stmt_wildcards.intersection(active_wildcards).count();
// How many new wildcards does this introduce?
let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count();
// After adding this statement, what would be active?
let mut projected_active = active_wildcards.clone();
projected_active.extend(stmt_wildcards.clone());
// Which wildcards are still needed by other remaining statements?
let needed_later: HashSet<String> = remaining
.iter()
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
// Only score based on private wildcards. Public args are always available at every
// split boundary — they never consume a promotion slot, so their liveness is free.
let stmt_wildcards: HashSet<_> = all_wildcards
.into_iter()
.filter(|w| !public_args.contains(w))
.collect();
// Wildcards we can close = active now but not needed later
// Statements that touch only public args ("cheap" statements) waste a bucket slot
// that could be used to cluster private wildcards. Strongly defer them while any
// private-wildcard statements remain, so they fill leftover space at the end.
// `needed_later` is non-empty iff some remaining statement has a private wildcard.
if stmt_wildcards.is_empty() {
return if needed_later.is_empty() {
0
} else {
i32::MIN / 2
};
}
// How many active private wildcards does this reuse?
let reuse_count = stmt_wildcards.intersection(active_wildcards).count();
// How many new private wildcards does this introduce?
let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count();
// Which of the projected-active wildcards are still needed after this statement?
let mut projected_active = active_wildcards.clone();
projected_active.extend(stmt_wildcards);
projected_active.retain(|w| needed_later.contains(w));
let still_active_count = projected_active.len();
// Base score calculation
// - Prefer statements that reuse active wildcards (don't introduce new liveness)
// - Penalize introducing new wildcards (increases liveness)
// - Penalize keeping many wildcards active (higher liveness)
// Base score:
// +3 per reused wildcard — rewards clustering (wildcard already open, no new cost)
// -4 per new wildcard — penalises opening new live ranges
// -2 per still-live — penalises carrying many wildcards toward the boundary
let base_score = (reuse_count * 3) as i32
- (new_wildcard_count * 4) as i32
- (still_active_count * 2) as i32;
// Look-ahead bonus: when approaching split, heavily favor closing wildcards
// When close to a split boundary, strongly reward statements that close wildcards
// (active.len() + new - still_active = number of wildcards resolved by this statement).
// Weight 10 >> max base-score magnitude to make closing the dominant factor.
if approaching_split {
let closes_count = active_wildcards.len() + new_wildcard_count - still_active_count;
base_score + (closes_count * 10) as i32
@ -375,8 +394,6 @@ fn calculate_live_wildcards(
fn generate_refactor_suggestion(
crossing_wildcards: &[String],
ordered_statements: &[StatementTmpl],
_pos: usize,
_end: usize,
) -> Option<crate::lang::error::RefactorSuggestion> {
use crate::lang::error::RefactorSuggestion;
@ -445,13 +462,6 @@ fn split_into_chain(
let original_name = pred.name.name.clone();
let conjunction = pred.conjunction_type;
let usage = analyze_wildcards(&pred.statements);
let real_statement_count = pred.statements.len();
let ordering_result = order_constraints_optimally(pred.statements, &usage);
let ordered_statements = ordering_result.statements;
let reorder_map = ordering_result.reorder_map;
let original_public_args: Vec<String> = pred
.args
.public_args
@ -459,6 +469,14 @@ fn split_into_chain(
.map(|id| id.name.clone())
.collect();
let public_args_set: HashSet<String> = original_public_args.iter().cloned().collect();
let real_statement_count = pred.statements.len();
let ordering_result = order_constraints_optimally(pred.statements, &public_args_set);
let ordered_statements = ordering_result.statements;
let reorder_map = ordering_result.reorder_map;
let mut chain_links = Vec::new();
let mut pos = 0;
let mut incoming_public = original_public_args.clone();
@ -501,8 +519,7 @@ fn split_into_chain(
total_public,
};
let suggestion =
generate_refactor_suggestion(&new_promotions, &ordered_statements, pos, end);
let suggestion = generate_refactor_suggestion(&new_promotions, &ordered_statements);
return Err(SplittingError::TooManyPublicArgsAtSplit {
predicate: original_name.clone(),
@ -540,23 +557,60 @@ fn split_into_chain(
});
}
let mut public_args_out: Vec<String> = live_at_boundary.iter().cloned().collect();
public_args_out.sort(); // Deterministic ordering
chain_links.push(ChainLink {
statements: ordered_statements[pos..end].to_vec(),
public_args_in: incoming_public.clone(),
private_args,
public_args_out: public_args_out.clone(),
// new_promotions are already sorted and already filtered to exclude incoming_public
promoted_wildcards: new_promotions.clone(),
});
pos = end;
// Next link's incoming public args = current incoming + newly promoted live wildcards
// Only add wildcards that aren't already in incoming_public to avoid duplicates
for wildcard in public_args_out {
if !incoming_set.contains(&wildcard) {
incoming_public.push(wildcard);
// Extend incoming_public for the next link with the newly promoted wildcards.
// new_promotions is already filtered to exclude incoming_set, so no dedup needed.
incoming_public.extend(new_promotions);
}
// Backward pass: prune each continuation's public args to the minimal set needed.
//
// The forward pass accumulates incoming_public monotonically, so a continuation may
// inherit original public args that none of its statements (or downstream continuations)
// ever reference. A continuation must declare every public arg it receives, and the
// proof system constrains each declared arg - an arg that goes unused has no constraints
// and will not match the value the caller passes.
//
// Propagating from the last link backward ensures each continuation declares exactly the
// args it uses directly, plus any args its successor still needs. Link 0 (the original
// predicate) is left untouched - its public-arg signature is user-declared.
{
let num_links = chain_links.len();
if num_links > 1 {
// Collect wildcards referenced by each link's statements once.
let link_wildcards: Vec<HashSet<String>> = chain_links
.iter()
.map(|link| {
link.statements
.iter()
.flat_map(collect_wildcards_from_statement)
.collect()
})
.collect();
let last = num_links - 1;
// Seed: last link retains only args it directly references.
chain_links[last]
.public_args_in
.retain(|a| link_wildcards[last].contains(a));
// Propagate backward through intermediate continuation links (skip link 0).
for i in (1..last).rev() {
let needed_downstream: HashSet<String> =
chain_links[i + 1].public_args_in.iter().cloned().collect();
chain_links[i]
.public_args_in
.retain(|a| link_wildcards[i].contains(a) || needed_downstream.contains(a));
}
}
}
@ -590,7 +644,7 @@ fn split_into_chain(
let mut chain_predicates =
generate_chain_predicates(&original_name, chain_links, conjunction, params)?;
validate_chain(&chain_predicates, params)?;
validate_chain(&chain_predicates, params);
// Reverse so continuations come before callers in declaration order.
// This ensures that when batched, continuations are in earlier batches
@ -633,7 +687,7 @@ fn generate_chain_predicates(
};
// Create arguments for chain call: use next link's public_args_in
// which is the deduplicated union of current public_args_in and public_args_out
// which is current public_args_in extended with current promoted_wildcards
let next_link = &chain_links[i + 1];
let chain_call_args: Vec<StatementTmplArg> = next_link
.public_args_in
@ -665,10 +719,12 @@ fn generate_chain_predicates(
})
.collect();
// Build private args (private + promoted for next)
// Build private args: segment-local private wildcards, plus any wildcards being
// promoted to public for the next link (they must be declared here so the solver
// can bind them before passing them as public args to the continuation).
let mut private_arg_names = link.private_args.clone();
if !is_last {
private_arg_names.extend(link.public_args_out.clone());
private_arg_names.extend(link.promoted_wildcards.iter().cloned());
}
let private_args = if private_arg_names.is_empty() {
@ -698,25 +754,34 @@ fn generate_chain_predicates(
Ok(predicates)
}
/// Phase 5: Validate the generated chain
///
/// Note: We no longer check chain length against max_custom_batch_size since
/// chains can now span multiple batches thanks to multi-batch support.
fn validate_chain(chain: &[CustomPredicateDef], params: &Params) -> Result<(), SplittingError> {
/// Sanity-check the generated chain. All three constraints are enforced as proper errors
/// earlier in `split_into_chain`, so violations here indicate a bug in the algorithm.
fn validate_chain(chain: &[CustomPredicateDef], params: &Params) {
for pred in chain {
// Each predicate should have ≤ max_statements
assert!(pred.statements.len() <= Params::max_custom_predicate_arity());
// Public args should fit
assert!(pred.args.public_args.len() <= Params::max_statement_args());
// Total args should fit
assert!(
pred.statements.len() <= Params::max_custom_predicate_arity(),
"chain link '{}' has {} statements, exceeds max {}",
pred.name.name,
pred.statements.len(),
Params::max_custom_predicate_arity(),
);
assert!(
pred.args.public_args.len() <= Params::max_statement_args(),
"chain link '{}' has {} public args, exceeds max {}",
pred.name.name,
pred.args.public_args.len(),
Params::max_statement_args(),
);
let total =
pred.args.public_args.len() + pred.args.private_args.as_ref().map_or(0, |v| v.len());
assert!(total <= params.max_custom_predicate_wildcards);
assert!(
total <= params.max_custom_predicate_wildcards,
"chain link '{}' has {} total args, exceeds max {}",
pred.name.name,
total,
params.max_custom_predicate_wildcards,
);
}
Ok(())
}
#[cfg(test)]
@ -976,8 +1041,24 @@ mod tests {
let remaining: HashSet<usize> = [0, 1].into_iter().collect();
let active_wildcards = HashSet::new();
let key0 = statement_selection_key(0, &statements, &active_wildcards, &remaining, false);
let key1 = statement_selection_key(1, &statements, &active_wildcards, &remaining, false);
// A and B are the public args of tie_break(A, B)
let public_args: HashSet<String> = ["A".to_string(), "B".to_string()].into_iter().collect();
let key0 = statement_selection_key(
0,
&statements,
&active_wildcards,
&remaining,
false,
&public_args,
);
let key1 = statement_selection_key(
1,
&statements,
&active_wildcards,
&remaining,
false,
&public_args,
);
assert_eq!(key0.0, key1.0, "Primary heuristic score should tie");
assert_eq!(key0.1, key1.1, "Secondary tie-breaker metrics should tie");
@ -986,7 +1067,8 @@ mod tests {
"Lower original index should win deterministic final tie-breaker"
);
let selected = find_best_next_statement(&statements, &remaining, &active_wildcards, 0);
let selected =
find_best_next_statement(&statements, &remaining, &active_wildcards, 0, &public_args);
assert_eq!(selected, 0);
}
@ -1084,7 +1166,7 @@ mod tests {
assert!(error_msg.contains("3 crossing wildcards"));
assert!(error_msg.contains("= 6 total"));
assert!(error_msg.contains("exceeds max of 5"));
assert!(error_msg.contains("Statements 0-4"));
assert!(error_msg.contains("Statements 0-3"));
assert!(error_msg.contains("Incoming public args: A, B, C"));
assert!(error_msg.contains("Wildcards crossing this boundary: T1, T2, T3"));
assert!(error_msg.contains("Suggestion:"));
@ -1144,25 +1226,82 @@ mod tests {
);
}
// --- Regression tests ---
/// Statements that reference only public args should be deferred until private-wildcard
/// statements have been clustered, so they don't consume bucket slots that would reduce
/// liveness at split boundaries.
///
/// 4 public args, 7 statements: W1 used in stmts 0,1,4; W2 used in stmts 1,2,3;
/// stmts 5,6 reference only public args. The scoring correctly defers stmts 5,6,
/// yielding bucket0={0,1,2,3}, bucket1={4,5,6} with only W1 crossing (4+1=5 <= max).
#[test]
fn test_refactor_suggestion_group_wildcards() {
// Test the "group wildcard usages" suggestion formatting
use crate::lang::error::RefactorSuggestion;
fn test_split_succeeds_with_four_public_args_and_public_only_statements() {
// Optimal split: bucket0={0,1,2,3}, bucket1={4,5,6}
// Only W1 crosses (used in 0,1 and 4), total = 4 public + 1 crossing = 5 ✓
let input = r#"
pred(A, B, C, D, private: W1, W2) = AND(
Equal(W1["x"], A["v"])
Equal(W2["y"], W1["x"])
Equal(W2["z"], B["v"])
Equal(C["r"], W2["y"])
Equal(D["s"], W1["x"])
Equal(A["out"], C["out"])
Equal(B["out"], D["out"])
)
"#;
let suggestion = RefactorSuggestion::GroupWildcardUsages {
wildcards: vec!["T1".to_string(), "T2".to_string(), "T3".to_string()],
};
let pred = parse_predicate(input);
let params = Params::default();
let suggestion_text = suggestion.format();
let result = split_predicate_if_needed(pred, &params);
assert!(
result.is_ok(),
"Should find a valid split with ≤1 crossing wildcard, got: {:?}",
result.err()
);
}
// Verify the suggestion formats correctly
assert!(suggestion_text.contains("Group operations for wildcards"));
assert!(suggestion_text.contains("T1, T2, T3"));
assert!(suggestion_text.contains("used across multiple segments"));
/// Continuation predicates should only declare the public args they actually use -
/// original public args that are not referenced in a continuation's statements or
/// any of its downstream continuations must be omitted from its signature.
#[test]
fn test_continuation_excludes_public_args_unused_after_split() {
// A is used only in the first segment; B is used only in the second segment.
// The continuation predicate (pred_1) must include B but not A.
let input = r#"
pred(A, B, private: T) = AND(
Equal(T["x"], A["val"])
Equal(T["y"], 1)
Equal(T["z"], 2)
Equal(T["w"], 3)
Equal(B["r"], T["x"])
Equal(B["s"], T["y"])
)
"#;
eprintln!(
"\n=== Example GroupWildcardUsages Suggestion ===\n{}\n",
suggestion_text
let pred = parse_predicate(input);
let params = Params::default();
let result = split_predicate_if_needed(pred, &params).unwrap();
// chain[0] is the continuation (_1 suffix), chain[1] is the original
let continuation = result
.predicates
.iter()
.find(|p| p.name.name == "pred_1")
.expect("Expected a pred_1 continuation predicate");
let cont_public: Vec<&str> = continuation
.args
.public_args
.iter()
.map(|id| id.name.as_str())
.collect();
assert!(
!cont_public.contains(&"A"),
"Continuation should drop unused public arg 'A', got: {:?}",
cont_public
);
}
}