Fix pod builder (#496)

Several fixes and code simplifications:
- MainPodBuilder
  - Fix: It was not tracking Contains statements inherited via input pods (via public statements) when automatically generating Contains statements for Entry arguments.
  - Enhancement: Deduplicate statements
- MultiPodBuilder
    - Simplify: Remove the "statement groups" logic and instead deduplicate statements in the MainPodBuilder (which is much simpler to do)
    - Remove the "anchored key" explicit dependency tracking and instead rely on regular dependency tracking by using all the implicit operations and statements generated by MainPodBuilder as input to the solver.
    - Fix: Count and constrain custom predicates used in a pod instead of batches used
This commit is contained in:
Eduard S. 2026-03-25 18:48:28 +01:00 committed by GitHub
parent 1e592e11cf
commit a4069bcc55
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 162 additions and 454 deletions

View file

@ -48,12 +48,12 @@
//! [`MainPodBuilder`]: crate::frontend::MainPodBuilder
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
collections::{BTreeSet, HashMap},
fmt,
};
use crate::{
frontend::{MainPod, MainPodBuilder, Operation, OperationArg},
frontend::{MainPod, MainPodBuilder, Operation},
middleware::{Hash, MainPodProver, Params, Statement, VDSet, Value},
};
@ -61,7 +61,7 @@ mod cost;
mod deps;
mod solver;
use cost::{AnchoredKeyId, StatementCost};
use cost::StatementCost;
use deps::{DependencyGraph, StatementSource};
pub use solver::MultiPodSolution;
@ -168,12 +168,8 @@ pub struct MultiPodBuilder {
options: Options,
/// External input PODs (already proved).
input_pods: Vec<MainPod>,
/// Statements created by this builder.
statements: Vec<Statement>,
/// Operations that produce each statement.
operations: Vec<Operation>,
/// Optional initial wildcard values for custom operations
operations_wildcard_values: Vec<Vec<(usize, Value)>>,
operations_wildcard_values: HashMap<usize, Vec<(usize, Value)>>,
/// Indices of statements that should be public in output PODs.
/// Uses Vec since max_public_statements is small (≤8); indices are naturally sorted.
output_public_indices: Vec<usize>,
@ -193,7 +189,7 @@ pub struct SolvedMultiPod {
statements: Vec<Statement>,
operations: Vec<Operation>,
output_public_indices: Vec<usize>,
operations_wildcard_values: Vec<Vec<(usize, Value)>>,
operations_wildcard_values: HashMap<usize, Vec<(usize, Value)>>,
solution: MultiPodSolution,
deps: DependencyGraph,
}
@ -260,56 +256,27 @@ impl SolvedMultiPod {
let statements_sorted: BTreeSet<usize> = statements_in_this_pod.iter().copied().collect();
let public_set = &solution.pod_public_statements[pod_idx];
// Track statements proved locally in this POD for argument remapping.
// We index by statement content so duplicate statements can reuse a single
// built statement slot in MainPodBuilder.
let mut added_statements_by_content: HashMap<Statement, Statement> = HashMap::new();
for &stmt_idx in &statements_sorted {
let original_stmt = self.statements[stmt_idx].clone();
// If this statement content was already built in this POD, reuse it instead
// of replaying the operation. If any duplicate is public, reveal the
// already-built statement.
if let Some(_existing_stmt) = added_statements_by_content.get(&original_stmt) {
continue;
}
let mut op = self.operations[stmt_idx].clone();
let wildcard_values = self.operations_wildcard_values[stmt_idx].clone();
// Remap Statement arguments that reference locally-proved statements.
// For external dependencies (from input PODs including earlier generated PODs),
// the original Statement is used directly - MainPodBuilder will find it in
// the input POD's public statements via find_op_arg.
for arg in &mut op.1 {
if let OperationArg::Statement(ref orig_stmt) = arg {
if let Some(remapped_stmt) = added_statements_by_content.get(orig_stmt) {
*arg = OperationArg::Statement(remapped_stmt.clone());
}
}
}
let op = self.operations[stmt_idx].clone();
let wildcard_values = self
.operations_wildcard_values
.get(&stmt_idx)
.cloned()
.unwrap_or_default();
let stmt = builder.op(false, wildcard_values, op)?;
added_statements_by_content.insert(original_stmt, stmt);
assert_eq!(stmt, self.statements[stmt_idx]); // Sanity check
}
// For the output pod, make statements public in the original order.
// Intermediate pods use the solver-selected public set.
if pod_idx == solution.pod_count - 1 {
for idx in &self.output_public_indices {
let stmt = added_statements_by_content
.get(&self.statements[*idx])
.expect("exists");
builder.reveal(stmt);
builder.reveal(&self.statements[*idx])?;
}
} else {
for idx in public_set {
let stmt = added_statements_by_content
.get(&self.statements[*idx])
.expect("exists");
builder.reveal(stmt);
builder.reveal(&self.statements[*idx])?;
}
}
@ -317,7 +284,7 @@ impl SolvedMultiPod {
// for this POD. These do not require local proving in this POD.
for ext_premise_idx in &solution.pod_public_external_premises[pod_idx] {
let ext_premise = &solution.external_premises[*ext_premise_idx];
builder.reveal(&ext_premise.statement);
builder.reveal(&ext_premise.statement)?;
}
// Step 4: Prove the POD
@ -456,9 +423,7 @@ impl MultiPodBuilder {
options,
builder,
input_pods: Vec::new(),
statements: Vec::new(),
operations: Vec::new(),
operations_wildcard_values: Vec::new(),
operations_wildcard_values: HashMap::new(),
output_public_indices: Vec::new(),
}
}
@ -480,6 +445,16 @@ impl MultiPodBuilder {
self.op(false, vec![], op)
}
// Find the index of a statement that has been added. Panics if the statement doesn't
// exist.
fn stmt_index(&self, stmt: &Statement) -> usize {
self.builder
.statements
.iter()
.position(|s| s == stmt)
.expect("exists")
}
pub fn op(
&mut self,
public: bool,
@ -488,8 +463,10 @@ impl MultiPodBuilder {
) -> Result<Statement> {
let stmt = self.add_operation(wildcard_values, op)?;
if public {
// Index is always new (just added), so push without duplicate check
self.output_public_indices.push(self.statements.len() - 1);
let index = self.stmt_index(&stmt);
if !self.output_public_indices.contains(&index) {
self.output_public_indices.push(index);
}
}
Ok(stmt)
}
@ -510,10 +487,8 @@ impl MultiPodBuilder {
let stmt = self
.builder
.op(false, wildcard_values.clone(), op.clone())?;
self.statements.push(stmt.clone());
self.operations.push(op);
self.operations_wildcard_values.push(wildcard_values);
self.operations_wildcard_values
.insert(self.stmt_index(&stmt), wildcard_values.clone());
Ok(stmt)
}
@ -523,7 +498,7 @@ impl MultiPodBuilder {
/// Returns an error if the statement was not found in the builder.
/// Calling this multiple times on the same statement is idempotent.
pub fn reveal(&mut self, stmt: &Statement) -> Result<()> {
if let Some(idx) = self.statements.iter().position(|s| s == stmt) {
if let Some(idx) = self.builder.statements.iter().position(|s| s == stmt) {
if !self.output_public_indices.contains(&idx) {
self.output_public_indices.push(idx);
}
@ -536,8 +511,8 @@ impl MultiPodBuilder {
}
/// Get the number of statements.
pub fn num_statements(&self) -> usize {
self.statements.len()
pub fn stmt_len(&self) -> usize {
self.builder.stmt_len()
}
/// Solve the packing problem and return a solved builder ready for proving.
@ -545,66 +520,31 @@ impl MultiPodBuilder {
/// This runs the MILP solver to find the optimal POD assignment.
/// Consumes the builder and returns a [`SolvedMultiPod`] that can be proved.
pub fn solve(self) -> Result<SolvedMultiPod> {
let MainPodBuilder {
statements,
operations,
..
} = self.builder;
// Compute costs for each statement
let costs: Vec<StatementCost> = self
.operations
let costs: Vec<StatementCost> = operations
.iter()
.map(StatementCost::from_operation)
.collect();
// Collect all unique anchored keys from the costs
let all_anchored_keys: Vec<AnchoredKeyId> = costs
.iter()
.flat_map(|c| c.anchored_keys.iter().cloned())
.collect::<std::collections::BTreeSet<_>>()
.into_iter()
.collect();
// Build map from anchored key to its producing statement index (if any).
// A Contains statement with literal (dict, key, value) "produces" that anchored key.
let mut ak_to_producer: HashMap<AnchoredKeyId, usize> = HashMap::new();
for (stmt_idx, stmt) in self.statements.iter().enumerate() {
if let Some(ak) = AnchoredKeyId::from_contains_statement(stmt) {
// First producer wins (shouldn't have duplicates in practice)
ak_to_producer.entry(ak).or_insert(stmt_idx);
}
}
// Build parallel array: anchored_key_producers[i] = producer for all_anchored_keys[i]
let anchored_key_producers: Vec<Option<usize>> = all_anchored_keys
.iter()
.map(|ak| ak_to_producer.get(ak).copied())
.collect();
// Build external POD statement mapping
let external_pod_statements = build_external_statement_map(&self.input_pods);
// Build dependency graph
let deps =
DependencyGraph::build(&self.statements, &self.operations, &external_pod_statements);
// Build statement content groups for deduplication.
// Statements with identical content share a single slot in the POD.
// Keep groups ordered by first occurrence index for deterministic solver input.
let mut first_idx_by_stmt: HashMap<&Statement, usize> = HashMap::new();
let mut groups_by_first_idx: BTreeMap<usize, Vec<usize>> = BTreeMap::new();
for (idx, stmt) in self.statements.iter().enumerate() {
let first_idx = *first_idx_by_stmt.entry(stmt).or_insert(idx);
groups_by_first_idx.entry(first_idx).or_default().push(idx);
}
let statement_content_groups: Vec<Vec<usize>> = groups_by_first_idx.into_values().collect();
let deps = DependencyGraph::build(&statements, &operations, &external_pod_statements);
// Run solver
let input = solver::SolverInput {
num_statements: self.statements.len(),
num_statements: statements.len(),
costs: &costs,
deps: &deps,
output_public_indices: &self.output_public_indices,
params: &self.params,
max_pods: self.options.max_pods,
all_anchored_keys: &all_anchored_keys,
anchored_key_producers: &anchored_key_producers,
statement_content_groups: &statement_content_groups,
};
let solution = solver::solve(&input)?;
@ -613,8 +553,8 @@ impl MultiPodBuilder {
params: self.params,
vd_set: self.vd_set,
input_pods: self.input_pods,
statements: self.statements,
operations: self.operations,
statements,
operations,
output_public_indices: self.output_public_indices,
operations_wildcard_values: self.operations_wildcard_values,
solution,
@ -845,33 +785,13 @@ mod tests {
let solution = solved.solution();
// Expected: exactly 2 PODs
// - POD 0 (intermediate): statements 0 (contains), 1 (a_out); a_out is public
// - POD 1 (output): statement 2 (b_out); b_out is public
// The output POD accesses a_out from POD 0 to satisfy b_out's dependency.
assert_eq!(
solution.pod_count, 2,
"Expected exactly 2 PODs for 3-statement chain with max_priv=2"
);
// POD 0 should contain statements 0 and 1 (contains and a_out)
assert!(
solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1),
"POD 0 should contain statements 0 (contains) and 1 (a_out), got {:?}",
solution.pod_statements[0]
);
// Statement 1 (a_out) should be public in POD 0 so POD 1 can access it
assert!(
solution.pod_public_statements[0].contains(&1),
"Statement 1 (a_out) should be public in POD 0"
);
// POD 1 (output) should contain statement 2 (b_out)
assert!(
solution.pod_statements[1].contains(&2),
"POD 1 should contain statement 2 (b_out), got {:?}",
solution.pod_statements[1]
);
// Solution A:
// - POD 0 (intermediate): public statements 0 (contains)
// - POD 1 (output): inherits statement 0 (contains) from POD0, statement 1 (a_out),
// public statement 2 (b_out)
// Solution B:
// - POD 0 (intermediate): statements 0 (contains), public statement 1 (a_out)
// - POD 1 (output): inherits statement 1 (a_out) from POD0, public statement 2 (b_out)
// Statement 2 (b_out) should be public in POD 1 (it's output-public)
assert!(