Splitter fixes (#483)
* Add failing tests * Model statements depending on public args as cheaper than those depending on private args * Tidying * Fix unnecessary propagation of unused public args * More tidying * Tidy test comments * Fix incorrect Delete arities
This commit is contained in:
parent
e950661090
commit
f6c6ec43ef
3 changed files with 280 additions and 140 deletions
|
|
@ -287,7 +287,8 @@ fn format_public_args_at_split_error(
|
|||
|
||||
msg.push_str(&format!(
|
||||
" Statements {}-{} in this segment\n",
|
||||
context.statement_range.0, context.statement_range.1
|
||||
context.statement_range.0,
|
||||
context.statement_range.1 - 1
|
||||
));
|
||||
|
||||
if !context.incoming_public.is_empty() {
|
||||
|
|
|
|||
|
|
@ -15,10 +15,7 @@
|
|||
//! We use a greedy algorithm to order the statements in a predicate to minimize
|
||||
//! the number of live wildcards at split boundaries.
|
||||
|
||||
use std::{
|
||||
cmp::Reverse,
|
||||
collections::{HashMap, HashSet},
|
||||
};
|
||||
use std::{cmp::Reverse, collections::HashSet};
|
||||
|
||||
// SplittingError is now defined in error.rs
|
||||
pub use crate::lang::error::SplittingError;
|
||||
|
|
@ -33,8 +30,8 @@ pub struct ChainLink {
|
|||
pub public_args_in: Vec<String>,
|
||||
/// Private arguments used only in this link
|
||||
pub private_args: Vec<String>,
|
||||
/// Public arguments promoted to pass to next link (empty if last link)
|
||||
pub public_args_out: Vec<String>,
|
||||
/// Private wildcards promoted to public for the next link (empty if last link)
|
||||
pub promoted_wildcards: Vec<String>,
|
||||
}
|
||||
|
||||
/// Information about a single piece of a split predicate chain
|
||||
|
|
@ -71,13 +68,6 @@ pub struct SplitResult {
|
|||
pub chain_info: Option<SplitChainInfo>,
|
||||
}
|
||||
|
||||
/// Wildcard usage information
|
||||
#[derive(Debug, Clone)]
|
||||
struct WildcardUsage {
|
||||
/// Indices of statements using this wildcard
|
||||
used_in_statements: HashSet<usize>,
|
||||
}
|
||||
|
||||
/// Early validation: Check if predicate is fundamentally splittable
|
||||
pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), SplittingError> {
|
||||
let public_args = pred.args.public_args.len();
|
||||
|
|
@ -121,26 +111,6 @@ pub fn split_predicate_if_needed(
|
|||
})
|
||||
}
|
||||
|
||||
fn analyze_wildcards(statements: &[StatementTmpl]) -> HashMap<String, WildcardUsage> {
|
||||
let mut usage: HashMap<String, WildcardUsage> = HashMap::new();
|
||||
|
||||
for (idx, stmt) in statements.iter().enumerate() {
|
||||
let wildcards = collect_wildcards_from_statement(stmt);
|
||||
|
||||
for wildcard in wildcards {
|
||||
usage
|
||||
.entry(wildcard.clone())
|
||||
.or_insert_with(|| WildcardUsage {
|
||||
used_in_statements: HashSet::new(),
|
||||
})
|
||||
.used_in_statements
|
||||
.insert(idx);
|
||||
}
|
||||
}
|
||||
|
||||
usage
|
||||
}
|
||||
|
||||
/// Collect all wildcard names from a statement
|
||||
fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> {
|
||||
let mut wildcards = HashSet::new();
|
||||
|
|
@ -172,7 +142,7 @@ struct OrderingResult {
|
|||
|
||||
fn order_constraints_optimally(
|
||||
statements: Vec<StatementTmpl>,
|
||||
_usage: &HashMap<String, WildcardUsage>,
|
||||
public_args: &HashSet<String>,
|
||||
) -> OrderingResult {
|
||||
let n = statements.len();
|
||||
|
||||
|
|
@ -190,8 +160,13 @@ fn order_constraints_optimally(
|
|||
let mut active_wildcards: HashSet<String> = HashSet::new();
|
||||
|
||||
while !remaining.is_empty() {
|
||||
let best_idx =
|
||||
find_best_next_statement(&statements, &remaining, &active_wildcards, ordered.len());
|
||||
let best_idx = find_best_next_statement(
|
||||
&statements,
|
||||
&remaining,
|
||||
&active_wildcards,
|
||||
ordered.len(),
|
||||
public_args,
|
||||
);
|
||||
|
||||
remaining.remove(&best_idx);
|
||||
let stmt = &statements[best_idx];
|
||||
|
|
@ -200,14 +175,20 @@ fn order_constraints_optimally(
|
|||
reorder_map[best_idx] = ordered.len();
|
||||
ordered.push(stmt.clone());
|
||||
|
||||
// Update active wildcards
|
||||
// Only track private wildcards in the active set — public args are always
|
||||
// available at every boundary so their liveness is irrelevant to split cost.
|
||||
let stmt_wildcards = collect_wildcards_from_statement(stmt);
|
||||
active_wildcards.extend(stmt_wildcards);
|
||||
active_wildcards.extend(
|
||||
stmt_wildcards
|
||||
.into_iter()
|
||||
.filter(|w| !public_args.contains(w)),
|
||||
);
|
||||
|
||||
// Remove wildcards no longer needed by remaining statements
|
||||
// Remove private wildcards no longer needed by remaining statements
|
||||
let needed_later: HashSet<_> = remaining
|
||||
.iter()
|
||||
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
|
||||
.filter(|w| !public_args.contains(w))
|
||||
.collect();
|
||||
active_wildcards.retain(|w| needed_later.contains(w));
|
||||
}
|
||||
|
|
@ -225,30 +206,35 @@ fn compute_tie_breakers(
|
|||
active_wildcards: &HashSet<String>,
|
||||
statements: &[StatementTmpl],
|
||||
remaining: &HashSet<usize>,
|
||||
needed_later: &HashSet<String>,
|
||||
public_args: &HashSet<String>,
|
||||
) -> (usize, usize, i32) {
|
||||
let stmt_wildcards = collect_wildcards_from_statement(stmt);
|
||||
|
||||
// Metric 1: Simplicity - prefer statements with fewer wildcards
|
||||
let simplicity = usize::MAX - stmt_wildcards.len();
|
||||
|
||||
// Metric 2: Public closure - prefer statements that close active wildcards
|
||||
// (wildcards that won't be needed by any remaining statements)
|
||||
let needed_later: HashSet<String> = remaining
|
||||
.iter()
|
||||
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
|
||||
let all_wildcards = collect_wildcards_from_statement(stmt);
|
||||
// Only consider private wildcards for tie-breaking metrics
|
||||
let stmt_wildcards: HashSet<_> = all_wildcards
|
||||
.into_iter()
|
||||
.filter(|w| !public_args.contains(w))
|
||||
.collect();
|
||||
|
||||
// Metric 1: Simplicity - prefer statements with fewer private wildcards
|
||||
let simplicity = usize::MAX - stmt_wildcards.len();
|
||||
|
||||
// Metric 2: Closure - prefer statements that close active private wildcards
|
||||
// (wildcards that won't be needed by any remaining statements)
|
||||
let closes_count = stmt_wildcards
|
||||
.intersection(active_wildcards)
|
||||
.filter(|w| !needed_later.contains(*w))
|
||||
.count();
|
||||
|
||||
// Metric 3: Fanout - prefer statements with lower future usage
|
||||
// (number of remaining statements that use any wildcard from this statement)
|
||||
// (number of remaining statements sharing private wildcards with this statement)
|
||||
let fanout = remaining
|
||||
.iter()
|
||||
.filter(|&&i| {
|
||||
let other_wildcards = collect_wildcards_from_statement(&statements[i]);
|
||||
let other_wildcards: HashSet<_> = collect_wildcards_from_statement(&statements[i])
|
||||
.into_iter()
|
||||
.filter(|w| !public_args.contains(w))
|
||||
.collect();
|
||||
!stmt_wildcards.is_disjoint(&other_wildcards)
|
||||
})
|
||||
.count();
|
||||
|
|
@ -262,16 +248,33 @@ fn statement_selection_key(
|
|||
active_wildcards: &HashSet<String>,
|
||||
remaining: &HashSet<usize>,
|
||||
approaching_split: bool,
|
||||
public_args: &HashSet<String>,
|
||||
) -> (i32, (usize, usize, i32), Reverse<usize>) {
|
||||
// Pre-compute needed_later once and share between primary score and tie-breakers.
|
||||
// Exclude the candidate itself: we want to know what the *other* remaining statements
|
||||
// need, so that wildcards used only by this candidate correctly appear as closeable.
|
||||
let needed_later: HashSet<String> = remaining
|
||||
.iter()
|
||||
.filter(|&&i| i != idx)
|
||||
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
|
||||
.filter(|w| !public_args.contains(w))
|
||||
.collect();
|
||||
|
||||
let primary_score = score_statement(
|
||||
&statements[idx],
|
||||
active_wildcards,
|
||||
approaching_split,
|
||||
public_args,
|
||||
&needed_later,
|
||||
);
|
||||
let tie_breakers = compute_tie_breakers(
|
||||
&statements[idx],
|
||||
active_wildcards,
|
||||
statements,
|
||||
remaining,
|
||||
approaching_split,
|
||||
&needed_later,
|
||||
public_args,
|
||||
);
|
||||
let tie_breakers =
|
||||
compute_tie_breakers(&statements[idx], active_wildcards, statements, remaining);
|
||||
|
||||
// Final deterministic tie-breaker: prefer smaller original indices.
|
||||
// This avoids hash-iteration-dependent selection when scores are equal.
|
||||
|
|
@ -284,6 +287,7 @@ fn find_best_next_statement(
|
|||
remaining: &HashSet<usize>,
|
||||
active_wildcards: &HashSet<String>,
|
||||
ordered_count: usize,
|
||||
public_args: &HashSet<String>,
|
||||
) -> usize {
|
||||
// Calculate distance to next split point
|
||||
let bucket_size = Params::max_custom_predicate_arity() - 1; // Reserve slot for chain call
|
||||
|
|
@ -299,51 +303,66 @@ fn find_best_next_statement(
|
|||
active_wildcards,
|
||||
remaining,
|
||||
approaching_split,
|
||||
public_args,
|
||||
)
|
||||
})
|
||||
.copied()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Score a statement based on how well it minimizes liveness
|
||||
/// Score a statement based on how well it minimizes private-wildcard liveness at boundaries.
|
||||
/// `needed_later` is the set of private wildcards used by any remaining statement.
|
||||
fn score_statement(
|
||||
stmt: &StatementTmpl,
|
||||
active_wildcards: &HashSet<String>,
|
||||
statements: &[StatementTmpl],
|
||||
remaining: &HashSet<usize>,
|
||||
approaching_split: bool,
|
||||
public_args: &HashSet<String>,
|
||||
needed_later: &HashSet<String>,
|
||||
) -> i32 {
|
||||
let stmt_wildcards = collect_wildcards_from_statement(stmt);
|
||||
let all_wildcards = collect_wildcards_from_statement(stmt);
|
||||
|
||||
// How many active wildcards does this reuse?
|
||||
let reuse_count = stmt_wildcards.intersection(active_wildcards).count();
|
||||
|
||||
// How many new wildcards does this introduce?
|
||||
let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count();
|
||||
|
||||
// After adding this statement, what would be active?
|
||||
let mut projected_active = active_wildcards.clone();
|
||||
projected_active.extend(stmt_wildcards.clone());
|
||||
|
||||
// Which wildcards are still needed by other remaining statements?
|
||||
let needed_later: HashSet<String> = remaining
|
||||
.iter()
|
||||
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
|
||||
// Only score based on private wildcards. Public args are always available at every
|
||||
// split boundary — they never consume a promotion slot, so their liveness is free.
|
||||
let stmt_wildcards: HashSet<_> = all_wildcards
|
||||
.into_iter()
|
||||
.filter(|w| !public_args.contains(w))
|
||||
.collect();
|
||||
|
||||
// Wildcards we can close = active now but not needed later
|
||||
// Statements that touch only public args ("cheap" statements) waste a bucket slot
|
||||
// that could be used to cluster private wildcards. Strongly defer them while any
|
||||
// private-wildcard statements remain, so they fill leftover space at the end.
|
||||
// `needed_later` is non-empty iff some remaining statement has a private wildcard.
|
||||
if stmt_wildcards.is_empty() {
|
||||
return if needed_later.is_empty() {
|
||||
0
|
||||
} else {
|
||||
i32::MIN / 2
|
||||
};
|
||||
}
|
||||
|
||||
// How many active private wildcards does this reuse?
|
||||
let reuse_count = stmt_wildcards.intersection(active_wildcards).count();
|
||||
|
||||
// How many new private wildcards does this introduce?
|
||||
let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count();
|
||||
|
||||
// Which of the projected-active wildcards are still needed after this statement?
|
||||
let mut projected_active = active_wildcards.clone();
|
||||
projected_active.extend(stmt_wildcards);
|
||||
projected_active.retain(|w| needed_later.contains(w));
|
||||
let still_active_count = projected_active.len();
|
||||
|
||||
// Base score calculation
|
||||
// - Prefer statements that reuse active wildcards (don't introduce new liveness)
|
||||
// - Penalize introducing new wildcards (increases liveness)
|
||||
// - Penalize keeping many wildcards active (higher liveness)
|
||||
// Base score:
|
||||
// +3 per reused wildcard — rewards clustering (wildcard already open, no new cost)
|
||||
// -4 per new wildcard — penalises opening new live ranges
|
||||
// -2 per still-live — penalises carrying many wildcards toward the boundary
|
||||
let base_score = (reuse_count * 3) as i32
|
||||
- (new_wildcard_count * 4) as i32
|
||||
- (still_active_count * 2) as i32;
|
||||
|
||||
// Look-ahead bonus: when approaching split, heavily favor closing wildcards
|
||||
// When close to a split boundary, strongly reward statements that close wildcards
|
||||
// (active.len() + new - still_active = number of wildcards resolved by this statement).
|
||||
// Weight 10 >> max base-score magnitude to make closing the dominant factor.
|
||||
if approaching_split {
|
||||
let closes_count = active_wildcards.len() + new_wildcard_count - still_active_count;
|
||||
base_score + (closes_count * 10) as i32
|
||||
|
|
@ -375,8 +394,6 @@ fn calculate_live_wildcards(
|
|||
fn generate_refactor_suggestion(
|
||||
crossing_wildcards: &[String],
|
||||
ordered_statements: &[StatementTmpl],
|
||||
_pos: usize,
|
||||
_end: usize,
|
||||
) -> Option<crate::lang::error::RefactorSuggestion> {
|
||||
use crate::lang::error::RefactorSuggestion;
|
||||
|
||||
|
|
@ -445,13 +462,6 @@ fn split_into_chain(
|
|||
let original_name = pred.name.name.clone();
|
||||
let conjunction = pred.conjunction_type;
|
||||
|
||||
let usage = analyze_wildcards(&pred.statements);
|
||||
let real_statement_count = pred.statements.len();
|
||||
|
||||
let ordering_result = order_constraints_optimally(pred.statements, &usage);
|
||||
let ordered_statements = ordering_result.statements;
|
||||
let reorder_map = ordering_result.reorder_map;
|
||||
|
||||
let original_public_args: Vec<String> = pred
|
||||
.args
|
||||
.public_args
|
||||
|
|
@ -459,6 +469,14 @@ fn split_into_chain(
|
|||
.map(|id| id.name.clone())
|
||||
.collect();
|
||||
|
||||
let public_args_set: HashSet<String> = original_public_args.iter().cloned().collect();
|
||||
|
||||
let real_statement_count = pred.statements.len();
|
||||
|
||||
let ordering_result = order_constraints_optimally(pred.statements, &public_args_set);
|
||||
let ordered_statements = ordering_result.statements;
|
||||
let reorder_map = ordering_result.reorder_map;
|
||||
|
||||
let mut chain_links = Vec::new();
|
||||
let mut pos = 0;
|
||||
let mut incoming_public = original_public_args.clone();
|
||||
|
|
@ -501,8 +519,7 @@ fn split_into_chain(
|
|||
total_public,
|
||||
};
|
||||
|
||||
let suggestion =
|
||||
generate_refactor_suggestion(&new_promotions, &ordered_statements, pos, end);
|
||||
let suggestion = generate_refactor_suggestion(&new_promotions, &ordered_statements);
|
||||
|
||||
return Err(SplittingError::TooManyPublicArgsAtSplit {
|
||||
predicate: original_name.clone(),
|
||||
|
|
@ -540,23 +557,60 @@ fn split_into_chain(
|
|||
});
|
||||
}
|
||||
|
||||
let mut public_args_out: Vec<String> = live_at_boundary.iter().cloned().collect();
|
||||
public_args_out.sort(); // Deterministic ordering
|
||||
|
||||
chain_links.push(ChainLink {
|
||||
statements: ordered_statements[pos..end].to_vec(),
|
||||
public_args_in: incoming_public.clone(),
|
||||
private_args,
|
||||
public_args_out: public_args_out.clone(),
|
||||
// new_promotions are already sorted and already filtered to exclude incoming_public
|
||||
promoted_wildcards: new_promotions.clone(),
|
||||
});
|
||||
|
||||
pos = end;
|
||||
|
||||
// Next link's incoming public args = current incoming + newly promoted live wildcards
|
||||
// Only add wildcards that aren't already in incoming_public to avoid duplicates
|
||||
for wildcard in public_args_out {
|
||||
if !incoming_set.contains(&wildcard) {
|
||||
incoming_public.push(wildcard);
|
||||
// Extend incoming_public for the next link with the newly promoted wildcards.
|
||||
// new_promotions is already filtered to exclude incoming_set, so no dedup needed.
|
||||
incoming_public.extend(new_promotions);
|
||||
}
|
||||
|
||||
// Backward pass: prune each continuation's public args to the minimal set needed.
|
||||
//
|
||||
// The forward pass accumulates incoming_public monotonically, so a continuation may
|
||||
// inherit original public args that none of its statements (or downstream continuations)
|
||||
// ever reference. A continuation must declare every public arg it receives, and the
|
||||
// proof system constrains each declared arg - an arg that goes unused has no constraints
|
||||
// and will not match the value the caller passes.
|
||||
//
|
||||
// Propagating from the last link backward ensures each continuation declares exactly the
|
||||
// args it uses directly, plus any args its successor still needs. Link 0 (the original
|
||||
// predicate) is left untouched - its public-arg signature is user-declared.
|
||||
{
|
||||
let num_links = chain_links.len();
|
||||
if num_links > 1 {
|
||||
// Collect wildcards referenced by each link's statements once.
|
||||
let link_wildcards: Vec<HashSet<String>> = chain_links
|
||||
.iter()
|
||||
.map(|link| {
|
||||
link.statements
|
||||
.iter()
|
||||
.flat_map(collect_wildcards_from_statement)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let last = num_links - 1;
|
||||
|
||||
// Seed: last link retains only args it directly references.
|
||||
chain_links[last]
|
||||
.public_args_in
|
||||
.retain(|a| link_wildcards[last].contains(a));
|
||||
|
||||
// Propagate backward through intermediate continuation links (skip link 0).
|
||||
for i in (1..last).rev() {
|
||||
let needed_downstream: HashSet<String> =
|
||||
chain_links[i + 1].public_args_in.iter().cloned().collect();
|
||||
chain_links[i]
|
||||
.public_args_in
|
||||
.retain(|a| link_wildcards[i].contains(a) || needed_downstream.contains(a));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -590,7 +644,7 @@ fn split_into_chain(
|
|||
let mut chain_predicates =
|
||||
generate_chain_predicates(&original_name, chain_links, conjunction, params)?;
|
||||
|
||||
validate_chain(&chain_predicates, params)?;
|
||||
validate_chain(&chain_predicates, params);
|
||||
|
||||
// Reverse so continuations come before callers in declaration order.
|
||||
// This ensures that when batched, continuations are in earlier batches
|
||||
|
|
@ -633,7 +687,7 @@ fn generate_chain_predicates(
|
|||
};
|
||||
|
||||
// Create arguments for chain call: use next link's public_args_in
|
||||
// which is the deduplicated union of current public_args_in and public_args_out
|
||||
// which is current public_args_in extended with current promoted_wildcards
|
||||
let next_link = &chain_links[i + 1];
|
||||
let chain_call_args: Vec<StatementTmplArg> = next_link
|
||||
.public_args_in
|
||||
|
|
@ -665,10 +719,12 @@ fn generate_chain_predicates(
|
|||
})
|
||||
.collect();
|
||||
|
||||
// Build private args (private + promoted for next)
|
||||
// Build private args: segment-local private wildcards, plus any wildcards being
|
||||
// promoted to public for the next link (they must be declared here so the solver
|
||||
// can bind them before passing them as public args to the continuation).
|
||||
let mut private_arg_names = link.private_args.clone();
|
||||
if !is_last {
|
||||
private_arg_names.extend(link.public_args_out.clone());
|
||||
private_arg_names.extend(link.promoted_wildcards.iter().cloned());
|
||||
}
|
||||
|
||||
let private_args = if private_arg_names.is_empty() {
|
||||
|
|
@ -698,25 +754,34 @@ fn generate_chain_predicates(
|
|||
Ok(predicates)
|
||||
}
|
||||
|
||||
/// Phase 5: Validate the generated chain
|
||||
///
|
||||
/// Note: We no longer check chain length against max_custom_batch_size since
|
||||
/// chains can now span multiple batches thanks to multi-batch support.
|
||||
fn validate_chain(chain: &[CustomPredicateDef], params: &Params) -> Result<(), SplittingError> {
|
||||
/// Sanity-check the generated chain. All three constraints are enforced as proper errors
|
||||
/// earlier in `split_into_chain`, so violations here indicate a bug in the algorithm.
|
||||
fn validate_chain(chain: &[CustomPredicateDef], params: &Params) {
|
||||
for pred in chain {
|
||||
// Each predicate should have ≤ max_statements
|
||||
assert!(pred.statements.len() <= Params::max_custom_predicate_arity());
|
||||
|
||||
// Public args should fit
|
||||
assert!(pred.args.public_args.len() <= Params::max_statement_args());
|
||||
|
||||
// Total args should fit
|
||||
assert!(
|
||||
pred.statements.len() <= Params::max_custom_predicate_arity(),
|
||||
"chain link '{}' has {} statements, exceeds max {}",
|
||||
pred.name.name,
|
||||
pred.statements.len(),
|
||||
Params::max_custom_predicate_arity(),
|
||||
);
|
||||
assert!(
|
||||
pred.args.public_args.len() <= Params::max_statement_args(),
|
||||
"chain link '{}' has {} public args, exceeds max {}",
|
||||
pred.name.name,
|
||||
pred.args.public_args.len(),
|
||||
Params::max_statement_args(),
|
||||
);
|
||||
let total =
|
||||
pred.args.public_args.len() + pred.args.private_args.as_ref().map_or(0, |v| v.len());
|
||||
assert!(total <= params.max_custom_predicate_wildcards);
|
||||
assert!(
|
||||
total <= params.max_custom_predicate_wildcards,
|
||||
"chain link '{}' has {} total args, exceeds max {}",
|
||||
pred.name.name,
|
||||
total,
|
||||
params.max_custom_predicate_wildcards,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -976,8 +1041,24 @@ mod tests {
|
|||
let remaining: HashSet<usize> = [0, 1].into_iter().collect();
|
||||
let active_wildcards = HashSet::new();
|
||||
|
||||
let key0 = statement_selection_key(0, &statements, &active_wildcards, &remaining, false);
|
||||
let key1 = statement_selection_key(1, &statements, &active_wildcards, &remaining, false);
|
||||
// A and B are the public args of tie_break(A, B)
|
||||
let public_args: HashSet<String> = ["A".to_string(), "B".to_string()].into_iter().collect();
|
||||
let key0 = statement_selection_key(
|
||||
0,
|
||||
&statements,
|
||||
&active_wildcards,
|
||||
&remaining,
|
||||
false,
|
||||
&public_args,
|
||||
);
|
||||
let key1 = statement_selection_key(
|
||||
1,
|
||||
&statements,
|
||||
&active_wildcards,
|
||||
&remaining,
|
||||
false,
|
||||
&public_args,
|
||||
);
|
||||
|
||||
assert_eq!(key0.0, key1.0, "Primary heuristic score should tie");
|
||||
assert_eq!(key0.1, key1.1, "Secondary tie-breaker metrics should tie");
|
||||
|
|
@ -986,7 +1067,8 @@ mod tests {
|
|||
"Lower original index should win deterministic final tie-breaker"
|
||||
);
|
||||
|
||||
let selected = find_best_next_statement(&statements, &remaining, &active_wildcards, 0);
|
||||
let selected =
|
||||
find_best_next_statement(&statements, &remaining, &active_wildcards, 0, &public_args);
|
||||
assert_eq!(selected, 0);
|
||||
}
|
||||
|
||||
|
|
@ -1084,7 +1166,7 @@ mod tests {
|
|||
assert!(error_msg.contains("3 crossing wildcards"));
|
||||
assert!(error_msg.contains("= 6 total"));
|
||||
assert!(error_msg.contains("exceeds max of 5"));
|
||||
assert!(error_msg.contains("Statements 0-4"));
|
||||
assert!(error_msg.contains("Statements 0-3"));
|
||||
assert!(error_msg.contains("Incoming public args: A, B, C"));
|
||||
assert!(error_msg.contains("Wildcards crossing this boundary: T1, T2, T3"));
|
||||
assert!(error_msg.contains("Suggestion:"));
|
||||
|
|
@ -1144,25 +1226,82 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
// --- Regression tests ---
|
||||
|
||||
/// Statements that reference only public args should be deferred until private-wildcard
|
||||
/// statements have been clustered, so they don't consume bucket slots that would reduce
|
||||
/// liveness at split boundaries.
|
||||
///
|
||||
/// 4 public args, 7 statements: W1 used in stmts 0,1,4; W2 used in stmts 1,2,3;
|
||||
/// stmts 5,6 reference only public args. The scoring correctly defers stmts 5,6,
|
||||
/// yielding bucket0={0,1,2,3}, bucket1={4,5,6} with only W1 crossing (4+1=5 <= max).
|
||||
#[test]
|
||||
fn test_refactor_suggestion_group_wildcards() {
|
||||
// Test the "group wildcard usages" suggestion formatting
|
||||
use crate::lang::error::RefactorSuggestion;
|
||||
fn test_split_succeeds_with_four_public_args_and_public_only_statements() {
|
||||
// Optimal split: bucket0={0,1,2,3}, bucket1={4,5,6}
|
||||
// Only W1 crosses (used in 0,1 and 4), total = 4 public + 1 crossing = 5 ✓
|
||||
let input = r#"
|
||||
pred(A, B, C, D, private: W1, W2) = AND(
|
||||
Equal(W1["x"], A["v"])
|
||||
Equal(W2["y"], W1["x"])
|
||||
Equal(W2["z"], B["v"])
|
||||
Equal(C["r"], W2["y"])
|
||||
Equal(D["s"], W1["x"])
|
||||
Equal(A["out"], C["out"])
|
||||
Equal(B["out"], D["out"])
|
||||
)
|
||||
"#;
|
||||
|
||||
let suggestion = RefactorSuggestion::GroupWildcardUsages {
|
||||
wildcards: vec!["T1".to_string(), "T2".to_string(), "T3".to_string()],
|
||||
};
|
||||
let pred = parse_predicate(input);
|
||||
let params = Params::default();
|
||||
|
||||
let suggestion_text = suggestion.format();
|
||||
let result = split_predicate_if_needed(pred, ¶ms);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should find a valid split with ≤1 crossing wildcard, got: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
// Verify the suggestion formats correctly
|
||||
assert!(suggestion_text.contains("Group operations for wildcards"));
|
||||
assert!(suggestion_text.contains("T1, T2, T3"));
|
||||
assert!(suggestion_text.contains("used across multiple segments"));
|
||||
/// Continuation predicates should only declare the public args they actually use -
|
||||
/// original public args that are not referenced in a continuation's statements or
|
||||
/// any of its downstream continuations must be omitted from its signature.
|
||||
#[test]
|
||||
fn test_continuation_excludes_public_args_unused_after_split() {
|
||||
// A is used only in the first segment; B is used only in the second segment.
|
||||
// The continuation predicate (pred_1) must include B but not A.
|
||||
let input = r#"
|
||||
pred(A, B, private: T) = AND(
|
||||
Equal(T["x"], A["val"])
|
||||
Equal(T["y"], 1)
|
||||
Equal(T["z"], 2)
|
||||
Equal(T["w"], 3)
|
||||
Equal(B["r"], T["x"])
|
||||
Equal(B["s"], T["y"])
|
||||
)
|
||||
"#;
|
||||
|
||||
eprintln!(
|
||||
"\n=== Example GroupWildcardUsages Suggestion ===\n{}\n",
|
||||
suggestion_text
|
||||
let pred = parse_predicate(input);
|
||||
let params = Params::default();
|
||||
|
||||
let result = split_predicate_if_needed(pred, ¶ms).unwrap();
|
||||
// chain[0] is the continuation (_1 suffix), chain[1] is the original
|
||||
let continuation = result
|
||||
.predicates
|
||||
.iter()
|
||||
.find(|p| p.name.name == "pred_1")
|
||||
.expect("Expected a pred_1 continuation predicate");
|
||||
|
||||
let cont_public: Vec<&str> = continuation
|
||||
.args
|
||||
.public_args
|
||||
.iter()
|
||||
.map(|id| id.name.as_str())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
!cont_public.contains(&"A"),
|
||||
"Continuation should drop unused public arg 'A', got: {:?}",
|
||||
cont_public
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,14 +77,14 @@ impl NativePredicate {
|
|||
| NativePredicate::MaxOf
|
||||
| NativePredicate::HashOf
|
||||
| NativePredicate::SetInsert
|
||||
| NativePredicate::SetDelete => 3,
|
||||
| NativePredicate::SetDelete
|
||||
| NativePredicate::DictDelete
|
||||
| NativePredicate::ContainerDelete => 3,
|
||||
NativePredicate::DictInsert
|
||||
| NativePredicate::DictUpdate
|
||||
| NativePredicate::DictDelete
|
||||
| NativePredicate::ArrayUpdate
|
||||
| NativePredicate::ContainerInsert
|
||||
| NativePredicate::ContainerUpdate
|
||||
| NativePredicate::ContainerDelete => 4,
|
||||
| NativePredicate::ContainerUpdate => 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue