Fix accidental inclusion of extra input PODs (#476) (#478)

* Fix issue with adding extra input PODs

* Panic if input PODs are missing
This commit is contained in:
Rob Knight 2026-02-06 17:18:35 +01:00 committed by GitHub
parent 2bd99ef322
commit 5dab8195b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -47,7 +47,10 @@
//! //!
//! [`MainPodBuilder`]: crate::frontend::MainPodBuilder //! [`MainPodBuilder`]: crate::frontend::MainPodBuilder
use std::collections::{BTreeSet, HashMap}; use std::{
collections::{BTreeSet, HashMap},
fmt,
};
use crate::{ use crate::{
frontend::{MainPod, MainPodBuilder, Operation, OperationArg}, frontend::{MainPod, MainPodBuilder, Operation, OperationArg},
@ -237,61 +240,16 @@ impl SolvedMultiPod {
) -> Result<MainPod> { ) -> Result<MainPod> {
let mut builder = MainPodBuilder::new(&self.params, &self.vd_set); let mut builder = MainPodBuilder::new(&self.params, &self.vd_set);
let solution = &self.solution; let solution = &self.solution;
let statements_in_this_pod = &solution.pod_statements[pod_idx];
let statements_in_this_pod: &Vec<usize> = &solution.pod_statements[pod_idx];
let mut needed_external_pods: BTreeSet<usize> = BTreeSet::new();
let mut needed_earlier_pods: BTreeSet<usize> = BTreeSet::new();
// Step 1: Find which external and earlier PODs we need based on dependencies // Step 1: Find which external and earlier PODs we need based on dependencies
for &stmt_idx in statements_in_this_pod { let (needed_earlier_pods, needed_external_pods) = self.compute_pod_inputs(pod_idx);
for dep in &self.deps.statement_deps[stmt_idx] {
match dep {
StatementSource::Internal(dep_idx) => {
// Check if dependency is in an earlier generated POD
let mut found = false;
for earlier_pod_idx in 0..pod_idx {
if solution.pod_public_statements[earlier_pod_idx].contains(dep_idx) {
needed_earlier_pods.insert(earlier_pod_idx);
found = true;
break;
}
}
// If not found in earlier PODs, it must be local to this POD
if !found && !statements_in_this_pod.contains(dep_idx) {
unreachable!(
"Internal dependency {} for statement {} is neither local \
nor public in any earlier POD (solver bug)",
dep_idx, stmt_idx
);
}
}
StatementSource::External(pod_hash) => {
// Find which external POD has this hash
let ext_idx = self
.input_pods
.iter()
.position(|p| p.statements_hash() == *pod_hash);
match ext_idx {
Some(idx) => {
needed_external_pods.insert(idx);
}
None => {
unreachable!(
"External dependency with hash {:?} not found in input PODs",
pod_hash
);
}
}
}
}
}
}
// Step 2: Add input PODs to the builder // Step 2: Add input PODs to the builder
for &ext_idx in &needed_external_pods { for ext_idx in needed_external_pods {
builder.add_pod(self.input_pods[ext_idx].clone())?; builder.add_pod(self.input_pods[ext_idx].clone())?;
} }
for &earlier_idx in &needed_earlier_pods { for earlier_idx in needed_earlier_pods {
builder.add_pod(earlier_pods[earlier_idx].clone())?; builder.add_pod(earlier_pods[earlier_idx].clone())?;
} }
@ -338,6 +296,132 @@ impl SolvedMultiPod {
Ok(pod) Ok(pod)
} }
/// Compute which input PODs (internal and external) are needed for a given POD.
///
/// Returns (internal_pod_indices, external_pod_indices).
fn compute_pod_inputs(&self, pod_idx: usize) -> (BTreeSet<usize>, BTreeSet<usize>) {
let solution = &self.solution;
let statements_in_pod = &solution.pod_statements[pod_idx];
let mut internal_pods: BTreeSet<usize> = BTreeSet::new();
let mut external_pods: BTreeSet<usize> = BTreeSet::new();
for &stmt_idx in statements_in_pod {
for dep in &self.deps.statement_deps[stmt_idx] {
match dep {
StatementSource::Internal(dep_idx) => {
// Check if dependency is in an earlier POD (not local)
if !statements_in_pod.contains(dep_idx) {
let earlier_pod_idx = (0..pod_idx)
.find(|earlier_pod_idx| {
solution.pod_public_statements[*earlier_pod_idx]
.contains(dep_idx)
})
.expect("internal pod with dependency statement");
internal_pods.insert(earlier_pod_idx);
}
}
StatementSource::External(pod_hash) => {
let idx = self
.input_pods
.iter()
.position(|p| p.statements_hash() == *pod_hash)
.expect("external pod with dependency statement");
external_pods.insert(idx);
}
}
}
}
assert!(internal_pods.len() + external_pods.len() <= self.params.max_input_pods);
(internal_pods, external_pods)
}
}
impl fmt::Display for SolvedMultiPod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let solution = &self.solution;
let output_pod_idx = solution.pod_count.saturating_sub(1);
// Header
writeln!(
f,
"SolvedMultiPod: {} statements → {} PODs",
self.statements.len(),
solution.pod_count
)?;
if !self.input_pods.is_empty() {
writeln!(f, " External input PODs: {}", self.input_pods.len())?;
}
writeln!(f)?;
// Per-POD breakdown
for pod_idx in 0..solution.pod_count {
let is_output = pod_idx == output_pod_idx;
let role = if is_output { "output" } else { "intermediate" };
writeln!(f, " POD {} ({}):", pod_idx, role)?;
// Show input PODs
let (internal_inputs, external_inputs) = self.compute_pod_inputs(pod_idx);
if !internal_inputs.is_empty() || !external_inputs.is_empty() {
let internal_str: Vec<String> = internal_inputs
.iter()
.map(|i| format!("POD {}", i))
.collect();
let external_str: Vec<String> = external_inputs
.iter()
.map(|i| format!("ext[{}]", i))
.collect();
let all_inputs: Vec<&str> = internal_str
.iter()
.map(|s| s.as_str())
.chain(external_str.iter().map(|s| s.as_str()))
.collect();
writeln!(
f,
" inputs: {} (total: {})",
all_inputs.join(", "),
all_inputs.len()
)?;
}
// Show statements
let stmts = &solution.pod_statements[pod_idx];
let public_stmts = &solution.pod_public_statements[pod_idx];
for &stmt_idx in stmts {
let stmt = &self.statements[stmt_idx];
let is_public = public_stmts.contains(&stmt_idx);
let visibility = if is_public { "public" } else { "private" };
// Show dependencies for this statement
let deps = &self.deps.statement_deps[stmt_idx];
let dep_str = if deps.is_empty() {
String::new()
} else {
let dep_parts: Vec<String> = deps
.iter()
.map(|d| match d {
StatementSource::Internal(i) => format!("stmt[{}]", i),
StatementSource::External(_) => "ext".to_string(),
})
.collect();
format!("{}", dep_parts.join(", "))
};
writeln!(f, " [{}] {} [{}]{}", stmt_idx, stmt, visibility, dep_str)?;
}
writeln!(f)?;
}
Ok(())
}
} }
impl MultiPodBuilder { impl MultiPodBuilder {