Multipod external fix (#485)

This commit is contained in:
Rob Knight 2026-02-23 11:26:39 +00:00 committed by GitHub
parent c185d27344
commit a389ff1dc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 692 additions and 113 deletions

View file

@ -11,13 +11,22 @@ use crate::{
middleware::{Hash, Statement}, middleware::{Hash, Statement},
}; };
/// Reference to a statement sourced from an external input POD.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct ExternalDependency {
/// Hash of the external POD containing `statement` in its public set.
pub pod_hash: Hash,
/// The statement value itself.
pub statement: Statement,
}
/// Represents a source of a statement dependency. /// Represents a source of a statement dependency.
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum StatementSource { pub enum StatementSource {
/// Statement created within this builder at the given index. /// Statement created within this builder at the given index.
Internal(usize), Internal(usize),
/// Statement from an external input POD (identified by POD hash). /// Statement from an external input POD.
External(Hash), External(ExternalDependency),
} }
/// Dependency graph for all statements in a builder. /// Dependency graph for all statements in a builder.
@ -87,7 +96,10 @@ impl DependencyGraph {
// Check if this is from an external POD // Check if this is from an external POD
if let Some(&pod_hash) = external_pod_statements.get(dep_stmt) { if let Some(&pod_hash) = external_pod_statements.get(dep_stmt) {
deps.push(StatementSource::External(pod_hash)); deps.push(StatementSource::External(ExternalDependency {
pod_hash,
statement: dep_stmt.clone(),
}));
} else if AnchoredKeyId::from_contains_statement(dep_stmt).is_some() { } else if AnchoredKeyId::from_contains_statement(dep_stmt).is_some() {
// Anchored-key Contains args may be implicit requirements that are // Anchored-key Contains args may be implicit requirements that are
// auto-materialized by MainPodBuilder. They are handled by anchored-key // auto-materialized by MainPodBuilder. They are handled by anchored-key

View file

@ -295,7 +295,8 @@ impl SolvedMultiPod {
added_statements_by_content.insert(original_stmt, stmt); added_statements_by_content.insert(original_stmt, stmt);
} }
// For the output pod, make statements public in the original order // For the output pod, make statements public in the original order.
// Intermediate pods use the solver-selected public set.
if pod_idx == solution.pod_count - 1 { if pod_idx == solution.pod_count - 1 {
for idx in &self.output_public_indices { for idx in &self.output_public_indices {
let stmt = added_statements_by_content let stmt = added_statements_by_content
@ -312,6 +313,13 @@ impl SolvedMultiPod {
} }
} }
// Forward external premises only when the solver selected them as public
// for this POD. These do not require local proving in this POD.
for ext_premise_idx in &solution.pod_public_external_premises[pod_idx] {
let ext_premise = &solution.external_premises[*ext_premise_idx];
builder.reveal(&ext_premise.statement);
}
// Step 4: Prove the POD // Step 4: Prove the POD
let pod = builder.prove(prover)?; let pod = builder.prove(prover)?;
@ -323,36 +331,17 @@ impl SolvedMultiPod {
/// Returns (internal_pod_indices, external_pod_indices). /// Returns (internal_pod_indices, external_pod_indices).
fn compute_pod_inputs(&self, pod_idx: usize) -> (BTreeSet<usize>, BTreeSet<usize>) { fn compute_pod_inputs(&self, pod_idx: usize) -> (BTreeSet<usize>, BTreeSet<usize>) {
let solution = &self.solution; let solution = &self.solution;
let statements_in_pod = &solution.pod_statements[pod_idx]; let internal_pods = solution.pod_internal_inputs[pod_idx].clone();
let mut internal_pods: BTreeSet<usize> = BTreeSet::new();
let mut external_pods: BTreeSet<usize> = BTreeSet::new(); let mut external_pods: BTreeSet<usize> = BTreeSet::new();
for &stmt_idx in statements_in_pod { for external_idx in &solution.pod_external_inputs[pod_idx] {
for dep in &self.deps.statement_deps[stmt_idx] { let pod_hash = solution.external_pod_hashes[*external_idx];
match dep { let idx = self
StatementSource::Internal(dep_idx) => { .input_pods
// Check if dependency is in an earlier POD (not local) .iter()
if !statements_in_pod.contains(dep_idx) { .position(|p| p.statements_hash() == pod_hash)
let earlier_pod_idx = (0..pod_idx) .expect("external pod hash from solver solution");
.find(|earlier_pod_idx| { external_pods.insert(idx);
solution.pod_public_statements[*earlier_pod_idx]
.contains(dep_idx)
})
.expect("internal pod with dependency statement");
internal_pods.insert(earlier_pod_idx);
}
}
StatementSource::External(pod_hash) => {
let idx = self
.input_pods
.iter()
.position(|p| p.statements_hash() == *pod_hash)
.expect("external pod with dependency statement");
external_pods.insert(idx);
}
}
}
} }
assert!(internal_pods.len() + external_pods.len() <= self.params.max_input_pods); assert!(internal_pods.len() + external_pods.len() <= self.params.max_input_pods);
@ -1196,16 +1185,16 @@ mod tests {
#[test] #[test]
fn test_external_pods_counted_in_input_limit() -> Result<()> { fn test_external_pods_counted_in_input_limit() -> Result<()> {
// Verifies that external input PODs are counted toward max_input_pods. // Verifies that external input PODs are counted toward max_input_pods while
// still allowing the solver to route external premises through intermediate PODs.
// //
// Setup: // Setup:
// - max_input_pods = 2 // - max_input_pods = 2
// - 3 external PODs (A, B, C), each with a public statement // - 3 external PODs (A, B, C), each with a public statement
// - 3 public operations, each copying from a different external POD // - 3 public operations, each copying from a different external POD
// //
// Since all 3 must be public in POD 0 (the output POD), and POD 0 would need // A direct 1-POD layout would need 3 external inputs in the output POD (infeasible),
// all 3 external PODs as inputs (3 > max_input_pods), this is infeasible. // so the solver should split the work and keep each generated POD within input limits.
// The solver should correctly detect and report this.
let params = Params { let params = Params {
max_statements: 10, max_statements: 10,
@ -1256,26 +1245,24 @@ mod tests {
multi_builder.add_pod(ext_pod_b)?; multi_builder.add_pod(ext_pod_b)?;
multi_builder.add_pod(ext_pod_c)?; multi_builder.add_pod(ext_pod_c)?;
// Add public operations that each depend on a different external POD // Add public operations that each depend on a different external POD.
// All 3 must be public in POD 0, requiring 3 external inputs > max_input_pods
multi_builder.pub_op(FrontendOp::copy(stmt_a))?; multi_builder.pub_op(FrontendOp::copy(stmt_a))?;
multi_builder.pub_op(FrontendOp::copy(stmt_b))?; multi_builder.pub_op(FrontendOp::copy(stmt_b))?;
multi_builder.pub_op(FrontendOp::copy(stmt_c))?; multi_builder.pub_op(FrontendOp::copy(stmt_c))?;
// Solver should correctly detect infeasibility and return an error // Solver should find a feasible multi-POD layout that respects input limits.
let result = multi_builder.solve(); let solved = multi_builder.solve()?;
assert!( assert!(
result.is_err(), solved.solution().pod_count >= 2,
"Expected solver to report infeasibility, but got: {:?}", "Expected at least 2 PODs to satisfy max_input_pods=2 with 3 external sources"
result
); );
let err_msg = result.unwrap_err().to_string(); let result = solved.prove(&prover)?;
assert!( for (i, pod) in result.pods.iter().enumerate() {
err_msg.contains("No feasible solution"), pod.pod
"Expected 'No feasible solution' error, got: {}", .verify()
err_msg .unwrap_or_else(|_| panic!("POD {} verification failed", i));
); }
Ok(()) Ok(())
} }

View file

@ -20,9 +20,14 @@
//! - **Constraint 7 (Batch Cardinality)**: Limit distinct custom predicate batches per POD. //! - **Constraint 7 (Batch Cardinality)**: Limit distinct custom predicate batches per POD.
//! - **Constraint 7b (Anchored Keys)**: Track auto-inserted Contains for anchored key references. //! - **Constraint 7b (Anchored Keys)**: Track auto-inserted Contains for anchored key references.
//! - **Constraint 8a (Internal Inputs)**: Track which earlier PODs are used as inputs. //! - **Constraint 8a (Internal Inputs)**: Track which earlier PODs are used as inputs.
//! - **Constraint 8b (External Inputs)**: Track which external PODs are used as inputs. //! - **Constraint 8b (External Dep Inputs)**: Track when external dependencies are sourced from
//! - **Constraint 8c (Input Limit)**: Total inputs (internal + external) ≤ max_input_pods. //! earlier PODs instead of direct external inputs.
//! - **Constraint 8c (External Forward Inputs)**: Track inputs required when forwarding external
//! premises publicly across PODs.
//! - **Constraint 8d (Input Limit)**: Total inputs (internal + external) ≤ max_input_pods.
//! - **Constraint 9 (Symmetry Breaking)**: PODs are used in order (0, 1, 2, ...) with no gaps. //! - **Constraint 9 (Symmetry Breaking)**: PODs are used in order (0, 1, 2, ...) with no gaps.
//! - **Constraint 10 (External Public Availability)**: External premises can be made public only
//! when available in that POD.
//! //!
//! # Solution Approach //! # Solution Approach
//! //!
@ -33,11 +38,14 @@
// MILP constraint building uses explicit index loops for clarity // MILP constraint building uses explicit index loops for clarity
#![allow(clippy::needless_range_loop)] #![allow(clippy::needless_range_loop)]
use std::collections::BTreeSet; use std::{
collections::{BTreeSet, HashMap},
time::Instant,
};
use good_lp::{ use good_lp::{
constraint, default_solver, variable, Expression, ProblemVariables, Solution, SolverModel, constraint, default_solver, variable, Expression, ProblemVariables, ResolutionError, Solution,
Variable, SolverModel, Variable,
}; };
use itertools::Itertools; use itertools::Itertools;
@ -45,9 +53,9 @@ use super::Result;
use crate::{ use crate::{
frontend::multi_pod::{ frontend::multi_pod::{
cost::{AnchoredKeyId, CustomBatchId, StatementCost}, cost::{AnchoredKeyId, CustomBatchId, StatementCost},
deps::{DependencyGraph, StatementSource}, deps::{DependencyGraph, ExternalDependency, StatementSource},
}, },
middleware::Params, middleware::{Hash, Params},
}; };
/// Threshold for interpreting MILP solver's floating-point results as binary. /// Threshold for interpreting MILP solver's floating-point results as binary.
@ -55,6 +63,207 @@ use crate::{
/// values > 0.5 are interpreted as "true" (1), otherwise "false" (0). /// values > 0.5 are interpreted as "true" (1), otherwise "false" (0).
const SOLVER_BINARY_THRESHOLD: f64 = 0.5; const SOLVER_BINARY_THRESHOLD: f64 = 0.5;
#[derive(Clone, Copy, Debug, Default)]
struct ResourceTotals {
merkle_proofs: usize,
merkle_state_transitions: usize,
custom_pred_verifications: usize,
signed_by: usize,
public_key_of: usize,
}
impl ResourceTotals {
fn from_costs(costs: &[StatementCost]) -> Self {
costs.iter().fold(Self::default(), |mut totals, c| {
totals.merkle_proofs += c.merkle_proofs;
totals.merkle_state_transitions += c.merkle_state_transitions;
totals.custom_pred_verifications += c.custom_pred_verifications;
totals.signed_by += c.signed_by;
totals.public_key_of += c.public_key_of;
totals
})
}
}
#[derive(Clone, Copy, Debug, Default)]
struct DependencyStats {
internal_edges: usize,
external_edges: usize,
}
#[derive(Clone, Copy, Debug)]
struct SolveDebugContext {
dep_stats: DependencyStats,
batch_memberships: usize,
anchored_key_memberships: usize,
}
#[derive(Clone, Copy, Debug, Default)]
struct ModelSizeEstimate {
vars_prove: usize,
vars_public: usize,
vars_public_external: usize,
vars_pod_used: usize,
vars_batch_used: usize,
vars_anchored_key_used: usize,
vars_uses_input: usize,
vars_uses_external: usize,
vars_content_group_used: usize,
vars_total: usize,
c1_coverage: usize,
c2_output_public: usize,
c2b_output_privacy: usize,
c3_public_implies_proved: usize,
c4_pod_existence: usize,
c5_internal_dependencies: usize,
c5_external_dependencies: usize,
c6_pre_content_group: usize,
c6_resource_limits: usize,
c7_batch_cardinality: usize,
c7b_anchored_key_tracking: usize,
c8a_internal_inputs: usize,
c8b_external_dep_inputs: usize,
c8c_external_forward_inputs: usize,
c8d_input_limit: usize,
c10_external_public_availability: usize,
c10b_external_public_implies_pod_used: usize,
c9_symmetry_breaking: usize,
constraints_total: usize,
}
impl ModelSizeEstimate {
fn for_target_pods(
input: &SolverInput,
target_pods: usize,
all_batches_len: usize,
external_pods_len: usize,
external_premises_len: usize,
debug_ctx: &SolveDebugContext,
) -> Self {
let n = input.num_statements;
let num_groups = input.statement_content_groups.len();
let num_anchored_keys = input.all_anchored_keys.len();
let triangular_k = target_pods * target_pods.saturating_sub(1) / 2;
let vars_prove = n * target_pods;
let vars_public = n * target_pods;
let vars_public_external = external_premises_len * target_pods;
let vars_pod_used = target_pods;
let vars_batch_used = all_batches_len * target_pods;
let vars_anchored_key_used = num_anchored_keys * target_pods;
let vars_uses_input = triangular_k;
let vars_uses_external = external_pods_len * target_pods;
let vars_content_group_used = num_groups * target_pods;
let vars_total = vars_prove
+ vars_public
+ vars_public_external
+ vars_pod_used
+ vars_batch_used
+ vars_anchored_key_used
+ vars_uses_input
+ vars_uses_external
+ vars_content_group_used;
let c1_coverage = n;
let c2_output_public = input.output_public_indices.len();
let c2b_output_privacy = n.saturating_sub(c2_output_public);
let c3_public_implies_proved = n * target_pods;
let c4_pod_existence = n * target_pods;
let c5_internal_dependencies = debug_ctx.dep_stats.internal_edges * target_pods;
let c5_external_dependencies = debug_ctx.dep_stats.external_edges * target_pods;
let c6_pre_content_group = (n * target_pods) + (num_groups * target_pods);
let c6_resource_limits = 7 * target_pods;
let c7_batch_cardinality =
(debug_ctx.batch_memberships * target_pods) + (all_batches_len * target_pods);
let c7b_anchored_key_tracking =
(debug_ctx.anchored_key_memberships * target_pods) + (num_anchored_keys * target_pods);
let c8a_internal_inputs = debug_ctx.dep_stats.internal_edges * triangular_k;
let c8b_external_dep_inputs = debug_ctx.dep_stats.external_edges * triangular_k;
let c8c_external_forward_inputs = external_premises_len * triangular_k;
let c8d_input_limit = target_pods;
let c10_external_public_availability = external_premises_len * target_pods;
let c10b_external_public_implies_pod_used = external_premises_len * target_pods;
let c9_symmetry_breaking = target_pods.saturating_sub(1);
let constraints_total = c1_coverage
+ c2_output_public
+ c2b_output_privacy
+ c3_public_implies_proved
+ c4_pod_existence
+ c5_internal_dependencies
+ c5_external_dependencies
+ c6_pre_content_group
+ c6_resource_limits
+ c7_batch_cardinality
+ c7b_anchored_key_tracking
+ c8a_internal_inputs
+ c8b_external_dep_inputs
+ c8c_external_forward_inputs
+ c8d_input_limit
+ c10_external_public_availability
+ c10b_external_public_implies_pod_used
+ c9_symmetry_breaking;
Self {
vars_prove,
vars_public,
vars_public_external,
vars_pod_used,
vars_batch_used,
vars_anchored_key_used,
vars_uses_input,
vars_uses_external,
vars_content_group_used,
vars_total,
c1_coverage,
c2_output_public,
c2b_output_privacy,
c3_public_implies_proved,
c4_pod_existence,
c5_internal_dependencies,
c5_external_dependencies,
c6_pre_content_group,
c6_resource_limits,
c7_batch_cardinality,
c7b_anchored_key_tracking,
c8a_internal_inputs,
c8b_external_dep_inputs,
c8c_external_forward_inputs,
c8d_input_limit,
c10_external_public_availability,
c10b_external_public_implies_pod_used,
c9_symmetry_breaking,
constraints_total,
}
}
}
fn dependency_stats(deps: &DependencyGraph) -> DependencyStats {
let mut stats = DependencyStats::default();
for dep_list in &deps.statement_deps {
for dep in dep_list {
match dep {
StatementSource::Internal(_) => stats.internal_edges += 1,
StatementSource::External(_) => stats.external_edges += 1,
}
}
}
stats
}
fn lower_bound_from_total(total: usize, per_pod_limit: usize) -> Option<usize> {
if total == 0 {
Some(0)
} else if per_pod_limit == 0 {
None
} else {
Some(total.div_ceil(per_pod_limit))
}
}
fn format_lower_bound(lb: Option<usize>) -> String {
lb.map_or_else(|| "inf".to_string(), |v| v.to_string())
}
/// Solution from the MILP solver. /// Solution from the MILP solver.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct MultiPodSolution { pub struct MultiPodSolution {
@ -70,6 +279,24 @@ pub struct MultiPodSolution {
/// For each POD, which statement indices are public in it. /// For each POD, which statement indices are public in it.
pub pod_public_statements: Vec<BTreeSet<usize>>, pub pod_public_statements: Vec<BTreeSet<usize>>,
/// For each POD, which earlier internal PODs are used as inputs.
pub pod_internal_inputs: Vec<BTreeSet<usize>>,
/// External input POD hashes referenced by the solution.
/// `pod_external_inputs[*]` stores indices into this vector.
pub external_pod_hashes: Vec<Hash>,
/// For each POD, which external input PODs are used as inputs.
/// Indices are into `external_pod_hashes`.
pub pod_external_inputs: Vec<BTreeSet<usize>>,
/// Unique external premises referenced by statement dependencies.
pub external_premises: Vec<ExternalDependency>,
/// For each POD, which external premises are exposed as public statements.
/// Indices are into `external_premises`.
pub pod_public_external_premises: Vec<BTreeSet<usize>>,
} }
/// Input to the MILP solver. /// Input to the MILP solver.
@ -167,10 +394,130 @@ pub fn solve(input: &SolverInput) -> Result<MultiPodSolution> {
.unique() .unique()
.collect(); .collect();
// Collect all unique external POD hashes and external premises referenced by dependencies.
let mut external_pods: Vec<Hash> = Vec::new();
let mut external_pod_to_idx: HashMap<Hash, usize> = HashMap::new();
let mut external_premises: Vec<ExternalDependency> = Vec::new();
let mut external_premise_to_idx: HashMap<ExternalDependency, usize> = HashMap::new();
for deps in &input.deps.statement_deps {
for dep in deps {
if let StatementSource::External(ext) = dep {
if let std::collections::hash_map::Entry::Vacant(e) =
external_pod_to_idx.entry(ext.pod_hash)
{
e.insert(external_pods.len());
external_pods.push(ext.pod_hash);
}
if !external_premise_to_idx.contains_key(ext) {
external_premise_to_idx.insert(ext.clone(), external_premises.len());
external_premises.push(ext.clone());
}
}
}
}
let dep_stats = dependency_stats(input.deps);
let batch_memberships: usize = input.costs.iter().map(|c| c.custom_batch_ids.len()).sum();
let anchored_key_memberships: usize = input.costs.iter().map(|c| c.anchored_keys.len()).sum();
let debug_ctx = SolveDebugContext {
dep_stats,
batch_memberships,
anchored_key_memberships,
};
if log::log_enabled!(log::Level::Debug) {
let resource_totals = ResourceTotals::from_costs(input.costs);
let lb_statement_groups =
lower_bound_from_total(input.statement_content_groups.len(), max_stmts_per_pod);
let lb_merkle = lower_bound_from_total(
resource_totals.merkle_proofs,
input.params.max_merkle_proofs_containers,
);
let lb_merkle_transitions = lower_bound_from_total(
resource_totals.merkle_state_transitions,
input
.params
.max_merkle_tree_state_transition_proofs_containers,
);
let lb_custom_pred_verifications = lower_bound_from_total(
resource_totals.custom_pred_verifications,
input.params.max_custom_predicate_verifications,
);
let lb_signed_by =
lower_bound_from_total(resource_totals.signed_by, input.params.max_signed_by);
let lb_public_key_of = lower_bound_from_total(
resource_totals.public_key_of,
input.params.max_public_key_of,
);
let lower_bound_candidates = [
("statements_raw", Some(min_pods_by_statements)),
("merkle_proofs", lb_merkle),
("merkle_state_transitions", lb_merkle_transitions),
("custom_pred_verifications", lb_custom_pred_verifications),
("signed_by", lb_signed_by),
("public_key_of", lb_public_key_of),
];
let dominant_lb = lower_bound_candidates
.iter()
.max_by_key(|(_, lb)| lb.unwrap_or(usize::MAX))
.expect("non-empty lower-bound candidate list");
log::debug!(
"MILP summary: statements={} output_public={} content_groups={} anchored_keys={} \
batches={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \
external_premises={} search_min_pods={} max_pods={}",
n,
num_output_public,
input.statement_content_groups.len(),
input.all_anchored_keys.len(),
all_batches.len(),
dep_stats.internal_edges,
dep_stats.external_edges,
external_pods.len(),
external_premises.len(),
min_pods,
input.max_pods
);
log::debug!(
"MILP resource totals: merkle_proofs={} merkle_state_transitions={} \
custom_pred_verifications={} signed_by={} public_key_of={} \
batch_memberships={} anchored_key_memberships={}",
resource_totals.merkle_proofs,
resource_totals.merkle_state_transitions,
resource_totals.custom_pred_verifications,
resource_totals.signed_by,
resource_totals.public_key_of,
batch_memberships,
anchored_key_memberships
);
log::debug!(
"MILP lower bounds (pods): statements_raw={} statements_dedup={} merkle_proofs={} \
merkle_state_transitions={} custom_pred_verifications={} signed_by={} \
public_key_of={} dominant={}({})",
min_pods_by_statements,
format_lower_bound(lb_statement_groups),
format_lower_bound(lb_merkle),
format_lower_bound(lb_merkle_transitions),
format_lower_bound(lb_custom_pred_verifications),
format_lower_bound(lb_signed_by),
format_lower_bound(lb_public_key_of),
dominant_lb.0,
format_lower_bound(dominant_lb.1)
);
}
// Incremental approach: try solving with increasing POD counts // Incremental approach: try solving with increasing POD counts
// Start with min_pods and increment until we find a feasible solution // Start with min_pods and increment until we find a feasible solution
for target_pods in min_pods..=input.max_pods { for target_pods in min_pods..=input.max_pods {
if let Some(solution) = try_solve_with_pods(input, target_pods, &all_batches)? { log::debug!("Trying to solve with {} PODs", target_pods);
if let Some(solution) = try_solve_with_pods(
input,
target_pods,
&all_batches,
&external_pods,
&external_premises,
&debug_ctx,
)? {
return Ok(solution); return Ok(solution);
} }
// Infeasible with target_pods, try more // Infeasible with target_pods, try more
@ -194,7 +541,12 @@ fn try_solve_with_pods(
input: &SolverInput, input: &SolverInput,
target_pods: usize, target_pods: usize,
all_batches: &[CustomBatchId], all_batches: &[CustomBatchId],
external_pods: &[Hash],
external_premises: &[ExternalDependency],
debug_ctx: &SolveDebugContext,
) -> Result<Option<MultiPodSolution>> { ) -> Result<Option<MultiPodSolution>> {
let attempt_started_at = Instant::now();
// Create variables // Create variables
let mut vars = ProblemVariables::new(); let mut vars = ProblemVariables::new();
let n = input.num_statements; let n = input.num_statements;
@ -250,22 +602,6 @@ fn try_solve_with_pods(
.map(|p| (0..p).map(|_| vars.add(variable().binary())).collect()) .map(|p| (0..p).map(|_| vars.add(variable().binary())).collect())
.collect(); .collect();
// Collect all external POD hashes that statements depend on.
// These are user-provided input PODs referenced by statements.
use crate::middleware::Hash;
let external_pods: Vec<Hash> = input
.deps
.statement_deps
.iter()
.flat_map(|deps| deps.iter())
.filter_map(|dep| match dep {
StatementSource::External(h) => Some(*h),
StatementSource::Internal(_) => None,
})
.collect::<BTreeSet<_>>()
.into_iter()
.collect();
// uses_external[p][e] - POD p uses external POD e as an input // uses_external[p][e] - POD p uses external POD e as an input
let uses_external: Vec<Vec<Variable>> = (0..target_pods) let uses_external: Vec<Vec<Variable>> = (0..target_pods)
.map(|_| { .map(|_| {
@ -275,13 +611,28 @@ fn try_solve_with_pods(
}) })
.collect(); .collect();
// public_external[u][p] - external premise u is exposed publicly in POD p
let public_external: Vec<Vec<Variable>> = (0..external_premises.len())
.map(|_| {
(0..target_pods)
.map(|_| vars.add(variable().binary()))
.collect()
})
.collect();
// Map from external POD hash to index in uses_external // Map from external POD hash to index in uses_external
let external_to_idx: std::collections::HashMap<Hash, usize> = external_pods let external_to_idx: HashMap<Hash, usize> = external_pods
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, h)| (*h, i)) .map(|(i, h)| (*h, i))
.collect(); .collect();
let external_premise_to_idx: HashMap<ExternalDependency, usize> = external_premises
.iter()
.enumerate()
.map(|(i, ext)| (ext.clone(), i))
.collect();
// content_group_used[g][p] - content group g has at least one statement proved in POD p // content_group_used[g][p] - content group g has at least one statement proved in POD p
// When multiple statements have identical content, they share a slot in the POD. // When multiple statements have identical content, they share a slot in the POD.
// This variable tracks whether at least one statement from each content group is proved. // This variable tracks whether at least one statement from each content group is proved.
@ -294,6 +645,58 @@ fn try_solve_with_pods(
}) })
.collect(); .collect();
if log::log_enabled!(log::Level::Debug) {
let estimate = ModelSizeEstimate::for_target_pods(
input,
target_pods,
all_batches.len(),
external_pods.len(),
external_premises.len(),
debug_ctx,
);
log::debug!(
"MILP(k={}) model estimate vars_total={} [prove={} public={} pod_used={} \
public_external={} batch_used={} anchored_key_used={} uses_input={} \
uses_external={} content_group_used={}]",
target_pods,
estimate.vars_total,
estimate.vars_prove,
estimate.vars_public,
estimate.vars_pod_used,
estimate.vars_public_external,
estimate.vars_batch_used,
estimate.vars_anchored_key_used,
estimate.vars_uses_input,
estimate.vars_uses_external,
estimate.vars_content_group_used
);
log::debug!(
"MILP(k={}) model estimate constraints_total={} [c1={} c2={} c2b={} c3={} c4={} \
c5i={} c5e={} c6_pre={} c6_limits={} c7={} c7b={} c8a={} c8b={} c8c={} \
c8d={} c9={} c10={} c10b={}]",
target_pods,
estimate.constraints_total,
estimate.c1_coverage,
estimate.c2_output_public,
estimate.c2b_output_privacy,
estimate.c3_public_implies_proved,
estimate.c4_pod_existence,
estimate.c5_internal_dependencies,
estimate.c5_external_dependencies,
estimate.c6_pre_content_group,
estimate.c6_resource_limits,
estimate.c7_batch_cardinality,
estimate.c7b_anchored_key_tracking,
estimate.c8a_internal_inputs,
estimate.c8b_external_dep_inputs,
estimate.c8c_external_forward_inputs,
estimate.c8d_input_limit,
estimate.c9_symmetry_breaking,
estimate.c10_external_public_availability,
estimate.c10b_external_public_implies_pod_used
);
}
// No optimization objective needed - we use an incremental approach that tries // No optimization objective needed - we use an incremental approach that tries
// min_pods first and increments until feasible. Combined with symmetry breaking // min_pods first and increments until feasible. Combined with symmetry breaking
// (Constraint 9), this finds the minimum number of PODs without needing MILP // (Constraint 9), this finds the minimum number of PODs without needing MILP
@ -336,29 +739,61 @@ fn try_solve_with_pods(
} }
} }
// Constraint 5: Dependencies (works with Constraint 8 to enforce input POD tracking) // Constraint 5: Dependency availability.
// //
// If s depends on d (internal), and s is proved in p, then either: // Internal dependency (s depends on d):
// - d is proved in p (local availability), OR // prove[s][p] <= prove[d][p] + sum_{pp < p} public[d][pp]
// - d is public in some earlier POD p' < p (cross-POD availability)
// //
// This constraint ensures dependencies are AVAILABLE. It does NOT track which // External dependency (s depends on external premise u from external POD e):
// earlier PODs are actually used as inputs - that's handled by Constraint 8. // prove[s][p] <= uses_external[p][e] + sum_{pp < p} public_external[u][pp]
// Together: //
// - Constraint 5 ensures the dependency CAN be satisfied // This captures the intended non-sticky semantics for external premises:
// - Constraint 8 ensures that when we use a statement from earlier POD pp, // a consumer POD can use the external POD directly, OR consume an earlier POD
// we count pp as an input to pod p (for max_input_pods enforcement) // that made the external premise public.
for s in 0..n { for s in 0..n {
for dep in &input.deps.statement_deps[s] { for dep in &input.deps.statement_deps[s] {
if let StatementSource::Internal(d) = dep { match dep {
for p in 0..target_pods { StatementSource::Internal(d) => {
// prove[s][p] <= prove[d][p] + sum_{p' < p} public[d][p'] for p in 0..target_pods {
let mut rhs: Expression = prove[*d][p].into(); let mut rhs: Expression = prove[*d][p].into();
for pp in 0..p { for pp in 0..p {
rhs += public[*d][pp]; rhs += public[*d][pp];
}
model.add_constraint(constraint!(prove[s][p] <= rhs));
} }
model.add_constraint(constraint!(prove[s][p] <= rhs));
} }
StatementSource::External(ext) => {
if let (Some(&e), Some(&u)) = (
external_to_idx.get(&ext.pod_hash),
external_premise_to_idx.get(ext),
) {
for p in 0..target_pods {
let mut rhs: Expression = uses_external[p][e].into();
for pp in 0..p {
rhs += public_external[u][pp];
}
model.add_constraint(constraint!(prove[s][p] <= rhs));
}
}
}
}
}
}
// Constraint 10: External-public availability and pod usage.
//
// An external premise can be made public in POD p iff it is available there:
// either directly from its source external input POD, or from an earlier POD
// that already exposed it publicly.
for (u, ext) in external_premises.iter().enumerate() {
if let Some(&e) = external_to_idx.get(&ext.pod_hash) {
for p in 0..target_pods {
let mut rhs: Expression = uses_external[p][e].into();
for pp in 0..p {
rhs += public_external[u][pp];
}
model.add_constraint(constraint!(public_external[u][p] <= rhs));
model.add_constraint(constraint!(public_external[u][p] <= pod_used[p]));
} }
} }
} }
@ -394,10 +829,14 @@ fn try_solve_with_pods(
<= (input.params.max_priv_statements() as f64) * pod_used[p] <= (input.params.max_priv_statements() as f64) * pod_used[p]
)); ));
// 6b: Public statement count // 6b: Public statement count (internal public statements + forwarded external premises)
let pub_sum: Expression = (0..n).map(|s| public[s][p]).sum(); let pub_sum_internal: Expression = (0..n).map(|s| public[s][p]).sum();
let pub_sum_external: Expression = (0..external_premises.len())
.map(|u| public_external[u][p])
.sum();
model.add_constraint(constraint!( model.add_constraint(constraint!(
pub_sum <= (input.params.max_public_statements as f64) * pod_used[p] pub_sum_internal + pub_sum_external
<= (input.params.max_public_statements as f64) * pod_used[p]
)); ));
// 6c: Merkle proofs // 6c: Merkle proofs
@ -509,14 +948,9 @@ fn try_solve_with_pods(
} }
} }
// Constraint 8a: Internal input POD tracking using uses_input // Constraint 8a: Internal input POD tracking using uses_input.
// uses_input[p][pp] >= prove[s][p] + public[d][pp] - prove[d][p] - 1 // If s is proved in p and depends on internal d exposed by pp, then pp must be counted
// for each dependency (s depends on d) // as an input unless d is also proved locally in p.
//
// If s is proved in p and d is public in pp, we need pp as input UNLESS d is also
// proved locally in p. Subtracting prove[d][p] ensures that when d is re-proved
// locally (prove[d][p] = 1), the constraint becomes uses_input >= 0, which is
// always satisfied without forcing the input relationship.
for s in 0..n { for s in 0..n {
for dep in &input.deps.statement_deps[s] { for dep in &input.deps.statement_deps[s] {
if let StatementSource::Internal(d) = dep { if let StatementSource::Internal(d) = dep {
@ -531,22 +965,50 @@ fn try_solve_with_pods(
} }
} }
// Constraint 8b: External input POD tracking using uses_external // Constraint 8b: External dependency input tracking via earlier PODs.
// If statement s is proved in POD p and s depends on external POD e, then uses_external[p][e] = 1 // If s is proved in p, and external premise u is provided by earlier POD pp
// (i.e., public_external[u][pp] = 1), then pp must be counted as an input unless
// p uses the source external POD directly.
for s in 0..n { for s in 0..n {
for dep in &input.deps.statement_deps[s] { for dep in &input.deps.statement_deps[s] {
if let StatementSource::External(h) = dep { if let StatementSource::External(ext) = dep {
if let Some(&e) = external_to_idx.get(h) { if let (Some(&e), Some(&u)) = (
for p in 0..target_pods { external_to_idx.get(&ext.pod_hash),
// If s is proved in p, then uses_external[p][e] = 1 external_premise_to_idx.get(ext),
model.add_constraint(constraint!(uses_external[p][e] >= prove[s][p])); ) {
for p in 1..target_pods {
for pp in 0..p {
model.add_constraint(constraint!(
uses_input[p][pp]
>= prove[s][p] + public_external[u][pp]
- uses_external[p][e]
- 1.0
));
}
} }
} }
} }
} }
} }
// Constraint 8c: Total input PODs (internal + external) must not exceed max_input_pods // Constraint 8c: Forwarding an external premise as public also consumes an internal input
// unless the forwarding POD uses the source external POD directly.
for (u, ext) in external_premises.iter().enumerate() {
if let Some(&e) = external_to_idx.get(&ext.pod_hash) {
for p in 1..target_pods {
for pp in 0..p {
model.add_constraint(constraint!(
uses_input[p][pp]
>= public_external[u][p] + public_external[u][pp]
- uses_external[p][e]
- 1.0
));
}
}
}
}
// Constraint 8d: Total input PODs (internal + external) must not exceed max_input_pods
// For each POD p, the total number of inputs is: // For each POD p, the total number of inputs is:
// - Internal inputs: PODs pp < p that provide public statements used by p // - Internal inputs: PODs pp < p that provide public statements used by p
// - External inputs: User-provided PODs referenced by statements in p // - External inputs: User-provided PODs referenced by statements in p
@ -569,10 +1031,32 @@ fn try_solve_with_pods(
} }
// Solve // Solve
let solve_started_at = Instant::now();
let solution = match model.solve() { let solution = match model.solve() {
Ok(sol) => sol, Ok(sol) => {
Err(_) => { log::debug!(
// Infeasible with this number of PODs, try more "MILP(k={}) result=feasible solve_ms={} total_ms={}",
target_pods,
solve_started_at.elapsed().as_millis(),
attempt_started_at.elapsed().as_millis()
);
sol
}
Err(err) => {
let status = match err {
ResolutionError::Infeasible => "infeasible",
ResolutionError::Unbounded => "unbounded",
ResolutionError::Other(_) | ResolutionError::Str(_) => "error",
};
log::debug!(
"MILP(k={}) result={} solve_ms={} total_ms={} detail={}",
target_pods,
status,
solve_started_at.elapsed().as_millis(),
attempt_started_at.elapsed().as_millis(),
err
);
// Infeasible or solver error with this number of PODs, try more
return Ok(None); return Ok(None);
} }
}; };
@ -589,6 +1073,9 @@ fn try_solve_with_pods(
let mut statement_to_pods: Vec<Vec<usize>> = vec![vec![]; n]; let mut statement_to_pods: Vec<Vec<usize>> = vec![vec![]; n];
let mut pod_statements: Vec<Vec<usize>> = vec![vec![]; pod_count]; let mut pod_statements: Vec<Vec<usize>> = vec![vec![]; pod_count];
let mut pod_public_statements: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); pod_count]; let mut pod_public_statements: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); pod_count];
let mut pod_internal_inputs: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); pod_count];
let mut pod_external_inputs: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); pod_count];
let mut pod_public_external_premises: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); pod_count];
for s in 0..n { for s in 0..n {
for p in 0..pod_count { for p in 0..pod_count {
@ -602,17 +1089,47 @@ fn try_solve_with_pods(
} }
} }
for p in 0..pod_count {
for pp in 0..p {
if solution.value(uses_input[p][pp]) > SOLVER_BINARY_THRESHOLD {
pod_internal_inputs[p].insert(pp);
}
}
for e in 0..external_pods.len() {
if solution.value(uses_external[p][e]) > SOLVER_BINARY_THRESHOLD {
pod_external_inputs[p].insert(e);
}
}
}
for u in 0..external_premises.len() {
for p in 0..pod_count {
if solution.value(public_external[u][p]) > SOLVER_BINARY_THRESHOLD {
pod_public_external_premises[p].insert(u);
}
}
}
Ok(Some(MultiPodSolution { Ok(Some(MultiPodSolution {
pod_count, pod_count,
statement_to_pods, statement_to_pods,
pod_statements, pod_statements,
pod_public_statements, pod_public_statements,
pod_internal_inputs,
external_pod_hashes: external_pods.to_vec(),
pod_external_inputs,
external_premises: external_premises.to_vec(),
pod_public_external_premises,
})) }))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::{
frontend::multi_pod::deps::ExternalDependency,
middleware::{Statement, Value, ValueRef},
};
#[test] #[test]
fn test_no_public_statements_error() { fn test_no_public_statements_error() {
@ -642,4 +1159,67 @@ mod tests {
.to_string() .to_string()
.contains("No public statements requested")); .contains("No public statements requested"));
} }
#[test]
fn test_external_dependency_can_be_forwarded_to_reduce_input_pressure() {
// Build a minimal synthetic case:
// - s0 depends on external premise E
// - s1 (output) depends on s0 and E
// - max_input_pods = 1 and max_priv_statements = 1 forces two PODs:
// POD0 proves s0 and must make both s0 and E public,
// POD1 consumes only POD0 as input (no direct external input).
let params = Params {
max_statements: 3,
max_public_statements: 2,
max_input_pods: 1,
..Params::default()
};
let ext_stmt = Statement::Equal(
ValueRef::Literal(Value::from(42_i64)),
ValueRef::Literal(Value::from(42_i64)),
);
let external_dep = ExternalDependency {
pod_hash: Hash::default(),
statement: ext_stmt,
};
let deps = DependencyGraph {
statement_deps: vec![
vec![StatementSource::External(external_dep.clone())],
vec![
StatementSource::Internal(0),
StatementSource::External(external_dep),
],
],
};
let costs = vec![StatementCost::default(), StatementCost::default()];
let statement_content_groups = vec![vec![0], vec![1]];
let output_public = vec![1];
let input = SolverInput {
num_statements: 2,
costs: &costs,
deps: &deps,
output_public_indices: &output_public,
params: &params,
max_pods: 4,
all_anchored_keys: &[],
anchored_key_producers: &[],
statement_content_groups: &statement_content_groups,
};
let solution = solve(&input).expect("solver should find a feasible forwarding layout");
assert_eq!(solution.pod_count, 2);
assert_eq!(solution.external_premises.len(), 1);
// POD1 should consume POD0 as its only input and avoid direct external input.
assert!(solution.pod_internal_inputs[1].contains(&0));
assert!(solution.pod_external_inputs[1].is_empty());
// POD0 should expose the external premise publicly so POD1 can consume it via POD0.
assert!(solution.pod_public_external_premises[0].contains(&0));
}
} }