From 2443f8552038fbe887cde3faf3c57fbd7aa935de Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Thu, 30 Apr 2026 16:56:42 +0100 Subject: [PATCH] Working MILP version of predicate splitter --- src/lang/error.rs | 10 + src/lang/frontend_ast.rs | 12 + src/lang/frontend_ast_lower.rs | 18 +- src/lang/frontend_ast_split.rs | 1265 +++++++++++++++++--------------- 4 files changed, 705 insertions(+), 600 deletions(-) diff --git a/src/lang/error.rs b/src/lang/error.rs index ca2dabc..5078f42 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -399,6 +399,16 @@ pub enum SplittingError { max_allowed: usize, suggestion: Option>, }, + + #[error("Could not split predicate '{predicate}' into a chain: no feasible partition exists with up to {max_links} links. \ + The predicate's wildcard structure may be too dense for any chain to fit within max_statement_args ({max_statement_args}) \ + and max_custom_predicate_wildcards ({max_wildcards}) per link.")] + Infeasible { + predicate: String, + max_links: usize, + max_statement_args: usize, + max_wildcards: usize, + }, } impl From for LangError { diff --git a/src/lang/frontend_ast.rs b/src/lang/frontend_ast.rs index 9843fcd..14ea33d 100644 --- a/src/lang/frontend_ast.rs +++ b/src/lang/frontend_ast.rs @@ -134,6 +134,18 @@ pub struct StatementTmpl { pub span: Option, } +impl StatementTmpl { + /// Names of all wildcards referenced by this statement's arguments, + /// in argument order with duplicates included. + pub fn wildcard_names(&self) -> impl Iterator { + self.args.iter().filter_map(|arg| match arg { + StatementTmplArg::Wildcard(id) => Some(id.name.as_str()), + StatementTmplArg::AnchoredKey(ak) => Some(ak.root.name.as_str()), + StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => None, + }) + } +} + /// Reference to a predicate (local or qualified with module name) #[derive(Debug, Clone, PartialEq)] pub enum PredicateRef { diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index c89d045..3f4a558 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -528,21 +528,9 @@ impl<'a> Lowerer<'a> { names: &mut Vec, seen: &mut HashSet, ) { - for arg in &stmt.args { - match arg { - StatementTmplArg::Wildcard(id) => { - if !seen.contains(&id.name) { - seen.insert(id.name.clone()); - names.push(id.name.clone()); - } - } - StatementTmplArg::AnchoredKey(ak) => { - if !seen.contains(&ak.root.name) { - seen.insert(ak.root.name.clone()); - names.push(ak.root.name.clone()); - } - } - StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {} + for name in stmt.wildcard_names() { + if seen.insert(name.to_string()) { + names.push(name.to_string()); } } } diff --git a/src/lang/frontend_ast_split.rs b/src/lang/frontend_ast_split.rs index 7889f67..544f151 100644 --- a/src/lang/frontend_ast_split.rs +++ b/src/lang/frontend_ast_split.rs @@ -1,26 +1,45 @@ //! Predicate splitting for frontend AST //! -//! This module implements automatic predicate splitting when predicates exceed -//! middleware constraints. +//! Predicates whose statement count exceeds the middleware's +//! `max_custom_predicate_arity` are split into a chain of smaller predicates, +//! each calling the next via a tail-position chain call. Private wildcards +//! that span a split boundary must be promoted to public arguments on the +//! continuation, since they need the same binding on both sides. //! -//! When splitting a predicate, we try to group statements that use the same -//! wildcards together. However, if a private wildcard must be used across a -//! split boundary, it must be promoted to a public argument in the latter -//! predicate, to ensure that it is bound to the same value in both predicates. +//! The split is computed by an MILP that, for a given number of links K: //! -//! A wildcard is "live" at a split boundary if it is used in a statement on both -//! sides of the boundary. We want to minimize the number of live wildcards at -//! split boundaries, to minimize the number of promotions required. +//! - Assigns each statement to exactly one link. +//! - Tracks which wildcards are used and where, derives "live ranges," and +//! counts each link's declared public/private wildcards. +//! - Caps each link's public-arg count at `max_statement_args` and total +//! declared wildcards at `max_custom_predicate_wildcards`. +//! - Reserves a chain-call slot on every non-last link. //! -//! We use a greedy algorithm to order the statements in a predicate to minimize -//! the number of live wildcards at split boundaries. +//! We try `K = K_min, K_min+1, ...` and stop at the first feasible K. The +//! objective is a tiny `Σ (n-1-s) · i · assign[s][i]` tiebreaker that biases +//! statements with low original index toward low-index links — so the chain +//! roughly follows source order when nothing else forces a rearrangement. +//! +//! On infeasibility for every K up to `n`, we emit +//! [`SplittingError::Infeasible`]. -use std::{cmp::Reverse, collections::HashSet}; +#![allow(clippy::needless_range_loop)] + +use std::collections::{HashMap, HashSet}; + +use good_lp::{ + constraint, default_solver, variable, Expression, ProblemVariables, Solution, SolverModel, + Variable, +}; -// SplittingError is now defined in error.rs pub use crate::lang::error::SplittingError; use crate::{lang::frontend_ast::*, middleware::Params}; +/// Threshold for interpreting MILP solver's floating-point results as binary. +/// The solver returns continuous values in [0, 1] for binary variables; +/// values > 0.5 are interpreted as "true" (1), otherwise "false" (0). +const SOLVER_BINARY_THRESHOLD: f64 = 0.5; + /// A link in the predicate chain #[derive(Debug, Clone)] pub struct ChainLink { @@ -111,356 +130,296 @@ pub fn split_predicate_if_needed( }) } -/// Collect all wildcard names from a statement fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet { - let mut wildcards = HashSet::new(); - - for arg in &stmt.args { - match arg { - StatementTmplArg::Wildcard(id) => { - wildcards.insert(id.name.clone()); - } - StatementTmplArg::AnchoredKey(ak) => { - wildcards.insert(ak.root.name.clone()); - } - StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {} - } - } - - wildcards + stmt.wildcard_names().map(str::to_string).collect() } -/// Order constraints optimally to minimize liveness at boundaries -/// Result of ordering statements optimally for splitting -struct OrderingResult { - /// Reordered statements - statements: Vec, - /// Maps original statement index → reordered index - /// reorder_map[original_idx] = new_idx - reorder_map: Vec, -} - -fn order_constraints_optimally( - statements: Vec, - public_args: &HashSet, -) -> OrderingResult { - let n = statements.len(); - - // If no splitting needed, preserve original order (identity mapping) - if n <= Params::max_custom_predicate_arity() { - return OrderingResult { - statements, - reorder_map: (0..n).collect(), - }; - } - - let mut ordered = Vec::new(); - let mut reorder_map = vec![0; n]; - let mut remaining: HashSet = (0..n).collect(); - let mut active_wildcards: HashSet = HashSet::new(); - - while !remaining.is_empty() { - let best_idx = find_best_next_statement( - &statements, - &remaining, - &active_wildcards, - ordered.len(), - public_args, - ); - - remaining.remove(&best_idx); - let stmt = &statements[best_idx]; - - // Record the mapping: original index best_idx → new index ordered.len() - reorder_map[best_idx] = ordered.len(); - ordered.push(stmt.clone()); - - // Only track private wildcards in the active set — public args are always - // available at every boundary so their liveness is irrelevant to split cost. - let stmt_wildcards = collect_wildcards_from_statement(stmt); - active_wildcards.extend( - stmt_wildcards - .into_iter() - .filter(|w| !public_args.contains(w)), - ); - - // Remove private wildcards no longer needed by remaining statements - let needed_later: HashSet<_> = remaining - .iter() - .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) - .filter(|w| !public_args.contains(w)) - .collect(); - active_wildcards.retain(|w| needed_later.contains(w)); - } - - OrderingResult { - statements: ordered, - reorder_map, - } -} - -/// Compute tie-breaker metrics for deterministic ordering when scores are equal -/// Returns (simplicity, public_closure, negative_fanout) tuple for use in max_by_key -fn compute_tie_breakers( - stmt: &StatementTmpl, - active_wildcards: &HashSet, - statements: &[StatementTmpl], - remaining: &HashSet, - needed_later: &HashSet, - public_args: &HashSet, -) -> (usize, usize, i32) { - let all_wildcards = collect_wildcards_from_statement(stmt); - // Only consider private wildcards for tie-breaking metrics - let stmt_wildcards: HashSet<_> = all_wildcards - .into_iter() - .filter(|w| !public_args.contains(w)) - .collect(); - - // Metric 1: Simplicity - prefer statements with fewer private wildcards - let simplicity = usize::MAX - stmt_wildcards.len(); - - // Metric 2: Closure - prefer statements that close active private wildcards - // (wildcards that won't be needed by any remaining statements) - let closes_count = stmt_wildcards - .intersection(active_wildcards) - .filter(|w| !needed_later.contains(*w)) - .count(); - - // Metric 3: Fanout - prefer statements with lower future usage - // (number of remaining statements sharing private wildcards with this statement) - let fanout = remaining - .iter() - .filter(|&&i| { - let other_wildcards: HashSet<_> = collect_wildcards_from_statement(&statements[i]) - .into_iter() - .filter(|w| !public_args.contains(w)) - .collect(); - !stmt_wildcards.is_disjoint(&other_wildcards) - }) - .count(); - - (simplicity, closes_count, -(fanout as i32)) -} - -fn statement_selection_key( - idx: usize, - statements: &[StatementTmpl], - active_wildcards: &HashSet, - remaining: &HashSet, - approaching_split: bool, - public_args: &HashSet, -) -> (i32, (usize, usize, i32), Reverse) { - // Pre-compute needed_later once and share between primary score and tie-breakers. - // Exclude the candidate itself: we want to know what the *other* remaining statements - // need, so that wildcards used only by this candidate correctly appear as closeable. - let needed_later: HashSet = remaining - .iter() - .filter(|&&i| i != idx) - .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) - .filter(|w| !public_args.contains(w)) - .collect(); - - let primary_score = score_statement( - &statements[idx], - active_wildcards, - approaching_split, - public_args, - &needed_later, - ); - let tie_breakers = compute_tie_breakers( - &statements[idx], - active_wildcards, - statements, - remaining, - &needed_later, - public_args, - ); - - // Final deterministic tie-breaker: prefer smaller original indices. - // This avoids hash-iteration-dependent selection when scores are equal. - (primary_score, tie_breakers, Reverse(idx)) -} - -/// Find the best next statement to add based on scoring heuristic -fn find_best_next_statement( - statements: &[StatementTmpl], - remaining: &HashSet, - active_wildcards: &HashSet, - ordered_count: usize, - public_args: &HashSet, -) -> usize { - // Calculate distance to next split point - let bucket_size = Params::max_custom_predicate_arity() - 1; // Reserve slot for chain call - let distance_to_split = bucket_size - (ordered_count % bucket_size); - let approaching_split = distance_to_split <= 2; - - remaining - .iter() - .max_by_key(|&&idx| { - statement_selection_key( - idx, - statements, - active_wildcards, - remaining, - approaching_split, - public_args, - ) - }) - .copied() - .unwrap() -} - -/// Score a statement based on how well it minimizes private-wildcard liveness at boundaries. -/// `needed_later` is the set of private wildcards used by any remaining statement. -fn score_statement( - stmt: &StatementTmpl, - active_wildcards: &HashSet, - approaching_split: bool, - public_args: &HashSet, - needed_later: &HashSet, -) -> i32 { - let all_wildcards = collect_wildcards_from_statement(stmt); - - // Only score based on private wildcards. Public args are always available at every - // split boundary — they never consume a promotion slot, so their liveness is free. - let stmt_wildcards: HashSet<_> = all_wildcards - .into_iter() - .filter(|w| !public_args.contains(w)) - .collect(); - - // Statements that touch only public args ("cheap" statements) waste a bucket slot - // that could be used to cluster private wildcards. Strongly defer them while any - // private-wildcard statements remain, so they fill leftover space at the end. - // `needed_later` is non-empty iff some remaining statement has a private wildcard. - if stmt_wildcards.is_empty() { - return if needed_later.is_empty() { - 0 - } else { - i32::MIN / 2 - }; - } - - // How many active private wildcards does this reuse? - let reuse_count = stmt_wildcards.intersection(active_wildcards).count(); - - // How many new private wildcards does this introduce? - let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count(); - - // Which of the projected-active wildcards are still needed after this statement? - let mut projected_active = active_wildcards.clone(); - projected_active.extend(stmt_wildcards); - projected_active.retain(|w| needed_later.contains(w)); - let still_active_count = projected_active.len(); - - // Base score: - // +3 per reused wildcard — rewards clustering (wildcard already open, no new cost) - // -4 per new wildcard — penalises opening new live ranges - // -2 per still-live — penalises carrying many wildcards toward the boundary - let base_score = (reuse_count * 3) as i32 - - (new_wildcard_count * 4) as i32 - - (still_active_count * 2) as i32; - - // When close to a split boundary, strongly reward statements that close wildcards - // (active.len() + new - still_active = number of wildcards resolved by this statement). - // Weight 10 >> max base-score magnitude to make closing the dominant factor. - if approaching_split { - let closes_count = active_wildcards.len() + new_wildcard_count - still_active_count; - base_score + (closes_count * 10) as i32 +/// Compute the minimum number of chain links needed to fit `n` statements, +/// given that non-last links reserve 1 slot for the chain call (so they hold +/// up to `max_arity - 1` real statements) and the last link uses all of +/// `max_arity`. +fn compute_min_links(n: usize) -> usize { + let max_arity = Params::max_custom_predicate_arity(); + if n <= max_arity { + 1 } else { - base_score + // Smallest K such that (K-1)·(max_arity-1) + max_arity >= n + (n - max_arity).div_ceil(max_arity - 1) + 1 } } -/// Calculate which wildcards are live at a split boundary -fn calculate_live_wildcards( - before_split: &[StatementTmpl], - after_split: &[StatementTmpl], -) -> HashSet { - let before: HashSet<_> = before_split - .iter() - .flat_map(collect_wildcards_from_statement) - .collect(); +/// MILP outcome for a single K: `links[i]` is the list of original statement +/// indices placed in link i, in original order. +type LinkAssignment = Vec>; - let after: HashSet<_> = after_split - .iter() - .flat_map(collect_wildcards_from_statement) - .collect(); +/// Try to partition `n` statements into exactly `k` links using MILP. +/// +/// Returns `Some(assignment)` if a feasible partition exists, `None` if the +/// model is infeasible at this K (caller should try a larger K). +/// +/// Variables (all binary): +/// - `assign[s][i]`: statement `s` is placed in link `i`. +/// - `u[w][i]`: wildcard `w` is referenced by some statement at link `i`. +/// - `before[w][i]`, `after[w][i]`: cumulative ORs of `u[w][·]` from the left +/// and right respectively. `before[w][i] = 1` iff w is used at link ≤ i. +/// - `pubin[w][i]`: w appears in link i's `public_args_in`. +/// - `privin[w][i]`: w appears in link i's `private_args` list. +/// +/// All non-`assign` variables are forced to be exact functions of `assign` +/// via two-sided (≤ and ≥) constraints, so the LP relaxation has a unique +/// solution given any integer assignment. The objective is a small +/// source-order tiebreaker on `assign`. +fn solve_milp_for_k( + n: usize, + k: usize, + statements_using: &[Vec], + is_original_public: &[bool], + params: &Params, +) -> Option { + let max_arity = Params::max_custom_predicate_arity(); + let max_args = Params::max_statement_args(); + let max_wildcards = params.max_custom_predicate_wildcards; + let num_wildcards = is_original_public.len(); - // Live = in both sets (crosses boundary) - before.intersection(&after).cloned().collect() -} + let mut vars = ProblemVariables::new(); -/// Generate a refactor suggestion for wildcards crossing a boundary -fn generate_refactor_suggestion( - crossing_wildcards: &[String], - ordered_statements: &[StatementTmpl], -) -> Option { - use crate::lang::error::RefactorSuggestion; + fn mk_grid(vars: &mut ProblemVariables, rows: usize, cols: usize) -> Vec> { + (0..rows) + .map(|_| (0..cols).map(|_| vars.add(variable().binary())).collect()) + .collect() + } + let assign = mk_grid(&mut vars, n, k); + let u = mk_grid(&mut vars, num_wildcards, k); + let before = mk_grid(&mut vars, num_wildcards, k); + let after = mk_grid(&mut vars, num_wildcards, k); + let pubin = mk_grid(&mut vars, num_wildcards, k); + let privin = mk_grid(&mut vars, num_wildcards, k); - if crossing_wildcards.is_empty() { - return None; + // Source-order tiebreaker. Coefficient `(n-1-s) * i` is largest for + // low-index statements at high-index links, so minimization pulls + // low-original-index statements toward low-link-index assignments — + // i.e., the chain roughly preserves source order. The objective only + // breaks ties among feasibility-equivalent solutions. + let objective: Expression = (0..n) + .flat_map(|s| (0..k).map(move |i| (s, i))) + .map(|(s, i)| ((n - 1 - s) as f64) * (i as f64) * assign[s][i]) + .sum(); + + let mut model = vars.minimise(objective).using(default_solver); + + // C1: Each statement assigned to exactly one link. + for s in 0..n { + let sum: Expression = (0..k).map(|i| assign[s][i]).sum(); + model.add_constraint(constraint!(sum == 1)); } - // Normalize wildcard order so diagnostics are deterministic. - let mut sorted_crossing_wildcards = crossing_wildcards.to_vec(); - sorted_crossing_wildcards.sort(); + // C2: Per-link statement count. Non-last links reserve a slot for the + // chain call. Also require at least one statement per link. + for i in 0..k { + let cap = if i + 1 < k { max_arity - 1 } else { max_arity }; + let sum_le: Expression = (0..n).map(|s| assign[s][i]).sum(); + model.add_constraint(constraint!(sum_le <= cap as f64)); + let sum_ge: Expression = (0..n).map(|s| assign[s][i]).sum(); + model.add_constraint(constraint!(sum_ge >= 1)); + } - // Analyze the span of each crossing wildcard - let mut wildcard_spans: Vec<(String, usize, usize, usize)> = Vec::new(); + // C3: u[w][i] is exactly the OR over s referencing w of assign[s][i]. + for w in 0..num_wildcards { + for i in 0..k { + for &s in &statements_using[w] { + model.add_constraint(constraint!(u[w][i] >= assign[s][i])); + } + let upper: Expression = statements_using[w] + .iter() + .map(|&s| Expression::from(assign[s][i])) + .sum(); + model.add_constraint(constraint!(u[w][i] <= upper)); + } + } - for wildcard in &sorted_crossing_wildcards { - let mut first_use = None; - let mut last_use = None; + // C4: before[w][i] = u[w][0] OR u[w][1] OR ... OR u[w][i]. + for w in 0..num_wildcards { + model.add_constraint(constraint!(before[w][0] == u[w][0])); + for i in 1..k { + model.add_constraint(constraint!(before[w][i] >= before[w][i - 1])); + model.add_constraint(constraint!(before[w][i] >= u[w][i])); + model.add_constraint(constraint!(before[w][i] <= before[w][i - 1] + u[w][i])); + } + } - for (i, stmt) in ordered_statements.iter().enumerate() { - let wildcards = collect_wildcards_from_statement(stmt); - if wildcards.contains(wildcard) { - if first_use.is_none() { - first_use = Some(i); - } - last_use = Some(i); + // C5: after[w][i] = u[w][i] OR u[w][i+1] OR ... OR u[w][k-1]. + for w in 0..num_wildcards { + model.add_constraint(constraint!(after[w][k - 1] == u[w][k - 1])); + for i in (0..k - 1).rev() { + model.add_constraint(constraint!(after[w][i] >= after[w][i + 1])); + model.add_constraint(constraint!(after[w][i] >= u[w][i])); + model.add_constraint(constraint!(after[w][i] <= after[w][i + 1] + u[w][i])); + } + } + + // C6: pubin definitions. + for w in 0..num_wildcards { + if is_original_public[w] { + // Original public args: declared at link 0 (predicate signature) + // and forwarded to link i iff used at some link ≥ i. + model.add_constraint(constraint!(pubin[w][0] == 1)); + for i in 1..k { + model.add_constraint(constraint!(pubin[w][i] == after[w][i])); + } + } else { + // Private wildcards: pubin[w][i] = before[w][i-1] AND after[w][i] + // (used somewhere strictly before AND somewhere at i or later). + model.add_constraint(constraint!(pubin[w][0] == 0)); + for i in 1..k { + model.add_constraint(constraint!(pubin[w][i] <= before[w][i - 1])); + model.add_constraint(constraint!(pubin[w][i] <= after[w][i])); + model.add_constraint(constraint!( + pubin[w][i] >= before[w][i - 1] + after[w][i] - 1 + )); } } + } - if let (Some(first), Some(last)) = (first_use, last_use) { - let span = last - first; - wildcard_spans.push((wildcard.clone(), first, last, span)); + // C7: privin definitions. + for w in 0..num_wildcards { + if is_original_public[w] { + for i in 0..k { + model.add_constraint(constraint!(privin[w][i] == 0)); + } + } else { + // privin[w][0] = u[w][0]: at link 0 there is no "before," so a + // private wildcard used at link 0 is necessarily declared private. + model.add_constraint(constraint!(privin[w][0] == u[w][0])); + for i in 1..k { + // privin[w][i] = u[w][i] AND NOT before[w][i-1] + model.add_constraint(constraint!(privin[w][i] <= u[w][i])); + model.add_constraint(constraint!(privin[w][i] <= 1 - before[w][i - 1])); + model.add_constraint(constraint!(privin[w][i] >= u[w][i] - before[w][i - 1])); + } } } - // Sort by span (largest first) - wildcard_spans.sort_by(|a, b| b.3.cmp(&a.3)); + // C8: per-link public-args cap (incoming chain-call args). + for i in 0..k { + let sum: Expression = (0..num_wildcards).map(|w| pubin[w][i]).sum(); + model.add_constraint(constraint!(sum <= max_args as f64)); + } - if let Some((wildcard, first, last, span)) = wildcard_spans.first() { - // If a single wildcard has a large span, suggest reducing it - if *span > 3 { - return Some(RefactorSuggestion::ReduceWildcardSpan { - wildcard: wildcard.clone(), - first_use: *first, - last_use: *last, - span: *span, - }); + // C9: per-link total declared wildcards cap. + for i in 0..k { + let sum: Expression = (0..num_wildcards) + .map(|w| Expression::from(pubin[w][i]) + privin[w][i]) + .sum(); + model.add_constraint(constraint!(sum <= max_wildcards as f64)); + } + + let solution = model.solve().ok()?; + + // Extract per-link statement lists in original-index order. + let mut links: LinkAssignment = vec![Vec::new(); k]; + for s in 0..n { + for i in 0..k { + if solution.value(assign[s][i]) > SOLVER_BINARY_THRESHOLD { + links[i].push(s); + break; + } } } - - // If multiple wildcards cross the boundary, suggest grouping - if sorted_crossing_wildcards.len() > 1 { - return Some(RefactorSuggestion::GroupWildcardUsages { - wildcards: sorted_crossing_wildcards, - }); - } - - None + Some(links) } -/// Split into chain using bucket-filling approach -/// Returns the split predicates and metadata about the split +/// Convert an MILP link assignment into [`ChainLink`]s, computing each link's +/// public/private/promoted wildcards from the assignment plus the original +/// public-args list. +fn build_chain_links_from_assignment( + links: LinkAssignment, + statements: &[StatementTmpl], + original_public_args: &[String], +) -> Vec { + let k = links.len(); + let stmt_wcs: Vec> = statements + .iter() + .map(collect_wildcards_from_statement) + .collect(); + let link_wcs: Vec> = (0..k) + .map(|i| { + links[i] + .iter() + .flat_map(|&s| stmt_wcs[s].iter().cloned()) + .collect() + }) + .collect(); + + let mut result = Vec::with_capacity(k); + let mut incoming: Vec = original_public_args.to_vec(); + + for i in 0..k { + let stmts: Vec = links[i].iter().map(|&s| statements[s].clone()).collect(); + + // Wildcards crossing forward from link i (used here AND later). + let after_wcs: HashSet = (i + 1..k) + .flat_map(|j| link_wcs[j].iter().cloned()) + .collect(); + let crossings: HashSet = link_wcs[i].intersection(&after_wcs).cloned().collect(); + + let incoming_set: HashSet = incoming.iter().cloned().collect(); + + let mut promotions: Vec = crossings + .iter() + .filter(|w| !incoming_set.contains(*w)) + .cloned() + .collect(); + promotions.sort(); + + let mut private_args: Vec = link_wcs[i] + .difference(&incoming_set) + .filter(|w| !crossings.contains(*w)) + .cloned() + .collect(); + private_args.sort(); + + result.push(ChainLink { + statements: stmts, + public_args_in: incoming.clone(), + private_args, + promoted_wildcards: promotions.clone(), + }); + + incoming.extend(promotions); + } + + // Backward pruning: drop public args from continuations that no link + // (this one or downstream) actually references. Link 0 keeps its full + // user-declared signature. + let num_links = result.len(); + if num_links > 1 { + let last = num_links - 1; + result[last] + .public_args_in + .retain(|a| link_wcs[last].contains(a)); + for i in (1..last).rev() { + let needed_downstream: HashSet = + result[i + 1].public_args_in.iter().cloned().collect(); + result[i] + .public_args_in + .retain(|a| link_wcs[i].contains(a) || needed_downstream.contains(a)); + } + } + + result +} + +/// Split a predicate into a chain via MILP. Tries `K = K_min, K_min+1, ...`, +/// returning the first feasible chain or [`SplittingError::Infeasible`] if +/// no `K` up to `n` works. fn split_into_chain( pred: CustomPredicateDef, params: &Params, ) -> Result<(Vec, SplitChainInfo), SplittingError> { let original_name = pred.name.name.clone(); let conjunction = pred.conjunction_type; + let n = pred.statements.len(); + let real_statement_count = n; let original_public_args: Vec = pred .args @@ -469,154 +428,76 @@ fn split_into_chain( .map(|id| id.name.clone()) .collect(); - let public_args_set: HashSet = original_public_args.iter().cloned().collect(); - - let real_statement_count = pred.statements.len(); - - let ordering_result = order_constraints_optimally(pred.statements, &public_args_set); - let ordered_statements = ordering_result.statements; - let reorder_map = ordering_result.reorder_map; - - let mut chain_links = Vec::new(); - let mut pos = 0; - let mut incoming_public = original_public_args.clone(); - - while pos < ordered_statements.len() { - let remaining = ordered_statements.len() - pos; - let is_last = remaining <= Params::max_custom_predicate_arity(); - - let bucket_size = if is_last { - remaining // Last predicate uses all remaining - } else { - Params::max_custom_predicate_arity() - 1 // Reserve slot for chain call - }; - - let end = pos + bucket_size; - - // Calculate liveness at this split boundary - let live_at_boundary = if is_last { - HashSet::new() - } else { - calculate_live_wildcards(&ordered_statements[pos..end], &ordered_statements[end..]) - }; - - // Check: Can we fit promoted wildcards in public args? - // Need to account for possible overlap between incoming_public and live_at_boundary - let incoming_set: HashSet<_> = incoming_public.iter().cloned().collect(); - let mut new_promotions: Vec<_> = live_at_boundary - .iter() - .filter(|w| !incoming_set.contains(*w)) - .cloned() - .collect(); - new_promotions.sort(); - let total_public = incoming_public.len() + new_promotions.len(); - if total_public > Params::max_statement_args() { - let context = crate::lang::error::SplitContext { - split_index: chain_links.len(), - statement_range: (pos, end), - incoming_public: incoming_public.clone(), - crossing_wildcards: new_promotions.clone(), - total_public, - }; - - let suggestion = generate_refactor_suggestion(&new_promotions, &ordered_statements); - - return Err(SplittingError::TooManyPublicArgsAtSplit { - predicate: original_name.clone(), - context: Box::new(context), - max_allowed: Params::max_statement_args(), - suggestion: suggestion.map(Box::new), - }); - } - - // Calculate private args (used in this segment but not incoming and not outgoing) - let segment_wildcards: HashSet<_> = ordered_statements[pos..end] - .iter() - .flat_map(collect_wildcards_from_statement) - .collect(); - - let mut private_args: Vec = segment_wildcards - .difference(&incoming_set) - .filter(|w| !live_at_boundary.contains(*w)) - .cloned() - .collect(); - private_args.sort(); // Deterministic ordering - - // Check: Total args constraint (incoming + new promotions + private) - let public_count = incoming_public.len() + new_promotions.len(); - let private_count = private_args.len(); - let total_args = public_count + private_count; - if total_args > params.max_custom_predicate_wildcards { - return Err(SplittingError::TooManyTotalArgsInChainLink { - predicate: original_name.clone(), - link_index: chain_links.len(), - public_count, - private_count, - total_count: total_args, - max_allowed: params.max_custom_predicate_wildcards, - }); - } - - chain_links.push(ChainLink { - statements: ordered_statements[pos..end].to_vec(), - public_args_in: incoming_public.clone(), - private_args, - // new_promotions are already sorted and already filtered to exclude incoming_public - promoted_wildcards: new_promotions.clone(), - }); - - pos = end; - - // Extend incoming_public for the next link with the newly promoted wildcards. - // new_promotions is already filtered to exclude incoming_set, so no dedup needed. - incoming_public.extend(new_promotions); + // Build a stable, sorted index over all wildcards referenced by the + // predicate (statements + declared public args). + let mut wildcard_set: HashSet = pred + .statements + .iter() + .flat_map(collect_wildcards_from_statement) + .collect(); + for name in &original_public_args { + wildcard_set.insert(name.clone()); } + let mut wildcard_names: Vec = wildcard_set.into_iter().collect(); + wildcard_names.sort(); + let wildcard_index: HashMap = wildcard_names + .iter() + .enumerate() + .map(|(i, name)| (name.clone(), i)) + .collect(); - // Backward pass: prune each continuation's public args to the minimal set needed. - // - // The forward pass accumulates incoming_public monotonically, so a continuation may - // inherit original public args that none of its statements (or downstream continuations) - // ever reference. A continuation must declare every public arg it receives, and the - // proof system constrains each declared arg - an arg that goes unused has no constraints - // and will not match the value the caller passes. - // - // Propagating from the last link backward ensures each continuation declares exactly the - // args it uses directly, plus any args its successor still needs. Link 0 (the original - // predicate) is left untouched - its public-arg signature is user-declared. - { - let num_links = chain_links.len(); - if num_links > 1 { - // Collect wildcards referenced by each link's statements once. - let link_wildcards: Vec> = chain_links - .iter() - .map(|link| { - link.statements - .iter() - .flat_map(collect_wildcards_from_statement) - .collect() - }) - .collect(); - - let last = num_links - 1; - - // Seed: last link retains only args it directly references. - chain_links[last] - .public_args_in - .retain(|a| link_wildcards[last].contains(a)); - - // Propagate backward through intermediate continuation links (skip link 0). - for i in (1..last).rev() { - let needed_downstream: HashSet = - chain_links[i + 1].public_args_in.iter().cloned().collect(); - chain_links[i] - .public_args_in - .retain(|a| link_wildcards[i].contains(a) || needed_downstream.contains(a)); + // Inverse map from wildcard index to the statements that reference it. + // Loop-invariant across the K-search, so build once. + let mut statements_using: Vec> = vec![Vec::new(); wildcard_names.len()]; + for (s, stmt) in pred.statements.iter().enumerate() { + let mut seen: HashSet = HashSet::new(); + for name in stmt.wildcard_names() { + let w = wildcard_index[name]; + if seen.insert(w) { + statements_using[w].push(s); } } } - // Build SplitChainInfo from chain_links before generating predicates - // Pieces are in execution order: innermost continuation first, original last + let mut is_original_public = vec![false; wildcard_names.len()]; + for name in &original_public_args { + is_original_public[wildcard_index[name]] = true; + } + + let k_min = compute_min_links(n); + let mut found: Option<(usize, LinkAssignment)> = None; + for k in k_min..=n { + if let Some(assignment) = + solve_milp_for_k(n, k, &statements_using, &is_original_public, params) + { + found = Some((k, assignment)); + break; + } + } + + let (_k, assignment) = found.ok_or_else(|| SplittingError::Infeasible { + predicate: original_name.clone(), + max_links: n, + max_statement_args: Params::max_statement_args(), + max_wildcards: params.max_custom_predicate_wildcards, + })?; + + // Reorder map: original index → position in flattened chain. + let mut reorder_map = vec![0usize; n]; + { + let mut flat = 0usize; + for link in &assignment { + for &s in link { + reorder_map[s] = flat; + flat += 1; + } + } + } + + let chain_links = + build_chain_links_from_assignment(assignment, &pred.statements, &original_public_args); + + // Build SplitChainInfo (execution order: innermost continuation first). let num_links = chain_links.len(); let mut chain_pieces = Vec::new(); for i in (0..num_links).rev() { @@ -647,14 +528,13 @@ fn split_into_chain( validate_chain(&chain_predicates, params); // Reverse so continuations come before callers in declaration order. - // This ensures that when batched, continuations are in earlier batches - // and can be referenced by their callers. chain_predicates.reverse(); Ok((chain_predicates, chain_info)) } -/// Phase 4: Generate synthetic predicates from chain links +/// Build the chain's [`CustomPredicateDef`]s from the per-link metadata, +/// inserting a chain call on every non-last link. fn generate_chain_predicates( original_name: &str, chain_links: Vec, @@ -679,15 +559,12 @@ fn generate_chain_predicates( let is_last = i == chain_links.len() - 1; let mut statements = link.statements.clone(); - // Add chain call if not last if !is_last { let next_pred_name = Identifier { name: format!("{}_{}", original_name, i + 1), span: None, }; - // Create arguments for chain call: use next link's public_args_in - // which is current public_args_in extended with current promoted_wildcards let next_link = &chain_links[i + 1]; let chain_call_args: Vec = next_link .public_args_in @@ -1030,113 +907,6 @@ mod tests { ); } - #[test] - fn test_statement_selection_prefers_lower_index_on_tie() { - // Two structurally symmetric statements produce identical heuristic scores. - // Determinism comes from the final index-based tie breaker. - let input = r#" - tie_break(A, B) = AND ( - Equal(A["x"], B["x"]) - Equal(A["y"], B["y"]) - ) - "#; - - let pred = parse_predicate(input); - let statements = pred.statements; - let remaining: HashSet = [0, 1].into_iter().collect(); - let active_wildcards = HashSet::new(); - - // A and B are the public args of tie_break(A, B) - let public_args: HashSet = ["A".to_string(), "B".to_string()].into_iter().collect(); - let key0 = statement_selection_key( - 0, - &statements, - &active_wildcards, - &remaining, - false, - &public_args, - ); - let key1 = statement_selection_key( - 1, - &statements, - &active_wildcards, - &remaining, - false, - &public_args, - ); - - assert_eq!(key0.0, key1.0, "Primary heuristic score should tie"); - assert_eq!(key0.1, key1.1, "Secondary tie-breaker metrics should tie"); - assert!( - key0 > key1, - "Lower original index should win deterministic final tie-breaker" - ); - - let selected = - find_best_next_statement(&statements, &remaining, &active_wildcards, 0, &public_args); - assert_eq!(selected, 0); - } - - #[test] - fn test_greedy_ordering_reduces_liveness() { - // This test verifies that our greedy ordering algorithm reduces wildcard liveness - // by clustering statements that use the same wildcards together. - // - // The predicate has 8 statements using 3 private wildcards (T1, T2, T3): - // - T1 used in statements 1, 4, 7 - // - T2 used in statements 2, 5, 8 - // - T3 used in statements 3, 6 - // - // NAIVE ORDERING (original order): - // Would interleave T1, T2, T3 usage throughout the predicate. - // When splitting at statement limit (5 statements per predicate): - // Predicate 1: statements 1-5 (introduces T1, T2, T3 - none complete) - // Predicate 2: statements 6-8 (all 3 wildcards still live) - // Result: 2 public args (A, B) + 3 promoted wildcards = 5 total in predicate 2 - // - // GREEDY ORDERING (our algorithm): - // Clusters statements by wildcard to minimize liveness: - // Groups T1 statements together, then T2, then T3 - // Predicate 1: completes some wildcards before the split point - // Predicate 2: fewer wildcards need to cross the boundary - // Result: 2 public args (A, B) + 1-2 promoted wildcards = 3-4 total in predicate 2 - let input = r#" - clustered(A, B, private: T1, T2, T3) = AND ( - Equal(T1["x"], 1) - Equal(T2["y"], 2) - Equal(T3["z"], 3) - Equal(T1["a"], 4) - Equal(T2["b"], 5) - Equal(T3["c"], 6) - Equal(T1["d"], A["result"]) - Equal(T2["e"], B["value"]) - ) - "#; - - let pred = parse_predicate(input); - let params = Params::default(); - - let result = split_predicate_if_needed(pred, ¶ms); - assert!(result.is_ok()); - - let split_result = result.unwrap(); - let chain = &split_result.predicates; - assert_eq!(chain.len(), 2, "Predicate should split into 2 links"); - - let second_pred = &chain[1]; - let second_pred_public_count = second_pred.args.public_args.len(); - - // Verify greedy ordering achieves better results than naive ordering would - // Started with 2 public args (A, B) - // Naive would have: 2 + 3 promoted = 5 public args in second predicate - // Greedy achieves: 2 + 1-2 promoted = 3-4 public args in second predicate - assert!( - second_pred_public_count <= 4, - "Greedy ordering should reduce promotions to ≤4 public args, but got {}", - second_pred_public_count - ); - } - #[test] fn test_error_message_formatting() { // Test that error messages format correctly with detailed context @@ -1309,4 +1079,329 @@ mod tests { cont_public ); } + + // =================================================================== + // Completeness probe for the splitter. + // + // `build_pred` constructs a CustomPredicateDef from a "wildcard set per + // statement" specification (cheaper than parsing). `find_any_feasible_ordering` + // brute-forces all permutations and uses the same per-link constraints as + // `split_into_chain` to check whether a feasible chain exists at all. + // =================================================================== + + fn build_pred( + name: &str, + public_args: &[&str], + private_args: &[&str], + stmt_wildcards: &[&[&str]], + ) -> CustomPredicateDef { + let statements: Vec = stmt_wildcards + .iter() + .map(|wcs| { + let args: Vec = wcs + .iter() + .map(|n| { + StatementTmplArg::Wildcard(Identifier { + name: n.to_string(), + span: None, + }) + }) + .collect(); + StatementTmpl { + predicate: PredicateRef::Local(Identifier { + name: "Equal".to_string(), + span: None, + }), + args, + span: None, + } + }) + .collect(); + + let private_args = if private_args.is_empty() { + None + } else { + Some( + private_args + .iter() + .map(|n| TypedArg { + name: n.to_string(), + type_name: None, + span: None, + }) + .collect(), + ) + }; + + CustomPredicateDef { + name: Identifier { + name: name.to_string(), + span: None, + }, + args: ArgSection { + public_args: public_args + .iter() + .map(|n| TypedArg { + name: n.to_string(), + type_name: None, + span: None, + }) + .collect(), + private_args, + span: None, + }, + conjunction_type: ConjunctionType::And, + statements, + span: None, + } + } + + /// Replicates the bucket-fill constraint check from `split_into_chain` for + /// a *fixed* ordering of statements. Returns Ok if the ordering produces a + /// valid chain, Err otherwise. + fn check_ordering_feasible( + ordered: &[StatementTmpl], + original_public_args: &[String], + params: &Params, + ) -> bool { + if ordered.len() <= Params::max_custom_predicate_arity() { + return true; + } + + let mut pos = 0; + let mut incoming_public: Vec = original_public_args.to_vec(); + + while pos < ordered.len() { + let remaining = ordered.len() - pos; + let is_last = remaining <= Params::max_custom_predicate_arity(); + let bucket_size = if is_last { + remaining + } else { + Params::max_custom_predicate_arity() - 1 + }; + let end = pos + bucket_size; + + let live: HashSet = if is_last { + HashSet::new() + } else { + let before: HashSet = ordered[pos..end] + .iter() + .flat_map(collect_wildcards_from_statement) + .collect(); + let after: HashSet = ordered[end..] + .iter() + .flat_map(collect_wildcards_from_statement) + .collect(); + before.intersection(&after).cloned().collect() + }; + + let incoming_set: HashSet = incoming_public.iter().cloned().collect(); + let new_promotions: Vec = live + .iter() + .filter(|w| !incoming_set.contains(*w)) + .cloned() + .collect(); + let total_public = incoming_public.len() + new_promotions.len(); + if total_public > Params::max_statement_args() { + return false; + } + + let segment_wildcards: HashSet = ordered[pos..end] + .iter() + .flat_map(collect_wildcards_from_statement) + .collect(); + let private_args: Vec = segment_wildcards + .difference(&incoming_set) + .filter(|w| !live.contains(*w)) + .cloned() + .collect(); + let total_args = total_public + private_args.len(); + if total_args > params.max_custom_predicate_wildcards { + return false; + } + + pos = end; + incoming_public.extend(new_promotions); + } + + true + } + + /// Brute-force search over all permutations of the predicate's statements + /// for *any* ordering that produces a feasible split. Returns the + /// permutation if found, else None. Caps at 9! to keep tests cheap. + fn find_any_feasible_ordering( + pred: &CustomPredicateDef, + params: &Params, + ) -> Option> { + use itertools::Itertools; + + let n = pred.statements.len(); + assert!(n <= 9, "brute-force capped at 9! permutations"); + + let original_public_args: Vec = pred + .args + .public_args + .iter() + .map(|id| id.name.clone()) + .collect(); + + for perm in (0..n).permutations(n) { + let ordered: Vec = + perm.iter().map(|&i| pred.statements[i].clone()).collect(); + if check_ordering_feasible(&ordered, &original_public_args, params) { + return Some(perm); + } + } + None + } + + /// 6 statements with 2 public args (A0, A1) and 5 private wildcards + /// (T0..T4). A feasible 4+2 chain exists where exactly 3 wildcards cross + /// the boundary (3 promotions + 2 incoming = 5 total public, hitting the + /// cap). The splitter must find one — a partition that puts an extra + /// wildcard across the boundary fails the per-link public-arg cap. + /// + /// Found by random search with seed 0xC0FFEE; inlined for determinism. + #[test] + fn test_splitter_handles_tight_public_arg_cap() { + let pred = build_pred( + "p", + &["A0", "A1"], + &["T0", "T1", "T2", "T3", "T4"], + &[ + &["T0", "T4", "T2"], + &["T1", "T3", "T4"], + &["T2", "T3", "T1"], + &["T4", "A0", "A1"], + &["T3", "T0", "T2"], + &["T0", "A1", "T1"], + ], + ); + let params = Params::default(); + + // Sanity: brute force confirms a feasible ordering exists. + let feasible = find_any_feasible_ordering(&pred, ¶ms); + assert!( + feasible.is_some(), + "expected at least one feasible permutation" + ); + + let result = split_predicate_if_needed(pred, ¶ms); + assert!( + result.is_ok(), + "splitter rejected an input with a feasible ordering ({:?}): {}", + feasible.unwrap(), + result.err().unwrap() + ); + } + + /// Randomized counterexample search. Run with + /// `cargo test --release search_splitter -- --ignored --nocapture`. + #[test] + #[ignore] + fn search_splitter_counterexample() { + // Tiny LCG so we don't pull rand as a dep. + struct Lcg(u64); + impl Lcg { + fn next(&mut self) -> u64 { + self.0 = self + .0 + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + self.0 + } + fn rand_in(&mut self, n: usize) -> usize { + (self.next() as usize) % n + } + } + + let params = Params::default(); + let mut rng = Lcg(0xC0FFEE); + let mut checked = 0; + let mut found = 0; + + // Sweep over (n_stmts, n_pub, n_priv) combos. + for &n_stmts in &[6usize, 7, 8, 9] { + for &n_pub in &[1usize, 2, 3, 4] { + for &n_priv in &[2usize, 3, 4, 5] { + let pub_names: Vec = (0..n_pub).map(|i| format!("A{}", i)).collect(); + let priv_names: Vec = (0..n_priv).map(|i| format!("T{}", i)).collect(); + let all_names: Vec = + pub_names.iter().chain(priv_names.iter()).cloned().collect(); + + // Generate 200 random predicates per shape. + for trial in 0..200 { + // Each statement gets 2-3 distinct wildcards drawn from all_names. + let stmt_wildcards: Vec> = (0..n_stmts) + .map(|_| { + let arity = 2 + rng.rand_in(2); // 2 or 3 + let mut chosen = Vec::new(); + let mut tries = 0; + while chosen.len() < arity && tries < 20 { + let pick = all_names[rng.rand_in(all_names.len())].clone(); + if !chosen.contains(&pick) { + chosen.push(pick); + } + tries += 1; + } + chosen + }) + .collect(); + + let stmt_refs: Vec> = stmt_wildcards + .iter() + .map(|v| v.iter().map(|s| s.as_str()).collect()) + .collect(); + let stmt_slices: Vec<&[&str]> = + stmt_refs.iter().map(|v| v.as_slice()).collect(); + let pub_refs: Vec<&str> = pub_names.iter().map(|s| s.as_str()).collect(); + let priv_refs: Vec<&str> = priv_names.iter().map(|s| s.as_str()).collect(); + + let pred = build_pred("p", &pub_refs, &priv_refs, &stmt_slices); + + // Skip inputs that fail early validation (e.g. too many public args). + if validate_predicate_is_splittable(&pred).is_err() { + continue; + } + + checked += 1; + let feasible = find_any_feasible_ordering(&pred, ¶ms); + let split = split_predicate_if_needed(pred.clone(), ¶ms); + + if let (Err(err), Some(perm)) = (split, feasible) { + found += 1; + eprintln!( + "\n=== COUNTEREXAMPLE #{} ===\n\ + shape: n={}, n_pub={}, n_priv={}, trial={}\n\ + statements (original order):", + found, n_stmts, n_pub, n_priv, trial + ); + for (i, wcs) in stmt_wildcards.iter().enumerate() { + eprintln!(" s{}: {:?}", i, wcs); + } + eprintln!("feasible permutation: {:?}", perm); + eprintln!("splitter error: {}\n", err); + + if found >= 3 { + eprintln!( + "Found {} counterexamples (out of {} checked); stopping.", + found, checked + ); + return; + } + } + } + } + } + } + + eprintln!( + "Searched {} predicates; found {} counterexamples.", + checked, found + ); + if found == 0 { + eprintln!("No counterexamples found."); + } + } }