From 48aa004ae52b24fc6a9c6c1e34bd69ae0e246cba Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Wed, 28 Jan 2026 07:44:04 +0100 Subject: [PATCH] Create multiple PODs where resource limits for a single POD are exceeded (#444) * Create multiple PODs where resource limits for a single POD are exceeded * HashSet -> BTreeSet determinism fix * Fixed incorrect assignment of input PODs and added test * Ensure only a single output POD * Return error when reveal() called with unknown statement * Use unreachable! for presumed-impossible cases * Use assert_eq! rather than debug_assert_eq * Use FIFO for topological sort * Simplify bounds calculation * Some more simplifications/comments * Enforce dep_idx < idx invariant * Incrementally solve rather than estimating slack * Fix tests to correctly test dependencies between private and public statements * More tidying * Note possible optimisation of MainPodBuilder cloning of input PODs * Fix tracking of total input POD count * Refactor tests * Formatting * Small optimisation: use Vec in place of BTreeSet * Account for automatically-inserted Contains statements * Formatting * Fix possible issue with copied statements * Simplify result type given only a single result MainPod * Remove unnecessary POD count estimate functionality * Simplify dependency ordering and tracking * Remove notion of multiple output PODs from solver * Minor simplifications * Use add_constraint instead of with * Remove unnecessary check following assertion * Fix handling of anchored keys given that Contains statements are not auto-inserted if they already exist * Fix confusing dependency graph test * Remove prove_order * Fix deduplication and possible double-counting of public but not copied statements * Reorder so that the output POD is the final POD * Add more detailed tests * Remove redundant tests * Simplify POD counting * More docs * Flag more branches as unreachable * Formatting * Fix for changed custom batch parsing --- Cargo.toml | 1 + src/frontend/mod.rs | 4 + src/frontend/multi_pod/cost.rs | 224 ++++ src/frontend/multi_pod/deps.rs | 178 +++ src/frontend/multi_pod/mod.rs | 1937 ++++++++++++++++++++++++++++++ src/frontend/multi_pod/solver.rs | 703 +++++++++++ 6 files changed, 3047 insertions(+) create mode 100644 src/frontend/multi_pod/cost.rs create mode 100644 src/frontend/multi_pod/deps.rs create mode 100644 src/frontend/multi_pod/mod.rs create mode 100644 src/frontend/multi_pod/solver.rs diff --git a/Cargo.toml b/Cargo.toml index 537c38e..810b08c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ serde_bytes = "0.11" serde_arrays = "0.2.0" sha2 = { version = "0.10.9" } rand_chacha = "0.3.1" +good_lp = { version = "1.8", default-features = false, features = ["microlp"] } # Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory. # [patch."https://github.com/0xPARC/plonky2"] diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index aa96f46..f600f7c 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -21,11 +21,15 @@ use crate::middleware::{ mod custom; mod error; +mod multi_pod; mod operation; mod pod_request; mod serialization; pub use custom::*; pub use error::*; +pub use multi_pod::{ + MultiPodBuilder, MultiPodResult, MultiPodSolution, Options as MultiPodOptions, +}; pub use operation::*; pub use pod_request::*; diff --git a/src/frontend/multi_pod/cost.rs b/src/frontend/multi_pod/cost.rs new file mode 100644 index 0000000..a5d89da --- /dev/null +++ b/src/frontend/multi_pod/cost.rs @@ -0,0 +1,224 @@ +//! Resource cost analysis for statements and operations. +//! +//! This module provides cost analysis for multi-POD packing. Each operation +//! consumes various resources that have per-POD limits. + +use std::collections::BTreeSet; + +use crate::{ + frontend::{Operation, OperationArg}, + middleware::{ + CustomPredicateBatch, Hash, NativeOperation, OperationType, RawValue, Statement, ValueRef, + }, +}; + +/// Unique identifier for a custom predicate batch. +/// +/// Uses the batch's cryptographic hash as identifier. Two batches with the same +/// hash are considered identical for resource counting purposes. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct CustomBatchId(pub Hash); + +impl From<&CustomPredicateBatch> for CustomBatchId { + fn from(batch: &CustomPredicateBatch) -> Self { + Self(batch.id()) + } +} + +/// Unique identifier for an anchored key (dict, key) pair. +/// +/// When a Contains statement is used as an argument to operations like gt(), eq(), etc., +/// the value is accessed via an "anchored key" - a reference to a specific key in a +/// specific dictionary. Each unique anchored key used in a POD requires a Contains +/// statement to be present in that POD (auto-inserted by MainPodBuilder if needed). +/// +/// We use the raw values of the dict and key for comparison, as they uniquely identify +/// the anchored key regardless of the specific Value types involved. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct AnchoredKeyId { + /// The dictionary root value (raw representation for Ord). + pub dict: RawValue, + /// The key within the dictionary (raw representation for Ord). + pub key: RawValue, +} + +impl AnchoredKeyId { + /// Create a new anchored key ID from raw values. + pub fn new(dict: RawValue, key: RawValue) -> Self { + Self { dict, key } + } + + /// Try to extract an anchored key ID from a Contains statement with all literal values. + pub fn from_contains_statement(stmt: &Statement) -> Option { + if let Statement::Contains( + ValueRef::Literal(dict), + ValueRef::Literal(key), + ValueRef::Literal(_value), + ) = stmt + { + Some(Self::new(dict.raw(), key.raw())) + } else { + None + } + } +} + +/// Resource costs for a single statement/operation. +/// +/// Each field corresponds to a resource with a per-POD limit in `Params`. +#[derive(Clone, Debug, Default)] +pub struct StatementCost { + /// Number of merkle proofs used (for Contains/NotContains). + /// Limit: `params.max_merkle_proofs_containers` + pub merkle_proofs: usize, + + /// Number of merkle tree state transition proofs (for Insert/Update/Delete). + /// Limit: `params.max_merkle_tree_state_transition_proofs_containers` + pub merkle_state_transitions: usize, + + /// Number of custom predicate verifications. + /// Limit: `params.max_custom_predicate_verifications` + pub custom_pred_verifications: usize, + + /// Number of SignedBy operations. + /// Limit: `params.max_signed_by` + pub signed_by: usize, + + /// Number of PublicKeyOf operations. + /// Limit: `params.max_public_key_of` + pub public_key_of: usize, + + /// Custom predicate batches used (for batch cardinality constraint). + /// Limit: `params.max_custom_predicate_batches` distinct batches per POD. + pub custom_batch_ids: BTreeSet, + + /// Anchored keys referenced by this operation. + /// + /// When a Contains statement with all literal values is used as an argument, + /// the operation references an "anchored key" (dict, key pair). Each unique + /// anchored key used in a POD incurs an additional Contains statement cost, + /// as MainPodBuilder::add_entries_contains will auto-insert it if not already present. + pub anchored_keys: BTreeSet, +} + +impl StatementCost { + /// Compute the resource cost of an operation. + pub fn from_operation(op: &Operation) -> Self { + let mut cost = Self::default(); + + match &op.0 { + OperationType::Native(native_op) => { + match native_op { + // Operations that use merkle proofs + NativeOperation::ContainsFromEntries + | NativeOperation::NotContainsFromEntries + | NativeOperation::DictContainsFromEntries + | NativeOperation::DictNotContainsFromEntries + | NativeOperation::SetContainsFromEntries + | NativeOperation::SetNotContainsFromEntries + | NativeOperation::ArrayContainsFromEntries => { + cost.merkle_proofs = 1; + } + + // Operations that use merkle state transitions + NativeOperation::ContainerInsertFromEntries + | NativeOperation::ContainerUpdateFromEntries + | NativeOperation::ContainerDeleteFromEntries + | NativeOperation::DictInsertFromEntries + | NativeOperation::DictUpdateFromEntries + | NativeOperation::DictDeleteFromEntries + | NativeOperation::SetInsertFromEntries + | NativeOperation::SetDeleteFromEntries + | NativeOperation::ArrayUpdateFromEntries => { + cost.merkle_state_transitions = 1; + } + + // SignedBy operation + NativeOperation::SignedBy => { + cost.signed_by = 1; + } + + // PublicKeyOf operation + NativeOperation::PublicKeyOf => { + cost.public_key_of = 1; + } + + // Operations with no special resource costs + NativeOperation::None + | NativeOperation::CopyStatement + | NativeOperation::EqualFromEntries + | NativeOperation::NotEqualFromEntries + | NativeOperation::LtEqFromEntries + | NativeOperation::LtFromEntries + | NativeOperation::TransitiveEqualFromStatements + | NativeOperation::LtToNotEqual + | NativeOperation::SumOf + | NativeOperation::ProductOf + | NativeOperation::MaxOf + | NativeOperation::HashOf + // Syntactic sugar variants (lowered before proving) + | NativeOperation::GtEqFromEntries + | NativeOperation::GtFromEntries + | NativeOperation::GtToNotEqual => {} + } + } + OperationType::Custom(cpr) => { + cost.custom_pred_verifications = 1; + cost.custom_batch_ids + .insert(CustomBatchId::from(&*cpr.batch)); + } + } + + // Extract anchored keys from operation arguments. + // Any argument that is a Contains statement with all literal values + // represents an anchored key reference that will require a Contains + // statement in the POD (auto-inserted by MainPodBuilder if needed). + for arg in &op.1 { + if let OperationArg::Statement(stmt) = arg { + if let Some(anchored_key) = AnchoredKeyId::from_contains_statement(stmt) { + cost.anchored_keys.insert(anchored_key); + } + } + } + + cost + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + frontend::Operation as FrontendOp, + middleware::{NativeOperation, OperationAux, OperationType}, + }; + + fn make_native_op(native_op: NativeOperation) -> FrontendOp { + FrontendOp(OperationType::Native(native_op), vec![], OperationAux::None) + } + + #[test] + fn test_cost_from_native_ops() { + // Test merkle proof ops + let contains_op = make_native_op(NativeOperation::ContainsFromEntries); + let cost = StatementCost::from_operation(&contains_op); + assert_eq!(cost.merkle_proofs, 1); + assert_eq!(cost.merkle_state_transitions, 0); + + // Test merkle state transition ops + let insert_op = make_native_op(NativeOperation::ContainerInsertFromEntries); + let cost = StatementCost::from_operation(&insert_op); + assert_eq!(cost.merkle_proofs, 0); + assert_eq!(cost.merkle_state_transitions, 1); + + // Test signed_by + let signed_op = make_native_op(NativeOperation::SignedBy); + let cost = StatementCost::from_operation(&signed_op); + assert_eq!(cost.signed_by, 1); + + // Test public_key_of + let pk_op = make_native_op(NativeOperation::PublicKeyOf); + let cost = StatementCost::from_operation(&pk_op); + assert_eq!(cost.public_key_of, 1); + } +} diff --git a/src/frontend/multi_pod/deps.rs b/src/frontend/multi_pod/deps.rs new file mode 100644 index 0000000..328dd5b --- /dev/null +++ b/src/frontend/multi_pod/deps.rs @@ -0,0 +1,178 @@ +//! Dependency analysis for statements and operations. +//! +//! This module analyzes dependencies between statements to determine +//! which statements must be proved before others. + +use std::collections::HashMap; + +use crate::{ + frontend::{Operation, OperationArg}, + middleware::{Hash, Statement}, +}; + +/// Represents a source of a statement dependency. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum StatementSource { + /// Statement created within this builder at the given index. + Internal(usize), + /// Statement from an external input POD (identified by POD hash). + External(Hash), +} + +/// Dependency graph for all statements in a builder. +/// +/// Each element `statement_deps[i]` is the list of dependencies for statement `i`. +#[derive(Clone, Debug)] +pub struct DependencyGraph { + /// Dependencies for each statement (indexed by statement index). + pub statement_deps: Vec>, +} + +impl DependencyGraph { + /// Build a dependency graph from statements and operations. + /// + /// `statements` and `operations` should be parallel arrays where + /// `operations[i]` produces `statements[i]`. + /// + /// `external_pod_statements` maps (pod_hash, statement) pairs to enable + /// recognizing references to external POD statements. + pub fn build( + statements: &[Statement], + operations: &[Operation], + external_pod_statements: &HashMap, + ) -> Self { + let mut statement_deps = Vec::with_capacity(statements.len()); + + // Build a map from statement to its index for internal lookup. + // Use entry().or_insert() to preserve the FIRST occurrence of each statement. + // This is important for CopyStatement: if statements[0] = A and statements[2] = copy(A) = A, + // we want statement_to_index[A] = 0 (the original), not 2 (the copy). + let mut statement_to_index: HashMap<&Statement, usize> = HashMap::new(); + for (i, s) in statements.iter().enumerate() { + if !s.is_none() { + statement_to_index.entry(s).or_insert(i); + } + } + + for (idx, op) in operations.iter().enumerate() { + let mut deps = Vec::new(); + + // Examine each argument to the operation + for arg in &op.1 { + if let OperationArg::Statement(ref dep_stmt) = arg { + if dep_stmt.is_none() { + continue; + } + + // Check if this is an internal statement (created earlier in this builder) + if let Some(&dep_idx) = statement_to_index.get(dep_stmt) { + // Internal dependencies must always be from earlier statements + assert!( + dep_idx <= idx, + "Statement at index {} depends on future statement at index {}", + idx, + dep_idx + ); + + if dep_idx < idx { + // The statement was created by an earlier operation + deps.push(StatementSource::Internal(dep_idx)); + continue; + } + // dep_idx == idx: The first occurrence of this statement is at the current index, + // meaning this operation both takes and produces this statement (e.g., CopyStatement + // copying from an external POD). Fall through to check external PODs for the source. + } + + // Check if this is from an external POD + if let Some(&pod_hash) = external_pod_statements.get(dep_stmt) { + deps.push(StatementSource::External(pod_hash)); + } else { + // Statement arguments should either be internal (created earlier) + // or from external PODs. If neither, something is wrong. + unreachable!( + "Statement argument not found in internal statements or external PODs: {:?}", + dep_stmt + ); + } + } + } + + statement_deps.push(deps); + } + + Self { statement_deps } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + frontend::Operation as FrontendOp, + middleware::{NativeOperation, OperationAux, OperationType, Value, ValueRef}, + }; + + fn equal_stmt(n: i64) -> Statement { + Statement::Equal( + ValueRef::Literal(Value::from(n)), + ValueRef::Literal(Value::from(n)), + ) + } + + /// None operation produces Statement::None + fn none_op() -> FrontendOp { + FrontendOp( + OperationType::Native(NativeOperation::None), + vec![], + OperationAux::None, + ) + } + + /// CopyStatement(s) produces s (the same statement) + fn copy_op(stmt: Statement) -> FrontendOp { + FrontendOp( + OperationType::Native(NativeOperation::CopyStatement), + vec![OperationArg::Statement(stmt)], + OperationAux::None, + ) + } + + #[test] + fn test_copy_creates_dependency_on_original() { + // CopyStatement(s) produces s. When we copy a statement, the copy + // depends on where that statement first appears. + // + // statements[0] = s (produced by none_op - not realistic, but we need a first occurrence) + // statements[1] = s (produced by copy_op(s)) + // + // op1's argument s matches statements[0], so statement 1 depends on statement 0. + let s = equal_stmt(1); + + let statements = vec![s.clone(), s.clone()]; + let operations = vec![ + none_op(), // Placeholder - in reality something else would produce s + copy_op(s), // Copies s, producing s. Depends on statements[0]. + ]; + + let graph = DependencyGraph::build(&statements, &operations, &HashMap::new()); + + assert!(graph.statement_deps[0].is_empty()); + assert_eq!(graph.statement_deps[1], vec![StatementSource::Internal(0)]); + } + + #[test] + fn test_multiple_copies_depend_on_original() { + // Multiple copies of the same statement all depend on where it first appears. + let s = equal_stmt(1); + + let statements = vec![s.clone(), s.clone(), s.clone()]; + let operations = vec![none_op(), copy_op(s.clone()), copy_op(s)]; + + let graph = DependencyGraph::build(&statements, &operations, &HashMap::new()); + + assert!(graph.statement_deps[0].is_empty()); + assert_eq!(graph.statement_deps[1], vec![StatementSource::Internal(0)]); + assert_eq!(graph.statement_deps[2], vec![StatementSource::Internal(0)]); + } +} diff --git a/src/frontend/multi_pod/mod.rs b/src/frontend/multi_pod/mod.rs new file mode 100644 index 0000000..d90c295 --- /dev/null +++ b/src/frontend/multi_pod/mod.rs @@ -0,0 +1,1937 @@ +//! Multi-POD builder for automatic statement packing. +//! +//! This module provides [`MultiPodBuilder`], a higher-level alternative to [`MainPodBuilder`] +//! that automatically handles cases where statements exceed per-POD resource limits by +//! splitting across multiple PODs. +//! +//! # Problem +//! +//! A single POD has resource limits (max statements, max custom predicate batches, etc.). +//! When a proof requires more resources than a single POD can provide, statements must +//! be split across multiple PODs with dependencies resolved via cross-POD copying. +//! +//! # Architecture +//! +//! The multi-POD system uses a MILP (Mixed Integer Linear Program) solver to find the +//! optimal assignment of statements to PODs. The solver minimizes the number of PODs +//! while respecting: +//! - Per-POD resource limits (statements, batches, merkle proofs, etc.) +//! - Statement dependencies (if A depends on B, B must be available when proving A) +//! - Input POD limits (each POD can only reference a limited number of other PODs) +//! +//! # POD Ordering +//! +//! PODs are built in index order: 0, 1, 2, ..., k. The **output POD is always last** +//! (index k), containing the user-requested public statements. Earlier PODs (0..k-1) +//! are **intermediate PODs** that prove supporting statements. +//! +//! This ordering allows dependencies to flow forward: later PODs can access public +//! statements from earlier PODs via `CopyStatement`. The output POD, being last, can +//! access all intermediate PODs. +//! +//! # Usage +//! +//! ```ignore +//! let mut builder = MultiPodBuilder::new(¶ms, &vd_set); +//! +//! // Add operations (similar to MainPodBuilder) +//! let stmt_a = builder.priv_op(FrontendOp::eq(1, 1))?; +//! let stmt_b = builder.pub_op(FrontendOp::eq(2, 2))?; // Will be public in output +//! +//! // Solve and prove +//! let result = builder.prove(&prover)?; +//! +//! // Access the output POD +//! let output = result.output_pod(); +//! ``` +//! +//! [`MainPodBuilder`]: crate::frontend::MainPodBuilder + +use std::collections::{BTreeSet, HashMap}; + +use crate::{ + frontend::{MainPod, MainPodBuilder, Operation, OperationArg}, + middleware::{ + Hash, MainPodProver, NativeOperation, OperationAux, OperationType, Params, Statement, VDSet, + }, +}; + +mod cost; +mod deps; +mod solver; + +use cost::{AnchoredKeyId, StatementCost}; +use deps::{DependencyGraph, StatementSource}; +pub use solver::MultiPodSolution; + +/// Error type for multi-POD operations. +#[derive(Debug, Clone)] +pub enum Error { + /// Error from the frontend. + Frontend(String), + /// Error from the MILP solver. + Solver(String), + /// No solution exists (shouldn't happen with valid input). + NoSolution, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::Frontend(msg) => write!(f, "Frontend error: {}", msg), + Error::Solver(msg) => write!(f, "Solver error: {}", msg), + Error::NoSolution => write!(f, "No solution exists"), + } + } +} + +impl std::error::Error for Error {} + +impl From for Error { + fn from(e: crate::frontend::Error) -> Self { + Error::Frontend(e.to_string()) + } +} + +pub type Result = std::result::Result; + +/// Default maximum number of PODs the solver will consider. +pub const DEFAULT_MAX_PODS: usize = 20; + +/// Options for configuring MultiPodBuilder behavior. +#[derive(Debug, Clone)] +pub struct Options { + /// Maximum number of PODs the solver will consider. + /// Defaults to 20. Increase if you have a very large number of statements. + pub max_pods: usize, +} + +impl Default for Options { + fn default() -> Self { + Self { + max_pods: DEFAULT_MAX_PODS, + } + } +} + +/// Result of proving with MultiPodBuilder. +#[derive(Debug)] +pub struct MultiPodResult { + /// All PODs in build order (0, 1, ..., k). + /// Intermediate PODs are at indices 0..k-1. + /// The output POD is at index k (the last POD). + pub pods: Vec, +} + +impl MultiPodResult { + /// Get the output POD (containing user-requested public statements). + /// This is always the last POD (`pods[k]`), which can access all earlier + /// intermediate PODs for dependencies. + pub fn output_pod(&self) -> &MainPod { + self.pods + .last() + .expect("MultiPodResult must have at least one POD") + } + + /// Get intermediate/supporting PODs (all PODs except the output POD). + /// These are at indices 0..k-1, built before the output POD. + pub fn intermediate_pods(&self) -> &[MainPod] { + &self.pods[..self.pods.len() - 1] + } +} + +/// Builder for creating multiple PODs when statements exceed per-POD limits. +/// +/// # Overview +/// +/// `MultiPodBuilder` provides a similar API to [`MainPodBuilder`], but automatically +/// splits statements across multiple PODs when resource limits are exceeded. The +/// workflow is: +/// +/// 1. **Add operations**: Use [`priv_op`](Self::priv_op) and [`pub_op`](Self::pub_op) +/// to add statements, just like `MainPodBuilder`. +/// +/// 2. **Solve**: Call [`solve`](Self::solve) to run the MILP solver, which determines +/// the optimal assignment of statements to PODs. +/// +/// 3. **Prove**: Call [`prove`](Self::prove) to build and prove all PODs. +/// +/// # POD Structure +/// +/// The result contains PODs in build order: intermediate PODs first (indices 0..k-1), +/// then the output POD last (index k). The output POD contains all user-requested +/// public statements (those added via `pub_op`). Intermediate PODs make their +/// statements public so later PODs can copy them. +/// +/// [`MainPodBuilder`]: crate::frontend::MainPodBuilder +#[derive(Debug)] +pub struct MultiPodBuilder { + params: Params, + vd_set: VDSet, + options: Options, + /// External input PODs (already proved). + input_pods: Vec, + /// Statements created by this builder. + statements: Vec, + /// Operations that produce each statement. + operations: Vec, + /// Indices of statements that should be public in output PODs. + /// Uses Vec since max_public_statements is small (≤8); indices are naturally sorted. + output_public_indices: Vec, + /// Cached solution from the solver. + cached_solution: Option, + /// Cached dependency graph (computed once in solve(), reused in build_single_pod()). + cached_deps: Option, + /// Cached external POD statement map (computed once in solve(), reused in build_single_pod()). + cached_external_map: Option>, + /// Cached MainPodBuilder for incremental statement computation. + cached_builder: Option, +} + +impl MultiPodBuilder { + /// Create a new MultiPodBuilder with default options. + pub fn new(params: &Params, vd_set: &VDSet) -> Self { + Self::new_with_options(params, vd_set, Options::default()) + } + + /// Create a new MultiPodBuilder with custom options. + pub fn new_with_options(params: &Params, vd_set: &VDSet, options: Options) -> Self { + Self { + params: params.clone(), + vd_set: vd_set.clone(), + options, + input_pods: Vec::new(), + statements: Vec::new(), + operations: Vec::new(), + output_public_indices: Vec::new(), + cached_solution: None, + cached_deps: None, + cached_external_map: None, + cached_builder: None, + } + } + + /// Add an external input POD. + pub fn add_pod(&mut self, pod: MainPod) { + // Keep cached_builder in sync if it exists + if let Some(ref mut builder) = self.cached_builder { + // Won't fail - cached_builder has unlimited params + let _ = builder.add_pod(pod.clone()); + } + self.input_pods.push(pod); + self.invalidate_cache(); + } + + /// Add a public operation (statement will be public in output). + pub fn pub_op(&mut self, op: Operation) -> Result { + let stmt = self.add_operation(op)?; + // Index is always new (just added), so push without duplicate check + self.output_public_indices.push(self.statements.len() - 1); + Ok(stmt) + } + + /// Add a private operation. + pub fn priv_op(&mut self, op: Operation) -> Result { + self.add_operation(op) + } + + /// Internal: Add an operation and create its statement. + fn add_operation(&mut self, op: Operation) -> Result { + self.invalidate_cache(); + + // Get or create the cached builder + // + // NOTE: We clone input pods here because MainPodBuilder takes ownership. + // This could be avoided if MainPodBuilder were generic over the pod storage type: + // struct MainPodBuilder = MainPod> + // Then MultiPodBuilder could use MainPodBuilder<&MainPod> to borrow instead of clone, + // while existing code using MainPodBuilder (with the default) would be unaffected. + let builder = self.cached_builder.get_or_insert_with(|| { + let unlimited_params = Params { + max_statements: usize::MAX / 2, + max_public_statements: usize::MAX / 2, + max_input_pods: usize::MAX / 2, + max_input_pods_public_statements: usize::MAX / 2, + ..self.params.clone() + }; + let mut b = MainPodBuilder::new(&unlimited_params, &self.vd_set); + for pod in &self.input_pods { + let _ = b.add_pod(pod.clone()); + } + b + }); + + let stmt = builder + .op(false, vec![], op.clone()) + .map_err(|e| Error::Frontend(e.to_string()))?; + + self.statements.push(stmt.clone()); + self.operations.push(op); + + Ok(stmt) + } + + /// Mark a statement as public in output. + /// + /// Returns an error if the statement was not found in the builder. + /// Calling this multiple times on the same statement is idempotent. + pub fn reveal(&mut self, stmt: &Statement) -> Result<()> { + if let Some(idx) = self.statements.iter().position(|s| s == stmt) { + // Only invalidate cache if this is a new reveal + if !self.output_public_indices.contains(&idx) { + self.output_public_indices.push(idx); + self.invalidate_cache(); + } + Ok(()) + } else { + Err(Error::Frontend( + "reveal() called with statement not found in builder".to_string(), + )) + } + } + + /// Get the number of statements. + pub fn num_statements(&self) -> usize { + self.statements.len() + } + + /// Invalidate all cached data. Called when operations or statements change. + fn invalidate_cache(&mut self) { + self.cached_solution = None; + self.cached_deps = None; + self.cached_external_map = None; + } + + /// Solve the packing problem and return the solution. + /// + /// This runs the MILP solver to find the optimal POD assignment. + /// The solution is cached for subsequent calls. + pub fn solve(&mut self) -> Result<&MultiPodSolution> { + if self.cached_solution.is_some() { + return Ok(self.cached_solution.as_ref().unwrap()); + } + + // Compute costs for each statement + let costs: Vec = self + .operations + .iter() + .map(StatementCost::from_operation) + .collect(); + + // Collect all unique anchored keys from the costs + let all_anchored_keys: Vec = costs + .iter() + .flat_map(|c| c.anchored_keys.iter().cloned()) + .collect::>() + .into_iter() + .collect(); + + // Build map from anchored key to its producing statement index (if any). + // A Contains statement with literal (dict, key, value) "produces" that anchored key. + let mut ak_to_producer: HashMap = HashMap::new(); + for (stmt_idx, stmt) in self.statements.iter().enumerate() { + if let Some(ak) = AnchoredKeyId::from_contains_statement(stmt) { + // First producer wins (shouldn't have duplicates in practice) + ak_to_producer.entry(ak).or_insert(stmt_idx); + } + } + + // Build parallel array: anchored_key_producers[i] = producer for all_anchored_keys[i] + let anchored_key_producers: Vec> = all_anchored_keys + .iter() + .map(|ak| ak_to_producer.get(ak).copied()) + .collect(); + + // Build external POD statement mapping (cache for reuse in build_single_pod) + let external_pod_statements = self.build_external_statement_map(); + self.cached_external_map = Some(external_pod_statements); + let external_pod_statements = self.cached_external_map.as_ref().unwrap(); + + // Build dependency graph (cache for reuse in build_single_pod) + let deps = + DependencyGraph::build(&self.statements, &self.operations, external_pod_statements); + self.cached_deps = Some(deps); + let deps = self.cached_deps.as_ref().unwrap(); + + // Build statement content groups for deduplication. + // Statements with identical content share a single slot in the POD. + // Group statement indices by their content. + let mut content_to_indices: HashMap<&Statement, Vec> = HashMap::new(); + for (idx, stmt) in self.statements.iter().enumerate() { + content_to_indices.entry(stmt).or_default().push(idx); + } + let statement_content_groups: Vec> = content_to_indices.into_values().collect(); + + // Run solver + let input = solver::SolverInput { + num_statements: self.statements.len(), + costs: &costs, + deps, + output_public_indices: &self.output_public_indices, + params: &self.params, + max_pods: self.options.max_pods, + all_anchored_keys: &all_anchored_keys, + anchored_key_producers: &anchored_key_producers, + statement_content_groups: &statement_content_groups, + }; + + let solution = solver::solve(&input)?; + self.cached_solution = Some(solution); + + Ok(self.cached_solution.as_ref().unwrap()) + } + + /// Build and prove all PODs. + /// + /// This first solves if not already solved, then builds and proves + /// all necessary PODs in dependency order. + pub fn prove(&mut self, prover: &dyn MainPodProver) -> Result { + // Ensure we have a solution (can't use returned reference due to later &mut self borrows) + self.solve()?; + let solution = self.cached_solution.as_ref().unwrap(); + + // Build PODs in sequential order: 0, 1, 2, ..., k + // This order is guaranteed by the solver's symmetry-breaking constraint, which + // ensures PODs are used in order (no gaps). Sequential building is required because + // later PODs may reference earlier ones via CopyStatement for cross-POD dependencies. + // PODs 0..k-1 are intermediate; POD k (the last one) is the output POD. + let mut pods: Vec = Vec::with_capacity(solution.pod_count); + + for pod_idx in 0..solution.pod_count { + let pod = self.build_single_pod(pod_idx, solution, &pods, prover)?; + pods.push(pod); + } + + Ok(MultiPodResult { pods }) + } + + /// Build a single POD based on the solver solution. + /// + /// This function translates the solver's abstract assignment into a concrete POD by: + /// 1. Identifying which input PODs are needed (external + earlier generated) + /// 2. Adding those input PODs to a fresh `MainPodBuilder` + /// 3. For each statement assigned to this POD (in dependency order): + /// - Copy any dependencies from earlier PODs via `CopyStatement` + /// - Execute the original operation to create the statement + /// - Mark as public if the solver determined it should be + /// 4. Prove the POD + fn build_single_pod( + &self, + pod_idx: usize, + solution: &MultiPodSolution, + earlier_pods: &[MainPod], + prover: &dyn MainPodProver, + ) -> Result { + let mut builder = MainPodBuilder::new(&self.params, &self.vd_set); + + let deps = self + .cached_deps + .as_ref() + .expect("build_single_pod called before solve()"); + + let statements_in_this_pod: &Vec = &solution.pod_statements[pod_idx]; + let mut needed_external_pods: BTreeSet = BTreeSet::new(); + let mut needed_earlier_pods: BTreeSet = BTreeSet::new(); + + // Step 1: Find which external and earlier PODs we need based on dependencies + for &stmt_idx in statements_in_this_pod { + for dep in &deps.statement_deps[stmt_idx] { + match dep { + StatementSource::Internal(dep_idx) => { + // Check if dependency is in an earlier generated POD + let mut found = false; + for earlier_pod_idx in 0..pod_idx { + if solution.pod_public_statements[earlier_pod_idx].contains(dep_idx) { + needed_earlier_pods.insert(earlier_pod_idx); + found = true; + break; + } + } + // If not found in earlier PODs, it must be local to this POD + if !found && !statements_in_this_pod.contains(dep_idx) { + unreachable!( + "Internal dependency {} for statement {} is neither local \ + nor public in any earlier POD (solver bug)", + dep_idx, stmt_idx + ); + } + } + StatementSource::External(pod_hash) => { + // Find which external POD has this hash + let ext_idx = self + .input_pods + .iter() + .position(|p| p.statements_hash() == *pod_hash); + match ext_idx { + Some(idx) => { + needed_external_pods.insert(idx); + } + None => { + unreachable!( + "External dependency with hash {:?} not found in input PODs", + pod_hash + ); + } + } + } + } + } + } + + // Step 2: Add input PODs to the builder + for &ext_idx in &needed_external_pods { + builder.add_pod(self.input_pods[ext_idx].clone())?; + } + for &earlier_idx in &needed_earlier_pods { + builder.add_pod(earlier_pods[earlier_idx].clone())?; + } + + // Step 3: Build statement source map for determining what needs copying. + // Create a mapping from statement to its source (for copy operations). + // A statement may be both proved locally AND available from an earlier POD. + // We use or_insert to prefer local sources (inserted first) over earlier PODs. + let mut stmt_sources: HashMap = HashMap::new(); + for &stmt_idx in statements_in_this_pod { + stmt_sources.insert(stmt_idx, StmtSource::Local); + } + for earlier_pod_idx in 0..pod_idx { + for &stmt_idx in &solution.pod_public_statements[earlier_pod_idx] { + // Only insert if not already local - or_insert preserves existing entries + stmt_sources.entry(stmt_idx).or_insert(StmtSource::FromPod); + } + } + + // Step 4: Add statements in dependency order. + // Statements are added in ascending index order, which matches dependency order: + // if B depends on A, then A has a lower index and is added first. + let statements_sorted: BTreeSet = statements_in_this_pod.iter().copied().collect(); + let public_set = &solution.pod_public_statements[pod_idx]; + + // Track which statements have been added to this builder + let mut added_statements: HashMap = HashMap::new(); + + for &stmt_idx in &statements_sorted { + // First, ensure all dependencies are available (copy if needed). + // When a dependency comes from an earlier POD, we need CopyStatement to make it + // available in this POD's namespace. The earlier POD is already added as an input, + // but CopyStatement creates a local reference that operations can use. + for dep in &deps.statement_deps[stmt_idx] { + if let StatementSource::Internal(dep_idx) = dep { + if !added_statements.contains_key(dep_idx) { + // Need to copy this statement from an earlier POD + match stmt_sources.get(dep_idx) { + Some(StmtSource::FromPod) => { + // Dependency is from an earlier POD - copy it + let copy_op = Operation( + OperationType::Native(NativeOperation::CopyStatement), + vec![OperationArg::Statement( + self.statements[*dep_idx].clone(), + )], + OperationAux::None, + ); + let copied_stmt = builder + .priv_op(copy_op) + .map_err(|e| Error::Frontend(e.to_string()))?; + added_statements.insert(*dep_idx, copied_stmt); + } + Some(StmtSource::Local) => { + // Local dependency should already be added due to topological + // ordering. If we reach here, there's a bug in the ordering. + unreachable!( + "Local dependency at index {} should already be added \ + when processing statement {} (topological order violation)", + dep_idx, stmt_idx + ); + } + None => { + // Dependency not found in stmt_sources means it's neither + // in this POD nor available from earlier PODs - a solver bug. + unreachable!( + "Dependency at index {} not found in stmt_sources \ + when processing statement {}", + dep_idx, stmt_idx + ); + } + } + } + } + } + + // Now add the actual statement + let is_public = public_set.contains(&stmt_idx); + let mut op = self.operations[stmt_idx].clone(); + + // Remap Statement arguments in the operation to use statements created by MainPodBuilder. + // The original operation references Statements from MultiPodBuilder, but MainPodBuilder + // needs Statements that were either created by it or come from its input PODs. + for arg in &mut op.1 { + if let OperationArg::Statement(ref orig_stmt) = arg { + // Find the original statement's index in MultiPodBuilder + if let Some(orig_idx) = self.statements.iter().position(|s| s == orig_stmt) { + // Get the remapped statement from MainPodBuilder + if let Some(remapped_stmt) = added_statements.get(&orig_idx) { + *arg = OperationArg::Statement(remapped_stmt.clone()); + } + } + } + } + + let stmt = builder + .op(is_public, vec![], op) + .map_err(|e| Error::Frontend(e.to_string()))?; + + added_statements.insert(stmt_idx, stmt); + } + + // Step 5: Prove the POD + let pod = builder + .prove(prover) + .map_err(|e| Error::Frontend(e.to_string()))?; + + Ok(pod) + } + + /// Build mapping from external POD statements to their POD hash. + fn build_external_statement_map(&self) -> HashMap { + let mut map = HashMap::new(); + for pod in &self.input_pods { + let pod_hash = pod.statements_hash(); + for stmt in pod.pod.pub_statements() { + map.insert(stmt, pod_hash); + } + } + map + } +} + +/// Source of a statement within a built POD. +#[derive(Clone, Debug)] +enum StmtSource { + /// Statement is proved locally in this POD. + Local, + /// Statement is copied from an earlier generated POD. + /// (The specific POD index doesn't matter - we only need to know it's not local.) + FromPod, +} + +#[cfg(test)] +mod tests { + use hex::ToHex; + + use super::*; + use crate::{ + backends::plonky2::{ + mock::mainpod::MockProver, primitives::ec::schnorr::SecretKey, signer::Signer, + }, + dict, + examples::MOCK_VD_SET, + frontend::{Operation as FrontendOp, SignedDictBuilder}, + lang::parse, + }; + + #[test] + fn test_single_pod_case() -> Result<()> { + let params = Params::default(); + let vd_set = &*MOCK_VD_SET; + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Create a simple signed dict + let mut signed_builder = SignedDictBuilder::new(¶ms); + signed_builder.insert("value", 42); + let signer = Signer(SecretKey(1u32.into())); + let signed_dict = signed_builder.sign(&signer).unwrap(); + + // Add operation + builder.pub_op(FrontendOp::dict_signed_by(&signed_dict))?; + + // Solve + let solution = builder.solve()?; + assert_eq!(solution.pod_count, 1); + + // Prove + let prover = MockProver {}; + let result = builder.prove(&prover)?; + + assert_eq!(result.pods.len(), 1); + assert!(result.intermediate_pods().is_empty()); + + // Verify the POD + result.pods[0] + .pod + .verify() + .map_err(|e| Error::Frontend(e.to_string()))?; + + Ok(()) + } + + #[test] + fn test_multi_pod_overflow() -> Result<()> { + // Verifies automatic splitting when statements exceed per-POD capacity. + // + // This test uses independent statements with no dependencies - the only + // reason for multiple PODs is the statement limit being exceeded. + let params = Params { + max_statements: 6, + max_public_statements: 2, + // Derived: max_priv_statements = 6 - 2 = 4 + // With 6 private + 2 public = 8 statements, need ceil(8/4) = 2 PODs + max_input_pods: 2, + max_input_pods_public_statements: 4, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Add 6 independent private statements (no dependencies between them) + for i in 0..6i64 { + builder.priv_op(FrontendOp::eq(i, i))?; + } + + // Add 2 public statements for the output POD + builder.pub_op(FrontendOp::eq(100, 100))?; + builder.pub_op(FrontendOp::eq(101, 101))?; + + let pod_count = { + let solution = builder.solve()?; + // 8 statements / 4 per POD = 2 PODs minimum + assert!( + solution.pod_count >= 2, + "Expected at least 2 PODs for 8 statements with max_priv=4, got {}", + solution.pod_count + ); + solution.pod_count + }; + + // Prove and verify + let prover = MockProver {}; + let result = builder.prove(&prover)?; + assert_eq!(result.pods.len(), pod_count); + + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } + + #[test] + fn test_cross_pod_dependencies() -> Result<()> { + // Verifies that a dependency chain can be split across PODs. + // + // This tests the core multi-POD capability: when a dependency chain is too + // long to fit in the output POD, intermediate statements must be proved in + // earlier PODs and made public so the output POD can access them. + // + // Chain: b_out -> a_out -> contains + // - contains: base statement (dict_contains) + // - a_out: custom predicate taking contains as argument + // - b_out: custom predicate taking a_out as argument (OUTPUT-PUBLIC) + // + // With max_priv_statements = 2, we can't fit all 3 in one POD. + // Expected solution: + // - POD 0 (intermediate): contains, a_out (with a_out public) + // - POD 1 (output): copy(a_out), b_out + // + // This requires intermediate PODs to feed INTO the output POD. + + // Tight params to force the dependency chain to be split. + // With max_priv_statements = 2, we can't fit contains + a_out + b_out's + // dependencies all in one POD. + let params = Params { + max_statements: 4, + max_public_statements: 2, + // max_priv_statements = 2 + max_input_pods: 4, + max_input_pods_public_statements: 20, + max_custom_predicate_verifications: 10, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + // pred_a accepts a Contains statement + // pred_b accepts a pred_a statement (Custom statement from pred_a) + let parsed = parse( + r#" + pred_a(X) = AND(Contains(X, "k", 1)) + pred_b(X) = AND(pred_a(X)) + "#, + ¶ms, + &[], + ) + .expect("parse predicates"); + let batch = parsed + .first_batch() + .expect("parse predicates should have a batch"); + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Statement 0: Contains (base of the chain) + let dict = dict!({"k" => 1}); + let contains = builder.priv_op(FrontendOp::dict_contains(dict, "k", 1))?; + + // Statement 1: Custom(pred_a), depends on contains + let a_out = builder.priv_op(FrontendOp::custom( + batch.predicate_ref_by_name("pred_a").unwrap(), + [contains], + ))?; + + // Statement 2: Custom(pred_b), depends on a_out - make this output-public + // This forces the dependency chain to be resolved for the output POD. + let _b_out = builder.pub_op(FrontendOp::custom( + batch.predicate_ref_by_name("pred_b").unwrap(), + [a_out], + ))?; + + // Solve - this finds a multi-POD solution where intermediate PODs + // provide dependencies to the output POD. + let solution = builder.solve()?; + + // Expected: exactly 2 PODs + // - POD 0 (intermediate): statements 0 (contains), 1 (a_out); a_out is public + // - POD 1 (output): statement 2 (b_out); b_out is public + // The output POD copies a_out from POD 0 to satisfy b_out's dependency. + assert_eq!( + solution.pod_count, 2, + "Expected exactly 2 PODs for 3-statement chain with max_priv=2" + ); + + // POD 0 should contain statements 0 and 1 (contains and a_out) + assert!( + solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1), + "POD 0 should contain statements 0 (contains) and 1 (a_out), got {:?}", + solution.pod_statements[0] + ); + + // Statement 1 (a_out) should be public in POD 0 so POD 1 can copy it + assert!( + solution.pod_public_statements[0].contains(&1), + "Statement 1 (a_out) should be public in POD 0" + ); + + // POD 1 (output) should contain statement 2 (b_out) + assert!( + solution.pod_statements[1].contains(&2), + "POD 1 should contain statement 2 (b_out), got {:?}", + solution.pod_statements[1] + ); + + // Statement 2 (b_out) should be public in POD 1 (it's output-public) + assert!( + solution.pod_public_statements[1].contains(&2), + "Statement 2 (b_out) should be public in output POD" + ); + + // Prove and verify all PODs + let prover = MockProver {}; + let result = builder.prove(&prover)?; + + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } + + #[test] + fn test_isolated_pods_when_no_inputs_allowed() -> Result<()> { + // Verifies that PODs are completely isolated when max_input_pods = 0. + // + // With no input PODs allowed, each generated POD must independently prove + // all statements it contains - it cannot reference earlier PODs. + // This is an edge case but validates the input POD constraint. + let params = Params { + max_statements: 4, + max_public_statements: 2, + // Derived: max_priv_statements = 4 - 2 = 2 + max_input_pods: 0, // No input pods allowed - each POD is isolated + max_input_pods_public_statements: 0, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Add 4 independent private statements (no dependencies) + // With max_priv=2, need 2 PODs. Since max_input_pods=0, they can't share. + for i in 0..4i64 { + builder.priv_op(FrontendOp::eq(i, i))?; + } + + // Add 2 public statements for the output POD + builder.pub_op(FrontendOp::eq(100, 100))?; + builder.pub_op(FrontendOp::eq(101, 101))?; + + let solution = builder.solve()?; + + // 6 statements / 2 per POD = 3 PODs minimum + assert!( + solution.pod_count >= 2, + "Expected at least 2 PODs, got {}", + solution.pod_count + ); + + // Prove and verify + let prover = MockProver {}; + let result = builder.prove(&prover)?; + + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } + + #[test] + fn test_zero_public_capacity_fails() { + // Test that setting max_public_statements = 0 with a public operation + // results in a solver error (infeasible configuration). + let params = Params { + max_statements: 10, + max_public_statements: 0, // No public statements allowed + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Try to add a public operation + let _ = builder.pub_op(FrontendOp::eq(1, 1)); + + // Solving should fail because we can't satisfy the public statement requirement + let result = builder.solve(); + assert!( + result.is_err(), + "Expected solver to fail with zero public capacity, but it succeeded" + ); + } + + #[test] + fn test_max_pods_exceeded_error() { + // Test that exceeding max_pods gives a clear error message. + // With max_statements=3 and max_public_statements=1, we have + // max_priv_statements = 2. So 10 statements requires 5 PODs. + let params = Params { + max_statements: 3, + max_public_statements: 1, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + // Set max_pods to 2, which is less than the 5 PODs needed + let options = Options { max_pods: 2 }; + let mut builder = MultiPodBuilder::new_with_options(¶ms, vd_set, options); + + // Add 10 statements (requires 5 PODs). First one is public (required). + let _ = builder.pub_op(FrontendOp::eq(0, 0)); + for i in 1..10 { + let _ = builder.priv_op(FrontendOp::eq(i, i)); + } + + // Solving should fail with a clear error about max_pods + let result = builder.solve(); + assert!( + result.is_err(), + "Expected solver to fail when max_pods exceeded" + ); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("requires at least") && err_msg.contains("PODs"), + "Error message should explain POD requirement: {}", + err_msg + ); + assert!( + err_msg.contains("Options::max_pods"), + "Error message should suggest increasing Options::max_pods: {}", + err_msg + ); + } + + #[test] + fn test_external_pods_only_added_where_needed() -> Result<()> { + // Verifies that external input PODs are only added to generated PODs + // that actually need them based on statement dependencies. + // + // Setup: + // - Two external PODs: ext_A and ext_B, each with a public statement + // - max_input_pods = 1 (each generated POD can only have 1 input POD) + // - Private statements that copy from different external PODs force overflow + // + // With max_input_pods = 1, this only works if each generated POD + // includes only the external POD it actually depends on. + + let params = Params { + max_statements: 4, // Small limit + max_public_statements: 2, // max_priv_statements = 4 - 2 = 2 + max_input_pods: 1, // Only 1 input POD allowed per generated POD + max_input_pods_public_statements: 4, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + // Create external POD A with a public statement + let prover = MockProver {}; + let mut builder_a = MainPodBuilder::new(¶ms, vd_set); + builder_a.pub_op(FrontendOp::eq(100, 100))?; + let ext_pod_a = builder_a.prove(&prover)?; + + // Create external POD B with a public statement + let mut builder_b = MainPodBuilder::new(¶ms, vd_set); + builder_b.pub_op(FrontendOp::eq(200, 200))?; + let ext_pod_b = builder_b.prove(&prover)?; + + // Get the actual statements from the proved PODs + let stmt_a = ext_pod_a + .pod + .pub_statements() + .into_iter() + .find(|s| !s.is_none()) + .expect("ext_pod_a should have a public statement"); + let stmt_b = ext_pod_b + .pod + .pub_statements() + .into_iter() + .find(|s| !s.is_none()) + .expect("ext_pod_b should have a public statement"); + + // Create MultiPodBuilder and add both external PODs + let mut multi_builder = MultiPodBuilder::new(¶ms, vd_set); + multi_builder.add_pod(ext_pod_a.clone()); + multi_builder.add_pod(ext_pod_b.clone()); + + // Add private operations that reference different external PODs. + // These will force multiple PODs due to private statement limits. + multi_builder.priv_op(FrontendOp::copy(stmt_a))?; + multi_builder.priv_op(FrontendOp::eq(101, 101))?; + multi_builder.priv_op(FrontendOp::copy(stmt_b))?; + multi_builder.priv_op(FrontendOp::eq(201, 201))?; + + // Add 2 public statements (within single output POD limit) + multi_builder.pub_op(FrontendOp::eq(300, 300))?; + multi_builder.pub_op(FrontendOp::eq(301, 301))?; + + // With 6 statements and max_priv_statements = 2, we need multiple PODs. + // Each POD should only include the external POD it depends on. + + let solution = multi_builder.solve()?; + assert!( + solution.pod_count >= 2, + "Expected at least 2 PODs, got {}", + solution.pod_count + ); + + let result = multi_builder.prove(&prover)?; + + // Verify all PODs + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } + + #[test] + fn test_private_statement_not_leaked_to_output_pod() -> Result<()> { + // Verifies that private statements do not appear in the output POD's public slots. + // The solver enforces that only user-requested public statements can be + // public in the output POD (the last POD). + + let params = Params { + max_statements: 4, + max_public_statements: 2, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Add private statements (indices 0, 1, 2) - should NOT appear in output POD public slots + builder.priv_op(FrontendOp::eq(100, 100))?; + builder.priv_op(FrontendOp::eq(101, 101))?; + builder.priv_op(FrontendOp::eq(102, 102))?; + + // Add public statements (indices 3, 4) - these SHOULD appear in output POD public slots + builder.pub_op(FrontendOp::eq(200, 200))?; + builder.pub_op(FrontendOp::eq(201, 201))?; + + let solution = builder.solve()?; + + // Check that the output POD's public statements are exactly the user-requested public ones. + // The output POD is always the last one (index pod_count - 1). + let output_pod_idx = solution.pod_count - 1; + let output_public = &solution.pod_public_statements[output_pod_idx]; + assert!( + output_public.contains(&3), + "Public statement 3 should be public in output POD" + ); + assert!( + output_public.contains(&4), + "Public statement 4 should be public in output POD" + ); + + // Private statements should NOT be public in output POD + assert!( + !output_public.contains(&0), + "Private statement 0 should NOT be public in output POD" + ); + assert!( + !output_public.contains(&1), + "Private statement 1 should NOT be public in output POD" + ); + assert!( + !output_public.contains(&2), + "Private statement 2 should NOT be public in output POD" + ); + + Ok(()) + } + + #[test] + fn test_too_many_public_statements_error() -> Result<()> { + // Verifies that requesting more public statements than max_public_statements + // results in a clear error (since all public statements must fit in one output POD). + + let params = Params { + max_statements: 10, + max_public_statements: 2, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Add 3 public statements, but max is 2 + builder.pub_op(FrontendOp::eq(1, 1))?; + builder.pub_op(FrontendOp::eq(2, 2))?; + builder.pub_op(FrontendOp::eq(3, 3))?; + + let result = builder.solve(); + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Too many public statements"), + "Expected 'Too many public statements' error, got: {}", + err_msg + ); + + Ok(()) + } + + #[test] + fn test_external_pods_counted_in_input_limit() -> Result<()> { + // Verifies that external input PODs are counted toward max_input_pods. + // + // Setup: + // - max_input_pods = 2 + // - 3 external PODs (A, B, C), each with a public statement + // - 3 public operations, each copying from a different external POD + // + // Since all 3 must be public in POD 0 (the output POD), and POD 0 would need + // all 3 external PODs as inputs (3 > max_input_pods), this is infeasible. + // The solver should correctly detect and report this. + + let params = Params { + max_statements: 10, + max_public_statements: 5, + max_input_pods: 2, // Only 2 input PODs allowed per generated POD + max_input_pods_public_statements: 10, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + let prover = MockProver {}; + + // Create 3 external PODs, each with a distinct public statement + let mut builder_a = MainPodBuilder::new(¶ms, vd_set); + builder_a.pub_op(FrontendOp::eq(100, 100))?; + let ext_pod_a = builder_a.prove(&prover)?; + + let mut builder_b = MainPodBuilder::new(¶ms, vd_set); + builder_b.pub_op(FrontendOp::eq(200, 200))?; + let ext_pod_b = builder_b.prove(&prover)?; + + let mut builder_c = MainPodBuilder::new(¶ms, vd_set); + builder_c.pub_op(FrontendOp::eq(300, 300))?; + let ext_pod_c = builder_c.prove(&prover)?; + + // Get the actual statements from the proved PODs + let stmt_a = ext_pod_a + .pod + .pub_statements() + .into_iter() + .find(|s| !s.is_none()) + .expect("ext_pod_a should have a public statement"); + let stmt_b = ext_pod_b + .pod + .pub_statements() + .into_iter() + .find(|s| !s.is_none()) + .expect("ext_pod_b should have a public statement"); + let stmt_c = ext_pod_c + .pod + .pub_statements() + .into_iter() + .find(|s| !s.is_none()) + .expect("ext_pod_c should have a public statement"); + + // Create MultiPodBuilder and add all 3 external PODs + let mut multi_builder = MultiPodBuilder::new(¶ms, vd_set); + multi_builder.add_pod(ext_pod_a); + multi_builder.add_pod(ext_pod_b); + multi_builder.add_pod(ext_pod_c); + + // Add public operations that each depend on a different external POD + // All 3 must be public in POD 0, requiring 3 external inputs > max_input_pods + multi_builder.pub_op(FrontendOp::copy(stmt_a))?; + multi_builder.pub_op(FrontendOp::copy(stmt_b))?; + multi_builder.pub_op(FrontendOp::copy(stmt_c))?; + + // Solver should correctly detect infeasibility and return an error + let result = multi_builder.solve(); + assert!( + result.is_err(), + "Expected solver to report infeasibility, but got: {:?}", + result + ); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("No feasible solution"), + "Expected 'No feasible solution' error, got: {}", + err_msg + ); + + Ok(()) + } + + #[test] + fn test_explicit_contains_not_double_counted_as_anchored_key() -> Result<()> { + // Verifies that when a Contains statement is explicitly added and then used + // as an anchored key argument, it's not double-counted in statement limits. + // + // Background: MainPodBuilder auto-inserts Contains statements for anchored keys + // (dict, key pairs used as arguments to gt(), eq(), etc.). But if the Contains + // was already explicitly added, no auto-insertion happens (PR 456). + // + // The solver must NOT count anchored key overhead when the producing Contains + // statement is already in the same POD. + // + // Setup: + // - max_priv_statements = 4 + // - Statement 0: dict_contains (public) - produces anchored key (dict, "x") + // - Statements 1, 2, 3: gt(stmt_0, val) - each references the anchored key + // + // Correct counting for single POD: + // - stmt_sum = 4 (statements 0-3) + // - anchored_key_sum = 0 (statement 0 already provides the anchored key) + // - Total = 4 ≤ max_priv_statements ✓ + // + // Incorrect (double-counting) would give: + // - stmt_sum = 4 + anchored_key_sum = 1 → Total = 5 > 4 ✗ + + let params = Params { + max_statements: 5, + max_public_statements: 1, // max_priv_statements = 5 - 1 = 4 + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Statement 0: public Contains - produces anchored key (dict, "x") + let dict = dict!({"x" => 100}); + let contains_stmt = builder.pub_op(FrontendOp::dict_contains(dict, "x", 100))?; + + // Statements 1, 2, 3: each uses contains_stmt as an anchored key + builder.priv_op(FrontendOp::gt(contains_stmt.clone(), 0))?; + builder.priv_op(FrontendOp::gt(contains_stmt.clone(), 1))?; + builder.priv_op(FrontendOp::gt(contains_stmt, 2))?; + + // With correct counting, all 4 statements fit in 1 POD + let solution = builder.solve()?; + assert_eq!( + solution.pod_count, 1, + "All statements should fit in 1 POD when Contains is not double-counted. \ + Got {} PODs, which suggests the explicit Contains is being incorrectly \ + counted as both a statement AND an anchored key overhead.", + solution.pod_count + ); + + // Verify proving works + let prover = MockProver {}; + let result = builder.prove(&prover)?; + assert_eq!(result.pods.len(), 1); + + result + .output_pod() + .pod + .verify() + .map_err(|e| Error::Frontend(format!("Output POD verification failed: {}", e)))?; + + Ok(()) + } + + #[test] + fn test_anchored_key_overhead_counted_in_statement_limit() -> Result<()> { + // Verifies that anchored key overhead is correctly counted toward statement limits. + // + // When a Contains statement is used as an argument to operations like gt(), + // it creates an "anchored key" reference. If the gt() is proved in a different + // POD than the original Contains, MainPodBuilder auto-inserts a local Contains + // statement for that anchored key. The solver must account for this overhead. + // + // Setup: + // - max_priv_statements = 4 (small limit) + // - Statement A: dict_contains (public, in POD 0) + // - Statement B: eq (public, in POD 0) + // - Statements C, D, E: gt(A, val) - each uses A as an anchored key + // + // The solver must account for the anchored key Contains statements that will + // be auto-inserted when gt operations are proved in PODs other than POD 0. + + let params = Params { + max_statements: 6, + max_public_statements: 2, // max_priv_statements = 6 - 2 = 4 + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Statement A: public Contains - proved in POD 0 + let dict = dict!({"x" => 100}); + let stmt_a = builder.pub_op(FrontendOp::dict_contains(dict, "x", 100))?; + + // Statement B: another public statement in POD 0 + builder.pub_op(FrontendOp::eq(200, 200))?; + + // Statements C, D, E: each uses stmt_a as an anchored key + // When proved in a different POD, each needs a local Contains for the anchored key + builder.priv_op(FrontendOp::gt(stmt_a.clone(), 0))?; + builder.priv_op(FrontendOp::gt(stmt_a.clone(), 1))?; + builder.priv_op(FrontendOp::gt(stmt_a, 2))?; + + let prover = MockProver {}; + let result = builder.prove(&prover)?; + + // Verify all PODs + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } + + #[test] + fn test_mixed_internal_and_external_pods_work_within_limit() -> Result<()> { + // Verifies that scenarios with both internal and external dependencies work + // when the total input count stays within max_input_pods. + // + // Setup: + // - 1 external POD with a public statement + // - 2 public dict_contains statements (uses anchored keys) + // - 2 private gt statements that reference the dict_contains via anchored keys + // - 1 private copy of the external POD's statement + // + // This tests that mixing internal POD dependencies (from earlier generated PODs) + // and external POD dependencies (from user-provided input PODs) works correctly. + + let params = Params { + max_statements: 10, + max_public_statements: 3, // max_priv_statements = 7 + max_input_pods: 3, // Allow up to 3 inputs per POD + max_input_pods_public_statements: 10, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + let prover = MockProver {}; + + // Create 1 external POD + let mut ext_builder = MainPodBuilder::new(¶ms, vd_set); + ext_builder.pub_op(FrontendOp::eq(9999, 9999))?; + let ext_pod = ext_builder.prove(&prover)?; + + let stmt_ext = ext_pod + .pod + .pub_statements() + .into_iter() + .find(|s| !s.is_none()) + .expect("ext_pod should have a public statement"); + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + builder.add_pod(ext_pod); + + // Output POD: public Contains statements + let dict0 = dict!({"x" => 100}); + let dict1 = dict!({"y" => 200}); + let contains_0 = builder.pub_op(FrontendOp::dict_contains(dict0, "x", 100))?; + let contains_1 = builder.pub_op(FrontendOp::dict_contains(dict1, "y", 200))?; + + // Statements that depend on output POD + builder.priv_op(FrontendOp::gt(contains_0, 0))?; + builder.priv_op(FrontendOp::gt(contains_1, 0))?; + + // Depend on external POD + builder.priv_op(FrontendOp::copy(stmt_ext))?; + + // This should succeed - total inputs per POD should stay within limit + let result = builder.prove(&prover)?; + + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } + + #[test] + fn test_signed_by_limit_forces_multi_pod() -> Result<()> { + // Verifies that the solver respects max_signed_by per POD (C6f). + // + // Setup: + // - max_signed_by = 2 (small limit) + // - 4 SignedBy operations + // - Other limits high enough not to interfere + // + // Expected: Solver creates exactly 2 PODs since 4 SignedBy / 2 per POD = 2 PODs + let params = Params { + max_statements: 48, + max_public_statements: 8, + // Derived: max_priv_statements = 48 - 8 = 40 (plenty of room) + max_signed_by: 2, // Small limit to force splitting + max_input_pods: 10, + max_input_pods_public_statements: 20, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Create 4 different signed dicts + for i in 0..4i64 { + let mut signed_builder = SignedDictBuilder::new(¶ms); + signed_builder.insert("id", i); + let signer = Signer(SecretKey((i as u32 + 1).into())); + let signed_dict = signed_builder.sign(&signer).unwrap(); + builder.priv_op(FrontendOp::dict_signed_by(&signed_dict))?; + } + + // Add one public statement for output + builder.pub_op(FrontendOp::eq(100, 100))?; + + let pod_count = { + let solution = builder.solve()?; + // 4 SignedBy / 2 per POD = exactly 2 PODs + assert_eq!( + solution.pod_count, 2, + "Expected exactly 2 PODs for 4 SignedBy with max_signed_by=2, got {}", + solution.pod_count + ); + solution.pod_count + }; + + // Prove and verify + let prover = MockProver {}; + let result = builder.prove(&prover)?; + assert_eq!(result.pods.len(), pod_count); + + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } + + #[test] + fn test_batch_cardinality_forces_multi_pod() -> Result<()> { + // Verifies that the solver respects max_custom_predicate_batches per POD (C7). + // + // Setup: + // - max_custom_predicate_batches = 2 (small limit) + // - 4 different batches, each with one simple predicate + // - 4 operations, one from each batch + // + // Expected: Solver creates exactly 2 PODs since 4 batches / 2 per POD = 2 PODs + let params = Params { + max_statements: 48, + max_public_statements: 8, + max_custom_predicate_batches: 2, // Small limit to force splitting + max_input_pods: 10, + max_input_pods_public_statements: 20, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + // Create 4 separate batches using podlang parser + // Each batch has a simple predicate that checks a Contains statement + let parsed1 = + parse(r#"pred1(A) = AND(Contains(A, "x", 1))"#, ¶ms, &[]).expect("parse batch1"); + let batch1 = parsed1 + .first_batch() + .expect("parse batch1 should have a batch"); + + let parsed2 = + parse(r#"pred2(A) = AND(Contains(A, "x", 2))"#, ¶ms, &[]).expect("parse batch2"); + let batch2 = parsed2 + .first_batch() + .expect("parse batch2 should have a batch"); + + let parsed3 = + parse(r#"pred3(A) = AND(Contains(A, "x", 3))"#, ¶ms, &[]).expect("parse batch3"); + let batch3 = parsed3 + .first_batch() + .expect("parse batch3 should have a batch"); + + let parsed4 = + parse(r#"pred4(A) = AND(Contains(A, "x", 4))"#, ¶ms, &[]).expect("parse batch4"); + let batch4 = parsed4 + .first_batch() + .expect("parse batch4 should have a batch"); + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Add operations using predicates from each batch + // Each custom predicate needs a Contains statement argument + let dict1 = dict!({"x" => 1}); + let contains1 = builder.priv_op(FrontendOp::dict_contains(dict1, "x", 1))?; + builder.priv_op(FrontendOp::custom( + batch1.predicate_ref_by_name("pred1").unwrap(), + [contains1], + ))?; + + let dict2 = dict!({"x" => 2}); + let contains2 = builder.priv_op(FrontendOp::dict_contains(dict2, "x", 2))?; + builder.priv_op(FrontendOp::custom( + batch2.predicate_ref_by_name("pred2").unwrap(), + [contains2], + ))?; + + let dict3 = dict!({"x" => 3}); + let contains3 = builder.priv_op(FrontendOp::dict_contains(dict3, "x", 3))?; + builder.priv_op(FrontendOp::custom( + batch3.predicate_ref_by_name("pred3").unwrap(), + [contains3], + ))?; + + let dict4 = dict!({"x" => 4}); + let contains4 = builder.priv_op(FrontendOp::dict_contains(dict4, "x", 4))?; + builder.pub_op(FrontendOp::custom( + batch4.predicate_ref_by_name("pred4").unwrap(), + [contains4], + ))?; + + let pod_count = { + let solution = builder.solve()?; + // 4 batches / 2 per POD = exactly 2 PODs + assert_eq!( + solution.pod_count, 2, + "Expected exactly 2 PODs for 4 batches with max_custom_predicate_batches=2, got {}", + solution.pod_count + ); + solution.pod_count + }; + + // Prove and verify + let prover = MockProver {}; + let result = builder.prove(&prover)?; + assert_eq!(result.pods.len(), pod_count); + + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } + + #[test] + fn test_long_dependency_chain_spans_multiple_pods() -> Result<()> { + // Verifies that a long dependency chain correctly cascades through multiple + // intermediate PODs before reaching the output POD. + // + // Chain: d_out -> c_out -> b_out -> a_out -> contains (5 statements) + // + // With max_priv_statements = 2, each POD can hold at most 2 statements + // (including copies). Expected solution with 4 PODs: + // - POD 0 (intermediate): contains, a_out (a_out public) + // - POD 1 (intermediate): copy(a_out), b_out (b_out public) + // - POD 2 (intermediate): copy(b_out), c_out (c_out public) + // - POD 3 (output): copy(c_out), d_out + + let params = Params { + max_statements: 4, + max_public_statements: 2, + // max_priv_statements = 2 + max_input_pods: 4, + max_input_pods_public_statements: 20, + max_custom_predicate_verifications: 10, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + // Chain of predicates: each accepts the output of the previous + let parsed = parse( + r#" + pred_a(X) = AND(Contains(X, "k", 1)) + pred_b(X) = AND(pred_a(X)) + pred_c(X) = AND(pred_b(X)) + pred_d(X) = AND(pred_c(X)) + "#, + ¶ms, + &[], + ) + .expect("parse predicates"); + let batch = parsed + .first_batch() + .expect("parse predicates should have a batch"); + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Build the chain: contains -> a_out -> b_out -> c_out -> d_out + let dict = dict!({"k" => 1}); + let contains = builder.priv_op(FrontendOp::dict_contains(dict, "k", 1))?; + + let a_out = builder.priv_op(FrontendOp::custom( + batch.predicate_ref_by_name("pred_a").unwrap(), + [contains], + ))?; + + let b_out = builder.priv_op(FrontendOp::custom( + batch.predicate_ref_by_name("pred_b").unwrap(), + [a_out], + ))?; + + let c_out = builder.priv_op(FrontendOp::custom( + batch.predicate_ref_by_name("pred_c").unwrap(), + [b_out], + ))?; + + let _d_out = builder.pub_op(FrontendOp::custom( + batch.predicate_ref_by_name("pred_d").unwrap(), + [c_out], + ))?; + + let solution = builder.solve()?; + + // Expected: exactly 4 PODs for a 5-statement chain with max_priv=2 + // - POD 0: statements 0 (contains), 1 (a_out); a_out public + // - POD 1: statement 2 (b_out); b_out public (copies a_out) + // - POD 2: statement 3 (c_out); c_out public (copies b_out) + // - POD 3 (output): statement 4 (d_out); d_out public (copies c_out) + assert_eq!( + solution.pod_count, 4, + "Expected exactly 4 PODs for 5-statement chain with max_priv=2" + ); + + // POD 0: contains(0) and a_out(1) + assert!( + solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1), + "POD 0 should contain statements 0 and 1, got {:?}", + solution.pod_statements[0] + ); + assert!( + solution.pod_public_statements[0].contains(&1), + "Statement 1 (a_out) should be public in POD 0" + ); + + // POD 1: b_out(2) + assert!( + solution.pod_statements[1].contains(&2), + "POD 1 should contain statement 2 (b_out), got {:?}", + solution.pod_statements[1] + ); + assert!( + solution.pod_public_statements[1].contains(&2), + "Statement 2 (b_out) should be public in POD 1" + ); + + // POD 2: c_out(3) + assert!( + solution.pod_statements[2].contains(&3), + "POD 2 should contain statement 3 (c_out), got {:?}", + solution.pod_statements[2] + ); + assert!( + solution.pod_public_statements[2].contains(&3), + "Statement 3 (c_out) should be public in POD 2" + ); + + // POD 3 (output): d_out(4) + assert!( + solution.pod_statements[3].contains(&4), + "POD 3 should contain statement 4 (d_out), got {:?}", + solution.pod_statements[3] + ); + assert!( + solution.pod_public_statements[3].contains(&4), + "Statement 4 (d_out) should be public in output POD" + ); + + // Prove and verify all PODs + let prover = MockProver {}; + let result = builder.prove(&prover)?; + + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } + + #[test] + fn test_diamond_dependencies_across_pods() -> Result<()> { + // Verifies that diamond-shaped dependencies work across PODs. + // + // Diamond structure: + // a_out (output) + // / \ + // b_out c_out + // \ / + // contains + // + // Where a_out depends on BOTH b_out and c_out, creating a diamond. + // With tight limits, b_out and c_out may end up in different PODs, + // and the output POD must copy from both. + // + // With max_priv_statements = 3: + // - POD 0: contains, b_out, c_out (b_out and c_out public) - 3 statements + // - POD 1 (output): copy(b_out), copy(c_out), a_out - 3 statements + // Or the solver may find a different arrangement. + + let params = Params { + max_statements: 6, + max_public_statements: 3, + // max_priv_statements = 3 + max_input_pods: 4, + max_input_pods_public_statements: 20, + max_custom_predicate_verifications: 10, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + // pred_a takes TWO custom statement arguments (b_out and c_out) + // pred_b and pred_c each take a Contains + // Note: AND clauses are newline-separated, not comma-separated + let parsed = parse( + r#" + pred_b(X) = AND(Contains(X, "k", 1)) + pred_c(X) = AND(Contains(X, "k", 1)) + pred_a(X, Y) = AND( + pred_b(X) + pred_c(Y) + ) + "#, + ¶ms, + &[], + ) + .expect("parse predicates"); + let batch = parsed + .first_batch() + .expect("parse predicates should have a batch"); + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Base: single contains statement (shared by both branches conceptually, + // but we need separate ones for pred_b and pred_c due to predicate signatures) + let dict = dict!({"k" => 1}); + let contains = builder.priv_op(FrontendOp::dict_contains(dict, "k", 1))?; + + // Left branch: b_out depends on contains + let b_out = builder.priv_op(FrontendOp::custom( + batch.predicate_ref_by_name("pred_b").unwrap(), + [contains.clone()], + ))?; + + // Right branch: c_out depends on contains + let c_out = builder.priv_op(FrontendOp::custom( + batch.predicate_ref_by_name("pred_c").unwrap(), + [contains], + ))?; + + // Top: a_out depends on BOTH b_out and c_out + let _a_out = builder.pub_op(FrontendOp::custom( + batch.predicate_ref_by_name("pred_a").unwrap(), + [b_out, c_out], + ))?; + + let solution = builder.solve()?; + + // Expected: exactly 2 PODs for the diamond + // - POD 0: contains(0), b_out(1), c_out(2); b_out and c_out public + // - POD 1 (output): a_out(3); a_out public (copies b_out and c_out) + assert_eq!( + solution.pod_count, 2, + "Expected exactly 2 PODs for diamond with max_priv=3" + ); + + // POD 0 should contain statements 0, 1, 2 + assert!( + solution.pod_statements[0].contains(&0) + && solution.pod_statements[0].contains(&1) + && solution.pod_statements[0].contains(&2), + "POD 0 should contain statements 0, 1, 2, got {:?}", + solution.pod_statements[0] + ); + + // Statements 1 and 2 (b_out and c_out) should be public in POD 0 + assert!( + solution.pod_public_statements[0].contains(&1) + && solution.pod_public_statements[0].contains(&2), + "Statements 1 and 2 should be public in POD 0" + ); + + // POD 1 (output) should contain statement 3 (a_out) + assert!( + solution.pod_statements[1].contains(&3), + "POD 1 should contain statement 3 (a_out), got {:?}", + solution.pod_statements[1] + ); + + // Statement 3 (a_out) should be public in output POD + assert!( + solution.pod_public_statements[1].contains(&3), + "Statement 3 (a_out) should be public in output POD" + ); + + // Prove and verify all PODs + let prover = MockProver {}; + let result = builder.prove(&prover)?; + + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } + + #[test] + fn test_dependency_chain_with_batch_limit() -> Result<()> { + // Verifies that dependency chains work correctly when combined with + // batch cardinality limits. + // + // Setup: Two predicates in DIFFERENT batches, where pred_b depends on pred_a. + // With max_custom_predicate_batches = 1, pred_a and pred_b must be in + // different PODs due to the batch limit. The dependency must still be + // satisfied via cross-POD copying. + + let params = Params { + max_statements: 10, + max_public_statements: 4, + max_input_pods: 4, + max_input_pods_public_statements: 20, + max_custom_predicate_batches: 1, // Only 1 batch per POD + max_custom_predicate_verifications: 10, + ..Params::default() + }; + let vd_set = &*MOCK_VD_SET; + + // Create two SEPARATE batches (parsed separately to get different batch IDs) + let parsed_a = + parse(r#"pred_a(X) = AND(Contains(X, "k", 1))"#, ¶ms, &[]).expect("parse batch_a"); + let batch_a = parsed_a + .first_batch() + .expect("parse batch_a should have a batch"); + + // batch_b's pred_b accepts pred_a statements + // Must use "use batch" syntax to reference external predicates + let batch_a_id = batch_a.id().encode_hex::(); + let batch_b_src = format!( + r#" + use batch pred_a from 0x{batch_a_id} + pred_b(X) = AND(pred_a(X)) + "# + ); + let parsed_b = + parse(&batch_b_src, ¶ms, std::slice::from_ref(batch_a)).expect("parse batch_b"); + let batch_b = parsed_b + .first_batch() + .expect("parse batch_b should have a batch"); + + let mut builder = MultiPodBuilder::new(¶ms, vd_set); + + // Statement 0: Contains (no batch) + let dict = dict!({"k" => 1}); + let contains = builder.priv_op(FrontendOp::dict_contains(dict, "k", 1))?; + + // Statement 1: pred_a (batch A) + let a_out = builder.priv_op(FrontendOp::custom( + batch_a.predicate_ref_by_name("pred_a").unwrap(), + [contains], + ))?; + + // Statement 2: pred_b (batch B) - depends on a_out + // With max_custom_predicate_batches = 1, this MUST be in a different POD + let _b_out = builder.pub_op(FrontendOp::custom( + batch_b.predicate_ref_by_name("pred_b").unwrap(), + [a_out], + ))?; + + let solution = builder.solve()?; + + // Expected: exactly 2 PODs due to batch limit + // - POD 0: contains(0), a_out(1) using batch_a; a_out public + // - POD 1 (output): b_out(2) using batch_b; b_out public (copies a_out) + // + // Even though max_priv_statements=6 could fit all 3 statements, + // max_custom_predicate_batches=1 forces batch_a and batch_b into different PODs. + assert_eq!( + solution.pod_count, 2, + "Expected exactly 2 PODs due to batch limit (max_custom_predicate_batches=1)" + ); + + // POD 0: contains(0), a_out(1) + assert!( + solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1), + "POD 0 should contain statements 0 and 1, got {:?}", + solution.pod_statements[0] + ); + assert!( + solution.pod_public_statements[0].contains(&1), + "Statement 1 (a_out) should be public in POD 0" + ); + + // POD 1 (output): b_out(2) + assert!( + solution.pod_statements[1].contains(&2), + "POD 1 should contain statement 2 (b_out), got {:?}", + solution.pod_statements[1] + ); + assert!( + solution.pod_public_statements[1].contains(&2), + "Statement 2 (b_out) should be public in output POD" + ); + + // Prove and verify + let prover = MockProver {}; + let result = builder.prove(&prover)?; + + for (i, pod) in result.pods.iter().enumerate() { + pod.pod + .verify() + .map_err(|e| Error::Frontend(format!("POD {} verification failed: {}", i, e)))?; + } + + Ok(()) + } +} diff --git a/src/frontend/multi_pod/solver.rs b/src/frontend/multi_pod/solver.rs new file mode 100644 index 0000000..7920384 --- /dev/null +++ b/src/frontend/multi_pod/solver.rs @@ -0,0 +1,703 @@ +//! MILP solver for multi-POD packing. +//! +//! This module builds and solves a Mixed Integer Linear Program (MILP) to minimize +//! the number of PODs needed to prove a set of statements while respecting resource +//! limits and dependency constraints. +//! +//! # Constraint Overview +//! +//! The solver uses the following constraints (numbered for reference in code comments): +//! +//! - **Constraint 1 (Coverage)**: Each statement must be proved in at least one POD. +//! - **Constraint 2 (Output POD)**: Output-public statements must be public in the last POD. +//! - **Constraint 2b (Privacy)**: Non-output-public statements cannot be public in the output POD. +//! - **Constraint 3 (Public ⇒ Proved)**: A statement can only be public if it's proved there. +//! - **Constraint 4 (POD Existence)**: If any statement is proved in POD p, then p is used. +//! - **Constraint 5 (Dependencies)**: If statement S depends on D and S is proved in POD p, +//! then D must be available: either proved locally in p, or public in some earlier POD. +//! - **Constraint 5b (Copy Tracking)**: Track when dependencies need CopyStatement. +//! - **Constraint 6 (Resource Limits)**: Per-POD limits on statements, public slots, merkle +//! proofs, custom predicates, batches, etc. +//! - **Constraint 7 (Batch Cardinality)**: Limit distinct custom predicate batches per POD. +//! - **Constraint 7b (Anchored Keys)**: Track auto-inserted Contains for anchored key references. +//! - **Constraint 8a (Internal Inputs)**: Track which earlier PODs are used as inputs. +//! - **Constraint 8b (External Inputs)**: Track which external PODs are used as inputs. +//! - **Constraint 8c (Input Limit)**: Total inputs (internal + external) ≤ max_input_pods. +//! - **Constraint 9 (Symmetry Breaking)**: PODs are used in order (0, 1, 2, ...) with no gaps. +//! +//! # Solution Approach +//! +//! The solver uses an incremental approach: it tries solving with the minimum possible +//! number of PODs first, then increments until a feasible solution is found. This is +//! efficient for the common case where few PODs are needed. + +// MILP constraint building uses explicit index loops for clarity +#![allow(clippy::needless_range_loop)] + +use std::collections::BTreeSet; + +use good_lp::{ + constraint, default_solver, variable, Expression, ProblemVariables, Solution, SolverModel, + Variable, +}; +use itertools::Itertools; + +use super::Result; +use crate::{ + frontend::multi_pod::{ + cost::{AnchoredKeyId, CustomBatchId, StatementCost}, + deps::{DependencyGraph, StatementSource}, + }, + middleware::Params, +}; + +/// Threshold for interpreting MILP solver's floating-point results as binary. +/// The solver returns continuous values in [0, 1] for binary variables; +/// values > 0.5 are interpreted as "true" (1), otherwise "false" (0). +const SOLVER_BINARY_THRESHOLD: f64 = 0.5; + +/// Solution from the MILP solver. +#[derive(Clone, Debug)] +pub struct MultiPodSolution { + /// Number of PODs needed. + pub pod_count: usize, + + /// For each statement index, which POD(s) it is proved in. + /// (A statement may be proved in multiple PODs if re-proving is cheaper than copying.) + pub statement_to_pods: Vec>, + + /// For each POD, which statement indices are proved in it. + pub pod_statements: Vec>, + + /// For each POD, which statement indices are public in it. + pub pod_public_statements: Vec>, +} + +/// Input to the MILP solver. +pub struct SolverInput<'a> { + /// Number of statements. + pub num_statements: usize, + + /// Resource costs for each statement. + pub costs: &'a [StatementCost], + + /// Dependency graph. + pub deps: &'a DependencyGraph, + + /// Indices of statements that must be public in output PODs. + pub output_public_indices: &'a [usize], + + /// Parameters defining per-POD limits. + pub params: &'a Params, + + /// Maximum number of PODs the solver will consider. + pub max_pods: usize, + + /// All unique anchored keys referenced by any statement. + /// + /// Each unique (dict, key) pair that is used as an anchored key reference + /// in any operation. When a Contains statement with literal values is used + /// as an argument, it creates an anchored key reference. + pub all_anchored_keys: &'a [AnchoredKeyId], + + /// For each anchored key, the statement index that produces it (if any). + /// + /// When a Contains statement with literal (dict, key, value) args is explicitly + /// added, it "produces" that anchored key. If the producer is in the same POD + /// as statements using the anchored key, no auto-insertion is needed. + /// `anchored_key_producers[i]` corresponds to `all_anchored_keys[i]`. + pub anchored_key_producers: &'a [Option], + + /// Statement content groups for deduplication. + /// + /// Each inner Vec contains statement indices that have identical content. + /// When multiple statements with the same content are proved in the same POD, + /// they only use one statement slot (the POD deduplicates identical statements). + pub statement_content_groups: &'a [Vec], +} + +/// Solve the MILP problem to find optimal POD packing. +/// +/// Uses an incremental approach: tries solving with min_pods first, +/// then increments until a solution is found or target_pods is exceeded. +/// This is efficient for the common case where min_pods is sufficient. +pub fn solve(input: &SolverInput) -> Result { + let n = input.num_statements; + + // Require at least one public statement. A POD with no public statements + // can't prove anything to an external verifier. + if input.output_public_indices.is_empty() { + return Err(super::Error::Solver( + "No public statements requested. Use pub_op() to add at least one statement \ + that should be visible in the output POD." + .to_string(), + )); + } + + // Check that all output-public statements can fit in a single POD + let num_output_public = input.output_public_indices.len(); + if num_output_public > input.params.max_public_statements { + return Err(super::Error::Solver(format!( + "Too many public statements requested: {} requested, but max_public_statements is {}. \ + All public statements must fit in a single output POD.", + num_output_public, input.params.max_public_statements + ))); + } + + // Lower bound on number of PODs needed + // Note: max_priv_statements is the limit on total unique statements per POD + // (public statements are copies from private slots) + let max_stmts_per_pod = input.params.max_priv_statements(); + let min_pods_by_statements = n.div_ceil(max_stmts_per_pod); + let min_pods = min_pods_by_statements.max(1); + + // Check if the problem exceeds the configured max_pods limit + if min_pods > input.max_pods { + return Err(super::Error::Solver(format!( + "Problem requires at least {} PODs, but max_pods is set to {}. \ + Increase Options::max_pods to allow more PODs.", + min_pods, input.max_pods + ))); + } + + // Collect all unique custom batch IDs used + let all_batches: Vec = input + .costs + .iter() + .flat_map(|c| c.custom_batch_ids.iter().cloned()) + .unique() + .collect(); + + // Incremental approach: try solving with increasing POD counts + // Start with min_pods and increment until we find a feasible solution + for target_pods in min_pods..=input.max_pods { + if let Some(solution) = try_solve_with_pods(input, target_pods, &all_batches)? { + return Ok(solution); + } + // Infeasible with target_pods, try more + } + + // No feasible solution found even with max_pods + Err(super::Error::Solver(format!( + "No feasible solution found with up to {} PODs", + input.max_pods + ))) +} + +/// Try to solve the packing problem with exactly `target_pods` PODs. +/// +/// Builds a MILP model with all constraints and attempts to solve it. +/// Returns `Ok(Some(solution))` if a feasible assignment exists, +/// `Ok(None)` if the problem is infeasible with this many PODs. +/// +/// The caller (in `solve()`) handles incrementing `target_pods` when infeasible. +fn try_solve_with_pods( + input: &SolverInput, + target_pods: usize, + all_batches: &[CustomBatchId], +) -> Result> { + // Create variables + let mut vars = ProblemVariables::new(); + let n = input.num_statements; + + // prove[s][p] - statement s is proved in POD p + let prove: Vec> = (0..n) + .map(|_| { + (0..target_pods) + .map(|_| vars.add(variable().binary())) + .collect() + }) + .collect(); + + // public[s][p] - statement s is public in POD p + let public: Vec> = (0..n) + .map(|_| { + (0..target_pods) + .map(|_| vars.add(variable().binary())) + .collect() + }) + .collect(); + + // pod_used[p] - POD p is used + let pod_used: Vec = (0..target_pods) + .map(|_| vars.add(variable().binary())) + .collect(); + + // batch_used[b][p] - custom batch b is used in POD p + let batch_used: Vec> = (0..all_batches.len()) + .map(|_| { + (0..target_pods) + .map(|_| vars.add(variable().binary())) + .collect() + }) + .collect(); + + // anchored_key_used[ak][p] - anchored key ak is used in POD p + // When a statement references an anchored key (via a Contains statement argument), + // that POD must have a Contains statement for that (dict, key) pair. + // MainPodBuilder::add_entries_contains auto-inserts these, and we must account + // for them in the statement count. + let anchored_key_used: Vec> = (0..input.all_anchored_keys.len()) + .map(|_| { + (0..target_pods) + .map(|_| vars.add(variable().binary())) + .collect() + }) + .collect(); + + // uses_input[p][pp] - POD p uses POD pp as an input (pp < p) + // We only create variables for pp < p + let uses_input: Vec> = (0..target_pods) + .map(|p| (0..p).map(|_| vars.add(variable().binary())).collect()) + .collect(); + + // Collect all statement indices that are internal dependencies. + // These are statements that other statements depend on, and may need to be copied + // into PODs where the dependent statement is proved but the dependency is not. + let internal_deps: BTreeSet = input + .deps + .statement_deps + .iter() + .flat_map(|deps| deps.iter()) + .filter_map(|dep| match dep { + StatementSource::Internal(d) => Some(*d), + StatementSource::External(_) => None, + }) + .collect(); + + // needs_copy[d][p] - dependency d needs to be copied into POD p + // This is 1 when: (some statement s in p depends on d) AND (d is not proved in p) + // We only create variables for dependencies that are actually used. + let dep_indices: Vec = internal_deps.iter().copied().collect(); + let needs_copy: Vec> = (0..dep_indices.len()) + .map(|_| { + (0..target_pods) + .map(|_| vars.add(variable().binary())) + .collect() + }) + .collect(); + + // Collect all external POD hashes that statements depend on. + // These are user-provided input PODs referenced by statements. + use crate::middleware::Hash; + let external_pods: Vec = input + .deps + .statement_deps + .iter() + .flat_map(|deps| deps.iter()) + .filter_map(|dep| match dep { + StatementSource::External(h) => Some(*h), + StatementSource::Internal(_) => None, + }) + .collect::>() + .into_iter() + .collect(); + + // uses_external[p][e] - POD p uses external POD e as an input + let uses_external: Vec> = (0..target_pods) + .map(|_| { + (0..external_pods.len()) + .map(|_| vars.add(variable().binary())) + .collect() + }) + .collect(); + + // Map from external POD hash to index in uses_external + let external_to_idx: std::collections::HashMap = external_pods + .iter() + .enumerate() + .map(|(i, h)| (*h, i)) + .collect(); + + // content_group_used[g][p] - content group g has at least one statement proved in POD p + // When multiple statements have identical content, they share a slot in the POD. + // This variable tracks whether at least one statement from each content group is proved. + let num_groups = input.statement_content_groups.len(); + let content_group_used: Vec> = (0..num_groups) + .map(|_| { + (0..target_pods) + .map(|_| vars.add(variable().binary())) + .collect() + }) + .collect(); + + // Objective: minimize number of PODs used + let objective: Expression = pod_used.iter().sum(); + let mut model = vars.minimise(objective).using(default_solver); + + // Constraint 1: Each statement must be proved at least once + for s in 0..n { + let sum: Expression = prove[s].iter().sum(); + model.add_constraint(constraint!(sum >= 1)); + } + + // Constraint 2: Output-public statements must be public in the output POD (last POD) + // The output POD is at index target_pods-1, allowing it to access all earlier PODs + // for dependencies. This ensures exactly one output POD with deterministic location. + let output_pod = target_pods - 1; + for &s in input.output_public_indices { + model.add_constraint(constraint!(public[s][output_pod] == 1)); + } + + // Constraint 2b: Non-output-public statements cannot be public in the output POD + // This prevents private statements from leaking to the output POD's public slots. + for s in 0..n { + if !input.output_public_indices.contains(&s) { + model.add_constraint(constraint!(public[s][output_pod] == 0)); + } + } + + // Constraint 3: Public implies proved + for s in 0..n { + for p in 0..target_pods { + model.add_constraint(constraint!(public[s][p] <= prove[s][p])); + } + } + + // Constraint 4: Pod existence - if any statement is proved in p, p is used + for s in 0..n { + for p in 0..target_pods { + model.add_constraint(constraint!(prove[s][p] <= pod_used[p])); + } + } + + // Constraint 5: Dependencies (works with Constraint 8 to enforce input POD tracking) + // + // If s depends on d (internal), and s is proved in p, then either: + // - d is proved in p (local availability), OR + // - d is public in some earlier POD p' < p (cross-POD availability) + // + // This constraint ensures dependencies are AVAILABLE. It does NOT track which + // earlier PODs are actually used as inputs - that's handled by Constraint 8. + // Together: + // - Constraint 5 ensures the dependency CAN be satisfied + // - Constraint 8 ensures that when we use a statement from earlier POD pp, + // we count pp as an input to pod p (for max_input_pods enforcement) + for s in 0..n { + for dep in &input.deps.statement_deps[s] { + if let StatementSource::Internal(d) = dep { + for p in 0..target_pods { + // prove[s][p] <= prove[d][p] + sum_{p' < p} public[d][p'] + let mut rhs: Expression = prove[*d][p].into(); + for pp in 0..p { + rhs += public[*d][pp]; + } + model.add_constraint(constraint!(prove[s][p] <= rhs)); + } + } + } + } + + // Constraint 5b: needs_copy tracking for cross-POD dependencies + // needs_copy[d][p] = 1 when: some statement s proved in p depends on d, AND d is not proved in p. + // This tracks CopyStatements that will be added during build_single_pod. + for (di, &d) in dep_indices.iter().enumerate() { + for p in 0..target_pods { + // needs_copy[d][p] >= prove[s][p] - prove[d][p] for each s that depends on d + // If s is in p (prove[s][p]=1) and d is not in p (prove[d][p]=0), then needs_copy >= 1 + for s in 0..n { + let depends_on_d = input.deps.statement_deps[s] + .iter() + .any(|dep| matches!(dep, StatementSource::Internal(dep_d) if *dep_d == d)); + if depends_on_d { + model.add_constraint(constraint!( + needs_copy[di][p] >= prove[s][p] - prove[d][p] + )); + } + } + + // needs_copy[d][p] <= 1 - prove[d][p] + // If d is proved locally (prove[d][p]=1), no copy needed (needs_copy <= 0) + model.add_constraint(constraint!(needs_copy[di][p] <= 1 - prove[d][p])); + } + } + + // Constraint 6: Resource limits per POD + // + // 6a-pre: Content group tracking for statement deduplication + // When multiple statement indices have identical content, they share a single slot in the POD. + // content_group_used[g][p] = 1 iff at least one statement from group g is proved in POD p. + for (g, group) in input.statement_content_groups.iter().enumerate() { + for p in 0..target_pods { + // Lower bound: if any statement in the group is proved, the group is used + for &s in group { + model.add_constraint(constraint!(content_group_used[g][p] >= prove[s][p])); + } + // Upper bound: if no statements in the group are proved, the group is not used + let group_prove_sum: Expression = group.iter().map(|&s| prove[s][p]).sum(); + model.add_constraint(constraint!(content_group_used[g][p] <= group_prove_sum)); + } + } + + for p in 0..target_pods { + // 6a: Unique statement count (unique content groups + CopyStatements + anchored key Contains) + // Statements with identical content share a slot, so we count content groups, not indices. + // CopyStatements and anchored key Contains also use statement slots. + // The total must not exceed max_priv_statements (= max_statements - max_public_statements). + let unique_stmt_sum: Expression = (0..num_groups).map(|g| content_group_used[g][p]).sum(); + let copy_sum: Expression = (0..dep_indices.len()).map(|di| needs_copy[di][p]).sum(); + let anchored_key_sum: Expression = (0..input.all_anchored_keys.len()) + .map(|ak| anchored_key_used[ak][p]) + .sum(); + model.add_constraint(constraint!( + unique_stmt_sum + copy_sum + anchored_key_sum + <= (input.params.max_priv_statements() as f64) * pod_used[p] + )); + + // 6b: Public statement count + let pub_sum: Expression = (0..n).map(|s| public[s][p]).sum(); + model.add_constraint(constraint!( + pub_sum <= (input.params.max_public_statements as f64) * pod_used[p] + )); + + // 6c: Merkle proofs + let merkle_sum: Expression = (0..n) + .map(|s| (input.costs[s].merkle_proofs as f64) * prove[s][p]) + .sum(); + model.add_constraint(constraint!( + merkle_sum <= (input.params.max_merkle_proofs_containers as f64) * pod_used[p] + )); + + // 6d: Merkle state transitions + let mst_sum: Expression = (0..n) + .map(|s| (input.costs[s].merkle_state_transitions as f64) * prove[s][p]) + .sum(); + model.add_constraint(constraint!( + mst_sum + <= (input + .params + .max_merkle_tree_state_transition_proofs_containers as f64) + * pod_used[p] + )); + + // 6e: Custom predicate verifications + let cpv_sum: Expression = (0..n) + .map(|s| (input.costs[s].custom_pred_verifications as f64) * prove[s][p]) + .sum(); + model.add_constraint(constraint!( + cpv_sum <= (input.params.max_custom_predicate_verifications as f64) * pod_used[p] + )); + + // 6f: SignedBy + let sb_sum: Expression = (0..n) + .map(|s| (input.costs[s].signed_by as f64) * prove[s][p]) + .sum(); + model.add_constraint(constraint!( + sb_sum <= (input.params.max_signed_by as f64) * pod_used[p] + )); + + // 6g: PublicKeyOf + let pko_sum: Expression = (0..n) + .map(|s| (input.costs[s].public_key_of as f64) * prove[s][p]) + .sum(); + model.add_constraint(constraint!( + pko_sum <= (input.params.max_public_key_of as f64) * pod_used[p] + )); + } + + // Constraint 7: Batch cardinality + // batch_used[b][p] >= prove[s][p] for all s that use batch b (batch is used if any statement uses it) + // batch_used[b][p] <= sum of prove[s][p] for all s using batch b (batch is 0 if no statements use it) + for (b, batch_id) in all_batches.iter().enumerate() { + for p in 0..target_pods { + let mut sum: Expression = 0.into(); + for s in 0..n { + if input.costs[s].custom_batch_ids.contains(batch_id) { + model.add_constraint(constraint!(batch_used[b][p] >= prove[s][p])); + sum += prove[s][p]; + } + } + model.add_constraint(constraint!(batch_used[b][p] <= sum)); + } + } + + // Batch count per POD + for p in 0..target_pods { + let batch_sum: Expression = (0..all_batches.len()).map(|b| batch_used[b][p]).sum(); + model.add_constraint(constraint!( + batch_sum <= (input.params.max_custom_predicate_batches as f64) * pod_used[p] + )); + } + + // Constraint 7b: Anchored key tracking + // + // anchored_key_used[ak][p] = 1 when auto-insertion of a Contains is needed for anchored key ak in POD p. + // This happens when: some statement using ak is in POD p, AND the producing Contains is NOT in POD p. + // + // If a Contains statement explicitly produces ak (anchored_key_producers[ak] = Some(prod_idx)): + // - Lower: anchored_key_used[ak][p] >= prove[s][p] - prove[prod_idx][p] for all s using ak + // - Upper: anchored_key_used[ak][p] <= 1 - prove[prod_idx][p] + // This ensures overhead is 0 when the producer is in the same POD. + // + // If no Contains produces ak (anchored_key_producers[ak] = None): + // - Lower: anchored_key_used[ak][p] >= prove[s][p] for all s using ak + // - Upper: anchored_key_used[ak][p] <= sum of prove[s][p] for all s using ak + // Auto-insertion is always needed when any user is present. + for (ak_idx, ak) in input.all_anchored_keys.iter().enumerate() { + let producer = input.anchored_key_producers[ak_idx]; + + for p in 0..target_pods { + let mut user_sum: Expression = 0.into(); + for s in 0..n { + if input.costs[s].anchored_keys.contains(ak) { + if let Some(prod_idx) = producer { + // Producer exists: only count overhead if producer not in this POD + model.add_constraint(constraint!( + anchored_key_used[ak_idx][p] >= prove[s][p] - prove[prod_idx][p] + )); + } else { + // No producer: always need auto-insertion if user is present + model.add_constraint(constraint!( + anchored_key_used[ak_idx][p] >= prove[s][p] + )); + } + user_sum += prove[s][p]; + } + } + + if let Some(prod_idx) = producer { + // If producer is in POD, no auto-insertion needed (overhead = 0) + model.add_constraint(constraint!( + anchored_key_used[ak_idx][p] <= 1 - prove[prod_idx][p] + )); + } else { + // No producer: overhead is bounded by whether any user is present + model.add_constraint(constraint!(anchored_key_used[ak_idx][p] <= user_sum)); + } + } + } + + // Constraint 8a: Internal input POD tracking using uses_input + // uses_input[p][pp] >= prove[s][p] + public[d][pp] - prove[d][p] - 1 + // for each dependency (s depends on d) + // + // If s is proved in p and d is public in pp, we need pp as input UNLESS d is also + // proved locally in p. Subtracting prove[d][p] ensures that when d is re-proved + // locally (prove[d][p] = 1), the constraint becomes uses_input >= 0, which is + // always satisfied without forcing the input relationship. + for s in 0..n { + for dep in &input.deps.statement_deps[s] { + if let StatementSource::Internal(d) = dep { + for p in 1..target_pods { + for pp in 0..p { + model.add_constraint(constraint!( + uses_input[p][pp] >= prove[s][p] + public[*d][pp] - prove[*d][p] - 1.0 + )); + } + } + } + } + } + + // Constraint 8b: External input POD tracking using uses_external + // If statement s is proved in POD p and s depends on external POD e, then uses_external[p][e] = 1 + for s in 0..n { + for dep in &input.deps.statement_deps[s] { + if let StatementSource::External(h) = dep { + if let Some(&e) = external_to_idx.get(h) { + for p in 0..target_pods { + // If s is proved in p, then uses_external[p][e] = 1 + model.add_constraint(constraint!(uses_external[p][e] >= prove[s][p])); + } + } + } + } + } + + // Constraint 8c: Total input PODs (internal + external) must not exceed max_input_pods + // For each POD p, the total number of inputs is: + // - Internal inputs: PODs pp < p that provide public statements used by p + // - External inputs: User-provided PODs referenced by statements in p + for p in 0..target_pods { + let internal_sum: Expression = if p > 0 { + (0..p).map(|pp| uses_input[p][pp]).sum() + } else { + 0.into() + }; + let external_sum: Expression = (0..external_pods.len()).map(|e| uses_external[p][e]).sum(); + model.add_constraint(constraint!( + internal_sum + external_sum <= (input.params.max_input_pods as f64) * pod_used[p] + )); + } + + // Constraint 9: Symmetry breaking - use PODs in order + // pod_used[p] >= pod_used[p+1] + for p in 0..target_pods - 1 { + model.add_constraint(constraint!(pod_used[p] >= pod_used[p + 1])); + } + + // Solve + let solution = match model.solve() { + Ok(sol) => sol, + Err(_) => { + // Infeasible with this number of PODs, try more + return Ok(None); + } + }; + + // Extract solution: count how many PODs are used. + // Symmetry breaking (Constraint 9) ensures PODs are used in order with no gaps. + let mut pod_count = 0; + for p in 0..target_pods { + if solution.value(pod_used[p]) > SOLVER_BINARY_THRESHOLD { + pod_count += 1; + } + } + + let mut statement_to_pods: Vec> = vec![vec![]; n]; + let mut pod_statements: Vec> = vec![vec![]; pod_count]; + let mut pod_public_statements: Vec> = vec![BTreeSet::new(); pod_count]; + + for s in 0..n { + for p in 0..pod_count { + if solution.value(prove[s][p]) > SOLVER_BINARY_THRESHOLD { + statement_to_pods[s].push(p); + pod_statements[p].push(s); + } + if solution.value(public[s][p]) > SOLVER_BINARY_THRESHOLD { + pod_public_statements[p].insert(s); + } + } + } + + Ok(Some(MultiPodSolution { + pod_count, + statement_to_pods, + pod_statements, + pod_public_statements, + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_no_public_statements_error() { + // At least one public statement is required - otherwise the POD can't + // prove anything to an external verifier. + let params = Params::default(); + let deps = DependencyGraph { + statement_deps: vec![], + }; + + let input = SolverInput { + num_statements: 0, + costs: &[], + deps: &deps, + output_public_indices: &[], + params: ¶ms, + max_pods: 20, + all_anchored_keys: &[], + anchored_key_producers: &[], + statement_content_groups: &[], + }; + + let result = solve(&input); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("No public statements requested")); + } +}