Fix pod builder (#496)
Several fixes and code simplifications:
- MainPodBuilder
- Fix: It was not tracking Contains statements inherited via input pods (via public statements) when automatically generating Contains statements for Entry arguments.
- Enhancement: Deduplicate statements
- MultiPodBuilder
- Simplify: Remove the "statement groups" logic and instead deduplicate statements in the MainPodBuilder (which is much simpler to do)
- Remove the "anchored key" explicit dependency tracking and instead rely on regular dependency tracking by using all the implicit operations and statements generated by MainPodBuilder as input to the solver.
- Fix: Count and constrain custom predicates used in a pod instead of batches used
This commit is contained in:
parent
1e592e11cf
commit
a4069bcc55
7 changed files with 162 additions and 454 deletions
|
|
@ -48,12 +48,12 @@
|
|||
//! [`MainPodBuilder`]: crate::frontend::MainPodBuilder
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet, HashMap},
|
||||
collections::{BTreeSet, HashMap},
|
||||
fmt,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
frontend::{MainPod, MainPodBuilder, Operation, OperationArg},
|
||||
frontend::{MainPod, MainPodBuilder, Operation},
|
||||
middleware::{Hash, MainPodProver, Params, Statement, VDSet, Value},
|
||||
};
|
||||
|
||||
|
|
@ -61,7 +61,7 @@ mod cost;
|
|||
mod deps;
|
||||
mod solver;
|
||||
|
||||
use cost::{AnchoredKeyId, StatementCost};
|
||||
use cost::StatementCost;
|
||||
use deps::{DependencyGraph, StatementSource};
|
||||
pub use solver::MultiPodSolution;
|
||||
|
||||
|
|
@ -168,12 +168,8 @@ pub struct MultiPodBuilder {
|
|||
options: Options,
|
||||
/// External input PODs (already proved).
|
||||
input_pods: Vec<MainPod>,
|
||||
/// Statements created by this builder.
|
||||
statements: Vec<Statement>,
|
||||
/// Operations that produce each statement.
|
||||
operations: Vec<Operation>,
|
||||
/// Optional initial wildcard values for custom operations
|
||||
operations_wildcard_values: Vec<Vec<(usize, Value)>>,
|
||||
operations_wildcard_values: HashMap<usize, Vec<(usize, Value)>>,
|
||||
/// Indices of statements that should be public in output PODs.
|
||||
/// Uses Vec since max_public_statements is small (≤8); indices are naturally sorted.
|
||||
output_public_indices: Vec<usize>,
|
||||
|
|
@ -193,7 +189,7 @@ pub struct SolvedMultiPod {
|
|||
statements: Vec<Statement>,
|
||||
operations: Vec<Operation>,
|
||||
output_public_indices: Vec<usize>,
|
||||
operations_wildcard_values: Vec<Vec<(usize, Value)>>,
|
||||
operations_wildcard_values: HashMap<usize, Vec<(usize, Value)>>,
|
||||
solution: MultiPodSolution,
|
||||
deps: DependencyGraph,
|
||||
}
|
||||
|
|
@ -260,56 +256,27 @@ impl SolvedMultiPod {
|
|||
let statements_sorted: BTreeSet<usize> = statements_in_this_pod.iter().copied().collect();
|
||||
let public_set = &solution.pod_public_statements[pod_idx];
|
||||
|
||||
// Track statements proved locally in this POD for argument remapping.
|
||||
// We index by statement content so duplicate statements can reuse a single
|
||||
// built statement slot in MainPodBuilder.
|
||||
let mut added_statements_by_content: HashMap<Statement, Statement> = HashMap::new();
|
||||
|
||||
for &stmt_idx in &statements_sorted {
|
||||
let original_stmt = self.statements[stmt_idx].clone();
|
||||
|
||||
// If this statement content was already built in this POD, reuse it instead
|
||||
// of replaying the operation. If any duplicate is public, reveal the
|
||||
// already-built statement.
|
||||
if let Some(_existing_stmt) = added_statements_by_content.get(&original_stmt) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut op = self.operations[stmt_idx].clone();
|
||||
let wildcard_values = self.operations_wildcard_values[stmt_idx].clone();
|
||||
|
||||
// Remap Statement arguments that reference locally-proved statements.
|
||||
// For external dependencies (from input PODs including earlier generated PODs),
|
||||
// the original Statement is used directly - MainPodBuilder will find it in
|
||||
// the input POD's public statements via find_op_arg.
|
||||
for arg in &mut op.1 {
|
||||
if let OperationArg::Statement(ref orig_stmt) = arg {
|
||||
if let Some(remapped_stmt) = added_statements_by_content.get(orig_stmt) {
|
||||
*arg = OperationArg::Statement(remapped_stmt.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
let op = self.operations[stmt_idx].clone();
|
||||
let wildcard_values = self
|
||||
.operations_wildcard_values
|
||||
.get(&stmt_idx)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
let stmt = builder.op(false, wildcard_values, op)?;
|
||||
|
||||
added_statements_by_content.insert(original_stmt, stmt);
|
||||
assert_eq!(stmt, self.statements[stmt_idx]); // Sanity check
|
||||
}
|
||||
|
||||
// For the output pod, make statements public in the original order.
|
||||
// Intermediate pods use the solver-selected public set.
|
||||
if pod_idx == solution.pod_count - 1 {
|
||||
for idx in &self.output_public_indices {
|
||||
let stmt = added_statements_by_content
|
||||
.get(&self.statements[*idx])
|
||||
.expect("exists");
|
||||
builder.reveal(stmt);
|
||||
builder.reveal(&self.statements[*idx])?;
|
||||
}
|
||||
} else {
|
||||
for idx in public_set {
|
||||
let stmt = added_statements_by_content
|
||||
.get(&self.statements[*idx])
|
||||
.expect("exists");
|
||||
builder.reveal(stmt);
|
||||
builder.reveal(&self.statements[*idx])?;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -317,7 +284,7 @@ impl SolvedMultiPod {
|
|||
// for this POD. These do not require local proving in this POD.
|
||||
for ext_premise_idx in &solution.pod_public_external_premises[pod_idx] {
|
||||
let ext_premise = &solution.external_premises[*ext_premise_idx];
|
||||
builder.reveal(&ext_premise.statement);
|
||||
builder.reveal(&ext_premise.statement)?;
|
||||
}
|
||||
|
||||
// Step 4: Prove the POD
|
||||
|
|
@ -456,9 +423,7 @@ impl MultiPodBuilder {
|
|||
options,
|
||||
builder,
|
||||
input_pods: Vec::new(),
|
||||
statements: Vec::new(),
|
||||
operations: Vec::new(),
|
||||
operations_wildcard_values: Vec::new(),
|
||||
operations_wildcard_values: HashMap::new(),
|
||||
output_public_indices: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
|
@ -480,6 +445,16 @@ impl MultiPodBuilder {
|
|||
self.op(false, vec![], op)
|
||||
}
|
||||
|
||||
// Find the index of a statement that has been added. Panics if the statement doesn't
|
||||
// exist.
|
||||
fn stmt_index(&self, stmt: &Statement) -> usize {
|
||||
self.builder
|
||||
.statements
|
||||
.iter()
|
||||
.position(|s| s == stmt)
|
||||
.expect("exists")
|
||||
}
|
||||
|
||||
pub fn op(
|
||||
&mut self,
|
||||
public: bool,
|
||||
|
|
@ -488,8 +463,10 @@ impl MultiPodBuilder {
|
|||
) -> Result<Statement> {
|
||||
let stmt = self.add_operation(wildcard_values, op)?;
|
||||
if public {
|
||||
// Index is always new (just added), so push without duplicate check
|
||||
self.output_public_indices.push(self.statements.len() - 1);
|
||||
let index = self.stmt_index(&stmt);
|
||||
if !self.output_public_indices.contains(&index) {
|
||||
self.output_public_indices.push(index);
|
||||
}
|
||||
}
|
||||
Ok(stmt)
|
||||
}
|
||||
|
|
@ -510,10 +487,8 @@ impl MultiPodBuilder {
|
|||
let stmt = self
|
||||
.builder
|
||||
.op(false, wildcard_values.clone(), op.clone())?;
|
||||
|
||||
self.statements.push(stmt.clone());
|
||||
self.operations.push(op);
|
||||
self.operations_wildcard_values.push(wildcard_values);
|
||||
self.operations_wildcard_values
|
||||
.insert(self.stmt_index(&stmt), wildcard_values.clone());
|
||||
|
||||
Ok(stmt)
|
||||
}
|
||||
|
|
@ -523,7 +498,7 @@ impl MultiPodBuilder {
|
|||
/// Returns an error if the statement was not found in the builder.
|
||||
/// Calling this multiple times on the same statement is idempotent.
|
||||
pub fn reveal(&mut self, stmt: &Statement) -> Result<()> {
|
||||
if let Some(idx) = self.statements.iter().position(|s| s == stmt) {
|
||||
if let Some(idx) = self.builder.statements.iter().position(|s| s == stmt) {
|
||||
if !self.output_public_indices.contains(&idx) {
|
||||
self.output_public_indices.push(idx);
|
||||
}
|
||||
|
|
@ -536,8 +511,8 @@ impl MultiPodBuilder {
|
|||
}
|
||||
|
||||
/// Get the number of statements.
|
||||
pub fn num_statements(&self) -> usize {
|
||||
self.statements.len()
|
||||
pub fn stmt_len(&self) -> usize {
|
||||
self.builder.stmt_len()
|
||||
}
|
||||
|
||||
/// Solve the packing problem and return a solved builder ready for proving.
|
||||
|
|
@ -545,66 +520,31 @@ impl MultiPodBuilder {
|
|||
/// This runs the MILP solver to find the optimal POD assignment.
|
||||
/// Consumes the builder and returns a [`SolvedMultiPod`] that can be proved.
|
||||
pub fn solve(self) -> Result<SolvedMultiPod> {
|
||||
let MainPodBuilder {
|
||||
statements,
|
||||
operations,
|
||||
..
|
||||
} = self.builder;
|
||||
// Compute costs for each statement
|
||||
let costs: Vec<StatementCost> = self
|
||||
.operations
|
||||
let costs: Vec<StatementCost> = operations
|
||||
.iter()
|
||||
.map(StatementCost::from_operation)
|
||||
.collect();
|
||||
|
||||
// Collect all unique anchored keys from the costs
|
||||
let all_anchored_keys: Vec<AnchoredKeyId> = costs
|
||||
.iter()
|
||||
.flat_map(|c| c.anchored_keys.iter().cloned())
|
||||
.collect::<std::collections::BTreeSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
// Build map from anchored key to its producing statement index (if any).
|
||||
// A Contains statement with literal (dict, key, value) "produces" that anchored key.
|
||||
let mut ak_to_producer: HashMap<AnchoredKeyId, usize> = HashMap::new();
|
||||
for (stmt_idx, stmt) in self.statements.iter().enumerate() {
|
||||
if let Some(ak) = AnchoredKeyId::from_contains_statement(stmt) {
|
||||
// First producer wins (shouldn't have duplicates in practice)
|
||||
ak_to_producer.entry(ak).or_insert(stmt_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Build parallel array: anchored_key_producers[i] = producer for all_anchored_keys[i]
|
||||
let anchored_key_producers: Vec<Option<usize>> = all_anchored_keys
|
||||
.iter()
|
||||
.map(|ak| ak_to_producer.get(ak).copied())
|
||||
.collect();
|
||||
|
||||
// Build external POD statement mapping
|
||||
let external_pod_statements = build_external_statement_map(&self.input_pods);
|
||||
|
||||
// Build dependency graph
|
||||
let deps =
|
||||
DependencyGraph::build(&self.statements, &self.operations, &external_pod_statements);
|
||||
|
||||
// Build statement content groups for deduplication.
|
||||
// Statements with identical content share a single slot in the POD.
|
||||
// Keep groups ordered by first occurrence index for deterministic solver input.
|
||||
let mut first_idx_by_stmt: HashMap<&Statement, usize> = HashMap::new();
|
||||
let mut groups_by_first_idx: BTreeMap<usize, Vec<usize>> = BTreeMap::new();
|
||||
for (idx, stmt) in self.statements.iter().enumerate() {
|
||||
let first_idx = *first_idx_by_stmt.entry(stmt).or_insert(idx);
|
||||
groups_by_first_idx.entry(first_idx).or_default().push(idx);
|
||||
}
|
||||
let statement_content_groups: Vec<Vec<usize>> = groups_by_first_idx.into_values().collect();
|
||||
let deps = DependencyGraph::build(&statements, &operations, &external_pod_statements);
|
||||
|
||||
// Run solver
|
||||
let input = solver::SolverInput {
|
||||
num_statements: self.statements.len(),
|
||||
num_statements: statements.len(),
|
||||
costs: &costs,
|
||||
deps: &deps,
|
||||
output_public_indices: &self.output_public_indices,
|
||||
params: &self.params,
|
||||
max_pods: self.options.max_pods,
|
||||
all_anchored_keys: &all_anchored_keys,
|
||||
anchored_key_producers: &anchored_key_producers,
|
||||
statement_content_groups: &statement_content_groups,
|
||||
};
|
||||
|
||||
let solution = solver::solve(&input)?;
|
||||
|
|
@ -613,8 +553,8 @@ impl MultiPodBuilder {
|
|||
params: self.params,
|
||||
vd_set: self.vd_set,
|
||||
input_pods: self.input_pods,
|
||||
statements: self.statements,
|
||||
operations: self.operations,
|
||||
statements,
|
||||
operations,
|
||||
output_public_indices: self.output_public_indices,
|
||||
operations_wildcard_values: self.operations_wildcard_values,
|
||||
solution,
|
||||
|
|
@ -845,33 +785,13 @@ mod tests {
|
|||
let solution = solved.solution();
|
||||
|
||||
// Expected: exactly 2 PODs
|
||||
// - POD 0 (intermediate): statements 0 (contains), 1 (a_out); a_out is public
|
||||
// - POD 1 (output): statement 2 (b_out); b_out is public
|
||||
// The output POD accesses a_out from POD 0 to satisfy b_out's dependency.
|
||||
assert_eq!(
|
||||
solution.pod_count, 2,
|
||||
"Expected exactly 2 PODs for 3-statement chain with max_priv=2"
|
||||
);
|
||||
|
||||
// POD 0 should contain statements 0 and 1 (contains and a_out)
|
||||
assert!(
|
||||
solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1),
|
||||
"POD 0 should contain statements 0 (contains) and 1 (a_out), got {:?}",
|
||||
solution.pod_statements[0]
|
||||
);
|
||||
|
||||
// Statement 1 (a_out) should be public in POD 0 so POD 1 can access it
|
||||
assert!(
|
||||
solution.pod_public_statements[0].contains(&1),
|
||||
"Statement 1 (a_out) should be public in POD 0"
|
||||
);
|
||||
|
||||
// POD 1 (output) should contain statement 2 (b_out)
|
||||
assert!(
|
||||
solution.pod_statements[1].contains(&2),
|
||||
"POD 1 should contain statement 2 (b_out), got {:?}",
|
||||
solution.pod_statements[1]
|
||||
);
|
||||
// Solution A:
|
||||
// - POD 0 (intermediate): public statements 0 (contains)
|
||||
// - POD 1 (output): inherits statement 0 (contains) from POD0, statement 1 (a_out),
|
||||
// public statement 2 (b_out)
|
||||
// Solution B:
|
||||
// - POD 0 (intermediate): statements 0 (contains), public statement 1 (a_out)
|
||||
// - POD 1 (output): inherits statement 1 (a_out) from POD0, public statement 2 (b_out)
|
||||
|
||||
// Statement 2 (b_out) should be public in POD 1 (it's output-public)
|
||||
assert!(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue