From a4069bcc55e86a5f6e53e0374510f49d75b64b37 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Wed, 25 Mar 2026 18:48:28 +0100 Subject: [PATCH] Fix pod builder (#496) Several fixes and code simplifications: - MainPodBuilder - Fix: It was not tracking Contains statements inherited via input pods (via public statements) when automatically generating Contains statements for Entry arguments. - Enhancement: Deduplicate statements - MultiPodBuilder - Simplify: Remove the "statement groups" logic and instead deduplicate statements in the MainPodBuilder (which is much simpler to do) - Remove the "anchored key" explicit dependency tracking and instead rely on regular dependency tracking by using all the implicit operations and statements generated by MainPodBuilder as input to the solver. - Fix: Count and constrain custom predicates used in a pod instead of batches used --- src/backends/plonky2/mainpod/mod.rs | 3 +- src/frontend/mod.rs | 78 ++++++---- src/frontend/multi_pod/cost.rs | 86 ++--------- src/frontend/multi_pod/deps.rs | 37 +---- src/frontend/multi_pod/mod.rs | 184 +++++++---------------- src/frontend/multi_pod/solver.rs | 222 ++++++---------------------- src/middleware/db/mem.rs | 6 +- 7 files changed, 162 insertions(+), 454 deletions(-) diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 4968316..ae1ade3 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1253,7 +1253,8 @@ pub mod tests { cpr, [1, 1, 2].into_iter().map(middleware::Value::from).collect(), ); - builder.insert(true, (st, op)).unwrap(); + builder.insert((st.clone(), op)).unwrap(); + builder.reveal(&st).unwrap(); let prover = Prover {}; builder.prove(&prover).unwrap(); } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 1ce2795..f23e374 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -137,7 +137,7 @@ pub struct MainPodBuilder { pub operations: Vec, pub public_statements: Vec, // Internal state - dict_contains: Vec<(Value, Value)>, // (root, key) + contains: Vec<(RawValue, RawValue)>, // (root, key) } impl fmt::Display for MainPodBuilder { @@ -171,10 +171,16 @@ impl MainPodBuilder { statements: Vec::new(), operations: Vec::new(), public_statements: Vec::new(), - dict_contains: Vec::new(), + contains: Vec::new(), } } + pub fn stmt_len(&self) -> usize { + self.statements.len() + } pub fn add_pod(&mut self, pod: MainPod) -> Result<()> { + for st in &pod.public_statements { + self.track_contains(st); + } self.input_pods.push(pod); match self.input_pods.len() > self.params.max_input_pods { true => Err(Error::too_many_input_pods( @@ -184,31 +190,26 @@ impl MainPodBuilder { _ => Ok(()), } } - pub fn insert(&mut self, public: bool, st_op: (Statement, Operation)) -> Result<()> { - // TODO: Do error handling instead of panic - let (st, op) = st_op; - // If we're adding a Contains statement with literal arguments (an Entry), track it in - // `dict_contains` to avoid adding it again via `Self::add_entries_contains`. + // If we're adding a Contains statement with literal arguments (an Entry), track it in + // `dict_contains` to avoid adding it again via `Self::add_entries_contains`. + fn track_contains(&mut self, st: &Statement) { if let Statement::Contains( ValueRef::Literal(dict), ValueRef::Literal(key), ValueRef::Literal(_), ) = &st { - let root_key = (dict.clone(), key.clone()); - self.dict_contains.push(root_key); + let root_key = (dict.raw(), key.raw()); + self.contains.push(root_key); } + } + + pub fn insert(&mut self, st_op: (Statement, Operation)) -> Result<()> { + // TODO: Do error handling instead of panic + let (st, op) = st_op; + self.track_contains(&st); - if public { - self.public_statements.push(st.clone()); - } - if self.public_statements.len() > self.params.max_public_statements { - return Err(Error::too_many_public_statements( - self.public_statements.len(), - self.params.max_public_statements, - )); - } self.statements.push(st); self.operations.push(op); if self.statements.len() > self.params.max_statements { @@ -404,7 +405,7 @@ impl MainPodBuilder { } fn op_statement( - &mut self, + &self, wildcard_values: Vec<(usize, Value)>, op: Operation, ) -> Result { @@ -621,7 +622,7 @@ impl MainPodBuilder { } /// For every operation that has Entry statements as arguments we add a Contains statement to - /// open the dictionary. + /// open the dictionary (unless such Contains already exists). fn add_entries_contains(&mut self, op: &Operation) -> Result<()> { for arg in &op.1 { if let OperationArg::Statement(Statement::Contains( @@ -630,9 +631,9 @@ impl MainPodBuilder { ValueRef::Literal(v), )) = arg { - let root_key = (dict.clone(), key.clone()); - if !self.dict_contains.contains(&root_key) { - self.dict_contains.push(root_key); + let root_key = (dict.raw(), key.raw()); + if !self.contains.contains(&root_key) { + self.contains.push(root_key); self.priv_op(Operation::dict_contains(dict, key, v))?; } } @@ -650,13 +651,28 @@ impl MainPodBuilder { self.add_entries_contains(&op)?; let op = Self::fill_in_aux(Self::lower_op(op)?)?; let st = self.op_statement(wildcard_values, op.clone())?; - self.insert(public, (st, op))?; + // Skip adding the statement and operation if it already exists + if !self.statements.contains(&st) { + self.insert((st.clone(), op))?; + } + if public { + self.reveal(&st)?; + } - Ok(self.statements[self.statements.len() - 1].clone()) + Ok(st) } - pub fn reveal(&mut self, st: &Statement) { - self.public_statements.push(st.clone()); + pub fn reveal(&mut self, st: &Statement) -> Result<()> { + if !self.public_statements.contains(st) { + self.public_statements.push(st.clone()); + } + if self.public_statements.len() > self.params.max_public_statements { + return Err(Error::too_many_public_statements( + self.public_statements.len(), + self.params.max_public_statements, + )); + } + Ok(()) } pub fn prove(&self, prover: &dyn MainPodProver) -> Result { @@ -1351,11 +1367,9 @@ pub mod tests { OperationAux::None, ); builder - .insert(false, (value_of_a.clone(), op_contains.clone())) - .unwrap(); - builder - .insert(false, (value_of_b.clone(), op_contains)) + .insert((value_of_a.clone(), op_contains.clone())) .unwrap(); + builder.insert((value_of_b.clone(), op_contains)).unwrap(); let st = Statement::equal( AnchoredKey::from((&local, "a")), AnchoredKey::from((&local, "b")), @@ -1368,7 +1382,7 @@ pub mod tests { ], OperationAux::None, ); - builder.insert(false, (st, op)).unwrap(); + builder.insert((st, op)).unwrap(); let prover = MockProver {}; let pod = builder.prove(&prover).unwrap(); diff --git a/src/frontend/multi_pod/cost.rs b/src/frontend/multi_pod/cost.rs index a5d89da..0c0c2ef 100644 --- a/src/frontend/multi_pod/cost.rs +++ b/src/frontend/multi_pod/cost.rs @@ -6,60 +6,20 @@ use std::collections::BTreeSet; use crate::{ - frontend::{Operation, OperationArg}, - middleware::{ - CustomPredicateBatch, Hash, NativeOperation, OperationType, RawValue, Statement, ValueRef, - }, + frontend::Operation, + middleware::{CustomPredicateRef, Hash, NativeOperation, OperationType, Predicate}, }; -/// Unique identifier for a custom predicate batch. +/// Unique identifier for a custom predicate in a module. /// -/// Uses the batch's cryptographic hash as identifier. Two batches with the same +/// Uses the predicate's cryptographic hash as identifier. Two predicates with the same /// hash are considered identical for resource counting purposes. #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct CustomBatchId(pub Hash); +pub struct CustomPredicateId(pub Hash); -impl From<&CustomPredicateBatch> for CustomBatchId { - fn from(batch: &CustomPredicateBatch) -> Self { - Self(batch.id()) - } -} - -/// Unique identifier for an anchored key (dict, key) pair. -/// -/// When a Contains statement is used as an argument to operations like gt(), eq(), etc., -/// the value is accessed via an "anchored key" - a reference to a specific key in a -/// specific dictionary. Each unique anchored key used in a POD requires a Contains -/// statement to be present in that POD (auto-inserted by MainPodBuilder if needed). -/// -/// We use the raw values of the dict and key for comparison, as they uniquely identify -/// the anchored key regardless of the specific Value types involved. -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct AnchoredKeyId { - /// The dictionary root value (raw representation for Ord). - pub dict: RawValue, - /// The key within the dictionary (raw representation for Ord). - pub key: RawValue, -} - -impl AnchoredKeyId { - /// Create a new anchored key ID from raw values. - pub fn new(dict: RawValue, key: RawValue) -> Self { - Self { dict, key } - } - - /// Try to extract an anchored key ID from a Contains statement with all literal values. - pub fn from_contains_statement(stmt: &Statement) -> Option { - if let Statement::Contains( - ValueRef::Literal(dict), - ValueRef::Literal(key), - ValueRef::Literal(_value), - ) = stmt - { - Some(Self::new(dict.raw(), key.raw())) - } else { - None - } +impl From<&CustomPredicateRef> for CustomPredicateId { + fn from(predicate: &CustomPredicateRef) -> Self { + Self(Predicate::Custom(predicate.clone()).hash()) } } @@ -88,17 +48,9 @@ pub struct StatementCost { /// Limit: `params.max_public_key_of` pub public_key_of: usize, - /// Custom predicate batches used (for batch cardinality constraint). - /// Limit: `params.max_custom_predicate_batches` distinct batches per POD. - pub custom_batch_ids: BTreeSet, - - /// Anchored keys referenced by this operation. - /// - /// When a Contains statement with all literal values is used as an argument, - /// the operation references an "anchored key" (dict, key pair). Each unique - /// anchored key used in a POD incurs an additional Contains statement cost, - /// as MainPodBuilder::add_entries_contains will auto-insert it if not already present. - pub anchored_keys: BTreeSet, + /// Custom predicates used (for custom predicate cardinality constraint). + /// Limit: `params.max_custom_predicates` distinct custom predicates per POD. + pub custom_predicates_ids: BTreeSet, } impl StatementCost { @@ -164,20 +116,8 @@ impl StatementCost { } OperationType::Custom(cpr) => { cost.custom_pred_verifications = 1; - cost.custom_batch_ids - .insert(CustomBatchId::from(&*cpr.batch)); - } - } - - // Extract anchored keys from operation arguments. - // Any argument that is a Contains statement with all literal values - // represents an anchored key reference that will require a Contains - // statement in the POD (auto-inserted by MainPodBuilder if needed). - for arg in &op.1 { - if let OperationArg::Statement(stmt) = arg { - if let Some(anchored_key) = AnchoredKeyId::from_contains_statement(stmt) { - cost.anchored_keys.insert(anchored_key); - } + cost.custom_predicates_ids + .insert(CustomPredicateId::from(cpr)); } } diff --git a/src/frontend/multi_pod/deps.rs b/src/frontend/multi_pod/deps.rs index 97b4ef4..9472a1f 100644 --- a/src/frontend/multi_pod/deps.rs +++ b/src/frontend/multi_pod/deps.rs @@ -5,7 +5,6 @@ use std::collections::HashMap; -use super::cost::AnchoredKeyId; use crate::{ frontend::{Operation, OperationArg}, middleware::{Hash, Statement}, @@ -100,11 +99,6 @@ impl DependencyGraph { pod_hash, statement: dep_stmt.clone(), })); - } else if AnchoredKeyId::from_contains_statement(dep_stmt).is_some() { - // Anchored-key Contains args may be implicit requirements that are - // auto-materialized by MainPodBuilder. They are handled by anchored-key - // resource accounting, not by statement dependency edges. - continue; } else { // Statement arguments should either be internal (created earlier) // or from external PODs (except anchored-key implicit Contains). @@ -128,9 +122,8 @@ impl DependencyGraph { mod tests { use super::*; use crate::{ - dict, frontend::Operation as FrontendOp, - middleware::{AnchoredKey, NativeOperation, OperationAux, OperationType, Value, ValueRef}, + middleware::{NativeOperation, OperationAux, OperationType, Value, ValueRef}, }; fn equal_stmt(n: i64) -> Statement { @@ -195,32 +188,4 @@ mod tests { assert_eq!(graph.statement_deps[1], vec![StatementSource::Internal(0)]); assert_eq!(graph.statement_deps[2], vec![StatementSource::Internal(0)]); } - - #[test] - fn test_anchored_key_contains_arg_is_treated_as_implicit_requirement() { - // A literal Contains statement can be used as an anchored-key argument even when - // no explicit producer statement exists in internal/external statements, because - // MainPodBuilder auto-inserts Contains statements for anchored keys. - let dict = dict!({ - "k" => 7_i64 - }); - - let anchored_contains = Statement::Contains( - ValueRef::Literal(Value::from(dict.clone())), - ValueRef::Literal(Value::from("k")), - ValueRef::Literal(Value::from(7_i64)), - ); - let ak = AnchoredKey::from((&dict, "k")); - let produced_statement = Statement::Equal(ValueRef::Key(ak.clone()), ValueRef::Key(ak)); - - // Use a typical frontend operation that consumes entry-like args. - // We're only testing the dependency graph, not the actual proof, so the operation - // just needs to have the right arguments to test what we're looking for. - let statements = vec![produced_statement]; - let operations = vec![FrontendOp::eq(anchored_contains.clone(), anchored_contains)]; - - let graph = DependencyGraph::build(&statements, &operations, &HashMap::new()); - - assert!(graph.statement_deps[0].is_empty()); - } } diff --git a/src/frontend/multi_pod/mod.rs b/src/frontend/multi_pod/mod.rs index d25fcce..6bade5b 100644 --- a/src/frontend/multi_pod/mod.rs +++ b/src/frontend/multi_pod/mod.rs @@ -48,12 +48,12 @@ //! [`MainPodBuilder`]: crate::frontend::MainPodBuilder use std::{ - collections::{BTreeMap, BTreeSet, HashMap}, + collections::{BTreeSet, HashMap}, fmt, }; use crate::{ - frontend::{MainPod, MainPodBuilder, Operation, OperationArg}, + frontend::{MainPod, MainPodBuilder, Operation}, middleware::{Hash, MainPodProver, Params, Statement, VDSet, Value}, }; @@ -61,7 +61,7 @@ mod cost; mod deps; mod solver; -use cost::{AnchoredKeyId, StatementCost}; +use cost::StatementCost; use deps::{DependencyGraph, StatementSource}; pub use solver::MultiPodSolution; @@ -168,12 +168,8 @@ pub struct MultiPodBuilder { options: Options, /// External input PODs (already proved). input_pods: Vec, - /// Statements created by this builder. - statements: Vec, - /// Operations that produce each statement. - operations: Vec, /// Optional initial wildcard values for custom operations - operations_wildcard_values: Vec>, + operations_wildcard_values: HashMap>, /// Indices of statements that should be public in output PODs. /// Uses Vec since max_public_statements is small (≤8); indices are naturally sorted. output_public_indices: Vec, @@ -193,7 +189,7 @@ pub struct SolvedMultiPod { statements: Vec, operations: Vec, output_public_indices: Vec, - operations_wildcard_values: Vec>, + operations_wildcard_values: HashMap>, solution: MultiPodSolution, deps: DependencyGraph, } @@ -260,56 +256,27 @@ impl SolvedMultiPod { let statements_sorted: BTreeSet = statements_in_this_pod.iter().copied().collect(); let public_set = &solution.pod_public_statements[pod_idx]; - // Track statements proved locally in this POD for argument remapping. - // We index by statement content so duplicate statements can reuse a single - // built statement slot in MainPodBuilder. - let mut added_statements_by_content: HashMap = HashMap::new(); - for &stmt_idx in &statements_sorted { - let original_stmt = self.statements[stmt_idx].clone(); - - // If this statement content was already built in this POD, reuse it instead - // of replaying the operation. If any duplicate is public, reveal the - // already-built statement. - if let Some(_existing_stmt) = added_statements_by_content.get(&original_stmt) { - continue; - } - - let mut op = self.operations[stmt_idx].clone(); - let wildcard_values = self.operations_wildcard_values[stmt_idx].clone(); - - // Remap Statement arguments that reference locally-proved statements. - // For external dependencies (from input PODs including earlier generated PODs), - // the original Statement is used directly - MainPodBuilder will find it in - // the input POD's public statements via find_op_arg. - for arg in &mut op.1 { - if let OperationArg::Statement(ref orig_stmt) = arg { - if let Some(remapped_stmt) = added_statements_by_content.get(orig_stmt) { - *arg = OperationArg::Statement(remapped_stmt.clone()); - } - } - } + let op = self.operations[stmt_idx].clone(); + let wildcard_values = self + .operations_wildcard_values + .get(&stmt_idx) + .cloned() + .unwrap_or_default(); let stmt = builder.op(false, wildcard_values, op)?; - - added_statements_by_content.insert(original_stmt, stmt); + assert_eq!(stmt, self.statements[stmt_idx]); // Sanity check } // For the output pod, make statements public in the original order. // Intermediate pods use the solver-selected public set. if pod_idx == solution.pod_count - 1 { for idx in &self.output_public_indices { - let stmt = added_statements_by_content - .get(&self.statements[*idx]) - .expect("exists"); - builder.reveal(stmt); + builder.reveal(&self.statements[*idx])?; } } else { for idx in public_set { - let stmt = added_statements_by_content - .get(&self.statements[*idx]) - .expect("exists"); - builder.reveal(stmt); + builder.reveal(&self.statements[*idx])?; } } @@ -317,7 +284,7 @@ impl SolvedMultiPod { // for this POD. These do not require local proving in this POD. for ext_premise_idx in &solution.pod_public_external_premises[pod_idx] { let ext_premise = &solution.external_premises[*ext_premise_idx]; - builder.reveal(&ext_premise.statement); + builder.reveal(&ext_premise.statement)?; } // Step 4: Prove the POD @@ -456,9 +423,7 @@ impl MultiPodBuilder { options, builder, input_pods: Vec::new(), - statements: Vec::new(), - operations: Vec::new(), - operations_wildcard_values: Vec::new(), + operations_wildcard_values: HashMap::new(), output_public_indices: Vec::new(), } } @@ -480,6 +445,16 @@ impl MultiPodBuilder { self.op(false, vec![], op) } + // Find the index of a statement that has been added. Panics if the statement doesn't + // exist. + fn stmt_index(&self, stmt: &Statement) -> usize { + self.builder + .statements + .iter() + .position(|s| s == stmt) + .expect("exists") + } + pub fn op( &mut self, public: bool, @@ -488,8 +463,10 @@ impl MultiPodBuilder { ) -> Result { let stmt = self.add_operation(wildcard_values, op)?; if public { - // Index is always new (just added), so push without duplicate check - self.output_public_indices.push(self.statements.len() - 1); + let index = self.stmt_index(&stmt); + if !self.output_public_indices.contains(&index) { + self.output_public_indices.push(index); + } } Ok(stmt) } @@ -510,10 +487,8 @@ impl MultiPodBuilder { let stmt = self .builder .op(false, wildcard_values.clone(), op.clone())?; - - self.statements.push(stmt.clone()); - self.operations.push(op); - self.operations_wildcard_values.push(wildcard_values); + self.operations_wildcard_values + .insert(self.stmt_index(&stmt), wildcard_values.clone()); Ok(stmt) } @@ -523,7 +498,7 @@ impl MultiPodBuilder { /// Returns an error if the statement was not found in the builder. /// Calling this multiple times on the same statement is idempotent. pub fn reveal(&mut self, stmt: &Statement) -> Result<()> { - if let Some(idx) = self.statements.iter().position(|s| s == stmt) { + if let Some(idx) = self.builder.statements.iter().position(|s| s == stmt) { if !self.output_public_indices.contains(&idx) { self.output_public_indices.push(idx); } @@ -536,8 +511,8 @@ impl MultiPodBuilder { } /// Get the number of statements. - pub fn num_statements(&self) -> usize { - self.statements.len() + pub fn stmt_len(&self) -> usize { + self.builder.stmt_len() } /// Solve the packing problem and return a solved builder ready for proving. @@ -545,66 +520,31 @@ impl MultiPodBuilder { /// This runs the MILP solver to find the optimal POD assignment. /// Consumes the builder and returns a [`SolvedMultiPod`] that can be proved. pub fn solve(self) -> Result { + let MainPodBuilder { + statements, + operations, + .. + } = self.builder; // Compute costs for each statement - let costs: Vec = self - .operations + let costs: Vec = operations .iter() .map(StatementCost::from_operation) .collect(); - // Collect all unique anchored keys from the costs - let all_anchored_keys: Vec = costs - .iter() - .flat_map(|c| c.anchored_keys.iter().cloned()) - .collect::>() - .into_iter() - .collect(); - - // Build map from anchored key to its producing statement index (if any). - // A Contains statement with literal (dict, key, value) "produces" that anchored key. - let mut ak_to_producer: HashMap = HashMap::new(); - for (stmt_idx, stmt) in self.statements.iter().enumerate() { - if let Some(ak) = AnchoredKeyId::from_contains_statement(stmt) { - // First producer wins (shouldn't have duplicates in practice) - ak_to_producer.entry(ak).or_insert(stmt_idx); - } - } - - // Build parallel array: anchored_key_producers[i] = producer for all_anchored_keys[i] - let anchored_key_producers: Vec> = all_anchored_keys - .iter() - .map(|ak| ak_to_producer.get(ak).copied()) - .collect(); - // Build external POD statement mapping let external_pod_statements = build_external_statement_map(&self.input_pods); // Build dependency graph - let deps = - DependencyGraph::build(&self.statements, &self.operations, &external_pod_statements); - - // Build statement content groups for deduplication. - // Statements with identical content share a single slot in the POD. - // Keep groups ordered by first occurrence index for deterministic solver input. - let mut first_idx_by_stmt: HashMap<&Statement, usize> = HashMap::new(); - let mut groups_by_first_idx: BTreeMap> = BTreeMap::new(); - for (idx, stmt) in self.statements.iter().enumerate() { - let first_idx = *first_idx_by_stmt.entry(stmt).or_insert(idx); - groups_by_first_idx.entry(first_idx).or_default().push(idx); - } - let statement_content_groups: Vec> = groups_by_first_idx.into_values().collect(); + let deps = DependencyGraph::build(&statements, &operations, &external_pod_statements); // Run solver let input = solver::SolverInput { - num_statements: self.statements.len(), + num_statements: statements.len(), costs: &costs, deps: &deps, output_public_indices: &self.output_public_indices, params: &self.params, max_pods: self.options.max_pods, - all_anchored_keys: &all_anchored_keys, - anchored_key_producers: &anchored_key_producers, - statement_content_groups: &statement_content_groups, }; let solution = solver::solve(&input)?; @@ -613,8 +553,8 @@ impl MultiPodBuilder { params: self.params, vd_set: self.vd_set, input_pods: self.input_pods, - statements: self.statements, - operations: self.operations, + statements, + operations, output_public_indices: self.output_public_indices, operations_wildcard_values: self.operations_wildcard_values, solution, @@ -845,33 +785,13 @@ mod tests { let solution = solved.solution(); // Expected: exactly 2 PODs - // - POD 0 (intermediate): statements 0 (contains), 1 (a_out); a_out is public - // - POD 1 (output): statement 2 (b_out); b_out is public - // The output POD accesses a_out from POD 0 to satisfy b_out's dependency. - assert_eq!( - solution.pod_count, 2, - "Expected exactly 2 PODs for 3-statement chain with max_priv=2" - ); - - // POD 0 should contain statements 0 and 1 (contains and a_out) - assert!( - solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1), - "POD 0 should contain statements 0 (contains) and 1 (a_out), got {:?}", - solution.pod_statements[0] - ); - - // Statement 1 (a_out) should be public in POD 0 so POD 1 can access it - assert!( - solution.pod_public_statements[0].contains(&1), - "Statement 1 (a_out) should be public in POD 0" - ); - - // POD 1 (output) should contain statement 2 (b_out) - assert!( - solution.pod_statements[1].contains(&2), - "POD 1 should contain statement 2 (b_out), got {:?}", - solution.pod_statements[1] - ); + // Solution A: + // - POD 0 (intermediate): public statements 0 (contains) + // - POD 1 (output): inherits statement 0 (contains) from POD0, statement 1 (a_out), + // public statement 2 (b_out) + // Solution B: + // - POD 0 (intermediate): statements 0 (contains), public statement 1 (a_out) + // - POD 1 (output): inherits statement 1 (a_out) from POD0, public statement 2 (b_out) // Statement 2 (b_out) should be public in POD 1 (it's output-public) assert!( diff --git a/src/frontend/multi_pod/solver.rs b/src/frontend/multi_pod/solver.rs index 9a24fb0..db1502e 100644 --- a/src/frontend/multi_pod/solver.rs +++ b/src/frontend/multi_pod/solver.rs @@ -52,7 +52,7 @@ use itertools::Itertools; use super::Result; use crate::{ frontend::multi_pod::{ - cost::{AnchoredKeyId, CustomBatchId, StatementCost}, + cost::{CustomPredicateId, StatementCost}, deps::{DependencyGraph, ExternalDependency, StatementSource}, }, middleware::{Hash, Params}, @@ -95,7 +95,6 @@ struct DependencyStats { struct SolveDebugContext { dep_stats: DependencyStats, batch_memberships: usize, - anchored_key_memberships: usize, } #[derive(Clone, Copy, Debug, Default)] @@ -105,10 +104,8 @@ struct ModelSizeEstimate { vars_public_external: usize, vars_pod_used: usize, vars_batch_used: usize, - vars_anchored_key_used: usize, vars_uses_input: usize, vars_uses_external: usize, - vars_content_group_used: usize, vars_total: usize, c1_coverage: usize, c2_output_public: usize, @@ -120,7 +117,6 @@ struct ModelSizeEstimate { c6_pre_content_group: usize, c6_resource_limits: usize, c7_batch_cardinality: usize, - c7b_anchored_key_tracking: usize, c8a_internal_inputs: usize, c8b_external_dep_inputs: usize, c8c_external_forward_inputs: usize, @@ -141,8 +137,6 @@ impl ModelSizeEstimate { debug_ctx: &SolveDebugContext, ) -> Self { let n = input.num_statements; - let num_groups = input.statement_content_groups.len(); - let num_anchored_keys = input.all_anchored_keys.len(); let triangular_k = target_pods * target_pods.saturating_sub(1) / 2; let vars_prove = n * target_pods; @@ -150,19 +144,15 @@ impl ModelSizeEstimate { let vars_public_external = external_premises_len * target_pods; let vars_pod_used = target_pods; let vars_batch_used = all_batches_len * target_pods; - let vars_anchored_key_used = num_anchored_keys * target_pods; let vars_uses_input = triangular_k; let vars_uses_external = external_pods_len * target_pods; - let vars_content_group_used = num_groups * target_pods; let vars_total = vars_prove + vars_public + vars_public_external + vars_pod_used + vars_batch_used - + vars_anchored_key_used + vars_uses_input - + vars_uses_external - + vars_content_group_used; + + vars_uses_external; let c1_coverage = n; let c2_output_public = input.output_public_indices.len(); @@ -171,12 +161,10 @@ impl ModelSizeEstimate { let c4_pod_existence = n * target_pods; let c5_internal_dependencies = debug_ctx.dep_stats.internal_edges * target_pods; let c5_external_dependencies = debug_ctx.dep_stats.external_edges * target_pods; - let c6_pre_content_group = (n * target_pods) + (num_groups * target_pods); + let c6_pre_content_group = n * target_pods; let c6_resource_limits = 7 * target_pods; let c7_batch_cardinality = (debug_ctx.batch_memberships * target_pods) + (all_batches_len * target_pods); - let c7b_anchored_key_tracking = - (debug_ctx.anchored_key_memberships * target_pods) + (num_anchored_keys * target_pods); let c8a_internal_inputs = debug_ctx.dep_stats.internal_edges * triangular_k; let c8b_external_dep_inputs = debug_ctx.dep_stats.external_edges * triangular_k; let c8c_external_forward_inputs = external_premises_len * triangular_k; @@ -194,7 +182,6 @@ impl ModelSizeEstimate { + c6_pre_content_group + c6_resource_limits + c7_batch_cardinality - + c7b_anchored_key_tracking + c8a_internal_inputs + c8b_external_dep_inputs + c8c_external_forward_inputs @@ -209,10 +196,8 @@ impl ModelSizeEstimate { vars_public_external, vars_pod_used, vars_batch_used, - vars_anchored_key_used, vars_uses_input, vars_uses_external, - vars_content_group_used, vars_total, c1_coverage, c2_output_public, @@ -224,7 +209,6 @@ impl ModelSizeEstimate { c6_pre_content_group, c6_resource_limits, c7_batch_cardinality, - c7b_anchored_key_tracking, c8a_internal_inputs, c8b_external_dep_inputs, c8c_external_forward_inputs, @@ -300,6 +284,7 @@ pub struct MultiPodSolution { } /// Input to the MILP solver. +#[derive(Debug)] pub struct SolverInput<'a> { /// Number of statements. pub num_statements: usize, @@ -318,28 +303,6 @@ pub struct SolverInput<'a> { /// Maximum number of PODs the solver will consider. pub max_pods: usize, - - /// All unique anchored keys referenced by any statement. - /// - /// Each unique (dict, key) pair that is used as an anchored key reference - /// in any operation. When a Contains statement with literal values is used - /// as an argument, it creates an anchored key reference. - pub all_anchored_keys: &'a [AnchoredKeyId], - - /// For each anchored key, the statement index that produces it (if any). - /// - /// When a Contains statement with literal (dict, key, value) args is explicitly - /// added, it "produces" that anchored key. If the producer is in the same POD - /// as statements using the anchored key, no auto-insertion is needed. - /// `anchored_key_producers[i]` corresponds to `all_anchored_keys[i]`. - pub anchored_key_producers: &'a [Option], - - /// Statement content groups for deduplication. - /// - /// Each inner Vec contains statement indices that have identical content. - /// When multiple statements with the same content are proved in the same POD, - /// they only use one statement slot (the POD deduplicates identical statements). - pub statement_content_groups: &'a [Vec], } /// Solve the MILP problem to find optimal POD packing. @@ -386,11 +349,11 @@ pub fn solve(input: &SolverInput) -> Result { ))); } - // Collect all unique custom batch IDs used - let all_batches: Vec = input + // Collect all unique custom predicate IDs used + let all_custom_predicates: Vec = input .costs .iter() - .flat_map(|c| c.custom_batch_ids.iter().cloned()) + .flat_map(|c| c.custom_predicates_ids.iter().cloned()) .unique() .collect(); @@ -417,18 +380,19 @@ pub fn solve(input: &SolverInput) -> Result { } let dep_stats = dependency_stats(input.deps); - let batch_memberships: usize = input.costs.iter().map(|c| c.custom_batch_ids.len()).sum(); - let anchored_key_memberships: usize = input.costs.iter().map(|c| c.anchored_keys.len()).sum(); + let batch_memberships: usize = input + .costs + .iter() + .map(|c| c.custom_predicates_ids.len()) + .sum(); let debug_ctx = SolveDebugContext { dep_stats, batch_memberships, - anchored_key_memberships, }; if log::log_enabled!(log::Level::Debug) { let resource_totals = ResourceTotals::from_costs(input.costs); - let lb_statement_groups = - lower_bound_from_total(input.statement_content_groups.len(), max_stmts_per_pod); + let lb_statement_groups = lower_bound_from_total(input.num_statements, max_stmts_per_pod); let lb_merkle = lower_bound_from_total( resource_totals.merkle_proofs, input.params.max_merkle_proofs_containers, @@ -463,14 +427,12 @@ pub fn solve(input: &SolverInput) -> Result { .expect("non-empty lower-bound candidate list"); log::debug!( - "MILP summary: statements={} output_public={} content_groups={} anchored_keys={} \ - batches={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \ + "MILP summary: statements={} output_public={} \ + custom_predicates={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \ external_premises={} search_min_pods={} max_pods={}", n, num_output_public, - input.statement_content_groups.len(), - input.all_anchored_keys.len(), - all_batches.len(), + all_custom_predicates.len(), dep_stats.internal_edges, dep_stats.external_edges, external_pods.len(), @@ -481,14 +443,13 @@ pub fn solve(input: &SolverInput) -> Result { log::debug!( "MILP resource totals: merkle_proofs={} merkle_state_transitions={} \ custom_pred_verifications={} signed_by={} public_key_of={} \ - batch_memberships={} anchored_key_memberships={}", + batch_memberships={}", resource_totals.merkle_proofs, resource_totals.merkle_state_transitions, resource_totals.custom_pred_verifications, resource_totals.signed_by, resource_totals.public_key_of, batch_memberships, - anchored_key_memberships ); log::debug!( "MILP lower bounds (pods): statements_raw={} statements_dedup={} merkle_proofs={} \ @@ -513,7 +474,7 @@ pub fn solve(input: &SolverInput) -> Result { if let Some(solution) = try_solve_with_pods( input, target_pods, - &all_batches, + &all_custom_predicates, &external_pods, &external_premises, &debug_ctx, @@ -540,7 +501,7 @@ pub fn solve(input: &SolverInput) -> Result { fn try_solve_with_pods( input: &SolverInput, target_pods: usize, - all_batches: &[CustomBatchId], + all_custom_predicates: &[CustomPredicateId], external_pods: &[Hash], external_premises: &[ExternalDependency], debug_ctx: &SolveDebugContext, @@ -574,21 +535,8 @@ fn try_solve_with_pods( .map(|_| vars.add(variable().binary())) .collect(); - // batch_used[b][p] - custom batch b is used in POD p - let batch_used: Vec> = (0..all_batches.len()) - .map(|_| { - (0..target_pods) - .map(|_| vars.add(variable().binary())) - .collect() - }) - .collect(); - - // anchored_key_used[ak][p] - anchored key ak is used in POD p - // When a statement references an anchored key (via a Contains statement argument), - // that POD must have a Contains statement for that (dict, key) pair. - // MainPodBuilder::add_entries_contains auto-inserts these, and we must account - // for them in the statement count. - let anchored_key_used: Vec> = (0..input.all_anchored_keys.len()) + // custom_predicates[b][p] - custom predicate b is used in POD p + let custom_predicate_used: Vec> = (0..all_custom_predicates.len()) .map(|_| { (0..target_pods) .map(|_| vars.add(variable().binary())) @@ -633,31 +581,19 @@ fn try_solve_with_pods( .map(|(i, ext)| (ext.clone(), i)) .collect(); - // content_group_used[g][p] - content group g has at least one statement proved in POD p - // When multiple statements have identical content, they share a slot in the POD. - // This variable tracks whether at least one statement from each content group is proved. - let num_groups = input.statement_content_groups.len(); - let content_group_used: Vec> = (0..num_groups) - .map(|_| { - (0..target_pods) - .map(|_| vars.add(variable().binary())) - .collect() - }) - .collect(); - if log::log_enabled!(log::Level::Debug) { let estimate = ModelSizeEstimate::for_target_pods( input, target_pods, - all_batches.len(), + all_custom_predicates.len(), external_pods.len(), external_premises.len(), debug_ctx, ); log::debug!( "MILP(k={}) model estimate vars_total={} [prove={} public={} pod_used={} \ - public_external={} batch_used={} anchored_key_used={} uses_input={} \ - uses_external={} content_group_used={}]", + public_external={} batch_used={} uses_input={} \ + uses_external={}]", target_pods, estimate.vars_total, estimate.vars_prove, @@ -665,14 +601,12 @@ fn try_solve_with_pods( estimate.vars_pod_used, estimate.vars_public_external, estimate.vars_batch_used, - estimate.vars_anchored_key_used, estimate.vars_uses_input, estimate.vars_uses_external, - estimate.vars_content_group_used ); log::debug!( "MILP(k={}) model estimate constraints_total={} [c1={} c2={} c2b={} c3={} c4={} \ - c5i={} c5e={} c6_pre={} c6_limits={} c7={} c7b={} c8a={} c8b={} c8c={} \ + c5i={} c5e={} c6_pre={} c6_limits={} c7={} c8a={} c8b={} c8c={} \ c8d={} c9={} c10={} c10b={}]", target_pods, estimate.constraints_total, @@ -686,7 +620,6 @@ fn try_solve_with_pods( estimate.c6_pre_content_group, estimate.c6_resource_limits, estimate.c7_batch_cardinality, - estimate.c7b_anchored_key_tracking, estimate.c8a_internal_inputs, estimate.c8b_external_dep_inputs, estimate.c8c_external_forward_inputs, @@ -798,35 +731,11 @@ fn try_solve_with_pods( } } - // Constraint 6: Resource limits per POD - // - // 6a-pre: Content group tracking for statement deduplication - // When multiple statement indices have identical content, they share a single slot in the POD. - // content_group_used[g][p] = 1 iff at least one statement from group g is proved in POD p. - for (g, group) in input.statement_content_groups.iter().enumerate() { - for p in 0..target_pods { - // Lower bound: if any statement in the group is proved, the group is used - for &s in group { - model.add_constraint(constraint!(content_group_used[g][p] >= prove[s][p])); - } - // Upper bound: if no statements in the group are proved, the group is not used - let group_prove_sum: Expression = group.iter().map(|&s| prove[s][p]).sum(); - model.add_constraint(constraint!(content_group_used[g][p] <= group_prove_sum)); - } - } - for p in 0..target_pods { - // 6a: Unique statement count (unique content groups + anchored key Contains) - // Statements with identical content share a slot, so we count content groups, not indices. - // Anchored key Contains statements are auto-inserted by MainPodBuilder when needed. - // The total must not exceed max_priv_statements (= max_statements - max_public_statements). - let unique_stmt_sum: Expression = (0..num_groups).map(|g| content_group_used[g][p]).sum(); - let anchored_key_sum: Expression = (0..input.all_anchored_keys.len()) - .map(|ak| anchored_key_used[ak][p]) - .sum(); + // 6a: Statement count + let stmt_sum: Expression = (0..n).map(|g| prove[g][p]).sum(); model.add_constraint(constraint!( - unique_stmt_sum + anchored_key_sum - <= (input.params.max_priv_statements() as f64) * pod_used[p] + stmt_sum <= (input.params.max_priv_statements() as f64) * pod_used[p] )); // 6b: Public statement count (internal public statements + forwarded external premises) @@ -885,67 +794,31 @@ fn try_solve_with_pods( } // Constraint 7: Batch cardinality - // batch_used[b][p] >= prove[s][p] for all s that use batch b (batch is used if any statement uses it) - // batch_used[b][p] <= sum of prove[s][p] for all s using batch b (batch is 0 if no statements use it) - for (b, batch_id) in all_batches.iter().enumerate() { + // custom_predicate_used[b][p] >= prove[s][p] for all s that use custom predicate b (custom + // predicate is used if any statement uses it) + // custom_predicate_used[b][p] <= sum of prove[s][p] for all s using custom predicate b (custom + // predicate is 0 if no statements use it) + for (b, predicate_id) in all_custom_predicates.iter().enumerate() { for p in 0..target_pods { let mut sum: Expression = 0.into(); for s in 0..n { - if input.costs[s].custom_batch_ids.contains(batch_id) { - model.add_constraint(constraint!(batch_used[b][p] >= prove[s][p])); + if input.costs[s].custom_predicates_ids.contains(predicate_id) { + model.add_constraint(constraint!(custom_predicate_used[b][p] >= prove[s][p])); sum += prove[s][p]; } } - model.add_constraint(constraint!(batch_used[b][p] <= sum)); + model.add_constraint(constraint!(custom_predicate_used[b][p] <= sum)); } } - // Constraint 7b: Anchored key tracking - // - // anchored_key_used[ak][p] = 1 when auto-insertion of a Contains is needed for anchored key ak in POD p. - // This happens when: some statement using ak is in POD p, AND the producing Contains is NOT in POD p. - // - // If a Contains statement explicitly produces ak (anchored_key_producers[ak] = Some(prod_idx)): - // - Lower: anchored_key_used[ak][p] >= prove[s][p] - prove[prod_idx][p] for all s using ak - // - Upper: anchored_key_used[ak][p] <= 1 - prove[prod_idx][p] - // This ensures overhead is 0 when the producer is in the same POD. - // - // If no Contains produces ak (anchored_key_producers[ak] = None): - // - Lower: anchored_key_used[ak][p] >= prove[s][p] for all s using ak - // - Upper: anchored_key_used[ak][p] <= sum of prove[s][p] for all s using ak - // Auto-insertion is always needed when any user is present. - for (ak_idx, ak) in input.all_anchored_keys.iter().enumerate() { - let producer = input.anchored_key_producers[ak_idx]; - - for p in 0..target_pods { - let mut user_sum: Expression = 0.into(); - for s in 0..n { - if input.costs[s].anchored_keys.contains(ak) { - if let Some(prod_idx) = producer { - // Producer exists: only count overhead if producer not in this POD - model.add_constraint(constraint!( - anchored_key_used[ak_idx][p] >= prove[s][p] - prove[prod_idx][p] - )); - } else { - // No producer: always need auto-insertion if user is present - model.add_constraint(constraint!( - anchored_key_used[ak_idx][p] >= prove[s][p] - )); - } - user_sum += prove[s][p]; - } - } - - if let Some(prod_idx) = producer { - // If producer is in POD, no auto-insertion needed (overhead = 0) - model.add_constraint(constraint!( - anchored_key_used[ak_idx][p] <= 1 - prove[prod_idx][p] - )); - } else { - // No producer: overhead is bounded by whether any user is present - model.add_constraint(constraint!(anchored_key_used[ak_idx][p] <= user_sum)); - } - } + // Custom predicate count per POD + for p in 0..target_pods { + let custom_predicate_sum: Expression = (0..all_custom_predicates.len()) + .map(|b| custom_predicate_used[b][p]) + .sum(); + model.add_constraint(constraint!( + custom_predicate_sum <= (input.params.max_custom_predicates as f64) * pod_used[p] + )); } // Constraint 8a: Internal input POD tracking using uses_input. @@ -1147,9 +1020,6 @@ mod tests { output_public_indices: &[], params: ¶ms, max_pods: 20, - all_anchored_keys: &[], - anchored_key_producers: &[], - statement_content_groups: &[], }; let result = solve(&input); @@ -1195,7 +1065,6 @@ mod tests { }; let costs = vec![StatementCost::default(), StatementCost::default()]; - let statement_content_groups = vec![vec![0], vec![1]]; let output_public = vec![1]; let input = SolverInput { @@ -1205,9 +1074,6 @@ mod tests { output_public_indices: &output_public, params: ¶ms, max_pods: 4, - all_anchored_keys: &[], - anchored_key_producers: &[], - statement_content_groups: &statement_content_groups, }; let solution = solve(&input).expect("solver should find a feasible forwarding layout"); diff --git a/src/middleware/db/mem.rs b/src/middleware/db/mem.rs index 53ab91e..71211fa 100644 --- a/src/middleware/db/mem.rs +++ b/src/middleware/db/mem.rs @@ -43,8 +43,10 @@ impl DB for MemDB { let mut values = self.values.write().expect("lock not poisoned"); let value_raw = value.raw(); if let Some(old_value) = values.get(&value_raw) { - // If we had a non-raw value stored never overwrite it with a raw value - if !old_value.is_raw() && value.is_raw() { + let old_is_raw = old_value.is_raw(); + // If we had a non-RawValue stored don't overwrite it (specially not with a + // RawValue). Also skip redundant RawValue overwrite. + if !old_is_raw || value.is_raw() { return Ok(()); } }