Splitter fixes (#483)

* Add failing tests

* Model statements depending on public args as cheaper than those depending on private args

* Tidying

* Fix unnecessary propagation of unused public args

* More tidying

* Tidy test comments

* Fix incorrect Delete arities
This commit is contained in:
Rob Knight 2026-02-20 17:13:42 +01:00 committed by GitHub
parent e950661090
commit f6c6ec43ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 280 additions and 140 deletions

View file

@ -287,7 +287,8 @@ fn format_public_args_at_split_error(
msg.push_str(&format!( msg.push_str(&format!(
" Statements {}-{} in this segment\n", " Statements {}-{} in this segment\n",
context.statement_range.0, context.statement_range.1 context.statement_range.0,
context.statement_range.1 - 1
)); ));
if !context.incoming_public.is_empty() { if !context.incoming_public.is_empty() {

View file

@ -15,10 +15,7 @@
//! We use a greedy algorithm to order the statements in a predicate to minimize //! We use a greedy algorithm to order the statements in a predicate to minimize
//! the number of live wildcards at split boundaries. //! the number of live wildcards at split boundaries.
use std::{ use std::{cmp::Reverse, collections::HashSet};
cmp::Reverse,
collections::{HashMap, HashSet},
};
// SplittingError is now defined in error.rs // SplittingError is now defined in error.rs
pub use crate::lang::error::SplittingError; pub use crate::lang::error::SplittingError;
@ -33,8 +30,8 @@ pub struct ChainLink {
pub public_args_in: Vec<String>, pub public_args_in: Vec<String>,
/// Private arguments used only in this link /// Private arguments used only in this link
pub private_args: Vec<String>, pub private_args: Vec<String>,
/// Public arguments promoted to pass to next link (empty if last link) /// Private wildcards promoted to public for the next link (empty if last link)
pub public_args_out: Vec<String>, pub promoted_wildcards: Vec<String>,
} }
/// Information about a single piece of a split predicate chain /// Information about a single piece of a split predicate chain
@ -71,13 +68,6 @@ pub struct SplitResult {
pub chain_info: Option<SplitChainInfo>, pub chain_info: Option<SplitChainInfo>,
} }
/// Wildcard usage information
#[derive(Debug, Clone)]
struct WildcardUsage {
/// Indices of statements using this wildcard
used_in_statements: HashSet<usize>,
}
/// Early validation: Check if predicate is fundamentally splittable /// Early validation: Check if predicate is fundamentally splittable
pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), SplittingError> { pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), SplittingError> {
let public_args = pred.args.public_args.len(); let public_args = pred.args.public_args.len();
@ -121,26 +111,6 @@ pub fn split_predicate_if_needed(
}) })
} }
fn analyze_wildcards(statements: &[StatementTmpl]) -> HashMap<String, WildcardUsage> {
let mut usage: HashMap<String, WildcardUsage> = HashMap::new();
for (idx, stmt) in statements.iter().enumerate() {
let wildcards = collect_wildcards_from_statement(stmt);
for wildcard in wildcards {
usage
.entry(wildcard.clone())
.or_insert_with(|| WildcardUsage {
used_in_statements: HashSet::new(),
})
.used_in_statements
.insert(idx);
}
}
usage
}
/// Collect all wildcard names from a statement /// Collect all wildcard names from a statement
fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> { fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> {
let mut wildcards = HashSet::new(); let mut wildcards = HashSet::new();
@ -172,7 +142,7 @@ struct OrderingResult {
fn order_constraints_optimally( fn order_constraints_optimally(
statements: Vec<StatementTmpl>, statements: Vec<StatementTmpl>,
_usage: &HashMap<String, WildcardUsage>, public_args: &HashSet<String>,
) -> OrderingResult { ) -> OrderingResult {
let n = statements.len(); let n = statements.len();
@ -190,8 +160,13 @@ fn order_constraints_optimally(
let mut active_wildcards: HashSet<String> = HashSet::new(); let mut active_wildcards: HashSet<String> = HashSet::new();
while !remaining.is_empty() { while !remaining.is_empty() {
let best_idx = let best_idx = find_best_next_statement(
find_best_next_statement(&statements, &remaining, &active_wildcards, ordered.len()); &statements,
&remaining,
&active_wildcards,
ordered.len(),
public_args,
);
remaining.remove(&best_idx); remaining.remove(&best_idx);
let stmt = &statements[best_idx]; let stmt = &statements[best_idx];
@ -200,14 +175,20 @@ fn order_constraints_optimally(
reorder_map[best_idx] = ordered.len(); reorder_map[best_idx] = ordered.len();
ordered.push(stmt.clone()); ordered.push(stmt.clone());
// Update active wildcards // Only track private wildcards in the active set — public args are always
// available at every boundary so their liveness is irrelevant to split cost.
let stmt_wildcards = collect_wildcards_from_statement(stmt); let stmt_wildcards = collect_wildcards_from_statement(stmt);
active_wildcards.extend(stmt_wildcards); active_wildcards.extend(
stmt_wildcards
.into_iter()
.filter(|w| !public_args.contains(w)),
);
// Remove wildcards no longer needed by remaining statements // Remove private wildcards no longer needed by remaining statements
let needed_later: HashSet<_> = remaining let needed_later: HashSet<_> = remaining
.iter() .iter()
.flat_map(|&i| collect_wildcards_from_statement(&statements[i])) .flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
.filter(|w| !public_args.contains(w))
.collect(); .collect();
active_wildcards.retain(|w| needed_later.contains(w)); active_wildcards.retain(|w| needed_later.contains(w));
} }
@ -225,30 +206,35 @@ fn compute_tie_breakers(
active_wildcards: &HashSet<String>, active_wildcards: &HashSet<String>,
statements: &[StatementTmpl], statements: &[StatementTmpl],
remaining: &HashSet<usize>, remaining: &HashSet<usize>,
needed_later: &HashSet<String>,
public_args: &HashSet<String>,
) -> (usize, usize, i32) { ) -> (usize, usize, i32) {
let stmt_wildcards = collect_wildcards_from_statement(stmt); let all_wildcards = collect_wildcards_from_statement(stmt);
// Only consider private wildcards for tie-breaking metrics
// Metric 1: Simplicity - prefer statements with fewer wildcards let stmt_wildcards: HashSet<_> = all_wildcards
let simplicity = usize::MAX - stmt_wildcards.len(); .into_iter()
.filter(|w| !public_args.contains(w))
// Metric 2: Public closure - prefer statements that close active wildcards
// (wildcards that won't be needed by any remaining statements)
let needed_later: HashSet<String> = remaining
.iter()
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
.collect(); .collect();
// Metric 1: Simplicity - prefer statements with fewer private wildcards
let simplicity = usize::MAX - stmt_wildcards.len();
// Metric 2: Closure - prefer statements that close active private wildcards
// (wildcards that won't be needed by any remaining statements)
let closes_count = stmt_wildcards let closes_count = stmt_wildcards
.intersection(active_wildcards) .intersection(active_wildcards)
.filter(|w| !needed_later.contains(*w)) .filter(|w| !needed_later.contains(*w))
.count(); .count();
// Metric 3: Fanout - prefer statements with lower future usage // Metric 3: Fanout - prefer statements with lower future usage
// (number of remaining statements that use any wildcard from this statement) // (number of remaining statements sharing private wildcards with this statement)
let fanout = remaining let fanout = remaining
.iter() .iter()
.filter(|&&i| { .filter(|&&i| {
let other_wildcards = collect_wildcards_from_statement(&statements[i]); let other_wildcards: HashSet<_> = collect_wildcards_from_statement(&statements[i])
.into_iter()
.filter(|w| !public_args.contains(w))
.collect();
!stmt_wildcards.is_disjoint(&other_wildcards) !stmt_wildcards.is_disjoint(&other_wildcards)
}) })
.count(); .count();
@ -262,16 +248,33 @@ fn statement_selection_key(
active_wildcards: &HashSet<String>, active_wildcards: &HashSet<String>,
remaining: &HashSet<usize>, remaining: &HashSet<usize>,
approaching_split: bool, approaching_split: bool,
public_args: &HashSet<String>,
) -> (i32, (usize, usize, i32), Reverse<usize>) { ) -> (i32, (usize, usize, i32), Reverse<usize>) {
// Pre-compute needed_later once and share between primary score and tie-breakers.
// Exclude the candidate itself: we want to know what the *other* remaining statements
// need, so that wildcards used only by this candidate correctly appear as closeable.
let needed_later: HashSet<String> = remaining
.iter()
.filter(|&&i| i != idx)
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
.filter(|w| !public_args.contains(w))
.collect();
let primary_score = score_statement( let primary_score = score_statement(
&statements[idx],
active_wildcards,
approaching_split,
public_args,
&needed_later,
);
let tie_breakers = compute_tie_breakers(
&statements[idx], &statements[idx],
active_wildcards, active_wildcards,
statements, statements,
remaining, remaining,
approaching_split, &needed_later,
public_args,
); );
let tie_breakers =
compute_tie_breakers(&statements[idx], active_wildcards, statements, remaining);
// Final deterministic tie-breaker: prefer smaller original indices. // Final deterministic tie-breaker: prefer smaller original indices.
// This avoids hash-iteration-dependent selection when scores are equal. // This avoids hash-iteration-dependent selection when scores are equal.
@ -284,6 +287,7 @@ fn find_best_next_statement(
remaining: &HashSet<usize>, remaining: &HashSet<usize>,
active_wildcards: &HashSet<String>, active_wildcards: &HashSet<String>,
ordered_count: usize, ordered_count: usize,
public_args: &HashSet<String>,
) -> usize { ) -> usize {
// Calculate distance to next split point // Calculate distance to next split point
let bucket_size = Params::max_custom_predicate_arity() - 1; // Reserve slot for chain call let bucket_size = Params::max_custom_predicate_arity() - 1; // Reserve slot for chain call
@ -299,51 +303,66 @@ fn find_best_next_statement(
active_wildcards, active_wildcards,
remaining, remaining,
approaching_split, approaching_split,
public_args,
) )
}) })
.copied() .copied()
.unwrap() .unwrap()
} }
/// Score a statement based on how well it minimizes liveness /// Score a statement based on how well it minimizes private-wildcard liveness at boundaries.
/// `needed_later` is the set of private wildcards used by any remaining statement.
fn score_statement( fn score_statement(
stmt: &StatementTmpl, stmt: &StatementTmpl,
active_wildcards: &HashSet<String>, active_wildcards: &HashSet<String>,
statements: &[StatementTmpl],
remaining: &HashSet<usize>,
approaching_split: bool, approaching_split: bool,
public_args: &HashSet<String>,
needed_later: &HashSet<String>,
) -> i32 { ) -> i32 {
let stmt_wildcards = collect_wildcards_from_statement(stmt); let all_wildcards = collect_wildcards_from_statement(stmt);
// How many active wildcards does this reuse? // Only score based on private wildcards. Public args are always available at every
let reuse_count = stmt_wildcards.intersection(active_wildcards).count(); // split boundary — they never consume a promotion slot, so their liveness is free.
let stmt_wildcards: HashSet<_> = all_wildcards
// How many new wildcards does this introduce? .into_iter()
let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count(); .filter(|w| !public_args.contains(w))
// After adding this statement, what would be active?
let mut projected_active = active_wildcards.clone();
projected_active.extend(stmt_wildcards.clone());
// Which wildcards are still needed by other remaining statements?
let needed_later: HashSet<String> = remaining
.iter()
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
.collect(); .collect();
// Wildcards we can close = active now but not needed later // Statements that touch only public args ("cheap" statements) waste a bucket slot
// that could be used to cluster private wildcards. Strongly defer them while any
// private-wildcard statements remain, so they fill leftover space at the end.
// `needed_later` is non-empty iff some remaining statement has a private wildcard.
if stmt_wildcards.is_empty() {
return if needed_later.is_empty() {
0
} else {
i32::MIN / 2
};
}
// How many active private wildcards does this reuse?
let reuse_count = stmt_wildcards.intersection(active_wildcards).count();
// How many new private wildcards does this introduce?
let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count();
// Which of the projected-active wildcards are still needed after this statement?
let mut projected_active = active_wildcards.clone();
projected_active.extend(stmt_wildcards);
projected_active.retain(|w| needed_later.contains(w)); projected_active.retain(|w| needed_later.contains(w));
let still_active_count = projected_active.len(); let still_active_count = projected_active.len();
// Base score calculation // Base score:
// - Prefer statements that reuse active wildcards (don't introduce new liveness) // +3 per reused wildcard — rewards clustering (wildcard already open, no new cost)
// - Penalize introducing new wildcards (increases liveness) // -4 per new wildcard — penalises opening new live ranges
// - Penalize keeping many wildcards active (higher liveness) // -2 per still-live — penalises carrying many wildcards toward the boundary
let base_score = (reuse_count * 3) as i32 let base_score = (reuse_count * 3) as i32
- (new_wildcard_count * 4) as i32 - (new_wildcard_count * 4) as i32
- (still_active_count * 2) as i32; - (still_active_count * 2) as i32;
// Look-ahead bonus: when approaching split, heavily favor closing wildcards // When close to a split boundary, strongly reward statements that close wildcards
// (active.len() + new - still_active = number of wildcards resolved by this statement).
// Weight 10 >> max base-score magnitude to make closing the dominant factor.
if approaching_split { if approaching_split {
let closes_count = active_wildcards.len() + new_wildcard_count - still_active_count; let closes_count = active_wildcards.len() + new_wildcard_count - still_active_count;
base_score + (closes_count * 10) as i32 base_score + (closes_count * 10) as i32
@ -375,8 +394,6 @@ fn calculate_live_wildcards(
fn generate_refactor_suggestion( fn generate_refactor_suggestion(
crossing_wildcards: &[String], crossing_wildcards: &[String],
ordered_statements: &[StatementTmpl], ordered_statements: &[StatementTmpl],
_pos: usize,
_end: usize,
) -> Option<crate::lang::error::RefactorSuggestion> { ) -> Option<crate::lang::error::RefactorSuggestion> {
use crate::lang::error::RefactorSuggestion; use crate::lang::error::RefactorSuggestion;
@ -445,13 +462,6 @@ fn split_into_chain(
let original_name = pred.name.name.clone(); let original_name = pred.name.name.clone();
let conjunction = pred.conjunction_type; let conjunction = pred.conjunction_type;
let usage = analyze_wildcards(&pred.statements);
let real_statement_count = pred.statements.len();
let ordering_result = order_constraints_optimally(pred.statements, &usage);
let ordered_statements = ordering_result.statements;
let reorder_map = ordering_result.reorder_map;
let original_public_args: Vec<String> = pred let original_public_args: Vec<String> = pred
.args .args
.public_args .public_args
@ -459,6 +469,14 @@ fn split_into_chain(
.map(|id| id.name.clone()) .map(|id| id.name.clone())
.collect(); .collect();
let public_args_set: HashSet<String> = original_public_args.iter().cloned().collect();
let real_statement_count = pred.statements.len();
let ordering_result = order_constraints_optimally(pred.statements, &public_args_set);
let ordered_statements = ordering_result.statements;
let reorder_map = ordering_result.reorder_map;
let mut chain_links = Vec::new(); let mut chain_links = Vec::new();
let mut pos = 0; let mut pos = 0;
let mut incoming_public = original_public_args.clone(); let mut incoming_public = original_public_args.clone();
@ -501,8 +519,7 @@ fn split_into_chain(
total_public, total_public,
}; };
let suggestion = let suggestion = generate_refactor_suggestion(&new_promotions, &ordered_statements);
generate_refactor_suggestion(&new_promotions, &ordered_statements, pos, end);
return Err(SplittingError::TooManyPublicArgsAtSplit { return Err(SplittingError::TooManyPublicArgsAtSplit {
predicate: original_name.clone(), predicate: original_name.clone(),
@ -540,23 +557,60 @@ fn split_into_chain(
}); });
} }
let mut public_args_out: Vec<String> = live_at_boundary.iter().cloned().collect();
public_args_out.sort(); // Deterministic ordering
chain_links.push(ChainLink { chain_links.push(ChainLink {
statements: ordered_statements[pos..end].to_vec(), statements: ordered_statements[pos..end].to_vec(),
public_args_in: incoming_public.clone(), public_args_in: incoming_public.clone(),
private_args, private_args,
public_args_out: public_args_out.clone(), // new_promotions are already sorted and already filtered to exclude incoming_public
promoted_wildcards: new_promotions.clone(),
}); });
pos = end; pos = end;
// Next link's incoming public args = current incoming + newly promoted live wildcards // Extend incoming_public for the next link with the newly promoted wildcards.
// Only add wildcards that aren't already in incoming_public to avoid duplicates // new_promotions is already filtered to exclude incoming_set, so no dedup needed.
for wildcard in public_args_out { incoming_public.extend(new_promotions);
if !incoming_set.contains(&wildcard) { }
incoming_public.push(wildcard);
// Backward pass: prune each continuation's public args to the minimal set needed.
//
// The forward pass accumulates incoming_public monotonically, so a continuation may
// inherit original public args that none of its statements (or downstream continuations)
// ever reference. A continuation must declare every public arg it receives, and the
// proof system constrains each declared arg - an arg that goes unused has no constraints
// and will not match the value the caller passes.
//
// Propagating from the last link backward ensures each continuation declares exactly the
// args it uses directly, plus any args its successor still needs. Link 0 (the original
// predicate) is left untouched - its public-arg signature is user-declared.
{
let num_links = chain_links.len();
if num_links > 1 {
// Collect wildcards referenced by each link's statements once.
let link_wildcards: Vec<HashSet<String>> = chain_links
.iter()
.map(|link| {
link.statements
.iter()
.flat_map(collect_wildcards_from_statement)
.collect()
})
.collect();
let last = num_links - 1;
// Seed: last link retains only args it directly references.
chain_links[last]
.public_args_in
.retain(|a| link_wildcards[last].contains(a));
// Propagate backward through intermediate continuation links (skip link 0).
for i in (1..last).rev() {
let needed_downstream: HashSet<String> =
chain_links[i + 1].public_args_in.iter().cloned().collect();
chain_links[i]
.public_args_in
.retain(|a| link_wildcards[i].contains(a) || needed_downstream.contains(a));
} }
} }
} }
@ -590,7 +644,7 @@ fn split_into_chain(
let mut chain_predicates = let mut chain_predicates =
generate_chain_predicates(&original_name, chain_links, conjunction, params)?; generate_chain_predicates(&original_name, chain_links, conjunction, params)?;
validate_chain(&chain_predicates, params)?; validate_chain(&chain_predicates, params);
// Reverse so continuations come before callers in declaration order. // Reverse so continuations come before callers in declaration order.
// This ensures that when batched, continuations are in earlier batches // This ensures that when batched, continuations are in earlier batches
@ -633,7 +687,7 @@ fn generate_chain_predicates(
}; };
// Create arguments for chain call: use next link's public_args_in // Create arguments for chain call: use next link's public_args_in
// which is the deduplicated union of current public_args_in and public_args_out // which is current public_args_in extended with current promoted_wildcards
let next_link = &chain_links[i + 1]; let next_link = &chain_links[i + 1];
let chain_call_args: Vec<StatementTmplArg> = next_link let chain_call_args: Vec<StatementTmplArg> = next_link
.public_args_in .public_args_in
@ -665,10 +719,12 @@ fn generate_chain_predicates(
}) })
.collect(); .collect();
// Build private args (private + promoted for next) // Build private args: segment-local private wildcards, plus any wildcards being
// promoted to public for the next link (they must be declared here so the solver
// can bind them before passing them as public args to the continuation).
let mut private_arg_names = link.private_args.clone(); let mut private_arg_names = link.private_args.clone();
if !is_last { if !is_last {
private_arg_names.extend(link.public_args_out.clone()); private_arg_names.extend(link.promoted_wildcards.iter().cloned());
} }
let private_args = if private_arg_names.is_empty() { let private_args = if private_arg_names.is_empty() {
@ -698,25 +754,34 @@ fn generate_chain_predicates(
Ok(predicates) Ok(predicates)
} }
/// Phase 5: Validate the generated chain /// Sanity-check the generated chain. All three constraints are enforced as proper errors
/// /// earlier in `split_into_chain`, so violations here indicate a bug in the algorithm.
/// Note: We no longer check chain length against max_custom_batch_size since fn validate_chain(chain: &[CustomPredicateDef], params: &Params) {
/// chains can now span multiple batches thanks to multi-batch support.
fn validate_chain(chain: &[CustomPredicateDef], params: &Params) -> Result<(), SplittingError> {
for pred in chain { for pred in chain {
// Each predicate should have ≤ max_statements assert!(
assert!(pred.statements.len() <= Params::max_custom_predicate_arity()); pred.statements.len() <= Params::max_custom_predicate_arity(),
"chain link '{}' has {} statements, exceeds max {}",
// Public args should fit pred.name.name,
assert!(pred.args.public_args.len() <= Params::max_statement_args()); pred.statements.len(),
Params::max_custom_predicate_arity(),
// Total args should fit );
assert!(
pred.args.public_args.len() <= Params::max_statement_args(),
"chain link '{}' has {} public args, exceeds max {}",
pred.name.name,
pred.args.public_args.len(),
Params::max_statement_args(),
);
let total = let total =
pred.args.public_args.len() + pred.args.private_args.as_ref().map_or(0, |v| v.len()); pred.args.public_args.len() + pred.args.private_args.as_ref().map_or(0, |v| v.len());
assert!(total <= params.max_custom_predicate_wildcards); assert!(
total <= params.max_custom_predicate_wildcards,
"chain link '{}' has {} total args, exceeds max {}",
pred.name.name,
total,
params.max_custom_predicate_wildcards,
);
} }
Ok(())
} }
#[cfg(test)] #[cfg(test)]
@ -976,8 +1041,24 @@ mod tests {
let remaining: HashSet<usize> = [0, 1].into_iter().collect(); let remaining: HashSet<usize> = [0, 1].into_iter().collect();
let active_wildcards = HashSet::new(); let active_wildcards = HashSet::new();
let key0 = statement_selection_key(0, &statements, &active_wildcards, &remaining, false); // A and B are the public args of tie_break(A, B)
let key1 = statement_selection_key(1, &statements, &active_wildcards, &remaining, false); let public_args: HashSet<String> = ["A".to_string(), "B".to_string()].into_iter().collect();
let key0 = statement_selection_key(
0,
&statements,
&active_wildcards,
&remaining,
false,
&public_args,
);
let key1 = statement_selection_key(
1,
&statements,
&active_wildcards,
&remaining,
false,
&public_args,
);
assert_eq!(key0.0, key1.0, "Primary heuristic score should tie"); assert_eq!(key0.0, key1.0, "Primary heuristic score should tie");
assert_eq!(key0.1, key1.1, "Secondary tie-breaker metrics should tie"); assert_eq!(key0.1, key1.1, "Secondary tie-breaker metrics should tie");
@ -986,7 +1067,8 @@ mod tests {
"Lower original index should win deterministic final tie-breaker" "Lower original index should win deterministic final tie-breaker"
); );
let selected = find_best_next_statement(&statements, &remaining, &active_wildcards, 0); let selected =
find_best_next_statement(&statements, &remaining, &active_wildcards, 0, &public_args);
assert_eq!(selected, 0); assert_eq!(selected, 0);
} }
@ -1084,7 +1166,7 @@ mod tests {
assert!(error_msg.contains("3 crossing wildcards")); assert!(error_msg.contains("3 crossing wildcards"));
assert!(error_msg.contains("= 6 total")); assert!(error_msg.contains("= 6 total"));
assert!(error_msg.contains("exceeds max of 5")); assert!(error_msg.contains("exceeds max of 5"));
assert!(error_msg.contains("Statements 0-4")); assert!(error_msg.contains("Statements 0-3"));
assert!(error_msg.contains("Incoming public args: A, B, C")); assert!(error_msg.contains("Incoming public args: A, B, C"));
assert!(error_msg.contains("Wildcards crossing this boundary: T1, T2, T3")); assert!(error_msg.contains("Wildcards crossing this boundary: T1, T2, T3"));
assert!(error_msg.contains("Suggestion:")); assert!(error_msg.contains("Suggestion:"));
@ -1144,25 +1226,82 @@ mod tests {
); );
} }
// --- Regression tests ---
/// Statements that reference only public args should be deferred until private-wildcard
/// statements have been clustered, so they don't consume bucket slots that would reduce
/// liveness at split boundaries.
///
/// 4 public args, 7 statements: W1 used in stmts 0,1,4; W2 used in stmts 1,2,3;
/// stmts 5,6 reference only public args. The scoring correctly defers stmts 5,6,
/// yielding bucket0={0,1,2,3}, bucket1={4,5,6} with only W1 crossing (4+1=5 <= max).
#[test] #[test]
fn test_refactor_suggestion_group_wildcards() { fn test_split_succeeds_with_four_public_args_and_public_only_statements() {
// Test the "group wildcard usages" suggestion formatting // Optimal split: bucket0={0,1,2,3}, bucket1={4,5,6}
use crate::lang::error::RefactorSuggestion; // Only W1 crosses (used in 0,1 and 4), total = 4 public + 1 crossing = 5 ✓
let input = r#"
pred(A, B, C, D, private: W1, W2) = AND(
Equal(W1["x"], A["v"])
Equal(W2["y"], W1["x"])
Equal(W2["z"], B["v"])
Equal(C["r"], W2["y"])
Equal(D["s"], W1["x"])
Equal(A["out"], C["out"])
Equal(B["out"], D["out"])
)
"#;
let suggestion = RefactorSuggestion::GroupWildcardUsages { let pred = parse_predicate(input);
wildcards: vec!["T1".to_string(), "T2".to_string(), "T3".to_string()], let params = Params::default();
};
let suggestion_text = suggestion.format(); let result = split_predicate_if_needed(pred, &params);
assert!(
result.is_ok(),
"Should find a valid split with ≤1 crossing wildcard, got: {:?}",
result.err()
);
}
// Verify the suggestion formats correctly /// Continuation predicates should only declare the public args they actually use -
assert!(suggestion_text.contains("Group operations for wildcards")); /// original public args that are not referenced in a continuation's statements or
assert!(suggestion_text.contains("T1, T2, T3")); /// any of its downstream continuations must be omitted from its signature.
assert!(suggestion_text.contains("used across multiple segments")); #[test]
fn test_continuation_excludes_public_args_unused_after_split() {
// A is used only in the first segment; B is used only in the second segment.
// The continuation predicate (pred_1) must include B but not A.
let input = r#"
pred(A, B, private: T) = AND(
Equal(T["x"], A["val"])
Equal(T["y"], 1)
Equal(T["z"], 2)
Equal(T["w"], 3)
Equal(B["r"], T["x"])
Equal(B["s"], T["y"])
)
"#;
eprintln!( let pred = parse_predicate(input);
"\n=== Example GroupWildcardUsages Suggestion ===\n{}\n", let params = Params::default();
suggestion_text
let result = split_predicate_if_needed(pred, &params).unwrap();
// chain[0] is the continuation (_1 suffix), chain[1] is the original
let continuation = result
.predicates
.iter()
.find(|p| p.name.name == "pred_1")
.expect("Expected a pred_1 continuation predicate");
let cont_public: Vec<&str> = continuation
.args
.public_args
.iter()
.map(|id| id.name.as_str())
.collect();
assert!(
!cont_public.contains(&"A"),
"Continuation should drop unused public arg 'A', got: {:?}",
cont_public
); );
} }
} }

View file

@ -77,14 +77,14 @@ impl NativePredicate {
| NativePredicate::MaxOf | NativePredicate::MaxOf
| NativePredicate::HashOf | NativePredicate::HashOf
| NativePredicate::SetInsert | NativePredicate::SetInsert
| NativePredicate::SetDelete => 3, | NativePredicate::SetDelete
| NativePredicate::DictDelete
| NativePredicate::ContainerDelete => 3,
NativePredicate::DictInsert NativePredicate::DictInsert
| NativePredicate::DictUpdate | NativePredicate::DictUpdate
| NativePredicate::DictDelete
| NativePredicate::ArrayUpdate | NativePredicate::ArrayUpdate
| NativePredicate::ContainerInsert | NativePredicate::ContainerInsert
| NativePredicate::ContainerUpdate | NativePredicate::ContainerUpdate => 4,
| NativePredicate::ContainerDelete => 4,
} }
} }
} }