Improved diagnostics on failure

This commit is contained in:
Rob Knight 2026-04-30 18:58:07 +01:00
parent 20204548b3
commit fcf0888c25
No known key found for this signature in database
4 changed files with 393 additions and 91 deletions

View file

@ -554,7 +554,7 @@ impl<'a> Lowerer<'a> {
let mut split_results = Vec::new();
for mut pred in predicates {
self.rewrite_typed_dot_access(&mut pred);
let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?;
let result = frontend_ast_split::split_predicate_if_needed(&pred, self.params)?;
split_results.push(result);
}

View file

@ -25,7 +25,10 @@
#![allow(clippy::needless_range_loop)]
use std::collections::{HashMap, HashSet};
use std::{
collections::{HashMap, HashSet},
fmt,
};
use good_lp::{
constraint, default_solver, variable, Expression, ProblemVariables, Solution, SolverModel,
@ -87,6 +90,68 @@ pub struct SplitResult {
pub chain_info: Option<SplitChainInfo>,
}
/// Per-link bottleneck found by [`analyze_infeasibility`]: how far each
/// binding link overshoots the per-link caps, and which wildcards crowd it.
#[derive(Debug, Clone)]
pub struct LinkOvershoot {
pub link_index: usize,
/// Number of public-args slots over `max_statement_args` for this link.
pub public_args_overflow: usize,
/// Number of total declared-wildcard slots over `max_custom_predicate_wildcards`.
pub total_args_overflow: usize,
/// Wildcards passed in to this link as public args (in the elastic solution).
pub public_args_in: Vec<String>,
/// Wildcards declared as private at this link (in the elastic solution).
pub private_args: Vec<String>,
}
/// Diagnostic report explaining why [`split_predicate_if_needed`] returned
/// [`SplittingError::Infeasible`]. Produced by [`analyze_infeasibility`] on
/// demand — the splitter itself doesn't compute it, since computing it
/// requires a second LP solve.
#[derive(Debug, Clone)]
pub struct InfeasibilityReport {
pub predicate: String,
/// Number of links the elastic LP was solved at (the minimum K).
pub k: usize,
/// Per-link overshoots in link-index order. Links not over any cap are omitted.
pub overshoots: Vec<LinkOvershoot>,
}
impl fmt::Display for InfeasibilityReport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"Predicate '{}' cannot be split into {} link(s) without overflowing per-link caps:",
self.predicate, self.k
)?;
let max_args = Params::max_statement_args();
for o in &self.overshoots {
if o.public_args_overflow > 0 {
writeln!(
f,
" link {}: public_args_in = [{}] ({} args, {} over the {}-arg cap)",
o.link_index,
o.public_args_in.join(", "),
o.public_args_in.len(),
o.public_args_overflow,
max_args
)?;
}
if o.total_args_overflow > 0 {
writeln!(
f,
" link {}: declared {} wildcards (public_args_in + private_args), {} over the cap",
o.link_index,
o.public_args_in.len() + o.private_args.len(),
o.total_args_overflow,
)?;
}
}
Ok(())
}
}
/// Early validation: Check if predicate is fundamentally splittable
pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), SplittingError> {
let public_args = pred.args.public_args.len();
@ -105,23 +170,23 @@ pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(),
Ok(())
}
/// Split a predicate into a chain if it exceeds statement limit
/// Split a predicate into a chain if it exceeds statement limit.
///
/// Takes `pred` by reference so callers can re-use it (for example to call
/// [`analyze_infeasibility`] on the failure path) without cloning preemptively.
pub fn split_predicate_if_needed(
pred: CustomPredicateDef,
pred: &CustomPredicateDef,
params: &Params,
) -> Result<SplitResult, SplittingError> {
// Early validation
validate_predicate_is_splittable(&pred)?;
validate_predicate_is_splittable(pred)?;
// If within limits, no splitting needed
if pred.statements.len() <= Params::max_custom_predicate_arity() {
return Ok(SplitResult {
predicates: vec![pred],
predicates: vec![pred.clone()],
chain_info: None,
});
}
// Need to split - execute the splitting algorithm
let (predicates, chain_info) = split_into_chain(pred, params)?;
Ok(SplitResult {
@ -152,60 +217,87 @@ fn compute_min_links(n: usize) -> usize {
/// indices placed in link i, in original order.
type LinkAssignment = Vec<Vec<usize>>;
/// Try to partition `n` statements into exactly `k` links using MILP.
/// MILP variables shared by the strict feasibility solve and the elastic
/// diagnostic solve.
///
/// Returns `Some(assignment)` if a feasible partition exists, `None` if the
/// model is infeasible at this K (caller should try a larger K).
///
/// Variables (all binary):
/// - `assign[s][i]`: statement `s` is placed in link `i`.
/// - `u[w][i]`: wildcard `w` is referenced by some statement at link `i`.
/// - `before[w][i]`, `after[w][i]`: cumulative ORs of `u[w][·]` from the left
/// and right respectively. `before[w][i] = 1` iff w is used at link ≤ i.
/// - `pubin[w][i]`: w appears in link i's `public_args_in`.
/// - `privin[w][i]`: w appears in link i's `private_args` list.
///
/// All non-`assign` variables are forced to be exact functions of `assign`
/// via two-sided (≤ and ≥) constraints, so the LP relaxation has a unique
/// solution given any integer assignment. The objective is a small
/// source-order tiebreaker on `assign`.
fn solve_milp_for_k(
/// All variables are binary. Constraints (C1..C7 below) make every variable
/// other than `assign` an exact function of the assignment, so the strict and
/// elastic models differ only in how they handle the per-link caps (C8/C9).
struct MilpVars {
n: usize,
k: usize,
statements_using: &[Vec<usize>],
is_original_public: &[bool],
params: &Params,
) -> Option<LinkAssignment> {
let max_arity = Params::max_custom_predicate_arity();
let max_args = Params::max_statement_args();
let max_wildcards = params.max_custom_predicate_wildcards;
let num_wildcards = is_original_public.len();
num_wildcards: usize,
/// `assign[s][i]`: statement `s` placed in link `i`.
assign: Vec<Vec<Variable>>,
/// `u[w][i]`: wildcard `w` referenced by some statement at link `i`.
u: Vec<Vec<Variable>>,
/// `before[w][i]`: cumulative OR of `u[w][·]` from the left — w is used at link ≤ i.
before: Vec<Vec<Variable>>,
/// `after[w][i]`: cumulative OR of `u[w][·]` from the right — w is used at link ≥ i.
after: Vec<Vec<Variable>>,
/// `pubin[w][i]`: w appears in link i's `public_args_in`.
pubin: Vec<Vec<Variable>>,
/// `privin[w][i]`: w appears in link i's `private_args` list.
privin: Vec<Vec<Variable>>,
}
let mut vars = ProblemVariables::new();
fn mk_grid(vars: &mut ProblemVariables, rows: usize, cols: usize) -> Vec<Vec<Variable>> {
fn mk_binary_grid(vars: &mut ProblemVariables, rows: usize, cols: usize) -> Vec<Vec<Variable>> {
(0..rows)
.map(|_| (0..cols).map(|_| vars.add(variable().binary())).collect())
.collect()
}
let assign = mk_grid(&mut vars, n, k);
let u = mk_grid(&mut vars, num_wildcards, k);
let before = mk_grid(&mut vars, num_wildcards, k);
let after = mk_grid(&mut vars, num_wildcards, k);
let pubin = mk_grid(&mut vars, num_wildcards, k);
let privin = mk_grid(&mut vars, num_wildcards, k);
// Source-order tiebreaker. Coefficient `(n-1-s) * i` is largest for
// low-index statements at high-index links, so minimization pulls
// low-original-index statements toward low-link-index assignments —
// i.e., the chain roughly preserves source order. The objective only
// breaks ties among feasibility-equivalent solutions.
let objective: Expression = (0..n)
.flat_map(|s| (0..k).map(move |i| (s, i)))
.map(|(s, i)| ((n - 1 - s) as f64) * (i as f64) * assign[s][i])
.sum();
fn declare_milp_vars(
vars: &mut ProblemVariables,
n: usize,
k: usize,
num_wildcards: usize,
) -> MilpVars {
MilpVars {
n,
k,
num_wildcards,
assign: mk_binary_grid(vars, n, k),
u: mk_binary_grid(vars, num_wildcards, k),
before: mk_binary_grid(vars, num_wildcards, k),
after: mk_binary_grid(vars, num_wildcards, k),
pubin: mk_binary_grid(vars, num_wildcards, k),
privin: mk_binary_grid(vars, num_wildcards, k),
}
}
let mut model = vars.minimise(objective).using(default_solver);
/// Source-order tiebreaker: prefers low-original-index statements at low-link
/// indices, so the chain roughly preserves source order when nothing else
/// forces a rearrangement.
fn source_order_tiebreaker(v: &MilpVars) -> Expression {
(0..v.n)
.flat_map(|s| (0..v.k).map(move |i| (s, i)))
.map(|(s, i)| ((v.n - 1 - s) as f64) * (i as f64) * v.assign[s][i])
.sum()
}
/// Add the MILP's structural constraints (C1..C7): assignment, link size,
/// `u`/`before`/`after`/`pubin`/`privin` definitions. Cap constraints (C8/C9)
/// are added by the caller — the strict and elastic versions differ there.
fn add_structural_constraints<M: SolverModel>(
model: &mut M,
v: &MilpVars,
statements_using: &[Vec<usize>],
is_original_public: &[bool],
) {
let max_arity = Params::max_custom_predicate_arity();
let MilpVars {
n,
k,
num_wildcards,
assign,
u,
before,
after,
pubin,
privin,
} = v;
let (n, k, num_wildcards) = (*n, *k, *num_wildcards);
// C1: Each statement assigned to exactly one link.
for s in 0..n {
@ -298,17 +390,39 @@ fn solve_milp_for_k(
}
}
}
}
/// Try to partition `n` statements into exactly `k` links using MILP.
///
/// Returns `Some(assignment)` if a feasible partition exists, `None` if the
/// model is infeasible at this K (caller should try a larger K).
fn solve_milp_for_k(
n: usize,
k: usize,
statements_using: &[Vec<usize>],
is_original_public: &[bool],
params: &Params,
) -> Option<LinkAssignment> {
let max_args = Params::max_statement_args();
let max_wildcards = params.max_custom_predicate_wildcards;
let num_wildcards = is_original_public.len();
let mut vars = ProblemVariables::new();
let v = declare_milp_vars(&mut vars, n, k, num_wildcards);
let objective = source_order_tiebreaker(&v);
let mut model = vars.minimise(objective).using(default_solver);
add_structural_constraints(&mut model, &v, statements_using, is_original_public);
// C8: per-link public-args cap (incoming chain-call args).
for i in 0..k {
let sum: Expression = (0..num_wildcards).map(|w| pubin[w][i]).sum();
let sum: Expression = (0..num_wildcards).map(|w| v.pubin[w][i]).sum();
model.add_constraint(constraint!(sum <= max_args as f64));
}
// C9: per-link total declared wildcards cap.
for i in 0..k {
let sum: Expression = (0..num_wildcards)
.map(|w| Expression::from(pubin[w][i]) + privin[w][i])
.map(|w| Expression::from(v.pubin[w][i]) + v.privin[w][i])
.sum();
model.add_constraint(constraint!(sum <= max_wildcards as f64));
}
@ -319,7 +433,7 @@ fn solve_milp_for_k(
let mut links: LinkAssignment = vec![Vec::new(); k];
for s in 0..n {
for i in 0..k {
if solution.value(assign[s][i]) > SOLVER_BINARY_THRESHOLD {
if solution.value(v.assign[s][i]) > SOLVER_BINARY_THRESHOLD {
links[i].push(s);
break;
}
@ -409,18 +523,17 @@ fn build_chain_links_from_assignment(
result
}
/// Split a predicate into a chain via MILP. Tries `K = K_min, K_min+1, ...`,
/// returning the first feasible chain or [`SplittingError::Infeasible`] if
/// no `K` up to `n` works.
fn split_into_chain(
pred: CustomPredicateDef,
params: &Params,
) -> Result<(Vec<CustomPredicateDef>, SplitChainInfo), SplittingError> {
let original_name = pred.name.name.clone();
let conjunction = pred.conjunction_type;
let n = pred.statements.len();
let real_statement_count = n;
/// Numeric encoding of a predicate's wildcard graph, ready for either the
/// strict MILP or the elastic diagnostic LP.
struct MilpInput {
n: usize,
wildcard_names: Vec<String>,
statements_using: Vec<Vec<usize>>,
is_original_public: Vec<bool>,
original_public_args: Vec<String>,
}
fn prepare_milp_input(pred: &CustomPredicateDef) -> MilpInput {
let original_public_args: Vec<String> = pred
.args
.public_args
@ -428,8 +541,8 @@ fn split_into_chain(
.map(|id| id.name.clone())
.collect();
// Build a stable, sorted index over all wildcards referenced by the
// predicate (statements + declared public args).
// Stable, sorted index over wildcards referenced by statements OR declared
// as public args (a public arg may be unused in any statement).
let mut wildcard_set: HashSet<String> = pred
.statements
.iter()
@ -446,8 +559,7 @@ fn split_into_chain(
.map(|(i, name)| (name.clone(), i))
.collect();
// Inverse map from wildcard index to the statements that reference it.
// Loop-invariant across the K-search, so build once.
// Inverse: which statements reference each wildcard (by index).
let mut statements_using: Vec<Vec<usize>> = vec![Vec::new(); wildcard_names.len()];
for (s, stmt) in pred.statements.iter().enumerate() {
let mut seen: HashSet<usize> = HashSet::new();
@ -464,12 +576,136 @@ fn split_into_chain(
is_original_public[wildcard_index[name]] = true;
}
MilpInput {
n: pred.statements.len(),
wildcard_names,
statements_using,
is_original_public,
original_public_args,
}
}
/// Solve the elastic LP at the given K, returning per-link slack and
/// wildcard membership for the binding links. Slack variables on each cap
/// turn the otherwise-infeasible model into one that minimises constraint
/// violation, exposing exactly which links are over their caps and by how
/// much.
fn solve_elastic_lp(k: usize, input: &MilpInput, params: &Params) -> Option<Vec<LinkOvershoot>> {
let max_args = Params::max_statement_args();
let max_wildcards = params.max_custom_predicate_wildcards;
let num_wildcards = input.wildcard_names.len();
let n = input.n;
let mut vars = ProblemVariables::new();
let v = declare_milp_vars(&mut vars, n, k, num_wildcards);
let slack_pub: Vec<Variable> = (0..k).map(|_| vars.add(variable().min(0.0))).collect();
let slack_total: Vec<Variable> = (0..k).map(|_| vars.add(variable().min(0.0))).collect();
let slack_term: Expression = (0..k)
.map(|i| Expression::from(slack_pub[i]) + slack_total[i])
.sum();
// Tiebreaker bound is n²k². Scale so even the worst-case tiebreaker total
// is < 1 — never enough to outweigh a single unit of slack.
let scale = 1.0 / ((n * n * k * k + 1) as f64);
let objective = slack_term + scale * source_order_tiebreaker(&v);
let mut model = vars.minimise(objective).using(default_solver);
add_structural_constraints(
&mut model,
&v,
&input.statements_using,
&input.is_original_public,
);
// C8 elastic: Σ pubin[w][i] ≤ max_args + slack_pub[i].
for i in 0..k {
let sum: Expression = (0..num_wildcards).map(|w| v.pubin[w][i]).sum();
model.add_constraint(constraint!(sum <= max_args as f64 + slack_pub[i]));
}
// C9 elastic: Σ (pubin + privin)[w][i] ≤ max_wildcards + slack_total[i].
for i in 0..k {
let sum: Expression = (0..num_wildcards)
.map(|w| Expression::from(v.pubin[w][i]) + v.privin[w][i])
.sum();
model.add_constraint(constraint!(sum <= max_wildcards as f64 + slack_total[i]));
}
let solution = model.solve().ok()?;
let mut overshoots = Vec::new();
for i in 0..k {
let pub_overflow = solution.value(slack_pub[i]).round() as usize;
let total_overflow = solution.value(slack_total[i]).round() as usize;
if pub_overflow == 0 && total_overflow == 0 {
continue;
}
let mut public_args_in = Vec::new();
let mut private_args = Vec::new();
for w in 0..num_wildcards {
if solution.value(v.pubin[w][i]) > SOLVER_BINARY_THRESHOLD {
public_args_in.push(input.wildcard_names[w].clone());
}
if solution.value(v.privin[w][i]) > SOLVER_BINARY_THRESHOLD {
private_args.push(input.wildcard_names[w].clone());
}
}
public_args_in.sort();
private_args.sort();
overshoots.push(LinkOvershoot {
link_index: i,
public_args_overflow: pub_overflow,
total_args_overflow: total_overflow,
public_args_in,
private_args,
});
}
Some(overshoots)
}
/// Diagnose why the splitter rejected `pred`. Runs an elastic version of the
/// MILP that allows the per-link caps to be violated by non-negative slack
/// and minimises total slack — the result tells you exactly which links
/// overshoot which caps and by how much.
///
/// Only meaningful to call on inputs that produced
/// [`SplittingError::Infeasible`]. On feasible inputs the report's
/// `overshoots` will be empty.
pub fn analyze_infeasibility(pred: &CustomPredicateDef, params: &Params) -> InfeasibilityReport {
let input = prepare_milp_input(pred);
let k = compute_min_links(input.n);
let overshoots = solve_elastic_lp(k, &input, params).unwrap_or_default();
InfeasibilityReport {
predicate: pred.name.name.clone(),
k,
overshoots,
}
}
/// Split a predicate into a chain via MILP. Tries `K = K_min, K_min+1, ...`,
/// returning the first feasible chain or [`SplittingError::Infeasible`] if
/// no `K` up to `n` works.
fn split_into_chain(
pred: &CustomPredicateDef,
params: &Params,
) -> Result<(Vec<CustomPredicateDef>, SplitChainInfo), SplittingError> {
let original_name = pred.name.name.clone();
let conjunction = pred.conjunction_type;
let real_statement_count = pred.statements.len();
let input = prepare_milp_input(pred);
let n = input.n;
let k_min = compute_min_links(n);
let mut found: Option<(usize, LinkAssignment)> = None;
for k in k_min..=n {
if let Some(assignment) =
solve_milp_for_k(n, k, &statements_using, &is_original_public, params)
{
if let Some(assignment) = solve_milp_for_k(
n,
k,
&input.statements_using,
&input.is_original_public,
params,
) {
found = Some((k, assignment));
break;
}
@ -494,8 +730,11 @@ fn split_into_chain(
}
}
let chain_links =
build_chain_links_from_assignment(assignment, &pred.statements, &original_public_args);
let chain_links = build_chain_links_from_assignment(
assignment,
&pred.statements,
&input.original_public_args,
);
// Build SplitChainInfo (execution order: innermost continuation first).
let num_links = chain_links.len();
@ -726,7 +965,7 @@ mod tests {
let pred = parse_predicate(input);
let params = Params::default();
let result = split_predicate_if_needed(pred, &params);
let result = split_predicate_if_needed(&pred, &params);
assert!(result.is_ok());
let split_result = result.unwrap();
@ -750,7 +989,7 @@ mod tests {
let pred = parse_predicate(input);
let params = Params::default(); // max_custom_predicate_arity = 5
let result = split_predicate_if_needed(pred, &params);
let result = split_predicate_if_needed(&pred, &params);
assert!(result.is_ok());
let split_result = result.unwrap();
@ -794,7 +1033,7 @@ mod tests {
let pred = parse_predicate(input);
let params = Params::default(); // max_custom_predicate_arity = 5
let result = split_predicate_if_needed(pred, &params);
let result = split_predicate_if_needed(&pred, &params);
assert!(result.is_ok());
let split_result = result.unwrap();
@ -830,7 +1069,7 @@ mod tests {
let pred = parse_predicate(input);
let params = Params::default(); // max_custom_predicate_arity = 5
let result = split_predicate_if_needed(pred, &params);
let result = split_predicate_if_needed(&pred, &params);
assert!(result.is_ok());
let split_result = result.unwrap();
@ -879,7 +1118,7 @@ mod tests {
let pred = parse_predicate(input);
let params = Params::default();
let result = split_predicate_if_needed(pred, &params);
let result = split_predicate_if_needed(&pred, &params);
assert!(result.is_ok());
let split_result = result.unwrap();
@ -935,7 +1174,7 @@ mod tests {
let pred = parse_predicate(input);
let params = Params::default();
let result = split_predicate_if_needed(pred, &params);
let result = split_predicate_if_needed(&pred, &params);
assert!(
result.is_ok(),
"Should find a valid split with ≤1 crossing wildcard, got: {:?}",
@ -964,7 +1203,7 @@ mod tests {
let pred = parse_predicate(input);
let params = Params::default();
let result = split_predicate_if_needed(pred, &params).unwrap();
let result = split_predicate_if_needed(&pred, &params).unwrap();
// chain[0] is the continuation (_1 suffix), chain[1] is the original
let continuation = result
.predicates
@ -1193,7 +1432,7 @@ mod tests {
"expected at least one feasible permutation"
);
let result = split_predicate_if_needed(pred, &params);
let result = split_predicate_if_needed(&pred, &params);
assert!(
result.is_ok(),
"splitter rejected an input with a feasible ordering ({:?}): {}",
@ -1202,6 +1441,66 @@ mod tests {
);
}
/// A predicate with one statement that references 9 distinct wildcards
/// is unsplittable: any link containing that statement declares ≥ 9
/// wildcards, exceeding the per-link cap of 8. `analyze_infeasibility`
/// must surface this as a non-zero `total_args_overflow` and list the
/// crowded link's private args.
#[test]
fn test_analyze_infeasibility_reports_total_args_overflow() {
let pred = build_pred(
"dense",
&["A"],
&["W0", "W1", "W2", "W3", "W4", "W5", "W6", "W7", "W8"],
&[
&["W0", "W1", "W2", "W3", "W4", "W5", "W6", "W7", "W8"],
&["W0"],
&["W0"],
&["W0"],
&["W0"],
&["W0"],
],
);
let params = Params::default();
// Sanity: regular splitter rejects this input.
assert!(matches!(
split_predicate_if_needed(&pred, &params),
Err(SplittingError::Infeasible { .. })
));
let report = analyze_infeasibility(&pred, &params);
assert_eq!(report.predicate, "dense");
assert_eq!(report.k, 2);
let total_overflow: usize = report
.overshoots
.iter()
.map(|o| o.total_args_overflow)
.sum();
assert!(
total_overflow >= 1,
"expected ≥1 total-args overflow, got {} (overshoots: {:?})",
total_overflow,
report.overshoots
);
// The dense statement forces W1..W8 into one link as private args.
let crowded_link_has_dense_privates = report
.overshoots
.iter()
.any(|o| o.private_args.iter().filter(|w| w.starts_with('W')).count() >= 8);
assert!(
crowded_link_has_dense_privates,
"expected a binding link to declare 8+ W-wildcards as private; got {:?}",
report.overshoots
);
// Display impl shouldn't panic and should mention the predicate name.
let formatted = format!("{}", report);
assert!(formatted.contains("dense"));
}
/// Randomized counterexample search. Run with
/// `cargo test --release search_splitter -- --ignored --nocapture`.
#[test]
@ -1273,7 +1572,7 @@ mod tests {
checked += 1;
let feasible = find_any_feasible_ordering(&pred, &params);
let split = split_predicate_if_needed(pred.clone(), &params);
let split = split_predicate_if_needed(&pred, &params);
if let (Err(err), Some(perm)) = (split, feasible) {
found += 1;

View file

@ -40,7 +40,10 @@ use std::sync::Arc;
pub use diagnostics::render_error;
pub use error::{LangError, LangErrorKind};
pub use frontend_ast_split::{SplitChainInfo, SplitChainPiece, SplitResult};
pub use frontend_ast_split::{
analyze_infeasibility, InfeasibilityReport, LinkOvershoot, SplitChainInfo, SplitChainPiece,
SplitResult,
};
pub use module::{Module, MultiOperationError};
pub use parser::{parse_podlang, Pairs, ParseError, Rule};
pub use pretty_print::PrettyPrint;

View file

@ -617,7 +617,7 @@ mod tests {
// Split the predicate
let mut split_results = Vec::new();
for pred in predicates {
for pred in &predicates {
let result = split_predicate_if_needed(pred, &params).expect("Split failed");
split_results.push(result);
}