From 5dab8195b4099a66d6239a97d87ac2090716436a Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Fri, 6 Feb 2026 17:18:35 +0100 Subject: [PATCH] Fix accidental inclusion of extra input PODs (#476) (#478) * Fix issue with adding extra input PODs * Panic if input PODs are missing --- src/frontend/multi_pod/mod.rs | 184 +++++++++++++++++++++++++--------- 1 file changed, 134 insertions(+), 50 deletions(-) diff --git a/src/frontend/multi_pod/mod.rs b/src/frontend/multi_pod/mod.rs index a18da24..bb80438 100644 --- a/src/frontend/multi_pod/mod.rs +++ b/src/frontend/multi_pod/mod.rs @@ -47,7 +47,10 @@ //! //! [`MainPodBuilder`]: crate::frontend::MainPodBuilder -use std::collections::{BTreeSet, HashMap}; +use std::{ + collections::{BTreeSet, HashMap}, + fmt, +}; use crate::{ frontend::{MainPod, MainPodBuilder, Operation, OperationArg}, @@ -237,61 +240,16 @@ impl SolvedMultiPod { ) -> Result { let mut builder = MainPodBuilder::new(&self.params, &self.vd_set); let solution = &self.solution; - - let statements_in_this_pod: &Vec = &solution.pod_statements[pod_idx]; - let mut needed_external_pods: BTreeSet = BTreeSet::new(); - let mut needed_earlier_pods: BTreeSet = BTreeSet::new(); + let statements_in_this_pod = &solution.pod_statements[pod_idx]; // Step 1: Find which external and earlier PODs we need based on dependencies - for &stmt_idx in statements_in_this_pod { - for dep in &self.deps.statement_deps[stmt_idx] { - match dep { - StatementSource::Internal(dep_idx) => { - // Check if dependency is in an earlier generated POD - let mut found = false; - for earlier_pod_idx in 0..pod_idx { - if solution.pod_public_statements[earlier_pod_idx].contains(dep_idx) { - needed_earlier_pods.insert(earlier_pod_idx); - found = true; - break; - } - } - // If not found in earlier PODs, it must be local to this POD - if !found && !statements_in_this_pod.contains(dep_idx) { - unreachable!( - "Internal dependency {} for statement {} is neither local \ - nor public in any earlier POD (solver bug)", - dep_idx, stmt_idx - ); - } - } - StatementSource::External(pod_hash) => { - // Find which external POD has this hash - let ext_idx = self - .input_pods - .iter() - .position(|p| p.statements_hash() == *pod_hash); - match ext_idx { - Some(idx) => { - needed_external_pods.insert(idx); - } - None => { - unreachable!( - "External dependency with hash {:?} not found in input PODs", - pod_hash - ); - } - } - } - } - } - } + let (needed_earlier_pods, needed_external_pods) = self.compute_pod_inputs(pod_idx); // Step 2: Add input PODs to the builder - for &ext_idx in &needed_external_pods { + for ext_idx in needed_external_pods { builder.add_pod(self.input_pods[ext_idx].clone())?; } - for &earlier_idx in &needed_earlier_pods { + for earlier_idx in needed_earlier_pods { builder.add_pod(earlier_pods[earlier_idx].clone())?; } @@ -338,6 +296,132 @@ impl SolvedMultiPod { Ok(pod) } + + /// Compute which input PODs (internal and external) are needed for a given POD. + /// + /// Returns (internal_pod_indices, external_pod_indices). + fn compute_pod_inputs(&self, pod_idx: usize) -> (BTreeSet, BTreeSet) { + let solution = &self.solution; + let statements_in_pod = &solution.pod_statements[pod_idx]; + + let mut internal_pods: BTreeSet = BTreeSet::new(); + let mut external_pods: BTreeSet = BTreeSet::new(); + + for &stmt_idx in statements_in_pod { + for dep in &self.deps.statement_deps[stmt_idx] { + match dep { + StatementSource::Internal(dep_idx) => { + // Check if dependency is in an earlier POD (not local) + if !statements_in_pod.contains(dep_idx) { + let earlier_pod_idx = (0..pod_idx) + .find(|earlier_pod_idx| { + solution.pod_public_statements[*earlier_pod_idx] + .contains(dep_idx) + }) + .expect("internal pod with dependency statement"); + internal_pods.insert(earlier_pod_idx); + } + } + StatementSource::External(pod_hash) => { + let idx = self + .input_pods + .iter() + .position(|p| p.statements_hash() == *pod_hash) + .expect("external pod with dependency statement"); + external_pods.insert(idx); + } + } + } + } + + assert!(internal_pods.len() + external_pods.len() <= self.params.max_input_pods); + + (internal_pods, external_pods) + } +} + +impl fmt::Display for SolvedMultiPod { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let solution = &self.solution; + let output_pod_idx = solution.pod_count.saturating_sub(1); + + // Header + writeln!( + f, + "SolvedMultiPod: {} statements → {} PODs", + self.statements.len(), + solution.pod_count + )?; + + if !self.input_pods.is_empty() { + writeln!(f, " External input PODs: {}", self.input_pods.len())?; + } + + writeln!(f)?; + + // Per-POD breakdown + for pod_idx in 0..solution.pod_count { + let is_output = pod_idx == output_pod_idx; + let role = if is_output { "output" } else { "intermediate" }; + + writeln!(f, " POD {} ({}):", pod_idx, role)?; + + // Show input PODs + let (internal_inputs, external_inputs) = self.compute_pod_inputs(pod_idx); + if !internal_inputs.is_empty() || !external_inputs.is_empty() { + let internal_str: Vec = internal_inputs + .iter() + .map(|i| format!("POD {}", i)) + .collect(); + let external_str: Vec = external_inputs + .iter() + .map(|i| format!("ext[{}]", i)) + .collect(); + let all_inputs: Vec<&str> = internal_str + .iter() + .map(|s| s.as_str()) + .chain(external_str.iter().map(|s| s.as_str())) + .collect(); + writeln!( + f, + " inputs: {} (total: {})", + all_inputs.join(", "), + all_inputs.len() + )?; + } + + // Show statements + let stmts = &solution.pod_statements[pod_idx]; + let public_stmts = &solution.pod_public_statements[pod_idx]; + + for &stmt_idx in stmts { + let stmt = &self.statements[stmt_idx]; + let is_public = public_stmts.contains(&stmt_idx); + let visibility = if is_public { "public" } else { "private" }; + + // Show dependencies for this statement + let deps = &self.deps.statement_deps[stmt_idx]; + let dep_str = if deps.is_empty() { + String::new() + } else { + let dep_parts: Vec = deps + .iter() + .map(|d| match d { + StatementSource::Internal(i) => format!("stmt[{}]", i), + StatementSource::External(_) => "ext".to_string(), + }) + .collect(); + format!(" ← {}", dep_parts.join(", ")) + }; + + writeln!(f, " [{}] {} [{}]{}", stmt_idx, stmt, visibility, dep_str)?; + } + + writeln!(f)?; + } + + Ok(()) + } } impl MultiPodBuilder {