Splitter fixes (#483)
* Add failing tests * Model statements depending on public args as cheaper than those depending on private args * Tidying * Fix unnecessary propagation of unused public args * More tidying * Tidy test comments * Fix incorrect Delete arities
This commit is contained in:
parent
e950661090
commit
f6c6ec43ef
3 changed files with 280 additions and 140 deletions
|
|
@ -287,7 +287,8 @@ fn format_public_args_at_split_error(
|
||||||
|
|
||||||
msg.push_str(&format!(
|
msg.push_str(&format!(
|
||||||
" Statements {}-{} in this segment\n",
|
" Statements {}-{} in this segment\n",
|
||||||
context.statement_range.0, context.statement_range.1
|
context.statement_range.0,
|
||||||
|
context.statement_range.1 - 1
|
||||||
));
|
));
|
||||||
|
|
||||||
if !context.incoming_public.is_empty() {
|
if !context.incoming_public.is_empty() {
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,7 @@
|
||||||
//! We use a greedy algorithm to order the statements in a predicate to minimize
|
//! We use a greedy algorithm to order the statements in a predicate to minimize
|
||||||
//! the number of live wildcards at split boundaries.
|
//! the number of live wildcards at split boundaries.
|
||||||
|
|
||||||
use std::{
|
use std::{cmp::Reverse, collections::HashSet};
|
||||||
cmp::Reverse,
|
|
||||||
collections::{HashMap, HashSet},
|
|
||||||
};
|
|
||||||
|
|
||||||
// SplittingError is now defined in error.rs
|
// SplittingError is now defined in error.rs
|
||||||
pub use crate::lang::error::SplittingError;
|
pub use crate::lang::error::SplittingError;
|
||||||
|
|
@ -33,8 +30,8 @@ pub struct ChainLink {
|
||||||
pub public_args_in: Vec<String>,
|
pub public_args_in: Vec<String>,
|
||||||
/// Private arguments used only in this link
|
/// Private arguments used only in this link
|
||||||
pub private_args: Vec<String>,
|
pub private_args: Vec<String>,
|
||||||
/// Public arguments promoted to pass to next link (empty if last link)
|
/// Private wildcards promoted to public for the next link (empty if last link)
|
||||||
pub public_args_out: Vec<String>,
|
pub promoted_wildcards: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Information about a single piece of a split predicate chain
|
/// Information about a single piece of a split predicate chain
|
||||||
|
|
@ -71,13 +68,6 @@ pub struct SplitResult {
|
||||||
pub chain_info: Option<SplitChainInfo>,
|
pub chain_info: Option<SplitChainInfo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wildcard usage information
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct WildcardUsage {
|
|
||||||
/// Indices of statements using this wildcard
|
|
||||||
used_in_statements: HashSet<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Early validation: Check if predicate is fundamentally splittable
|
/// Early validation: Check if predicate is fundamentally splittable
|
||||||
pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), SplittingError> {
|
pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), SplittingError> {
|
||||||
let public_args = pred.args.public_args.len();
|
let public_args = pred.args.public_args.len();
|
||||||
|
|
@ -121,26 +111,6 @@ pub fn split_predicate_if_needed(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn analyze_wildcards(statements: &[StatementTmpl]) -> HashMap<String, WildcardUsage> {
|
|
||||||
let mut usage: HashMap<String, WildcardUsage> = HashMap::new();
|
|
||||||
|
|
||||||
for (idx, stmt) in statements.iter().enumerate() {
|
|
||||||
let wildcards = collect_wildcards_from_statement(stmt);
|
|
||||||
|
|
||||||
for wildcard in wildcards {
|
|
||||||
usage
|
|
||||||
.entry(wildcard.clone())
|
|
||||||
.or_insert_with(|| WildcardUsage {
|
|
||||||
used_in_statements: HashSet::new(),
|
|
||||||
})
|
|
||||||
.used_in_statements
|
|
||||||
.insert(idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
usage
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Collect all wildcard names from a statement
|
/// Collect all wildcard names from a statement
|
||||||
fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> {
|
fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> {
|
||||||
let mut wildcards = HashSet::new();
|
let mut wildcards = HashSet::new();
|
||||||
|
|
@ -172,7 +142,7 @@ struct OrderingResult {
|
||||||
|
|
||||||
fn order_constraints_optimally(
|
fn order_constraints_optimally(
|
||||||
statements: Vec<StatementTmpl>,
|
statements: Vec<StatementTmpl>,
|
||||||
_usage: &HashMap<String, WildcardUsage>,
|
public_args: &HashSet<String>,
|
||||||
) -> OrderingResult {
|
) -> OrderingResult {
|
||||||
let n = statements.len();
|
let n = statements.len();
|
||||||
|
|
||||||
|
|
@ -190,8 +160,13 @@ fn order_constraints_optimally(
|
||||||
let mut active_wildcards: HashSet<String> = HashSet::new();
|
let mut active_wildcards: HashSet<String> = HashSet::new();
|
||||||
|
|
||||||
while !remaining.is_empty() {
|
while !remaining.is_empty() {
|
||||||
let best_idx =
|
let best_idx = find_best_next_statement(
|
||||||
find_best_next_statement(&statements, &remaining, &active_wildcards, ordered.len());
|
&statements,
|
||||||
|
&remaining,
|
||||||
|
&active_wildcards,
|
||||||
|
ordered.len(),
|
||||||
|
public_args,
|
||||||
|
);
|
||||||
|
|
||||||
remaining.remove(&best_idx);
|
remaining.remove(&best_idx);
|
||||||
let stmt = &statements[best_idx];
|
let stmt = &statements[best_idx];
|
||||||
|
|
@ -200,14 +175,20 @@ fn order_constraints_optimally(
|
||||||
reorder_map[best_idx] = ordered.len();
|
reorder_map[best_idx] = ordered.len();
|
||||||
ordered.push(stmt.clone());
|
ordered.push(stmt.clone());
|
||||||
|
|
||||||
// Update active wildcards
|
// Only track private wildcards in the active set — public args are always
|
||||||
|
// available at every boundary so their liveness is irrelevant to split cost.
|
||||||
let stmt_wildcards = collect_wildcards_from_statement(stmt);
|
let stmt_wildcards = collect_wildcards_from_statement(stmt);
|
||||||
active_wildcards.extend(stmt_wildcards);
|
active_wildcards.extend(
|
||||||
|
stmt_wildcards
|
||||||
|
.into_iter()
|
||||||
|
.filter(|w| !public_args.contains(w)),
|
||||||
|
);
|
||||||
|
|
||||||
// Remove wildcards no longer needed by remaining statements
|
// Remove private wildcards no longer needed by remaining statements
|
||||||
let needed_later: HashSet<_> = remaining
|
let needed_later: HashSet<_> = remaining
|
||||||
.iter()
|
.iter()
|
||||||
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
|
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
|
||||||
|
.filter(|w| !public_args.contains(w))
|
||||||
.collect();
|
.collect();
|
||||||
active_wildcards.retain(|w| needed_later.contains(w));
|
active_wildcards.retain(|w| needed_later.contains(w));
|
||||||
}
|
}
|
||||||
|
|
@ -225,30 +206,35 @@ fn compute_tie_breakers(
|
||||||
active_wildcards: &HashSet<String>,
|
active_wildcards: &HashSet<String>,
|
||||||
statements: &[StatementTmpl],
|
statements: &[StatementTmpl],
|
||||||
remaining: &HashSet<usize>,
|
remaining: &HashSet<usize>,
|
||||||
|
needed_later: &HashSet<String>,
|
||||||
|
public_args: &HashSet<String>,
|
||||||
) -> (usize, usize, i32) {
|
) -> (usize, usize, i32) {
|
||||||
let stmt_wildcards = collect_wildcards_from_statement(stmt);
|
let all_wildcards = collect_wildcards_from_statement(stmt);
|
||||||
|
// Only consider private wildcards for tie-breaking metrics
|
||||||
// Metric 1: Simplicity - prefer statements with fewer wildcards
|
let stmt_wildcards: HashSet<_> = all_wildcards
|
||||||
let simplicity = usize::MAX - stmt_wildcards.len();
|
.into_iter()
|
||||||
|
.filter(|w| !public_args.contains(w))
|
||||||
// Metric 2: Public closure - prefer statements that close active wildcards
|
|
||||||
// (wildcards that won't be needed by any remaining statements)
|
|
||||||
let needed_later: HashSet<String> = remaining
|
|
||||||
.iter()
|
|
||||||
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
// Metric 1: Simplicity - prefer statements with fewer private wildcards
|
||||||
|
let simplicity = usize::MAX - stmt_wildcards.len();
|
||||||
|
|
||||||
|
// Metric 2: Closure - prefer statements that close active private wildcards
|
||||||
|
// (wildcards that won't be needed by any remaining statements)
|
||||||
let closes_count = stmt_wildcards
|
let closes_count = stmt_wildcards
|
||||||
.intersection(active_wildcards)
|
.intersection(active_wildcards)
|
||||||
.filter(|w| !needed_later.contains(*w))
|
.filter(|w| !needed_later.contains(*w))
|
||||||
.count();
|
.count();
|
||||||
|
|
||||||
// Metric 3: Fanout - prefer statements with lower future usage
|
// Metric 3: Fanout - prefer statements with lower future usage
|
||||||
// (number of remaining statements that use any wildcard from this statement)
|
// (number of remaining statements sharing private wildcards with this statement)
|
||||||
let fanout = remaining
|
let fanout = remaining
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|&&i| {
|
.filter(|&&i| {
|
||||||
let other_wildcards = collect_wildcards_from_statement(&statements[i]);
|
let other_wildcards: HashSet<_> = collect_wildcards_from_statement(&statements[i])
|
||||||
|
.into_iter()
|
||||||
|
.filter(|w| !public_args.contains(w))
|
||||||
|
.collect();
|
||||||
!stmt_wildcards.is_disjoint(&other_wildcards)
|
!stmt_wildcards.is_disjoint(&other_wildcards)
|
||||||
})
|
})
|
||||||
.count();
|
.count();
|
||||||
|
|
@ -262,16 +248,33 @@ fn statement_selection_key(
|
||||||
active_wildcards: &HashSet<String>,
|
active_wildcards: &HashSet<String>,
|
||||||
remaining: &HashSet<usize>,
|
remaining: &HashSet<usize>,
|
||||||
approaching_split: bool,
|
approaching_split: bool,
|
||||||
|
public_args: &HashSet<String>,
|
||||||
) -> (i32, (usize, usize, i32), Reverse<usize>) {
|
) -> (i32, (usize, usize, i32), Reverse<usize>) {
|
||||||
|
// Pre-compute needed_later once and share between primary score and tie-breakers.
|
||||||
|
// Exclude the candidate itself: we want to know what the *other* remaining statements
|
||||||
|
// need, so that wildcards used only by this candidate correctly appear as closeable.
|
||||||
|
let needed_later: HashSet<String> = remaining
|
||||||
|
.iter()
|
||||||
|
.filter(|&&i| i != idx)
|
||||||
|
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
|
||||||
|
.filter(|w| !public_args.contains(w))
|
||||||
|
.collect();
|
||||||
|
|
||||||
let primary_score = score_statement(
|
let primary_score = score_statement(
|
||||||
|
&statements[idx],
|
||||||
|
active_wildcards,
|
||||||
|
approaching_split,
|
||||||
|
public_args,
|
||||||
|
&needed_later,
|
||||||
|
);
|
||||||
|
let tie_breakers = compute_tie_breakers(
|
||||||
&statements[idx],
|
&statements[idx],
|
||||||
active_wildcards,
|
active_wildcards,
|
||||||
statements,
|
statements,
|
||||||
remaining,
|
remaining,
|
||||||
approaching_split,
|
&needed_later,
|
||||||
|
public_args,
|
||||||
);
|
);
|
||||||
let tie_breakers =
|
|
||||||
compute_tie_breakers(&statements[idx], active_wildcards, statements, remaining);
|
|
||||||
|
|
||||||
// Final deterministic tie-breaker: prefer smaller original indices.
|
// Final deterministic tie-breaker: prefer smaller original indices.
|
||||||
// This avoids hash-iteration-dependent selection when scores are equal.
|
// This avoids hash-iteration-dependent selection when scores are equal.
|
||||||
|
|
@ -284,6 +287,7 @@ fn find_best_next_statement(
|
||||||
remaining: &HashSet<usize>,
|
remaining: &HashSet<usize>,
|
||||||
active_wildcards: &HashSet<String>,
|
active_wildcards: &HashSet<String>,
|
||||||
ordered_count: usize,
|
ordered_count: usize,
|
||||||
|
public_args: &HashSet<String>,
|
||||||
) -> usize {
|
) -> usize {
|
||||||
// Calculate distance to next split point
|
// Calculate distance to next split point
|
||||||
let bucket_size = Params::max_custom_predicate_arity() - 1; // Reserve slot for chain call
|
let bucket_size = Params::max_custom_predicate_arity() - 1; // Reserve slot for chain call
|
||||||
|
|
@ -299,51 +303,66 @@ fn find_best_next_statement(
|
||||||
active_wildcards,
|
active_wildcards,
|
||||||
remaining,
|
remaining,
|
||||||
approaching_split,
|
approaching_split,
|
||||||
|
public_args,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.copied()
|
.copied()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Score a statement based on how well it minimizes liveness
|
/// Score a statement based on how well it minimizes private-wildcard liveness at boundaries.
|
||||||
|
/// `needed_later` is the set of private wildcards used by any remaining statement.
|
||||||
fn score_statement(
|
fn score_statement(
|
||||||
stmt: &StatementTmpl,
|
stmt: &StatementTmpl,
|
||||||
active_wildcards: &HashSet<String>,
|
active_wildcards: &HashSet<String>,
|
||||||
statements: &[StatementTmpl],
|
|
||||||
remaining: &HashSet<usize>,
|
|
||||||
approaching_split: bool,
|
approaching_split: bool,
|
||||||
|
public_args: &HashSet<String>,
|
||||||
|
needed_later: &HashSet<String>,
|
||||||
) -> i32 {
|
) -> i32 {
|
||||||
let stmt_wildcards = collect_wildcards_from_statement(stmt);
|
let all_wildcards = collect_wildcards_from_statement(stmt);
|
||||||
|
|
||||||
// How many active wildcards does this reuse?
|
// Only score based on private wildcards. Public args are always available at every
|
||||||
let reuse_count = stmt_wildcards.intersection(active_wildcards).count();
|
// split boundary — they never consume a promotion slot, so their liveness is free.
|
||||||
|
let stmt_wildcards: HashSet<_> = all_wildcards
|
||||||
// How many new wildcards does this introduce?
|
.into_iter()
|
||||||
let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count();
|
.filter(|w| !public_args.contains(w))
|
||||||
|
|
||||||
// After adding this statement, what would be active?
|
|
||||||
let mut projected_active = active_wildcards.clone();
|
|
||||||
projected_active.extend(stmt_wildcards.clone());
|
|
||||||
|
|
||||||
// Which wildcards are still needed by other remaining statements?
|
|
||||||
let needed_later: HashSet<String> = remaining
|
|
||||||
.iter()
|
|
||||||
.flat_map(|&i| collect_wildcards_from_statement(&statements[i]))
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Wildcards we can close = active now but not needed later
|
// Statements that touch only public args ("cheap" statements) waste a bucket slot
|
||||||
|
// that could be used to cluster private wildcards. Strongly defer them while any
|
||||||
|
// private-wildcard statements remain, so they fill leftover space at the end.
|
||||||
|
// `needed_later` is non-empty iff some remaining statement has a private wildcard.
|
||||||
|
if stmt_wildcards.is_empty() {
|
||||||
|
return if needed_later.is_empty() {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
i32::MIN / 2
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// How many active private wildcards does this reuse?
|
||||||
|
let reuse_count = stmt_wildcards.intersection(active_wildcards).count();
|
||||||
|
|
||||||
|
// How many new private wildcards does this introduce?
|
||||||
|
let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count();
|
||||||
|
|
||||||
|
// Which of the projected-active wildcards are still needed after this statement?
|
||||||
|
let mut projected_active = active_wildcards.clone();
|
||||||
|
projected_active.extend(stmt_wildcards);
|
||||||
projected_active.retain(|w| needed_later.contains(w));
|
projected_active.retain(|w| needed_later.contains(w));
|
||||||
let still_active_count = projected_active.len();
|
let still_active_count = projected_active.len();
|
||||||
|
|
||||||
// Base score calculation
|
// Base score:
|
||||||
// - Prefer statements that reuse active wildcards (don't introduce new liveness)
|
// +3 per reused wildcard — rewards clustering (wildcard already open, no new cost)
|
||||||
// - Penalize introducing new wildcards (increases liveness)
|
// -4 per new wildcard — penalises opening new live ranges
|
||||||
// - Penalize keeping many wildcards active (higher liveness)
|
// -2 per still-live — penalises carrying many wildcards toward the boundary
|
||||||
let base_score = (reuse_count * 3) as i32
|
let base_score = (reuse_count * 3) as i32
|
||||||
- (new_wildcard_count * 4) as i32
|
- (new_wildcard_count * 4) as i32
|
||||||
- (still_active_count * 2) as i32;
|
- (still_active_count * 2) as i32;
|
||||||
|
|
||||||
// Look-ahead bonus: when approaching split, heavily favor closing wildcards
|
// When close to a split boundary, strongly reward statements that close wildcards
|
||||||
|
// (active.len() + new - still_active = number of wildcards resolved by this statement).
|
||||||
|
// Weight 10 >> max base-score magnitude to make closing the dominant factor.
|
||||||
if approaching_split {
|
if approaching_split {
|
||||||
let closes_count = active_wildcards.len() + new_wildcard_count - still_active_count;
|
let closes_count = active_wildcards.len() + new_wildcard_count - still_active_count;
|
||||||
base_score + (closes_count * 10) as i32
|
base_score + (closes_count * 10) as i32
|
||||||
|
|
@ -375,8 +394,6 @@ fn calculate_live_wildcards(
|
||||||
fn generate_refactor_suggestion(
|
fn generate_refactor_suggestion(
|
||||||
crossing_wildcards: &[String],
|
crossing_wildcards: &[String],
|
||||||
ordered_statements: &[StatementTmpl],
|
ordered_statements: &[StatementTmpl],
|
||||||
_pos: usize,
|
|
||||||
_end: usize,
|
|
||||||
) -> Option<crate::lang::error::RefactorSuggestion> {
|
) -> Option<crate::lang::error::RefactorSuggestion> {
|
||||||
use crate::lang::error::RefactorSuggestion;
|
use crate::lang::error::RefactorSuggestion;
|
||||||
|
|
||||||
|
|
@ -445,13 +462,6 @@ fn split_into_chain(
|
||||||
let original_name = pred.name.name.clone();
|
let original_name = pred.name.name.clone();
|
||||||
let conjunction = pred.conjunction_type;
|
let conjunction = pred.conjunction_type;
|
||||||
|
|
||||||
let usage = analyze_wildcards(&pred.statements);
|
|
||||||
let real_statement_count = pred.statements.len();
|
|
||||||
|
|
||||||
let ordering_result = order_constraints_optimally(pred.statements, &usage);
|
|
||||||
let ordered_statements = ordering_result.statements;
|
|
||||||
let reorder_map = ordering_result.reorder_map;
|
|
||||||
|
|
||||||
let original_public_args: Vec<String> = pred
|
let original_public_args: Vec<String> = pred
|
||||||
.args
|
.args
|
||||||
.public_args
|
.public_args
|
||||||
|
|
@ -459,6 +469,14 @@ fn split_into_chain(
|
||||||
.map(|id| id.name.clone())
|
.map(|id| id.name.clone())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
let public_args_set: HashSet<String> = original_public_args.iter().cloned().collect();
|
||||||
|
|
||||||
|
let real_statement_count = pred.statements.len();
|
||||||
|
|
||||||
|
let ordering_result = order_constraints_optimally(pred.statements, &public_args_set);
|
||||||
|
let ordered_statements = ordering_result.statements;
|
||||||
|
let reorder_map = ordering_result.reorder_map;
|
||||||
|
|
||||||
let mut chain_links = Vec::new();
|
let mut chain_links = Vec::new();
|
||||||
let mut pos = 0;
|
let mut pos = 0;
|
||||||
let mut incoming_public = original_public_args.clone();
|
let mut incoming_public = original_public_args.clone();
|
||||||
|
|
@ -501,8 +519,7 @@ fn split_into_chain(
|
||||||
total_public,
|
total_public,
|
||||||
};
|
};
|
||||||
|
|
||||||
let suggestion =
|
let suggestion = generate_refactor_suggestion(&new_promotions, &ordered_statements);
|
||||||
generate_refactor_suggestion(&new_promotions, &ordered_statements, pos, end);
|
|
||||||
|
|
||||||
return Err(SplittingError::TooManyPublicArgsAtSplit {
|
return Err(SplittingError::TooManyPublicArgsAtSplit {
|
||||||
predicate: original_name.clone(),
|
predicate: original_name.clone(),
|
||||||
|
|
@ -540,23 +557,60 @@ fn split_into_chain(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut public_args_out: Vec<String> = live_at_boundary.iter().cloned().collect();
|
|
||||||
public_args_out.sort(); // Deterministic ordering
|
|
||||||
|
|
||||||
chain_links.push(ChainLink {
|
chain_links.push(ChainLink {
|
||||||
statements: ordered_statements[pos..end].to_vec(),
|
statements: ordered_statements[pos..end].to_vec(),
|
||||||
public_args_in: incoming_public.clone(),
|
public_args_in: incoming_public.clone(),
|
||||||
private_args,
|
private_args,
|
||||||
public_args_out: public_args_out.clone(),
|
// new_promotions are already sorted and already filtered to exclude incoming_public
|
||||||
|
promoted_wildcards: new_promotions.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
pos = end;
|
pos = end;
|
||||||
|
|
||||||
// Next link's incoming public args = current incoming + newly promoted live wildcards
|
// Extend incoming_public for the next link with the newly promoted wildcards.
|
||||||
// Only add wildcards that aren't already in incoming_public to avoid duplicates
|
// new_promotions is already filtered to exclude incoming_set, so no dedup needed.
|
||||||
for wildcard in public_args_out {
|
incoming_public.extend(new_promotions);
|
||||||
if !incoming_set.contains(&wildcard) {
|
}
|
||||||
incoming_public.push(wildcard);
|
|
||||||
|
// Backward pass: prune each continuation's public args to the minimal set needed.
|
||||||
|
//
|
||||||
|
// The forward pass accumulates incoming_public monotonically, so a continuation may
|
||||||
|
// inherit original public args that none of its statements (or downstream continuations)
|
||||||
|
// ever reference. A continuation must declare every public arg it receives, and the
|
||||||
|
// proof system constrains each declared arg - an arg that goes unused has no constraints
|
||||||
|
// and will not match the value the caller passes.
|
||||||
|
//
|
||||||
|
// Propagating from the last link backward ensures each continuation declares exactly the
|
||||||
|
// args it uses directly, plus any args its successor still needs. Link 0 (the original
|
||||||
|
// predicate) is left untouched - its public-arg signature is user-declared.
|
||||||
|
{
|
||||||
|
let num_links = chain_links.len();
|
||||||
|
if num_links > 1 {
|
||||||
|
// Collect wildcards referenced by each link's statements once.
|
||||||
|
let link_wildcards: Vec<HashSet<String>> = chain_links
|
||||||
|
.iter()
|
||||||
|
.map(|link| {
|
||||||
|
link.statements
|
||||||
|
.iter()
|
||||||
|
.flat_map(collect_wildcards_from_statement)
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let last = num_links - 1;
|
||||||
|
|
||||||
|
// Seed: last link retains only args it directly references.
|
||||||
|
chain_links[last]
|
||||||
|
.public_args_in
|
||||||
|
.retain(|a| link_wildcards[last].contains(a));
|
||||||
|
|
||||||
|
// Propagate backward through intermediate continuation links (skip link 0).
|
||||||
|
for i in (1..last).rev() {
|
||||||
|
let needed_downstream: HashSet<String> =
|
||||||
|
chain_links[i + 1].public_args_in.iter().cloned().collect();
|
||||||
|
chain_links[i]
|
||||||
|
.public_args_in
|
||||||
|
.retain(|a| link_wildcards[i].contains(a) || needed_downstream.contains(a));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -590,7 +644,7 @@ fn split_into_chain(
|
||||||
let mut chain_predicates =
|
let mut chain_predicates =
|
||||||
generate_chain_predicates(&original_name, chain_links, conjunction, params)?;
|
generate_chain_predicates(&original_name, chain_links, conjunction, params)?;
|
||||||
|
|
||||||
validate_chain(&chain_predicates, params)?;
|
validate_chain(&chain_predicates, params);
|
||||||
|
|
||||||
// Reverse so continuations come before callers in declaration order.
|
// Reverse so continuations come before callers in declaration order.
|
||||||
// This ensures that when batched, continuations are in earlier batches
|
// This ensures that when batched, continuations are in earlier batches
|
||||||
|
|
@ -633,7 +687,7 @@ fn generate_chain_predicates(
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create arguments for chain call: use next link's public_args_in
|
// Create arguments for chain call: use next link's public_args_in
|
||||||
// which is the deduplicated union of current public_args_in and public_args_out
|
// which is current public_args_in extended with current promoted_wildcards
|
||||||
let next_link = &chain_links[i + 1];
|
let next_link = &chain_links[i + 1];
|
||||||
let chain_call_args: Vec<StatementTmplArg> = next_link
|
let chain_call_args: Vec<StatementTmplArg> = next_link
|
||||||
.public_args_in
|
.public_args_in
|
||||||
|
|
@ -665,10 +719,12 @@ fn generate_chain_predicates(
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Build private args (private + promoted for next)
|
// Build private args: segment-local private wildcards, plus any wildcards being
|
||||||
|
// promoted to public for the next link (they must be declared here so the solver
|
||||||
|
// can bind them before passing them as public args to the continuation).
|
||||||
let mut private_arg_names = link.private_args.clone();
|
let mut private_arg_names = link.private_args.clone();
|
||||||
if !is_last {
|
if !is_last {
|
||||||
private_arg_names.extend(link.public_args_out.clone());
|
private_arg_names.extend(link.promoted_wildcards.iter().cloned());
|
||||||
}
|
}
|
||||||
|
|
||||||
let private_args = if private_arg_names.is_empty() {
|
let private_args = if private_arg_names.is_empty() {
|
||||||
|
|
@ -698,25 +754,34 @@ fn generate_chain_predicates(
|
||||||
Ok(predicates)
|
Ok(predicates)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Phase 5: Validate the generated chain
|
/// Sanity-check the generated chain. All three constraints are enforced as proper errors
|
||||||
///
|
/// earlier in `split_into_chain`, so violations here indicate a bug in the algorithm.
|
||||||
/// Note: We no longer check chain length against max_custom_batch_size since
|
fn validate_chain(chain: &[CustomPredicateDef], params: &Params) {
|
||||||
/// chains can now span multiple batches thanks to multi-batch support.
|
|
||||||
fn validate_chain(chain: &[CustomPredicateDef], params: &Params) -> Result<(), SplittingError> {
|
|
||||||
for pred in chain {
|
for pred in chain {
|
||||||
// Each predicate should have ≤ max_statements
|
assert!(
|
||||||
assert!(pred.statements.len() <= Params::max_custom_predicate_arity());
|
pred.statements.len() <= Params::max_custom_predicate_arity(),
|
||||||
|
"chain link '{}' has {} statements, exceeds max {}",
|
||||||
// Public args should fit
|
pred.name.name,
|
||||||
assert!(pred.args.public_args.len() <= Params::max_statement_args());
|
pred.statements.len(),
|
||||||
|
Params::max_custom_predicate_arity(),
|
||||||
// Total args should fit
|
);
|
||||||
|
assert!(
|
||||||
|
pred.args.public_args.len() <= Params::max_statement_args(),
|
||||||
|
"chain link '{}' has {} public args, exceeds max {}",
|
||||||
|
pred.name.name,
|
||||||
|
pred.args.public_args.len(),
|
||||||
|
Params::max_statement_args(),
|
||||||
|
);
|
||||||
let total =
|
let total =
|
||||||
pred.args.public_args.len() + pred.args.private_args.as_ref().map_or(0, |v| v.len());
|
pred.args.public_args.len() + pred.args.private_args.as_ref().map_or(0, |v| v.len());
|
||||||
assert!(total <= params.max_custom_predicate_wildcards);
|
assert!(
|
||||||
|
total <= params.max_custom_predicate_wildcards,
|
||||||
|
"chain link '{}' has {} total args, exceeds max {}",
|
||||||
|
pred.name.name,
|
||||||
|
total,
|
||||||
|
params.max_custom_predicate_wildcards,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -976,8 +1041,24 @@ mod tests {
|
||||||
let remaining: HashSet<usize> = [0, 1].into_iter().collect();
|
let remaining: HashSet<usize> = [0, 1].into_iter().collect();
|
||||||
let active_wildcards = HashSet::new();
|
let active_wildcards = HashSet::new();
|
||||||
|
|
||||||
let key0 = statement_selection_key(0, &statements, &active_wildcards, &remaining, false);
|
// A and B are the public args of tie_break(A, B)
|
||||||
let key1 = statement_selection_key(1, &statements, &active_wildcards, &remaining, false);
|
let public_args: HashSet<String> = ["A".to_string(), "B".to_string()].into_iter().collect();
|
||||||
|
let key0 = statement_selection_key(
|
||||||
|
0,
|
||||||
|
&statements,
|
||||||
|
&active_wildcards,
|
||||||
|
&remaining,
|
||||||
|
false,
|
||||||
|
&public_args,
|
||||||
|
);
|
||||||
|
let key1 = statement_selection_key(
|
||||||
|
1,
|
||||||
|
&statements,
|
||||||
|
&active_wildcards,
|
||||||
|
&remaining,
|
||||||
|
false,
|
||||||
|
&public_args,
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(key0.0, key1.0, "Primary heuristic score should tie");
|
assert_eq!(key0.0, key1.0, "Primary heuristic score should tie");
|
||||||
assert_eq!(key0.1, key1.1, "Secondary tie-breaker metrics should tie");
|
assert_eq!(key0.1, key1.1, "Secondary tie-breaker metrics should tie");
|
||||||
|
|
@ -986,7 +1067,8 @@ mod tests {
|
||||||
"Lower original index should win deterministic final tie-breaker"
|
"Lower original index should win deterministic final tie-breaker"
|
||||||
);
|
);
|
||||||
|
|
||||||
let selected = find_best_next_statement(&statements, &remaining, &active_wildcards, 0);
|
let selected =
|
||||||
|
find_best_next_statement(&statements, &remaining, &active_wildcards, 0, &public_args);
|
||||||
assert_eq!(selected, 0);
|
assert_eq!(selected, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1084,7 +1166,7 @@ mod tests {
|
||||||
assert!(error_msg.contains("3 crossing wildcards"));
|
assert!(error_msg.contains("3 crossing wildcards"));
|
||||||
assert!(error_msg.contains("= 6 total"));
|
assert!(error_msg.contains("= 6 total"));
|
||||||
assert!(error_msg.contains("exceeds max of 5"));
|
assert!(error_msg.contains("exceeds max of 5"));
|
||||||
assert!(error_msg.contains("Statements 0-4"));
|
assert!(error_msg.contains("Statements 0-3"));
|
||||||
assert!(error_msg.contains("Incoming public args: A, B, C"));
|
assert!(error_msg.contains("Incoming public args: A, B, C"));
|
||||||
assert!(error_msg.contains("Wildcards crossing this boundary: T1, T2, T3"));
|
assert!(error_msg.contains("Wildcards crossing this boundary: T1, T2, T3"));
|
||||||
assert!(error_msg.contains("Suggestion:"));
|
assert!(error_msg.contains("Suggestion:"));
|
||||||
|
|
@ -1144,25 +1226,82 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Regression tests ---
|
||||||
|
|
||||||
|
/// Statements that reference only public args should be deferred until private-wildcard
|
||||||
|
/// statements have been clustered, so they don't consume bucket slots that would reduce
|
||||||
|
/// liveness at split boundaries.
|
||||||
|
///
|
||||||
|
/// 4 public args, 7 statements: W1 used in stmts 0,1,4; W2 used in stmts 1,2,3;
|
||||||
|
/// stmts 5,6 reference only public args. The scoring correctly defers stmts 5,6,
|
||||||
|
/// yielding bucket0={0,1,2,3}, bucket1={4,5,6} with only W1 crossing (4+1=5 <= max).
|
||||||
#[test]
|
#[test]
|
||||||
fn test_refactor_suggestion_group_wildcards() {
|
fn test_split_succeeds_with_four_public_args_and_public_only_statements() {
|
||||||
// Test the "group wildcard usages" suggestion formatting
|
// Optimal split: bucket0={0,1,2,3}, bucket1={4,5,6}
|
||||||
use crate::lang::error::RefactorSuggestion;
|
// Only W1 crosses (used in 0,1 and 4), total = 4 public + 1 crossing = 5 ✓
|
||||||
|
let input = r#"
|
||||||
|
pred(A, B, C, D, private: W1, W2) = AND(
|
||||||
|
Equal(W1["x"], A["v"])
|
||||||
|
Equal(W2["y"], W1["x"])
|
||||||
|
Equal(W2["z"], B["v"])
|
||||||
|
Equal(C["r"], W2["y"])
|
||||||
|
Equal(D["s"], W1["x"])
|
||||||
|
Equal(A["out"], C["out"])
|
||||||
|
Equal(B["out"], D["out"])
|
||||||
|
)
|
||||||
|
"#;
|
||||||
|
|
||||||
let suggestion = RefactorSuggestion::GroupWildcardUsages {
|
let pred = parse_predicate(input);
|
||||||
wildcards: vec!["T1".to_string(), "T2".to_string(), "T3".to_string()],
|
let params = Params::default();
|
||||||
};
|
|
||||||
|
|
||||||
let suggestion_text = suggestion.format();
|
let result = split_predicate_if_needed(pred, ¶ms);
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"Should find a valid split with ≤1 crossing wildcard, got: {:?}",
|
||||||
|
result.err()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Verify the suggestion formats correctly
|
/// Continuation predicates should only declare the public args they actually use -
|
||||||
assert!(suggestion_text.contains("Group operations for wildcards"));
|
/// original public args that are not referenced in a continuation's statements or
|
||||||
assert!(suggestion_text.contains("T1, T2, T3"));
|
/// any of its downstream continuations must be omitted from its signature.
|
||||||
assert!(suggestion_text.contains("used across multiple segments"));
|
#[test]
|
||||||
|
fn test_continuation_excludes_public_args_unused_after_split() {
|
||||||
|
// A is used only in the first segment; B is used only in the second segment.
|
||||||
|
// The continuation predicate (pred_1) must include B but not A.
|
||||||
|
let input = r#"
|
||||||
|
pred(A, B, private: T) = AND(
|
||||||
|
Equal(T["x"], A["val"])
|
||||||
|
Equal(T["y"], 1)
|
||||||
|
Equal(T["z"], 2)
|
||||||
|
Equal(T["w"], 3)
|
||||||
|
Equal(B["r"], T["x"])
|
||||||
|
Equal(B["s"], T["y"])
|
||||||
|
)
|
||||||
|
"#;
|
||||||
|
|
||||||
eprintln!(
|
let pred = parse_predicate(input);
|
||||||
"\n=== Example GroupWildcardUsages Suggestion ===\n{}\n",
|
let params = Params::default();
|
||||||
suggestion_text
|
|
||||||
|
let result = split_predicate_if_needed(pred, ¶ms).unwrap();
|
||||||
|
// chain[0] is the continuation (_1 suffix), chain[1] is the original
|
||||||
|
let continuation = result
|
||||||
|
.predicates
|
||||||
|
.iter()
|
||||||
|
.find(|p| p.name.name == "pred_1")
|
||||||
|
.expect("Expected a pred_1 continuation predicate");
|
||||||
|
|
||||||
|
let cont_public: Vec<&str> = continuation
|
||||||
|
.args
|
||||||
|
.public_args
|
||||||
|
.iter()
|
||||||
|
.map(|id| id.name.as_str())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!cont_public.contains(&"A"),
|
||||||
|
"Continuation should drop unused public arg 'A', got: {:?}",
|
||||||
|
cont_public
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -77,14 +77,14 @@ impl NativePredicate {
|
||||||
| NativePredicate::MaxOf
|
| NativePredicate::MaxOf
|
||||||
| NativePredicate::HashOf
|
| NativePredicate::HashOf
|
||||||
| NativePredicate::SetInsert
|
| NativePredicate::SetInsert
|
||||||
| NativePredicate::SetDelete => 3,
|
| NativePredicate::SetDelete
|
||||||
|
| NativePredicate::DictDelete
|
||||||
|
| NativePredicate::ContainerDelete => 3,
|
||||||
NativePredicate::DictInsert
|
NativePredicate::DictInsert
|
||||||
| NativePredicate::DictUpdate
|
| NativePredicate::DictUpdate
|
||||||
| NativePredicate::DictDelete
|
|
||||||
| NativePredicate::ArrayUpdate
|
| NativePredicate::ArrayUpdate
|
||||||
| NativePredicate::ContainerInsert
|
| NativePredicate::ContainerInsert
|
||||||
| NativePredicate::ContainerUpdate
|
| NativePredicate::ContainerUpdate => 4,
|
||||||
| NativePredicate::ContainerDelete => 4,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue