diff --git a/src/frontend/multi_pod/mod.rs b/src/frontend/multi_pod/mod.rs index d90c295..2abe85d 100644 --- a/src/frontend/multi_pod/mod.rs +++ b/src/frontend/multi_pod/mod.rs @@ -26,8 +26,8 @@ //! are **intermediate PODs** that prove supporting statements. //! //! This ordering allows dependencies to flow forward: later PODs can access public -//! statements from earlier PODs via `CopyStatement`. The output POD, being last, can -//! access all intermediate PODs. +//! statements from earlier PODs by adding them as input PODs. The output POD, being +//! last, can access all intermediate PODs. //! //! # Usage //! @@ -51,9 +51,7 @@ use std::collections::{BTreeSet, HashMap}; use crate::{ frontend::{MainPod, MainPodBuilder, Operation, OperationArg}, - middleware::{ - Hash, MainPodProver, NativeOperation, OperationAux, OperationType, Params, Statement, VDSet, - }, + middleware::{Hash, MainPodProver, Params, Statement, VDSet}, }; mod cost; @@ -152,16 +150,18 @@ impl MultiPodResult { /// to add statements, just like `MainPodBuilder`. /// /// 2. **Solve**: Call [`solve`](Self::solve) to run the MILP solver, which determines -/// the optimal assignment of statements to PODs. +/// the optimal assignment of statements to PODs. This consumes the builder and +/// returns a [`SolvedMultiPod`]. /// -/// 3. **Prove**: Call [`prove`](Self::prove) to build and prove all PODs. +/// 3. **Prove**: Call [`prove`](SolvedMultiPod::prove) on the solved result to build +/// and prove all PODs. /// /// # POD Structure /// /// The result contains PODs in build order: intermediate PODs first (indices 0..k-1), /// then the output POD last (index k). The output POD contains all user-requested /// public statements (those added via `pub_op`). Intermediate PODs make their -/// statements public so later PODs can copy them. +/// statements public so later PODs can access them. /// /// [`MainPodBuilder`]: crate::frontend::MainPodBuilder #[derive(Debug)] @@ -178,227 +178,45 @@ pub struct MultiPodBuilder { /// Indices of statements that should be public in output PODs. /// Uses Vec since max_public_statements is small (≤8); indices are naturally sorted. output_public_indices: Vec, - /// Cached solution from the solver. - cached_solution: Option, - /// Cached dependency graph (computed once in solve(), reused in build_single_pod()). - cached_deps: Option, - /// Cached external POD statement map (computed once in solve(), reused in build_single_pod()). - cached_external_map: Option>, - /// Cached MainPodBuilder for incremental statement computation. - cached_builder: Option, + /// Used during add_operation to validate statements with unlimited params. + builder: MainPodBuilder, } -impl MultiPodBuilder { - /// Create a new MultiPodBuilder with default options. - pub fn new(params: &Params, vd_set: &VDSet) -> Self { - Self::new_with_options(params, vd_set, Options::default()) - } +/// A solved multi-POD problem, ready to be proved. +/// +/// Created by [`MultiPodBuilder::solve`]. Call [`prove`](Self::prove) to build +/// and prove all PODs, or inspect the [`solution`](Self::solution) first. +#[derive(Debug)] +pub struct SolvedMultiPod { + params: Params, + vd_set: VDSet, + input_pods: Vec, + statements: Vec, + operations: Vec, + solution: MultiPodSolution, + deps: DependencyGraph, +} - /// Create a new MultiPodBuilder with custom options. - pub fn new_with_options(params: &Params, vd_set: &VDSet, options: Options) -> Self { - Self { - params: params.clone(), - vd_set: vd_set.clone(), - options, - input_pods: Vec::new(), - statements: Vec::new(), - operations: Vec::new(), - output_public_indices: Vec::new(), - cached_solution: None, - cached_deps: None, - cached_external_map: None, - cached_builder: None, - } - } - - /// Add an external input POD. - pub fn add_pod(&mut self, pod: MainPod) { - // Keep cached_builder in sync if it exists - if let Some(ref mut builder) = self.cached_builder { - // Won't fail - cached_builder has unlimited params - let _ = builder.add_pod(pod.clone()); - } - self.input_pods.push(pod); - self.invalidate_cache(); - } - - /// Add a public operation (statement will be public in output). - pub fn pub_op(&mut self, op: Operation) -> Result { - let stmt = self.add_operation(op)?; - // Index is always new (just added), so push without duplicate check - self.output_public_indices.push(self.statements.len() - 1); - Ok(stmt) - } - - /// Add a private operation. - pub fn priv_op(&mut self, op: Operation) -> Result { - self.add_operation(op) - } - - /// Internal: Add an operation and create its statement. - fn add_operation(&mut self, op: Operation) -> Result { - self.invalidate_cache(); - - // Get or create the cached builder - // - // NOTE: We clone input pods here because MainPodBuilder takes ownership. - // This could be avoided if MainPodBuilder were generic over the pod storage type: - // struct MainPodBuilder = MainPod> - // Then MultiPodBuilder could use MainPodBuilder<&MainPod> to borrow instead of clone, - // while existing code using MainPodBuilder (with the default) would be unaffected. - let builder = self.cached_builder.get_or_insert_with(|| { - let unlimited_params = Params { - max_statements: usize::MAX / 2, - max_public_statements: usize::MAX / 2, - max_input_pods: usize::MAX / 2, - max_input_pods_public_statements: usize::MAX / 2, - ..self.params.clone() - }; - let mut b = MainPodBuilder::new(&unlimited_params, &self.vd_set); - for pod in &self.input_pods { - let _ = b.add_pod(pod.clone()); - } - b - }); - - let stmt = builder - .op(false, vec![], op.clone()) - .map_err(|e| Error::Frontend(e.to_string()))?; - - self.statements.push(stmt.clone()); - self.operations.push(op); - - Ok(stmt) - } - - /// Mark a statement as public in output. - /// - /// Returns an error if the statement was not found in the builder. - /// Calling this multiple times on the same statement is idempotent. - pub fn reveal(&mut self, stmt: &Statement) -> Result<()> { - if let Some(idx) = self.statements.iter().position(|s| s == stmt) { - // Only invalidate cache if this is a new reveal - if !self.output_public_indices.contains(&idx) { - self.output_public_indices.push(idx); - self.invalidate_cache(); - } - Ok(()) - } else { - Err(Error::Frontend( - "reveal() called with statement not found in builder".to_string(), - )) - } - } - - /// Get the number of statements. - pub fn num_statements(&self) -> usize { - self.statements.len() - } - - /// Invalidate all cached data. Called when operations or statements change. - fn invalidate_cache(&mut self) { - self.cached_solution = None; - self.cached_deps = None; - self.cached_external_map = None; - } - - /// Solve the packing problem and return the solution. - /// - /// This runs the MILP solver to find the optimal POD assignment. - /// The solution is cached for subsequent calls. - pub fn solve(&mut self) -> Result<&MultiPodSolution> { - if self.cached_solution.is_some() { - return Ok(self.cached_solution.as_ref().unwrap()); - } - - // Compute costs for each statement - let costs: Vec = self - .operations - .iter() - .map(StatementCost::from_operation) - .collect(); - - // Collect all unique anchored keys from the costs - let all_anchored_keys: Vec = costs - .iter() - .flat_map(|c| c.anchored_keys.iter().cloned()) - .collect::>() - .into_iter() - .collect(); - - // Build map from anchored key to its producing statement index (if any). - // A Contains statement with literal (dict, key, value) "produces" that anchored key. - let mut ak_to_producer: HashMap = HashMap::new(); - for (stmt_idx, stmt) in self.statements.iter().enumerate() { - if let Some(ak) = AnchoredKeyId::from_contains_statement(stmt) { - // First producer wins (shouldn't have duplicates in practice) - ak_to_producer.entry(ak).or_insert(stmt_idx); - } - } - - // Build parallel array: anchored_key_producers[i] = producer for all_anchored_keys[i] - let anchored_key_producers: Vec> = all_anchored_keys - .iter() - .map(|ak| ak_to_producer.get(ak).copied()) - .collect(); - - // Build external POD statement mapping (cache for reuse in build_single_pod) - let external_pod_statements = self.build_external_statement_map(); - self.cached_external_map = Some(external_pod_statements); - let external_pod_statements = self.cached_external_map.as_ref().unwrap(); - - // Build dependency graph (cache for reuse in build_single_pod) - let deps = - DependencyGraph::build(&self.statements, &self.operations, external_pod_statements); - self.cached_deps = Some(deps); - let deps = self.cached_deps.as_ref().unwrap(); - - // Build statement content groups for deduplication. - // Statements with identical content share a single slot in the POD. - // Group statement indices by their content. - let mut content_to_indices: HashMap<&Statement, Vec> = HashMap::new(); - for (idx, stmt) in self.statements.iter().enumerate() { - content_to_indices.entry(stmt).or_default().push(idx); - } - let statement_content_groups: Vec> = content_to_indices.into_values().collect(); - - // Run solver - let input = solver::SolverInput { - num_statements: self.statements.len(), - costs: &costs, - deps, - output_public_indices: &self.output_public_indices, - params: &self.params, - max_pods: self.options.max_pods, - all_anchored_keys: &all_anchored_keys, - anchored_key_producers: &anchored_key_producers, - statement_content_groups: &statement_content_groups, - }; - - let solution = solver::solve(&input)?; - self.cached_solution = Some(solution); - - Ok(self.cached_solution.as_ref().unwrap()) +impl SolvedMultiPod { + /// Get the solver's solution (POD assignments). + pub fn solution(&self) -> &MultiPodSolution { + &self.solution } /// Build and prove all PODs. /// - /// This first solves if not already solved, then builds and proves - /// all necessary PODs in dependency order. - pub fn prove(&mut self, prover: &dyn MainPodProver) -> Result { - // Ensure we have a solution (can't use returned reference due to later &mut self borrows) - self.solve()?; - let solution = self.cached_solution.as_ref().unwrap(); + /// Builds PODs in dependency order (0, 1, ..., k) and proves each one. + /// The last POD is the output POD containing user-requested public statements. + pub fn prove(self, prover: &dyn MainPodProver) -> Result { + let solution = &self.solution; // Build PODs in sequential order: 0, 1, 2, ..., k - // This order is guaranteed by the solver's symmetry-breaking constraint, which - // ensures PODs are used in order (no gaps). Sequential building is required because - // later PODs may reference earlier ones via CopyStatement for cross-POD dependencies. - // PODs 0..k-1 are intermediate; POD k (the last one) is the output POD. + // This order is guaranteed by the solver's symmetry-breaking constraint. + // Later PODs may reference earlier ones for cross-POD dependencies. let mut pods: Vec = Vec::with_capacity(solution.pod_count); for pod_idx in 0..solution.pod_count { - let pod = self.build_single_pod(pod_idx, solution, &pods, prover)?; + let pod = self.build_single_pod(pod_idx, &pods, prover)?; pods.push(pod); } @@ -411,23 +229,17 @@ impl MultiPodBuilder { /// 1. Identifying which input PODs are needed (external + earlier generated) /// 2. Adding those input PODs to a fresh `MainPodBuilder` /// 3. For each statement assigned to this POD (in dependency order): - /// - Copy any dependencies from earlier PODs via `CopyStatement` /// - Execute the original operation to create the statement /// - Mark as public if the solver determined it should be /// 4. Prove the POD fn build_single_pod( &self, pod_idx: usize, - solution: &MultiPodSolution, earlier_pods: &[MainPod], prover: &dyn MainPodProver, ) -> Result { let mut builder = MainPodBuilder::new(&self.params, &self.vd_set); - - let deps = self - .cached_deps - .as_ref() - .expect("build_single_pod called before solve()"); + let solution = &self.solution; let statements_in_this_pod: &Vec = &solution.pod_statements[pod_idx]; let mut needed_external_pods: BTreeSet = BTreeSet::new(); @@ -435,7 +247,7 @@ impl MultiPodBuilder { // Step 1: Find which external and earlier PODs we need based on dependencies for &stmt_idx in statements_in_this_pod { - for dep in &deps.statement_deps[stmt_idx] { + for dep in &self.deps.statement_deps[stmt_idx] { match dep { StatementSource::Internal(dep_idx) => { // Check if dependency is in an earlier generated POD @@ -486,89 +298,31 @@ impl MultiPodBuilder { builder.add_pod(earlier_pods[earlier_idx].clone())?; } - // Step 3: Build statement source map for determining what needs copying. - // Create a mapping from statement to its source (for copy operations). - // A statement may be both proved locally AND available from an earlier POD. - // We use or_insert to prefer local sources (inserted first) over earlier PODs. - let mut stmt_sources: HashMap = HashMap::new(); - for &stmt_idx in statements_in_this_pod { - stmt_sources.insert(stmt_idx, StmtSource::Local); - } - for earlier_pod_idx in 0..pod_idx { - for &stmt_idx in &solution.pod_public_statements[earlier_pod_idx] { - // Only insert if not already local - or_insert preserves existing entries - stmt_sources.entry(stmt_idx).or_insert(StmtSource::FromPod); - } - } - - // Step 4: Add statements in dependency order. + // Step 3: Add statements in dependency order. // Statements are added in ascending index order, which matches dependency order: // if B depends on A, then A has a lower index and is added first. let statements_sorted: BTreeSet = statements_in_this_pod.iter().copied().collect(); let public_set = &solution.pod_public_statements[pod_idx]; - // Track which statements have been added to this builder + // Track statements proved locally in this POD for argument remapping. + // When an operation references a statement proved earlier in this same POD, + // we need to use the Statement object that MainPodBuilder created (not the + // original from MultiPodBuilder) so that find_op_arg can locate it. let mut added_statements: HashMap = HashMap::new(); for &stmt_idx in &statements_sorted { - // First, ensure all dependencies are available (copy if needed). - // When a dependency comes from an earlier POD, we need CopyStatement to make it - // available in this POD's namespace. The earlier POD is already added as an input, - // but CopyStatement creates a local reference that operations can use. - for dep in &deps.statement_deps[stmt_idx] { - if let StatementSource::Internal(dep_idx) = dep { - if !added_statements.contains_key(dep_idx) { - // Need to copy this statement from an earlier POD - match stmt_sources.get(dep_idx) { - Some(StmtSource::FromPod) => { - // Dependency is from an earlier POD - copy it - let copy_op = Operation( - OperationType::Native(NativeOperation::CopyStatement), - vec![OperationArg::Statement( - self.statements[*dep_idx].clone(), - )], - OperationAux::None, - ); - let copied_stmt = builder - .priv_op(copy_op) - .map_err(|e| Error::Frontend(e.to_string()))?; - added_statements.insert(*dep_idx, copied_stmt); - } - Some(StmtSource::Local) => { - // Local dependency should already be added due to topological - // ordering. If we reach here, there's a bug in the ordering. - unreachable!( - "Local dependency at index {} should already be added \ - when processing statement {} (topological order violation)", - dep_idx, stmt_idx - ); - } - None => { - // Dependency not found in stmt_sources means it's neither - // in this POD nor available from earlier PODs - a solver bug. - unreachable!( - "Dependency at index {} not found in stmt_sources \ - when processing statement {}", - dep_idx, stmt_idx - ); - } - } - } - } - } - - // Now add the actual statement let is_public = public_set.contains(&stmt_idx); let mut op = self.operations[stmt_idx].clone(); - // Remap Statement arguments in the operation to use statements created by MainPodBuilder. - // The original operation references Statements from MultiPodBuilder, but MainPodBuilder - // needs Statements that were either created by it or come from its input PODs. + // Remap Statement arguments that reference locally-proved statements. + // For external dependencies (from input PODs including earlier generated PODs), + // the original Statement is used directly - MainPodBuilder will find it in + // the input POD's public statements via find_op_arg. for arg in &mut op.1 { if let OperationArg::Statement(ref orig_stmt) = arg { - // Find the original statement's index in MultiPodBuilder + // Find the original statement's index if let Some(orig_idx) = self.statements.iter().position(|s| s == orig_stmt) { - // Get the remapped statement from MainPodBuilder + // Get the remapped statement if it was proved locally in this POD if let Some(remapped_stmt) = added_statements.get(&orig_idx) { *arg = OperationArg::Statement(remapped_stmt.clone()); } @@ -583,35 +337,196 @@ impl MultiPodBuilder { added_statements.insert(stmt_idx, stmt); } - // Step 5: Prove the POD + // Step 4: Prove the POD let pod = builder .prove(prover) .map_err(|e| Error::Frontend(e.to_string()))?; Ok(pod) } +} - /// Build mapping from external POD statements to their POD hash. - fn build_external_statement_map(&self) -> HashMap { - let mut map = HashMap::new(); - for pod in &self.input_pods { - let pod_hash = pod.statements_hash(); - for stmt in pod.pod.pub_statements() { - map.insert(stmt, pod_hash); +impl MultiPodBuilder { + /// Create a new MultiPodBuilder with default options. + pub fn new(params: &Params, vd_set: &VDSet) -> Self { + Self::new_with_options(params, vd_set, Options::default()) + } + + /// Create a new MultiPodBuilder with custom options. + pub fn new_with_options(params: &Params, vd_set: &VDSet, options: Options) -> Self { + let unlimited_params = Params { + max_statements: usize::MAX / 2, + max_public_statements: usize::MAX / 2, + max_input_pods: usize::MAX / 2, + max_input_pods_public_statements: usize::MAX / 2, + ..params.clone() + }; + let builder = MainPodBuilder::new(&unlimited_params, vd_set); + Self { + params: params.clone(), + vd_set: vd_set.clone(), + options, + builder, + input_pods: Vec::new(), + statements: Vec::new(), + operations: Vec::new(), + output_public_indices: Vec::new(), + } + } + + /// Add an external input POD. + pub fn add_pod(&mut self, pod: MainPod) -> Result<()> { + self.builder + .add_pod(pod.clone()) + .map_err(|e| Error::Frontend(e.to_string()))?; + self.input_pods.push(pod); + Ok(()) + } + + /// Add a public operation (statement will be public in output). + pub fn pub_op(&mut self, op: Operation) -> Result { + let stmt = self.add_operation(op)?; + // Index is always new (just added), so push without duplicate check + self.output_public_indices.push(self.statements.len() - 1); + Ok(stmt) + } + + /// Add a private operation. + pub fn priv_op(&mut self, op: Operation) -> Result { + self.add_operation(op) + } + + /// Internal: Add an operation and create its statement. + fn add_operation(&mut self, op: Operation) -> Result { + // Get or create the cached builder + // + // NOTE: We clone input pods here because MainPodBuilder takes ownership. + // This could be avoided if MainPodBuilder were generic over the pod storage type: + // struct MainPodBuilder = MainPod> + // Then MultiPodBuilder could use MainPodBuilder<&MainPod> to borrow instead of clone, + // while existing code using MainPodBuilder (with the default) would be unaffected. + let stmt = self + .builder + .op(false, vec![], op.clone()) + .map_err(|e| Error::Frontend(e.to_string()))?; + + self.statements.push(stmt.clone()); + self.operations.push(op); + + Ok(stmt) + } + + /// Mark a statement as public in output. + /// + /// Returns an error if the statement was not found in the builder. + /// Calling this multiple times on the same statement is idempotent. + pub fn reveal(&mut self, stmt: &Statement) -> Result<()> { + if let Some(idx) = self.statements.iter().position(|s| s == stmt) { + if !self.output_public_indices.contains(&idx) { + self.output_public_indices.push(idx); + } + Ok(()) + } else { + Err(Error::Frontend( + "reveal() called with statement not found in builder".to_string(), + )) + } + } + + /// Get the number of statements. + pub fn num_statements(&self) -> usize { + self.statements.len() + } + + /// Solve the packing problem and return a solved builder ready for proving. + /// + /// This runs the MILP solver to find the optimal POD assignment. + /// Consumes the builder and returns a [`SolvedMultiPod`] that can be proved. + pub fn solve(self) -> Result { + // Compute costs for each statement + let costs: Vec = self + .operations + .iter() + .map(StatementCost::from_operation) + .collect(); + + // Collect all unique anchored keys from the costs + let all_anchored_keys: Vec = costs + .iter() + .flat_map(|c| c.anchored_keys.iter().cloned()) + .collect::>() + .into_iter() + .collect(); + + // Build map from anchored key to its producing statement index (if any). + // A Contains statement with literal (dict, key, value) "produces" that anchored key. + let mut ak_to_producer: HashMap = HashMap::new(); + for (stmt_idx, stmt) in self.statements.iter().enumerate() { + if let Some(ak) = AnchoredKeyId::from_contains_statement(stmt) { + // First producer wins (shouldn't have duplicates in practice) + ak_to_producer.entry(ak).or_insert(stmt_idx); } } - map + + // Build parallel array: anchored_key_producers[i] = producer for all_anchored_keys[i] + let anchored_key_producers: Vec> = all_anchored_keys + .iter() + .map(|ak| ak_to_producer.get(ak).copied()) + .collect(); + + // Build external POD statement mapping + let external_pod_statements = build_external_statement_map(&self.input_pods); + + // Build dependency graph + let deps = + DependencyGraph::build(&self.statements, &self.operations, &external_pod_statements); + + // Build statement content groups for deduplication. + // Statements with identical content share a single slot in the POD. + // Group statement indices by their content. + let mut content_to_indices: HashMap<&Statement, Vec> = HashMap::new(); + for (idx, stmt) in self.statements.iter().enumerate() { + content_to_indices.entry(stmt).or_default().push(idx); + } + let statement_content_groups: Vec> = content_to_indices.into_values().collect(); + + // Run solver + let input = solver::SolverInput { + num_statements: self.statements.len(), + costs: &costs, + deps: &deps, + output_public_indices: &self.output_public_indices, + params: &self.params, + max_pods: self.options.max_pods, + all_anchored_keys: &all_anchored_keys, + anchored_key_producers: &anchored_key_producers, + statement_content_groups: &statement_content_groups, + }; + + let solution = solver::solve(&input)?; + + Ok(SolvedMultiPod { + params: self.params, + vd_set: self.vd_set, + input_pods: self.input_pods, + statements: self.statements, + operations: self.operations, + solution, + deps, + }) } } -/// Source of a statement within a built POD. -#[derive(Clone, Debug)] -enum StmtSource { - /// Statement is proved locally in this POD. - Local, - /// Statement is copied from an earlier generated POD. - /// (The specific POD index doesn't matter - we only need to know it's not local.) - FromPod, +/// Build mapping from external POD statements to their POD hash. +fn build_external_statement_map(input_pods: &[MainPod]) -> HashMap { + let mut map = HashMap::new(); + for pod in input_pods { + let pod_hash = pod.statements_hash(); + for stmt in pod.pod.pub_statements() { + map.insert(stmt, pod_hash); + } + } + map } #[cfg(test)] @@ -646,12 +561,12 @@ mod tests { builder.pub_op(FrontendOp::dict_signed_by(&signed_dict))?; // Solve - let solution = builder.solve()?; - assert_eq!(solution.pod_count, 1); + let solved = builder.solve()?; + assert_eq!(solved.solution().pod_count, 1); // Prove let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = solved.prove(&prover)?; assert_eq!(result.pods.len(), 1); assert!(result.intermediate_pods().is_empty()); @@ -693,20 +608,19 @@ mod tests { builder.pub_op(FrontendOp::eq(100, 100))?; builder.pub_op(FrontendOp::eq(101, 101))?; - let pod_count = { - let solution = builder.solve()?; - // 8 statements / 4 per POD = 2 PODs minimum - assert!( - solution.pod_count >= 2, - "Expected at least 2 PODs for 8 statements with max_priv=4, got {}", - solution.pod_count - ); - solution.pod_count - }; + // Solve + let solved = builder.solve()?; + // 8 statements / 4 per POD = 2 PODs minimum + assert!( + solved.solution().pod_count >= 2, + "Expected at least 2 PODs for 8 statements with max_priv=4, got {}", + solved.solution().pod_count + ); + let pod_count = solved.solution().pod_count; // Prove and verify let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = solved.prove(&prover)?; assert_eq!(result.pods.len(), pod_count); for (i, pod) in result.pods.iter().enumerate() { @@ -788,12 +702,13 @@ mod tests { // Solve - this finds a multi-POD solution where intermediate PODs // provide dependencies to the output POD. - let solution = builder.solve()?; + let solved = builder.solve()?; + let solution = solved.solution(); // Expected: exactly 2 PODs // - POD 0 (intermediate): statements 0 (contains), 1 (a_out); a_out is public // - POD 1 (output): statement 2 (b_out); b_out is public - // The output POD copies a_out from POD 0 to satisfy b_out's dependency. + // The output POD accesses a_out from POD 0 to satisfy b_out's dependency. assert_eq!( solution.pod_count, 2, "Expected exactly 2 PODs for 3-statement chain with max_priv=2" @@ -806,7 +721,7 @@ mod tests { solution.pod_statements[0] ); - // Statement 1 (a_out) should be public in POD 0 so POD 1 can copy it + // Statement 1 (a_out) should be public in POD 0 so POD 1 can access it assert!( solution.pod_public_statements[0].contains(&1), "Statement 1 (a_out) should be public in POD 0" @@ -827,7 +742,7 @@ mod tests { // Prove and verify all PODs let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = solved.prove(&prover)?; for (i, pod) in result.pods.iter().enumerate() { pod.pod @@ -867,18 +782,18 @@ mod tests { builder.pub_op(FrontendOp::eq(100, 100))?; builder.pub_op(FrontendOp::eq(101, 101))?; - let solution = builder.solve()?; + let solved = builder.solve()?; // 6 statements / 2 per POD = 3 PODs minimum assert!( - solution.pod_count >= 2, + solved.solution().pod_count >= 2, "Expected at least 2 PODs, got {}", - solution.pod_count + solved.solution().pod_count ); // Prove and verify let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = solved.prove(&prover)?; for (i, pod) in result.pods.iter().enumerate() { pod.pod @@ -1004,8 +919,8 @@ mod tests { // Create MultiPodBuilder and add both external PODs let mut multi_builder = MultiPodBuilder::new(¶ms, vd_set); - multi_builder.add_pod(ext_pod_a.clone()); - multi_builder.add_pod(ext_pod_b.clone()); + multi_builder.add_pod(ext_pod_a.clone())?; + multi_builder.add_pod(ext_pod_b.clone())?; // Add private operations that reference different external PODs. // These will force multiple PODs due to private statement limits. @@ -1021,14 +936,14 @@ mod tests { // With 6 statements and max_priv_statements = 2, we need multiple PODs. // Each POD should only include the external POD it depends on. - let solution = multi_builder.solve()?; + let solved = multi_builder.solve()?; assert!( - solution.pod_count >= 2, + solved.solution().pod_count >= 2, "Expected at least 2 PODs, got {}", - solution.pod_count + solved.solution().pod_count ); - let result = multi_builder.prove(&prover)?; + let result = solved.prove(&prover)?; // Verify all PODs for (i, pod) in result.pods.iter().enumerate() { @@ -1064,7 +979,8 @@ mod tests { builder.pub_op(FrontendOp::eq(200, 200))?; builder.pub_op(FrontendOp::eq(201, 201))?; - let solution = builder.solve()?; + let solved = builder.solve()?; + let solution = solved.solution(); // Check that the output POD's public statements are exactly the user-requested public ones. // The output POD is always the last one (index pod_count - 1). @@ -1186,9 +1102,9 @@ mod tests { // Create MultiPodBuilder and add all 3 external PODs let mut multi_builder = MultiPodBuilder::new(¶ms, vd_set); - multi_builder.add_pod(ext_pod_a); - multi_builder.add_pod(ext_pod_b); - multi_builder.add_pod(ext_pod_c); + multi_builder.add_pod(ext_pod_a)?; + multi_builder.add_pod(ext_pod_b)?; + multi_builder.add_pod(ext_pod_c)?; // Add public operations that each depend on a different external POD // All 3 must be public in POD 0, requiring 3 external inputs > max_input_pods @@ -1258,18 +1174,19 @@ mod tests { builder.priv_op(FrontendOp::gt(contains_stmt, 2))?; // With correct counting, all 4 statements fit in 1 POD - let solution = builder.solve()?; + let solved = builder.solve()?; assert_eq!( - solution.pod_count, 1, + solved.solution().pod_count, + 1, "All statements should fit in 1 POD when Contains is not double-counted. \ Got {} PODs, which suggests the explicit Contains is being incorrectly \ counted as both a statement AND an anchored key overhead.", - solution.pod_count + solved.solution().pod_count ); // Verify proving works let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = solved.prove(&prover)?; assert_eq!(result.pods.len(), 1); result @@ -1322,7 +1239,7 @@ mod tests { builder.priv_op(FrontendOp::gt(stmt_a, 2))?; let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = builder.solve()?.prove(&prover)?; // Verify all PODs for (i, pod) in result.pods.iter().enumerate() { @@ -1371,7 +1288,7 @@ mod tests { .expect("ext_pod should have a public statement"); let mut builder = MultiPodBuilder::new(¶ms, vd_set); - builder.add_pod(ext_pod); + builder.add_pod(ext_pod)?; // Output POD: public Contains statements let dict0 = dict!({"x" => 100}); @@ -1387,7 +1304,7 @@ mod tests { builder.priv_op(FrontendOp::copy(stmt_ext))?; // This should succeed - total inputs per POD should stay within limit - let result = builder.prove(&prover)?; + let result = builder.solve()?.prove(&prover)?; for (i, pod) in result.pods.iter().enumerate() { pod.pod @@ -1433,20 +1350,19 @@ mod tests { // Add one public statement for output builder.pub_op(FrontendOp::eq(100, 100))?; - let pod_count = { - let solution = builder.solve()?; - // 4 SignedBy / 2 per POD = exactly 2 PODs - assert_eq!( - solution.pod_count, 2, - "Expected exactly 2 PODs for 4 SignedBy with max_signed_by=2, got {}", - solution.pod_count - ); - solution.pod_count - }; + let solved = builder.solve()?; + // 4 SignedBy / 2 per POD = exactly 2 PODs + assert_eq!( + solved.solution().pod_count, + 2, + "Expected exactly 2 PODs for 4 SignedBy with max_signed_by=2, got {}", + solved.solution().pod_count + ); + let pod_count = solved.solution().pod_count; // Prove and verify let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = solved.prove(&prover)?; assert_eq!(result.pods.len(), pod_count); for (i, pod) in result.pods.iter().enumerate() { @@ -1536,20 +1452,19 @@ mod tests { [contains4], ))?; - let pod_count = { - let solution = builder.solve()?; - // 4 batches / 2 per POD = exactly 2 PODs - assert_eq!( - solution.pod_count, 2, - "Expected exactly 2 PODs for 4 batches with max_custom_predicate_batches=2, got {}", - solution.pod_count - ); - solution.pod_count - }; + let solved = builder.solve()?; + // 4 batches / 2 per POD = exactly 2 PODs + assert_eq!( + solved.solution().pod_count, + 2, + "Expected exactly 2 PODs for 4 batches with max_custom_predicate_batches=2, got {}", + solved.solution().pod_count + ); + let pod_count = solved.solution().pod_count; // Prove and verify let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = solved.prove(&prover)?; assert_eq!(result.pods.len(), pod_count); for (i, pod) in result.pods.iter().enumerate() { @@ -1568,12 +1483,12 @@ mod tests { // // Chain: d_out -> c_out -> b_out -> a_out -> contains (5 statements) // - // With max_priv_statements = 2, each POD can hold at most 2 statements - // (including copies). Expected solution with 4 PODs: - // - POD 0 (intermediate): contains, a_out (a_out public) - // - POD 1 (intermediate): copy(a_out), b_out (b_out public) - // - POD 2 (intermediate): copy(b_out), c_out (c_out public) - // - POD 3 (output): copy(c_out), d_out + // With max_priv_statements = 2, each POD can hold at most 2 statements. + // Cross-POD dependencies are available via input PODs without needing copies. + // Expected solution with 3 PODs (ceil(5/2) = 3): + // - POD 0 (intermediate): contains, a_out (a_out public for POD 1) + // - POD 1 (intermediate): b_out, c_out (c_out public for POD 2) + // - POD 2 (output): d_out (public) let params = Params { max_statements: 4, @@ -1628,65 +1543,54 @@ mod tests { [c_out], ))?; - let solution = builder.solve()?; + let solved = builder.solve()?; + let solution = solved.solution(); - // Expected: exactly 4 PODs for a 5-statement chain with max_priv=2 - // - POD 0: statements 0 (contains), 1 (a_out); a_out public - // - POD 1: statement 2 (b_out); b_out public (copies a_out) - // - POD 2: statement 3 (c_out); c_out public (copies b_out) - // - POD 3 (output): statement 4 (d_out); d_out public (copies c_out) + // Expected: exactly 3 PODs for a 5-statement chain with max_priv=2 + // (5 statements / 2 per POD = 3 PODs) assert_eq!( - solution.pod_count, 4, - "Expected exactly 4 PODs for 5-statement chain with max_priv=2" + solution.pod_count, 3, + "Expected exactly 3 PODs for 5-statement chain with max_priv=2" ); - // POD 0: contains(0) and a_out(1) - assert!( - solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1), - "POD 0 should contain statements 0 and 1, got {:?}", - solution.pod_statements[0] - ); - assert!( - solution.pod_public_statements[0].contains(&1), - "Statement 1 (a_out) should be public in POD 0" + // All 5 statements should be assigned across the PODs + let all_statements: BTreeSet = solution + .pod_statements + .iter() + .flat_map(|s| s.iter().copied()) + .collect(); + assert_eq!( + all_statements, + (0..5).collect::>(), + "All 5 statements should be assigned" ); - // POD 1: b_out(2) - assert!( - solution.pod_statements[1].contains(&2), - "POD 1 should contain statement 2 (b_out), got {:?}", - solution.pod_statements[1] - ); - assert!( - solution.pod_public_statements[1].contains(&2), - "Statement 2 (b_out) should be public in POD 1" - ); + // Each POD should have at most max_priv_statements = 2 + for (i, stmts) in solution.pod_statements.iter().enumerate() { + assert!( + stmts.len() <= 2, + "POD {} has {} statements, but max_priv=2: {:?}", + i, + stmts.len(), + stmts + ); + } - // POD 2: c_out(3) + // The output POD (last) must contain d_out(4) and it must be public + let output_pod_idx = solution.pod_count - 1; assert!( - solution.pod_statements[2].contains(&3), - "POD 2 should contain statement 3 (c_out), got {:?}", - solution.pod_statements[2] + solution.pod_statements[output_pod_idx].contains(&4), + "Output POD should contain statement 4 (d_out), got {:?}", + solution.pod_statements[output_pod_idx] ); assert!( - solution.pod_public_statements[2].contains(&3), - "Statement 3 (c_out) should be public in POD 2" - ); - - // POD 3 (output): d_out(4) - assert!( - solution.pod_statements[3].contains(&4), - "POD 3 should contain statement 4 (d_out), got {:?}", - solution.pod_statements[3] - ); - assert!( - solution.pod_public_statements[3].contains(&4), + solution.pod_public_statements[output_pod_idx].contains(&4), "Statement 4 (d_out) should be public in output POD" ); // Prove and verify all PODs let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = solved.prove(&prover)?; for (i, pod) in result.pods.iter().enumerate() { pod.pod @@ -1709,13 +1613,8 @@ mod tests { // contains // // Where a_out depends on BOTH b_out and c_out, creating a diamond. - // With tight limits, b_out and c_out may end up in different PODs, - // and the output POD must copy from both. - // - // With max_priv_statements = 3: - // - POD 0: contains, b_out, c_out (b_out and c_out public) - 3 statements - // - POD 1 (output): copy(b_out), copy(c_out), a_out - 3 statements - // Or the solver may find a different arrangement. + // The solver may distribute statements across PODs in various ways, + // as long as dependencies are satisfied. let params = Params { max_statements: 6, @@ -1773,48 +1672,42 @@ mod tests { [b_out, c_out], ))?; - let solution = builder.solve()?; + let solved = builder.solve()?; + let solution = solved.solution(); - // Expected: exactly 2 PODs for the diamond - // - POD 0: contains(0), b_out(1), c_out(2); b_out and c_out public - // - POD 1 (output): a_out(3); a_out public (copies b_out and c_out) + // With 4 statements and max_priv=3, we need at least 2 PODs (ceil(4/3) = 2) assert_eq!( solution.pod_count, 2, "Expected exactly 2 PODs for diamond with max_priv=3" ); - // POD 0 should contain statements 0, 1, 2 + // The output POD (last) must contain statement 3 (a_out) and it must be public + let output_pod_idx = solution.pod_count - 1; assert!( - solution.pod_statements[0].contains(&0) - && solution.pod_statements[0].contains(&1) - && solution.pod_statements[0].contains(&2), - "POD 0 should contain statements 0, 1, 2, got {:?}", - solution.pod_statements[0] + solution.pod_statements[output_pod_idx].contains(&3), + "Output POD should contain statement 3 (a_out), got {:?}", + solution.pod_statements[output_pod_idx] ); - - // Statements 1 and 2 (b_out and c_out) should be public in POD 0 assert!( - solution.pod_public_statements[0].contains(&1) - && solution.pod_public_statements[0].contains(&2), - "Statements 1 and 2 should be public in POD 0" - ); - - // POD 1 (output) should contain statement 3 (a_out) - assert!( - solution.pod_statements[1].contains(&3), - "POD 1 should contain statement 3 (a_out), got {:?}", - solution.pod_statements[1] - ); - - // Statement 3 (a_out) should be public in output POD - assert!( - solution.pod_public_statements[1].contains(&3), + solution.pod_public_statements[output_pod_idx].contains(&3), "Statement 3 (a_out) should be public in output POD" ); - // Prove and verify all PODs + // All statements should be covered exactly once across all PODs + let all_statements: BTreeSet = solution + .pod_statements + .iter() + .flat_map(|s| s.iter().copied()) + .collect(); + assert_eq!( + all_statements, + [0, 1, 2, 3].into_iter().collect(), + "All statements should be assigned to exactly one POD" + ); + + // Prove and verify all PODs - this validates dependencies are satisfied let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = solved.prove(&prover)?; for (i, pod) in result.pods.iter().enumerate() { pod.pod @@ -1887,11 +1780,12 @@ mod tests { [a_out], ))?; - let solution = builder.solve()?; + let solved = builder.solve()?; + let solution = solved.solution(); // Expected: exactly 2 PODs due to batch limit // - POD 0: contains(0), a_out(1) using batch_a; a_out public - // - POD 1 (output): b_out(2) using batch_b; b_out public (copies a_out) + // - POD 1 (output): b_out(2) using batch_b; b_out public // // Even though max_priv_statements=6 could fit all 3 statements, // max_custom_predicate_batches=1 forces batch_a and batch_b into different PODs. @@ -1924,7 +1818,7 @@ mod tests { // Prove and verify let prover = MockProver {}; - let result = builder.prove(&prover)?; + let result = solved.prove(&prover)?; for (i, pod) in result.pods.iter().enumerate() { pod.pod diff --git a/src/frontend/multi_pod/solver.rs b/src/frontend/multi_pod/solver.rs index 7920384..4441c27 100644 --- a/src/frontend/multi_pod/solver.rs +++ b/src/frontend/multi_pod/solver.rs @@ -15,7 +15,6 @@ //! - **Constraint 4 (POD Existence)**: If any statement is proved in POD p, then p is used. //! - **Constraint 5 (Dependencies)**: If statement S depends on D and S is proved in POD p, //! then D must be available: either proved locally in p, or public in some earlier POD. -//! - **Constraint 5b (Copy Tracking)**: Track when dependencies need CopyStatement. //! - **Constraint 6 (Resource Limits)**: Per-POD limits on statements, public slots, merkle //! proofs, custom predicates, batches, etc. //! - **Constraint 7 (Batch Cardinality)**: Limit distinct custom predicate batches per POD. @@ -251,32 +250,6 @@ fn try_solve_with_pods( .map(|p| (0..p).map(|_| vars.add(variable().binary())).collect()) .collect(); - // Collect all statement indices that are internal dependencies. - // These are statements that other statements depend on, and may need to be copied - // into PODs where the dependent statement is proved but the dependency is not. - let internal_deps: BTreeSet = input - .deps - .statement_deps - .iter() - .flat_map(|deps| deps.iter()) - .filter_map(|dep| match dep { - StatementSource::Internal(d) => Some(*d), - StatementSource::External(_) => None, - }) - .collect(); - - // needs_copy[d][p] - dependency d needs to be copied into POD p - // This is 1 when: (some statement s in p depends on d) AND (d is not proved in p) - // We only create variables for dependencies that are actually used. - let dep_indices: Vec = internal_deps.iter().copied().collect(); - let needs_copy: Vec> = (0..dep_indices.len()) - .map(|_| { - (0..target_pods) - .map(|_| vars.add(variable().binary())) - .collect() - }) - .collect(); - // Collect all external POD hashes that statements depend on. // These are user-provided input PODs referenced by statements. use crate::middleware::Hash; @@ -321,9 +294,11 @@ fn try_solve_with_pods( }) .collect(); - // Objective: minimize number of PODs used - let objective: Expression = pod_used.iter().sum(); - let mut model = vars.minimise(objective).using(default_solver); + // No optimization objective needed - we use an incremental approach that tries + // min_pods first and increments until feasible. Combined with symmetry breaking + // (Constraint 9), this finds the minimum number of PODs without needing MILP + // optimization. A constant objective makes the solver find any feasible solution. + let mut model = vars.minimise(0_i32).using(default_solver); // Constraint 1: Each statement must be proved at least once for s in 0..n { @@ -388,30 +363,6 @@ fn try_solve_with_pods( } } - // Constraint 5b: needs_copy tracking for cross-POD dependencies - // needs_copy[d][p] = 1 when: some statement s proved in p depends on d, AND d is not proved in p. - // This tracks CopyStatements that will be added during build_single_pod. - for (di, &d) in dep_indices.iter().enumerate() { - for p in 0..target_pods { - // needs_copy[d][p] >= prove[s][p] - prove[d][p] for each s that depends on d - // If s is in p (prove[s][p]=1) and d is not in p (prove[d][p]=0), then needs_copy >= 1 - for s in 0..n { - let depends_on_d = input.deps.statement_deps[s] - .iter() - .any(|dep| matches!(dep, StatementSource::Internal(dep_d) if *dep_d == d)); - if depends_on_d { - model.add_constraint(constraint!( - needs_copy[di][p] >= prove[s][p] - prove[d][p] - )); - } - } - - // needs_copy[d][p] <= 1 - prove[d][p] - // If d is proved locally (prove[d][p]=1), no copy needed (needs_copy <= 0) - model.add_constraint(constraint!(needs_copy[di][p] <= 1 - prove[d][p])); - } - } - // Constraint 6: Resource limits per POD // // 6a-pre: Content group tracking for statement deduplication @@ -430,17 +381,16 @@ fn try_solve_with_pods( } for p in 0..target_pods { - // 6a: Unique statement count (unique content groups + CopyStatements + anchored key Contains) + // 6a: Unique statement count (unique content groups + anchored key Contains) // Statements with identical content share a slot, so we count content groups, not indices. - // CopyStatements and anchored key Contains also use statement slots. + // Anchored key Contains statements are auto-inserted by MainPodBuilder when needed. // The total must not exceed max_priv_statements (= max_statements - max_public_statements). let unique_stmt_sum: Expression = (0..num_groups).map(|g| content_group_used[g][p]).sum(); - let copy_sum: Expression = (0..dep_indices.len()).map(|di| needs_copy[di][p]).sum(); let anchored_key_sum: Expression = (0..input.all_anchored_keys.len()) .map(|ak| anchored_key_used[ak][p]) .sum(); model.add_constraint(constraint!( - unique_stmt_sum + copy_sum + anchored_key_sum + unique_stmt_sum + anchored_key_sum <= (input.params.max_priv_statements() as f64) * pod_used[p] ));