From 32f45872d71aef30a6a39b683a529daf62cecf95 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Wed, 11 Mar 2026 16:32:42 +0100 Subject: [PATCH 01/10] Re-implement merkletree with persistent storage (key-value db) (#487) * refactor merkletree to work with disk keyvalue database (wip) * various fixes post reimplementation; pending delete leaf * add delete operation case for the new in db tree approach * polish tree update & delete; everything works (pending polishing) * polish panics into errs, prints, etc * Implement iterator * Lint * fix case no-siblings * case delete with semi-empty branch * polishing * starting to add rocksdb & heeddb for the DB & Txn traits * Satisfy the borrow checker * abstract merkletree tests to use the various available DBs * update store_node interface (rm hash input), rm heed.rs * polishing * typos * Ditch transactions * add feature for rocksdb, return errs at new_with_db, remove empty leaf case in Leaf::new * intermediate instead of leaf in empty node when deleting leaf --------- Co-authored-by: Ahmad --- Cargo.toml | 5 +- src/backends/plonky2/circuits/mainpod.rs | 8 +- .../plonky2/primitives/merkletree/circuit.rs | 13 +- .../plonky2/primitives/merkletree/db/mod.rs | 109 ++ .../plonky2/primitives/merkletree/db/rocks.rs | 58 + .../plonky2/primitives/merkletree/mod.rs | 1212 +++++++++++------ 6 files changed, 974 insertions(+), 431 deletions(-) create mode 100644 src/backends/plonky2/primitives/merkletree/db/mod.rs create mode 100644 src/backends/plonky2/primitives/merkletree/db/rocks.rs diff --git a/Cargo.toml b/Cargo.toml index a1f7511..c3329ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ good_lp = { version = "1.8", default-features = false, features = [ "scip_bundled", ] } annotate-snippets = "0.11" +rocksdb = { version = "0.24.0", optional = true } # keyvalue database for merkletree # Uncomment for debugging with https://github.com/ed255/plonky2/ at branch `feat/debug`. The repo directory needs to be checked out next to the pod2 repo directory. # [patch."https://github.com/0xPARC/plonky2"] @@ -57,12 +58,13 @@ annotate-snippets = "0.11" pretty_assertions = "1.4.1" # Used only for testing JSON Schema generation and validation. jsonschema = "0.30.0" +tempfile = "3" [build-dependencies] vergen-gitcl = { version = "1.0.0", features = ["build"] } [features] -default = ["backend_plonky2", "zk", "mem_cache"] +default = ["backend_plonky2", "zk", "mem_cache", "db_rocksdb"] backend_plonky2 = ["plonky2"] zk = [] metrics = [] @@ -70,6 +72,7 @@ time = [] examples = [] disk_cache = ["directories", "minicbor-serde"] mem_cache = [] +db_rocksdb = ["rocksdb"] # Uncomment in order to enable debug information in the release builds. This allows getting panic backtraces with a performance similar to regular release. # [profile.release] diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index ebe77b4..b0c8f48 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -2276,9 +2276,9 @@ mod tests { ] .into_iter() .for_each(|(op, st)| { - let check = std::panic::catch_unwind(|| { + let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { operation_verify(st, op, prev_statements.to_vec(), Aux::default()) - }); + })); match check { Err(e) => { let err_string = e.downcast_ref::().unwrap(); @@ -2689,9 +2689,9 @@ mod tests { ); let prev_statements = [Statement::None.into()]; - let check = std::panic::catch_unwind(|| { + let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { operation_verify(st, op, prev_statements.to_vec(), Aux::default()) - }); + })); match check { Err(e) => { let err_string = e.downcast_ref::().unwrap(); diff --git a/src/backends/plonky2/primitives/merkletree/circuit.rs b/src/backends/plonky2/primitives/merkletree/circuit.rs index 0c5978f..2c54b8b 100644 --- a/src/backends/plonky2/primitives/merkletree/circuit.rs +++ b/src/backends/plonky2/primitives/merkletree/circuit.rs @@ -32,7 +32,7 @@ use crate::{ circuits::common::{CircuitBuilderPod, ValueTarget}, error::{Error, Result}, primitives::merkletree::{ - MerkleClaimAndProof, MerkleTreeOp, MerkleTreeStateTransitionProof, TreeError, + MerkleClaimAndProof, MerkleTreeOp, MerkleTreeStateTransitionProof, TreeError, MAX_DEPTH, }, }, measure_gates_begin, measure_gates_end, @@ -703,10 +703,13 @@ impl MerkleTreeStateTransitionProofTarget { { pw.set_hash_target(self.siblings[i], HashOut::from_vec(sibling.0.to_vec()))?; } - pw.set_target( - self.divergence_level, - F::from_canonical_u64((new_siblings.len() - 1) as u64), - )?; + let div_lvl = if new_siblings.is_empty() { + // don't subtract since it would underflow, use MAX_DEPTH + MAX_DEPTH as u64 + } else { + (new_siblings.len() - 1) as u64 + }; + pw.set_target(self.divergence_level, F::from_canonical_u64(div_lvl))?; Ok(()) } diff --git a/src/backends/plonky2/primitives/merkletree/db/mod.rs b/src/backends/plonky2/primitives/merkletree/db/mod.rs new file mode 100644 index 0000000..3402d1d --- /dev/null +++ b/src/backends/plonky2/primitives/merkletree/db/mod.rs @@ -0,0 +1,109 @@ +//! Module that implements the key-value DB used at the MerkleTree module. + +use std::{ + collections::HashMap, + fmt::Debug, + sync::{Arc, Mutex}, +}; + +use anyhow::{anyhow, bail, Result}; +use dyn_clone::DynClone; + +use crate::{ + backends::plonky2::primitives::merkletree::{Leaf, Node}, + middleware::{RawValue, EMPTY_VALUE}, +}; + +#[cfg(feature = "db_rocksdb")] +pub mod rocks; + +pub trait DB: Debug + DynClone + Sync + Send { + fn load_node(&self, hash: RawValue) -> Result; + fn store_node(&mut self, node: Node) -> Result<()>; +} +dyn_clone::clone_trait_object!(DB); + +/// MemDB implements the DB trait in a in-memory HashMap. +#[derive(Clone, Debug, Default)] +pub(crate) struct MemDB { + inner: Arc>>, +} + +impl MemDB { + pub fn new() -> Self { + Self::default() + } +} + +impl DB for MemDB { + fn load_node(&self, hash: RawValue) -> Result { + let db = self + .inner + .lock() + .map_err(|e| anyhow!("failed to acquire memdb lock for read: {}", e))?; + + if let Some(node) = db.get(&hash) { + return Ok(node.clone()); + } + + if hash == EMPTY_VALUE { + return Ok(Node::Leaf(Leaf::new(hash, EMPTY_VALUE))); + } + + bail!("MemDB error: node not found: {}", hash); + } + + fn store_node(&mut self, node: Node) -> Result<()> { + let mut db = self + .inner + .lock() + .map_err(|e| anyhow!("failed to acquire memdb lock for write: {}", e))?; + db.insert(node.hash().into(), node); + Ok(()) + } +} + +// NOTE: this can be replaced by `.to_bytes` & `from_bytes` optimized methods at `Node` +#[allow(dead_code)] +fn encode_node(node: &Node) -> Result> { + serde_json::to_vec(node).map_err(|e| anyhow!("failed to serialize node: {e}")) +} +#[allow(dead_code)] +fn decode_node(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(|e| anyhow!("failed to deserialize node: {e}")) +} + +#[cfg(test)] +pub mod tests { + + use super::*; + + #[test] + fn test_db() -> Result<()> { + let mut db = MemDB::new(); + test_db_opt(&mut db)?; + + #[cfg(feature = "db_rocksdb")] + { + let path = "/tmp/rocksdb"; + let mut db = rocks::RocksDB::open(path)?; + test_db_opt(&mut db)?; + } + + Ok(()) + } + + fn test_db_opt(db: &mut dyn DB) -> Result<()> { + let node = Leaf::new(1.into(), 1.into()); + db.store_node(Node::Leaf(node.clone()))?; + + let obtained_node = db.load_node(node.hash.into())?; + let leaf = match obtained_node { + Node::Leaf(l) => l, + _ => panic!("expected a leaf"), + }; + assert_eq!(leaf.hash, node.hash); + + Ok(()) + } +} diff --git a/src/backends/plonky2/primitives/merkletree/db/rocks.rs b/src/backends/plonky2/primitives/merkletree/db/rocks.rs new file mode 100644 index 0000000..47a1739 --- /dev/null +++ b/src/backends/plonky2/primitives/merkletree/db/rocks.rs @@ -0,0 +1,58 @@ +use std::{fmt, path::Path, sync::Arc}; + +use anyhow::{anyhow, Result}; +use rocksdb::{Options, TransactionDB, TransactionDBOptions}; + +use super::DB; +use crate::{ + backends::plonky2::primitives::merkletree::{Leaf, Node}, + middleware::{RawValue, EMPTY_VALUE}, +}; + +#[derive(Clone)] +pub struct RocksDB(Arc); + +#[allow(dead_code)] +impl RocksDB { + pub fn open(path: impl AsRef) -> Result { + let mut options = Options::default(); + options.create_if_missing(true); + let txn_options = TransactionDBOptions::default(); + let inner = + TransactionDB::open(&options, &txn_options, path).map_err(|e| anyhow!("{e}"))?; + Ok(Self(Arc::new(inner))) + } +} + +impl fmt::Debug for RocksDB { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "RocksDB") + } +} + +impl DB for RocksDB { + fn load_node(&self, hash: RawValue) -> Result { + if hash == EMPTY_VALUE { + return Ok(Node::Leaf(Leaf::new(hash, EMPTY_VALUE))); + } + + let maybe_node_bytes = self + .0 + .get(hash.to_bytes()) + .map_err(|e| anyhow!("rocksdb transaction get failed: {e}"))?; + + match maybe_node_bytes { + Some(bytes) => super::decode_node(&bytes), + None => Err(anyhow!("rocksdb: node not found")), + } + } + + fn store_node(&mut self, node: Node) -> Result<()> { + self.0 + .put( + RawValue::from(node.hash()).to_bytes(), + super::encode_node(&node)?, + ) + .map_err(|e| anyhow!("rocksdb transaction put failed: {e}")) + } +} diff --git a/src/backends/plonky2/primitives/merkletree/mod.rs b/src/backends/plonky2/primitives/merkletree/mod.rs index 35c4c11..72e7c23 100644 --- a/src/backends/plonky2/primitives/merkletree/mod.rs +++ b/src/backends/plonky2/primitives/merkletree/mod.rs @@ -2,6 +2,7 @@ //! . use std::{collections::HashMap, fmt, iter::IntoIterator}; +use anyhow::{anyhow, Result}; use itertools::zip_eq; use plonky2::{ field::types::Field, @@ -15,6 +16,8 @@ use crate::middleware::{Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F}; pub mod circuit; pub use circuit::*; +mod db; +use db::DB; pub mod error; pub use error::{TreeError, TreeResult}; @@ -25,7 +28,8 @@ const MAX_DEPTH: usize = 256; /// #[derive(Clone, Debug)] pub struct MerkleTree { - root: Node, + root: Hash, + db: Box, } impl PartialEq for MerkleTree { @@ -38,28 +42,191 @@ impl Eq for MerkleTree {} impl MerkleTree { /// builds a new `MerkleTree` where the leaves contain the given key-values pub fn new(kvs: &HashMap) -> Self { + let db = db::MemDB::new(); + Self::new_with_db(Box::new(db), kvs).unwrap() + } + pub fn new_with_db(db: Box, kvs: &HashMap) -> Result { // Start with an empty node as root. - let mut root = Node::None; + let (root, db) = { + let mut db = db; - // Iterate over key-value pairs (if any) and add them. - for (k, v) in kvs.iter() { - root.apply_op(MerkleTreeOp::Insert, *k, Some(*v)).unwrap(); - } + // Iterate over key-value pairs (if any) and add them. + let mut root = EMPTY_HASH; + for (k, v) in kvs.iter() { + root = Self::apply_op(db.as_mut(), MerkleTreeOp::Insert, root, *k, Some(*v))?; + } + (root, db) + }; - // Fill in hashes. - let _ = root.compute_hash(); - Self { root } + Ok(Self { root, db }) + } + + pub fn empty_with_db(db: Box) -> Self { + Self::from_db(EMPTY_HASH, db) + } + + pub fn from_db(root: Hash, db: Box) -> Self { + Self { root, db } } /// returns the root of the tree pub fn root(&self) -> Hash { - self.root.hash() + self.root + } + + /// Goes down from the current node until it encounters a terminal node, + /// viz. a leaf or empty node, or until it reaches the maximum depth. The + /// `siblings` parameter is used to store the siblings while going down to + /// the leaf, if the given parameter is set to `None`, then no siblings are + /// stored. In this way, the same method `down` can be used by MerkleTree + /// methods `get`, `contains`, `prove` and `prove_nonexistence`. + /// + /// Be aware that this method will return the found leaf at the given path, + /// which may contain a different key and value than the expected one. And + /// while it does not return explicitly a `siblings` variable, the input + /// `siblings` is modified adding there the siblings found along the path. + fn down( + db: &dyn DB, + path_and_lvl: (Vec, usize), // path and lvl + curr_node_hash: Hash, // hash of current level node + new_key: RawValue, // key to be added/found at the leaf + mut siblings: Option<&mut Vec>, + op: MerkleTreeOp, + ) -> TreeResult> { + let (path, lvl) = path_and_lvl; + + if lvl > MAX_DEPTH { + return Err(TreeError::max_depth()); + } + + if curr_node_hash == EMPTY_HASH { + return Ok(None); + } + + let node = db.load_node(curr_node_hash.into())?; + match node { + Node::Intermediate(n) => { + if path[lvl] { + if let Some(s) = siblings.as_mut() { + s.push(n.left); + } + Self::down(db, (path, lvl + 1), n.right, new_key, siblings, op) + } else { + if let Some(s) = siblings.as_mut() { + s.push(n.right); + } + Self::down(db, (path, lvl + 1), n.left, new_key, siblings, op) + } + } + Node::Leaf(old_leaf) => { + if op == MerkleTreeOp::ReadOnly { + return Ok(Some((old_leaf.key, old_leaf.value))); + } + + if new_key == old_leaf.key { + if op == MerkleTreeOp::Insert { + // in Insert, key should not exist + return Err(TreeError::key_exists()); + } + // we're at the operation Update/Delete case + return Ok(Some((old_leaf.key, old_leaf.value))); + } + + Self::down_till_divergence( + lvl, + curr_node_hash.into(), + old_leaf.path, + path, + siblings.ok_or(anyhow!("expected siblings, got None"))?, + )?; + Ok(Some((old_leaf.key, old_leaf.value))) + } + } + } + + /// goes down through a 'virtual' path till finding a divergence. This + /// method is used for when adding a new leaf another already existing leaf + /// is found, so that both leaves (new and old) are pushed down the path + /// till their keys diverge. + fn down_till_divergence( + lvl: usize, + old_key: RawValue, + old_path: Vec, + new_path: Vec, + siblings: &mut Vec, + ) -> TreeResult<()> { + if lvl > MAX_DEPTH { + return Err(TreeError::max_depth()); + } + if old_path[lvl] == new_path[lvl] { + siblings.push(EMPTY_HASH); + return Self::down_till_divergence(lvl + 1, old_key, old_path, new_path, siblings); + } + // reached the divergence + siblings.push(old_key.into()); + Ok(()) + } + + /// go up recursively updating the intermediate nodes + fn up( + db: &mut dyn DB, + path: Vec, + curr_lvl: usize, + key: Hash, + siblings: Vec, + op: MerkleTreeOp, + // first_zeroes should be set to `true` when calling `up` from outside + // the method itself. It is used internally to know when to go up + // 'virtually' for the first batch of zeroes. + first_zeroes: bool, + ) -> Result { + // recall, in the delete case, the `key` is the `remaining_key` + let key_node = db.load_node(key.into())?; + if op == MerkleTreeOp::Delete + && first_zeroes + && matches!(key_node, Node::Leaf(..)) + && siblings[curr_lvl] == EMPTY_HASH + { + // - if we're at operation delete, the node that we're holding is a leaf, + // and we're at the first consecutive zero siblings + // - in operation Delete, go up till the first non-zero sibling and + // pair the given key with that sibling. + // This is only done for the first batch of zero siblings, that is, + // after a non-zero sibling, no matter how many zero siblings it + // has, don't do this logic anymore. + if curr_lvl == 0 { + return Ok(key); + } + return Self::up(db, path, curr_lvl - 1, key, siblings, op, true); + } + + let node = if path[curr_lvl] { + Intermediate::new(siblings[curr_lvl], key) + } else { + Intermediate::new(key, siblings[curr_lvl]) + }; + let node_hash = node.hash; // variable to avoid cloning `node` later + + // store in db + db.store_node(Node::Intermediate(node))?; + + if curr_lvl == 0 { + return Ok(node_hash); + } + Self::up(db, path, curr_lvl - 1, node_hash, siblings, op, false) } /// returns the value at the given key pub fn get(&self, key: &RawValue) -> TreeResult { let path = keypath(*key); - let (key_resolution, _) = self.root.down(0, path, None); + let key_resolution = Self::down( + self.db.as_ref(), + (path, 0), + self.root, + *key, + None, + MerkleTreeOp::ReadOnly, + )?; match key_resolution { Some((k, v)) if &k == key => Ok(v), _ => Err(TreeError::key_not_found()), @@ -69,8 +236,15 @@ impl MerkleTree { /// returns a boolean indicating whether the key exists in the tree pub fn contains(&self, key: &RawValue) -> TreeResult { let path = keypath(*key); - match self.root.down(0, path, None) { - (Some((k, _)), _) if &k == key => Ok(true), + match Self::down( + self.db.as_ref(), + (path, 0), + self.root, + *key, + None, + MerkleTreeOp::ReadOnly, + )? { + Some((k, _)) if &k == key => Ok(true), _ => Ok(false), } } @@ -82,10 +256,15 @@ impl MerkleTree { ) -> TreeResult { let proof_non_existence = self.prove_nonexistence(key)?; - let old_root: Hash = self.root.hash(); - self.root - .apply_op(MerkleTreeOp::Insert, *key, Some(*value))?; - let new_root = self.root.compute_hash(); + let old_root: Hash = self.root; + + self.root = Self::apply_op( + self.db.as_mut(), + MerkleTreeOp::Insert, + self.root, + *key, + Some(*value), + )?; let (v, proof) = self.prove(key)?; assert!(proof.existence); @@ -96,7 +275,7 @@ impl MerkleTree { op: MerkleTreeOp::Insert, // insertion old_root, op_proof: proof_non_existence, - new_root, + new_root: self.root, op_key: *key, op_value: *value, value: None, @@ -111,10 +290,14 @@ impl MerkleTree { ) -> TreeResult { let (old_value, old_proof) = self.prove(key)?; - let old_root: Hash = self.root.hash(); - self.root - .apply_op(MerkleTreeOp::Update, *key, Some(*value))?; - let new_root = self.root.compute_hash(); + let old_root: Hash = self.root; + self.root = Self::apply_op( + self.db.as_mut(), + MerkleTreeOp::Update, + self.root, + *key, + Some(*value), + )?; let (v, proof) = self.prove(key)?; assert!(proof.existence); @@ -125,7 +308,7 @@ impl MerkleTree { op: MerkleTreeOp::Update, old_root, op_proof: old_proof, - new_root, + new_root: self.root, op_key: *key, op_value: *value, value: Some(old_value), @@ -136,9 +319,14 @@ impl MerkleTree { pub fn delete(&mut self, key: &RawValue) -> TreeResult { let (value, proof_existence) = self.prove(key)?; - let old_root: Hash = self.root.hash(); - self.root.apply_op(MerkleTreeOp::Delete, *key, None)?; - let new_root = self.root.compute_hash(); + let old_root: Hash = self.root; + self.root = Self::apply_op( + self.db.as_mut(), + MerkleTreeOp::Delete, + self.root, + *key, + None, + )?; let proof = self.prove_nonexistence(key)?; assert!(!proof.existence); @@ -147,7 +335,7 @@ impl MerkleTree { op: MerkleTreeOp::Delete, old_root, op_proof: proof, - new_root, + new_root: self.root, op_key: *key, op_value: value, value: None, @@ -162,9 +350,15 @@ impl MerkleTree { let path = keypath(*key); let mut siblings: Vec = Vec::new(); - - match self.root.down(0, path, Some(&mut siblings)) { - (Some((k, v)), _) if &k == key => Ok(( + match Self::down( + self.db.as_ref(), + (path, 0), + self.root, + *key, + Some(&mut siblings), + MerkleTreeOp::ReadOnly, + )? { + Some((k, v)) if &k == key => Ok(( v, MerkleProof { existence: true, @@ -186,15 +380,22 @@ impl MerkleTree { let mut siblings: Vec = Vec::new(); // note: non-existence of a key can be in 2 cases: - match self.root.down(0, path, Some(&mut siblings)) { + match Self::down( + self.db.as_ref(), + (path, 0), + self.root, + *key, + Some(&mut siblings), + MerkleTreeOp::ReadOnly, + )? { // case i) the expected leaf does not exist - (None, _) => Ok(MerkleProof { + None => Ok(MerkleProof { existence: false, siblings, other_leaf: None, }), // case ii) the expected leaf does exist in the tree, but it has a different `key` - (Some((k, v)), _) if &k != key => Ok(MerkleProof { + Some((k, v)) if &k != key => Ok(MerkleProof { existence: false, siblings, other_leaf: Some((k, v)), @@ -257,12 +458,17 @@ impl MerkleTree { Self::verify_state_transition(&equivalent_insertion_proof) } MerkleTreeOp::Update => { + if proof.value.is_none() { + return Err(TreeError::state_transition_fail( + "Invalid proof of update: proof.value should not be None".to_string(), + )); + } // check that for the old_root, (op_key, value) *does* exist in the tree Self::verify( proof.old_root, &proof.op_proof, &proof.op_key, - &proof.value.unwrap(), + &proof.value.unwrap(), // unrawp is safe due prev `is_none` check )?; // check that for the new_root, (op_key, op_value) *does* exist in the tree Self::verify( @@ -366,14 +572,127 @@ impl MerkleTree { } Ok(()) } + _ => Err(TreeError::invalid_proof("proof.op".to_string())), } } +} - /// returns an iterator over the leaves of the tree - pub fn iter(&self) -> Iter<'_> { - Iter { - state: vec![&self.root], +// auxiliary methods +impl MerkleTree { + /// Applies given Merkle tree op. + pub(crate) fn apply_op( + db: &mut dyn DB, + op: MerkleTreeOp, + root: Hash, + k: RawValue, + maybe_value: Option, + ) -> TreeResult { + // Rule out invalid arguments + match (op, maybe_value) { + (MerkleTreeOp::Insert, None) | (MerkleTreeOp::Update, None) => { + Err(TreeError::invalid_state_transition_proof_arg(format!( + "{:?} op requires a value argument.", + op + ))) + } + (MerkleTreeOp::Delete, Some(_)) => { + Err(TreeError::invalid_state_transition_proof_arg(format!( + "{:?} op requires no value argument, yet one was provided.", + op + ))) + } + (MerkleTreeOp::ReadOnly, _) => { + Err(TreeError::invalid_state_transition_proof_arg(format!( + "{:?} 'read only' op should not reach the 'apply_op' method", + op + ))) + } + _ => Ok(()), + }?; + + // go down, update the leaf, go up storing new hashes in the db + let path = keypath(k); + let mut siblings: Vec = Vec::new(); + let _ = Self::down( + db, + (path.clone(), 0), // from lvl 0 + root, + k, + Some(&mut siblings), + op, + )?; + + let node: Node = match (op, maybe_value) { + (MerkleTreeOp::Insert, Some(value)) | (MerkleTreeOp::Update, Some(value)) => { + Node::Leaf(Leaf::new(k, value)) + } + (MerkleTreeOp::Delete, None) => { + // return an intermediate node whose hash is 'empty', to + // indicate that there is no leaf + Node::Intermediate(Intermediate { + hash: EMPTY_HASH, + left: EMPTY_HASH, + right: EMPTY_HASH, + }) + } + _ => { + return Err(TreeError::invalid_state_transition_proof_arg(format!( + "{:?} op has invalid value type: {:?}", + op, maybe_value + ))) + } + }; + let node_hash = node.hash(); // variable to avoid cloning `node` later + db.store_node(node)?; + if siblings.is_empty() { + // return the node's hash as root + return Ok(node_hash); } + + let new_root = if op == MerkleTreeOp::Delete { + if siblings.len() == 1 { + // we're at the root-1 level, there is only a sibling, and we're + // removing the current leaf. + // If the sibling is a Leaf, the sibling (leaf) is now the new root + let sibling_node = db.load_node(siblings[0].into())?; + if matches!(sibling_node, Node::Leaf(..)) { + return Ok(siblings[0]); + } + // if the sibling is an Intermediate node, it means that the + // branch goes deeper, so don't short the path going up and pair + // it with an empty hash. + let node = if path[0] { + Intermediate::new(siblings[0], EMPTY_HASH) + } else { + Intermediate::new(EMPTY_HASH, siblings[0]) + }; + let node_hash = node.hash; // variable to avoid cloning `node` later + + // store in db + db.store_node(Node::Intermediate(node))?; + return Ok(node_hash); + } + // use the last sibling as the key that we will push up from + let l = siblings.len() - 1; + let remaining_key = siblings[l]; + siblings[l] = EMPTY_HASH; + // invert the last sibling level + let mut path = path.clone(); + path[siblings.len() - 1] = !path[siblings.len() - 1]; + Self::up( + db, + path, + siblings.len() - 1, + remaining_key, + siblings, + op, + true, + )? + } else { + Self::up(db, path, siblings.len() - 1, node_hash, siblings, op, true)? + }; + + Ok(new_root) } } @@ -422,14 +741,60 @@ fn hash_with_flag(flag: F, inputs: &[F]) -> Hash { } } +impl MerkleTree { + /// returns an iterator over the leaves of the tree + pub fn iter(&self) -> Iter<'_> { + Iter { + state: if self.root == EMPTY_HASH { + vec![] + } else { + vec![self.root] + }, + db: self.db.as_ref(), + } + } +} impl<'a> IntoIterator for &'a MerkleTree { - type Item = (&'a RawValue, &'a RawValue); + type Item = (RawValue, RawValue); type IntoIter = Iter<'a>; - fn into_iter(self) -> Self::IntoIter { self.iter() } } +pub struct Iter<'a> { + state: Vec, + db: &'a dyn DB, +} +impl<'a> Iterator for Iter<'a> { + type Item = (RawValue, RawValue); + fn next(&mut self) -> Option { + let node_hash = self.state.pop()?; + + // Inspect node + let node = self.db.load_node(node_hash.into()).ok()?; + + match node { + Node::Leaf(Leaf { + hash: _, + path: _, + key, + value, + }) => Some((key, value)), + Node::Intermediate(Intermediate { + hash: _, + left, + right, + }) => { + [right, left].into_iter().for_each(|h| { + if h != EMPTY_HASH { + self.state.push(h) + } + }); + self.next() + } + } + } +} impl fmt::Display for MerkleTree { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -439,11 +804,53 @@ impl fmt::Display for MerkleTree { )?; writeln!(f, "digraph hierarchy {{")?; writeln!(f, "node [fontname=Monospace,fontsize=10,shape=box]")?; - write!(f, "{}", self.root)?; + print_graph_viz(f, self.db.as_ref(), self.root)?; writeln!(f, "\n}}\n-----") } } +fn print_graph_viz(f: &mut fmt::Formatter<'_>, db: &dyn DB, hash: Hash) -> fmt::Result { + if hash == EMPTY_HASH { + return Ok(()); + } + + let node = db.load_node(hash.into()).map_err(|_| fmt::Error)?; + match node { + Node::Intermediate(n) => { + let left_hash: String = if n.left == EMPTY_HASH { + writeln!( + f, + "\"{}_child_of_{}\" [label=\"{}\"]", + n.left, n.hash, n.left + )?; + format!("\"{}_child_of_{}\"", n.left, n.hash) + } else { + writeln!(f, "\"{}\"", n.left)?; + format!("\"{}\"", n.left) + }; + let right_hash = if n.right == EMPTY_HASH { + writeln!( + f, + "\"{}_child_of_{}\" [label=\"{}\"]", + n.right, n.hash, n.right + )?; + format!("\"{}_child_of_{}\"", n.right, n.hash) + } else { + writeln!(f, "\"{}\"", n.right,)?; + format!("\"{}\"", n.right) + }; + writeln!(f, "\"{}\" -> {{ {} {} }}", n.hash, left_hash, right_hash,)?; + print_graph_viz(f, db, n.left)?; + print_graph_viz(f, db, n.right) + } + Node::Leaf(l) => { + writeln!(f, "\"{}\" [style=filled]", l.hash)?; + writeln!(f, "\"k:{}\\nv:{}\" [style=dashed]", l.key, l.value)?; + writeln!(f, "\"{}\" -> {{ \"k:{}\\nv:{}\" }}", l.hash, l.key, l.value,) + } + } +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct MerkleProof { // note: currently we don't use the `_existence` field, we would use if we merge the methods @@ -477,6 +884,9 @@ impl MerkleProof { self.compute_root_from_node(&h, path) } fn compute_root_from_node(&self, node_hash: &Hash, path: Vec) -> TreeResult { + if self.siblings.len() > MAX_DEPTH { + return Err(TreeError::max_depth()); + } let mut h = *node_hash; for (i, sibling) in self.siblings.iter().enumerate().rev() { let input: Vec = if path[i] { @@ -532,6 +942,7 @@ pub enum MerkleTreeOp { Insert = 0, Update, Delete, + ReadOnly, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -579,357 +990,56 @@ impl MerkleTreeStateTransitionProof { } } -#[derive(Clone, Debug)] -enum Node { - None, +// NOTE: currently we use automatic serialization/deserialization, which is +// used when storing the node into the DB; but we could manually implement it +// for more disk-space efficiency. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum Node { Leaf(Leaf), Intermediate(Intermediate), } - -impl fmt::Display for Node { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Intermediate(n) => { - let left_hash: String = if n.left.is_empty() { - writeln!( - f, - "\"{}_child_of_{}\" [label=\"{}\"]", - n.left.hash(), - n.hash(), - n.left.hash() - )?; - format!("\"{}_child_of_{}\"", n.left.hash(), n.hash()) - } else { - writeln!(f, "\"{}\"", n.left.hash(),)?; - format!("\"{}\"", n.left.hash()) - }; - let right_hash = if n.right.is_empty() { - writeln!( - f, - "\"{}_child_of_{}\" [label=\"{}\"]", - n.right.hash(), - n.hash(), - n.right.hash() - )?; - format!("\"{}_child_of_{}\"", n.right.hash(), n.hash()) - } else { - writeln!(f, "\"{}\"", n.right.hash(),)?; - format!("\"{}\"", n.right.hash()) - }; - writeln!(f, "\"{}\" -> {{ {} {} }}", n.hash(), left_hash, right_hash,)?; - write!(f, "{}", n.left)?; - write!(f, "{}", n.right) - } - Self::Leaf(l) => { - writeln!(f, "\"{}\" [style=filled]", l.hash())?; - writeln!(f, "\"k:{}\\nv:{}\" [style=dashed]", l.key, l.value)?; - writeln!( - f, - "\"{}\" -> {{ \"k:{}\\nv:{}\" }}", - l.hash(), - l.key, - l.value, - ) - } - Self::None => Ok(()), - } - } -} - impl Node { - fn is_empty(&self) -> bool { + pub fn hash(&self) -> Hash { match self { - Self::None => true, - Self::Leaf(_l) => false, - Self::Intermediate(_n) => false, + Node::Leaf(Leaf { + hash, + path: _, + key: _, + value: _, + }) => *hash, + Node::Intermediate(Intermediate { + hash, + left: _, + right: _, + }) => *hash, } } - fn is_intermediate(&self) -> bool { - match self { - Self::None => false, - Self::Leaf(_l) => false, - Self::Intermediate(_n) => true, - } - } - fn compute_hash(&mut self) -> Hash { - match self { - Self::None => EMPTY_HASH, - Self::Leaf(l) => l.compute_hash(), - Self::Intermediate(n) => n.compute_hash(), - } - } - fn hash(&self) -> Hash { - match self { - Self::None => EMPTY_HASH, - Self::Leaf(l) => l.hash(), - Self::Intermediate(n) => n.hash(), - } - } - - /// Goes down from the current node until it encounters a terminal node, - /// viz. a leaf or empty node, or until it reaches the maximum depth. The - /// `siblings` parameter is used to store the siblings while going down to - /// the leaf, if the given parameter is set to `None`, then no siblings are - /// stored. In this way, the same method `down` can be used by MerkleTree - /// methods `get`, `contains`, `prove` and `prove_nonexistence`. - /// - /// Be aware that this method will return the found leaf at the given path, - /// which may contain a different key and value than the expected one. - fn down( - &self, - lvl: usize, - path: Vec, - mut siblings: Option<&mut Vec>, - ) -> (Option<(RawValue, RawValue)>, usize) { - match self { - Self::Intermediate(n) => { - if path[lvl] { - if let Some(s) = siblings.as_mut() { - s.push(n.left.hash()); - } - n.right.down(lvl + 1, path, siblings) - } else { - if let Some(s) = siblings.as_mut() { - s.push(n.right.hash()); - } - n.left.down(lvl + 1, path, siblings) - } - } - Self::Leaf(Leaf { - key, - value, - path: _p, - hash: _h, - }) => (Some((*key, *value)), lvl), - _ => (None, lvl), - } - } - - /// Applies given Merkle tree op without computing hashes. - pub(crate) fn apply_op( - &mut self, - op: MerkleTreeOp, - key: RawValue, - maybe_value: Option, - ) -> TreeResult<()> { - let key_path = keypath(key); - // Rule out invalid arguments - match (op, maybe_value) { - (MerkleTreeOp::Insert, None) | (MerkleTreeOp::Update, None) => { - Err(TreeError::invalid_state_transition_proof_arg(format!( - "{:?} op requires a value argument.", - op - ))) - } - (MerkleTreeOp::Delete, Some(_)) => { - Err(TreeError::invalid_state_transition_proof_arg(format!( - "{:?} op requires no value argument, yet one was provided.", - op - ))) - } - _ => Ok(()), - }?; - - // Loop through to leaf. - self.apply_op_loop(0, op, key, &key_path, maybe_value)?; - - // If we are dealing with a deletion, normalise along key - // path. - if let MerkleTreeOp::Delete = op { - self.normalise_path(&key_path); - } - - Ok(()) - } - - /// Normalises a Merkle tree along a specified path. Useful - /// post-deletion. - fn normalise_path(&mut self, key_path: &[bool]) { - match self { - Self::Leaf(_) | Self::None => (), - Self::Intermediate(Intermediate { - hash: _h, - left, - right, - }) => { - if key_path[0] { - right.normalise_path(&key_path[1..]); - } else { - left.normalise_path(&key_path[1..]); - } - - // If we have a branch with children (NIL, X) or (X, - // NIL) where X is not a branch, then replace with X. - if left.is_empty() && !right.is_intermediate() { - *self = *right.clone(); - } else if right.is_empty() && !left.is_intermediate() { - *self = *left.clone(); - } - } - } - } - - fn apply_op_loop( - &mut self, - lvl: usize, - op: MerkleTreeOp, - key: RawValue, - key_path: &[bool], - maybe_value: Option, - ) -> TreeResult<()> { - match self { - Self::Intermediate(n) => { - if key_path[lvl] { - n.right - .apply_op_loop(lvl + 1, op, key, key_path, maybe_value) - } else { - n.left - .apply_op_loop(lvl + 1, op, key, key_path, maybe_value) - } - } - _ => { - *self = Self::op_node_check(lvl, self, op, key, key_path, maybe_value)?; - Ok(()) - } - } - } - - /// Checks the terminal node against the desired op and returns a - /// suitable replacement. - /// - /// - Insertion => Node should be empty or contain a different - /// key. A leaf is inserted in the right place. - /// - Update/Deletion => Node should contain the given key. The - /// value is replaced in the case of an update and the leaf removed - /// in the case of a deletion. - pub(crate) fn op_node_check( - lvl: usize, - node: &Node, - op: MerkleTreeOp, - key: RawValue, - key_path: &[bool], - maybe_value: Option, - ) -> TreeResult { - use MerkleTreeOp::*; - - // Invalid args are assumed to have been ruled out. - match (op, node, maybe_value) { - // Insertion case - (Insert, Node::None, Some(value)) => Ok(Node::Leaf(Leaf::new(key, value))), - (Insert, Node::Leaf(l), Some(value)) => { - // in this case, it means that we found a leaf in the new-leaf - // path, thus we need to push both leaves (old-leaf and - // new-leaf) down the path till their paths diverge. - - // first check that keys of both leaves are different - // (l=old-leaf, leaf=new-leaf) - if l.key == key { - // Note: current approach returns an error when trying to - // add to a leaf where the key already exists. We could also - // ignore it if needed. - Err(TreeError::key_exists()) - } else { - let old_leaf = l.clone(); - // set new node as an intermediate node - let mut new_node = Node::Intermediate(Intermediate::empty()); - new_node.down_till_divergence( - lvl, - old_leaf, - Leaf { - hash: None, - path: key_path.to_vec(), - key, - value, - }, - )?; - Ok(new_node) - } - } - // Update case - (Update, Node::Leaf(l), Some(value)) if l.key == key => { - Ok(Node::Leaf(Leaf::new(key, value))) - } - // Deletion case - (Delete, Node::Leaf(l), None) if l.key == key => Ok(Node::None), - // Case of terminal node that does not match. - _ => Err(TreeError::state_transition_fail(format!( - "{:?} op requires key {} to be present in the tree, yet it is not.", - op, key - ))), - } - } - - /// goes down through a 'virtual' path till finding a divergence. This - /// method is used for when adding a new leaf another already existing leaf - /// is found, so that both leaves (new and old) are pushed down the path - /// till their keys diverge. - fn down_till_divergence( - &mut self, - lvl: usize, - old_leaf: Leaf, - new_leaf: Leaf, - ) -> TreeResult<()> { - if let Node::Intermediate(ref mut n) = self { - if old_leaf.path[lvl] != new_leaf.path[lvl] { - // reached divergence in next level, set the leaves as children - // at the current node - if new_leaf.path[lvl] { - n.left = Box::new(Node::Leaf(old_leaf)); - n.right = Box::new(Node::Leaf(new_leaf)); - } else { - n.left = Box::new(Node::Leaf(new_leaf)); - n.right = Box::new(Node::Leaf(old_leaf)); - } - return Ok(()); - } - - // no divergence yet, continue going down - if new_leaf.path[lvl] { - n.right = Box::new(Node::Intermediate(Intermediate::empty())); - return n.right.down_till_divergence(lvl + 1, old_leaf, new_leaf); - } else { - n.left = Box::new(Node::Intermediate(Intermediate::empty())); - return n.left.down_till_divergence(lvl + 1, old_leaf, new_leaf); - } - } - Ok(()) - } } -#[derive(Clone, Debug)] -struct Intermediate { - hash: Option, - left: Box, - right: Box, +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Intermediate { + hash: Hash, + left: Hash, + right: Hash, } impl Intermediate { - fn empty() -> Self { - Self { - hash: None, - left: Box::new(Node::None), - right: Box::new(Node::None), + fn new(left: Hash, right: Hash) -> Self { + if left == EMPTY_HASH && right == EMPTY_HASH { + return Self { + hash: EMPTY_HASH, + left, + right, + }; } - } - fn compute_hash(&mut self) -> Hash { - if self.left.clone().is_empty() && self.right.clone().is_empty() { - self.hash = Some(EMPTY_HASH); - return EMPTY_HASH; - } - let l_hash = self.left.compute_hash(); - let r_hash = self.right.compute_hash(); - let input: Vec = [l_hash.0.to_vec(), r_hash.0.to_vec()].concat(); - let h = hash_with_flag(F::TWO, &input); - self.hash = Some(h); - h - } - fn hash(&self) -> Hash { - self.hash.expect("Hash has not been computed.") + let input: Vec = [left.0.to_vec(), right.0.to_vec()].concat(); + let hash = hash_with_flag(F::TWO, &input); + Self { hash, left, right } } } -#[derive(Clone, Debug)] -pub(crate) struct Leaf { - pub(crate) hash: Option, +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Leaf { + pub(crate) hash: Hash, pub(crate) path: Vec, pub(crate) key: RawValue, pub(crate) value: RawValue, @@ -937,20 +1047,12 @@ pub(crate) struct Leaf { impl Leaf { fn new(key: RawValue, value: RawValue) -> Self { Self { - hash: None, + hash: kv_hash(&key, Some(value)), path: keypath(key), key, value, } } - fn compute_hash(&mut self) -> Hash { - let h = kv_hash(&self.key, Some(self.value)); - self.hash = Some(h); - h - } - fn hash(&self) -> Hash { - self.hash.expect("Hash has not been computed.") - } } // NOTE 1: think if maybe the length of the returned vector can be <256 @@ -968,37 +1070,6 @@ pub(crate) fn keypath(k: RawValue) -> Vec { .collect() } -pub struct Iter<'a> { - state: Vec<&'a Node>, -} - -impl<'a> Iterator for Iter<'a> { - type Item = (&'a RawValue, &'a RawValue); - - fn next(&mut self) -> Option { - let node = self.state.pop(); - match node { - Some(Node::None) => self.next(), - Some(Node::Leaf(Leaf { - hash: _, - path: _, - key, - value, - })) => Some((key, value)), - Some(Node::Intermediate(Intermediate { - hash: _, - left, - right, - })) => { - self.state.push(right); - self.state.push(left); - self.next() - } - _ => None, - } - } -} - #[cfg(test)] pub mod tests { use std::cmp::Ordering; @@ -1009,6 +1080,20 @@ pub mod tests { #[test] fn test_merkletree() -> TreeResult<()> { + let db = Box::new(db::MemDB::new()); + test_merkletree_opt(db)?; + + #[cfg(feature = "db_rocksdb")] + { + let db = Box::new(db::rocks::RocksDB::open( + tempfile::TempDir::new().unwrap().path(), + )?); + test_merkletree_opt(db)?; + } + + Ok(()) + } + fn test_merkletree_opt(db: Box) -> TreeResult<()> { let mut kvs = HashMap::new(); for i in 0..8 { if i == 1 { @@ -1020,7 +1105,7 @@ pub mod tests { let value = RawValue::from(1013); kvs.insert(key, value); - let tree = MerkleTree::new(&kvs); + let tree = MerkleTree::new_with_db(db, &kvs)?; // when printing the tree, it should print the same tree as in // https://0xparc.github.io/pod2/merkletree.html#example-2 println!("{}", tree); @@ -1073,8 +1158,8 @@ pub mod tests { }; let sorted_kvs = kvs - .iter() - .sorted_by(|(k1, _), (k2, _)| cmp(**k1, **k2)) + .into_iter() + .sorted_by(|(k1, _), (k2, _)| cmp(*k1, *k2)) .collect::>(); assert_eq!(collected_kvs, sorted_kvs); @@ -1082,14 +1167,299 @@ pub mod tests { Ok(()) } + #[test] + fn test_delete_to_empty() -> TreeResult<()> { + let db = Box::new(db::MemDB::new()); + test_delete_to_empty_opt(db)?; + + #[cfg(feature = "db_rocksdb")] + { + let db = Box::new(db::rocks::RocksDB::open( + tempfile::TempDir::new().unwrap().path(), + )?); + test_delete_to_empty_opt(db)?; + } + + Ok(()) + } + fn test_delete_to_empty_opt(db: Box) -> TreeResult<()> { + let mut tree = MerkleTree::new_with_db(db, &HashMap::new())?; + + let (key, value) = (RawValue::from(2), RawValue::from(1002)); + let _ = tree.insert(&key, &value)?; + + let (key, value) = (RawValue::from(6), RawValue::from(1006)); + let _ = tree.insert(&key, &value)?; + + let (key, value) = (RawValue::from(3), RawValue::from(1003)); + let _ = tree.insert(&key, &value)?; + + let (key, value) = (RawValue::from(7), RawValue::from(1007)); + let _ = tree.insert(&key, &value)?; + + let _ = tree.delete(&RawValue::from(3))?; + let _ = tree.delete(&RawValue::from(7))?; + let _ = tree.delete(&RawValue::from(6))?; + assert_eq!( + tree.root, + Leaf::new(RawValue::from(2), RawValue::from(1002)).hash + ); + + let _ = tree.delete(&RawValue::from(2))?; + assert_eq!(tree.root, EMPTY_HASH); + + Ok(()) + } + + #[test] + fn test_prove_verify() -> TreeResult<()> { + let db = Box::new(db::MemDB::new()); + test_prove_verify_opt(db)?; + + #[cfg(feature = "db_rocksdb")] + { + let db = Box::new(db::rocks::RocksDB::open( + tempfile::TempDir::new().unwrap().path(), + )?); + test_prove_verify_opt(db)?; + } + + Ok(()) + } + fn test_prove_verify_opt(db: Box) -> TreeResult<()> { + let kvs = [ + (1.into(), 55.into()), + (2.into(), 88.into()), + (175.into(), 0.into()), + ] + .into_iter() + .collect(); + let tree = MerkleTree::new_with_db(db, &kvs)?; + + let (key, value) = (175.into(), 0.into()); + let (v, proof) = tree.prove(&key)?; + assert_eq!(v, value); + MerkleTree::verify(tree.root(), &proof, &key, &value)?; + + let (key, value) = (2.into(), 88.into()); + let (v, proof) = tree.prove(&key)?; + assert_eq!(v, value); + MerkleTree::verify(tree.root(), &proof, &key, &value)?; + + let (key, value) = (175.into(), 0.into()); + let (v, proof) = tree.prove(&key)?; + assert_eq!(v, value); + MerkleTree::verify(tree.root(), &proof, &key, &value)?; + + Ok(()) + } + + #[test] + fn test_update_leaf() -> TreeResult<()> { + let db = Box::new(db::MemDB::new()); + test_update_leaf_opt(db)?; + + #[cfg(feature = "db_rocksdb")] + { + let db = Box::new(db::rocks::RocksDB::open( + tempfile::TempDir::new().unwrap().path(), + )?); + test_update_leaf_opt(db)?; + } + + Ok(()) + } + fn test_update_leaf_opt(db: Box) -> TreeResult<()> { + let kvs = [ + (1.into(), 1.into()), + (9.into(), 9.into()), + (7.into(), 7.into()), + (15.into(), 15.into()), + ] + .into_iter() + .collect(); + let mut tree = MerkleTree::new_with_db(db.clone(), &kvs)?; + let state_transition_proof = tree.update(&7.into(), &0.into())?; + MerkleTree::verify_state_transition(&state_transition_proof)?; + + let kvs = [ + (1.into(), 1.into()), + (9.into(), 9.into()), + (7.into(), 0.into()), + (15.into(), 15.into()), + ] + .into_iter() + .collect(); + let tree2 = MerkleTree::new_with_db(db, &kvs)?; + + assert_eq!(tree.root, tree2.root); + + // update the other leaves + let state_transition_proof = tree.update(&1.into(), &0.into())?; + MerkleTree::verify_state_transition(&state_transition_proof)?; + let state_transition_proof = tree.update(&9.into(), &0.into())?; + MerkleTree::verify_state_transition(&state_transition_proof)?; + let state_transition_proof = tree.update(&15.into(), &0.into())?; + MerkleTree::verify_state_transition(&state_transition_proof) + } + + #[test] + fn test_update_delete_leaf() -> TreeResult<()> { + let db = Box::new(db::MemDB::new()); + test_update_delete_leaf_opt(db)?; + + #[cfg(feature = "db_rocksdb")] + { + let db = Box::new(db::rocks::RocksDB::open( + tempfile::TempDir::new().unwrap().path(), + )?); + test_update_delete_leaf_opt(db)?; + } + + Ok(()) + } + fn test_update_delete_leaf_opt(db: Box) -> TreeResult<()> { + let kvs: HashMap = (0..10) + .map(|i| (i.into(), i.into())) + .collect::>(); + let mut mt = MerkleTree::new_with_db(db, &kvs)?; + + // insert + (11..20) + .map(|i| (i.into(), i.into())) + .try_for_each(|(k, v)| { + let mtp = mt.insert(&k, &v).unwrap(); + MerkleTree::verify_state_transition(&mtp) + })?; + // update + (11..20) + .map(|i| (i.into(), (i + 1).into())) + .try_for_each(|(k, v)| { + let mtp = mt.update(&k, &v).unwrap(); + MerkleTree::verify_state_transition(&mtp) + })?; + // delete + (11..20).map(|i| i.into()).try_for_each(|k| { + let mtp = mt.delete(&k).unwrap(); + MerkleTree::verify_state_transition(&mtp) + })?; + + Ok(()) + } + + #[test] + fn test_delete_leaf() -> TreeResult<()> { + let db = Box::new(db::MemDB::new()); + test_delete_leaf_opt(db)?; + + #[cfg(feature = "db_rocksdb")] + { + let db = Box::new(db::rocks::RocksDB::open( + tempfile::TempDir::new().unwrap().path(), + )?); + test_delete_leaf_opt(db)?; + } + + Ok(()) + } + fn test_delete_leaf_opt(db: Box) -> TreeResult<()> { + let kvs = [(1.into(), 1.into()), (9.into(), 9.into())] + .into_iter() + .collect(); + let tree = MerkleTree::new_with_db(db.clone(), &kvs)?; + let expected_root = tree.root; + + let kvs = [ + (1.into(), 1.into()), + (9.into(), 9.into()), + (7.into(), 7.into()), + (15.into(), 15.into()), + ] + .into_iter() + .collect(); + let mut tree = MerkleTree::new_with_db(db.clone(), &kvs)?; + let state_transition_proof = tree.delete(&15.into())?; + MerkleTree::verify_state_transition(&state_transition_proof)?; + + let kvs = [ + (1.into(), 1.into()), + (9.into(), 9.into()), + (7.into(), 7.into()), + ] + .into_iter() + .collect(); + let tree2 = MerkleTree::new_with_db(db, &kvs)?; + + assert_eq!(tree.root, tree2.root); + + // delete the leaf '7', which when deleted will leave an entire branch + // empty + let state_transition_proof = tree.delete(&7.into())?; + MerkleTree::verify_state_transition(&state_transition_proof)?; + + assert_eq!(tree.root, expected_root); + + Ok(()) + } + + #[test] + fn test_delete_from_two_leaves() -> TreeResult<()> { + let db = Box::new(db::MemDB::new()); + test_delete_from_two_leaves_opt(db)?; + + #[cfg(feature = "db_rocksdb")] + { + let db = Box::new(db::rocks::RocksDB::open( + tempfile::TempDir::new().unwrap().path(), + )?); + test_delete_from_two_leaves_opt(db)?; + } + + Ok(()) + } + fn test_delete_from_two_leaves_opt(db: Box) -> TreeResult<()> { + // tree with two leaves whose keys diverge at the first bit, so that when + // deleting one key leads to a tree with a single Leaf as a root + let mut kvs = HashMap::new(); + kvs.insert(RawValue::from(0), RawValue::from(1000)); + kvs.insert(RawValue::from(1), RawValue::from(1001)); + + let mut tree = MerkleTree::new_with_db(db.clone(), &kvs)?; + tree.delete(&RawValue::from(1))?; + + // the expected_tree has a single leaf, which should match the tree that + // started from two leaves and got one removed + let expected = [(RawValue::from(0), RawValue::from(1000))] + .into_iter() + .collect::>(); + let expected_tree = MerkleTree::new_with_db(db, &expected)?; + + assert_eq!(tree.root(), expected_tree.root()); + Ok(()) + } + #[test] fn test_state_transition() -> TreeResult<()> { + let db = Box::new(db::MemDB::new()); + test_state_transition_opt(db)?; + + #[cfg(feature = "db_rocksdb")] + { + let db = Box::new(db::rocks::RocksDB::open( + tempfile::TempDir::new().unwrap().path(), + )?); + test_state_transition_opt(db)?; + } + + Ok(()) + } + fn test_state_transition_opt(db: Box) -> TreeResult<()> { let mut kvs = HashMap::new(); for i in 0..8 { kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); } - let mut tree = MerkleTree::new(&kvs); + let mut tree = MerkleTree::new_with_db(db, &kvs)?; let old_root = tree.root(); // key=37 shares path with key=5, till the level 6, needing 2 extra From 13cabdb511475ce7793ef776e0310cd43a6c5bc3 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Mon, 23 Mar 2026 12:31:28 +0100 Subject: [PATCH 02/10] Support persistent storage in Containers (#493) Extend the work of https://github.com/0xPARC/pod2/pull/487 to the Containers (Dictionary, Set, Array). The merkle tree only stores `RawValue` for both the key and the value, so it is the responsibility of the Container to store the rich value. In order to handle containers with persistent storage efficiently (which means, cloning them or updating them should not cause an O(n) data copy) I figured we need to have a database of `Value`s indexed by their raw value; as this gives us deduplication and free cloning of containers. The issue with this approach is that in the current design we have collisions between Value's of different types: https://github.com/0xPARC/pod2/issues/426 and the current API relies on the single type of values. To resolve this issue I decided to change the API, instead of assuming that a Value has a fixed type, let the value be possibly multiple compatible types and let the user of the library try casting the Value to a particular type. For this I deprecated the public access of everything related to `TypedValue` and I propose for it to be considered an implementation detail and a blackbox from the external developer point of view. The `Value` type is now used like this: - To create a new Value use `Value::from(...)` where you can pass any compatible type (the same types as before) - To access the Value in typed form you cast it like `value.as_foo()` which returns `Option`. Previously we had a collision between `true` and `1` (and `false` and `0`). Now it doesn't matter whether a value holds a `true` or a `1`, both should be seen as the same and both return `Some` when doing `as_int` and `as_bool`. Similarly we had collisions with containers. For example `set(0, 1, 2) == array[0, 1, 2]` and `set("a", "b") = dict("a": "a", "b": "b")`. Now any container can be casted to any of `set, array, dict`. There's a caveat here: each of these types expects a particular encoding of keys, so casting to the wrong type will return errors on some operations. With this design it no longer matters what is being stored and recovered because the API requires the user to express the expected type and any type with collisions for particular values can be casted to the right type. There's only one case where it's not desirable to swap one `TypedValue` for another: the `TypedValue::Raw`. If a non-`RawValue` in the DB is replaced by the corresponding `RawValue` we erase the required information to recover the rich value. For this reason the implementations of the database treat the `RawValue` as a special case: if an value is stored in non-`RawValue`, the corresponding `RawValue` can never overwrite it. If a value is stored in `RawValue`, a matching non-`RawValue` will overwrite it (promoting it to a rich value). This way we never lose data. A consequence of this is that the serialization, `Display` and `Debug` of a container is not stable. At any point any of the entries can be swapped for a "compatible" one if they share the storage with other containers that introduce collisions. I rewrote all containers as wrapper to a generic `Container` which holds a `Map` from `Value` to `Value`. The serialization of each container now uses the single implementation of the generic `Container`. --- .github/workflows/build.yml | 2 + .github/workflows/tests.yml | 3 +- Cargo.toml | 2 +- src/backends/plonky2/error.rs | 4 +- src/backends/plonky2/mainpod/mod.rs | 13 +- src/backends/plonky2/primitives/ec/curve.rs | 2 +- .../plonky2/primitives/merkletree/db/mod.rs | 42 +- .../plonky2/primitives/merkletree/db/rocks.rs | 35 +- .../plonky2/primitives/merkletree/error.rs | 17 +- .../plonky2/primitives/merkletree/mod.rs | 295 ++++---- src/examples/mod.rs | 6 +- src/frontend/mod.rs | 49 +- src/frontend/operation.rs | 11 +- src/frontend/serialization.rs | 57 +- src/lang/pretty_print.rs | 2 +- src/middleware/containers.rs | 661 +++++++++++++----- src/middleware/db/mem.rs | 60 ++ src/middleware/db/mod.rs | 30 + src/middleware/db/rocks.rs | 107 +++ src/middleware/error.rs | 8 +- src/middleware/mod.rs | 359 +++++----- src/middleware/operation.rs | 43 +- 22 files changed, 1187 insertions(+), 621 deletions(-) create mode 100644 src/middleware/db/mem.rs create mode 100644 src/middleware/db/mod.rs create mode 100644 src/middleware/db/rocks.rs diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d3741ab..c34d0ea 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,6 +24,8 @@ jobs: run: cargo build --features metrics - name: Build time run: cargo build --features time + - name: Build db_rocksdb + run: cargo build --features db_rocksdb - name: Build disk_cache run: cargo build --no-default-features --features backend_plonky2,zk,disk_cache diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3d1ba0e..b3b389a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,4 +17,5 @@ jobs: - name: Set up Rust uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Run tests - run: cargo test --release + # RocksDB is disabled by default but we still want to test it. + run: cargo test --release --features db_rocksdb diff --git a/Cargo.toml b/Cargo.toml index c3329ca..704fe89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,7 +64,7 @@ tempfile = "3" vergen-gitcl = { version = "1.0.0", features = ["build"] } [features] -default = ["backend_plonky2", "zk", "mem_cache", "db_rocksdb"] +default = ["backend_plonky2", "zk", "mem_cache"] backend_plonky2 = ["plonky2"] zk = [] metrics = [] diff --git a/src/backends/plonky2/error.rs b/src/backends/plonky2/error.rs index 355eaf1..6d57568 100644 --- a/src/backends/plonky2/error.rs +++ b/src/backends/plonky2/error.rs @@ -61,8 +61,8 @@ macro_rules! new { } use InnerError::*; impl Error { - pub fn custom(s: String) -> Self { - new!(Custom(s)) + pub fn custom(s: impl Into) -> Self { + new!(Custom(s.into())) } pub fn plonky2_proof_fail(context: impl Into, e: anyhow::Error) -> Self { Self::Plonky2ProofFail(context.into(), e) diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 341e295..8e6ed46 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -225,11 +225,10 @@ pub(crate) fn extract_public_key_of( ) = (op, st) { let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone()); - let sk = SecretKey::try_from( - value_from_op(sk_s, sk_ref) - .ok_or_else(deduction_err)? - .typed(), - )?; + let value = value_from_op(sk_s, sk_ref).ok_or_else(deduction_err)?; + let sk = value + .as_secret_key() + .ok_or_else(|| Error::custom("{value} not SecretKey"))?; aux_list[i] = OperationAux::PublicKeyOfIndex(table.len()); table.push(sk); } @@ -283,7 +282,9 @@ pub(crate) fn extract_signatures( aux_list[i] = OperationAux::SignedByIndex(table.len()); table.push(SignedBy { msg: msg.raw(), - pk: PublicKey::try_from(pk.typed())?, + pk: pk + .as_public_key() + .ok_or_else(|| Error::custom(format!("{pk} is not PublicKey")))?, sig: sig.clone(), }); } diff --git a/src/backends/plonky2/primitives/ec/curve.rs b/src/backends/plonky2/primitives/ec/curve.rs index caf3727..67b7513 100644 --- a/src/backends/plonky2/primitives/ec/curve.rs +++ b/src/backends/plonky2/primitives/ec/curve.rs @@ -207,7 +207,7 @@ impl Point { u: *u, }); points.find(|p| p.is_in_subgroup()).ok_or(Error::custom( - "One of the points must lie in the EC subgroup.".into(), + "One of the points must lie in the EC subgroup.", )) } pub fn as_bytes_from_subgroup(&self) -> Result, Error> { diff --git a/src/backends/plonky2/primitives/merkletree/db/mod.rs b/src/backends/plonky2/primitives/merkletree/db/mod.rs index 3402d1d..7082eaa 100644 --- a/src/backends/plonky2/primitives/merkletree/db/mod.rs +++ b/src/backends/plonky2/primitives/merkletree/db/mod.rs @@ -6,19 +6,20 @@ use std::{ sync::{Arc, Mutex}, }; -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, Result}; use dyn_clone::DynClone; use crate::{ - backends::plonky2::primitives::merkletree::{Leaf, Node}, - middleware::{RawValue, EMPTY_VALUE}, + backends::plonky2::primitives::merkletree::{Intermediate, Node}, + middleware::{Hash, EMPTY_HASH}, }; #[cfg(feature = "db_rocksdb")] pub mod rocks; pub trait DB: Debug + DynClone + Sync + Send { - fn load_node(&self, hash: RawValue) -> Result; + /// Must always return the empty intermediate node when hash is EMPTY_HASH + fn load_node(&self, hash: Hash) -> Result>; fn store_node(&mut self, node: Node) -> Result<()>; } dyn_clone::clone_trait_object!(DB); @@ -26,7 +27,7 @@ dyn_clone::clone_trait_object!(DB); /// MemDB implements the DB trait in a in-memory HashMap. #[derive(Clone, Debug, Default)] pub(crate) struct MemDB { - inner: Arc>>, + inner: Arc>>, } impl MemDB { @@ -36,21 +37,18 @@ impl MemDB { } impl DB for MemDB { - fn load_node(&self, hash: RawValue) -> Result { + fn load_node(&self, hash: Hash) -> Result> { let db = self .inner .lock() .map_err(|e| anyhow!("failed to acquire memdb lock for read: {}", e))?; - if let Some(node) = db.get(&hash) { - return Ok(node.clone()); + if hash == EMPTY_HASH { + return Ok(Some(Node::Intermediate(Intermediate::new( + EMPTY_HASH, EMPTY_HASH, + )))); } - - if hash == EMPTY_VALUE { - return Ok(Node::Leaf(Leaf::new(hash, EMPTY_VALUE))); - } - - bail!("MemDB error: node not found: {}", hash); + Ok(db.get(&hash).cloned()) } fn store_node(&mut self, node: Node) -> Result<()> { @@ -58,25 +56,15 @@ impl DB for MemDB { .inner .lock() .map_err(|e| anyhow!("failed to acquire memdb lock for write: {}", e))?; - db.insert(node.hash().into(), node); + db.insert(node.hash(), node); Ok(()) } } -// NOTE: this can be replaced by `.to_bytes` & `from_bytes` optimized methods at `Node` -#[allow(dead_code)] -fn encode_node(node: &Node) -> Result> { - serde_json::to_vec(node).map_err(|e| anyhow!("failed to serialize node: {e}")) -} -#[allow(dead_code)] -fn decode_node(bytes: &[u8]) -> Result { - serde_json::from_slice(bytes).map_err(|e| anyhow!("failed to deserialize node: {e}")) -} - #[cfg(test)] pub mod tests { - use super::*; + use super::{super::Leaf, *}; #[test] fn test_db() -> Result<()> { @@ -97,7 +85,7 @@ pub mod tests { let node = Leaf::new(1.into(), 1.into()); db.store_node(Node::Leaf(node.clone()))?; - let obtained_node = db.load_node(node.hash.into())?; + let obtained_node = db.load_node(node.hash)?.unwrap(); let leaf = match obtained_node { Node::Leaf(l) => l, _ => panic!("expected a leaf"), diff --git a/src/backends/plonky2/primitives/merkletree/db/rocks.rs b/src/backends/plonky2/primitives/merkletree/db/rocks.rs index 47a1739..0601983 100644 --- a/src/backends/plonky2/primitives/merkletree/db/rocks.rs +++ b/src/backends/plonky2/primitives/merkletree/db/rocks.rs @@ -3,10 +3,9 @@ use std::{fmt, path::Path, sync::Arc}; use anyhow::{anyhow, Result}; use rocksdb::{Options, TransactionDB, TransactionDBOptions}; -use super::DB; use crate::{ - backends::plonky2::primitives::merkletree::{Leaf, Node}, - middleware::{RawValue, EMPTY_VALUE}, + backends::plonky2::primitives::merkletree::{self, db}, + middleware::{Hash, RawValue, EMPTY_HASH}, }; #[derive(Clone)] @@ -30,29 +29,27 @@ impl fmt::Debug for RocksDB { } } -impl DB for RocksDB { - fn load_node(&self, hash: RawValue) -> Result { - if hash == EMPTY_VALUE { - return Ok(Node::Leaf(Leaf::new(hash, EMPTY_VALUE))); +impl db::DB for RocksDB { + fn load_node(&self, hash: Hash) -> Result> { + if hash == EMPTY_HASH { + return Ok(Some(merkletree::Node::Intermediate( + merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH), + ))); } - let maybe_node_bytes = self + match self .0 - .get(hash.to_bytes()) - .map_err(|e| anyhow!("rocksdb transaction get failed: {e}"))?; - - match maybe_node_bytes { - Some(bytes) => super::decode_node(&bytes), - None => Err(anyhow!("rocksdb: node not found")), + .get(RawValue::from(hash).to_bytes()) + .map_err(|e| anyhow!("rocksdb: get failed: {e}"))? + { + None => Ok(None), + Some(bytes) => Ok(Some(merkletree::Node::decode(bytes.as_ref())?)), } } - fn store_node(&mut self, node: Node) -> Result<()> { + fn store_node(&mut self, node: merkletree::Node) -> Result<()> { self.0 - .put( - RawValue::from(node.hash()).to_bytes(), - super::encode_node(&node)?, - ) + .put(RawValue::from(node.hash()).to_bytes(), node.encode()?) .map_err(|e| anyhow!("rocksdb transaction put failed: {e}")) } } diff --git a/src/backends/plonky2/primitives/merkletree/error.rs b/src/backends/plonky2/primitives/merkletree/error.rs index 2eb3198..9345700 100644 --- a/src/backends/plonky2/primitives/merkletree/error.rs +++ b/src/backends/plonky2/primitives/merkletree/error.rs @@ -2,12 +2,16 @@ use std::{backtrace::Backtrace, fmt::Debug}; +use crate::middleware::Hash; + pub type TreeResult = core::result::Result; #[derive(Debug, thiserror::Error)] pub enum TreeInnerError { #[error("key not found")] KeyNotFound, + #[error("node with hash {0} not found")] + NodeNotFound(Hash), #[error("key already exists")] KeyExists, #[error("max depth reached")] @@ -22,6 +26,9 @@ pub enum TreeInnerError { StateTransitionProofFail(String), #[error("circuit max_depth {0} is smaller than proof depth {1}")] CircuitDepthTooSmall(usize, usize), + // Other + #[error("{0}")] + Custom(String), } #[derive(thiserror::Error)] @@ -31,8 +38,8 @@ pub enum TreeError { inner: Box, backtrace: Box, }, - #[error("anyhow::Error: {0}")] - Anyhow(#[from] anyhow::Error), + #[error("database error: {0}")] + Database(anyhow::Error), } impl Debug for TreeError { @@ -60,6 +67,9 @@ impl TreeError { pub(crate) fn key_not_found() -> Self { new!(KeyNotFound) } + pub(crate) fn node_not_found(hash: Hash) -> Self { + new!(NodeNotFound(hash)) + } pub(crate) fn key_exists() -> Self { new!(KeyExists) } @@ -81,4 +91,7 @@ impl TreeError { pub(crate) fn circuit_depth_too_small(circuit_depth: usize, proof_depth: usize) -> Self { new!(CircuitDepthTooSmall(circuit_depth, proof_depth)) } + pub(crate) fn custom(s: impl Into) -> Self { + new!(Custom(s.into())) + } } diff --git a/src/backends/plonky2/primitives/merkletree/mod.rs b/src/backends/plonky2/primitives/merkletree/mod.rs index 72e7c23..0e29e14 100644 --- a/src/backends/plonky2/primitives/merkletree/mod.rs +++ b/src/backends/plonky2/primitives/merkletree/mod.rs @@ -2,7 +2,7 @@ //! . use std::{collections::HashMap, fmt, iter::IntoIterator}; -use anyhow::{anyhow, Result}; +use anyhow::anyhow; use itertools::zip_eq; use plonky2::{ field::types::Field, @@ -16,10 +16,15 @@ use crate::middleware::{Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F}; pub mod circuit; pub use circuit::*; -mod db; -use db::DB; +pub mod db; +pub use db::DB; pub mod error; pub use error::{TreeError, TreeResult}; +use error::{TreeError as Error, TreeResult as Result}; + +// TODO: Replace all `&RawValue` for `RawValue`. This type is very small and `Copy` so there's +// no benefit in passing a reference instead of a copy. Moreover, most of the times the value is +// being copied in methods that receive the reference: see all `*key` and `*value` in the code. /// Theoretical max depth of a merkle tree. This limits appears because we store keys of 256 bits. const MAX_DEPTH: usize = 256; @@ -39,6 +44,20 @@ impl PartialEq for MerkleTree { } impl Eq for MerkleTree {} +pub(crate) fn load_node(db: &dyn DB, hash: Hash) -> Result { + match db.load_node(hash) { + Err(e) => Err(Error::Database(e)), + Ok(None) => Err(Error::node_not_found(hash)), + Ok(Some(node)) => Ok(node), + } +} +fn store_node(db: &mut dyn DB, node: Node) -> Result<()> { + match db.store_node(node) { + Ok(_) => Ok(()), + Err(e) => Err(Error::Database(e)), + } +} + impl MerkleTree { /// builds a new `MerkleTree` where the leaves contain the given key-values pub fn new(kvs: &HashMap) -> Self { @@ -92,18 +111,18 @@ impl MerkleTree { new_key: RawValue, // key to be added/found at the leaf mut siblings: Option<&mut Vec>, op: MerkleTreeOp, - ) -> TreeResult> { + ) -> Result> { let (path, lvl) = path_and_lvl; if lvl > MAX_DEPTH { - return Err(TreeError::max_depth()); + return Err(Error::max_depth()); } if curr_node_hash == EMPTY_HASH { return Ok(None); } - let node = db.load_node(curr_node_hash.into())?; + let node = load_node(db, curr_node_hash)?; match node { Node::Intermediate(n) => { if path[lvl] { @@ -126,7 +145,7 @@ impl MerkleTree { if new_key == old_leaf.key { if op == MerkleTreeOp::Insert { // in Insert, key should not exist - return Err(TreeError::key_exists()); + return Err(Error::key_exists()); } // we're at the operation Update/Delete case return Ok(Some((old_leaf.key, old_leaf.value))); @@ -137,7 +156,7 @@ impl MerkleTree { curr_node_hash.into(), old_leaf.path, path, - siblings.ok_or(anyhow!("expected siblings, got None"))?, + siblings.ok_or(Error::custom("expected siblings, got None"))?, )?; Ok(Some((old_leaf.key, old_leaf.value))) } @@ -154,9 +173,9 @@ impl MerkleTree { old_path: Vec, new_path: Vec, siblings: &mut Vec, - ) -> TreeResult<()> { + ) -> Result<()> { if lvl > MAX_DEPTH { - return Err(TreeError::max_depth()); + return Err(Error::max_depth()); } if old_path[lvl] == new_path[lvl] { siblings.push(EMPTY_HASH); @@ -181,7 +200,7 @@ impl MerkleTree { first_zeroes: bool, ) -> Result { // recall, in the delete case, the `key` is the `remaining_key` - let key_node = db.load_node(key.into())?; + let key_node = load_node(db, key)?; if op == MerkleTreeOp::Delete && first_zeroes && matches!(key_node, Node::Leaf(..)) @@ -208,7 +227,7 @@ impl MerkleTree { let node_hash = node.hash; // variable to avoid cloning `node` later // store in db - db.store_node(Node::Intermediate(node))?; + store_node(db, Node::Intermediate(node))?; if curr_lvl == 0 { return Ok(node_hash); @@ -217,7 +236,7 @@ impl MerkleTree { } /// returns the value at the given key - pub fn get(&self, key: &RawValue) -> TreeResult { + pub fn get(&self, key: &RawValue) -> Result> { let path = keypath(*key); let key_resolution = Self::down( self.db.as_ref(), @@ -228,13 +247,13 @@ impl MerkleTree { MerkleTreeOp::ReadOnly, )?; match key_resolution { - Some((k, v)) if &k == key => Ok(v), - _ => Err(TreeError::key_not_found()), + Some((k, v)) if &k == key => Ok(Some(v)), + _ => Ok(None), } } /// returns a boolean indicating whether the key exists in the tree - pub fn contains(&self, key: &RawValue) -> TreeResult { + pub fn contains(&self, key: &RawValue) -> Result { let path = keypath(*key); match Self::down( self.db.as_ref(), @@ -253,7 +272,7 @@ impl MerkleTree { &mut self, key: &RawValue, value: &RawValue, - ) -> TreeResult { + ) -> Result { let proof_non_existence = self.prove_nonexistence(key)?; let old_root: Hash = self.root; @@ -287,7 +306,7 @@ impl MerkleTree { &mut self, key: &RawValue, value: &RawValue, - ) -> TreeResult { + ) -> Result { let (old_value, old_proof) = self.prove(key)?; let old_root: Hash = self.root; @@ -316,7 +335,7 @@ impl MerkleTree { }) } - pub fn delete(&mut self, key: &RawValue) -> TreeResult { + pub fn delete(&mut self, key: &RawValue) -> Result { let (value, proof_existence) = self.prove(key)?; let old_root: Hash = self.root; @@ -346,7 +365,7 @@ impl MerkleTree { /// returns a proof of existence, which proves that the given key exists in /// the tree. It returns the `value` of the leaf at the given `key`, and the /// `MerkleProof`. - pub fn prove(&self, key: &RawValue) -> TreeResult<(RawValue, MerkleProof)> { + pub fn prove(&self, key: &RawValue) -> Result<(RawValue, MerkleProof)> { let path = keypath(*key); let mut siblings: Vec = Vec::new(); @@ -366,7 +385,7 @@ impl MerkleTree { other_leaf: None, }, )), - _ => Err(TreeError::key_not_found()), + _ => Err(Error::key_not_found()), } } @@ -374,7 +393,7 @@ impl MerkleTree { /// `key` does not exist in the tree. The return value specifies /// the key-value pair in the leaf reached as a result of /// resolving `key` as well as a `MerkleProof`. - pub fn prove_nonexistence(&self, key: &RawValue) -> TreeResult { + pub fn prove_nonexistence(&self, key: &RawValue) -> Result { let path = keypath(*key); let mut siblings: Vec = Vec::new(); @@ -400,22 +419,17 @@ impl MerkleTree { siblings, other_leaf: Some((k, v)), }), - _ => Err(TreeError::key_exists()), + _ => Err(Error::key_exists()), } // both cases prove that the given key don't exist in the tree. } /// verifies an inclusion proof for the given `key` and `value` - pub fn verify( - root: Hash, - proof: &MerkleProof, - key: &RawValue, - value: &RawValue, - ) -> TreeResult<()> { + pub fn verify(root: Hash, proof: &MerkleProof, key: &RawValue, value: &RawValue) -> Result<()> { let h = proof.compute_root_from_leaf(key, Some(*value))?; if h != root { - Err(TreeError::proof_fail("inclusion".to_string())) + Err(Error::proof_fail("inclusion".to_string())) } else { Ok(()) } @@ -423,18 +437,16 @@ impl MerkleTree { /// verifies a non-inclusion proof for the given `key`, that is, the given /// `key` does not exist in the tree - pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &RawValue) -> TreeResult<()> { + pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &RawValue) -> Result<()> { match proof.other_leaf { - Some((k, _v)) if &k == key => { - Err(TreeError::invalid_proof("non-existence".to_string())) - } + Some((k, _v)) if &k == key => Err(Error::invalid_proof("non-existence".to_string())), _ => { let k = proof.other_leaf.map(|(k, _)| k).unwrap_or(*key); let v: Option = proof.other_leaf.map(|(_, v)| v); let h = proof.compute_root_from_leaf(&k, v)?; if h != root { - Err(TreeError::proof_fail("exclusion".to_string())) + Err(Error::proof_fail("exclusion".to_string())) } else { Ok(()) } @@ -442,7 +454,7 @@ impl MerkleTree { } } - pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> TreeResult<()> { + pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { let mut old_siblings = proof.op_proof.siblings.clone(); let new_siblings = proof.siblings.clone(); @@ -459,7 +471,7 @@ impl MerkleTree { } MerkleTreeOp::Update => { if proof.value.is_none() { - return Err(TreeError::state_transition_fail( + return Err(Error::state_transition_fail( "Invalid proof of update: proof.value should not be None".to_string(), )); } @@ -485,7 +497,7 @@ impl MerkleTree { // All siblings should agree (proof.siblings == proof.op_proof.siblings) .then_some(()) - .ok_or(TreeError::state_transition_fail(format!( + .ok_or(Error::state_transition_fail(format!( "Invalid proof of update for key {}: Siblings don't match.", proof.op_key ))) @@ -514,11 +526,11 @@ impl MerkleTree { let divergence_lvl: usize = match zip_eq(old_path, new_path).position(|(x, y)| x != y) { Some(d) => d, - None => return Err(TreeError::max_depth()), + None => return Err(Error::max_depth()), }; if divergence_lvl != new_siblings.len() - 1 { - return Err(TreeError::state_transition_fail( + return Err(Error::state_transition_fail( "paths divergence does not match".to_string(), )); } @@ -534,7 +546,7 @@ impl MerkleTree { if new_siblings.is_empty() { return (old_siblings.is_empty() && proof.old_root == EMPTY_HASH) .then_some(()) - .ok_or(TreeError::state_transition_fail( + .ok_or(Error::state_transition_fail( "new tree has no siblings yet old tree is not the empty tree" .to_string(), )); @@ -544,14 +556,14 @@ impl MerkleTree { old_siblings.resize(d + 1, EMPTY_HASH); for i in 0..d { if old_siblings[i] != new_siblings[i] { - return Err(TreeError::state_transition_fail( + return Err(Error::state_transition_fail( "siblings don't match: old[i]!=new[i] ∀ i (except at i==d)".to_string(), )); } } if old_siblings[d] != new_siblings[d] { if old_siblings[d] != EMPTY_HASH { - return Err(TreeError::state_transition_fail( + return Err(Error::state_transition_fail( "siblings don't match: old[d]!=empty".to_string(), )); } @@ -559,20 +571,20 @@ impl MerkleTree { .op_proof .other_leaf .map(|(k, _)| k) - .ok_or(TreeError::state_transition_fail( + .ok_or(Error::state_transition_fail( "proof.proof_non_existence.other_leaf can not be empty for the case old_siblings[d]!=new_siblings[d]".to_string() ))?; let v: Option = proof.op_proof.other_leaf.map(|(_, v)| v); let old_leaf_hash = kv_hash(&k, v); if new_siblings[d] != old_leaf_hash { - return Err(TreeError::state_transition_fail( + return Err(Error::state_transition_fail( "siblings don't match: new[d]!=old_leaf_hash".to_string(), )); } } Ok(()) } - _ => Err(TreeError::invalid_proof("proof.op".to_string())), + _ => Err(Error::invalid_proof("proof.op".to_string())), } } } @@ -586,27 +598,25 @@ impl MerkleTree { root: Hash, k: RawValue, maybe_value: Option, - ) -> TreeResult { + ) -> Result { // Rule out invalid arguments match (op, maybe_value) { (MerkleTreeOp::Insert, None) | (MerkleTreeOp::Update, None) => { - Err(TreeError::invalid_state_transition_proof_arg(format!( + Err(Error::invalid_state_transition_proof_arg(format!( "{:?} op requires a value argument.", op ))) } (MerkleTreeOp::Delete, Some(_)) => { - Err(TreeError::invalid_state_transition_proof_arg(format!( + Err(Error::invalid_state_transition_proof_arg(format!( "{:?} op requires no value argument, yet one was provided.", op ))) } - (MerkleTreeOp::ReadOnly, _) => { - Err(TreeError::invalid_state_transition_proof_arg(format!( - "{:?} 'read only' op should not reach the 'apply_op' method", - op - ))) - } + (MerkleTreeOp::ReadOnly, _) => Err(Error::invalid_state_transition_proof_arg(format!( + "{:?} 'read only' op should not reach the 'apply_op' method", + op + ))), _ => Ok(()), }?; @@ -627,8 +637,7 @@ impl MerkleTree { Node::Leaf(Leaf::new(k, value)) } (MerkleTreeOp::Delete, None) => { - // return an intermediate node whose hash is 'empty', to - // indicate that there is no leaf + // return a node whose hash is 'empty', to indicate that there is no leaf Node::Intermediate(Intermediate { hash: EMPTY_HASH, left: EMPTY_HASH, @@ -636,16 +645,16 @@ impl MerkleTree { }) } _ => { - return Err(TreeError::invalid_state_transition_proof_arg(format!( + return Err(Error::invalid_state_transition_proof_arg(format!( "{:?} op has invalid value type: {:?}", op, maybe_value ))) } }; - let node_hash = node.hash(); // variable to avoid cloning `node` later - db.store_node(node)?; + let node_hash = node.hash(); // variable to avoid cloning `leaf` later + store_node(db, node)?; if siblings.is_empty() { - // return the node's hash as root + // return the leaf's hash as root return Ok(node_hash); } @@ -654,7 +663,7 @@ impl MerkleTree { // we're at the root-1 level, there is only a sibling, and we're // removing the current leaf. // If the sibling is a Leaf, the sibling (leaf) is now the new root - let sibling_node = db.load_node(siblings[0].into())?; + let sibling_node = load_node(db, siblings[0])?; if matches!(sibling_node, Node::Leaf(..)) { return Ok(siblings[0]); } @@ -669,7 +678,7 @@ impl MerkleTree { let node_hash = node.hash; // variable to avoid cloning `node` later // store in db - db.store_node(Node::Intermediate(node))?; + store_node(db, Node::Intermediate(node))?; return Ok(node_hash); } // use the last sibling as the key that we will push up from @@ -743,48 +752,39 @@ fn hash_with_flag(flag: F, inputs: &[F]) -> Hash { impl MerkleTree { /// returns an iterator over the leaves of the tree - pub fn iter(&self) -> Iter<'_> { + pub fn iter(&self) -> Iter { Iter { state: if self.root == EMPTY_HASH { vec![] } else { vec![self.root] }, - db: self.db.as_ref(), + db: self.db.clone(), } } } -impl<'a> IntoIterator for &'a MerkleTree { +impl IntoIterator for &MerkleTree { type Item = (RawValue, RawValue); - type IntoIter = Iter<'a>; + type IntoIter = Iter; fn into_iter(self) -> Self::IntoIter { self.iter() } } -pub struct Iter<'a> { +pub struct Iter { state: Vec, - db: &'a dyn DB, + db: Box, } -impl<'a> Iterator for Iter<'a> { +impl Iterator for Iter { type Item = (RawValue, RawValue); fn next(&mut self) -> Option { let node_hash = self.state.pop()?; // Inspect node - let node = self.db.load_node(node_hash.into()).ok()?; + let node = load_node(self.db.as_ref(), node_hash).ok()?; match node { - Node::Leaf(Leaf { - hash: _, - path: _, - key, - value, - }) => Some((key, value)), - Node::Intermediate(Intermediate { - hash: _, - left, - right, - }) => { + Node::Leaf(Leaf { key, value, .. }) => Some((key, value)), + Node::Intermediate(Intermediate { left, right, .. }) => { [right, left].into_iter().for_each(|h| { if h != EMPTY_HASH { self.state.push(h) @@ -814,7 +814,7 @@ fn print_graph_viz(f: &mut fmt::Formatter<'_>, db: &dyn DB, hash: Hash) -> fmt:: return Ok(()); } - let node = db.load_node(hash.into()).map_err(|_| fmt::Error)?; + let node = load_node(db, hash).map_err(|_| fmt::Error)?; match node { Node::Intermediate(n) => { let left_hash: String = if n.left == EMPTY_HASH { @@ -878,14 +878,14 @@ impl MerkleProof { /// Computes the root of the Merkle tree suggested by a Merkle proof given a /// key & value. If a value is not provided, the terminal node is assumed to /// be empty. - fn compute_root_from_leaf(&self, key: &RawValue, value: Option) -> TreeResult { + fn compute_root_from_leaf(&self, key: &RawValue, value: Option) -> Result { let path = keypath(*key); let h = kv_hash(key, value); self.compute_root_from_node(&h, path) } - fn compute_root_from_node(&self, node_hash: &Hash, path: Vec) -> TreeResult { + fn compute_root_from_node(&self, node_hash: &Hash, path: Vec) -> Result { if self.siblings.len() > MAX_DEPTH { - return Err(TreeError::max_depth()); + return Err(Error::max_depth()); } let mut h = *node_hash; for (i, sibling) in self.siblings.iter().enumerate().rev() { @@ -1001,19 +1001,17 @@ pub enum Node { impl Node { pub fn hash(&self) -> Hash { match self { - Node::Leaf(Leaf { - hash, - path: _, - key: _, - value: _, - }) => *hash, - Node::Intermediate(Intermediate { - hash, - left: _, - right: _, - }) => *hash, + Node::Leaf(Leaf { hash, .. }) => *hash, + Node::Intermediate(Intermediate { hash, .. }) => *hash, } } + // NOTE: this can be replaced by `.to_bytes` & `from_bytes` optimized methods at `Node` + pub fn encode(&self) -> Result, anyhow::Error> { + serde_json::to_vec(self).map_err(|e| anyhow!("failed to serialize node: {e}")) + } + pub fn decode(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(|e| anyhow!("failed to deserialize node: {e}")) + } } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -1023,7 +1021,7 @@ pub struct Intermediate { right: Hash, } impl Intermediate { - fn new(left: Hash, right: Hash) -> Self { + pub fn new(left: Hash, right: Hash) -> Self { if left == EMPTY_HASH && right == EMPTY_HASH { return Self { hash: EMPTY_HASH, @@ -1045,7 +1043,7 @@ pub struct Leaf { pub(crate) value: RawValue, } impl Leaf { - fn new(key: RawValue, value: RawValue) -> Self { + pub fn new(key: RawValue, value: RawValue) -> Self { Self { hash: kv_hash(&key, Some(value)), path: keypath(key), @@ -1079,21 +1077,21 @@ pub mod tests { use super::*; #[test] - fn test_merkletree() -> TreeResult<()> { + fn test_merkletree() -> Result<()> { let db = Box::new(db::MemDB::new()); test_merkletree_opt(db)?; #[cfg(feature = "db_rocksdb")] { - let db = Box::new(db::rocks::RocksDB::open( - tempfile::TempDir::new().unwrap().path(), - )?); + let db = Box::new( + db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), + ); test_merkletree_opt(db)?; } Ok(()) } - fn test_merkletree_opt(db: Box) -> TreeResult<()> { + fn test_merkletree_opt(db: Box) -> Result<()> { let mut kvs = HashMap::new(); for i in 0..8 { if i == 1 { @@ -1168,21 +1166,40 @@ pub mod tests { } #[test] - fn test_delete_to_empty() -> TreeResult<()> { + fn test_key_not_found() -> Result<()> { + let db = Box::new(db::MemDB::new()); + let mut tree = MerkleTree::empty_with_db(db.clone()); + let value_option = tree.get(&RawValue::from(5)).unwrap(); + assert_eq!(None, value_option); + + tree.insert(&RawValue::from(1), &RawValue::from(42))?; + let value_option = tree.get(&RawValue::from(5)).unwrap(); + assert_eq!(None, value_option); + + // If the root doesn't exist there should be an error + let tree = MerkleTree::from_db(Hash::from(RawValue::from(42)), db); + let result = tree.get(&RawValue::from(5)); + assert!(result.is_err()); + + Ok(()) + } + + #[test] + fn test_delete_to_empty() -> Result<()> { let db = Box::new(db::MemDB::new()); test_delete_to_empty_opt(db)?; #[cfg(feature = "db_rocksdb")] { - let db = Box::new(db::rocks::RocksDB::open( - tempfile::TempDir::new().unwrap().path(), - )?); + let db = Box::new( + db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), + ); test_delete_to_empty_opt(db)?; } Ok(()) } - fn test_delete_to_empty_opt(db: Box) -> TreeResult<()> { + fn test_delete_to_empty_opt(db: Box) -> Result<()> { let mut tree = MerkleTree::new_with_db(db, &HashMap::new())?; let (key, value) = (RawValue::from(2), RawValue::from(1002)); @@ -1212,21 +1229,21 @@ pub mod tests { } #[test] - fn test_prove_verify() -> TreeResult<()> { + fn test_prove_verify() -> Result<()> { let db = Box::new(db::MemDB::new()); test_prove_verify_opt(db)?; #[cfg(feature = "db_rocksdb")] { - let db = Box::new(db::rocks::RocksDB::open( - tempfile::TempDir::new().unwrap().path(), - )?); + let db = Box::new( + db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), + ); test_prove_verify_opt(db)?; } Ok(()) } - fn test_prove_verify_opt(db: Box) -> TreeResult<()> { + fn test_prove_verify_opt(db: Box) -> Result<()> { let kvs = [ (1.into(), 55.into()), (2.into(), 88.into()), @@ -1255,21 +1272,21 @@ pub mod tests { } #[test] - fn test_update_leaf() -> TreeResult<()> { + fn test_update_leaf() -> Result<()> { let db = Box::new(db::MemDB::new()); test_update_leaf_opt(db)?; #[cfg(feature = "db_rocksdb")] { - let db = Box::new(db::rocks::RocksDB::open( - tempfile::TempDir::new().unwrap().path(), - )?); + let db = Box::new( + db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), + ); test_update_leaf_opt(db)?; } Ok(()) } - fn test_update_leaf_opt(db: Box) -> TreeResult<()> { + fn test_update_leaf_opt(db: Box) -> Result<()> { let kvs = [ (1.into(), 1.into()), (9.into(), 9.into()), @@ -1304,21 +1321,21 @@ pub mod tests { } #[test] - fn test_update_delete_leaf() -> TreeResult<()> { + fn test_update_delete_leaf() -> Result<()> { let db = Box::new(db::MemDB::new()); test_update_delete_leaf_opt(db)?; #[cfg(feature = "db_rocksdb")] { - let db = Box::new(db::rocks::RocksDB::open( - tempfile::TempDir::new().unwrap().path(), - )?); + let db = Box::new( + db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), + ); test_update_delete_leaf_opt(db)?; } Ok(()) } - fn test_update_delete_leaf_opt(db: Box) -> TreeResult<()> { + fn test_update_delete_leaf_opt(db: Box) -> Result<()> { let kvs: HashMap = (0..10) .map(|i| (i.into(), i.into())) .collect::>(); @@ -1348,21 +1365,21 @@ pub mod tests { } #[test] - fn test_delete_leaf() -> TreeResult<()> { + fn test_delete_leaf() -> Result<()> { let db = Box::new(db::MemDB::new()); test_delete_leaf_opt(db)?; #[cfg(feature = "db_rocksdb")] { - let db = Box::new(db::rocks::RocksDB::open( - tempfile::TempDir::new().unwrap().path(), - )?); + let db = Box::new( + db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), + ); test_delete_leaf_opt(db)?; } Ok(()) } - fn test_delete_leaf_opt(db: Box) -> TreeResult<()> { + fn test_delete_leaf_opt(db: Box) -> Result<()> { let kvs = [(1.into(), 1.into()), (9.into(), 9.into())] .into_iter() .collect(); @@ -1403,21 +1420,21 @@ pub mod tests { } #[test] - fn test_delete_from_two_leaves() -> TreeResult<()> { + fn test_delete_from_two_leaves() -> Result<()> { let db = Box::new(db::MemDB::new()); test_delete_from_two_leaves_opt(db)?; #[cfg(feature = "db_rocksdb")] { - let db = Box::new(db::rocks::RocksDB::open( - tempfile::TempDir::new().unwrap().path(), - )?); + let db = Box::new( + db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), + ); test_delete_from_two_leaves_opt(db)?; } Ok(()) } - fn test_delete_from_two_leaves_opt(db: Box) -> TreeResult<()> { + fn test_delete_from_two_leaves_opt(db: Box) -> Result<()> { // tree with two leaves whose keys diverge at the first bit, so that when // deleting one key leads to a tree with a single Leaf as a root let mut kvs = HashMap::new(); @@ -1439,21 +1456,21 @@ pub mod tests { } #[test] - fn test_state_transition() -> TreeResult<()> { + fn test_state_transition() -> Result<()> { let db = Box::new(db::MemDB::new()); test_state_transition_opt(db)?; #[cfg(feature = "db_rocksdb")] { - let db = Box::new(db::rocks::RocksDB::open( - tempfile::TempDir::new().unwrap().path(), - )?); + let db = Box::new( + db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(), + ); test_state_transition_opt(db)?; } Ok(()) } - fn test_state_transition_opt(db: Box) -> TreeResult<()> { + fn test_state_transition_opt(db: Box) -> Result<()> { let mut kvs = HashMap::new(); for i in 0..8 { kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); diff --git a/src/examples/mod.rs b/src/examples/mod.rs index 2b490f9..0780c7e 100644 --- a/src/examples/mod.rs +++ b/src/examples/mod.rs @@ -180,11 +180,7 @@ impl EthDosHelper { }; assert_eq!(int, Value::from(int_attestation.public_key)); - let n_i64 = if let TypedValue::Int(x) = n.typed() { - *x - } else { - panic!("distance value is not Int") - }; + let n_i64 = n.as_int().unwrap(); // eth_dos src->dst dist=n+1 self.n_plus_1(&mut pod, eth_dos_int_to_dst, int_attestation, n_i64)?; diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 04fe1ed..98f280e 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -13,10 +13,11 @@ use serde::{Deserialize, Serialize}; pub use serialization::SerializedMainPod; use crate::middleware::{ - self, check_custom_pred, containers::Dictionary, fill_wildcard_values, hash_op, max_op, - prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation, - OperationAux, OperationType, Params, PublicKey, RawValue, Signature, Signer, Statement, - StatementArg, VDSet, Value, ValueRef, + self, check_custom_pred, + containers::{Container, Dictionary}, + fill_wildcard_values, hash_op, max_op, prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, + MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey, RawValue, + Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, EMPTY_VALUE, }; mod custom; @@ -92,8 +93,11 @@ impl fmt::Display for SignedDict { // https://0xparc.github.io/pod2/merkletree.html will not need it since it will be // deterministic based on the keys values not on the order of the keys when added into the // tree. - for (k, v) in self.dict.kvs().iter().sorted_by_key(|kv| kv.0.hash()) { - writeln!(f, " - {} = {}", k, v)?; + for kv in self.dict.iter() { + match kv { + Ok((k, v)) => writeln!(f, " - {} = {}", k, v)?, + Err(e) => writeln!(f, " - ERR: {}", e)?, + } } Ok(()) } @@ -106,16 +110,13 @@ impl SignedDict { .then_some(()) .ok_or(Error::custom("Invalid signature!")) } - pub fn kvs(&self) -> &HashMap { - self.dict.kvs() - } - pub fn get(&self, key: impl Into) -> Option<&Value> { - self.kvs().get(&key.into()) + pub fn get(&self, key: impl Into) -> Option { + self.dict.get(&key.into()).unwrap() } // Returns the Contains statement that defines key if it exists. pub fn get_statement(&self, key: impl Into) -> Option { let key: Key = key.into(); - self.kvs().get(&key).map(|value| { + self.dict.get(&key).unwrap().map(|value| { Statement::Contains( ValueRef::Literal(Value::from(self.dict.clone())), ValueRef::Literal(Value::from(key.name())), @@ -156,6 +157,11 @@ impl fmt::Display for MainPodBuilder { } } +fn as_container_or_err(v: &Value) -> Result { + v.as_container() + .ok_or_else(|| Error::custom(format!("{v} not a container"))) +} + impl MainPodBuilder { pub fn new(params: &Params, vd_set: &VDSet) -> Self { Self { @@ -347,11 +353,12 @@ impl MainPodBuilder { .ok_or(Error::custom(format!( "Invalid key argument for op {}.", op - )))?; + )))? + .raw(); let proof = if op_type == &Native(ContainsFromEntries) { - container.prove_existence(key)?.1 + as_container_or_err(container)?.prove(key)?.1 } else { - container.prove_nonexistence(key)? + as_container_or_err(container)?.prove_nonexistence(key)? }; Ok(Operation(op_type.clone(), op.1, OpAux::MerkleProof(proof))) } @@ -375,18 +382,16 @@ impl MainPodBuilder { let value = op.1.get(3) .and_then(|arg| arg.value()) - .ok_or(Error::custom(format!( - "Invalid key argument for op {}.", - op - ))); + .cloned() + .unwrap_or(Value::from(EMPTY_VALUE)); let proof = match op_type { Native(ContainerInsertFromEntries) => { - old_container.prove_insertion(key, value?)? + as_container_or_err(old_container)?.insert(key.clone(), value)? } Native(ContainerUpdateFromEntries) => { - old_container.prove_update(key, value?)? + as_container_or_err(old_container)?.update(key.raw(), value)? } - _ => old_container.prove_deletion(key)?, + _ => as_container_or_err(old_container)?.delete(key.raw())?, }; Ok(Operation( op_type.clone(), diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index a61623c..a1045a5 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -4,7 +4,7 @@ use crate::{ frontend::SignedDict, middleware::{ containers::Dictionary, root_key_to_ak, CustomPredicateRef, NativeOperation, OperationAux, - OperationType, Signature, Statement, TypedValue, Value, ValueRef, + OperationType, Signature, Statement, Value, ValueRef, }, }; @@ -39,10 +39,9 @@ impl OperationArg { } pub(crate) fn int_value_and_ref(&self) -> Option<(ValueRef, i64)> { - self.value_and_ref().and_then(|(r, v)| match v.typed() { - &TypedValue::Int(i) => Some((r, i)), - _ => None, - }) + self.value_and_ref() + .and_then(|(r, v)| v.as_int().map(|i| Some((r, i)))) + .flatten() } } @@ -71,7 +70,7 @@ impl From<&Value> for OperationArg { impl From<(&Dictionary, &str)> for OperationArg { fn from((dict, key): (&Dictionary, &str)) -> Self { // TODO: Use TryFrom - let value = dict.get(&key.into()).cloned().unwrap(); + let value = dict.get(&key.into()).unwrap().unwrap(); Self::Statement(Statement::Contains( dict.clone().into(), key.into(), diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index 8a47db3..1def7c3 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -83,7 +83,7 @@ mod tests { middleware::{ self, containers::{Array, Dictionary, Set}, - Params, Signer as _, TypedValue, DEFAULT_VD_LIST, + Params, Signer as _, Value, DEFAULT_VD_LIST, }, }; @@ -91,48 +91,46 @@ mod tests { fn test_value_serialization() { // Pairs of values and their expected serialized representations let values = vec![ - (TypedValue::String("hello".to_string()), "\"hello\""), - (TypedValue::Int(42), "{\"Int\":\"42\"}"), - (TypedValue::Bool(true), "true"), + (Value::from("hello"), "\"hello\""), + (Value::from(42), "{\"Int\":\"42\"}"), + (Value::from(true), r#"{"Int":"1"}"#), ( - TypedValue::Array(Array::new(vec!["foo".into(), false.into()])), - "{\"array\":[\"foo\",false]}", + Value::from(Array::new(vec![Value::from("foo"), Value::from(false)])), + r#"{"inner":[[{"Int":"0"},"foo"],[{"Int":"1"},{"Int":"0"}]]}"#, ), ( - TypedValue::Dictionary( - Dictionary::new(HashMap::from([ - // The set of valid keys is equal to the set of valid JSON keys - ("foo".into(), 123.into()), - // Empty strings are valid JSON keys - (("".into()), "baz".into()), - // Keys can contain whitespace - ((" hi".into()), false.into()), - // Keys can contain special characters - (("!@£$%^&&*()".into()), "".into()), - // Keys can contain _very_ special characters - (("\0".into()), "".into()), - // Keys can contain emojis - (("🥳".into()), "party time!".into()), - ])) - ), - "{\"kvs\":{\"\":\"baz\",\"\\u0000\":\"\",\" hi\":false,\"!@£$%^&&*()\":\"\",\"foo\":{\"Int\":\"123\"},\"🥳\":\"party time!\"}}", + Value::from(Dictionary::new(HashMap::from([ + // The set of valid keys is equal to the set of valid JSON keys + ("foo".into(), 123.into()), + // Empty strings are valid JSON keys + (("".into()), "baz".into()), + // Keys can contain whitespace + ((" hi".into()), false.into()), + // Keys can contain special characters + (("!@£$%^&&*()".into()), "".into()), + // Keys can contain _very_ special characters + (("\0".into()), "".into()), + // Keys can contain emojis + (("🥳".into()), "party time!".into()), + ]))), + r#"{"inner":[["!@£$%^&&*()",""],["🥳","party time!"],[" hi",{"Int":"0"}],["foo",{"Int":"123"}],["\u0000",""],["","baz"]]}"#, ), ( - TypedValue::Set(Set::new(HashSet::from(["foo".into(), "bar".into()]))), - "{\"set\":[\"bar\",\"foo\"]}", + Value::from(Set::new(HashSet::from(["foo".into(), "bar".into()]))), + r#"{"inner":[["bar"],["foo"]]}"#, ), ]; for (value, expected) in values { let serialized = serde_json::to_string(&value).unwrap(); assert_eq!(serialized, expected); - let deserialized: TypedValue = serde_json::from_str(&serialized).unwrap(); + let deserialized: Value = serde_json::from_str(&serialized).unwrap(); assert_eq!( value, deserialized, "value {:#?} should equal deserialized {:#?}", value, deserialized ); - let expected_deserialized: TypedValue = serde_json::from_str(expected).unwrap(); + let expected_deserialized: Value = serde_json::from_str(expected).unwrap(); assert_eq!(value, expected_deserialized); } } @@ -177,7 +175,10 @@ mod tests { "deserialized: {}", serde_json::to_string_pretty(&deserialized).unwrap() ); - assert_eq!(signed_dict.dict.kvs(), deserialized.dict.kvs()); + assert_eq!( + signed_dict.dict.dump().unwrap(), + deserialized.dict.dump().unwrap() + ); assert_eq!(signed_dict.public_key, deserialized.public_key); assert_eq!(signed_dict.signature, deserialized.signature); assert_eq!(signed_dict.verify().is_ok(), deserialized.verify().is_ok()); diff --git a/src/lang/pretty_print.rs b/src/lang/pretty_print.rs index efca5c9..bd912cb 100644 --- a/src/lang/pretty_print.rs +++ b/src/lang/pretty_print.rs @@ -131,7 +131,7 @@ impl CustomPredicateBatch { impl PrettyPrint for Value { fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result { - write!(w, "{}", self.typed()) + write!(w, "{}", self.typed) } } diff --git a/src/middleware/containers.rs b/src/middleware/containers.rs index d01f43f..7c8e744 100644 --- a/src/middleware/containers.rs +++ b/src/middleware/containers.rs @@ -1,29 +1,260 @@ //! This file implements the types defined at //! . -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + fmt::{self, Debug}, +}; use schemars::JsonSchema; -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{ + de::{Error as _, SeqAccess, Visitor}, + ser, Deserialize, Deserializer, Serialize, +}; -use super::serialization::{ordered_map, ordered_set}; #[cfg(feature = "backend_plonky2")] -use crate::backends::plonky2::primitives::merkletree::{MerkleProof, MerkleTree}; +use crate::backends::plonky2::primitives::merkletree::{self, MerkleProof, MerkleTree}; use crate::{ backends::plonky2::primitives::merkletree::MerkleTreeStateTransitionProof, - middleware::{Error, Hash, Key, RawValue, Result, Value}, + middleware::{ + db::{mem::MemDB, DB}, + Error, Hash, Key, RawValue, Result, TypedValue, Value, EMPTY_HASH, + }, }; +#[derive(Clone, Debug)] +pub struct Container { + root: Hash, + db: Box, +} + +impl JsonSchema for Container { + fn schema_name() -> String { + "Container".to_string() + } + + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + // Just use the schema of Vec> since that's what we're actually serializing + Vec::>::json_schema(gen) + } +} + +impl Serialize for Container { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut pairs = self + .iter() + .collect::>>() + .map_err(ser::Error::custom)?; + pairs.sort_by(|(k1, _), (k2, _)| k1.raw().cmp(&k2.raw())); + // Serialize as an array + use serde::ser::SerializeSeq; + let mut seq = serializer.serialize_seq(Some(pairs.len()))?; + for (k, v) in pairs { + if k == v { + seq.serialize_element(&[&v])?; + } else { + seq.serialize_element(&[&k, &v])?; + } + } + seq.end() + } +} + +struct ContainerVisitor; + +impl<'de> Visitor<'de> for ContainerVisitor { + type Value = HashMap; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of `[Value]` or `[Value, Value]`") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut kvs = HashMap::::new(); + while let Some(mut elem) = seq.next_element::>()? { + match elem.len() { + 1 => { + let v = elem.pop().unwrap(); + kvs.insert(v.clone(), v); + } + 2 => { + let (v, k) = (elem.pop().unwrap(), elem.pop().unwrap()); + kvs.insert(k, v); + } + n => { + return Err(A::Error::custom(format!( + "invalid vec length of {n} in container entry" + ))) + } + } + } + + Ok(kvs) + } +} + +impl<'de> Deserialize<'de> for Container { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let kvs = deserializer.deserialize_seq(ContainerVisitor)?; + Ok(Container::new(kvs)) + } +} + +impl PartialEq for Container { + fn eq(&self, other: &Self) -> bool { + self.root == other.root + } +} +impl Eq for Container {} + +fn store_container_mt(db: &mut dyn DB, container: &Container) -> Result<()> { + match db.load_node(container.root) { + Err(e) => return Err(Error::Database(e)), + // Container already exists in the DB + Ok(Some(_)) => return Ok(()), + // Container not existing, we need to save it + Ok(None) => {} + }; + let mut container_copy = Container::empty_with_db(db.clone_box()); + for kv_result in container.iter() { + let (k, v) = kv_result?; + container_copy.insert(k, v)?; + } + Ok(()) +} + +fn store_value(db: &mut dyn DB, v: Value) -> Result<()> { + match &v.typed { + TypedValue::Set(Set { inner }) + | TypedValue::Dictionary(Dictionary { inner }) + | TypedValue::Array(Array { inner }) => { + if db.is_persistent() { + store_container_mt(db, inner)?; + } + db.store_value(v).map_err(Error::Database)? + } + _ => db.store_value(v).map_err(Error::Database)?, + } + Ok(()) +} + +fn load_value(db: &dyn DB, value_raw: RawValue) -> Result { + match db.load_value(value_raw) { + Err(e) => Err(Error::Database(e)), + Ok(Some(v)) => Ok(v), + Ok(None) => Err(Error::custom(format!( + "Value from {value_raw} not found in DB" + ))), + } +} + +impl Container { + fn mt(&self) -> MerkleTree { + MerkleTree::from_db(self.root, self.db.clone()) + } + pub fn new(kvs: HashMap) -> Self { + let db = Box::new(MemDB::new()); + let mut container = Self::empty_with_db(db); + for (k, v) in kvs { + container.insert(k, v).expect("no duplicates, no db errors"); + } + container + } + pub fn empty_with_db(db: Box) -> Self { + Self::from_db(EMPTY_HASH, db).expect("EMPTY_HASH exists implicitly") + } + pub fn from_db(root: Hash, db: Box) -> Result { + // Make sure the root exists in the db + let _ = merkletree::load_node(db.as_ref(), root)?; + Ok(Self { root, db }) + } + pub fn commitment(&self) -> Hash { + self.root + } + pub fn get(&self, key_raw: RawValue) -> Result> { + Ok(match self.mt().get(&key_raw)? { + Some(value_raw) => Some(load_value(self.db.as_ref(), value_raw)?), + None => None, + }) + } + pub fn prove(&self, key_raw: RawValue) -> Result<(Value, MerkleProof)> { + let (value_raw, mtp) = self.mt().prove(&key_raw)?; + let value = load_value(self.db.as_ref(), value_raw)?; + Ok((value, mtp)) + } + pub fn prove_nonexistence(&self, key_raw: RawValue) -> Result { + Ok(self.mt().prove_nonexistence(&key_raw)?) + } + pub fn insert(&mut self, key: Value, value: Value) -> Result { + let (key_raw, value_raw) = (key.raw(), value.raw()); + store_value(self.db.as_mut(), key)?; + store_value(self.db.as_mut(), value)?; + let mut mt = self.mt(); + let mtp = mt.insert(&key_raw, &value_raw)?; + self.root = mt.root(); + Ok(mtp) + } + pub fn update( + &mut self, + key_raw: RawValue, + value: Value, + ) -> Result { + let value_raw = value.raw(); + store_value(self.db.as_mut(), value)?; + let mut mt = self.mt(); + let mtp = mt.update(&key_raw, &value_raw)?; + self.root = mt.root(); + Ok(mtp) + } + pub fn delete(&mut self, key_raw: RawValue) -> Result { + let mut mt = self.mt(); + let mtp = mt.delete(&key_raw)?; + self.root = mt.root(); + Ok(mtp) + } + pub fn verify( + root: Hash, + proof: &MerkleProof, + key_raw: RawValue, + value_raw: RawValue, + ) -> Result<()> { + Ok(MerkleTree::verify(root, proof, &key_raw, &value_raw)?) + } + pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key_raw: RawValue) -> Result<()> { + Ok(MerkleTree::verify_nonexistence(root, proof, &key_raw)?) + } + pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { + MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) + } + pub fn iter(&self) -> impl Iterator> { + let db = self.db.clone(); + self.mt().iter().map(move |(key_raw, value_raw)| { + let key = load_value(db.as_ref(), key_raw)?; + let value = load_value(db.as_ref(), value_raw)?; + Ok((key, value)) + }) + } + /// This is an expensive operation + pub fn dump(&self) -> Result> { + self.iter().collect() + } +} + /// Dictionary: the user original keys and values are hashed to be used in the leaf. /// leaf.key=hash(original_key) /// leaf.value=hash(original_value) -#[derive(Clone, Debug, Serialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct Dictionary { - #[serde(skip)] - #[schemars(skip)] - mt: MerkleTree, - #[serde(serialize_with = "ordered_map")] - kvs: HashMap, + pub(crate) inner: Container, } #[macro_export] @@ -34,255 +265,371 @@ macro_rules! dict { ({ $($key:expr => $val:expr),* }) => ({ let mut map = ::std::collections::HashMap::new(); $( map.insert($crate::middleware::Key::from($key), $crate::middleware::Value::from($val)); )* - $crate::middleware::containers::Dictionary::new( map) + $crate::middleware::containers::Dictionary::new(map) }); } +// TODO: Replace all methods that receive a `&Key` by either `impl Into` for write +// methods and `impl AsRef` for read methods. +// TODO: Replace all methods that receive a `&Value` in write methods for `Value`. Consider a +// trait? + impl Dictionary { pub fn new(kvs: HashMap) -> Self { - let kvs_raw: HashMap = - kvs.iter().map(|(k, v)| (k.raw(), v.raw())).collect(); Self { - mt: MerkleTree::new(&kvs_raw), - kvs, + inner: Container::new( + kvs.into_iter() + .map(|(k, v)| (Value::from(k.name), v)) + .collect(), + ), } } + pub fn empty_with_db(db: Box) -> Self { + Self { + inner: Container::empty_with_db(db), + } + } + pub fn from_db(root: Hash, db: Box) -> Result { + Ok(Self { + inner: Container::from_db(root, db)?, + }) + } pub fn commitment(&self) -> Hash { - self.mt.root() + self.inner.commitment() } - pub fn get(&self, key: &Key) -> Result<&Value> { - self.kvs - .get(key) - .ok_or_else(|| Error::custom(format!("key \"{}\" not found", key.name()))) + pub fn get(&self, key: &Key) -> Result> { + self.inner.get(key.raw()) } - pub fn prove(&self, key: &Key) -> Result<(&Value, MerkleProof)> { - let (_, mtp) = self.mt.prove(&key.raw())?; - let value = self.kvs.get(key).expect("key exists"); - Ok((value, mtp)) + pub fn prove(&self, key: &Key) -> Result<(Value, MerkleProof)> { + self.inner.prove(key.raw()) } pub fn prove_nonexistence(&self, key: &Key) -> Result { - Ok(self.mt.prove_nonexistence(&key.raw())?) + self.inner.prove_nonexistence(key.raw()) } pub fn insert(&mut self, key: &Key, value: &Value) -> Result { - let mtp = self.mt.insert(&key.raw(), &value.raw())?; - self.kvs.insert(key.clone(), value.clone()); - Ok(mtp) + self.inner + .insert(Value::from(key.name.clone()), value.clone()) } pub fn update(&mut self, key: &Key, value: &Value) -> Result { - let mtp = self.mt.update(&key.raw(), &value.raw())?; - self.kvs.insert(key.clone(), value.clone()); - Ok(mtp) + self.inner.update(key.raw(), value.clone()) } pub fn delete(&mut self, key: &Key) -> Result { - let mtp = self.mt.delete(&key.raw())?; - self.kvs.remove(key); - Ok(mtp) + self.inner.delete(key.raw()) } pub fn verify(root: Hash, proof: &MerkleProof, key: &Key, value: &Value) -> Result<()> { - let key = key.raw(); - Ok(MerkleTree::verify(root, proof, &key, &value.raw())?) + Container::verify(root, proof, key.raw(), value.raw()) } pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Key) -> Result<()> { - let key = key.raw(); - Ok(MerkleTree::verify_nonexistence(root, proof, &key)?) + Container::verify_nonexistence(root, proof, key.raw()) } pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { - MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) + Container::verify_state_transition(proof) } - // TODO: Rename to dict to be consistent maybe? - pub fn kvs(&self) -> &HashMap { - &self.kvs + pub fn iter(&self) -> impl Iterator> + use<'_> { + self.inner.iter().map(|r| match r { + Ok((key, value)) => Ok(( + key.as_string() + .ok_or_else(|| Error::custom("dictionary: key is not string"))?, + value, + )), + Err(e) => Err(e), + }) + } + /// This is an expensive operation + pub fn dump(&self) -> Result> { + self.iter().collect() } } impl PartialEq for Dictionary { fn eq(&self, other: &Self) -> bool { - self.mt.root() == other.mt.root() + self.inner.eq(&other.inner) } } impl Eq for Dictionary {} -impl<'de> Deserialize<'de> for Dictionary { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - struct Aux { - #[serde(serialize_with = "ordered_map")] - kvs: HashMap, - } - let aux = Aux::deserialize(deserializer)?; - Ok(Dictionary::new(aux.kvs)) - } -} - /// Set: the value field of the leaf is unused, and the key contains the hash of the element. /// leaf.key=hash(original_value) /// leaf.value=0 -#[derive(Clone, Debug, Serialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct Set { - #[serde(skip)] - #[schemars(skip)] - mt: MerkleTree, - #[serde(serialize_with = "ordered_set")] - set: HashSet, + pub(crate) inner: Container, } impl Set { pub fn new(set: HashSet) -> Self { - let kvs_raw: HashMap = set - .iter() - .map(|e| { - let rv = e.raw(); - (rv, rv) - }) - .collect(); Self { - mt: MerkleTree::new(&kvs_raw), - set, + inner: Container::new(set.into_iter().map(|v| (v.clone(), v)).collect()), } } - pub fn commitment(&self) -> Hash { - self.mt.root() + pub fn empty_with_db(db: Box) -> Self { + Self { + inner: Container::empty_with_db(db), + } } - pub fn contains(&self, value: &Value) -> bool { - self.set.contains(value) + pub fn from_db(root: Hash, db: Box) -> Result { + Ok(Self { + inner: Container::from_db(root, db)?, + }) + } + pub fn commitment(&self) -> Hash { + self.inner.commitment() + } + pub fn contains(&self, value: &Value) -> Result { + Ok(self.inner.get(value.raw())?.is_some()) } pub fn prove(&self, value: &Value) -> Result { - let rv = value.raw(); - let (_, proof) = self.mt.prove(&rv)?; + let (_, proof) = self.inner.prove(value.raw())?; Ok(proof) } pub fn prove_nonexistence(&self, value: &Value) -> Result { - let rv = value.raw(); - Ok(self.mt.prove_nonexistence(&rv)?) + self.inner.prove_nonexistence(value.raw()) } pub fn insert(&mut self, value: &Value) -> Result { - let raw_value = value.raw(); - let mtp = self.mt.insert(&raw_value, &raw_value)?; - self.set.insert(value.clone()); - Ok(mtp) + self.inner.insert(value.clone(), value.clone()) } pub fn delete(&mut self, value: &Value) -> Result { - let mtp = self.mt.delete(&value.raw())?; - self.set.remove(value); - Ok(mtp) + self.inner.delete(value.raw()) } pub fn verify(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { - let rv = value.raw(); - Ok(MerkleTree::verify(root, proof, &rv, &rv)?) + Container::verify(root, proof, value.raw(), value.raw()) } pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { - let rv = value.raw(); - Ok(MerkleTree::verify_nonexistence(root, proof, &rv)?) + Container::verify_nonexistence(root, proof, value.raw()) } pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { - MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) + Container::verify_state_transition(proof) } - pub fn set(&self) -> &HashSet { - &self.set + pub fn iter(&self) -> impl Iterator> + use<'_> { + self.inner.iter().map(|r| match r { + Ok((key, value)) => { + if key != value { + return Err(Error::custom("set: key != value")); + } + Ok(value) + } + Err(e) => Err(e), + }) + } + /// This is an expensive operation + pub fn dump(&self) -> Result> { + self.iter().collect() } } impl PartialEq for Set { fn eq(&self, other: &Self) -> bool { - self.mt.root() == other.mt.root() + self.inner.eq(&other.inner) } } impl Eq for Set {} -impl<'de> Deserialize<'de> for Set { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize, JsonSchema)] - struct Aux { - #[serde(serialize_with = "ordered_set")] - set: HashSet, - } - let aux = Aux::deserialize(deserializer)?; - Ok(Set::new(aux.set)) - } -} - /// Array: the elements are placed at the value field of each leaf, and the key field is just the /// array index (integer). /// leaf.key=i /// leaf.value=original_value -#[derive(Clone, Debug, Serialize, JsonSchema)] +/// Due to its construction this should be seen as a sparse array, where there can be gaps +/// (unused indices). +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct Array { - #[serde(skip)] - #[schemars(skip)] - mt: MerkleTree, - array: Vec, + pub(crate) inner: Container, } impl Array { pub fn new(array: Vec) -> Self { - let kvs_raw: HashMap = array - .iter() - .enumerate() - .map(|(i, e)| (RawValue::from(i as i64), e.raw())) - .collect(); - Self { - mt: MerkleTree::new(&kvs_raw), - array, + inner: Container::new( + array + .into_iter() + .enumerate() + .map(|(i, v)| (Value::from(i as i64), v)) + .collect(), + ), } } - pub fn commitment(&self) -> Hash { - self.mt.root() + pub fn empty_with_db(db: Box) -> Self { + Self { + inner: Container::empty_with_db(db), + } } - pub fn get(&self, i: usize) -> Result<&Value> { - self.array.get(i).ok_or_else(|| { - Error::custom(format!("index {} out of bounds 0..{}", i, self.array.len())) + pub fn from_db(root: Hash, db: Box) -> Result { + Ok(Self { + inner: Container::from_db(root, db)?, }) } - pub fn prove(&self, i: usize) -> Result<(&Value, MerkleProof)> { - let (_, mtp) = self.mt.prove(&RawValue::from(i as i64))?; - let value = self.array.get(i).expect("valid index"); - Ok((value, mtp)) + pub fn commitment(&self) -> Hash { + self.inner.commitment() + } + pub fn get(&self, i: usize) -> Result> { + self.inner.get(Value::from(i as i64).raw()) + } + pub fn prove(&self, i: usize) -> Result<(Value, MerkleProof)> { + self.inner.prove(Value::from(i as i64).raw()) + } + pub fn insert(&mut self, i: usize, value: Value) -> Result { + self.inner.insert(Value::from(i as i64), value) + } + pub fn delete(&mut self, i: usize) -> Result { + self.inner.delete(Value::from(i as i64).raw()) } pub fn update(&mut self, i: usize, value: &Value) -> Result { - let mtp = self.mt.update(&(i as i64).into(), &value.raw())?; - self.array[i] = value.clone(); - Ok(mtp) + self.inner + .update(Value::from(i as i64).raw(), value.clone()) } pub fn verify(root: Hash, proof: &MerkleProof, i: usize, value: &Value) -> Result<()> { - Ok(MerkleTree::verify( - root, - proof, - &RawValue::from(i as i64), - &value.raw(), - )?) + Container::verify(root, proof, Value::from(i as i64).raw(), value.raw()) } pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { - MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) + Container::verify_state_transition(proof) } - pub fn array(&self) -> &[Value] { - &self.array + pub fn iter(&self) -> impl Iterator> + use<'_> { + self.inner.iter().map(|r| match r { + Ok((key, value)) => { + let index = key + .as_int() + .ok_or_else(|| Error::custom("array: key is not int"))?; + Ok((index as usize, value)) + } + Err(e) => Err(e), + }) + } + /// This is an expensive operation + pub fn dump(&self) -> Result> { + self.iter().collect() } } impl PartialEq for Array { fn eq(&self, other: &Self) -> bool { - self.mt.root() == other.mt.root() + self.inner.eq(&other.inner) } } impl Eq for Array {} -impl<'de> Deserialize<'de> for Array { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize, JsonSchema)] - struct Aux { - array: Vec, +#[cfg(test)] +mod tests { + use super::*; + use crate::middleware::db::mem::MemDB; + + fn test_databases(test_fn: &dyn Fn(Box)) { + let db = MemDB::new(); + test_fn(Box::new(db)); + #[cfg(feature = "db_rocksdb")] + { + use crate::middleware::db; + let db = db::rocks::RocksDB::open(tempfile::TempDir::new().unwrap().path()).unwrap(); + test_fn(Box::new(db)); } - let aux = Aux::deserialize(deserializer)?; - Ok(Array::new(aux.array)) + } + + fn _test_dict(db: Box) { + let mut dict0 = Dictionary::empty_with_db(db.clone()); + dict0.insert(&Key::from("a"), &Value::from(1)).unwrap(); + dict0.insert(&Key::from("b"), &Value::from(2)).unwrap(); + dict0.update(&Key::from("a"), &Value::from(3)).unwrap(); + dict0.insert(&Key::from("c"), &Value::from(4)).unwrap(); + dict0.delete(&Key::from("c")).unwrap(); + let kvs0 = dict0.dump().unwrap(); + assert_eq!( + kvs0, + [ + ("a".to_string(), Value::from(3)), + ("b".to_string(), Value::from(2)) + ] + .into_iter() + .collect() + ); + let dict1 = Dictionary::from_db(dict0.commitment(), db).unwrap(); + let kvs1 = dict1.dump().unwrap(); + assert_eq!(kvs0, kvs1); + } + + fn _test_set(db: Box) { + let mut set0 = Set::empty_with_db(db.clone()); + set0.insert(&Value::from(1)).unwrap(); + set0.insert(&Value::from(2)).unwrap(); + set0.insert(&Value::from(3)).unwrap(); + set0.delete(&Value::from(2)).unwrap(); + + let s0 = set0.dump().unwrap(); + assert_eq!(s0, [Value::from(1), Value::from(3)].into_iter().collect()); + let set1 = Set::from_db(set0.commitment(), db).unwrap(); + let s1 = set1.dump().unwrap(); + assert_eq!(s0, s1); + } + + fn _test_array(db: Box) { + let mut arr0 = Array::empty_with_db(db.clone()); + arr0.insert(0, Value::from("a")).unwrap(); + arr0.insert(1, Value::from("b")).unwrap(); + arr0.insert(2, Value::from("c")).unwrap(); + arr0.delete(1).unwrap(); + + let a0 = arr0.dump().unwrap(); + assert_eq!( + a0, + [(0, Value::from("a")), (2, Value::from("c"))] + .into_iter() + .collect() + ); + let arr1 = Array::from_db(arr0.commitment(), db).unwrap(); + let a1 = arr1.dump().unwrap(); + assert_eq!(a0, a1); + } + + fn _test_nested(db: Box) { + let mut nested = Dictionary::empty_with_db(db.clone()); + nested.insert(&Key::from("a"), &Value::from(1)).unwrap(); + nested.insert(&Key::from("b"), &Value::from(2)).unwrap(); + let nested_kvs0 = nested.dump().unwrap(); + + let mut dict0 = Dictionary::empty_with_db(db.clone()); + dict0.insert(&Key::from("x"), &Value::from(1)).unwrap(); + dict0 + .insert(&Key::from("y"), &Value::from(nested.clone())) + .unwrap(); + let kvs0 = dict0.dump().unwrap(); + + assert_eq!( + kvs0, + [ + ("x".to_string(), Value::from(1)), + ("y".to_string(), Value::from(nested)) + ] + .into_iter() + .collect() + ); + + let dict1 = Dictionary::from_db(dict0.commitment(), db).unwrap(); + let kvs1 = dict1.dump().unwrap(); + assert_eq!(kvs0, kvs1); + + match &kvs1["y"].typed { + TypedValue::Dictionary(d) => { + let nested_kvs1 = d.dump().unwrap(); + assert_eq!(nested_kvs0, nested_kvs1); + } + _ => unreachable!(), + } + } + + #[test] + fn test_dict() { + test_databases(&_test_dict); + } + + #[test] + fn test_set() { + test_databases(&_test_set); + } + + #[test] + fn test_array() { + test_databases(&_test_array); + } + + #[test] + fn test_nested() { + test_databases(&_test_nested); } } diff --git a/src/middleware/db/mem.rs b/src/middleware/db/mem.rs new file mode 100644 index 0000000..53ab91e --- /dev/null +++ b/src/middleware/db/mem.rs @@ -0,0 +1,60 @@ +use super::*; + +/// MemDB implements the DB trait in a in-memory HashMap. +#[derive(Clone, Debug, Default)] +pub struct MemDB { + nodes: Arc>>, + values: Arc>>, +} + +impl MemDB { + pub fn new() -> Self { + Self::default() + } +} + +impl merkletree::db::DB for MemDB { + fn load_node(&self, hash: Hash) -> anyhow::Result> { + let nodes = self.nodes.read().expect("lock not poisoned"); + + if hash == EMPTY_HASH { + return Ok(Some(merkletree::Node::Intermediate( + merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH), + ))); + } + + Ok(nodes.get(&hash).cloned()) + } + + fn store_node(&mut self, node: merkletree::Node) -> anyhow::Result<()> { + let mut nodes = self.nodes.write().expect("lock not poisoned"); + nodes.insert(node.hash(), node); + Ok(()) + } +} + +impl DB for MemDB { + fn load_value(&self, raw: RawValue) -> anyhow::Result> { + let values = self.values.read().expect("lock not poisoned"); + + Ok(values.get(&raw).cloned()) + } + fn store_value(&mut self, value: Value) -> anyhow::Result<()> { + let mut values = self.values.write().expect("lock not poisoned"); + let value_raw = value.raw(); + if let Some(old_value) = values.get(&value_raw) { + // If we had a non-raw value stored never overwrite it with a raw value + if !old_value.is_raw() && value.is_raw() { + return Ok(()); + } + } + values.insert(value_raw, value); + Ok(()) + } + fn is_persistent(&self) -> bool { + false + } + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} diff --git a/src/middleware/db/mod.rs b/src/middleware/db/mod.rs new file mode 100644 index 0000000..bb32a67 --- /dev/null +++ b/src/middleware/db/mod.rs @@ -0,0 +1,30 @@ +use std::{ + collections::HashMap, + fmt::Debug, + sync::{Arc, RwLock}, +}; + +use dyn_clone::DynClone; + +#[cfg(feature = "backend_plonky2")] +use crate::backends::plonky2::primitives::merkletree::{self}; +use crate::middleware::{Hash, RawValue, Value, EMPTY_HASH}; + +pub mod mem; +#[cfg(feature = "db_rocksdb")] +pub mod rocks; + +// Trait for database that stores values. Must be cheap to clone. +pub trait DB: Debug + DynClone + Sync + Send + merkletree::db::DB { + fn load_value(&self, raw: RawValue) -> anyhow::Result>; + // If the DB is persistent, for containers only the root needs to be stored because the + // Container type makes sure the underlying merkle tree is stored in the DB independently, so + // that it can be recovered back just with the root and the DB. + // If the value is RawValue and a previous non-RawValue exists, no store overwrite it. + // should be done. If the value is non-RawValue and a previous RawValue exists, store + // should overwrite it. + fn store_value(&mut self, value: Value) -> anyhow::Result<()>; + fn is_persistent(&self) -> bool; + fn clone_box(&self) -> Box; +} +dyn_clone::clone_trait_object!(DB); diff --git a/src/middleware/db/rocks.rs b/src/middleware/db/rocks.rs new file mode 100644 index 0000000..be5ca4a --- /dev/null +++ b/src/middleware/db/rocks.rs @@ -0,0 +1,107 @@ +use std::{fmt, path::Path, sync::Arc}; + +use anyhow::{anyhow, Result}; +use rocksdb::{Options, TransactionDB, TransactionDBOptions}; + +use super::*; + +fn node_key(hash: Hash) -> Vec { + let mut k = Vec::with_capacity(2 + 4); + k.extend_from_slice(b"n/"); + k.extend_from_slice(&RawValue::from(hash).to_bytes()); + k +} + +fn value_key(raw: RawValue) -> Vec { + let mut k = Vec::with_capacity(2 + 4); + k.extend_from_slice(b"v/"); + k.extend_from_slice(&raw.to_bytes()); + k +} + +#[derive(Clone)] +pub struct RocksDB { + db: Arc, +} + +impl fmt::Debug for RocksDB { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "RocksDB(path: {:?})", self.db.path()) + } +} + +impl RocksDB { + pub fn open(path: impl AsRef) -> Result { + let mut options = Options::default(); + options.create_if_missing(true); + let txn_options = TransactionDBOptions::default(); + let inner = + TransactionDB::open(&options, &txn_options, path).map_err(|e| anyhow!("{e}"))?; + Ok(Self { + db: Arc::new(inner), + }) + } +} + +impl merkletree::db::DB for RocksDB { + fn load_node(&self, hash: Hash) -> Result> { + if hash == EMPTY_HASH { + return Ok(Some(merkletree::Node::Intermediate( + merkletree::Intermediate::new(EMPTY_HASH, EMPTY_HASH), + ))); + } + + match self.db.get(node_key(hash))? { + None => Ok(None), + Some(bytes) => Ok(Some(merkletree::Node::decode(bytes.as_ref())?)), + } + } + + fn store_node(&mut self, node: merkletree::Node) -> Result<()> { + self.db + .put(node_key(node.hash()), node.encode()?) + .map_err(|e| anyhow!("rocksdb transaction put failed: {e}")) + } +} + +impl DB for RocksDB { + fn load_value(&self, raw: RawValue) -> anyhow::Result> { + match self.db.get(value_key(raw))? { + None => Ok(None), + Some(bytes) => Ok(Some({ + if bytes.is_empty() { + Value::from(raw) + } else { + Value::from_bytes(bytes.as_ref(), self.clone_box())? + } + })), + } + } + fn store_value(&mut self, value: Value) -> anyhow::Result<()> { + let value_key = value_key(value.raw()); + let tx = self.db.transaction(); + if let Some(old_value_bytes) = tx.get_for_update(&value_key, true)? { + let is_raw = old_value_bytes.is_empty(); + // If we had a non-RawValue stored don't overwrite it (specially not with a + // RawValue). Also skip redundant RawValue overwrite. + if !is_raw || (is_raw && value.is_raw()) { + return Ok(()); + } + } + let value_bytes = if value.is_raw() { + // For RawValue we store an empty vector because it's a duplicate of the key. + // This way we can easily check for RawValue without decoding. + vec![] + } else { + Value::to_bytes(&value) + }; + tx.put(value_key, value_bytes)?; + Ok(tx.commit()?) + } + fn is_persistent(&self) -> bool { + true + } + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} diff --git a/src/middleware/error.rs b/src/middleware/error.rs index 74605da..f7ad765 100644 --- a/src/middleware/error.rs +++ b/src/middleware/error.rs @@ -72,6 +72,10 @@ pub enum Error { }, #[error(transparent)] Tree(#[from] crate::backends::plonky2::primitives::merkletree::error::TreeError), + #[error(transparent)] + Json(#[from] serde_json::Error), + #[error("database error: {0}")] + Database(anyhow::Error), } impl Debug for Error { @@ -164,7 +168,7 @@ impl Error { pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self { new!(UnsatisfiedCustomPredicateDisjunction(pred)) } - pub(crate) fn custom(s: String) -> Self { - new!(Custom(s)) + pub(crate) fn custom(s: impl Into) -> Self { + new!(Custom(s.into())) } } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 542f5b2..19ca2c2 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,16 +1,13 @@ //! The middleware includes the type definitions and the traits used to connect the frontend and //! the backend. -use std::sync::Arc; - use hex::ToHex; -use itertools::Itertools; use strum_macros::FromRepr; mod basetypes; use std::{cmp::PartialEq, hash}; -use containers::{Array, Dictionary, Set}; +use containers::{Array, Container, Dictionary, Set}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; pub mod containers; @@ -22,6 +19,7 @@ pub mod serialization; mod statement; use std::{any::Any, fmt}; +pub mod db; pub use basetypes::*; pub use custom::*; use dyn_clone::DynClone; @@ -31,14 +29,10 @@ pub use pod_deserialization::*; use serialization::*; pub use statement::*; -use crate::backends::plonky2::primitives::merkletree::{ - MerkleProof, MerkleTreeStateTransitionProof, -}; - // TODO: Move all value-related types to to `value.rs` #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] // TODO #[schemars(transform = serialization::transform_value_schema)] -pub enum TypedValue { +pub(crate) enum TypedValue { // Serde cares about the order of the enum variants, with untagged variants // appearing at the end. // Variants without "untagged" will be serialized as "tagged" values by @@ -73,8 +67,6 @@ pub enum TypedValue { Array(Array), #[serde(untagged)] String(String), - #[serde(untagged)] - Bool(bool), } impl From<&str> for TypedValue { @@ -97,7 +89,11 @@ impl From for TypedValue { impl From for TypedValue { fn from(b: bool) -> Self { - TypedValue::Bool(b) + if b { + TypedValue::Int(1) + } else { + TypedValue::Int(0) + } } } @@ -149,70 +145,6 @@ impl From for TypedValue { } } -impl TryFrom<&TypedValue> for i64 { - type Error = Error; - fn try_from(v: &TypedValue) -> std::result::Result { - if let TypedValue::Int(n) = v { - Ok(*n) - } else { - Err(Error::custom("Value not an int".to_string())) - } - } -} - -impl TryFrom<&TypedValue> for String { - type Error = Error; - fn try_from(tv: &TypedValue) -> Result { - match tv { - TypedValue::String(s) => Ok(s.clone()), - _ => Err(Error::custom(format!( - "Value {} cannot be converted to a string.", - tv - ))), - } - } -} - -impl TryFrom<&TypedValue> for Key { - type Error = Error; - fn try_from(tv: &TypedValue) -> Result { - Ok(Key::new(String::try_from(tv)?)) - } -} - -impl TryFrom<&TypedValue> for PublicKey { - type Error = Error; - fn try_from(v: &TypedValue) -> std::result::Result { - if let TypedValue::PublicKey(pk) = v { - Ok(*pk) - } else { - Err(Error::custom("Value not a public key".to_string())) - } - } -} - -impl TryFrom<&TypedValue> for SecretKey { - type Error = Error; - fn try_from(v: &TypedValue) -> std::result::Result { - if let TypedValue::SecretKey(sk) = v { - Ok(sk.clone()) - } else { - Err(Error::custom("Value not a secret key".to_string())) - } - } -} - -impl TryFrom<&TypedValue> for Predicate { - type Error = Error; - fn try_from(v: &TypedValue) -> std::result::Result { - if let TypedValue::Predicate(p) = v { - Ok(p.clone()) - } else { - Err(Error::custom("Value not a Predicate".to_string())) - } - } -} - impl fmt::Display for TypedValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -224,36 +156,54 @@ impl fmt::Display for TypedValue { Err(_) => write!(f, "\"{}\"", s), } } - TypedValue::Bool(b) => write!(f, "{}", b), TypedValue::Array(a) => { write!(f, "[")?; - for (i, v) in a.array().iter().enumerate() { + for (i, r) in a.iter().enumerate() { if i > 0 { write!(f, ", ")?; } - write!(f, "{}", v)?; + if i == 8 { + write!(f, "…")?; + break; + } + match r { + Ok((index, value)) => write!(f, "{}: {}", index, value)?, + Err(e) => write!(f, "{e}")?, + } } write!(f, "]") } TypedValue::Dictionary(d) => { write!(f, "{{ ")?; - let kvs: Vec<_> = d.kvs().iter().sorted_by_key(|(k, _)| k.name()).collect(); - for (i, (k, v)) in kvs.iter().enumerate() { + for (i, r) in d.iter().enumerate() { if i > 0 { write!(f, ", ")?; } - write!(f, "{}: {}", k, v)?; + if i == 8 { + write!(f, "…")?; + break; + } + match r { + Ok((key, value)) => write!(f, "{}: {}", key, value)?, + Err(e) => write!(f, "{e}")?, + } } write!(f, " }}") } TypedValue::Set(s) => { write!(f, "#[")?; - let values: Vec<_> = s.set().iter().sorted_by_key(|k| k.raw()).collect(); - for (i, v) in values.iter().enumerate() { + for (i, r) in s.iter().enumerate() { if i > 0 { write!(f, ", ")?; } - write!(f, "{}", v)?; + if i == 8 { + write!(f, "…")?; + break; + } + match r { + Ok(value) => write!(f, "{}", value)?, + Err(e) => write!(f, "{e}")?, + } } write!(f, "]") } @@ -272,7 +222,6 @@ impl From<&TypedValue> for RawValue { match v { TypedValue::String(s) => RawValue::from(hash_str(s)), TypedValue::Int(v) => RawValue::from(*v), - TypedValue::Bool(b) => RawValue::from(*b as i64), TypedValue::Dictionary(d) => RawValue::from(d.commitment()), TypedValue::Set(s) => RawValue::from(s.commitment()), TypedValue::Array(a) => RawValue::from(a.commitment()), @@ -405,9 +354,8 @@ impl JsonSchema for TypedValue { #[derive(Clone, Debug)] pub struct Value { - // The `TypedValue` is under `Arc` so that cloning a `Value` is cheap. - typed: Arc, - raw: RawValue, + pub(crate) typed: TypedValue, + pub(crate) raw: RawValue, } // Values are serialized as their TypedValue. @@ -441,6 +389,55 @@ impl JsonSchema for Value { } } +/// Dual of TypedValue that is not recursive: for container types no entry only the commitment +/// (merkle tree root of underlying data) is available. Used for byte serialization for +/// persistent storage. +#[derive(Serialize, Deserialize)] +enum TypedValueNoRec { + Raw(RawValue), + Int(i64), + PublicKey(PublicKey), + SecretKey(SecretKey), + Predicate(Predicate), + Set(Hash), + Dictionary(Hash), + Array(Hash), + String(String), +} + +// NOTE: byte serialization is using json. Using a byte-native serialization would improve +// performance and storage usage. +impl Value { + pub fn to_bytes(&self) -> Vec { + let v = match &self.typed { + TypedValue::Int(v) => TypedValueNoRec::Int(*v), + TypedValue::Raw(v) => TypedValueNoRec::Raw(*v), + TypedValue::PublicKey(v) => TypedValueNoRec::PublicKey(*v), + TypedValue::SecretKey(v) => TypedValueNoRec::SecretKey(v.clone()), + TypedValue::Predicate(v) => TypedValueNoRec::Predicate(v.clone()), + TypedValue::Set(v) => TypedValueNoRec::Set(v.commitment()), + TypedValue::Dictionary(v) => TypedValueNoRec::Dictionary(v.commitment()), + TypedValue::Array(v) => TypedValueNoRec::Array(v.commitment()), + TypedValue::String(v) => TypedValueNoRec::String(v.clone()), + }; + serde_json::to_vec(&v).expect("json serialization succeeds") + } + pub fn from_bytes(bytes: &[u8], db: Box) -> Result { + let v: TypedValueNoRec = serde_json::from_slice(bytes)?; + Ok(match v { + TypedValueNoRec::Int(v) => Value::from(v), + TypedValueNoRec::Raw(v) => Value::from(v), + TypedValueNoRec::PublicKey(v) => Value::from(v), + TypedValueNoRec::SecretKey(v) => Value::from(v), + TypedValueNoRec::Predicate(v) => Value::from(v), + TypedValueNoRec::Set(v) => Value::from(Set::from_db(v, db)?), + TypedValueNoRec::Dictionary(v) => Value::from(Dictionary::from_db(v, db)?), + TypedValueNoRec::Array(v) => Value::from(Array::from_db(v, db)?), + TypedValueNoRec::String(v) => Value::from(v), + }) + } +} + impl PartialEq for Value { fn eq(&self, other: &Self) -> bool { self.raw == other.raw @@ -462,106 +459,110 @@ impl fmt::Display for Value { } impl Value { - pub fn new(value: TypedValue) -> Self { + pub(crate) fn new(value: TypedValue) -> Self { let raw_value = RawValue::from(&value); Self { - typed: Arc::new(value), + typed: value, raw: raw_value, } } - pub fn typed(&self) -> &TypedValue { - &self.typed - } pub fn raw(&self) -> RawValue { self.raw } - /// Determines Merkle existence proof for `key` in `self` (if applicable). - pub(crate) fn prove_existence<'a>( - &'a self, - key: &'a Value, - ) -> Result<(&'a Value, MerkleProof)> { - match &self.typed() { - TypedValue::Array(a) => match key.typed() { - TypedValue::Int(i) if i >= &0 => a.prove((*i) as usize), - _ => Err(Error::custom(format!( - "Invalid key {} for container {}.", - key, self - )))?, + /// Returns true if the typed value is RawValue, which means it's a generic value with no type + /// information and no extra value data. + pub fn is_raw(&self) -> bool { + matches!(self.typed, TypedValue::Raw(_)) + } + pub fn as_raw(&self) -> RawValue { + self.raw + } + pub fn as_int(&self) -> Option { + match self.typed { + TypedValue::Int(i) => Some(i), + _ => None, + } + } + pub fn as_public_key(&self) -> Option { + match &self.typed { + TypedValue::PublicKey(pk) => Some(*pk), + _ => None, + } + } + pub fn as_secret_key(&self) -> Option { + match &self.typed { + TypedValue::SecretKey(sk) => Some(sk.clone()), + _ => None, + } + } + pub fn as_predicate(&self) -> Option { + match &self.typed { + TypedValue::Predicate(p) => Some(p.clone()), + _ => None, + } + } + pub fn as_set(&self) -> Option { + match &self.typed { + TypedValue::Set(s) => Some(s.clone()), + TypedValue::Dictionary(d) => Some(Set { + inner: d.inner.clone(), + }), + TypedValue::Array(a) => Some(Set { + inner: a.inner.clone(), + }), + _ => None, + } + } + pub fn as_container(&self) -> Option { + match &self.typed { + TypedValue::Set(s) => Some(s.inner.clone()), + TypedValue::Dictionary(d) => Some(d.inner.clone()), + TypedValue::Array(a) => Some(a.inner.clone()), + _ => None, + } + } + pub fn as_dictionary(&self) -> Option { + match &self.typed { + TypedValue::Set(s) => Some(Dictionary { + inner: s.inner.clone(), + }), + TypedValue::Dictionary(d) => Some(d.clone()), + TypedValue::Array(a) => Some(Dictionary { + inner: a.inner.clone(), + }), + _ => None, + } + } + pub fn as_array(&self) -> Option { + match &self.typed { + TypedValue::Set(s) => Some(Array { + inner: s.inner.clone(), + }), + TypedValue::Dictionary(d) => Some(Array { + inner: d.inner.clone(), + }), + TypedValue::Array(a) => Some(a.clone()), + _ => None, + } + } + pub fn as_str(&self) -> Option<&str> { + match &self.typed { + TypedValue::String(s) => Some(s.as_str()), + _ => None, + } + } + pub fn as_string(&self) -> Option { + self.as_str().map(|s| s.to_string()) + } + pub fn as_bool(&self) -> Option { + match self.typed { + TypedValue::Int(i) => match i { + 0 => Some(false), + 1 => Some(true), + _ => None, }, - TypedValue::Dictionary(d) => d.prove(&key.typed().try_into()?), - TypedValue::Set(s) => Ok((key, s.prove(key)?)), - _ => Err(Error::custom(format!( - "Invalid container value {}", - self.typed() - ))), - } - } - /// Determines Merkle non-existence proof for `key` in `self` (if applicable). - pub(crate) fn prove_nonexistence<'a>(&'a self, key: &'a Value) -> Result { - match &self.typed() { - TypedValue::Array(_) => Err(Error::custom( - "Arrays do not support `NotContains` operation.".to_string(), - )), - TypedValue::Dictionary(d) => d.prove_nonexistence(&key.typed().try_into()?), - TypedValue::Set(s) => s.prove_nonexistence(key), - _ => Err(Error::custom(format!( - "Invalid container value {}", - self.typed() - ))), - } - } - /// Returns a Merkle state transition proof for inserting a - /// key-value pair (if applicable). - pub(crate) fn prove_insertion( - &self, - key: &Value, - value: &Value, - ) -> Result { - let container = self.typed().clone(); - match container { - TypedValue::Dictionary(mut d) => d.insert(&key.typed().try_into()?, value), - TypedValue::Set(mut s) => s.insert(value), - _ => Err(Error::custom(format!( - "Invalid container value {}", - self.typed() - ))), - } - } - /// Returns a Merkle state transition proof for updating a - /// key-value pair (if applicable). - pub(crate) fn prove_update( - &self, - key: &Value, - value: &Value, - ) -> Result { - let container = self.typed().clone(); - match container { - TypedValue::Array(mut a) => match key.typed() { - TypedValue::Int(i) if i >= &0 => a.update(*i as usize, value), - _ => Err(Error::custom(format!( - "Invalid key {} for container {}.", - key, self - )))?, - }, - TypedValue::Dictionary(mut d) => d.update(&key.typed().try_into()?, value), - _ => Err(Error::custom(format!( - "Invalid container value {} for update op", - self.typed() - ))), - } - } - /// Returns a Merkle state transition proof for deleting a - /// key (if applicable). - pub(crate) fn prove_deletion(&self, key: &Value) -> Result { - let container = self.typed().clone(); - match container { - TypedValue::Dictionary(mut d) => d.delete(&key.typed().try_into()?), - TypedValue::Set(mut s) => s.delete(key), - _ => Err(Error::custom(format!( - "Invalid container value {}", - self.typed() - ))), + _ => None, } } } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 526ff51..dfdfcfc 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -7,17 +7,14 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::primitives::{ - ec::{ - curve::{Point as PublicKey, GROUP_ORDER}, - schnorr::{SecretKey, Signature}, - }, + ec::{curve::GROUP_ORDER, schnorr::Signature}, merkletree::{MerkleProof, MerkleTree, MerkleTreeOp, MerkleTreeStateTransitionProof}, }, middleware::{ hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key, MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result, - Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, - ValueRef, Wildcard, F, + Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, Value, ValueRef, + Wildcard, F, }, }; @@ -241,6 +238,10 @@ pub(crate) fn hash_op(x: Value, y: Value) -> Value { Value::from(hash_values(&[x, y])) } +fn ok_or_type_err(o: Option, v: &Value, typ: &'static str) -> Result { + o.ok_or_else(|| Error::custom(format!("{v} type is not {typ}"))) +} + impl Operation { pub fn op_type(&self) -> OperationType { type OT = OperationType; @@ -404,20 +405,20 @@ impl Operation { v3: &Value, f: impl FnOnce(i64, i64) -> i64, ) -> Result { - let i1: i64 = v1.typed().try_into()?; - let i2: i64 = v2.typed().try_into()?; - let i3: i64 = v3.typed().try_into()?; + let i1 = ok_or_type_err(v1.as_int(), v1, "Int")?; + let i2 = ok_or_type_err(v2.as_int(), v2, "Int")?; + let i3 = ok_or_type_err(v3.as_int(), v3, "Int")?; Ok(i1 == f(i2, i3)) } pub(crate) fn check_public_key(v1: &Value, v2: &Value) -> Result { - let pk: PublicKey = v1.typed().try_into()?; - let sk: SecretKey = v2.typed().try_into()?; + let pk = ok_or_type_err(v1.as_public_key(), v1, "PublicKey")?; + let sk = ok_or_type_err(v2.as_secret_key(), v2, "SecretKey")?; Ok(sk.0 < *GROUP_ORDER && pk == sk.public_key()) } pub(crate) fn check_signed_by(msg: &Value, pk: &Value, sig: &Signature) -> Result { - let pk: PublicKey = pk.typed().try_into()?; + let pk = ok_or_type_err(pk.as_public_key(), pk, "PublicKey")?; Ok(sig.verify(pk, msg.raw())) } @@ -428,8 +429,8 @@ impl Operation { let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err); let int_val = |v, s| { let v_op = value_from_op(s, v).ok_or_else(deduction_err)?; - match v_op.typed() { - &TypedValue::Int(i) => Ok(i), + match v_op.as_int() { + Some(i) => Ok(i), _ => Err(deduction_err()), } }; @@ -494,8 +495,7 @@ impl Operation { && pf.op_value == value.raw()) .then_some(()) .ok_or(Error::custom( - "The provided Merkle tree state transition proof does not match the claim." - .into(), + "The provided Merkle tree state transition proof does not match the claim.", ))?; MerkleTree::verify_state_transition(pf)?; true @@ -515,8 +515,7 @@ impl Operation { && pf.op_value == value.raw()) .then_some(()) .ok_or(Error::custom( - "The provided Merkle tree state transition proof does not match the claim." - .into(), + "The provided Merkle tree state transition proof does not match the claim.", ))?; MerkleTree::verify_state_transition(pf)?; true @@ -534,8 +533,7 @@ impl Operation { && pf.op_key == key.raw()) .then_some(()) .ok_or(Error::custom( - "The provided Merkle tree state transition proof does not match the claim." - .into(), + "The provided Merkle tree state transition proof does not match the claim.", ))?; MerkleTree::verify_state_transition(pf)?; true @@ -789,9 +787,8 @@ impl fmt::Display for Operation { pub(crate) fn root_key_to_ak(root: &Value, key: &Value) -> Option { let root_hash = Hash::from(root.raw()); - Key::try_from(key.typed()) - .map(|key| AnchoredKey::new(root_hash, key)) - .ok() + key.as_str() + .map(|s| AnchoredKey::new(root_hash, Key::from(s))) } /// Returns the value associated with `output_ref`. From 1e592e11cf802a566b899c767f19dd6b497e693f Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Tue, 24 Mar 2026 14:25:11 +0000 Subject: [PATCH 03/10] Self-referential predicate hashes as statement template args (#494) * Support quoted predicate hashes, including self-referential predicates * Clippy * Review feedback --- src/backends/plonky2/circuits/common.rs | 25 ++- src/backends/plonky2/circuits/mainpod.rs | 144 ++++++++++++++- src/backends/plonky2/mainpod/mod.rs | 62 ++++++- src/frontend/custom.rs | 132 +++++++++++++- src/frontend/mod.rs | 4 +- src/lang/frontend_ast_lower.rs | 3 + src/lang/module.rs | 4 +- src/middleware/custom.rs | 221 ++++++++++++++++++++++- src/middleware/operation.rs | 9 +- 9 files changed, 573 insertions(+), 31 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index db8c32a..7d25786 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -771,7 +771,8 @@ impl CustomPredicateEntryTarget { pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?; pw.set_target(self.index, F::from_canonical_usize(predicate.index))?; - // Replace statement templates of batch-self with (id,index) + // Replace BatchSelf predicates with Custom(batch, i), and + // SelfPredicateHash args with Literal(hash(Custom(batch, i))) let batch = &predicate.batch; let predicate = predicate.predicate(); let statements = predicate @@ -788,10 +789,22 @@ impl CustomPredicateEntryTarget { } x => x.clone(), }; - StatementTmpl { - pred_or_wc, - args: st_tmpl.args, - } + let args = st_tmpl + .args + .into_iter() + .map(|arg| match arg { + StatementTmplArg::SelfPredicateHash(i) => { + let pred_hash = Predicate::Custom(CustomPredicateRef { + batch: batch.clone(), + index: i, + }) + .hash(); + StatementTmplArg::Literal(Value::from(pred_hash)) + } + other => other, + }) + .collect(); + StatementTmpl { pred_or_wc, args } }) .collect_vec(); let predicate = CustomPredicate { @@ -2012,7 +2025,7 @@ pub(crate) mod tests { // Empty case let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into()); _ = cpb_builder.predicate_and("empty", &[], &[], &[])?; - let custom_predicate_batch = cpb_builder.finish(); + let custom_predicate_batch = cpb_builder.finish()?; helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap(); // Some cases from the examples diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index b0c8f48..68114d2 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -54,8 +54,8 @@ use crate::{ measure_gates_begin, measure_gates_end, middleware::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, - NativePredicate, Params, PredicatePrefix, RawValue, Statement, ToFields, Value, F, - HASH_SIZE, + NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix, + ToFields, Value, F, HASH_SIZE, }, }; // @@ -1534,8 +1534,8 @@ pub fn calculate_statements_hash_circuit( sts_hash } -// Replace predicates of batch-self with the corresponding global custom predicate batch_id and -// index +// Replace BatchSelf predicates with the corresponding Custom(batch_id, index), and +// SelfPredicateHash args with Literal(hash(Custom(batch_id, index))). fn normalize_st_tmpl_circuit( params: &Params, builder: &mut CircuitBuilder, @@ -1564,7 +1564,41 @@ fn normalize_st_tmpl_circuit( ); let pred_hash_or_wc = PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data); - StatementTmplTarget::new(pred_hash_or_wc, st_tmpl.args.clone()) + + // Normalize SelfPredicateHash args: replace prefix 4 with Literal containing the resolved + // predicate hash. Same pattern as the predicate normalization above. + let prefix_sph = builder.constant(F::from(StatementTmplArgPrefix::SelfPredicateHash)); + let prefix_literal = builder.constant(F::from(StatementTmplArgPrefix::Literal)); + let zero = builder.zero(); + let normalized_args = st_tmpl + .args + .iter() + .map(|arg| { + let is_sph = builder.is_equal(arg.elements[0], prefix_sph); + + // The predicate index is in elements[1] (same slot as WildcardLiteral). + let pred_index = arg.elements[1]; + + // Compute hash(Custom(batch_id, pred_index)) + let pred_target = PredicateTarget::new_custom(builder, id, pred_index); + let pred_hash = pred_target.hash(builder); + + // Build a Literal-encoded arg: [1, hash[0..4], 0, 0, 0, 0] + let mut literal_elements = [zero; Params::statement_tmpl_arg_size()]; + literal_elements[0] = prefix_literal; + literal_elements[1] = pred_hash.elements[0]; + literal_elements[2] = pred_hash.elements[1]; + literal_elements[3] = pred_hash.elements[2]; + literal_elements[4] = pred_hash.elements[3]; + let normalized = StatementTmplArgTarget { + elements: literal_elements, + }; + + builder.select_flattenable(params, is_sph, &normalized, arg) + }) + .collect(); + + StatementTmplTarget::new(pred_hash_or_wc, normalized_args) } /// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as @@ -3262,7 +3296,7 @@ mod tests { &[stb0.clone(), stb1.clone()], )?; let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?; - let batch = builder.finish(); + let batch = builder.finish()?; let dict = Hash([F(6), F(7), F(8), F(9)]); @@ -3352,7 +3386,7 @@ mod tests { &[stb0.clone(), stb1.clone()], )?; let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?; - let batch = builder.finish(); + let batch = builder.finish()?; let dict = Hash([F(1), F(2), F(3), F(4)]); let secret_dict = Hash([F(6), F(7), F(8), F(9)]); @@ -3570,4 +3604,100 @@ mod tests { Ok(()) } + + #[test] + fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> { + let params = Params::default(); + + // Build a batch with two predicates: + // pred_A: Equal(x, y) + // pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash + use NativePredicate as NP; + let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal) + .arg("x") + .arg("y"); + cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a]) + .unwrap(); + + // Build pred_B's template manually with SelfPredicateHash(0) + let stb_b_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)), + args: vec![ + StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)), + StatementTmplArg::SelfPredicateHash(0), + ], + }; + let pred_b = CustomPredicate::new( + ¶ms, + "pred_B".into(), + true, + vec![stb_b_tmpl], + 1, + vec!["x".to_string()], + ) + .unwrap(); + cpb.predicates.push(pred_b); + let batch = cpb.finish().unwrap(); + + // Compute the expected resolved hash of pred_A + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_a_hash = Predicate::Custom(pred_a_ref).hash(); + let expected_pred_a_value = Value::from(pred_a_hash); + + // Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to + // Literal(pred_a_hash). Then make_statement_from_template_circuit should produce + // a statement with that literal value. + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_b_tmpl = &pred_b_ref.predicate().statements[0]; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + // Create the template target and batch id target + let st_tmpl_target = builder.add_virtual_statement_tmpl(true); + let batch_id = builder.add_virtual_hash(); + + // Normalize the template (this is what we're testing) + let normalized = + normalize_st_tmpl_circuit(¶ms, &mut builder, &st_tmpl_target, batch_id); + + // Feed normalized template into statement generation + let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_target = + make_statement_from_template_circuit(¶ms, &mut builder, &normalized, &args_target); + + // Connect to expected output + let expected_st_target = builder.add_virtual_statement(false); + builder.connect_flattenable(&expected_st_target, &st_target); + + // Set witness + let mut pw = PartialWitness::::new(); + st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?; + pw.set_target_arr(&batch_id.elements, &batch.id().0)?; + + let some_value = Value::from(42); + // args: first wildcard is "x" = some_value, rest are padding + let mut args_values = vec![some_value.clone()]; + for _ in 1..params.max_custom_predicate_wildcards { + args_values.push(Value::from(EMPTY_VALUE)); + } + for (target, value) in args_target.iter().zip(args_values.iter()) { + target.set_targets(&mut pw, value)?; + } + + // Expected statement: Equal(Literal(some_value), Literal(pred_a_hash)) + let expected_st: crate::backends::plonky2::mainpod::Statement = + Statement::equal(some_value, expected_pred_a_value).into(); + expected_st_target.set_targets(&mut pw, &expected_st)?; + + // Build and verify + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) + } } diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 8e6ed46..4968316 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -104,8 +104,9 @@ pub(crate) fn extract_custom_predicate_verifications( if let middleware::Operation::Custom(cpr, sts) = op { if let middleware::Statement::Custom(st_cpr, st_args) = st { assert_eq!(cpr, st_cpr); + let normalized_pred = cpr.normalized_predicate(); let wildcard_values = - wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args) + wildcard_values_from_op_st(params, &normalized_pred, sts, st_args) .expect("resolved wildcards"); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); let custom_predicate_table_index = custom_predicates @@ -1096,7 +1097,7 @@ pub mod tests { &[stb0.clone(), stb1.clone()], )?; let _ = cpb_builder.predicate_or("pred_or", &["dict"], &["secret_dict"], &[stb0, stb1])?; - let cpb = cpb_builder.finish(); + let cpb = cpb_builder.finish()?; let cpb_and = CustomPredicateRef::new(cpb.clone(), 0); let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1); @@ -1130,6 +1131,63 @@ pub mod tests { Ok(pod.verify()?) } + #[test] + fn test_main_self_predicate_hash() -> frontend::Result<()> { + use frontend::BuilderArg; + + let params = Params { + max_signed_by: 0, + max_input_pods: 0, + max_statements: 6, + max_public_statements: 2, + max_operation_args: 5, + max_custom_predicate_wildcards: 4, + max_custom_predicate_verifications: 2, + max_merkle_proofs_containers: 0, + max_merkle_tree_state_transition_proofs_containers: 0, + ..Default::default() + }; + let mut vds = DEFAULT_VD_LIST.clone(); + vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); + let vd_set = VDSet::new(&vds); + + // Build a batch: pred_A references pred_B's hash, pred_B references pred_A's hash + let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb_a = STB::new_from_pred(NP::Equal) + .arg("x") + .arg(BuilderArg::SelfPredicateHash("pred_B".into())); + cpb.predicate_and("pred_A", &["x"], &[], &[stb_a])?; + + let stb_b = STB::new_from_pred(NP::Equal) + .arg("x") + .arg(BuilderArg::SelfPredicateHash("pred_A".into())); + cpb.predicate_and("pred_B", &["x"], &[], &[stb_b])?; + + let batch = cpb.finish()?; + + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_b_hash = middleware::Value::from(middleware::Predicate::Custom(pred_b_ref).hash()); + + // Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash) + let mut pod_builder = MainPodBuilder::new(¶ms, &vd_set); + let eq_st = + pod_builder.priv_op(frontend::Operation::eq(pred_b_hash.clone(), pred_b_hash))?; + pod_builder.pub_op(frontend::Operation::custom(pred_a_ref, [eq_st]))?; + + // Mock + let prover = MockProver {}; + let pod = pod_builder.prove(&prover)?; + assert!(pod.pod.verify().is_ok()); + + // Real + let prover = Prover {}; + let pod = pod_builder.prove(&prover)?; + let pod = (pod.pod as Box).downcast::().unwrap(); + + Ok(pod.verify()?) + } + #[test] fn test_set_contains() -> frontend::Result<()> { let params = Params::default(); diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 92fdc4f..f3a8115 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -18,6 +18,8 @@ pub enum BuilderArg { /// Key: (origin, key), where origin is Wildcard and key is Key Key(String, String), WildcardLiteral(String), + /// Reference to a same-batch predicate's identity hash (resolved by name in finish()). + SelfPredicateHash(String), } /// When defining a `BuilderArg`, it can be done from 3 different inputs: @@ -130,6 +132,8 @@ pub struct CustomPredicateBatchBuilder { params: Params, pub name: String, pub predicates: Vec, + /// Forward references to resolve in finish(): (predicate_idx, statement_idx, arg_idx, name) + pending_self_pred_hashes: Vec<(usize, usize, usize, String)>, } impl CustomPredicateBatchBuilder { @@ -138,6 +142,7 @@ impl CustomPredicateBatchBuilder { params, name, predicates: Vec::new(), + pending_self_pred_hashes: Vec::new(), } } @@ -194,14 +199,18 @@ impl CustomPredicateBatchBuilder { )); } + let pred_idx = self.predicates.len(); + let mut pending = Vec::new(); let statements = sts .iter() - .map(|sb| { + .enumerate() + .map(|(stmt_idx, sb)| { let stb = sb.clone().desugar(); let st_tmpl_args = stb .args .iter() - .map(|a| { + .enumerate() + .map(|(arg_idx, a)| { Ok::<_, Error>(match a { BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey( @@ -211,6 +220,22 @@ impl CustomPredicateBatchBuilder { BuilderArg::WildcardLiteral(v) => { StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?) } + BuilderArg::SelfPredicateHash(pred_name) => { + // Try backward reference first + match self.predicates.iter().position(|p| p.name == *pred_name) { + Some(index) => StatementTmplArg::SelfPredicateHash(index), + None => { + // Forward reference - placeholder, resolved in finish() + pending.push(( + pred_idx, + stmt_idx, + arg_idx, + pred_name.clone(), + )); + StatementTmplArg::SelfPredicateHash(0) + } + } + } }) }) .collect::>()?; @@ -240,11 +265,27 @@ impl CustomPredicateBatchBuilder { .collect(), )?; self.predicates.push(custom_predicate); + self.pending_self_pred_hashes.extend(pending); Ok(Predicate::BatchSelf(self.predicates.len() - 1)) } - pub fn finish(self) -> Arc { - CustomPredicateBatch::new(self.name, self.predicates) + pub fn finish(mut self) -> Result> { + // Resolve forward references for SelfPredicateHash + for (pred_idx, stmt_idx, arg_idx, ref name) in &self.pending_self_pred_hashes { + let target_idx = self + .predicates + .iter() + .position(|p| p.name == *name) + .ok_or_else(|| { + Error::custom(format!( + "SelfPredicateHash references unknown predicate '{}'", + name + )) + })?; + self.predicates[*pred_idx].statements[*stmt_idx].args[*arg_idx] = + StatementTmplArg::SelfPredicateHash(target_idx); + } + Ok(CustomPredicateBatch::new(self.name, self.predicates)) } } @@ -306,7 +347,7 @@ mod tests { .arg("s2"); builder.predicate_and("gt_custom_pred", &["s1", "s2"], &[], &[gt_stb])?; - let batch = builder.finish(); + let batch = builder.finish()?; let batch_clone = batch.clone(); let gt_custom_pred = CustomPredicateRef::new(batch, 0); @@ -356,7 +397,7 @@ mod tests { &[], &[set_contains_stb], )?; - let batch = builder.finish(); + let batch = builder.finish()?; let batch_clone = batch.clone(); let mut mp_builder = MainPodBuilder::new(¶ms, vd_set); @@ -386,4 +427,83 @@ mod tests { Ok(()) } + + #[test] + fn test_builder_self_predicate_hash_unknown_ref() { + let params = Params::default(); + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + + let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) + .arg("x") + .arg(BuilderArg::SelfPredicateHash("nonexistent".into())); + builder + .predicate_and("pred_A", &["x"], &[], &[stb]) + .unwrap(); + + // finish() should fail because "nonexistent" was never defined + assert!(builder.finish().is_err()); + } + + /// Tests cyclic SelfPredicateHash references end-to-end: + /// pred_A references pred_B's hash (forward ref), pred_B references pred_A's hash (backward + /// ref). Exercises forward reference resolution in finish(), then builds and verifies a POD + /// using pred_A via MockProver. + #[test] + fn test_builder_self_predicate_hash_e2e() -> Result<()> { + let params = Params::default(); + let vd_set = &*MOCK_VD_SET; + + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + + // pred_A references pred_B's hash (forward ref, pred_B not yet defined) + let stb_a = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) + .arg("x") + .arg(BuilderArg::SelfPredicateHash("pred_B".into())); + builder.predicate_and("pred_A", &["x"], &[], &[stb_a])?; + + // pred_B references pred_A's hash (backward ref, pred_A already defined) + let stb_b = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) + .arg("x") + .arg(BuilderArg::SelfPredicateHash("pred_A".into())); + builder.predicate_and("pred_B", &["x"], &[], &[stb_b])?; + + let batch = builder.finish()?; + + // Verify resolution: pred_A references pred_B (index 1), pred_B references pred_A (index 0) + assert_eq!( + batch.predicates()[0].statements[0].args[1], + StatementTmplArg::SelfPredicateHash(1) + ); + assert_eq!( + batch.predicates()[1].statements[0].args[1], + StatementTmplArg::SelfPredicateHash(0) + ); + + // Compute concrete hashes + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash()); + + // Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash) + let mut mp_builder = MainPodBuilder::new(¶ms, vd_set); + let eq_st = mp_builder.priv_op(Operation::eq(pred_b_hash.clone(), pred_b_hash.clone()))?; + mp_builder.pub_op(Operation::custom(pred_a_ref, [eq_st]))?; + + // Prove and verify + let prover = MockProver {}; + let proof = mp_builder.prove(&prover)?; + proof.pod.verify()?; + + // Verify the public statement contains pred_b_hash as its argument + let pub_sts = proof.pod.pub_self_statements(); + let custom_st = pub_sts + .iter() + .find(|s| matches!(s, middleware::Statement::Custom(_, _))) + .expect("should have a custom statement"); + if let middleware::Statement::Custom(_, args) = custom_st { + assert_eq!(args[0], pred_b_hash); + } + + Ok(()) + } } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 98f280e..1ce2795 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -578,7 +578,7 @@ impl MainPodBuilder { } } OperationType::Custom(cpr) => { - let pred = &cpr.batch.predicates()[cpr.index]; + let pred = cpr.normalized_predicate(); if pred.statements.len() != op.1.len() { return Err(Error::custom(format!( "Custom predicate operation needs {} statements but has {}.", @@ -606,7 +606,7 @@ impl MainPodBuilder { } wildcard_map[index] = Some(value); } - fill_wildcard_values(pred, &args, &mut wildcard_map)?; + fill_wildcard_values(&pred, &args, &mut wildcard_map)?; let v_default = Value::from(0); let st_args: Vec<_> = wildcard_map .into_iter() diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index b429f4a..fe9b745 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -346,6 +346,9 @@ impl<'a> Lowerer<'a> { let key = Key::from(key_str.as_str()); MWStatementTmplArg::AnchoredKey(wildcard, key) } + BuilderArg::SelfPredicateHash(_) => { + unreachable!("SelfPredicateHash should not appear in request lowering") + } }; mw_args.push(mw_arg); } diff --git a/src/lang/module.rs b/src/lang/module.rs index 3ff3d6b..78fb22e 100644 --- a/src/lang/module.rs +++ b/src/lang/module.rs @@ -345,7 +345,9 @@ fn build_single_batch( })?; } - Ok(builder.finish()) + builder.finish().map_err(|e| BatchingError::Internal { + message: format!("Failed to finalize batch '{}': {}", batch_name, e), + }) } /// Build a statement template with properly resolved predicate references diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 13cc387..cf6d9be 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -49,6 +49,9 @@ pub enum StatementTmplArg { // AnchoredKey where the origin is a wildcard AnchoredKey(Wildcard, Key), Wildcard(Wildcard), + /// Reference to a same-batch predicate's identity hash, resolved at verification time. + /// The usize is the predicate index within the batch. + SelfPredicateHash(usize), } #[derive(Clone, Copy)] @@ -57,6 +60,7 @@ pub enum StatementTmplArgPrefix { Literal = 1, AnchoredKey = 2, WildcardLiteral = 3, + SelfPredicateHash = 4, } impl From for F { @@ -68,11 +72,12 @@ impl From for F { impl ToFields for StatementTmplArg { fn to_fields(&self) -> Vec { // Encoding: - // None => (0, 0, 0, 0, 0, 0, 0, 0, 0) - // Literal(v) => (1, [v ], 0, 0, 0, 0) - // Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc]) - // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0) - // In all three cases, we pad to 2 * hash_size + 1 = 9 field elements + // None => (0, 0, 0, 0, 0, 0, 0, 0, 0) + // Literal(v) => (1, [v ], 0, 0, 0, 0) + // Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc]) + // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0) + // SelfPredicateHash(pred_index) => (4, pred_index, 0, 0, 0, 0, 0, 0, 0) + // In all cases, we pad to 2 * hash_size + 1 = 9 field elements match self { StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None)) .chain(iter::repeat(F::ZERO)) @@ -97,6 +102,13 @@ impl ToFields for StatementTmplArg { .take(Params::statement_tmpl_arg_size()) .collect_vec() } + StatementTmplArg::SelfPredicateHash(index) => { + iter::once(F::from(StatementTmplArgPrefix::SelfPredicateHash)) + .chain(iter::once(F::from_canonical_usize(*index))) + .chain(iter::repeat(F::ZERO)) + .take(Params::statement_tmpl_arg_size()) + .collect_vec() + } } } } @@ -113,6 +125,7 @@ impl fmt::Display for StatementTmplArg { write!(f, "]") } Self::Wildcard(v) => v.fmt(f), + Self::SelfPredicateHash(i) => write!(f, "::self.{}", i), } } } @@ -569,6 +582,44 @@ impl CustomPredicateRef { pub fn predicate(&self) -> &CustomPredicate { &self.batch.predicates()[self.index] } + + /// Returns a copy of this predicate with all `SelfPredicateHash(i)` args + /// resolved to `Literal(hash(Custom(batch, i)))`. + pub fn normalized_predicate(&self) -> CustomPredicate { + let pred = self.predicate(); + let normalized_statements = pred + .statements + .iter() + .map(|st_tmpl| { + let args = st_tmpl + .args + .iter() + .map(|arg| match arg { + StatementTmplArg::SelfPredicateHash(i) => { + let pred_hash = Predicate::Custom(CustomPredicateRef { + batch: self.batch.clone(), + index: *i, + }) + .hash(); + StatementTmplArg::Literal(Value::from(pred_hash)) + } + other => other.clone(), + }) + .collect(); + StatementTmpl { + pred_or_wc: st_tmpl.pred_or_wc.clone(), + args, + } + }) + .collect(); + CustomPredicate { + name: pred.name.clone(), + conjunction: pred.conjunction, + statements: normalized_statements, + args_len: pred.args_len, + wildcard_names: pred.wildcard_names.clone(), + } + } } #[cfg(test)] @@ -823,4 +874,164 @@ mod tests { Ok(()) } + + #[test] + fn test_normalized_predicate() -> Result<()> { + let params = Params::default(); + + // Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0)) + let pred_a = CustomPredicate::and( + ¶ms, + "pred_A".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))], + )], + 2, + names(&["x", "y"]), + )?; + let pred_b = CustomPredicate::and( + ¶ms, + "pred_B".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)], + )], + 1, + names(&["x"]), + )?; + let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]); + + // Compute expected pred_A hash + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let expected_hash = Value::from(Predicate::Custom(pred_a_ref).hash()); + + // Normalize pred_B + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let normalized = pred_b_ref.normalized_predicate(); + + // The second arg should be resolved to Literal(pred_A_hash) + assert_eq!( + normalized.statements[0].args[1], + STA::Literal(expected_hash) + ); + + // First arg should be unchanged (still a wildcard) + assert_eq!(normalized.statements[0].args[0], STA::Wildcard(wc(0))); + + Ok(()) + } + + #[test] + fn test_self_predicate_hash_check() -> Result<()> { + let params = Params::default(); + + // Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0)) + let pred_a = CustomPredicate::and( + ¶ms, + "pred_A".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))], + )], + 2, + names(&["x", "y"]), + )?; + let pred_b = CustomPredicate::and( + ¶ms, + "pred_B".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)], + )], + 1, + names(&["x"]), + )?; + let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]); + + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref).hash()); + + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + + // Construct a valid operation: Equal(some_value, pred_a_hash) + let some_value = Value::from(42); + let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())]; + + // The output statement + let output_st = Statement::Custom(pred_b_ref.clone(), vec![some_value.clone()]); + + // This should pass + assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(¶ms, &output_st)?); + + // Now try with wrong hash, should fail + let wrong_hash = Value::from(999); + let bad_op_args = vec![Statement::equal(some_value.clone(), wrong_hash)]; + assert!(Operation::Custom(pred_b_ref, bad_op_args) + .check(¶ms, &output_st) + .is_err()); + + Ok(()) + } + + #[test] + fn test_self_predicate_hash_cyclic() -> Result<()> { + let params = Params::default(); + + // Build a batch where pred_A references pred_B's hash and vice versa + // pred_A = Equal(x, SelfPredicateHash(1)) + // pred_B = Equal(x, SelfPredicateHash(0)) + let pred_a = CustomPredicate::and( + ¶ms, + "pred_A".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(1)], + )], + 1, + names(&["x"]), + )?; + let pred_b = CustomPredicate::and( + ¶ms, + "pred_B".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)], + )], + 1, + names(&["x"]), + )?; + let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]); + + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref.clone()).hash()); + let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash()); + + // pred_A's normalized form should reference pred_B's hash + let norm_a = pred_a_ref.normalized_predicate(); + assert_eq!( + norm_a.statements[0].args[1], + STA::Literal(pred_b_hash.clone()) + ); + + // pred_B's normalized form should reference pred_A's hash + let norm_b = pred_b_ref.normalized_predicate(); + assert_eq!( + norm_b.statements[0].args[1], + STA::Literal(pred_a_hash.clone()) + ); + + // Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass + let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())]; + let st_a = Statement::Custom(pred_a_ref.clone(), vec![pred_b_hash.clone()]); + assert!(Operation::Custom(pred_a_ref, op_a).check(¶ms, &st_a)?); + + // Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass + let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())]; + let st_b = Statement::Custom(pred_b_ref.clone(), vec![pred_a_hash.clone()]); + assert!(Operation::Custom(pred_b_ref, op_b).check(¶ms, &st_b)?); + + Ok(()) + } } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index dfdfcfc..1793e4d 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -595,6 +595,11 @@ pub fn check_st_tmpl( (StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => { wc_check_or_set(v.clone(), wc, wildcard_map) } + (StatementTmplArg::SelfPredicateHash(_), _) => { + unreachable!( + "SelfPredicateHash should be normalized to Literal before template matching" + ) + } _ => Err(Error::mismatched_statement_tmpl_arg( st_tmpl_arg.clone(), st_arg.clone(), @@ -712,7 +717,7 @@ pub(crate) fn check_custom_pred( args: &[Statement], s_args: &[Value], ) -> Result<()> { - let pred = custom_pred_ref.predicate(); + let pred = custom_pred_ref.normalized_predicate(); if pred.statements.len() != args.len() { return Err(Error::diff_amount( "custom predicate operation".to_string(), @@ -731,7 +736,7 @@ pub(crate) fn check_custom_pred( } // Check that the resolved wildcards match the statement arguments. - let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) { + let wc_values = match wildcard_values_from_op_st(params, &pred, args, s_args) { Ok(wc_values) => wc_values, Err(Error::Inner { inner, backtrace }) => match *inner { MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev) From a4069bcc55e86a5f6e53e0374510f49d75b64b37 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Wed, 25 Mar 2026 18:48:28 +0100 Subject: [PATCH 04/10] Fix pod builder (#496) Several fixes and code simplifications: - MainPodBuilder - Fix: It was not tracking Contains statements inherited via input pods (via public statements) when automatically generating Contains statements for Entry arguments. - Enhancement: Deduplicate statements - MultiPodBuilder - Simplify: Remove the "statement groups" logic and instead deduplicate statements in the MainPodBuilder (which is much simpler to do) - Remove the "anchored key" explicit dependency tracking and instead rely on regular dependency tracking by using all the implicit operations and statements generated by MainPodBuilder as input to the solver. - Fix: Count and constrain custom predicates used in a pod instead of batches used --- src/backends/plonky2/mainpod/mod.rs | 3 +- src/frontend/mod.rs | 78 ++++++---- src/frontend/multi_pod/cost.rs | 86 ++--------- src/frontend/multi_pod/deps.rs | 37 +---- src/frontend/multi_pod/mod.rs | 184 +++++++---------------- src/frontend/multi_pod/solver.rs | 222 ++++++---------------------- src/middleware/db/mem.rs | 6 +- 7 files changed, 162 insertions(+), 454 deletions(-) diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 4968316..ae1ade3 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1253,7 +1253,8 @@ pub mod tests { cpr, [1, 1, 2].into_iter().map(middleware::Value::from).collect(), ); - builder.insert(true, (st, op)).unwrap(); + builder.insert((st.clone(), op)).unwrap(); + builder.reveal(&st).unwrap(); let prover = Prover {}; builder.prove(&prover).unwrap(); } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 1ce2795..f23e374 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -137,7 +137,7 @@ pub struct MainPodBuilder { pub operations: Vec, pub public_statements: Vec, // Internal state - dict_contains: Vec<(Value, Value)>, // (root, key) + contains: Vec<(RawValue, RawValue)>, // (root, key) } impl fmt::Display for MainPodBuilder { @@ -171,10 +171,16 @@ impl MainPodBuilder { statements: Vec::new(), operations: Vec::new(), public_statements: Vec::new(), - dict_contains: Vec::new(), + contains: Vec::new(), } } + pub fn stmt_len(&self) -> usize { + self.statements.len() + } pub fn add_pod(&mut self, pod: MainPod) -> Result<()> { + for st in &pod.public_statements { + self.track_contains(st); + } self.input_pods.push(pod); match self.input_pods.len() > self.params.max_input_pods { true => Err(Error::too_many_input_pods( @@ -184,31 +190,26 @@ impl MainPodBuilder { _ => Ok(()), } } - pub fn insert(&mut self, public: bool, st_op: (Statement, Operation)) -> Result<()> { - // TODO: Do error handling instead of panic - let (st, op) = st_op; - // If we're adding a Contains statement with literal arguments (an Entry), track it in - // `dict_contains` to avoid adding it again via `Self::add_entries_contains`. + // If we're adding a Contains statement with literal arguments (an Entry), track it in + // `dict_contains` to avoid adding it again via `Self::add_entries_contains`. + fn track_contains(&mut self, st: &Statement) { if let Statement::Contains( ValueRef::Literal(dict), ValueRef::Literal(key), ValueRef::Literal(_), ) = &st { - let root_key = (dict.clone(), key.clone()); - self.dict_contains.push(root_key); + let root_key = (dict.raw(), key.raw()); + self.contains.push(root_key); } + } + + pub fn insert(&mut self, st_op: (Statement, Operation)) -> Result<()> { + // TODO: Do error handling instead of panic + let (st, op) = st_op; + self.track_contains(&st); - if public { - self.public_statements.push(st.clone()); - } - if self.public_statements.len() > self.params.max_public_statements { - return Err(Error::too_many_public_statements( - self.public_statements.len(), - self.params.max_public_statements, - )); - } self.statements.push(st); self.operations.push(op); if self.statements.len() > self.params.max_statements { @@ -404,7 +405,7 @@ impl MainPodBuilder { } fn op_statement( - &mut self, + &self, wildcard_values: Vec<(usize, Value)>, op: Operation, ) -> Result { @@ -621,7 +622,7 @@ impl MainPodBuilder { } /// For every operation that has Entry statements as arguments we add a Contains statement to - /// open the dictionary. + /// open the dictionary (unless such Contains already exists). fn add_entries_contains(&mut self, op: &Operation) -> Result<()> { for arg in &op.1 { if let OperationArg::Statement(Statement::Contains( @@ -630,9 +631,9 @@ impl MainPodBuilder { ValueRef::Literal(v), )) = arg { - let root_key = (dict.clone(), key.clone()); - if !self.dict_contains.contains(&root_key) { - self.dict_contains.push(root_key); + let root_key = (dict.raw(), key.raw()); + if !self.contains.contains(&root_key) { + self.contains.push(root_key); self.priv_op(Operation::dict_contains(dict, key, v))?; } } @@ -650,13 +651,28 @@ impl MainPodBuilder { self.add_entries_contains(&op)?; let op = Self::fill_in_aux(Self::lower_op(op)?)?; let st = self.op_statement(wildcard_values, op.clone())?; - self.insert(public, (st, op))?; + // Skip adding the statement and operation if it already exists + if !self.statements.contains(&st) { + self.insert((st.clone(), op))?; + } + if public { + self.reveal(&st)?; + } - Ok(self.statements[self.statements.len() - 1].clone()) + Ok(st) } - pub fn reveal(&mut self, st: &Statement) { - self.public_statements.push(st.clone()); + pub fn reveal(&mut self, st: &Statement) -> Result<()> { + if !self.public_statements.contains(st) { + self.public_statements.push(st.clone()); + } + if self.public_statements.len() > self.params.max_public_statements { + return Err(Error::too_many_public_statements( + self.public_statements.len(), + self.params.max_public_statements, + )); + } + Ok(()) } pub fn prove(&self, prover: &dyn MainPodProver) -> Result { @@ -1351,11 +1367,9 @@ pub mod tests { OperationAux::None, ); builder - .insert(false, (value_of_a.clone(), op_contains.clone())) - .unwrap(); - builder - .insert(false, (value_of_b.clone(), op_contains)) + .insert((value_of_a.clone(), op_contains.clone())) .unwrap(); + builder.insert((value_of_b.clone(), op_contains)).unwrap(); let st = Statement::equal( AnchoredKey::from((&local, "a")), AnchoredKey::from((&local, "b")), @@ -1368,7 +1382,7 @@ pub mod tests { ], OperationAux::None, ); - builder.insert(false, (st, op)).unwrap(); + builder.insert((st, op)).unwrap(); let prover = MockProver {}; let pod = builder.prove(&prover).unwrap(); diff --git a/src/frontend/multi_pod/cost.rs b/src/frontend/multi_pod/cost.rs index a5d89da..0c0c2ef 100644 --- a/src/frontend/multi_pod/cost.rs +++ b/src/frontend/multi_pod/cost.rs @@ -6,60 +6,20 @@ use std::collections::BTreeSet; use crate::{ - frontend::{Operation, OperationArg}, - middleware::{ - CustomPredicateBatch, Hash, NativeOperation, OperationType, RawValue, Statement, ValueRef, - }, + frontend::Operation, + middleware::{CustomPredicateRef, Hash, NativeOperation, OperationType, Predicate}, }; -/// Unique identifier for a custom predicate batch. +/// Unique identifier for a custom predicate in a module. /// -/// Uses the batch's cryptographic hash as identifier. Two batches with the same +/// Uses the predicate's cryptographic hash as identifier. Two predicates with the same /// hash are considered identical for resource counting purposes. #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct CustomBatchId(pub Hash); +pub struct CustomPredicateId(pub Hash); -impl From<&CustomPredicateBatch> for CustomBatchId { - fn from(batch: &CustomPredicateBatch) -> Self { - Self(batch.id()) - } -} - -/// Unique identifier for an anchored key (dict, key) pair. -/// -/// When a Contains statement is used as an argument to operations like gt(), eq(), etc., -/// the value is accessed via an "anchored key" - a reference to a specific key in a -/// specific dictionary. Each unique anchored key used in a POD requires a Contains -/// statement to be present in that POD (auto-inserted by MainPodBuilder if needed). -/// -/// We use the raw values of the dict and key for comparison, as they uniquely identify -/// the anchored key regardless of the specific Value types involved. -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct AnchoredKeyId { - /// The dictionary root value (raw representation for Ord). - pub dict: RawValue, - /// The key within the dictionary (raw representation for Ord). - pub key: RawValue, -} - -impl AnchoredKeyId { - /// Create a new anchored key ID from raw values. - pub fn new(dict: RawValue, key: RawValue) -> Self { - Self { dict, key } - } - - /// Try to extract an anchored key ID from a Contains statement with all literal values. - pub fn from_contains_statement(stmt: &Statement) -> Option { - if let Statement::Contains( - ValueRef::Literal(dict), - ValueRef::Literal(key), - ValueRef::Literal(_value), - ) = stmt - { - Some(Self::new(dict.raw(), key.raw())) - } else { - None - } +impl From<&CustomPredicateRef> for CustomPredicateId { + fn from(predicate: &CustomPredicateRef) -> Self { + Self(Predicate::Custom(predicate.clone()).hash()) } } @@ -88,17 +48,9 @@ pub struct StatementCost { /// Limit: `params.max_public_key_of` pub public_key_of: usize, - /// Custom predicate batches used (for batch cardinality constraint). - /// Limit: `params.max_custom_predicate_batches` distinct batches per POD. - pub custom_batch_ids: BTreeSet, - - /// Anchored keys referenced by this operation. - /// - /// When a Contains statement with all literal values is used as an argument, - /// the operation references an "anchored key" (dict, key pair). Each unique - /// anchored key used in a POD incurs an additional Contains statement cost, - /// as MainPodBuilder::add_entries_contains will auto-insert it if not already present. - pub anchored_keys: BTreeSet, + /// Custom predicates used (for custom predicate cardinality constraint). + /// Limit: `params.max_custom_predicates` distinct custom predicates per POD. + pub custom_predicates_ids: BTreeSet, } impl StatementCost { @@ -164,20 +116,8 @@ impl StatementCost { } OperationType::Custom(cpr) => { cost.custom_pred_verifications = 1; - cost.custom_batch_ids - .insert(CustomBatchId::from(&*cpr.batch)); - } - } - - // Extract anchored keys from operation arguments. - // Any argument that is a Contains statement with all literal values - // represents an anchored key reference that will require a Contains - // statement in the POD (auto-inserted by MainPodBuilder if needed). - for arg in &op.1 { - if let OperationArg::Statement(stmt) = arg { - if let Some(anchored_key) = AnchoredKeyId::from_contains_statement(stmt) { - cost.anchored_keys.insert(anchored_key); - } + cost.custom_predicates_ids + .insert(CustomPredicateId::from(cpr)); } } diff --git a/src/frontend/multi_pod/deps.rs b/src/frontend/multi_pod/deps.rs index 97b4ef4..9472a1f 100644 --- a/src/frontend/multi_pod/deps.rs +++ b/src/frontend/multi_pod/deps.rs @@ -5,7 +5,6 @@ use std::collections::HashMap; -use super::cost::AnchoredKeyId; use crate::{ frontend::{Operation, OperationArg}, middleware::{Hash, Statement}, @@ -100,11 +99,6 @@ impl DependencyGraph { pod_hash, statement: dep_stmt.clone(), })); - } else if AnchoredKeyId::from_contains_statement(dep_stmt).is_some() { - // Anchored-key Contains args may be implicit requirements that are - // auto-materialized by MainPodBuilder. They are handled by anchored-key - // resource accounting, not by statement dependency edges. - continue; } else { // Statement arguments should either be internal (created earlier) // or from external PODs (except anchored-key implicit Contains). @@ -128,9 +122,8 @@ impl DependencyGraph { mod tests { use super::*; use crate::{ - dict, frontend::Operation as FrontendOp, - middleware::{AnchoredKey, NativeOperation, OperationAux, OperationType, Value, ValueRef}, + middleware::{NativeOperation, OperationAux, OperationType, Value, ValueRef}, }; fn equal_stmt(n: i64) -> Statement { @@ -195,32 +188,4 @@ mod tests { assert_eq!(graph.statement_deps[1], vec![StatementSource::Internal(0)]); assert_eq!(graph.statement_deps[2], vec![StatementSource::Internal(0)]); } - - #[test] - fn test_anchored_key_contains_arg_is_treated_as_implicit_requirement() { - // A literal Contains statement can be used as an anchored-key argument even when - // no explicit producer statement exists in internal/external statements, because - // MainPodBuilder auto-inserts Contains statements for anchored keys. - let dict = dict!({ - "k" => 7_i64 - }); - - let anchored_contains = Statement::Contains( - ValueRef::Literal(Value::from(dict.clone())), - ValueRef::Literal(Value::from("k")), - ValueRef::Literal(Value::from(7_i64)), - ); - let ak = AnchoredKey::from((&dict, "k")); - let produced_statement = Statement::Equal(ValueRef::Key(ak.clone()), ValueRef::Key(ak)); - - // Use a typical frontend operation that consumes entry-like args. - // We're only testing the dependency graph, not the actual proof, so the operation - // just needs to have the right arguments to test what we're looking for. - let statements = vec![produced_statement]; - let operations = vec![FrontendOp::eq(anchored_contains.clone(), anchored_contains)]; - - let graph = DependencyGraph::build(&statements, &operations, &HashMap::new()); - - assert!(graph.statement_deps[0].is_empty()); - } } diff --git a/src/frontend/multi_pod/mod.rs b/src/frontend/multi_pod/mod.rs index d25fcce..6bade5b 100644 --- a/src/frontend/multi_pod/mod.rs +++ b/src/frontend/multi_pod/mod.rs @@ -48,12 +48,12 @@ //! [`MainPodBuilder`]: crate::frontend::MainPodBuilder use std::{ - collections::{BTreeMap, BTreeSet, HashMap}, + collections::{BTreeSet, HashMap}, fmt, }; use crate::{ - frontend::{MainPod, MainPodBuilder, Operation, OperationArg}, + frontend::{MainPod, MainPodBuilder, Operation}, middleware::{Hash, MainPodProver, Params, Statement, VDSet, Value}, }; @@ -61,7 +61,7 @@ mod cost; mod deps; mod solver; -use cost::{AnchoredKeyId, StatementCost}; +use cost::StatementCost; use deps::{DependencyGraph, StatementSource}; pub use solver::MultiPodSolution; @@ -168,12 +168,8 @@ pub struct MultiPodBuilder { options: Options, /// External input PODs (already proved). input_pods: Vec, - /// Statements created by this builder. - statements: Vec, - /// Operations that produce each statement. - operations: Vec, /// Optional initial wildcard values for custom operations - operations_wildcard_values: Vec>, + operations_wildcard_values: HashMap>, /// Indices of statements that should be public in output PODs. /// Uses Vec since max_public_statements is small (≤8); indices are naturally sorted. output_public_indices: Vec, @@ -193,7 +189,7 @@ pub struct SolvedMultiPod { statements: Vec, operations: Vec, output_public_indices: Vec, - operations_wildcard_values: Vec>, + operations_wildcard_values: HashMap>, solution: MultiPodSolution, deps: DependencyGraph, } @@ -260,56 +256,27 @@ impl SolvedMultiPod { let statements_sorted: BTreeSet = statements_in_this_pod.iter().copied().collect(); let public_set = &solution.pod_public_statements[pod_idx]; - // Track statements proved locally in this POD for argument remapping. - // We index by statement content so duplicate statements can reuse a single - // built statement slot in MainPodBuilder. - let mut added_statements_by_content: HashMap = HashMap::new(); - for &stmt_idx in &statements_sorted { - let original_stmt = self.statements[stmt_idx].clone(); - - // If this statement content was already built in this POD, reuse it instead - // of replaying the operation. If any duplicate is public, reveal the - // already-built statement. - if let Some(_existing_stmt) = added_statements_by_content.get(&original_stmt) { - continue; - } - - let mut op = self.operations[stmt_idx].clone(); - let wildcard_values = self.operations_wildcard_values[stmt_idx].clone(); - - // Remap Statement arguments that reference locally-proved statements. - // For external dependencies (from input PODs including earlier generated PODs), - // the original Statement is used directly - MainPodBuilder will find it in - // the input POD's public statements via find_op_arg. - for arg in &mut op.1 { - if let OperationArg::Statement(ref orig_stmt) = arg { - if let Some(remapped_stmt) = added_statements_by_content.get(orig_stmt) { - *arg = OperationArg::Statement(remapped_stmt.clone()); - } - } - } + let op = self.operations[stmt_idx].clone(); + let wildcard_values = self + .operations_wildcard_values + .get(&stmt_idx) + .cloned() + .unwrap_or_default(); let stmt = builder.op(false, wildcard_values, op)?; - - added_statements_by_content.insert(original_stmt, stmt); + assert_eq!(stmt, self.statements[stmt_idx]); // Sanity check } // For the output pod, make statements public in the original order. // Intermediate pods use the solver-selected public set. if pod_idx == solution.pod_count - 1 { for idx in &self.output_public_indices { - let stmt = added_statements_by_content - .get(&self.statements[*idx]) - .expect("exists"); - builder.reveal(stmt); + builder.reveal(&self.statements[*idx])?; } } else { for idx in public_set { - let stmt = added_statements_by_content - .get(&self.statements[*idx]) - .expect("exists"); - builder.reveal(stmt); + builder.reveal(&self.statements[*idx])?; } } @@ -317,7 +284,7 @@ impl SolvedMultiPod { // for this POD. These do not require local proving in this POD. for ext_premise_idx in &solution.pod_public_external_premises[pod_idx] { let ext_premise = &solution.external_premises[*ext_premise_idx]; - builder.reveal(&ext_premise.statement); + builder.reveal(&ext_premise.statement)?; } // Step 4: Prove the POD @@ -456,9 +423,7 @@ impl MultiPodBuilder { options, builder, input_pods: Vec::new(), - statements: Vec::new(), - operations: Vec::new(), - operations_wildcard_values: Vec::new(), + operations_wildcard_values: HashMap::new(), output_public_indices: Vec::new(), } } @@ -480,6 +445,16 @@ impl MultiPodBuilder { self.op(false, vec![], op) } + // Find the index of a statement that has been added. Panics if the statement doesn't + // exist. + fn stmt_index(&self, stmt: &Statement) -> usize { + self.builder + .statements + .iter() + .position(|s| s == stmt) + .expect("exists") + } + pub fn op( &mut self, public: bool, @@ -488,8 +463,10 @@ impl MultiPodBuilder { ) -> Result { let stmt = self.add_operation(wildcard_values, op)?; if public { - // Index is always new (just added), so push without duplicate check - self.output_public_indices.push(self.statements.len() - 1); + let index = self.stmt_index(&stmt); + if !self.output_public_indices.contains(&index) { + self.output_public_indices.push(index); + } } Ok(stmt) } @@ -510,10 +487,8 @@ impl MultiPodBuilder { let stmt = self .builder .op(false, wildcard_values.clone(), op.clone())?; - - self.statements.push(stmt.clone()); - self.operations.push(op); - self.operations_wildcard_values.push(wildcard_values); + self.operations_wildcard_values + .insert(self.stmt_index(&stmt), wildcard_values.clone()); Ok(stmt) } @@ -523,7 +498,7 @@ impl MultiPodBuilder { /// Returns an error if the statement was not found in the builder. /// Calling this multiple times on the same statement is idempotent. pub fn reveal(&mut self, stmt: &Statement) -> Result<()> { - if let Some(idx) = self.statements.iter().position(|s| s == stmt) { + if let Some(idx) = self.builder.statements.iter().position(|s| s == stmt) { if !self.output_public_indices.contains(&idx) { self.output_public_indices.push(idx); } @@ -536,8 +511,8 @@ impl MultiPodBuilder { } /// Get the number of statements. - pub fn num_statements(&self) -> usize { - self.statements.len() + pub fn stmt_len(&self) -> usize { + self.builder.stmt_len() } /// Solve the packing problem and return a solved builder ready for proving. @@ -545,66 +520,31 @@ impl MultiPodBuilder { /// This runs the MILP solver to find the optimal POD assignment. /// Consumes the builder and returns a [`SolvedMultiPod`] that can be proved. pub fn solve(self) -> Result { + let MainPodBuilder { + statements, + operations, + .. + } = self.builder; // Compute costs for each statement - let costs: Vec = self - .operations + let costs: Vec = operations .iter() .map(StatementCost::from_operation) .collect(); - // Collect all unique anchored keys from the costs - let all_anchored_keys: Vec = costs - .iter() - .flat_map(|c| c.anchored_keys.iter().cloned()) - .collect::>() - .into_iter() - .collect(); - - // Build map from anchored key to its producing statement index (if any). - // A Contains statement with literal (dict, key, value) "produces" that anchored key. - let mut ak_to_producer: HashMap = HashMap::new(); - for (stmt_idx, stmt) in self.statements.iter().enumerate() { - if let Some(ak) = AnchoredKeyId::from_contains_statement(stmt) { - // First producer wins (shouldn't have duplicates in practice) - ak_to_producer.entry(ak).or_insert(stmt_idx); - } - } - - // Build parallel array: anchored_key_producers[i] = producer for all_anchored_keys[i] - let anchored_key_producers: Vec> = all_anchored_keys - .iter() - .map(|ak| ak_to_producer.get(ak).copied()) - .collect(); - // Build external POD statement mapping let external_pod_statements = build_external_statement_map(&self.input_pods); // Build dependency graph - let deps = - DependencyGraph::build(&self.statements, &self.operations, &external_pod_statements); - - // Build statement content groups for deduplication. - // Statements with identical content share a single slot in the POD. - // Keep groups ordered by first occurrence index for deterministic solver input. - let mut first_idx_by_stmt: HashMap<&Statement, usize> = HashMap::new(); - let mut groups_by_first_idx: BTreeMap> = BTreeMap::new(); - for (idx, stmt) in self.statements.iter().enumerate() { - let first_idx = *first_idx_by_stmt.entry(stmt).or_insert(idx); - groups_by_first_idx.entry(first_idx).or_default().push(idx); - } - let statement_content_groups: Vec> = groups_by_first_idx.into_values().collect(); + let deps = DependencyGraph::build(&statements, &operations, &external_pod_statements); // Run solver let input = solver::SolverInput { - num_statements: self.statements.len(), + num_statements: statements.len(), costs: &costs, deps: &deps, output_public_indices: &self.output_public_indices, params: &self.params, max_pods: self.options.max_pods, - all_anchored_keys: &all_anchored_keys, - anchored_key_producers: &anchored_key_producers, - statement_content_groups: &statement_content_groups, }; let solution = solver::solve(&input)?; @@ -613,8 +553,8 @@ impl MultiPodBuilder { params: self.params, vd_set: self.vd_set, input_pods: self.input_pods, - statements: self.statements, - operations: self.operations, + statements, + operations, output_public_indices: self.output_public_indices, operations_wildcard_values: self.operations_wildcard_values, solution, @@ -845,33 +785,13 @@ mod tests { let solution = solved.solution(); // Expected: exactly 2 PODs - // - POD 0 (intermediate): statements 0 (contains), 1 (a_out); a_out is public - // - POD 1 (output): statement 2 (b_out); b_out is public - // The output POD accesses a_out from POD 0 to satisfy b_out's dependency. - assert_eq!( - solution.pod_count, 2, - "Expected exactly 2 PODs for 3-statement chain with max_priv=2" - ); - - // POD 0 should contain statements 0 and 1 (contains and a_out) - assert!( - solution.pod_statements[0].contains(&0) && solution.pod_statements[0].contains(&1), - "POD 0 should contain statements 0 (contains) and 1 (a_out), got {:?}", - solution.pod_statements[0] - ); - - // Statement 1 (a_out) should be public in POD 0 so POD 1 can access it - assert!( - solution.pod_public_statements[0].contains(&1), - "Statement 1 (a_out) should be public in POD 0" - ); - - // POD 1 (output) should contain statement 2 (b_out) - assert!( - solution.pod_statements[1].contains(&2), - "POD 1 should contain statement 2 (b_out), got {:?}", - solution.pod_statements[1] - ); + // Solution A: + // - POD 0 (intermediate): public statements 0 (contains) + // - POD 1 (output): inherits statement 0 (contains) from POD0, statement 1 (a_out), + // public statement 2 (b_out) + // Solution B: + // - POD 0 (intermediate): statements 0 (contains), public statement 1 (a_out) + // - POD 1 (output): inherits statement 1 (a_out) from POD0, public statement 2 (b_out) // Statement 2 (b_out) should be public in POD 1 (it's output-public) assert!( diff --git a/src/frontend/multi_pod/solver.rs b/src/frontend/multi_pod/solver.rs index 9a24fb0..db1502e 100644 --- a/src/frontend/multi_pod/solver.rs +++ b/src/frontend/multi_pod/solver.rs @@ -52,7 +52,7 @@ use itertools::Itertools; use super::Result; use crate::{ frontend::multi_pod::{ - cost::{AnchoredKeyId, CustomBatchId, StatementCost}, + cost::{CustomPredicateId, StatementCost}, deps::{DependencyGraph, ExternalDependency, StatementSource}, }, middleware::{Hash, Params}, @@ -95,7 +95,6 @@ struct DependencyStats { struct SolveDebugContext { dep_stats: DependencyStats, batch_memberships: usize, - anchored_key_memberships: usize, } #[derive(Clone, Copy, Debug, Default)] @@ -105,10 +104,8 @@ struct ModelSizeEstimate { vars_public_external: usize, vars_pod_used: usize, vars_batch_used: usize, - vars_anchored_key_used: usize, vars_uses_input: usize, vars_uses_external: usize, - vars_content_group_used: usize, vars_total: usize, c1_coverage: usize, c2_output_public: usize, @@ -120,7 +117,6 @@ struct ModelSizeEstimate { c6_pre_content_group: usize, c6_resource_limits: usize, c7_batch_cardinality: usize, - c7b_anchored_key_tracking: usize, c8a_internal_inputs: usize, c8b_external_dep_inputs: usize, c8c_external_forward_inputs: usize, @@ -141,8 +137,6 @@ impl ModelSizeEstimate { debug_ctx: &SolveDebugContext, ) -> Self { let n = input.num_statements; - let num_groups = input.statement_content_groups.len(); - let num_anchored_keys = input.all_anchored_keys.len(); let triangular_k = target_pods * target_pods.saturating_sub(1) / 2; let vars_prove = n * target_pods; @@ -150,19 +144,15 @@ impl ModelSizeEstimate { let vars_public_external = external_premises_len * target_pods; let vars_pod_used = target_pods; let vars_batch_used = all_batches_len * target_pods; - let vars_anchored_key_used = num_anchored_keys * target_pods; let vars_uses_input = triangular_k; let vars_uses_external = external_pods_len * target_pods; - let vars_content_group_used = num_groups * target_pods; let vars_total = vars_prove + vars_public + vars_public_external + vars_pod_used + vars_batch_used - + vars_anchored_key_used + vars_uses_input - + vars_uses_external - + vars_content_group_used; + + vars_uses_external; let c1_coverage = n; let c2_output_public = input.output_public_indices.len(); @@ -171,12 +161,10 @@ impl ModelSizeEstimate { let c4_pod_existence = n * target_pods; let c5_internal_dependencies = debug_ctx.dep_stats.internal_edges * target_pods; let c5_external_dependencies = debug_ctx.dep_stats.external_edges * target_pods; - let c6_pre_content_group = (n * target_pods) + (num_groups * target_pods); + let c6_pre_content_group = n * target_pods; let c6_resource_limits = 7 * target_pods; let c7_batch_cardinality = (debug_ctx.batch_memberships * target_pods) + (all_batches_len * target_pods); - let c7b_anchored_key_tracking = - (debug_ctx.anchored_key_memberships * target_pods) + (num_anchored_keys * target_pods); let c8a_internal_inputs = debug_ctx.dep_stats.internal_edges * triangular_k; let c8b_external_dep_inputs = debug_ctx.dep_stats.external_edges * triangular_k; let c8c_external_forward_inputs = external_premises_len * triangular_k; @@ -194,7 +182,6 @@ impl ModelSizeEstimate { + c6_pre_content_group + c6_resource_limits + c7_batch_cardinality - + c7b_anchored_key_tracking + c8a_internal_inputs + c8b_external_dep_inputs + c8c_external_forward_inputs @@ -209,10 +196,8 @@ impl ModelSizeEstimate { vars_public_external, vars_pod_used, vars_batch_used, - vars_anchored_key_used, vars_uses_input, vars_uses_external, - vars_content_group_used, vars_total, c1_coverage, c2_output_public, @@ -224,7 +209,6 @@ impl ModelSizeEstimate { c6_pre_content_group, c6_resource_limits, c7_batch_cardinality, - c7b_anchored_key_tracking, c8a_internal_inputs, c8b_external_dep_inputs, c8c_external_forward_inputs, @@ -300,6 +284,7 @@ pub struct MultiPodSolution { } /// Input to the MILP solver. +#[derive(Debug)] pub struct SolverInput<'a> { /// Number of statements. pub num_statements: usize, @@ -318,28 +303,6 @@ pub struct SolverInput<'a> { /// Maximum number of PODs the solver will consider. pub max_pods: usize, - - /// All unique anchored keys referenced by any statement. - /// - /// Each unique (dict, key) pair that is used as an anchored key reference - /// in any operation. When a Contains statement with literal values is used - /// as an argument, it creates an anchored key reference. - pub all_anchored_keys: &'a [AnchoredKeyId], - - /// For each anchored key, the statement index that produces it (if any). - /// - /// When a Contains statement with literal (dict, key, value) args is explicitly - /// added, it "produces" that anchored key. If the producer is in the same POD - /// as statements using the anchored key, no auto-insertion is needed. - /// `anchored_key_producers[i]` corresponds to `all_anchored_keys[i]`. - pub anchored_key_producers: &'a [Option], - - /// Statement content groups for deduplication. - /// - /// Each inner Vec contains statement indices that have identical content. - /// When multiple statements with the same content are proved in the same POD, - /// they only use one statement slot (the POD deduplicates identical statements). - pub statement_content_groups: &'a [Vec], } /// Solve the MILP problem to find optimal POD packing. @@ -386,11 +349,11 @@ pub fn solve(input: &SolverInput) -> Result { ))); } - // Collect all unique custom batch IDs used - let all_batches: Vec = input + // Collect all unique custom predicate IDs used + let all_custom_predicates: Vec = input .costs .iter() - .flat_map(|c| c.custom_batch_ids.iter().cloned()) + .flat_map(|c| c.custom_predicates_ids.iter().cloned()) .unique() .collect(); @@ -417,18 +380,19 @@ pub fn solve(input: &SolverInput) -> Result { } let dep_stats = dependency_stats(input.deps); - let batch_memberships: usize = input.costs.iter().map(|c| c.custom_batch_ids.len()).sum(); - let anchored_key_memberships: usize = input.costs.iter().map(|c| c.anchored_keys.len()).sum(); + let batch_memberships: usize = input + .costs + .iter() + .map(|c| c.custom_predicates_ids.len()) + .sum(); let debug_ctx = SolveDebugContext { dep_stats, batch_memberships, - anchored_key_memberships, }; if log::log_enabled!(log::Level::Debug) { let resource_totals = ResourceTotals::from_costs(input.costs); - let lb_statement_groups = - lower_bound_from_total(input.statement_content_groups.len(), max_stmts_per_pod); + let lb_statement_groups = lower_bound_from_total(input.num_statements, max_stmts_per_pod); let lb_merkle = lower_bound_from_total( resource_totals.merkle_proofs, input.params.max_merkle_proofs_containers, @@ -463,14 +427,12 @@ pub fn solve(input: &SolverInput) -> Result { .expect("non-empty lower-bound candidate list"); log::debug!( - "MILP summary: statements={} output_public={} content_groups={} anchored_keys={} \ - batches={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \ + "MILP summary: statements={} output_public={} \ + custom_predicates={} deps_internal_edges={} deps_external_edges={} external_input_pods={} \ external_premises={} search_min_pods={} max_pods={}", n, num_output_public, - input.statement_content_groups.len(), - input.all_anchored_keys.len(), - all_batches.len(), + all_custom_predicates.len(), dep_stats.internal_edges, dep_stats.external_edges, external_pods.len(), @@ -481,14 +443,13 @@ pub fn solve(input: &SolverInput) -> Result { log::debug!( "MILP resource totals: merkle_proofs={} merkle_state_transitions={} \ custom_pred_verifications={} signed_by={} public_key_of={} \ - batch_memberships={} anchored_key_memberships={}", + batch_memberships={}", resource_totals.merkle_proofs, resource_totals.merkle_state_transitions, resource_totals.custom_pred_verifications, resource_totals.signed_by, resource_totals.public_key_of, batch_memberships, - anchored_key_memberships ); log::debug!( "MILP lower bounds (pods): statements_raw={} statements_dedup={} merkle_proofs={} \ @@ -513,7 +474,7 @@ pub fn solve(input: &SolverInput) -> Result { if let Some(solution) = try_solve_with_pods( input, target_pods, - &all_batches, + &all_custom_predicates, &external_pods, &external_premises, &debug_ctx, @@ -540,7 +501,7 @@ pub fn solve(input: &SolverInput) -> Result { fn try_solve_with_pods( input: &SolverInput, target_pods: usize, - all_batches: &[CustomBatchId], + all_custom_predicates: &[CustomPredicateId], external_pods: &[Hash], external_premises: &[ExternalDependency], debug_ctx: &SolveDebugContext, @@ -574,21 +535,8 @@ fn try_solve_with_pods( .map(|_| vars.add(variable().binary())) .collect(); - // batch_used[b][p] - custom batch b is used in POD p - let batch_used: Vec> = (0..all_batches.len()) - .map(|_| { - (0..target_pods) - .map(|_| vars.add(variable().binary())) - .collect() - }) - .collect(); - - // anchored_key_used[ak][p] - anchored key ak is used in POD p - // When a statement references an anchored key (via a Contains statement argument), - // that POD must have a Contains statement for that (dict, key) pair. - // MainPodBuilder::add_entries_contains auto-inserts these, and we must account - // for them in the statement count. - let anchored_key_used: Vec> = (0..input.all_anchored_keys.len()) + // custom_predicates[b][p] - custom predicate b is used in POD p + let custom_predicate_used: Vec> = (0..all_custom_predicates.len()) .map(|_| { (0..target_pods) .map(|_| vars.add(variable().binary())) @@ -633,31 +581,19 @@ fn try_solve_with_pods( .map(|(i, ext)| (ext.clone(), i)) .collect(); - // content_group_used[g][p] - content group g has at least one statement proved in POD p - // When multiple statements have identical content, they share a slot in the POD. - // This variable tracks whether at least one statement from each content group is proved. - let num_groups = input.statement_content_groups.len(); - let content_group_used: Vec> = (0..num_groups) - .map(|_| { - (0..target_pods) - .map(|_| vars.add(variable().binary())) - .collect() - }) - .collect(); - if log::log_enabled!(log::Level::Debug) { let estimate = ModelSizeEstimate::for_target_pods( input, target_pods, - all_batches.len(), + all_custom_predicates.len(), external_pods.len(), external_premises.len(), debug_ctx, ); log::debug!( "MILP(k={}) model estimate vars_total={} [prove={} public={} pod_used={} \ - public_external={} batch_used={} anchored_key_used={} uses_input={} \ - uses_external={} content_group_used={}]", + public_external={} batch_used={} uses_input={} \ + uses_external={}]", target_pods, estimate.vars_total, estimate.vars_prove, @@ -665,14 +601,12 @@ fn try_solve_with_pods( estimate.vars_pod_used, estimate.vars_public_external, estimate.vars_batch_used, - estimate.vars_anchored_key_used, estimate.vars_uses_input, estimate.vars_uses_external, - estimate.vars_content_group_used ); log::debug!( "MILP(k={}) model estimate constraints_total={} [c1={} c2={} c2b={} c3={} c4={} \ - c5i={} c5e={} c6_pre={} c6_limits={} c7={} c7b={} c8a={} c8b={} c8c={} \ + c5i={} c5e={} c6_pre={} c6_limits={} c7={} c8a={} c8b={} c8c={} \ c8d={} c9={} c10={} c10b={}]", target_pods, estimate.constraints_total, @@ -686,7 +620,6 @@ fn try_solve_with_pods( estimate.c6_pre_content_group, estimate.c6_resource_limits, estimate.c7_batch_cardinality, - estimate.c7b_anchored_key_tracking, estimate.c8a_internal_inputs, estimate.c8b_external_dep_inputs, estimate.c8c_external_forward_inputs, @@ -798,35 +731,11 @@ fn try_solve_with_pods( } } - // Constraint 6: Resource limits per POD - // - // 6a-pre: Content group tracking for statement deduplication - // When multiple statement indices have identical content, they share a single slot in the POD. - // content_group_used[g][p] = 1 iff at least one statement from group g is proved in POD p. - for (g, group) in input.statement_content_groups.iter().enumerate() { - for p in 0..target_pods { - // Lower bound: if any statement in the group is proved, the group is used - for &s in group { - model.add_constraint(constraint!(content_group_used[g][p] >= prove[s][p])); - } - // Upper bound: if no statements in the group are proved, the group is not used - let group_prove_sum: Expression = group.iter().map(|&s| prove[s][p]).sum(); - model.add_constraint(constraint!(content_group_used[g][p] <= group_prove_sum)); - } - } - for p in 0..target_pods { - // 6a: Unique statement count (unique content groups + anchored key Contains) - // Statements with identical content share a slot, so we count content groups, not indices. - // Anchored key Contains statements are auto-inserted by MainPodBuilder when needed. - // The total must not exceed max_priv_statements (= max_statements - max_public_statements). - let unique_stmt_sum: Expression = (0..num_groups).map(|g| content_group_used[g][p]).sum(); - let anchored_key_sum: Expression = (0..input.all_anchored_keys.len()) - .map(|ak| anchored_key_used[ak][p]) - .sum(); + // 6a: Statement count + let stmt_sum: Expression = (0..n).map(|g| prove[g][p]).sum(); model.add_constraint(constraint!( - unique_stmt_sum + anchored_key_sum - <= (input.params.max_priv_statements() as f64) * pod_used[p] + stmt_sum <= (input.params.max_priv_statements() as f64) * pod_used[p] )); // 6b: Public statement count (internal public statements + forwarded external premises) @@ -885,67 +794,31 @@ fn try_solve_with_pods( } // Constraint 7: Batch cardinality - // batch_used[b][p] >= prove[s][p] for all s that use batch b (batch is used if any statement uses it) - // batch_used[b][p] <= sum of prove[s][p] for all s using batch b (batch is 0 if no statements use it) - for (b, batch_id) in all_batches.iter().enumerate() { + // custom_predicate_used[b][p] >= prove[s][p] for all s that use custom predicate b (custom + // predicate is used if any statement uses it) + // custom_predicate_used[b][p] <= sum of prove[s][p] for all s using custom predicate b (custom + // predicate is 0 if no statements use it) + for (b, predicate_id) in all_custom_predicates.iter().enumerate() { for p in 0..target_pods { let mut sum: Expression = 0.into(); for s in 0..n { - if input.costs[s].custom_batch_ids.contains(batch_id) { - model.add_constraint(constraint!(batch_used[b][p] >= prove[s][p])); + if input.costs[s].custom_predicates_ids.contains(predicate_id) { + model.add_constraint(constraint!(custom_predicate_used[b][p] >= prove[s][p])); sum += prove[s][p]; } } - model.add_constraint(constraint!(batch_used[b][p] <= sum)); + model.add_constraint(constraint!(custom_predicate_used[b][p] <= sum)); } } - // Constraint 7b: Anchored key tracking - // - // anchored_key_used[ak][p] = 1 when auto-insertion of a Contains is needed for anchored key ak in POD p. - // This happens when: some statement using ak is in POD p, AND the producing Contains is NOT in POD p. - // - // If a Contains statement explicitly produces ak (anchored_key_producers[ak] = Some(prod_idx)): - // - Lower: anchored_key_used[ak][p] >= prove[s][p] - prove[prod_idx][p] for all s using ak - // - Upper: anchored_key_used[ak][p] <= 1 - prove[prod_idx][p] - // This ensures overhead is 0 when the producer is in the same POD. - // - // If no Contains produces ak (anchored_key_producers[ak] = None): - // - Lower: anchored_key_used[ak][p] >= prove[s][p] for all s using ak - // - Upper: anchored_key_used[ak][p] <= sum of prove[s][p] for all s using ak - // Auto-insertion is always needed when any user is present. - for (ak_idx, ak) in input.all_anchored_keys.iter().enumerate() { - let producer = input.anchored_key_producers[ak_idx]; - - for p in 0..target_pods { - let mut user_sum: Expression = 0.into(); - for s in 0..n { - if input.costs[s].anchored_keys.contains(ak) { - if let Some(prod_idx) = producer { - // Producer exists: only count overhead if producer not in this POD - model.add_constraint(constraint!( - anchored_key_used[ak_idx][p] >= prove[s][p] - prove[prod_idx][p] - )); - } else { - // No producer: always need auto-insertion if user is present - model.add_constraint(constraint!( - anchored_key_used[ak_idx][p] >= prove[s][p] - )); - } - user_sum += prove[s][p]; - } - } - - if let Some(prod_idx) = producer { - // If producer is in POD, no auto-insertion needed (overhead = 0) - model.add_constraint(constraint!( - anchored_key_used[ak_idx][p] <= 1 - prove[prod_idx][p] - )); - } else { - // No producer: overhead is bounded by whether any user is present - model.add_constraint(constraint!(anchored_key_used[ak_idx][p] <= user_sum)); - } - } + // Custom predicate count per POD + for p in 0..target_pods { + let custom_predicate_sum: Expression = (0..all_custom_predicates.len()) + .map(|b| custom_predicate_used[b][p]) + .sum(); + model.add_constraint(constraint!( + custom_predicate_sum <= (input.params.max_custom_predicates as f64) * pod_used[p] + )); } // Constraint 8a: Internal input POD tracking using uses_input. @@ -1147,9 +1020,6 @@ mod tests { output_public_indices: &[], params: ¶ms, max_pods: 20, - all_anchored_keys: &[], - anchored_key_producers: &[], - statement_content_groups: &[], }; let result = solve(&input); @@ -1195,7 +1065,6 @@ mod tests { }; let costs = vec![StatementCost::default(), StatementCost::default()]; - let statement_content_groups = vec![vec![0], vec![1]]; let output_public = vec![1]; let input = SolverInput { @@ -1205,9 +1074,6 @@ mod tests { output_public_indices: &output_public, params: ¶ms, max_pods: 4, - all_anchored_keys: &[], - anchored_key_producers: &[], - statement_content_groups: &statement_content_groups, }; let solution = solve(&input).expect("solver should find a feasible forwarding layout"); diff --git a/src/middleware/db/mem.rs b/src/middleware/db/mem.rs index 53ab91e..71211fa 100644 --- a/src/middleware/db/mem.rs +++ b/src/middleware/db/mem.rs @@ -43,8 +43,10 @@ impl DB for MemDB { let mut values = self.values.write().expect("lock not poisoned"); let value_raw = value.raw(); if let Some(old_value) = values.get(&value_raw) { - // If we had a non-raw value stored never overwrite it with a raw value - if !old_value.is_raw() && value.is_raw() { + let old_is_raw = old_value.is_raw(); + // If we had a non-RawValue stored don't overwrite it (specially not with a + // RawValue). Also skip redundant RawValue overwrite. + if !old_is_raw || value.is_raw() { return Ok(()); } } From 22d25e5cb261e9f97eef9d98384b9ae3cd845099 Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Mon, 30 Mar 2026 15:16:19 +0100 Subject: [PATCH 05/10] Podlang syntax for quoted predicates (#495) --- src/frontend/custom.rs | 6 ++ src/lang/diagnostics.rs | 11 +++ src/lang/error.rs | 3 + src/lang/frontend_ast.rs | 40 ++++++++++ src/lang/frontend_ast_lower.rs | 103 ++++++++++++++++++++++++-- src/lang/frontend_ast_split.rs | 2 +- src/lang/frontend_ast_validate.rs | 88 +++++++++++++++++++++- src/lang/grammar.pest | 11 ++- src/lang/module.rs | 118 +++++++++++++++++++++++++++++- src/lang/parser.rs | 30 ++++++++ src/lang/pretty_print.rs | 55 +++++++++++++- 11 files changed, 453 insertions(+), 14 deletions(-) diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index f3a8115..a2614a0 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -176,6 +176,12 @@ impl CustomPredicateBatchBuilder { priv_args: &[&str], sts: &[StatementTmplBuilder], ) -> Result { + if self.predicates.iter().any(|p| p.name == name) { + return Err(Error::custom(format!( + "Duplicate predicate name '{}' in batch", + name + ))); + } if self.predicates.len() >= Params::max_custom_batch_size() { return Err(Error::max_length( "self.predicates.len".to_string(), diff --git a/src/lang/diagnostics.rs b/src/lang/diagnostics.rs index ea528ef..0a1d770 100644 --- a/src/lang/diagnostics.rs +++ b/src/lang/diagnostics.rs @@ -287,6 +287,17 @@ fn render_validation_error( ValidationError::NoRequestBlock => { render_title_only(renderer, "requests must contain a REQUEST block") } + + ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests { span } => { + render_with_optional_span( + renderer, + source, + path, + "self-referential predicate literal not allowed in requests", + span.as_ref(), + "not allowed here", + ) + } } } diff --git a/src/lang/error.rs b/src/lang/error.rs index 944988c..769faf6 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -165,6 +165,9 @@ pub enum ValidationError { #[error("Modules must contain at least one predicate definition")] NoPredicatesInModule, + #[error("Self-referential predicate literal not allowed in requests")] + SelfReferentialPredicateLiteralNotAllowedInRequests { span: Option }, + #[error("Requests must contain a REQUEST block")] NoRequestBlock, } diff --git a/src/lang/frontend_ast.rs b/src/lang/frontend_ast.rs index 4ca7fe4..dd0052c 100644 --- a/src/lang/frontend_ast.rs +++ b/src/lang/frontend_ast.rs @@ -116,6 +116,8 @@ pub enum StatementTmplArg { Literal(LiteralValue), Wildcard(Identifier), AnchoredKey(AnchoredKey), + /// Hash of a same-module predicate, resolved at batch finalization time. + SelfPredicateHash(Identifier), } /// Anchored key: Var["key"] or Var.key @@ -168,6 +170,13 @@ pub enum LiteralValue { Array(LiteralArray), Set(LiteralSet), Dict(LiteralDict), + /// Hash of a native predicate (resolved immediately). + NativePredicateHash(Identifier), + /// Hash of an external module's predicate (resolved immediately). + ExternalPredicateHash { + module: Identifier, + predicate: Identifier, + }, } /// Integer literal @@ -391,6 +400,9 @@ impl fmt::Display for StatementTmplArg { StatementTmplArg::Literal(lit) => write!(f, "{}", lit), StatementTmplArg::Wildcard(id) => write!(f, "{}", id), StatementTmplArg::AnchoredKey(ak) => write!(f, "{}", ak), + StatementTmplArg::SelfPredicateHash(id) => { + write!(f, "@self_predicate({})", id) + } } } } @@ -422,6 +434,12 @@ impl fmt::Display for LiteralValue { LiteralValue::Array(a) => write!(f, "{}", a), LiteralValue::Set(s) => write!(f, "{}", s), LiteralValue::Dict(d) => write!(f, "{}", d), + LiteralValue::NativePredicateHash(id) => { + write!(f, "@native_predicate({})", id) + } + LiteralValue::ExternalPredicateHash { + module, predicate, .. + } => write!(f, "@external_predicate({}, {})", module, predicate), } } } @@ -769,6 +787,10 @@ pub mod parse { let inner = pair.into_inner().next().unwrap(); match inner.as_rule() { + Rule::predicate_hash_self => { + let id = parse_identifier(inner.into_inner().next().unwrap()); + Ok(StatementTmplArg::SelfPredicateHash(id)) + } Rule::literal_value => Ok(StatementTmplArg::Literal(parse_literal_value(inner)?)), Rule::identifier => Ok(StatementTmplArg::Wildcard(parse_identifier(inner))), Rule::anchored_key => Ok(StatementTmplArg::AnchoredKey(parse_anchored_key(inner)?)), @@ -823,6 +845,16 @@ pub mod parse { Rule::literal_array => Ok(LiteralValue::Array(parse_literal_array(inner)?)), Rule::literal_set => Ok(LiteralValue::Set(parse_literal_set(inner)?)), Rule::literal_dict => Ok(LiteralValue::Dict(parse_literal_dict(inner)?)), + Rule::predicate_hash_native => { + let id = parse_identifier(inner.into_inner().next().unwrap()); + Ok(LiteralValue::NativePredicateHash(id)) + } + Rule::predicate_hash_external => { + let mut parts = inner.into_inner(); + let module = parse_identifier(parts.next().unwrap()); + let predicate = parse_identifier(parts.next().unwrap()); + Ok(LiteralValue::ExternalPredicateHash { module, predicate }) + } _ => unreachable!("Unexpected literal value rule: {:?}", inner.as_rule()), } } @@ -1104,6 +1136,7 @@ mod tests { AnchoredKeyPath::Dot(id) => id.span = None, } } + StatementTmplArg::SelfPredicateHash(id) => id.span = None, } } } @@ -1139,6 +1172,13 @@ mod tests { clear_literal_spans(&mut pair.value); } } + LiteralValue::NativePredicateHash(id) => id.span = None, + LiteralValue::ExternalPredicateHash { + module, predicate, .. + } => { + module.span = None; + predicate.span = None; + } } } diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index fe9b745..fb00def 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -157,8 +157,10 @@ fn resolve_local_predicate( /// Lower a literal value from AST to middleware Value. /// -/// This is a pure conversion that cannot fail. -pub fn lower_literal(lit: &LiteralValue) -> Value { +/// This is a pure conversion that cannot fail for context-free literals. +/// Panics on ExternalPredicateHash — use `lower_literal_with_context` when +/// external predicate references may appear (e.g. inside containers). +pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { match lit { LiteralValue::Int(i) => Value::from(i.value), LiteralValue::Bool(b) => Value::from(b.value), @@ -190,13 +192,83 @@ pub fn lower_literal(lit: &LiteralValue) -> Value { let dict = containers::Dictionary::new(pairs); Value::from(dict) } + LiteralValue::NativePredicateHash(id) => { + let np = NativePredicate::from_str(&id.name).expect("validated native predicate"); + Value::from(Predicate::Native(np).hash()) + } + LiteralValue::ExternalPredicateHash { .. } => { + unreachable!( + "ExternalPredicateHash must be lowered with context via lower_literal_with_context" + ) + } + } +} + +/// Lower a literal value, resolving external predicate references using the symbol table. +pub fn lower_literal_with_context( + lit: &LiteralValue, + symbols: &SymbolTable, + context: &ResolutionContext, +) -> Result { + match lit { + LiteralValue::ExternalPredicateHash { module, predicate } => { + let pred_or_wc = resolve_predicate_ref( + &PredicateRef::Qualified { + module: module.clone(), + predicate: predicate.clone(), + }, + symbols, + context, + ) + .ok_or_else(|| LoweringError::PredicateNotFound { + name: format!("{}::{}", module.name, predicate.name), + })?; + let pred = match pred_or_wc { + crate::frontend::PredicateOrWildcard::Predicate(p) => p, + _ => unreachable!( + "`resolve_predicate_ref` always returns `PredicateOrWildcard::Predicate` on `PredicateRef::Qualified`" + ) + }; + Ok(Value::from(pred.hash())) + } + LiteralValue::Array(a) => { + let elements: Vec<_> = a + .elements + .iter() + .map(|e| lower_literal_with_context(e, symbols, context)) + .collect::>()?; + Ok(Value::from(containers::Array::new(elements))) + } + LiteralValue::Set(s) => { + let elements: std::collections::HashSet<_> = s + .elements + .iter() + .map(|e| lower_literal_with_context(e, symbols, context)) + .collect::>()?; + Ok(Value::from(containers::Set::new(elements))) + } + LiteralValue::Dict(d) => { + let pairs: HashMap<_, _> = d + .pairs + .iter() + .map(|pair| { + let key = Key::from(pair.key.value.as_str()); + let value = lower_literal_with_context(&pair.value, symbols, context)?; + Ok((key, value)) + }) + .collect::>()?; + Ok(Value::from(containers::Dictionary::new(pairs))) + } + // All other variants are context-free + other => Ok(lower_literal(other)), } } /// Lower a statement argument from AST to BuilderArg. /// -/// This is a pure conversion that cannot fail. -pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { +/// Context-free for most arg types. Panics on ExternalPredicateHash inside literals — +/// use `lower_statement_arg_with_context` when external predicate references may appear. +pub(crate) fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { match arg { StatementTmplArg::Literal(lit) => { let value = lower_literal(lit); @@ -210,6 +282,25 @@ pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { }; BuilderArg::Key(ak.root.name.clone(), key_str) } + StatementTmplArg::SelfPredicateHash(id) => BuilderArg::SelfPredicateHash(id.name.clone()), + } +} + +/// Lower a statement argument, resolving external predicate references using the symbol table. +pub fn lower_statement_arg_with_context( + arg: &StatementTmplArg, + symbols: &SymbolTable, + context: &ResolutionContext, +) -> Result { + match arg { + StatementTmplArg::Literal(lit) => { + let value = lower_literal_with_context(lit, symbols, context)?; + Ok(BuilderArg::Literal(value)) + } + StatementTmplArg::SelfPredicateHash(id) => { + Ok(BuilderArg::SelfPredicateHash(id.name.clone())) + } + other => Ok(lower_statement_arg(other)), } } @@ -324,7 +415,7 @@ impl<'a> Lowerer<'a> { // Create a builder with the resolved predicate and desugar let mut builder = StatementTmplBuilder::new(predicate.clone()); for arg in &stmt.args { - let builder_arg = lower_statement_arg(arg); + let builder_arg = lower_statement_arg_with_context(arg, symbols, &context)?; builder = builder.arg(builder_arg); } let desugared = builder.desugar(); @@ -402,7 +493,7 @@ impl<'a> Lowerer<'a> { names.push(ak.root.name.clone()); } } - StatementTmplArg::Literal(_) => {} + StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {} } } } diff --git a/src/lang/frontend_ast_split.rs b/src/lang/frontend_ast_split.rs index 0d17217..482db7a 100644 --- a/src/lang/frontend_ast_split.rs +++ b/src/lang/frontend_ast_split.rs @@ -123,7 +123,7 @@ fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet { StatementTmplArg::AnchoredKey(ak) => { wildcards.insert(ak.root.name.clone()); } - StatementTmplArg::Literal(_) => {} + StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {} } } diff --git a/src/lang/frontend_ast_validate.rs b/src/lang/frontend_ast_validate.rs index 49575b5..0b7737d 100644 --- a/src/lang/frontend_ast_validate.rs +++ b/src/lang/frontend_ast_validate.rs @@ -559,7 +559,12 @@ impl Validator { } } } - StatementTmplArg::Literal(_) => {} + StatementTmplArg::Literal(lit) => { + self.validate_literal_value(lit)?; + } + StatementTmplArg::SelfPredicateHash(id) => { + self.validate_self_predicate_hash(id, wildcard_context)?; + } } } } else { @@ -588,13 +593,92 @@ impl Validator { } } } - StatementTmplArg::Literal(_) => {} + StatementTmplArg::Literal(lit) => { + self.validate_literal_value(lit)?; + } + StatementTmplArg::SelfPredicateHash(id) => { + self.validate_self_predicate_hash(id, wildcard_context)?; + } } } } Ok(()) } + + /// Validate a @self_predicate reference: the name must be a custom predicate in this module. + fn validate_self_predicate_hash( + &self, + id: &Identifier, + wildcard_context: Option<(&str, &WildcardScope)>, + ) -> Result<(), ValidationError> { + // @self_predicate only makes sense inside module predicate definitions + if wildcard_context.is_none() { + return Err( + ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests { + span: id.span, + }, + ); + } + // Must refer to a custom predicate defined in this module (not intro/imported) + match self.symbols.predicates.get(&id.name) { + Some(info) if matches!(info.kind, PredicateKind::Custom { .. }) => Ok(()), + _ => Err(ValidationError::UndefinedPredicate { + name: id.name.clone(), + span: id.span, + }), + } + } + + /// Recursively validate a literal value, checking predicate hash references. + fn validate_literal_value(&self, lit: &LiteralValue) -> Result<(), ValidationError> { + match lit { + LiteralValue::NativePredicateHash(id) => { + if NativePredicate::from_str(&id.name).is_err() { + return Err(ValidationError::UndefinedPredicate { + name: id.name.clone(), + span: id.span, + }); + } + Ok(()) + } + LiteralValue::ExternalPredicateHash { module, predicate } => { + if let Some(imported) = self.symbols.imported_modules.get(&module.name) { + if !imported.predicate_index.contains_key(&predicate.name) { + return Err(ValidationError::UndefinedPredicate { + name: format!("{}::{}", module.name, predicate.name), + span: predicate.span, + }); + } + } else { + return Err(ValidationError::ModuleNotFound { + name: module.name.clone(), + span: module.span, + }); + } + Ok(()) + } + LiteralValue::Array(a) => { + for elem in &a.elements { + self.validate_literal_value(elem)?; + } + Ok(()) + } + LiteralValue::Set(s) => { + for elem in &s.elements { + self.validate_literal_value(elem)?; + } + Ok(()) + } + LiteralValue::Dict(d) => { + for pair in &d.pairs { + self.validate_literal_value(&pair.value)?; + } + Ok(()) + } + _ => Ok(()), + } + } } #[cfg(test)] diff --git a/src/lang/grammar.pest b/src/lang/grammar.pest index 3002d15..1c11baa 100644 --- a/src/lang/grammar.pest +++ b/src/lang/grammar.pest @@ -49,7 +49,14 @@ custom_predicate_def = { statement_list = { statement+ } -statement_arg = { literal_value | anchored_key | identifier } +// Predicate hash literals: resolve to the predicate's identity hash as a value. +// @native_predicate and @external_predicate are in literal_value (usable in containers). +// @self_predicate is only in statement_arg (not in containers — deferred resolution). +predicate_hash_native = { "@native_predicate" ~ "(" ~ identifier ~ ")" } +predicate_hash_external = { "@external_predicate" ~ "(" ~ identifier ~ "," ~ identifier ~ ")" } +predicate_hash_self = { "@self_predicate" ~ "(" ~ identifier ~ ")" } + +statement_arg = { predicate_hash_self | literal_value | anchored_key | identifier } statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* } // Predicate reference: either qualified (module::predicate) or local (predicate) @@ -74,6 +81,8 @@ literal_value = { literal_bool | literal_raw | literal_string | + predicate_hash_native | + predicate_hash_external | literal_int } diff --git a/src/lang/module.rs b/src/lang/module.rs index 78fb22e..b926871 100644 --- a/src/lang/module.rs +++ b/src/lang/module.rs @@ -11,7 +11,9 @@ use crate::{ lang::{ error::BatchingError, frontend_ast::{ConjunctionType, CustomPredicateDef}, - frontend_ast_lower::{lower_statement_arg, resolve_predicate_ref, ResolutionContext}, + frontend_ast_lower::{ + lower_statement_arg_with_context, resolve_predicate_ref, ResolutionContext, + }, frontend_ast_split::{SplitChainInfo, SplitResult}, frontend_ast_validate::SymbolTable, }, @@ -374,7 +376,13 @@ fn build_statement_with_resolved_refs( let mut builder = StatementTmplBuilder::new(pred_or_wc); for arg in &stmt.args { - builder = builder.arg(lower_statement_arg(arg)); + let builder_arg = + lower_statement_arg_with_context(arg, symbols, &context).map_err(|e| { + BatchingError::Internal { + message: format!("Failed to lower argument: {}", e), + } + })?; + builder = builder.arg(builder_arg); } Ok(builder) @@ -670,4 +678,110 @@ mod tests { PredicateOrWildcard::Predicate(Predicate::Custom(ordering_ref)) ); } + + #[test] + fn test_self_predicate_hash_podlang() { + let params = Params::default(); + let module = load_module( + r#" + pred_A(x, y) = AND( + Equal(x, y) + ) + + pred_B(x) = AND( + Equal(x, @self_predicate(pred_A)) + ) + "#, + "test", + ¶ms, + &[], + ) + .unwrap(); + + let batch = &module.batch; + + // pred_B is at index 1, its template should have SelfPredicateHash(0) resolved + // to a Literal containing pred_A's hash after normalization + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_a_hash = crate::middleware::Value::from(Predicate::Custom(pred_a_ref).hash()); + + // Use normalized_predicate to resolve + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let normalized = pred_b_ref.normalized_predicate(); + assert_eq!( + normalized.statements[0].args[1], + crate::middleware::StatementTmplArg::Literal(pred_a_hash) + ); + } + + #[test] + fn test_self_predicate_hash_podlang_cyclic() { + let params = Params::default(); + let module = load_module( + r#" + pred_A(x) = AND( + Equal(x, @self_predicate(pred_B)) + ) + + pred_B(x) = AND( + Equal(x, @self_predicate(pred_A)) + ) + "#, + "test", + ¶ms, + &[], + ) + .unwrap(); + + let batch = &module.batch; + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_a_hash = + crate::middleware::Value::from(Predicate::Custom(pred_a_ref.clone()).hash()); + let pred_b_hash = + crate::middleware::Value::from(Predicate::Custom(pred_b_ref.clone()).hash()); + + // pred_A's normalized form should contain pred_B's hash + let norm_a = pred_a_ref.normalized_predicate(); + assert_eq!( + norm_a.statements[0].args[1], + crate::middleware::StatementTmplArg::Literal(pred_b_hash) + ); + + // pred_B's normalized form should contain pred_A's hash + let norm_b = pred_b_ref.normalized_predicate(); + assert_eq!( + norm_b.statements[0].args[1], + crate::middleware::StatementTmplArg::Literal(pred_a_hash) + ); + } + + #[test] + fn test_native_predicate_hash_podlang() { + let params = Params::default(); + let module = load_module( + r#" + pred_C(x) = AND( + Equal(x, @native_predicate(Equal)) + ) + "#, + "test", + ¶ms, + &[], + ) + .unwrap(); + + let batch = &module.batch; + let pred_c_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_c = pred_c_ref.predicate(); + + // The second arg should be a Literal containing Equal's predicate hash + let equal_hash = crate::middleware::Value::from( + Predicate::Native(crate::middleware::NativePredicate::Equal).hash(), + ); + assert_eq!( + pred_c.statements[0].args[1], + crate::middleware::StatementTmplArg::Literal(equal_hash) + ); + } } diff --git a/src/lang/parser.rs b/src/lang/parser.rs index 000e683..1a29113 100644 --- a/src/lang/parser.rs +++ b/src/lang/parser.rs @@ -137,6 +137,9 @@ mod tests { assert_inner(&Rule::anchored_key, "someVar[\"key\"]"); assert_inner(&Rule::literal_value, "true"); assert_inner(&Rule::literal_value, "PublicKey(abc)"); + assert_inner(&Rule::predicate_hash_self, "@self_predicate(foo)"); + assert_inner(&Rule::literal_value, "@native_predicate(Equal)"); + assert_inner(&Rule::literal_value, "@external_predicate(mod_a, pred_b)"); } #[test] @@ -207,6 +210,33 @@ mod tests { "{ \"raw_val\": Raw(0x0000000000000000000000000000000000000000000000000000000000000000) } ", ); assert_fails(Rule::literal_dict, "{ name: \"Alice\" }"); // Key must be string literal with quotes + + // Predicate hash literals + assert_parses(Rule::predicate_hash_native, "@native_predicate(Equal)"); + assert_parses(Rule::predicate_hash_native, "@native_predicate(Lt)"); + assert_parses( + Rule::predicate_hash_external, + "@external_predicate(my_module, my_pred)", + ); + assert_parses(Rule::predicate_hash_self, "@self_predicate(local_pred)"); + + // Predicate hashes inside containers (native and external only) + assert_parses( + Rule::literal_array, + "[1, @native_predicate(Equal), @external_predicate(m, p)]", + ); + assert_parses( + Rule::literal_set, + "#[@native_predicate(Equal), @native_predicate(Lt)]", + ); + assert_parses( + Rule::literal_dict, + "{ \"pred\": @external_predicate(m, p) }", + ); + + // @self_predicate is NOT a literal_value, so it cannot appear inside containers + assert_fails(Rule::test_literal_value, "@self_predicate(local_pred)"); + assert_fails(Rule::literal_array, "[@self_predicate(foo)]"); } #[test] diff --git a/src/lang/pretty_print.rs b/src/lang/pretty_print.rs index bd912cb..8e4819d 100644 --- a/src/lang/pretty_print.rs +++ b/src/lang/pretty_print.rs @@ -92,7 +92,7 @@ impl StatementTmpl { if i > 0 { write!(w, ", ")?; } - arg.fmt_podlang(w)?; + arg.fmt_podlang_with_batch_context(w, batch_context)?; } write!(w, ")")?; @@ -102,7 +102,30 @@ impl StatementTmpl { impl PrettyPrint for StatementTmplArg { fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result { - write!(w, "{}", self) + self.fmt_podlang_with_batch_context(w, None) + } +} + +impl StatementTmplArg { + fn fmt_podlang_with_batch_context( + &self, + w: &mut dyn Write, + batch_context: Option<&CustomPredicateBatch>, + ) -> std::fmt::Result { + match self { + StatementTmplArg::SelfPredicateHash(index) => { + if let Some(batch) = batch_context { + if let Some(predicate) = batch.predicates().get(*index) { + write!(w, "@self_predicate({})", predicate.name) + } else { + write!(w, "@self_predicate(self_{})", index) + } + } else { + write!(w, "@self_predicate(self_{})", index) + } + } + other => write!(w, "{}", other), + } } } @@ -540,6 +563,34 @@ mod tests { assert_round_trip(&input); } + #[test] + fn test_round_trip_self_predicate_hash() { + let input = r#" + pred_A(x, y) = AND( + Equal(x, y) + ) + + pred_B(x) = AND( + Equal(x, @self_predicate(pred_A)) + ) + "#; + assert_round_trip(input); + } + + #[test] + fn test_round_trip_self_predicate_hash_cyclic() { + let input = r#" + pred_A(x) = AND( + Equal(x, @self_predicate(pred_B)) + ) + + pred_B(x) = AND( + Equal(x, @self_predicate(pred_A)) + ) + "#; + assert_round_trip(input); + } + #[test] fn test_pretty_print_demonstration() { let input = r#" From dbd958dcca41d74eaf85499ad0def8246c15b795 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Wed, 1 Apr 2026 23:49:29 +0200 Subject: [PATCH 06/10] Allow entries as args in custom statements (#498) - Introduce a new operation ReplaceValueWithEntry that allows taking any statement and replacing literal arguments with entries given a matching Contains statement. - Allow entries as args in custom statements - Circuit optimization: For the public statements slots in the circuit we only support None and Copy which take at most 1 argument; but we were still doing max_statement_args random accesses per slot; so I reduced that to just 1 random access to a previous statement. --- src/backends/plonky2/circuits/common.rs | 34 ++-- src/backends/plonky2/circuits/mainpod.rs | 189 +++++++++++++++++----- src/backends/plonky2/mainpod/mod.rs | 132 +++++++++++++-- src/backends/plonky2/mainpod/statement.rs | 14 +- src/backends/plonky2/mock/mainpod.rs | 3 +- src/frontend/custom.rs | 6 +- src/frontend/mod.rs | 42 ++++- src/frontend/multi_pod/cost.rs | 3 +- src/frontend/operation.rs | 22 ++- src/lang/diagnostics.rs | 12 -- src/lang/error.rs | 6 - src/lang/frontend_ast_validate.rs | 93 +++-------- src/lang/mod.rs | 1 - src/middleware/basetypes.rs | 6 + src/middleware/custom.rs | 46 ++++-- src/middleware/mod.rs | 5 +- src/middleware/operation.rs | 76 ++++++++- src/middleware/statement.rs | 15 +- 18 files changed, 515 insertions(+), 190 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 7d25786..de53ee5 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -37,8 +37,8 @@ use crate::{ hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate, OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix, PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg, - StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, STATEMENT_ARG_F_LEN, - VALUE_SIZE, + StatementTmplArgPrefix, ToFields, Value, BASE_PARAMS, EMPTY_VALUE, F, HASH_SIZE, + STATEMENT_ARG_F_LEN, VALUE_SIZE, }, }; @@ -103,6 +103,20 @@ pub struct StatementArgTarget { pub elements: [Target; STATEMENT_ARG_F_LEN], } +impl Flattenable for StatementArgTarget { + fn flatten(&self) -> Vec { + self.elements.to_vec() + } + fn from_flattened(_params: &Params, vs: &[Target]) -> Self { + Self { + elements: vs.try_into().expect("STATEMENT_ARG_F_LEN elements"), + } + } + fn size(_params: &Params) -> usize { + STATEMENT_ARG_F_LEN + } +} + impl StatementArgTarget { pub fn set_targets(&self, pw: &mut PartialWitness, arg: &StatementArg) -> Result<()> { Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?) @@ -318,7 +332,7 @@ impl OperationTarget { .args() .iter() .chain(iter::repeat(&OperationArg::None)) - .take(params.max_operation_args) + .take(BASE_PARAMS.max_operation_args) .enumerate() { self.args[i].set_targets(pw, arg.as_usize())?; @@ -328,7 +342,7 @@ impl OperationTarget { fn size(params: &Params) -> usize { OperationTypeTarget::size(params) - + params.max_operation_args * IndexTarget::size(params) + + BASE_PARAMS.max_operation_args * IndexTarget::size(params) + IndexTarget::size(params) } } @@ -868,7 +882,7 @@ impl CustomPredicateVerifyEntryTarget { args: (0..params.max_custom_predicate_wildcards) .map(|_| builder.add_virtual_value()) .collect(), - op_args: (0..params.max_operation_args) + op_args: (0..BASE_PARAMS.max_operation_args) .map(|_| builder.add_virtual_statement(false)) .collect(), } @@ -898,7 +912,7 @@ impl CustomPredicateVerifyEntryTarget { cpv.op_args .iter() .chain(iter::repeat(&pad_op_arg)) - .take(params.max_operation_args), + .take(BASE_PARAMS.max_operation_args), ) { op_arg_target.set_targets(pw, op_arg)? } @@ -941,7 +955,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget { .expect("len = operation_type_size"), }; let (pos, size) = (pos + size, StatementTarget::size(params)); - let op_args = (0..params.max_operation_args) + let op_args = (0..BASE_PARAMS.max_operation_args) .map(|i| { StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size]) }) @@ -953,7 +967,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget { } } fn size(params: &Params) -> usize { - StatementTarget::size(params) * (1 + params.max_operation_args) + StatementTarget::size(params) * (1 + BASE_PARAMS.max_operation_args) + OperationTarget::size(params) } } @@ -1425,7 +1439,7 @@ impl CircuitBuilderPod for CircuitBuilder { fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { OperationTarget { op_type: self.add_virtual_operation_type(), - args: (0..params.max_operation_args) + args: (0..BASE_PARAMS.max_operation_args) .map(|_| IndexTarget::new_virtual(params.statement_table_size(), self)) .collect(), aux_index: IndexTarget::new_virtual(OperationAux::table_size(params), self), @@ -1735,7 +1749,7 @@ impl CircuitBuilderPod for CircuitBuilder { let num_chunks = array.len().div_ceil(CHUNK_LEN); for chunk in array.chunks(CHUNK_LEN) { let mut index_chunk = i.low; - // I we have several chunks and the last one is smaller (it's index needs less than 6 + // If we have several chunks and the last one is smaller (it's index needs less than 6 // bits), make it zero except when it's used so that the range check over the index // passes. if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 { diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 68114d2..0ac3bec 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -55,7 +55,7 @@ use crate::{ middleware::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix, - ToFields, Value, F, HASH_SIZE, + ToFields, Value, BASE_PARAMS, F, HASH_SIZE, }, }; // @@ -69,30 +69,37 @@ pub const PI_OFFSET_VDSROOT: usize = 4; pub const NUM_PUBLIC_INPUTS: usize = 8; -const MAX_VALUE_ARGS: usize = 4; +const MAX_VALUE_ARGS: usize = 5; struct StatementArgCache { rhs: ValueTarget, lhs: StatementArgTarget, valid: BoolTarget, + pred_is_none: BoolTarget, + is_reference: BoolTarget, + // if `is_reference` then this is the AnchoredKey found in the Contains statement + reference: StatementArgTarget, + // if `is_reference` then this is the value found in the Contains statement + value: ValueTarget, } -struct StatementCache { - equations: [StatementArgCache; MAX_VALUE_ARGS], - first_n_equations_valid: [BoolTarget; MAX_VALUE_ARGS], +struct StatementCache { + equations: [StatementArgCache; MAX_EQS], + first_n_equations_valid: [BoolTarget; MAX_EQS], op_args: Vec, } -impl StatementCache { +impl StatementCache { fn new( params: &Params, + max_operation_args: usize, builder: &mut CircuitBuilder, op: &OperationTarget, st: &StatementTarget, prev_statements: &[StatementTarget], ) -> Self { let op_args = if prev_statements.is_empty() { - (0..params.max_operation_args) + (0..max_operation_args) .map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[])) .collect_vec() } else { @@ -100,10 +107,10 @@ impl StatementCache { // converting a length 1 array into a scalar. op.args .iter() + .take(max_operation_args) .map(|i| builder.vec_ref(params, prev_statements, i)) .collect::>() }; - assert!(params.max_operation_args >= MAX_VALUE_ARGS); assert!(Params::max_statement_args() >= MAX_VALUE_ARGS); let equations = array::from_fn(|i| { let pred_is_none = op_args[i].has_native_type(builder, NativePredicate::None); @@ -117,9 +124,9 @@ impl StatementCache { let is_reference = builder.and(pred_is_contains, ref_is_value); let valid = builder.or(is_literal, is_reference); - let rhs_literal = st.args[i].as_value(); - let rhs_reference = op_args[i].args[2].as_value(); - let rhs = builder.select_value(pred_is_none, rhs_literal, rhs_reference); + let rhs_from_literal = st.args[i].as_value(); + let rhs_from_reference = op_args[i].args[2].as_value(); + let rhs = builder.select_value(pred_is_none, rhs_from_literal, rhs_from_reference); let lhs_literal = &st.args[i]; let lhs_reference = StatementArgTarget::anchored_key( builder, @@ -127,10 +134,22 @@ impl StatementCache { &op_args[i].args[1].as_value(), ); let lhs = builder.select_statement_arg(pred_is_none, lhs_literal, &lhs_reference); - StatementArgCache { rhs, lhs, valid } + StatementArgCache { + rhs, + lhs, + valid, + pred_is_none, + is_reference, + reference: lhs_reference, + value: rhs_from_reference, + } }); - let mut first_n_equations_valid = [equations[0].valid; MAX_VALUE_ARGS]; - for i in 1..MAX_VALUE_ARGS { + let mut first_n_equations_valid = if MAX_EQS != 0 { + [equations[0].valid; MAX_EQS] + } else { + [builder._false(); MAX_EQS] + }; + for i in 1..MAX_EQS { first_n_equations_valid[i] = builder.and(equations[i].valid, first_n_equations_valid[i - 1]); } @@ -145,7 +164,7 @@ impl StatementCache { /// /// If the operation argument is a statement of type `None`, then the value /// should be the corresponding argument of the current statement. - /// If the operation argument is a statement of type `Equals`, then the value + /// If the operation argument is a statement of type `Contains`, then the value /// should be the argument at index 1 of that statement. /// If the function successfully interprets the arguments as values, /// returns `True` along with those values. Otherwise, returns `False` @@ -158,6 +177,12 @@ impl StatementCache { } } +/// Statement cache for private statements +type StatementCachePriv = StatementCache; +/// Statement cache for public statements. Since the operations can only be None or Copy, no +/// equation is needed because none of these operations dereference entries. +type StatementCachePub = StatementCache<0>; + /// Specialized implementation of `verify_operation_circuit` for operations that generate public /// statement. This only allows operations to be None, NewEntry or Copy and accounts for the fact /// that public statements in the current implementation are always generated by copying private @@ -169,13 +194,15 @@ fn verify_operation_public_statement_circuit( op: &OperationTarget, prev_statements: &[StatementTarget], ) -> Result<()> { - let measure = measure_gates_begin!(builder, "OpVerify"); + let measure = measure_gates_begin!(builder, "OpVerifyPub"); // Verify that the operation `op` correctly generates the statement `st`. The operation // can reference any of the `prev_statements`. // TODO: Clean this up. let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); - let cache = StatementCache::new(params, builder, op, st, prev_statements); + // None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the + // StatementCache requires. + let cache = StatementCachePub::new(params, 1, builder, op, st, prev_statements); measure_gates_end!(builder, measure_resolve_op_args); let op_checks = vec![ @@ -406,7 +433,7 @@ fn verify_operation_circuit( prev_statements: &[StatementTarget], aux_table: &MuxTableTarget, ) -> Result<()> { - let measure = measure_gates_begin!(builder, "OpVerify"); + let measure = measure_gates_begin!(builder, "OpVerifyPriv"); let _true = builder._true(); let _false = builder._false(); @@ -414,7 +441,14 @@ fn verify_operation_circuit( // can reference any of the `prev_statements`. // TODO: Clean this up. let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); - let cache = StatementCache::new(params, builder, op, st, prev_statements); + let cache = StatementCachePriv::new( + params, + BASE_PARAMS.max_operation_args, + builder, + op, + st, + prev_statements, + ); measure_gates_end!(builder, measure_resolve_op_args); // Certain operations (e.g.: Contains/NotContains) will refer to one of the provided verified @@ -442,6 +476,7 @@ fn verify_operation_circuit( verify_sum_of_circuit(params, builder, st, &op.op_type, &cache), verify_product_of_circuit(params, builder, st, &op.op_type, &cache), verify_max_of_circuit(params, builder, st, &op.op_type, &cache), + verify_replace_value_with_entry_circuit(params, builder, st, &op.op_type, &cache), ]); } // Skip these if there are no resolved aux entries @@ -542,7 +577,7 @@ fn verify_contains_from_entries_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpContainsFromEntries"); let (aux_tag_ok, resolved_merkle_claim) = @@ -592,7 +627,7 @@ fn verify_not_contains_from_entries_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries"); let (aux_tag_ok, resolved_merkle_claim) = @@ -639,7 +674,7 @@ fn verify_merkle_insert_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "MerkleInsertOp"); let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = @@ -714,7 +749,7 @@ fn verify_merkle_update_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "MerkleUpdateOp"); let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = @@ -789,7 +824,7 @@ fn verify_merkle_delete_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "MerkleDeleteOp"); let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = @@ -883,7 +918,7 @@ fn verify_eq_neq_from_entries_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries"); let eq_op_st_code_ok = { @@ -932,9 +967,9 @@ fn verify_lt_lteq_from_entries_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpLtLteqFromEntries"); + let measure = measure_gates_begin!(builder, "OpLtEqFromEntries"); let zero = ValueTarget::zero(builder); let one = ValueTarget::one(builder); @@ -1000,7 +1035,7 @@ fn verify_hash_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpHashOf"); let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf); @@ -1033,7 +1068,7 @@ fn verify_public_key_of_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpPublicKeyOf"); let (aux_tag_ok, resolved_pk_sk) = @@ -1069,7 +1104,7 @@ fn verify_signed_by_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpSignedBy"); let (aux_tag_ok, resolved_msg_pk) = @@ -1104,7 +1139,7 @@ fn verify_sum_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpSumOf"); let value_zero = ValueTarget::zero(builder); @@ -1142,7 +1177,7 @@ fn verify_product_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpProductOf"); let value_zero = ValueTarget::zero(builder); @@ -1180,7 +1215,7 @@ fn verify_max_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpMaxOf"); let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf); @@ -1220,6 +1255,47 @@ fn verify_max_of_circuit( ok } +fn verify_replace_value_with_entry_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + cache: &StatementCachePriv, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpReplaceValueWithEntry"); + let op_code_ok = op_type.has_native(builder, NativeOperation::ReplaceValueWithEntry); + + let st_in = &cache.op_args[BASE_PARAMS.max_statement_args]; + + let mut args = Vec::new(); + let mut args_ok = builder._true(); + for (arg_in, entry_cache) in zip_eq(&st_in.args, &cache.equations) { + // if the op_arg is None, keep the original argument, if it's a Contains swap the value by + // the reference Entry while checking that the value in Contains matches the original + // argument. + let arg = builder.select_flattenable( + params, + entry_cache.pred_is_none, + arg_in, + &entry_cache.reference, + ); + args.push(arg); + let arg_ref_ok = { + let arg_in_is_value = builder.statement_arg_is_value(arg_in); + let value_eq = builder.is_equal_flattenable(&arg_in.as_value(), &entry_cache.value); + builder.all([entry_cache.is_reference, arg_in_is_value, value_eq]) + }; + let arg_ok = builder.or(entry_cache.pred_is_none, arg_ref_ok); + args_ok = builder.and(args_ok, arg_ok); + } + let expected_statement = StatementTarget::new(*st_in.pred_hash(), args); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); + + let ok = builder.all([op_code_ok, args_ok, st_ok]); + measure_gates_end!(builder, measure); + ok +} + fn verify_transitive_eq_circuit( params: &Params, builder: &mut CircuitBuilder, @@ -1429,7 +1505,7 @@ fn make_custom_statement_circuit( ) -> Result<(StatementTarget, OperationTypeTarget)> { let measure = measure_gates_begin!(builder, "CustomOpVerify"); // Some sanity checks - assert_eq!(params.max_operation_args, op_args.len()); + assert_eq!(BASE_PARAMS.max_operation_args, op_args.len()); assert_eq!(params.max_custom_predicate_wildcards, args.len()); let (batch_id, index) = (custom_predicate.id, custom_predicate.index); @@ -1463,7 +1539,6 @@ fn make_custom_statement_circuit( .collect(); // expected_sts.len() == params.max_custom_predicate_arity // op_args.len() == params.max_operation_args; - assert!(Params::max_custom_predicate_arity() <= params.max_operation_args); let sts_eq: Vec<_> = expected_sts .iter() @@ -2076,7 +2151,8 @@ mod tests { frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, middleware::{ hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, - RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard, EMPTY_VALUE, + RawValue, StatementArg, StatementTmpl, StatementTmplArg, ValueRef, Wildcard, + BASE_PARAMS, EMPTY_VALUE, }, }; @@ -3068,6 +3144,33 @@ mod tests { operation_verify(st, op, prev_statements, Aux::signed_by(signed_by)) } + #[test] + fn test_operation_replace_value_with_entry() -> Result<()> { + let d = dict!({"a" => 42, "b" => 33}); + + // 0: None + // 1: Lt(5, 42) + let st_in: mainpod::Statement = Statement::lt(5, 42).into(); + // 2: Contains(d, "a", 42) + let st_entry: mainpod::Statement = Statement::contains(d.clone(), "a", 42).into(); + + let st_out: mainpod::Statement = + Statement::lt(5, ValueRef::Key(AnchoredKey::from((&d, "a")))).into(); + let mut op_args: Vec<_> = iter::repeat(OperationArg::None) + .take(BASE_PARAMS.max_statement_args + 1) + .collect(); + op_args[1] = OperationArg::Index(2); + op_args[BASE_PARAMS.max_statement_args] = OperationArg::Index(1); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ReplaceValueWithEntry), + op_args, + OperationAux::None, + ); + + let prev_statements = vec![Statement::None.into(), st_in, st_entry]; + operation_verify(st_out, op, prev_statements, Aux::default()) + } + fn helper_statement_arg_from_template( params: &Params, st_tmpl_arg: StatementTmplArg, @@ -3226,7 +3329,7 @@ mod tests { expected_st: Option, ) -> Result<()> { // Pad - for _ in op_args.len()..params.max_operation_args { + for _ in op_args.len()..BASE_PARAMS.max_operation_args { op_args.push(Statement::None); } for _ in args.len()..params.max_custom_predicate_wildcards { @@ -3275,6 +3378,10 @@ mod tests { Ok(data.verify(proof.clone())?) } + fn value_ref(v: impl Into) -> ValueRef { + v.into() + } + // TODO: Add negative tests #[test] fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { @@ -3309,7 +3416,7 @@ mod tests { let args = vec![Value::from(dict), Value::from(1234)]; let expected_st = Statement::Custom( custom_predicate.clone(), - vec![args[0].clone(), Value::from(0)], + vec![value_ref(args[0].clone()), value_ref(0)], ); helper_custom_operation_verify_gadget( @@ -3330,7 +3437,7 @@ mod tests { let args = vec![Value::from(dict), Value::from(0)]; let expected_st = Statement::Custom( custom_predicate.clone(), - vec![args[0].clone(), Value::from(0)], + vec![value_ref(args[0].clone()), value_ref(0)], ); helper_custom_operation_verify_gadget( @@ -3351,7 +3458,7 @@ mod tests { let args = vec![Value::from(dict), Value::from(1234)]; let expected_st = Statement::Custom( custom_predicate.clone(), - vec![args[0].clone(), Value::from(0)], + vec![value_ref(args[0].clone()), value_ref(0)], ); helper_custom_operation_verify_gadget( @@ -3403,7 +3510,7 @@ mod tests { let args = vec![Value::from(dict), Value::from(secret_dict)]; let expected_st = Statement::Custom( custom_predicate.clone(), - vec![args[0].clone(), Value::from(0)], + vec![value_ref(args[0].clone()), value_ref(0)], ); helper_custom_operation_verify_gadget( diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index ae1ade3..5e9df2e 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1,5 +1,5 @@ pub mod operation; -use crate::middleware::{wildcard_values_from_op_st, PodType}; +use crate::middleware::{wildcard_values_from_op_st, PodType, BASE_PARAMS}; pub mod statement; use std::iter; @@ -39,7 +39,7 @@ use crate::{ middleware::{ self, value_from_op, CustomPredicateRef, Error as MiddlewareError, Hash, MainPodInputs, MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg, - ToFields, VDSet, Value, + ToFields, VDSet, Value, ValueRef, }, timed, }; @@ -104,9 +104,20 @@ pub(crate) fn extract_custom_predicate_verifications( if let middleware::Operation::Custom(cpr, sts) = op { if let middleware::Statement::Custom(st_cpr, st_args) = st { assert_eq!(cpr, st_cpr); + // The custom operation outputs statements with literal arguments. They can be + // replaced by references later with ReplaceValueWithEntry. + let st_args = st_args + .iter() + .map(|arg| match arg { + ValueRef::Literal(v) => Ok(v.clone()), + _ => Err(Error::custom( + "custom operation cannot output entries as arguments", + )), + }) + .collect::>>()?; let normalized_pred = cpr.normalized_predicate(); let wildcard_values = - wildcard_values_from_op_st(params, &normalized_pred, sts, st_args) + wildcard_values_from_op_st(params, &normalized_pred, sts, &st_args) .expect("resolved wildcards"); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); let custom_predicate_table_index = custom_predicates @@ -329,8 +340,8 @@ pub fn pad_statement(s: &mut Statement) { fill_pad(&mut s.1, StatementArg::None, Params::max_statement_args()) } -fn pad_operation_args(params: &Params, args: &mut Vec) { - fill_pad(args, OperationArg::None, params.max_operation_args) +fn pad_operation_args(args: &mut Vec) { + fill_pad(args, OperationArg::None, BASE_PARAMS.max_operation_args) } /// Returns the statements from the given MainPodInputs, padding to the respective max lengths @@ -428,7 +439,7 @@ pub(crate) fn process_private_statements_operations( .map(|mid_arg| find_op_arg(statements, mid_arg)) .collect::>>()?; - pad_operation_args(params, &mut args); + pad_operation_args(&mut args); operations.push(Operation(op.op_type(), args, *aux)); } Ok(operations) @@ -459,7 +470,11 @@ pub(crate) fn process_public_statements_operations( OperationAux::None, ) }; - fill_pad(&mut op.1, OperationArg::None, params.max_operation_args); + fill_pad( + &mut op.1, + OperationArg::None, + BASE_PARAMS.max_operation_args, + ); operations.push(op); } Ok(operations) @@ -469,6 +484,7 @@ pub struct Prover {} impl MainPodProver for Prover { fn prove(&self, params: &Params, inputs: MainPodInputs) -> Result> { + assert_eq!(inputs.statements.len(), inputs.operations.len()); // Pad input recursive pods with empty pods if necessary let empty_pod = if inputs.pods.len() == params.max_input_pods { // We don't need padding so we skip creating an EmptyPod @@ -1005,7 +1021,6 @@ pub mod tests { max_input_pods_public_statements: 2, max_statements: 5, max_public_statements: 2, - max_operation_args: 5, max_custom_predicates: 2, max_custom_predicate_verifications: 2, max_custom_predicate_wildcards: 3, @@ -1070,7 +1085,6 @@ pub mod tests { max_input_pods: 0, max_statements: 9, max_public_statements: 4, - max_operation_args: 5, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, max_merkle_proofs_containers: 3, @@ -1140,7 +1154,6 @@ pub mod tests { max_input_pods: 0, max_statements: 6, max_public_statements: 2, - max_operation_args: 5, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, max_merkle_proofs_containers: 0, @@ -1251,11 +1264,108 @@ pub mod tests { ); let st = middleware::Statement::Custom( cpr, - [1, 1, 2].into_iter().map(middleware::Value::from).collect(), + [1, 1, 2] + .into_iter() + .map(middleware::ValueRef::from) + .collect(), ); builder.insert((st.clone(), op)).unwrap(); builder.reveal(&st).unwrap(); let prover = Prover {}; builder.prove(&prover).unwrap(); } + + #[test] + fn test_replace_value_with_entry() { + let params = middleware::Params::default(); + let vd_set = &*DEFAULT_VD_SET; + let mut builder = MainPodBuilder::new(¶ms, vd_set); + let d = dict!({"a" => 42, "b" => 33}); + builder + .priv_op(frontend::Operation::dict_contains(d.clone(), "a", 42)) + .unwrap(); + let st = builder.priv_op(frontend::Operation::lt(5, 42)).unwrap(); + // Transform `Lt(5, 42)` into `Lt(5, d.a)` by using `DictContains(d, "a", 42)` + builder + .pub_op(frontend::Operation::replace_value_with_entry( + vec![None, Some((&d, "a"))], + st, + )) + .unwrap(); + + // Mock + let prover = MockProver {}; + let pod = builder.prove(&prover).unwrap(); + pod.pod.verify().unwrap(); + assert_eq!( + middleware::Statement::Lt( + middleware::ValueRef::Literal(Value::from(5)), + middleware::ValueRef::Key(middleware::AnchoredKey { + root: d.commitment(), + key: middleware::Key::from("a") + }) + ), + pod.public_statements[0] + ); + + // Real + let prover = Prover {}; + let pod = builder.prove(&prover).unwrap(); + pod.pod.verify().unwrap() + } + + #[test] + fn test_entry_custom_statement_arg() { + let params = middleware::Params::default(); + let vd_set = &*DEFAULT_VD_SET; + let input = r#" + PredA(x) = AND( + Lt(x, 100) + ) + + PredB(d) = AND( + PredA(d.x) + ) + "#; + let module = load_module(input, "my_mod", ¶ms, &[]).expect("lang parse"); + let pred_a = module.batch.predicate_ref_by_name("PredA").unwrap(); + let pred_b = module.batch.predicate_ref_by_name("PredB").unwrap(); + + let mut builder = MainPodBuilder::new(¶ms, vd_set); + let d = dict!({"x" => 42, "y" => 33}); + + let st_lt = builder.priv_op(frontend::Operation::lt(42, 100)).unwrap(); + let st_a = builder + .priv_op(frontend::Operation::custom(pred_a, [st_lt])) + .unwrap(); + builder + .priv_op(frontend::Operation::dict_contains(d.clone(), "x", 42)) + .unwrap(); + // Transform `PredA(42)` into `PredA(d.x)` by using `DictContains(d, "x", 42)` + let st_a1 = builder + .priv_op(frontend::Operation::replace_value_with_entry( + vec![Some((&d, "x"))], + st_a, + )) + .unwrap(); + + builder + .pub_op(frontend::Operation::custom(pred_b.clone(), [st_a1])) + .unwrap(); + + // Mock + let prover = MockProver {}; + let pod = builder.prove(&prover).unwrap(); + pod.pod.verify().unwrap(); + let expected = middleware::Statement::Custom( + pred_b, + vec![middleware::ValueRef::Literal(Value::from(d))], + ); + assert_eq!(expected, pod.public_statements[0]); + + // Real + let prover = Prover {}; + let pod = builder.prove(&prover).unwrap(); + pod.pod.verify().unwrap() + } } diff --git a/src/backends/plonky2/mainpod/statement.rs b/src/backends/plonky2/mainpod/statement.rs index 27776a6..64fe675 100644 --- a/src/backends/plonky2/mainpod/statement.rs +++ b/src/backends/plonky2/mainpod/statement.rs @@ -4,7 +4,9 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::error::{Error, Result}, - middleware::{self, NativePredicate, Predicate, StatementArg, ToFields, Value, BASE_PARAMS}, + middleware::{ + self, NativePredicate, Predicate, StatementArg, ToFields, Value, ValueRef, BASE_PARAMS, + }, }; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -96,15 +98,15 @@ impl TryFrom for middleware::Statement { )))?, }, Predicate::Custom(cpr) => { - let vs: Vec = proper_args + let args: Vec = proper_args .into_iter() .filter_map(|arg| match arg { - SA::None => None, - SA::Literal(v) => Some(v), - _ => unreachable!(), + StatementArg::Literal(v) => Some(ValueRef::Literal(v)), + StatementArg::Key(k) => Some(ValueRef::Key(k)), + StatementArg::None => None, }) .collect(); - S::Custom(cpr, vs) + S::Custom(cpr, args) } Predicate::Intro(ir) => { let vs: Vec = proper_args diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index dcb1355..b8c6a03 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -380,7 +380,8 @@ pub mod tests { great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request, zu_kyc_sign_dict_builders, MOCK_VD_SET, }, - frontend, middleware, + frontend::{self}, + middleware, middleware::{Signer as _, Value}, }; diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index a2614a0..8de6871 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -316,7 +316,9 @@ mod tests { backends::plonky2::mock::mainpod::MockProver, examples::{custom::eth_dos_batch, MOCK_VD_SET}, frontend::{MainPodBuilder, Operation}, - middleware::{self, containers::Set, CustomPredicateRef, Params, PodType, DEFAULT_VD_SET}, + middleware::{ + self, containers::Set, CustomPredicateRef, Params, PodType, ValueRef, DEFAULT_VD_SET, + }, }; #[test] @@ -507,7 +509,7 @@ mod tests { .find(|s| matches!(s, middleware::Statement::Custom(_, _))) .expect("should have a custom statement"); if let middleware::Statement::Custom(_, args) = custom_st { - assert_eq!(args[0], pred_b_hash); + assert_eq!(args[0], ValueRef::Literal(pred_b_hash)); } Ok(()) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index f23e374..b6e8691 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -4,7 +4,7 @@ use std::{ collections::{HashMap, HashSet}, convert::From, - fmt, + fmt, iter, }; use itertools::Itertools; @@ -15,9 +15,10 @@ pub use serialization::SerializedMainPod; use crate::middleware::{ self, check_custom_pred, containers::{Container, Dictionary}, - fill_wildcard_values, hash_op, max_op, prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, - MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey, RawValue, - Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, EMPTY_VALUE, + fill_wildcard_values, hash_op, max_op, prod_op, root_key_to_ak, sum_op, AnchoredKey, Hash, Key, + MainPodInputs, MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey, + RawValue, Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, BASE_PARAMS, + EMPTY_VALUE, }; mod custom; @@ -566,6 +567,37 @@ impl MainPodBuilder { // TODO: validate proof Statement::ContainerDelete(r1, r2, r3) } + (ReplaceValueWithEntry, &args, _) => { + let mut args = args.to_vec(); + if args.len() != BASE_PARAMS.max_statement_args + 1 { + return Err(Error::custom(format!( + "ReplaceValueWithEntry requires exactly {} args but {} were found", + BASE_PARAMS.max_statement_args + 1, + args.len() + ))); + } + let st = match args.pop().expect("valid vec len") { + OperationArg::Statement(st) => st, + _ => return Err(Error::custom("expected statement")), + }; + let new_st_args = iter::zip(st.args().into_iter(), args) + .map(|(st_arg, arg)| match (st_arg, arg) { + (st_arg, OperationArg::Statement(Statement::None)) => Ok(st_arg), + ( + StatementArg::Literal(st_arg_v), + OperationArg::Statement(Statement::Contains( + ValueRef::Literal(root), + ValueRef::Literal(key), + ValueRef::Literal(v), + )), + ) if st_arg_v == v => root_key_to_ak(&root, &key) + .map(StatementArg::Key) + .ok_or_else(native_arg_error), + _ => Err(Error::custom("unexpected operation argument")), + }) + .collect::>>()?; + Statement::from_args(st.predicate(), new_st_args)? + } (t, _, _) => { if t.is_syntactic_sugar() { return Err(Error::custom(format!( @@ -615,7 +647,7 @@ impl MainPodBuilder { .map(|v| v.unwrap_or_else(|| v_default.clone())) .collect(); check_custom_pred(&self.params, &cpr, &args, &st_args)?; - Statement::Custom(cpr, st_args) + Statement::Custom(cpr, st_args.into_iter().map(ValueRef::Literal).collect()) } }; Ok(st) diff --git a/src/frontend/multi_pod/cost.rs b/src/frontend/multi_pod/cost.rs index 0c0c2ef..2839ea8 100644 --- a/src/frontend/multi_pod/cost.rs +++ b/src/frontend/multi_pod/cost.rs @@ -111,7 +111,8 @@ impl StatementCost { // Syntactic sugar variants (lowered before proving) | NativeOperation::GtEqFromEntries | NativeOperation::GtFromEntries - | NativeOperation::GtToNotEqual => {} + | NativeOperation::GtToNotEqual + | NativeOperation::ReplaceValueWithEntry => {} } } OperationType::Custom(cpr) => { diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index a1045a5..9794e60 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -1,10 +1,10 @@ -use std::fmt; +use std::{fmt, iter}; use crate::{ frontend::SignedDict, middleware::{ containers::Dictionary, root_key_to_ak, CustomPredicateRef, NativeOperation, OperationAux, - OperationType, Signature, Statement, Value, ValueRef, + OperationType, Signature, Statement, Value, ValueRef, BASE_PARAMS, }, }; @@ -219,6 +219,24 @@ impl Operation { op_impl_oa!(set_insert, SetInsertFromEntries, 3); op_impl_oa!(set_delete, SetDeleteFromEntries, 3); op_impl_oa!(array_update, ArrayUpdateFromEntries, 4); + pub fn replace_value_with_entry(args: Vec>, st: Statement) -> Self { + assert!(args.len() <= BASE_PARAMS.max_statement_args); + let args = args + .into_iter() + .chain(iter::repeat(None)) + .take(BASE_PARAMS.max_statement_args) + .map(|a| match a { + None => OperationArg::Statement(Statement::None), + Some((dict, key)) => OperationArg::from((dict, key)), + }) + .chain(iter::once(OperationArg::Statement(st))) + .collect(); + Self( + OperationType::Native(NativeOperation::ReplaceValueWithEntry), + args, + OperationAux::None, + ) + } pub fn signed_by( msg: impl Into, pk: impl Into, diff --git a/src/lang/diagnostics.rs b/src/lang/diagnostics.rs index 0a1d770..7807318 100644 --- a/src/lang/diagnostics.rs +++ b/src/lang/diagnostics.rs @@ -174,18 +174,6 @@ fn render_validation_error( "second REQUEST here", ), - ValidationError::InvalidArgumentType { predicate, span } => { - let title = format!("invalid argument type for `{}`", predicate); - render_with_optional_span( - renderer, - source, - path, - &title, - span.as_ref(), - "anchored keys not allowed here", - ) - } - ValidationError::DuplicateWildcard { name, span } => { let title = format!("duplicate wildcard: {}", name); render_with_optional_span( diff --git a/src/lang/error.rs b/src/lang/error.rs index 769faf6..792d4d8 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -135,12 +135,6 @@ pub enum ValidationError { span: Option, }, - #[error("Invalid argument type for {predicate}: anchored keys not allowed")] - InvalidArgumentType { - predicate: String, - span: Option, - }, - #[error("Duplicate wildcard in predicate arguments: {name}")] DuplicateWildcard { name: String, span: Option }, diff --git a/src/lang/frontend_ast_validate.rs b/src/lang/frontend_ast_validate.rs index 0b7737d..ef3d395 100644 --- a/src/lang/frontend_ast_validate.rs +++ b/src/lang/frontend_ast_validate.rs @@ -522,7 +522,7 @@ impl Validator { } // Validate arguments - self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?; + self.validate_statement_args(stmt, wildcard_context)?; Ok(()) } @@ -530,75 +530,37 @@ impl Validator { fn validate_statement_args( &self, stmt: &StatementTmpl, - pred_info: Option<&PredicateInfo>, wildcard_context: Option<(&str, &WildcardScope)>, ) -> Result<(), ValidationError> { - // For custom predicates, only wildcards and literals are allowed - if matches!( - pred_info.map(|i| &i.kind), - Some(PredicateKind::Custom { .. }) - | Some(PredicateKind::BatchImported { .. }) - | Some(PredicateKind::ModuleImported { .. }) - ) { - for arg in &stmt.args { - match arg { - StatementTmplArg::AnchoredKey(_) => { - return Err(ValidationError::InvalidArgumentType { - predicate: stmt.predicate.predicate_name().to_string(), - span: stmt.span, - }); - } - StatementTmplArg::Wildcard(id) => { - if let Some((pred_name, scope)) = wildcard_context { - if !scope.wildcards.contains_key(&id.name) { - return Err(ValidationError::UndefinedWildcard { - name: id.name.clone(), - pred_name: pred_name.to_string(), - span: id.span, - }); - } + for arg in &stmt.args { + match arg { + StatementTmplArg::Wildcard(id) => { + if let Some((pred_name, scope)) = wildcard_context { + if !scope.wildcards.contains_key(&id.name) { + return Err(ValidationError::UndefinedWildcard { + name: id.name.clone(), + pred_name: pred_name.to_string(), + span: id.span, + }); } } - StatementTmplArg::Literal(lit) => { - self.validate_literal_value(lit)?; - } - StatementTmplArg::SelfPredicateHash(id) => { - self.validate_self_predicate_hash(id, wildcard_context)?; - } } - } - } else { - // Native predicates can have anchored keys - for arg in &stmt.args { - match arg { - StatementTmplArg::Wildcard(id) => { - if let Some((pred_name, scope)) = wildcard_context { - if !scope.wildcards.contains_key(&id.name) { - return Err(ValidationError::UndefinedWildcard { - name: id.name.clone(), - pred_name: pred_name.to_string(), - span: id.span, - }); - } + StatementTmplArg::AnchoredKey(ak) => { + if let Some((pred_name, scope)) = wildcard_context { + if !scope.wildcards.contains_key(&ak.root.name) { + return Err(ValidationError::UndefinedWildcard { + name: ak.root.name.clone(), + pred_name: pred_name.to_string(), + span: ak.root.span, + }); } } - StatementTmplArg::AnchoredKey(ak) => { - if let Some((pred_name, scope)) = wildcard_context { - if !scope.wildcards.contains_key(&ak.root.name) { - return Err(ValidationError::UndefinedWildcard { - name: ak.root.name.clone(), - pred_name: pred_name.to_string(), - span: ak.root.span, - }); - } - } - } - StatementTmplArg::Literal(lit) => { - self.validate_literal_value(lit)?; - } - StatementTmplArg::SelfPredicateHash(id) => { - self.validate_self_predicate_hash(id, wildcard_context)?; - } + } + StatementTmplArg::Literal(lit) => { + self.validate_literal_value(lit)?; + } + StatementTmplArg::SelfPredicateHash(id) => { + self.validate_self_predicate_hash(id, wildcard_context)?; } } } @@ -839,10 +801,7 @@ mod tests { module_hash ); let result = parse_and_validate_request(&input, &available_modules); - assert!(matches!( - result, - Err(ValidationError::InvalidArgumentType { .. }) - )); + assert!(result.is_ok()); } #[test] diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 5674f53..291f7a6 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -578,7 +578,6 @@ mod tests { max_input_pods: 3, max_statements: 31, max_public_statements: 10, - max_operation_args: 5, max_custom_predicate_wildcards: 12, ..Default::default() }; diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index e6af211..0012251 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -169,6 +169,12 @@ pub struct Hash( pub [F; HASH_SIZE], ); +impl Hash { + pub fn raw(self) -> RawValue { + RawValue::from(self) + } +} + impl From for HashOut { fn from(hash: Hash) -> HashOut { HashOut { elements: hash.0 } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index cf6d9be..e5c7285 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -436,7 +436,7 @@ impl fmt::Display for CustomPredicate { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)] +#[derive(Clone, PartialEq, Eq, Serialize, JsonSchema)] enum CustomPredicateBatchData { Full { #[serde(skip)] @@ -449,6 +449,20 @@ enum CustomPredicateBatchData { }, } +// Explicit implementation of Debug to skip the merkle tree +impl fmt::Debug for CustomPredicateBatchData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self { + Self::Full { mt, predicates } => f + .debug_struct("Full") + .field("id", &mt.root()) + .field("predicates", &predicates) + .finish(), + Self::Opaque { id } => f.debug_struct("Opaque").field("id", &id).finish(), + } + } +} + // TODO: Rename Batch for Module everywhere in the code base impl CustomPredicateBatchData { fn new_full(predicates: Vec) -> Self { @@ -630,7 +644,7 @@ mod tests { middleware::{ AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate, Operation, Params, Predicate, Statement, StatementTmpl, - StatementTmplArg, + StatementTmplArg, ValueRef, }, }; @@ -653,6 +667,9 @@ mod tests { fn names(names: &[&str]) -> Vec { names.iter().map(|s| s.to_string()).collect() } + fn value_ref(v: impl Into) -> ValueRef { + v.into() + } #[allow(clippy::upper_case_acronyms)] type STA = StatementTmplArg; @@ -701,7 +718,7 @@ mod tests { }); let custom_statement = Statement::Custom( CustomPredicateRef::new(cust_pred_batch.clone(), 0), - vec![Value::from(d0.clone())], + vec![value_ref(d0.clone())], ); let custom_deduction = Operation::Custom( @@ -833,7 +850,7 @@ mod tests { // Example statement let ethdos_example = Statement::Custom( CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), - vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)], + vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)], ); // Copies should work. @@ -842,7 +859,7 @@ mod tests { // This could arise as the inductive step. let ethdos_ind_example = Statement::Custom( CustomPredicateRef::new(eth_dos_distance_batch.clone(), 1), - vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)], + vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)], ); assert!(Operation::Custom( @@ -857,12 +874,12 @@ mod tests { let ethdos_facts = vec![ Statement::Custom( CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), - vec![Value::from("Alice"), Value::from("Charlie"), Value::from(6)], + vec![value_ref("Alice"), value_ref("Charlie"), value_ref(6)], ), Statement::sum_of(Value::from(7), Value::from(6), Value::from(1)), Statement::Custom( CustomPredicateRef::new(eth_friend_batch.clone(), 0), - vec![Value::from("Charlie"), Value::from("Bob")], + vec![value_ref("Charlie"), value_ref("Bob")], ), ]; @@ -959,7 +976,10 @@ mod tests { let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())]; // The output statement - let output_st = Statement::Custom(pred_b_ref.clone(), vec![some_value.clone()]); + let output_st = Statement::Custom( + pred_b_ref.clone(), + vec![ValueRef::Literal(some_value.clone())], + ); // This should pass assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(¶ms, &output_st)?); @@ -1024,12 +1044,18 @@ mod tests { // Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())]; - let st_a = Statement::Custom(pred_a_ref.clone(), vec![pred_b_hash.clone()]); + let st_a = Statement::Custom( + pred_a_ref.clone(), + vec![ValueRef::Literal(pred_b_hash.clone())], + ); assert!(Operation::Custom(pred_a_ref, op_a).check(¶ms, &st_a)?); // Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())]; - let st_b = Statement::Custom(pred_b_ref.clone(), vec![pred_a_hash.clone()]); + let st_b = Statement::Custom( + pred_b_ref.clone(), + vec![ValueRef::Literal(pred_a_hash.clone())], + ); assert!(Operation::Custom(pred_b_ref, op_b).check(¶ms, &st_b)?); Ok(()) diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 19ca2c2..82675d7 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -768,6 +768,8 @@ pub struct BaseParams { /// in a custom predicate pub max_custom_predicate_arity: usize, pub max_depth_custom_batch_mt: usize, + // This value depends on `max_custom_predicate_arity` + pub max_operation_args: usize, } pub const BASE_PARAMS: BaseParams = BaseParams { @@ -775,6 +777,7 @@ pub const BASE_PARAMS: BaseParams = BaseParams { max_statement_args: 5, max_custom_predicate_arity: 5, max_depth_custom_batch_mt: 16, // up to 65k (2^16) custom predicates in a batch + max_operation_args: 5 + 1, }; /// Params: non dynamic parameters that define the circuit. @@ -785,7 +788,6 @@ pub struct Params { pub max_input_pods_public_statements: usize, pub max_statements: usize, pub max_public_statements: usize, - pub max_operation_args: usize, // max number of different custom predicates that can be used in a MainPod pub max_custom_predicates: usize, // max number of operations using custom predicates that can be verified in the MainPod @@ -815,7 +817,6 @@ impl Default for Params { max_input_pods_public_statements: 8, max_statements: 48, max_public_statements: 8, - max_operation_args: 5, max_custom_predicates: 8, max_custom_predicate_verifications: 8, max_custom_predicate_wildcards: 8, diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 1793e4d..8d3316c 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -14,7 +14,7 @@ use crate::{ hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key, MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result, Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, Value, ValueRef, - Wildcard, F, + Wildcard, BASE_PARAMS, F, }, }; @@ -89,6 +89,7 @@ pub enum NativeOperation { ContainerInsertFromEntries = 16, ContainerUpdateFromEntries = 17, ContainerDeleteFromEntries = 18, + ReplaceValueWithEntry = 19, // Syntactic sugar operations. These operations are not supported by the backend. The // frontend compiler is responsible of translating these operations into the operations above. @@ -164,6 +165,7 @@ impl OperationType { NativeOperation::ContainerDeleteFromEntries => { Some(Predicate::Native(NativePredicate::ContainerDelete)) } + NativeOperation::ReplaceValueWithEntry => None, no => unreachable!("Unexpected syntactic sugar op {:?}", no), }, OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())), @@ -219,6 +221,10 @@ pub enum Operation { /* key */ Statement, /* proof */ MerkleTreeStateTransitionProof, ), + ReplaceValueWithEntry( + /* Contains/None len=max_statement_args */ Vec, + /* to copy */ Statement, + ), Custom(CustomPredicateRef, Vec), } @@ -270,6 +276,7 @@ impl Operation { OT::Native(ContainerUpdateFromEntries) } Self::ContainerDeleteFromEntries(_, _, _, _) => OT::Native(ContainerDeleteFromEntries), + Self::ReplaceValueWithEntry(_, _) => OT::Native(ReplaceValueWithEntry), Self::Custom(cpr, _) => OT::Custom(cpr.clone()), } } @@ -295,6 +302,11 @@ impl Operation { Self::ContainerInsertFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4], Self::ContainerUpdateFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4], Self::ContainerDeleteFromEntries(s1, s2, s3, _pf) => vec![s1, s2, s3], + Self::ReplaceValueWithEntry(args, s) => { + let mut sts = args; + sts.push(s); + sts + } Self::Custom(_, args) => args, } } @@ -377,6 +389,18 @@ impl Operation { &[s1, s2, s3], OA::MerkleTreeStateTransitionProof(pf), ) => Self::ContainerDeleteFromEntries(s1.clone(), s2.clone(), s3.clone(), pf), + (NO::ReplaceValueWithEntry, args, OA::None) => { + let mut args = args.to_vec(); + if args.len() != BASE_PARAMS.max_statement_args + 1 { + return Err(Error::custom(format!( + "ReplaceValueWithEntry requires exactly {} args but {} were found", + BASE_PARAMS.max_statement_args + 1, + args.len() + ))); + } + let st = args.pop().expect("valid vec len"); + Self::ReplaceValueWithEntry(args, st) + } _ => Err(Error::custom(format!( "Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.", op_code, @@ -422,6 +446,38 @@ impl Operation { Ok(sig.verify(pk, msg.raw())) } + fn check_replace_value_with_entry( + entries: &[Statement], + st_in: &Statement, + expected_st_out: &Statement, + ) -> Result { + if entries.len() != BASE_PARAMS.max_statement_args { + return Ok(false); + } + let args = iter::zip(st_in.args(), entries) + .map(|(arg_in, entry)| match (arg_in, entry) { + (arg_in, Statement::None) => Ok(arg_in), + ( + StatementArg::Literal(v_in), + Statement::Contains( + ValueRef::Literal(root), + ValueRef::Literal(key), + ValueRef::Literal(v), + ), + ) if v == &v_in => Ok(StatementArg::Key(AnchoredKey::new( + Hash::from(root.raw()), + Key::from(key.as_str().ok_or_else(|| Error::custom("not a string"))?), + ))), + _ => Err(Error::custom( + "invalid statement argument in ReplaceValueWithEntry", + )), + }) + .collect::>>()?; + + let st_out = Statement::from_args(st_in.predicate(), args)?; + Ok(&st_out == expected_st_out) + } + /// Checks the given operation against a statement. pub fn check(&self, params: &Params, output_statement: &Statement) -> Result { use Statement::*; @@ -541,7 +597,19 @@ impl Operation { (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) if batch == &cpr.batch && index == &cpr.index => { - check_custom_pred(params, cpr, args, s_args).map(|_| true)? + // The custom operation outputs statements with literal arguments. They can be + // replaced by references later with ReplaceValueWithEntry. + let s_args = s_args + .iter() + .map(|arg| match arg { + ValueRef::Literal(v) => Ok(v.clone()), + _ => Err(deduction_err()), + }) + .collect::>>()?; + check_custom_pred(params, cpr, args, &s_args).map(|_| true)? + } + (Self::ReplaceValueWithEntry(entries, st_in), st_out) => { + Self::check_replace_value_with_entry(entries, st_in, st_out)? } _ => return Err(deduction_err()), }; @@ -648,9 +716,9 @@ pub fn wildcard_values_from_op_st( params: &Params, pred: &CustomPredicate, op_args: &[Statement], - st_args: &[Value], + resolved_st_args: &[Value], ) -> Result> { - let mut wildcard_map = st_args + let mut wildcard_map = resolved_st_args .iter() .map(|v| Some(v.clone())) .chain(core::iter::repeat(None)) diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index d3e0534..b5c1f60 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -311,7 +311,7 @@ pub enum Statement { /* old_root */ ValueRef, /* key */ ValueRef, ), - Custom(CustomPredicateRef, Vec), + Custom(CustomPredicateRef, Vec), Intro(IntroPredicateRef, Vec), } @@ -407,7 +407,7 @@ impl Statement { vec![ak1.into(), ak2.into(), ak3.into(), ak4.into()] } Self::ContainerDelete(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()], - Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(Literal)), + Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(StatementArg::from)), Self::Intro(_, args) => Vec::from_iter(args.into_iter().map(Literal)), } } @@ -478,14 +478,11 @@ impl Statement { } (BatchSelf(_), _) => unreachable!(), (Custom(cpr), _) => { - let v_args: Result> = args + let v_args = args .iter() - .map(|x| match x { - StatementArg::Literal(v) => Ok(v.clone()), - _ => Err(Error::incorrect_statements_args()), - }) - .collect(); - Self::Custom(cpr, v_args?) + .map(|x| x.try_into()) + .collect::>>()?; + Self::Custom(cpr, v_args) } (Intro(ir), _) => { let v_args: Result> = args From 3203c883e552c53ddbff2c5ef1aada5fa9938f88 Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Mon, 20 Apr 2026 00:07:20 -0700 Subject: [PATCH 07/10] Use windowed ECAddXuGate for PublicKeyOf (#501) --- src/backends/plonky2/circuits/mainpod.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 0ac3bec..c4b891a 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -38,7 +38,8 @@ use crate::{ ec::{ bits::{BigUInt320Target, CircuitBuilderBits}, curve::{ - CircuitBuilderElliptic, Point, PointTarget, WitnessWriteCurve, GROUP_ORDER, + CircuitBuilderElliptic, CircuitBuilderSignature, Point, PointTarget, + WitnessWriteCurve, GROUP_ORDER, }, schnorr::{CircuitBuilderSchnorr, SecretKey, SignatureTarget, WitnessWriteSchnorr}, }, @@ -320,9 +321,10 @@ fn build_operation_aux_table_circuit( } // PublicKeyOf: verify the derivation from a Schnorr secret key to public key + let invgenerator = builder.constant_point(Point::generator().inverse()); + let zero_bits: [BoolTarget; 320] = array::from_fn(|_| builder._false()); for sk in public_key_of_sks { let measure = measure_gates_begin!(builder, "PublicKeyOf"); - let invgenerator = builder.constant_point(Point::generator().inverse()); let group_orderm1 = &*GROUP_ORDER - BigUint::one(); let group_orderm1target = builder.constant_biguint320(&group_orderm1); let compare_ok = list_le_circuit( @@ -333,7 +335,9 @@ fn build_operation_aux_table_circuit( ); builder.assert_one(compare_ok.target); // public_key = g^-secret key - let pk = builder.multiply_point(&sk.bits, &invgenerator); + // Use the windowed ECAddXuGate (3-bit windows, 107 iterations) instead of the + // naive multiply_point (1-bit double-and-add, 320 iterations) for fewer gates. + let pk = builder.linear_combination_point_gen(&zero_bits, &sk.bits, &invgenerator); let sk_hash = builder.hash_n_to_hash_no_pad::(sk.limbs.to_vec()); let pk_hash = builder.hash_n_to_hash_no_pad::( pk.x.components.into_iter().chain(pk.u.components).collect(), From 8844fe124c79e150b8b82dbb62b7451428e7b367 Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Thu, 23 Apr 2026 01:41:29 -0700 Subject: [PATCH 08/10] Diagnostics for MultiPodBuilder (#500) * Diagnostics for MultiPodBuilder * Reduce duplication --- src/frontend/multi_pod/diagnostics.rs | 457 ++++++++++++++++++++++++++ src/frontend/multi_pod/mod.rs | 32 ++ 2 files changed, 489 insertions(+) create mode 100644 src/frontend/multi_pod/diagnostics.rs diff --git a/src/frontend/multi_pod/diagnostics.rs b/src/frontend/multi_pod/diagnostics.rs new file mode 100644 index 0000000..438f379 --- /dev/null +++ b/src/frontend/multi_pod/diagnostics.rs @@ -0,0 +1,457 @@ +//! Diagnostic utilities for multi-POD resource analysis. +//! +//! Provides two views: +//! - [`ResourceSummary`]: Pre-solve aggregate resource demand vs. per-POD limits. +//! Shows which resource category is the bottleneck (requires the most PODs). +//! - [`SolutionBreakdown`]: Post-solve per-POD utilization showing how full each POD is. + +use std::{collections::BTreeSet, fmt}; + +use super::cost::StatementCost; +use crate::middleware::Params; + +/// A single resource category's usage vs. per-POD limit. +/// +/// Used both for pre-solve aggregate demand (in [`ResourceSummary`]) where +/// `used` is the total across all statements, and for post-solve per-POD +/// breakdown (in [`PodUtilization`]) where `used` is the POD's consumption. +#[derive(Clone, Debug)] +pub struct UtilizationRow { + pub name: &'static str, + pub used: usize, + pub limit: usize, +} + +impl UtilizationRow { + /// Utilization as a fraction (0.0 to 1.0). + pub fn utilization(&self) -> f64 { + if self.limit == 0 { + if self.used == 0 { + 0.0 + } else { + f64::INFINITY + } + } else { + self.used as f64 / self.limit as f64 + } + } + + /// Minimum PODs needed for this resource alone: `ceil(used / limit)`. + /// `None` if `limit` is 0 and `used > 0` (infeasible). + pub fn min_pods(&self) -> Option { + lower_bound(self.used, self.limit) + } +} + +/// Aggregate resource usage over a set of statement costs into per-category rows. +/// +/// Single source of truth for the resource categories and their corresponding +/// `Params` limits. Used both for pre-solve totals and per-POD breakdowns. +fn aggregate_rows<'a>( + costs: impl IntoIterator, + params: &Params, +) -> (Vec, usize) { + let mut num_stmts = 0usize; + let mut merkle_proofs = 0usize; + let mut merkle_state_transitions = 0usize; + let mut custom_pred_verifications = 0usize; + let mut signed_by = 0usize; + let mut public_key_of = 0usize; + let mut custom_pred_ids = BTreeSet::new(); + + for c in costs { + num_stmts += 1; + merkle_proofs += c.merkle_proofs; + merkle_state_transitions += c.merkle_state_transitions; + custom_pred_verifications += c.custom_pred_verifications; + signed_by += c.signed_by; + public_key_of += c.public_key_of; + custom_pred_ids.extend(c.custom_predicates_ids.iter().cloned()); + } + + let rows = vec![ + UtilizationRow { + name: "private statements", + used: num_stmts, + limit: params.max_priv_statements(), + }, + UtilizationRow { + name: "merkle proofs", + used: merkle_proofs, + limit: params.max_merkle_proofs_containers, + }, + UtilizationRow { + name: "merkle state transitions", + used: merkle_state_transitions, + limit: params.max_merkle_tree_state_transition_proofs_containers, + }, + UtilizationRow { + name: "custom pred verifications", + used: custom_pred_verifications, + limit: params.max_custom_predicate_verifications, + }, + UtilizationRow { + name: "signed_by", + used: signed_by, + limit: params.max_signed_by, + }, + UtilizationRow { + name: "public_key_of", + used: public_key_of, + limit: params.max_public_key_of, + }, + UtilizationRow { + name: "distinct custom predicates", + used: custom_pred_ids.len(), + limit: params.max_custom_predicates, + }, + ]; + + (rows, num_stmts) +} + +/// Pre-solve aggregate resource summary. +/// +/// Shows total resource demand across all operations and the minimum PODs +/// each resource category would require independently. +#[derive(Clone, Debug)] +pub struct ResourceSummary { + pub rows: Vec, + pub num_statements: usize, +} + +impl ResourceSummary { + /// Compute a resource summary from per-statement costs and params. + pub fn from_costs(costs: &[StatementCost], params: &Params) -> Self { + let (rows, num_statements) = aggregate_rows(costs.iter(), params); + Self { + rows, + num_statements, + } + } + + /// The resource category requiring the most PODs (the bottleneck). + /// Returns `None` only if there are no statements. + pub fn bottleneck(&self) -> Option<&UtilizationRow> { + self.rows + .iter() + .filter(|r| r.used > 0) + .max_by_key(|r| r.min_pods().unwrap_or(usize::MAX)) + } +} + +impl fmt::Display for ResourceSummary { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "Resource Summary ({} statements)", self.num_statements)?; + writeln!( + f, + " {:<30} {:>5} {:>9} {:>8}", + "Category", "Total", "Limit/POD", "Min PODs" + )?; + + let bottleneck_name = self.bottleneck().map(|r| r.name); + + for row in &self.rows { + let min_pods_str = match row.min_pods() { + Some(n) => format!("{}", n), + None => "inf".to_string(), + }; + let marker = if Some(row.name) == bottleneck_name && row.used > 0 { + " <<<" + } else { + "" + }; + writeln!( + f, + " {:<30} {:>5} {:>9} {:>8}{}", + row.name, row.used, row.limit, min_pods_str, marker + )?; + } + + Ok(()) + } +} + +/// Per-POD resource utilization in a solved solution. +#[derive(Clone, Debug)] +pub struct PodUtilization { + /// POD index. + pub pod_idx: usize, + /// Whether this is the output POD (last). + pub is_output: bool, + /// Number of statements in this POD. + pub num_statements: usize, + /// Resource usage vs. limits for each category. + pub resources: Vec, +} + +/// Post-solve per-POD resource breakdown. +#[derive(Clone, Debug)] +pub struct SolutionBreakdown { + pub pods: Vec, + pub num_statements: usize, + pub pod_count: usize, +} + +impl SolutionBreakdown { + /// Compute a solution breakdown from per-statement costs, the solution's + /// pod_statements assignment, and params. + pub fn from_solution( + costs: &[StatementCost], + pod_statements: &[Vec], + pod_count: usize, + num_statements: usize, + params: &Params, + ) -> Self { + let pods = (0..pod_count) + .map(|pod_idx| { + let stmts = &pod_statements[pod_idx]; + let (resources, num_stmts) = + aggregate_rows(stmts.iter().map(|&s| &costs[s]), params); + PodUtilization { + pod_idx, + is_output: pod_idx == pod_count - 1, + num_statements: num_stmts, + resources, + } + }) + .collect(); + + Self { + pods, + num_statements, + pod_count, + } + } +} + +impl fmt::Display for SolutionBreakdown { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "Solution Breakdown ({} statements -> {} PODs)", + self.num_statements, self.pod_count + )?; + + for pod in &self.pods { + let role = if pod.is_output { + "output" + } else { + "intermediate" + }; + writeln!(f, " POD {} ({}):", pod.pod_idx, role)?; + + for row in &pod.resources { + // Only show rows with nonzero usage to reduce noise + if row.used > 0 { + let pct = if row.limit > 0 { + format!("({:>3}%)", (row.used * 100) / row.limit) + } else { + "".to_string() + }; + writeln!( + f, + " {:<30} {:>3}/{:<3} {}", + row.name, row.used, row.limit, pct + )?; + } + } + writeln!(f)?; + } + + Ok(()) + } +} + +fn lower_bound(used: usize, limit: usize) -> Option { + if used == 0 { + Some(0) + } else if limit == 0 { + None + } else { + Some(used.div_ceil(limit)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + frontend::multi_pod::cost::CustomPredicateId, + middleware::{Hash, RawValue}, + }; + + fn default_params() -> Params { + Params { + max_statements: 48, + max_public_statements: 8, + max_merkle_proofs_containers: 8, + max_merkle_tree_state_transition_proofs_containers: 4, + max_custom_predicate_verifications: 10, + max_custom_predicates: 2, + max_signed_by: 4, + max_public_key_of: 4, + ..Params::default() + } + } + + #[test] + fn test_resource_summary_bottleneck() { + let params = default_params(); + // max_priv = 48 - 8 = 40 + + // 6 merkle proofs, 3 state transitions, rest zero-cost + let costs: Vec = (0..14) + .map(|i| { + let mut c = StatementCost::default(); + if i < 6 { + c.merkle_proofs = 1; + } else if i < 9 { + c.merkle_state_transitions = 1; + } + c + }) + .collect(); + + let summary = ResourceSummary::from_costs(&costs, ¶ms); + + // 14 statements / 40 per pod = 1 pod for statements + // 6 merkle proofs / 8 per pod = 1 pod + // 3 state transitions / 4 per pod = 1 pod + // All categories need 1 pod, so bottleneck is whichever has the highest min_pods. + // They're all 1, so the first with total > 0 wins in max_by_key (stable). + let bottleneck = summary.bottleneck().unwrap(); + assert_eq!(bottleneck.min_pods(), Some(1)); + + // Verify display doesn't panic + let display = format!("{}", summary); + assert!(display.contains("Resource Summary (14 statements)")); + assert!(display.contains("merkle proofs")); + } + + #[test] + fn test_resource_summary_signed_by_bottleneck() { + let params = Params { + max_statements: 48, + max_public_statements: 8, + max_signed_by: 2, + ..Params::default() + }; + // max_priv = 40 + + // 6 signed_by operations + let costs: Vec = (0..6) + .map(|_| StatementCost { + signed_by: 1, + ..Default::default() + }) + .collect(); + + let summary = ResourceSummary::from_costs(&costs, ¶ms); + let bottleneck = summary.bottleneck().unwrap(); + + assert_eq!(bottleneck.name, "signed_by"); + // 6 / 2 = 3 pods + assert_eq!(bottleneck.min_pods(), Some(3)); + } + + #[test] + fn test_resource_summary_custom_predicates_bottleneck() { + let params = Params { + max_statements: 48, + max_public_statements: 8, + max_custom_predicates: 1, // Only 1 distinct predicate per POD + max_custom_predicate_verifications: 10, + ..Params::default() + }; + + // 3 statements using 3 different custom predicates + let costs: Vec = (0..3) + .map(|i| { + let mut ids = std::collections::BTreeSet::new(); + ids.insert(CustomPredicateId(Hash::from(RawValue::from(i as i64)))); + StatementCost { + custom_pred_verifications: 1, + custom_predicates_ids: ids, + ..Default::default() + } + }) + .collect(); + + let summary = ResourceSummary::from_costs(&costs, ¶ms); + let bottleneck = summary.bottleneck().unwrap(); + + assert_eq!(bottleneck.name, "distinct custom predicates"); + // 3 distinct predicates / 1 per pod = 3 pods + assert_eq!(bottleneck.min_pods(), Some(3)); + } + + #[test] + fn test_solution_breakdown_display() { + let params = default_params(); + + let costs: Vec = (0..8) + .map(|i| { + let mut c = StatementCost::default(); + if i < 4 { + c.merkle_proofs = 1; + } else { + c.merkle_state_transitions = 1; + } + c + }) + .collect(); + + let pod_statements = vec![ + vec![0, 1, 2, 3], // POD 0: 4 merkle proofs + vec![4, 5, 6, 7], // POD 1: 4 state transitions + ]; + + let breakdown = SolutionBreakdown::from_solution(&costs, &pod_statements, 2, 8, ¶ms); + + assert_eq!(breakdown.pods.len(), 2); + assert!(!breakdown.pods[0].is_output); + assert!(breakdown.pods[1].is_output); + + // POD 0 should have 4 merkle proofs + let mp = breakdown.pods[0] + .resources + .iter() + .find(|r| r.name == "merkle proofs") + .unwrap(); + assert_eq!(mp.used, 4); + assert_eq!(mp.limit, 8); + + // POD 1 should have 4 state transitions + let mst = breakdown.pods[1] + .resources + .iter() + .find(|r| r.name == "merkle state transitions") + .unwrap(); + assert_eq!(mst.used, 4); + assert_eq!(mst.limit, 4); + + // Verify display doesn't panic and contains expected content + let display = format!("{}", breakdown); + assert!(display.contains("Solution Breakdown (8 statements -> 2 PODs)")); + assert!(display.contains("POD 0 (intermediate)")); + assert!(display.contains("POD 1 (output)")); + } + + #[test] + fn test_utilization_row_fraction() { + let row = UtilizationRow { + name: "test", + used: 3, + limit: 4, + }; + assert!((row.utilization() - 0.75).abs() < f64::EPSILON); + + let zero_row = UtilizationRow { + name: "test", + used: 0, + limit: 4, + }; + assert!((zero_row.utilization()).abs() < f64::EPSILON); + } +} diff --git a/src/frontend/multi_pod/mod.rs b/src/frontend/multi_pod/mod.rs index 6bade5b..813e333 100644 --- a/src/frontend/multi_pod/mod.rs +++ b/src/frontend/multi_pod/mod.rs @@ -59,10 +59,12 @@ use crate::{ mod cost; mod deps; +pub mod diagnostics; mod solver; use cost::StatementCost; use deps::{DependencyGraph, StatementSource}; +pub use diagnostics::{ResourceSummary, SolutionBreakdown}; pub use solver::MultiPodSolution; /// Error type for multi-POD operations. @@ -200,6 +202,22 @@ impl SolvedMultiPod { &self.solution } + /// Compute a post-solve per-POD resource utilization breakdown. + pub fn solution_breakdown(&self) -> SolutionBreakdown { + let costs: Vec = self + .operations + .iter() + .map(StatementCost::from_operation) + .collect(); + SolutionBreakdown::from_solution( + &costs, + &self.solution.pod_statements, + self.solution.pod_count, + self.statements.len(), + &self.params, + ) + } + /// Build and prove all PODs. /// /// Builds PODs in dependency order (0, 1, ..., k) and proves each one. @@ -515,6 +533,20 @@ impl MultiPodBuilder { self.builder.stmt_len() } + /// Compute a pre-solve resource summary showing aggregate demand vs. per-POD limits. + /// + /// This is useful for understanding which resource category is the bottleneck + /// before running the solver, especially when debugging solver performance issues. + pub fn resource_summary(&self) -> ResourceSummary { + let costs: Vec = self + .builder + .operations + .iter() + .map(StatementCost::from_operation) + .collect(); + ResourceSummary::from_costs(&costs, &self.params) + } + /// Solve the packing problem and return a solved builder ready for proving. /// /// This runs the MILP solver to find the optimal POD assignment. From 111b132a00aa9a32b26981a5d00ec7376e1c422a Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Wed, 29 Apr 2026 00:56:39 -0700 Subject: [PATCH 09/10] Use projected statement lookup for op arg resolution (#503) * Use projected statement lookup for op arg resolution * Add projected op-arg index coverage test * Tidying and reorganising --- src/backends/plonky2/circuits/common.rs | 42 +++++++-- src/backends/plonky2/circuits/mainpod.rs | 105 ++++++++++++++++++--- src/backends/plonky2/circuits/mux_table.rs | 50 ++++++---- 3 files changed, 159 insertions(+), 38 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index de53ee5..dfee8a0 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ basetypes::{CircuitBuilder, CommonCircuitData, D}, - circuits::mainpod::CustomPredicateVerification, + circuits::{mainpod::CustomPredicateVerification, mux_table::TableGetGenerator}, error::Result, mainpod::{Operation, OperationArg, OperationAux, Statement}, primitives::merkletree::{ @@ -1362,6 +1362,18 @@ pub trait CircuitBuilderPod, const D: usize> { fn vec_ref(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T; /// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target` fn vec_ref_small(&mut self, params: &Params, ts: &[T], i: Target) -> T; + /// Like `vec_ref` but for wide rows: random-accesses a precomputed hash of each entry, then + /// materializes the selected row via a witness generator and constrains its hash. Cheaper than + /// `vec_ref` when each entry has many fields, since random access runs only over the 4-field + /// hashes. The caller is responsible for precomputing `ts_flattened` and `ts_hashes` once and + /// reusing the same slices across multiple lookups. + fn vec_ref_projected( + &mut self, + params: &Params, + ts_flattened: &[Vec], + ts_hashes: &[HashOutTarget], + i: &IndexTarget, + ) -> T; fn select_flattenable( &mut self, params: &Params, @@ -1764,12 +1776,6 @@ impl CircuitBuilderPod for CircuitBuilder { self.random_access(i.high, chunk_res) } - // TODO: Implement a version of vec_ref for types `T` which are big and support hashing. - // The idea would be the following: Take the array `ts` and hash each element. Then do the - // random access on the hash result. Finally "unhash" to recover the resolved element. - // We don't want to hash each element from the array each time, so we should cache the hashed - // result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and - // then do `ts: &[HashCache]`. fn vec_ref(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T { let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec], i| { let num_rows = m.len(); @@ -1793,6 +1799,28 @@ impl CircuitBuilderPod for CircuitBuilder { T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i)) } + fn vec_ref_projected( + &mut self, + params: &Params, + ts_flattened: &[Vec], + ts_hashes: &[HashOutTarget], + i: &IndexTarget, + ) -> T { + assert_eq!(ts_flattened.len(), ts_hashes.len()); + let selected_hash = self.vec_ref(params, ts_hashes, i); + let selected_flattened = self.add_virtual_targets(T::size(params)); + let selected_flattened_hash = + self.hash_n_to_hash_no_pad::(selected_flattened.clone()); + self.connect_hashes(selected_hash, selected_flattened_hash); + let result = T::from_flattened(params, &selected_flattened); + self.add_simple_generator(TableGetGenerator::new( + i.clone(), + ts_flattened.to_vec(), + selected_flattened, + )); + result + } + fn vec_ref_small(&mut self, params: &Params, ts: &[T], i: Target) -> T { let zero = self.zero(); self.vec_ref( diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index c4b891a..0605e07 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -97,9 +97,10 @@ impl StatementCache { builder: &mut CircuitBuilder, op: &OperationTarget, st: &StatementTarget, - prev_statements: &[StatementTarget], + prev_statement_flatteneds: &[Vec], + prev_statement_hashes: &[HashOutTarget], ) -> Self { - let op_args = if prev_statements.is_empty() { + let op_args = if prev_statement_flatteneds.is_empty() { (0..max_operation_args) .map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[])) .collect_vec() @@ -109,7 +110,14 @@ impl StatementCache { op.args .iter() .take(max_operation_args) - .map(|i| builder.vec_ref(params, prev_statements, i)) + .map(|i| { + builder.vec_ref_projected( + params, + prev_statement_flatteneds, + prev_statement_hashes, + i, + ) + }) .collect::>() }; assert!(Params::max_statement_args() >= MAX_VALUE_ARGS); @@ -193,7 +201,8 @@ fn verify_operation_public_statement_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op: &OperationTarget, - prev_statements: &[StatementTarget], + prev_statement_flatteneds: &[Vec], + prev_statement_hashes: &[HashOutTarget], ) -> Result<()> { let measure = measure_gates_begin!(builder, "OpVerifyPub"); @@ -203,7 +212,15 @@ fn verify_operation_public_statement_circuit( let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); // None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the // StatementCache requires. - let cache = StatementCachePub::new(params, 1, builder, op, st, prev_statements); + let cache = StatementCachePub::new( + params, + 1, + builder, + op, + st, + prev_statement_flatteneds, + prev_statement_hashes, + ); measure_gates_end!(builder, measure_resolve_op_args); let op_checks = vec![ @@ -434,7 +451,8 @@ fn verify_operation_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op: &OperationTarget, - prev_statements: &[StatementTarget], + prev_statement_flatteneds: &[Vec], + prev_statement_hashes: &[HashOutTarget], aux_table: &MuxTableTarget, ) -> Result<()> { let measure = measure_gates_begin!(builder, "OpVerifyPriv"); @@ -451,7 +469,8 @@ fn verify_operation_circuit( builder, op, st, - prev_statements, + prev_statement_flatteneds, + prev_statement_hashes, ); measure_gates_end!(builder, measure_resolve_op_args); @@ -1837,13 +1856,37 @@ fn verify_main_pod_circuit( // 2. Calculate the Pod Id from the public statements let sts_hash = calculate_statements_hash_circuit(builder, pub_statements); + // Precompute flattened statements and their hashes once, then resolve operation args using + // projected lookups. Reusing the flattened forms avoids re-flattening per op-arg lookup. + let statement_flatteneds: Vec> = statements.iter().map(|st| st.flatten()).collect(); + let statement_hashes = statement_flatteneds + .iter() + .map(|flat| builder.hash_n_to_hash_no_pad::(flat.clone())) + .collect_vec(); + // 5. Verify input statements for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() { - let prev_statements = &statements[..input_statements_offset + i]; + let prev_statement_flatteneds = &statement_flatteneds[..input_statements_offset + i]; + let prev_statement_hashes = &statement_hashes[..input_statements_offset + i]; if i < public_statements_offset { - verify_operation_circuit(params, builder, st, op, prev_statements, &aux_table)?; + verify_operation_circuit( + params, + builder, + st, + op, + prev_statement_flatteneds, + prev_statement_hashes, + &aux_table, + )?; } else { - verify_operation_public_statement_circuit(params, builder, st, op, prev_statements)?; + verify_operation_public_statement_circuit( + params, + builder, + st, + op, + prev_statement_flatteneds, + prev_statement_hashes, + )?; } } @@ -2221,6 +2264,14 @@ mod tests { let prev_statements_target: Vec<_> = (0..prev_statements.len()) .map(|_| builder.add_virtual_statement(false)) .collect(); + let prev_statement_flatteneds_target: Vec> = prev_statements_target + .iter() + .map(|st| st.flatten()) + .collect(); + let prev_statement_hashes_target: Vec<_> = prev_statement_flatteneds_target + .iter() + .map(|flat| builder.hash_n_to_hash_no_pad::(flat.clone())) + .collect(); let merkle_proofs_target: Vec<_> = aux .merkle_proofs @@ -2269,7 +2320,8 @@ mod tests { &mut builder, &st_target, &op_target, - &prev_statements_target, + &prev_statement_flatteneds_target, + &prev_statement_hashes_target, &aux_table, )?; @@ -2711,6 +2763,37 @@ mod tests { }) } + #[test] + fn test_operation_verify_sumof_non_monotonic_repeated_indices() -> Result<()> { + let local = dict!({ + "a" => 3, + "noise" => 99, + "sum" => 6, + }); + let st_a: mainpod::Statement = Statement::contains(local.clone(), "a", 3).into(); + let st_noise: mainpod::Statement = Statement::contains(local.clone(), "noise", 99).into(); + let st_sum: mainpod::Statement = Statement::contains(local.clone(), "sum", 6).into(); + + let st: mainpod::Statement = Statement::sum_of( + AnchoredKey::from((&local, "sum")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "a")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SumOf), + vec![ + // Non-monotonic and repeated indices to stress random-access resolution. + OperationArg::Index(2), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::None, + ); + let prev_statements = vec![st_a, st_noise, st_sum]; + operation_verify(st, op, prev_statements, Aux::default()) + } + #[test] fn test_operation_verify_productof() -> Result<()> { I64_TEST_PAIRS diff --git a/src/backends/plonky2/circuits/mux_table.rs b/src/backends/plonky2/circuits/mux_table.rs index 110dac9..c93d0e8 100644 --- a/src/backends/plonky2/circuits/mux_table.rs +++ b/src/backends/plonky2/circuits/mux_table.rs @@ -107,11 +107,11 @@ impl MuxTableTarget { rev_resolved_tagged_flattened.reverse(); let resolved_tagged_flattened = rev_resolved_tagged_flattened; - builder.add_simple_generator(TableGetGenerator { - index: index.clone(), - tagged_entries: self.tagged_entries.clone(), - get_tagged_entry: resolved_tagged_flattened.clone(), - }); + builder.add_simple_generator(TableGetGenerator::new( + index.clone(), + self.tagged_entries.clone(), + resolved_tagged_flattened.clone(), + )); measure_gates_end!(builder, measure); TableEntryTarget { params: self.params.clone(), @@ -123,8 +123,18 @@ impl MuxTableTarget { #[derive(Debug, Clone, Default)] pub struct TableGetGenerator { index: IndexTarget, - tagged_entries: Vec>, - get_tagged_entry: Vec, + entries: Vec>, + revealed_entry: Vec, +} + +impl TableGetGenerator { + pub fn new(index: IndexTarget, entries: Vec>, revealed_entry: Vec) -> Self { + Self { + index, + entries, + revealed_entry, + } + } } impl, const D: usize> SimpleGenerator for TableGetGenerator { @@ -135,7 +145,7 @@ impl, const D: usize> SimpleGenerator for Tab fn dependencies(&self) -> Vec { [self.index.low, self.index.high] .into_iter() - .chain(self.tagged_entries.iter().flatten().copied()) + .chain(self.entries.iter().flatten().copied()) .collect() } @@ -148,12 +158,12 @@ impl, const D: usize> SimpleGenerator for Tab let index_high = witness.get_target(self.index.high); let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64(); - let entry = witness.get_targets(&self.tagged_entries[index as usize]); + let entry = witness.get_targets(&self.entries[index as usize]); - for (target, value) in self.get_tagged_entry.iter().zip( + for (target, value) in self.revealed_entry.iter().zip( entry .iter() - .chain(iter::repeat(&F::ZERO).take(self.get_tagged_entry.len())), + .chain(iter::repeat(&F::ZERO).take(self.revealed_entry.len())), ) { out_buffer.set_target(*target, *value)?; } @@ -166,12 +176,12 @@ impl, const D: usize> SimpleGenerator for Tab dst.write_target(self.index.low)?; dst.write_target(self.index.high)?; - dst.write_usize(self.tagged_entries.len())?; - for tagged_entry in &self.tagged_entries { - dst.write_target_vec(tagged_entry)?; + dst.write_usize(self.entries.len())?; + for entry in &self.entries { + dst.write_target_vec(entry)?; } - dst.write_target_vec(&self.get_tagged_entry) + dst.write_target_vec(&self.revealed_entry) } fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { @@ -181,16 +191,16 @@ impl, const D: usize> SimpleGenerator for Tab high: src.read_target()?, }; let len = src.read_usize()?; - let mut tagged_entries = Vec::with_capacity(len); + let mut entries = Vec::with_capacity(len); for _ in 0..len { - tagged_entries.push(src.read_target_vec()?); + entries.push(src.read_target_vec()?); } - let get_tagged_entry = src.read_target_vec()?; + let revealed_entry = src.read_target_vec()?; Ok(Self { index, - tagged_entries, - get_tagged_entry, + entries, + revealed_entry, }) } } From 5e3ac9a101fa3cd82ad6ab99497c158be8f5a69d Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Wed, 6 May 2026 12:39:27 +0200 Subject: [PATCH 10/10] Support mixed depth container merkle proofs (#508) * remove enabled flag from merkle tree proofs * add small existence mpt proofs in MainPod * refactor params, add small transition proofs * complete * fix edge case in vdset * fix: use existence only proof for vdset * use consistent order for aux table --- src/backends/plonky2/basetypes.rs | 11 +- src/backends/plonky2/circuits/common.rs | 53 +- .../circuits/{mainpod.rs => mainpod/mod.rs} | 2137 ++--------------- .../plonky2/circuits/mainpod/tests.rs | 1707 +++++++++++++ src/backends/plonky2/mainpod/mod.rs | 149 +- src/backends/plonky2/mainpod/operation.rs | 124 +- src/backends/plonky2/mock/mainpod.rs | 43 +- .../plonky2/primitives/merkletree/circuit.rs | 204 +- .../plonky2/primitives/merkletree/mod.rs | 39 +- src/frontend/multi_pod/diagnostics.rs | 19 +- src/frontend/multi_pod/solver.rs | 14 +- src/middleware/mod.rs | 55 +- 12 files changed, 2366 insertions(+), 2189 deletions(-) rename src/backends/plonky2/circuits/{mainpod.rs => mainpod/mod.rs} (53%) create mode 100644 src/backends/plonky2/circuits/mainpod/tests.rs diff --git a/src/backends/plonky2/basetypes.rs b/src/backends/plonky2/basetypes.rs index d7d6b39..f65eb7b 100644 --- a/src/backends/plonky2/basetypes.rs +++ b/src/backends/plonky2/basetypes.rs @@ -51,7 +51,7 @@ use crate::{ mainpod::cache_get_rec_main_pod_verifier_circuit_data, primitives::merkletree::MerkleClaimAndProof, }, - middleware::{containers::Array, Hash, Params, RawValue, Result, Value}, + middleware::{containers::Array, Hash, Params, RawValue, Result, Value, EMPTY_HASH}, }; pub static DEFAULT_VD_LIST: LazyLock> = LazyLock::new(|| { @@ -95,6 +95,12 @@ impl Eq for VDSet {} impl VDSet { fn new_from_vds_hashes(mut vds_hashes: Vec) -> Self { + // If vds_hashes is empty we add an zero entry to be used as padding when verifying merkle + // proofs of inclusion in the vds set. This zero entry can't be abused because no circuit + // exists with a vds_hash = 0. + if vds_hashes.is_empty() { + vds_hashes.push(EMPTY_HASH); + } // before using the hash values, sort them, so that each set of // verifier_datas gets the same VDSet root vds_hashes.sort(); @@ -150,6 +156,9 @@ impl VDSet { ))? .clone()) } + pub fn get_vds_proof_0(&self) -> MerkleClaimAndProof { + self.proofs_map[&self.vds_hashes[0]].clone() + } /// Returns true if the `verifier_data_hash` is in the set pub fn contains(&self, verifier_data_hash: HashOut) -> bool { self.proofs_map diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index dfee8a0..bb194a0 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -30,7 +30,7 @@ use crate::{ mainpod::{Operation, OperationArg, OperationAux, Statement}, primitives::merkletree::{ verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget, - MerkleProof, MerkleTreeStateTransitionProofTarget, + MerkleProof, MerkleProofExistenceTarget, MerkleTreeStateTransitionProofTarget, }, }, middleware::{ @@ -725,7 +725,6 @@ impl CustomPredicateInBatchTarget { let mtp = MerkleClaimAndProofTarget::new_virtual(Params::max_depth_custom_batch_mt(), builder); let _true = builder._true(); - builder.connect(_true.target, mtp.enabled.target); builder.connect(_true.target, mtp.existence.target); let zero = builder.constant(F(0)); let key = ValueTarget { @@ -763,7 +762,7 @@ impl CustomPredicateInBatchTarget { value: RawValue::from(hash_fields(&predicate.to_fields())), proof: mtp.clone(), }; - self.mtp.set_targets(pw, true, &mtp_claim)?; + self.mtp.set_targets(pw, &mtp_claim)?; Ok(()) } } @@ -987,7 +986,6 @@ pub trait Flattenable { /// elsewhere. #[derive(Copy, Clone)] pub struct MerkleClaimTarget { - pub(crate) enabled: BoolTarget, pub(crate) root: HashOutTarget, pub(crate) key: ValueTarget, pub(crate) value: ValueTarget, @@ -997,7 +995,6 @@ pub struct MerkleClaimTarget { impl From for MerkleClaimTarget { fn from(pf: MerkleClaimAndProofTarget) -> Self { Self { - enabled: pf.enabled, root: pf.root, key: pf.key, value: pf.value, @@ -1006,12 +1003,25 @@ impl From for MerkleClaimTarget { } } +impl MerkleClaimTarget { + pub fn from_proof_existence( + builder: &mut CircuitBuilder, + pf: MerkleProofExistenceTarget, + ) -> Self { + Self { + root: pf.root, + key: pf.key, + value: pf.value, + existence: builder._true(), + } + } +} + /// For the purpose of op verification, we need only look up the /// Merkle state transition claim rather than the Merkle state /// transition proof since it is verified elsewhere. #[derive(Copy, Clone)] pub struct MerkleTreeStateTransitionClaimTarget { - pub(crate) enabled: BoolTarget, pub(crate) op: Target, pub(crate) old_root: HashOutTarget, pub(crate) new_root: HashOutTarget, @@ -1022,7 +1032,6 @@ pub struct MerkleTreeStateTransitionClaimTarget { impl From for MerkleTreeStateTransitionClaimTarget { fn from(pf: MerkleTreeStateTransitionProofTarget) -> Self { Self { - enabled: pf.enabled, op: pf.op, old_root: pf.old_root, new_root: pf.new_root, @@ -1063,7 +1072,6 @@ impl Flattenable for ValueTarget { impl Flattenable for MerkleClaimTarget { fn flatten(&self) -> Vec { [ - vec![self.enabled.target], self.root.elements.to_vec(), self.key.elements.to_vec(), self.value.elements.to_vec(), @@ -1075,31 +1083,28 @@ impl Flattenable for MerkleClaimTarget { fn from_flattened(params: &Params, vs: &[Target]) -> Self { assert_eq!(vs.len(), Self::size(params)); Self { - enabled: BoolTarget::new_unsafe(vs[0]), - root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()), - key: ValueTarget::from_slice( - &vs[1 + NUM_HASH_OUT_ELTS..1 + NUM_HASH_OUT_ELTS + VALUE_SIZE], - ), + root: HashOutTarget::from_vec(vs[0..NUM_HASH_OUT_ELTS].to_vec()), + key: ValueTarget::from_slice(&vs[NUM_HASH_OUT_ELTS..NUM_HASH_OUT_ELTS + VALUE_SIZE]), value: ValueTarget::from_slice( - &vs[1 + NUM_HASH_OUT_ELTS + VALUE_SIZE..1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], + &vs[NUM_HASH_OUT_ELTS + VALUE_SIZE..NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], ), - existence: BoolTarget::new_unsafe(vs[1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]), + existence: BoolTarget::new_unsafe(vs[NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]), } } fn size(params: &Params) -> usize { - 2 + HashOutTarget::size(params) + 2 * ValueTarget::size(params) + HashOutTarget::size(params) + 2 * ValueTarget::size(params) + 1 } } impl Flattenable for MerkleTreeStateTransitionClaimTarget { fn flatten(&self) -> Vec { [ - vec![self.enabled.target, self.op], self.old_root.elements.to_vec(), self.new_root.elements.to_vec(), self.op_key.elements.to_vec(), self.op_value.elements.to_vec(), + vec![self.op], ] .concat() } @@ -1107,24 +1112,22 @@ impl Flattenable for MerkleTreeStateTransitionClaimTarget { fn from_flattened(params: &Params, vs: &[Target]) -> Self { assert_eq!(vs.len(), Self::size(params)); Self { - enabled: BoolTarget::new_unsafe(vs[0]), - op: vs[1], - old_root: HashOutTarget::from_vec(vs[2..2 + NUM_HASH_OUT_ELTS].to_vec()), + old_root: HashOutTarget::from_vec(vs[0..NUM_HASH_OUT_ELTS].to_vec()), new_root: HashOutTarget::from_vec( - vs[2 + NUM_HASH_OUT_ELTS..2 * (1 + NUM_HASH_OUT_ELTS)].to_vec(), + vs[NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS].to_vec(), ), op_key: ValueTarget::from_slice( - &vs[2 * (1 + NUM_HASH_OUT_ELTS)..2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE], + &vs[2 * NUM_HASH_OUT_ELTS..2 * NUM_HASH_OUT_ELTS + VALUE_SIZE], ), op_value: ValueTarget::from_slice( - &vs[2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE - ..2 * (1 + NUM_HASH_OUT_ELTS) + 2 * VALUE_SIZE], + &vs[2 * NUM_HASH_OUT_ELTS + VALUE_SIZE..2 * NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], ), + op: vs[2 * NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE], } } fn size(params: &Params) -> usize { - 2 * (1 + HashOutTarget::size(params)) + 2 * ValueTarget::size(params) + 2 * HashOutTarget::size(params) + 2 * ValueTarget::size(params) + 1 } } diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod/mod.rs similarity index 53% rename from src/backends/plonky2/circuits/mainpod.rs rename to src/backends/plonky2/circuits/mainpod/mod.rs index 0605e07..89ed3cf 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod/mod.rs @@ -16,6 +16,9 @@ use plonky2::{ use plonky2_u32::gadgets::multiple_comparison::list_le_circuit; use serde::{Deserialize, Serialize}; +#[cfg(test)] +mod tests; + use crate::{ backends::plonky2::{ basetypes::{CircuitBuilder, VDSet}, @@ -33,7 +36,7 @@ use crate::{ }, emptypod::EmptyPod, error::Result, - mainpod::{self, pad_statement, SignedBy}, + mainpod::{self, pad_statement, MerkleProofs, MerkleTransitionProofs, SignedBy}, primitives::{ ec::{ bits::{BigUInt320Target, CircuitBuilderBits}, @@ -44,8 +47,9 @@ use crate::{ schnorr::{CircuitBuilderSchnorr, SecretKey, SignatureTarget, WitnessWriteSchnorr}, }, merkletree::{ - verify_merkle_proof_circuit, verify_merkle_state_transition_circuit, - MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProof, MerkleTreeOp, + verify_merkle_proof_circuit, verify_merkle_proof_existence_circuit, + verify_merkle_state_transition_circuit, MerkleClaimAndProof, + MerkleClaimAndProofTarget, MerkleProof, MerkleProofExistenceTarget, MerkleTreeOp, MerkleTreeStateTransitionProof, MerkleTreeStateTransitionProofTarget, }, signature::{verify_signature_circuit, SignatureVerifyTarget}, @@ -55,8 +59,8 @@ use crate::{ measure_gates_begin, measure_gates_end, middleware::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, - NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix, - ToFields, Value, BASE_PARAMS, F, HASH_SIZE, + NativePredicate, Params, PredicatePrefix, Statement, StatementTmplArgPrefix, ToFields, + Value, BASE_PARAMS, F, HASH_SIZE, VALUE_SIZE, }, }; // @@ -238,21 +242,21 @@ fn verify_operation_public_statement_circuit( enum OperationAuxTableTag { None = 0, MerkleProof = 1, - PublicKeyOf = 2, - SignedBy = 3, - MerkleTreeStateTransitionProof = 4, - CustomPredVerify = 5, + MerkleTransitionProof = 2, + CustomPredVerify = 3, + PublicKeyOf = 4, + SignedBy = 5, } fn max_operation_aux_entry_len(params: &Params) -> usize { [ - (params.max_merkle_proofs_containers > 0).then(|| MerkleClaimTarget::size(params)), - (params.max_public_key_of > 0).then(|| PubKeySecKeyTarget::size(params)), - (params.max_signed_by > 0).then(|| MsgPubKeyTarget::size(params)), - (params.max_merkle_tree_state_transition_proofs_containers > 0) + (params.containers.state.max_total() > 0).then(|| MerkleClaimTarget::size(params)), + (params.containers.transition.max_total() > 0) .then(|| MerkleTreeStateTransitionClaimTarget::size(params)), (params.max_custom_predicate_verifications > 0) .then(|| CustomPredicateVerifyQueryTarget::size(params)), + (params.max_public_key_of > 0).then(|| PubKeySecKeyTarget::size(params)), + (params.max_signed_by > 0).then(|| MsgPubKeyTarget::size(params)), ] .into_iter() .flatten() @@ -306,14 +310,59 @@ impl SignedByTarget { } } +fn append_container_proofs_operation_aux_table_circuit( + builder: &mut CircuitBuilder, + table: &mut MuxTableTarget, + merkle_proofs: &MerkleProofsTarget, + merkle_transition_proofs: &MerkleTransitionProofsTarget, +) { + // Small MerkleProofs: verify container merkle proofs (only inclusion) + for merkle_proof in &merkle_proofs.small { + verify_merkle_proof_existence_circuit(builder, merkle_proof); + let entry = MerkleClaimTarget::from_proof_existence(builder, merkle_proof.clone()); + + table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry); + } + // Medium MerkleProofs: verify container merkle proofs (inclusion/non-inclusion) + for merkle_proof in &merkle_proofs.medium { + verify_merkle_proof_circuit(builder, merkle_proof); + let entry = MerkleClaimTarget::from(merkle_proof.clone()); + + table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry); + } + + // Small Merkle state transition proofs: verify op proof (only update) + for merkle_transition_proof in &merkle_transition_proofs.small { + verify_merkle_state_transition_circuit(builder, merkle_transition_proof); + let entry = MerkleTreeStateTransitionClaimTarget::from(merkle_transition_proof.clone()); + + table.push( + builder, + OperationAuxTableTag::MerkleTransitionProof as u32, + &entry, + ); + } + // Medium Merkle state transition proofs: verify op proof (insert/update/delete) + for merkle_transition_proof in &merkle_transition_proofs.medium { + verify_merkle_state_transition_circuit(builder, merkle_transition_proof); + let entry = MerkleTreeStateTransitionClaimTarget::from(merkle_transition_proof.clone()); + + table.push( + builder, + OperationAuxTableTag::MerkleTransitionProof as u32, + &entry, + ); + } +} + #[allow(clippy::too_many_arguments)] fn build_operation_aux_table_circuit( params: &Params, builder: &mut CircuitBuilder, - merkle_proofs: &[MerkleClaimAndProofTarget], + merkle_proofs: &MerkleProofsTarget, + merkle_transition_proofs: &MerkleTransitionProofsTarget, public_key_of_sks: &[BigUInt320Target], signed_bys: &[SignedByTarget], - merkle_tree_state_transition_proofs: &[MerkleTreeStateTransitionProofTarget], custom_predicate_verifications: &[CustomPredicateVerifyEntryTarget], custom_predicate_table: &[HashOutTarget], ) -> Result { @@ -322,19 +371,56 @@ fn build_operation_aux_table_circuit( params.max_custom_predicate_verifications, custom_predicate_verifications.len() ); - assert_eq!(params.max_merkle_proofs_containers, merkle_proofs.len()); + assert_eq!(params.containers.state.max_small, merkle_proofs.small.len()); + assert_eq!( + params.containers.state.max_medium, + merkle_proofs.medium.len() + ); let max_entry_len = max_operation_aux_entry_len(params); let mut table = MuxTableTarget::new(params, max_entry_len); // None table.push_flattened(builder, OperationAuxTableTag::None as u32, &[]); - // MerkleProofs: verify container merkle proofs (inclusion/non-inclusion) - for merkle_proof in merkle_proofs { - verify_merkle_proof_circuit(builder, merkle_proof); - let entry = MerkleClaimTarget::from(merkle_proof.clone()); + append_container_proofs_operation_aux_table_circuit( + builder, + &mut table, + merkle_proofs, + merkle_transition_proofs, + ); - table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry); + // CustomPredVerify: verify custom predicate statements verification against operations + for entry in custom_predicate_verifications { + let measure = measure_gates_begin!(builder, "CustomPredVerify"); + // Verify the custom predicate operation + let (statement, op_type) = make_custom_statement_circuit( + params, + builder, + &entry.custom_predicate, + &entry.op_args, + &entry.args, + )?; + + // Check that the batch id is correct by querying the custom predicate batches table + let table_query_hash = builder.vec_ref( + params, + custom_predicate_table, + &entry.custom_predicate_table_index, + ); + let out_query_hash = entry.custom_predicate.hash(builder); + builder.connect_array(table_query_hash.elements, out_query_hash.elements); + + let query = CustomPredicateVerifyQueryTarget { + statement, // output + op_type, // output + op_args: entry.op_args.clone(), // input + }; + table.push( + builder, + OperationAuxTableTag::CustomPredVerify as u32, + &query, + ); + measure_gates_end!(builder, measure); } // PublicKeyOf: verify the derivation from a Schnorr secret key to public key @@ -394,53 +480,6 @@ fn build_operation_aux_table_circuit( measure_gates_end!(builder, measure); } - // Merkle state transition proofs: verify op proof (insert/update/delete) - for merkle_tree_state_transition_proof in merkle_tree_state_transition_proofs { - verify_merkle_state_transition_circuit(builder, merkle_tree_state_transition_proof); - let entry = - MerkleTreeStateTransitionClaimTarget::from(merkle_tree_state_transition_proof.clone()); - - table.push( - builder, - OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, - &entry, - ); - } - - // CustomPredVerify: verify custom predicate statements verification against operations - for entry in custom_predicate_verifications { - let measure = measure_gates_begin!(builder, "CustomPredVerify"); - // Verify the custom predicate operation - let (statement, op_type) = make_custom_statement_circuit( - params, - builder, - &entry.custom_predicate, - &entry.op_args, - &entry.args, - )?; - - // Check that the batch id is correct by querying the custom predicate batches table - let table_query_hash = builder.vec_ref( - params, - custom_predicate_table, - &entry.custom_predicate_table_index, - ); - let out_query_hash = entry.custom_predicate.hash(builder); - builder.connect_array(table_query_hash.elements, out_query_hash.elements); - - let query = CustomPredicateVerifyQueryTarget { - statement, // output - op_type, // output - op_args: entry.op_args.clone(), // input - }; - table.push( - builder, - OperationAuxTableTag::CustomPredVerify as u32, - &query, - ); - measure_gates_end!(builder, measure); - } - measure_gates_end!(builder, measure); Ok(table) } @@ -504,7 +543,7 @@ fn verify_operation_circuit( } // Skip these if there are no resolved aux entries if let Some(resolved_aux) = resolved_aux { - if params.max_merkle_proofs_containers > 0 { + if params.containers.state.max_total() > 0 { op_checks.extend_from_slice(&[ verify_contains_from_entries_circuit( params, @@ -544,7 +583,7 @@ fn verify_operation_circuit( &cache, )); } - if params.max_merkle_tree_state_transition_proofs_containers > 0 { + if params.containers.transition.max_total() > 0 { op_checks.extend_from_slice(&[ verify_merkle_insert_circuit( params, @@ -612,8 +651,6 @@ fn verify_contains_from_entries_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ - /* The supplied Merkle proof must be enabled. */ - resolved_merkle_claim.enabled, /* ...and it must be an existence proof. */ resolved_merkle_claim.existence, /* ...for the root-key-value triple in the resolved op args. */ @@ -661,8 +698,6 @@ fn verify_not_contains_from_entries_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ - /* The supplied Merkle proof must be enabled. */ - resolved_merkle_claim.enabled, /* ...and it must be a nonexistence proof. */ builder.not(resolved_merkle_claim.existence), /* ...for the root-key pair in the resolved op args. */ @@ -703,7 +738,7 @@ fn verify_merkle_insert_circuit( let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = aux.as_type::( builder, - OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, + OperationAuxTableTag::MerkleTransitionProof as u32, ); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerInsertFromEntries); @@ -714,8 +749,6 @@ fn verify_merkle_insert_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ - /* The supplied Merkle transition proof must be enabled. */ - resolved_merkle_tree_state_transition_claim.enabled, /* ...and it must be an insertion proof. */ builder.is_equal( resolved_merkle_tree_state_transition_claim.op, @@ -778,7 +811,7 @@ fn verify_merkle_update_circuit( let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = aux.as_type::( builder, - OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, + OperationAuxTableTag::MerkleTransitionProof as u32, ); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerUpdateFromEntries); @@ -789,8 +822,6 @@ fn verify_merkle_update_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ - /* The supplied Merkle transition proof must be enabled. */ - resolved_merkle_tree_state_transition_claim.enabled, /* ...and it must be an update proof. */ builder.is_equal( resolved_merkle_tree_state_transition_claim.op, @@ -853,7 +884,7 @@ fn verify_merkle_delete_circuit( let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = aux.as_type::( builder, - OperationAuxTableTag::MerkleTreeStateTransitionProof as u32, + OperationAuxTableTag::MerkleTransitionProof as u32, ); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerDeleteFromEntries); @@ -864,8 +895,6 @@ fn verify_merkle_delete_circuit( // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ - /* The supplied Merkle transition proof must be enabled. */ - resolved_merkle_tree_state_transition_claim.enabled, /* ...and it must be a deletion proof. */ builder.is_equal( resolved_merkle_tree_state_transition_claim.op, @@ -1774,19 +1803,20 @@ fn verify_main_pod_circuit( // NOTE: We use an EmptyPod for padding input pod slots. The EmptyPod is an introduction // pod that declares a statement with no arguments. - let is_blank_intro = input_pod_self_statements[0].pred_is_blank_intro(builder); + let st0_is_intro = input_pod_self_statements[0].pred_is_blank_intro(builder); // Introduction pods can only have Introduction or None statements - let mut intro_ok = is_blank_intro; + let mut intro_ok = st0_is_intro; for self_st in &input_pod_self_statements[1..] { let st_is_intro = self_st.pred_is_blank_intro(builder); let st_is_none = self_st.has_native_type(builder, NativePredicate::None); let st_is_intro_or_none = builder.or(st_is_intro, st_is_none); intro_ok = builder.and(intro_ok, st_is_intro_or_none); } - builder.connect(is_blank_intro.target, intro_ok.target); + builder.connect(st0_is_intro.target, intro_ok.target); - let is_main = builder.not(is_blank_intro); + let is_not_main = st0_is_intro; + let is_main = builder.not(is_not_main); for self_st in input_pod_self_statements { let normalized_st = normalize_statement_circuit( params, @@ -1805,18 +1835,19 @@ fn verify_main_pod_circuit( // their verifier_data_hash appears in their introduction statement. // - verify_merkle_proof_circuit(builder, vd_mt_proof); + verify_merkle_proof_existence_circuit(builder, vd_mt_proof); - // ensure that mt_proof is enabled if it's a main pod - builder.connect(vd_mt_proof.enabled.target, is_main.target); // connect the vd_mt_proof's root to the actual vds_root, to ensure that the mt proof // verifies against the vds_root builder.connect_hashes(main_pod.vds_root, vd_mt_proof.root); - // connect vd_mt_proof's value with the verified_proof.verifier_data_hash - builder.connect_hashes( - verified_proof.verifier_data_hash, - HashOutTarget::from_vec(vd_mt_proof.value.elements.to_vec()), - ); + // connect vd_mt_proof's value with the verified_proof.verifier_data_hash only when is_main + for i in 0..VALUE_SIZE { + builder.conditional_assert_eq( + is_main.target, + verified_proof.verifier_data_hash.elements[i], + vd_mt_proof.value.elements[i], + ) + } // // Verify that VD array that input pod uses is the same we use now. @@ -1846,9 +1877,9 @@ fn verify_main_pod_circuit( params, builder, &main_pod.merkle_proofs, + &main_pod.merkle_transition_proofs, &main_pod.public_key_of_sks, &main_pod.signed_bys, - &main_pod.merkle_tree_state_transition_proofs, &main_pod.custom_predicate_verifications, &custom_predicate_table, )?; @@ -1894,19 +1925,77 @@ fn verify_main_pod_circuit( Ok(sts_hash) } +#[derive(Clone, Serialize, Deserialize)] +pub struct MerkleProofsTarget { + small: Vec, + medium: Vec, +} + +impl MerkleProofsTarget { + pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { + Self { + small: (0..params.containers.state.max_small) + .map(|_| { + MerkleProofExistenceTarget::new_virtual( + params.containers.max_depth_small, + builder, + ) + }) + .collect(), + medium: (0..params.containers.state.max_medium) + .map(|_| { + MerkleClaimAndProofTarget::new_virtual( + params.containers.max_depth_medium, + builder, + ) + }) + .collect(), + } + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct MerkleTransitionProofsTarget { + small: Vec, + medium: Vec, +} + +impl MerkleTransitionProofsTarget { + pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { + Self { + small: (0..params.containers.transition.max_small) + .map(|_| { + MerkleTreeStateTransitionProofTarget::new_virtual( + params.containers.max_depth_small, + builder, + ) + }) + .collect(), + medium: (0..params.containers.transition.max_medium) + .map(|_| { + MerkleTreeStateTransitionProofTarget::new_virtual( + params.containers.max_depth_medium, + builder, + ) + }) + .collect(), + } + } +} + #[derive(Clone, Serialize, Deserialize)] pub struct MainPodVerifyTarget { params: Params, vds_root: HashOutTarget, - vd_mt_proofs: Vec, + vd_mt_proofs: Vec, input_pods_self_statements: Vec>, // The KEY_TYPE statement must be the first public one input_statements: Vec, operations: Vec, - merkle_proofs: Vec, + merkle_proofs: MerkleProofsTarget, + merkle_transition_proofs: MerkleTransitionProofsTarget, public_key_of_sks: Vec, signed_bys: Vec, - merkle_tree_state_transition_proofs: Vec, custom_predicates: Vec, custom_predicate_verifications: Vec, } @@ -1917,7 +2006,7 @@ impl MainPodVerifyTarget { params: params.clone(), vds_root: builder.add_virtual_hash(), vd_mt_proofs: (0..params.max_input_pods) - .map(|_| MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_vds, builder)) + .map(|_| MerkleProofExistenceTarget::new_virtual(params.max_depth_mt_vds, builder)) .collect(), input_pods_self_statements: (0..params.max_input_pods) .map(|_| { @@ -1932,26 +2021,14 @@ impl MainPodVerifyTarget { operations: (0..params.max_statements) .map(|_| builder.add_virtual_operation(params)) .collect(), - merkle_proofs: (0..params.max_merkle_proofs_containers) - .map(|_| { - MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, builder) - }) - .collect(), + merkle_proofs: MerkleProofsTarget::new_virtual(params, builder), + merkle_transition_proofs: MerkleTransitionProofsTarget::new_virtual(params, builder), public_key_of_sks: (0..params.max_public_key_of) .map(|_| builder.add_virtual_biguint320_target()) .collect(), signed_bys: (0..params.max_signed_by) .map(|_| SignedByTarget::new_virtual(builder)) .collect(), - merkle_tree_state_transition_proofs: (0..params - .max_merkle_tree_state_transition_proofs_containers) - .map(|_| { - MerkleTreeStateTransitionProofTarget::new_virtual( - params.max_depth_mt_containers, - builder, - ) - }) - .collect(), custom_predicates: (0..params.max_custom_predicates) .map(|_| CustomPredicateInBatchTarget::new_virtual(builder)) .collect(), @@ -1960,6 +2037,64 @@ impl MainPodVerifyTarget { .collect(), } } + + fn set_container_mtp_targets( + &self, + pw: &mut PartialWitness, + input: &MainPodVerifyInput, + ) -> Result<()> { + assert!(input.merkle_proofs.small.len() <= self.params.containers.state.max_small); + for (i, mp) in input.merkle_proofs.small.iter().enumerate() { + self.merkle_proofs.small[i].set_targets(pw, mp)?; + } + // Padding + let pad_mp = MerkleClaimAndProof::pad(); + for i in input.merkle_proofs.small.len()..self.params.containers.state.max_small { + self.merkle_proofs.small[i].set_targets(pw, &pad_mp)?; + } + + assert!(input.merkle_proofs.medium.len() <= self.params.containers.state.max_medium); + for (i, mp) in input.merkle_proofs.medium.iter().enumerate() { + self.merkle_proofs.medium[i].set_targets(pw, mp)?; + } + // Padding + let pad_mp = MerkleClaimAndProof::pad(); + for i in input.merkle_proofs.medium.len()..self.params.containers.state.max_medium { + self.merkle_proofs.medium[i].set_targets(pw, &pad_mp)?; + } + + assert!( + input.merkle_transition_proofs.small.len() + <= self.params.containers.transition.max_small + ); + for (i, mtp) in input.merkle_transition_proofs.small.iter().enumerate() { + self.merkle_transition_proofs.small[i].set_targets(pw, mtp)?; + } + // Padding + let pad_mtp = MerkleTreeStateTransitionProof::pad(); + for i in + input.merkle_transition_proofs.small.len()..self.params.containers.transition.max_small + { + self.merkle_transition_proofs.small[i].set_targets(pw, &pad_mtp)?; + } + + assert!( + input.merkle_transition_proofs.medium.len() + <= self.params.containers.transition.max_medium + ); + for (i, mtp) in input.merkle_transition_proofs.medium.iter().enumerate() { + self.merkle_transition_proofs.medium[i].set_targets(pw, mtp)?; + } + // Padding + let pad_mtp = MerkleTreeStateTransitionProof::pad(); + for i in input.merkle_transition_proofs.medium.len() + ..self.params.containers.transition.max_medium + { + self.merkle_transition_proofs.medium[i].set_targets(pw, &pad_mtp)?; + } + + Ok(()) + } } pub struct CustomPredicateVerification { @@ -1974,15 +2109,14 @@ pub struct MainPodVerifyInput { /// field containing the `vd_mt_proofs` aside from the `vds_set`, because /// inside the MainPodVerifyTarget circuit, since it is the InnerCircuit for /// the RecursiveCircuit, we don't have access to the used verifier_datas. - /// The bool is used as `enabled` and will be false for intro pods. - pub vd_mt_proofs: Vec<(bool, MerkleClaimAndProof)>, + pub vd_mt_proofs: Vec, pub input_pods_pub_self_statements: Vec>, pub statements: Vec, pub operations: Vec, - pub merkle_proofs: Vec, + pub merkle_proofs: MerkleProofs, + pub merkle_transition_proofs: MerkleTransitionProofs, pub public_key_of_sks: Vec, pub signed_bys: Vec, - pub merkle_tree_state_transition_proofs: Vec, pub custom_predicates_with_mpt_proofs: Vec<(CustomPredicateRef, MerkleProof)>, pub custom_predicate_verifications: Vec, } @@ -2038,8 +2172,8 @@ impl InnerCircuit for MainPodVerifyTarget { ); let input_pods_len = input.vd_mt_proofs.len(); assert!(input_pods_len <= self.params.max_input_pods); - for (i, (enable, vd_mt_proof)) in input.vd_mt_proofs.iter().enumerate() { - self.vd_mt_proofs[i].set_targets(pw, *enable, vd_mt_proof)?; + for (i, vd_mt_proof) in input.vd_mt_proofs.iter().enumerate() { + self.vd_mt_proofs[i].set_targets(pw, vd_mt_proof)?; } for (i, pod_pub_statements) in input.input_pods_pub_self_statements.iter().enumerate() { set_targets_input_pods_self_statements( @@ -2053,14 +2187,10 @@ impl InnerCircuit for MainPodVerifyTarget { if input_pods_len != self.params.max_input_pods { let empty_pod = EmptyPod::new_boxed(input.vds_set.clone()); let empty_pod_statements = empty_pod.pub_statements(); - let empty_mt_proof = MerkleClaimAndProof { - root: input.vds_set.root(), - value: RawValue::from(empty_pod.verifier_data_hash()), - ..MerkleClaimAndProof::empty() - }; + let pad_mt_proof = input.vds_set.get_vds_proof_0(); for i in input_pods_len..self.params.max_input_pods { - self.vd_mt_proofs[i].set_targets(pw, false, &empty_mt_proof)?; + self.vd_mt_proofs[i].set_targets(pw, &pad_mt_proof)?; set_targets_input_pods_self_statements( pw, &self.params, @@ -2076,15 +2206,7 @@ impl InnerCircuit for MainPodVerifyTarget { self.operations[i].set_targets(pw, &self.params, op)?; } - assert!(input.merkle_proofs.len() <= self.params.max_merkle_proofs_containers); - for (i, mp) in input.merkle_proofs.iter().enumerate() { - self.merkle_proofs[i].set_targets(pw, true, mp)?; - } - // Padding - let pad_mp = MerkleClaimAndProof::empty(); - for i in input.merkle_proofs.len()..self.params.max_merkle_proofs_containers { - self.merkle_proofs[i].set_targets(pw, false, &pad_mp)?; - } + self.set_container_mtp_targets(pw, input)?; assert!(input.public_key_of_sks.len() <= self.params.max_public_key_of); for (i, sk) in input.public_key_of_sks.iter().enumerate() { @@ -2106,25 +2228,6 @@ impl InnerCircuit for MainPodVerifyTarget { self.signed_bys[i].set_targets(pw, &pad_signed_by)?; } - assert!( - input.merkle_tree_state_transition_proofs.len() - <= self - .params - .max_merkle_tree_state_transition_proofs_containers - ); - for (i, mtp) in input.merkle_tree_state_transition_proofs.iter().enumerate() { - self.merkle_tree_state_transition_proofs[i].set_targets(pw, true, mtp)?; - } - // Padding - let pad_mtp = MerkleTreeStateTransitionProof::empty(); - for i in input.merkle_tree_state_transition_proofs.len() - ..self - .params - .max_merkle_tree_state_transition_proofs_containers - { - self.merkle_tree_state_transition_proofs[i].set_targets(pw, false, &pad_mtp)?; - } - assert!(input.custom_predicates_with_mpt_proofs.len() <= self.params.max_custom_predicates); for (i, (cp, mtp)) in input.custom_predicates_with_mpt_proofs.iter().enumerate() { self.custom_predicates[i].set_targets(pw, cp, mtp)?; @@ -2169,1729 +2272,3 @@ impl InnerCircuit for MainPodVerifyTarget { Ok(()) } } - -#[cfg(test)] -mod tests { - use std::{iter, ops::Not}; - - use num::FromPrimitive; - use plonky2::{ - field::{goldilocks_field::GoldilocksField, types::Field}, - hash::hash_types::HashOut, - iop::witness::WitnessWrite, - plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, - }; - - use super::*; - use crate::{ - backends::plonky2::{ - basetypes::C, - circuits::common::tests::I64_TEST_PAIRS, - mainpod::{calculate_statements_hash, OperationArg, OperationAux}, - primitives::{ - ec::schnorr::SecretKey, - merkletree::{MerkleClaimAndProof, MerkleTree, MerkleTreeStateTransitionProof}, - }, - signer, - }, - dict, - frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, - middleware::{ - hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, - RawValue, StatementArg, StatementTmpl, StatementTmplArg, ValueRef, Wildcard, - BASE_PARAMS, EMPTY_VALUE, - }, - }; - - #[derive(Default)] - struct Aux { - merkle_proofs: Vec, - secret_keys: Vec, - signed_bys: Vec, - merkle_tree_state_transition_proofs: Vec, - } - - impl Aux { - fn merkle_proof(v: MerkleClaimAndProof) -> Self { - Self { - merkle_proofs: vec![v], - ..Default::default() - } - } - fn secret_key(v: SecretKey) -> Self { - Self { - secret_keys: vec![v], - ..Default::default() - } - } - fn signed_by(v: SignedBy) -> Self { - Self { - signed_bys: vec![v], - ..Default::default() - } - } - fn merkle_tree_state_transition_proof(v: MerkleTreeStateTransitionProof) -> Self { - Self { - merkle_tree_state_transition_proofs: vec![v], - ..Default::default() - } - } - } - - fn operation_verify( - st: mainpod::Statement, - op: mainpod::Operation, - prev_statements: Vec, - aux: Aux, - ) -> Result<()> { - let params = Params { - max_merkle_proofs_containers: aux.merkle_proofs.len(), - max_public_key_of: aux.secret_keys.len(), - max_signed_by: aux.signed_bys.len(), - max_merkle_tree_state_transition_proofs_containers: aux - .merkle_tree_state_transition_proofs - .len(), - max_custom_predicate_verifications: 0, - max_custom_predicates: 0, - ..Default::default() - }; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let st_target = builder.add_virtual_statement(false); - let op_target = builder.add_virtual_operation(¶ms); - let prev_statements_target: Vec<_> = (0..prev_statements.len()) - .map(|_| builder.add_virtual_statement(false)) - .collect(); - let prev_statement_flatteneds_target: Vec> = prev_statements_target - .iter() - .map(|st| st.flatten()) - .collect(); - let prev_statement_hashes_target: Vec<_> = prev_statement_flatteneds_target - .iter() - .map(|flat| builder.hash_n_to_hash_no_pad::(flat.clone())) - .collect(); - - let merkle_proofs_target: Vec<_> = aux - .merkle_proofs - .iter() - .map(|_| { - MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, &mut builder) - }) - .collect(); - - let secret_keys_target: Vec<_> = aux - .secret_keys - .iter() - .map(|sk| builder.constant_biguint320(&sk.0)) - .collect(); - - let signed_by_targets: Vec<_> = aux - .signed_bys - .iter() - .map(|_| SignedByTarget::new_virtual(&mut builder)) - .collect(); - - let merkle_tree_state_transition_proofs_target: Vec<_> = aux - .merkle_tree_state_transition_proofs - .iter() - .map(|_| { - MerkleTreeStateTransitionProofTarget::new_virtual( - params.max_depth_mt_containers, - &mut builder, - ) - }) - .collect(); - - let aux_table = build_operation_aux_table_circuit( - ¶ms, - &mut builder, - &merkle_proofs_target, - &secret_keys_target, - &signed_by_targets, - &merkle_tree_state_transition_proofs_target, - &[], - &[], - )?; - - verify_operation_circuit( - ¶ms, - &mut builder, - &st_target, - &op_target, - &prev_statement_flatteneds_target, - &prev_statement_hashes_target, - &aux_table, - )?; - - let mut pw = PartialWitness::::new(); - st_target.set_targets(&mut pw, &st)?; - op_target.set_targets(&mut pw, ¶ms, &op)?; - for (prev_st_target, prev_st) in prev_statements_target.iter().zip(prev_statements.iter()) { - prev_st_target.set_targets(&mut pw, prev_st)?; - } - for (signed_by_target, signed_by) in signed_by_targets.iter().zip(aux.signed_bys.iter()) { - signed_by_target.set_targets(&mut pw, signed_by)? - } - for (merkle_proof_target, merkle_proof) in - merkle_proofs_target.iter().zip(aux.merkle_proofs.iter()) - { - merkle_proof_target.set_targets(&mut pw, true, merkle_proof)? - } - for (merkle_tree_state_transition_proof_target, merkle_tree_state_transition_proof) in - merkle_tree_state_transition_proofs_target - .iter() - .zip(aux.merkle_tree_state_transition_proofs.iter()) - { - merkle_tree_state_transition_proof_target.set_targets( - &mut pw, - true, - merkle_tree_state_transition_proof, - )? - } - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - - Ok(()) - } - - #[test] - fn test_lt_lteq_verify_failures() { - let invalid_int = RawValue([ - GoldilocksField::NEG_ONE, - GoldilocksField::ZERO, - GoldilocksField::ZERO, - GoldilocksField::ZERO, - ]); - - let prev_statements = [Statement::None.into()]; - - [ - // 56 < 55, 55 < 55, 56 <= 55, -55 < -55, -55 < -56, -55 <= -56 should fail to verify - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(56, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(55, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt_eq(56, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(-55, -55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(-55, -56).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt_eq(-55, -56).into(), - ), - // 56 < p-1 and p-1 <= p-1 should fail to verify, where p - // is the Goldilocks prime and 'p-1' occupies a single - // limb. - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt(56, invalid_int).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::lt_eq(invalid_int, invalid_int).into(), - ), - ] - .into_iter() - .for_each(|(op, st)| { - let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - operation_verify(st, op, prev_statements.to_vec(), Aux::default()) - })); - match check { - Err(e) => { - let err_string = e.downcast_ref::().unwrap(); - if !err_string.contains("Integer too large to fit") { - panic!("Test failed with an unexpected error: {}", err_string); - } - } - Ok(Err(_)) => {} - _ => panic!("Test passed, yet it should have failed!"), - } - }); - } - - #[test] - fn test_eq_neq_verify_failures() { - let prev_statements = [Statement::None.into()]; - - [ - // 56 == 55, 55 != 55 should fail to verify - ( - mainpod::Operation( - OperationType::Native(NativeOperation::EqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::equal(56, 55).into(), - ), - ( - mainpod::Operation( - OperationType::Native(NativeOperation::NotEqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ), - Statement::not_equal(55, 55).into(), - ), - ] - .into_iter() - .for_each(|(op, st)| { - assert!(operation_verify(st, op, prev_statements.to_vec(), Aux::default()).is_err()) - }); - } - - #[test] - fn test_operation_verify_none() -> Result<()> { - let st: mainpod::Statement = Statement::None.into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::None), - vec![], - OperationAux::None, - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_copy() -> Result<()> { - let st: mainpod::Statement = Statement::None.into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::CopyStatement), - vec![OperationArg::Index(0)], - OperationAux::None, - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_eq() -> Result<()> { - let dict1 = dict!({"hello" => 55}); - let dict2 = dict!({"world" => 55}); - let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); - let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 55).into(); - let st: mainpod::Statement = Statement::equal( - AnchoredKey::from((&dict1, "hello")), - AnchoredKey::from((&dict2, "world")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::EqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_neq() -> Result<()> { - let dict1 = dict!({"hello" => 55}); - let dict2 = dict!({"world" => 75}); - let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); - let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 75).into(); - let st: mainpod::Statement = Statement::not_equal( - AnchoredKey::from((&dict1, "hello")), - AnchoredKey::from((&dict2, "world")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::NotEqualFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_lt() -> Result<()> { - let dict1 = dict!({"hello" => 55}); - let dict2 = dict!({"hello" => 56}); - let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); - let st2: mainpod::Statement = Statement::contains(dict2.clone(), "hello", 56).into(); - let st: mainpod::Statement = Statement::lt( - AnchoredKey::from((&dict1, "hello")), - AnchoredKey::from((&dict2, "hello")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2.clone()]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative < negative - let dict3 = dict!({"hola" => -56}); - let dict4 = dict!({"mundo" => -55}); - let st3: mainpod::Statement = Statement::contains(dict3.clone(), "hola", -56).into(); - let st4: mainpod::Statement = Statement::contains(dict4.clone(), "mundo", -55).into(); - let st: mainpod::Statement = Statement::lt( - AnchoredKey::from((&dict3, "hola")), - AnchoredKey::from((&dict4, "mundo")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3.clone(), st4]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative < positive - let st: mainpod::Statement = Statement::lt( - AnchoredKey::from((&dict3, "hola")), - AnchoredKey::from((&dict2, "hello")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3, st2]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_lteq() -> Result<()> { - let local = dict!({ - "n55" => 55, - "n56" => 56, - "n_56" => -56, - "n_55" => -55, - }); - let st1: mainpod::Statement = Statement::contains(local.clone(), "n55", 55).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "n56", 56).into(); - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n55")), - AnchoredKey::from((&local, "n56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2.clone()]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative <= negative - let st3: mainpod::Statement = Statement::contains(local.clone(), "n_56", -56).into(); - let st4: mainpod::Statement = Statement::contains(local.clone(), "n_55", -55).into(); - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n_56")), - AnchoredKey::from((&local, "n_55")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3.clone(), st4]; - operation_verify(st, op, prev_statements, Aux::default())?; - - // Also check negative <= positive - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n_56")), - AnchoredKey::from((&local, "n56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st3, st2]; - operation_verify(st, op, prev_statements.clone(), Aux::default())?; - - // Also check equality, both positive and negative. - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n_56")), - AnchoredKey::from((&local, "n_56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ); - operation_verify(st, op, prev_statements.clone(), Aux::default())?; - let st: mainpod::Statement = Statement::lt_eq( - AnchoredKey::from((&local, "n56")), - AnchoredKey::from((&local, "n56")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtEqFromEntries), - vec![OperationArg::Index(1), OperationArg::Index(1)], - OperationAux::None, - ); - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_hashof() -> Result<()> { - let input_values = [ - Value::from(RawValue([ - GoldilocksField(1), - GoldilocksField(2), - GoldilocksField(3), - GoldilocksField(4), - ])), - Value::from(512), - ]; - let v1 = hash_values(&input_values); - let [v2, v3] = input_values; - - let local = dict!({ - "hola" => v1, - "mundo" => v2.clone(), - "!" => v3.clone(), - }); - - let st1: mainpod::Statement = Statement::contains(local.clone(), "hola", v1).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "mundo", v2).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "!", v3).into(); - - let st: mainpod::Statement = Statement::hash_of( - AnchoredKey::from((&local, "hola")), - AnchoredKey::from((&local, "mundo")), - AnchoredKey::from((&local, "!")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::HashOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_sumof() -> Result<()> { - I64_TEST_PAIRS - .into_iter() - .flat_map(|(a, b)| { - let (sum, overflow) = a.overflowing_add(b); - overflow.not().then_some((a, b, sum)) - }) - .try_for_each(|(a, b, sum)| { - let local = dict!({ - "sum" => sum, - "a" => a, - "b" => b, - }); - - let st1: mainpod::Statement = Statement::contains(local.clone(), "sum", sum).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); - - let st: mainpod::Statement = Statement::sum_of( - AnchoredKey::from((&local, "sum")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::SumOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - }) - } - - #[test] - fn test_operation_verify_sumof_non_monotonic_repeated_indices() -> Result<()> { - let local = dict!({ - "a" => 3, - "noise" => 99, - "sum" => 6, - }); - let st_a: mainpod::Statement = Statement::contains(local.clone(), "a", 3).into(); - let st_noise: mainpod::Statement = Statement::contains(local.clone(), "noise", 99).into(); - let st_sum: mainpod::Statement = Statement::contains(local.clone(), "sum", 6).into(); - - let st: mainpod::Statement = Statement::sum_of( - AnchoredKey::from((&local, "sum")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "a")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::SumOf), - vec![ - // Non-monotonic and repeated indices to stress random-access resolution. - OperationArg::Index(2), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::None, - ); - let prev_statements = vec![st_a, st_noise, st_sum]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_productof() -> Result<()> { - I64_TEST_PAIRS - .into_iter() - .flat_map(|(a, b)| { - let (prod, overflow) = a.overflowing_mul(b); - overflow.not().then_some((a, b, prod)) - }) - .try_for_each(|(a, b, prod)| { - let local = dict!({ - "prod" => prod, - "a" => a, - "b" => b, - }); - - let st1: mainpod::Statement = - Statement::contains(local.clone(), "prod", prod).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); - - let st: mainpod::Statement = Statement::product_of( - AnchoredKey::from((&local, "prod")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ProductOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - }) - } - - #[test] - fn test_operation_verify_maxof() -> Result<()> { - I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| { - let max = i64::max(a, b); - let local = dict!({ - "max" => max, - "a" => a, - "b" => b, - }); - - let st1: mainpod::Statement = Statement::contains(local.clone(), "max", max).into(); - let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); - let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); - - let st: mainpod::Statement = Statement::max_of( - AnchoredKey::from((&local, "max")), - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - - let op = mainpod::Operation( - OperationType::Native(NativeOperation::MaxOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::None, - ); - let prev_statements = vec![st1, st2, st3]; - operation_verify(st, op, prev_statements, Aux::default()) - }) - } - - #[test] - fn test_operation_verify_maxof_failures() { - [(5, 3, 4), (5, 5, 8), (3, 4, 5)] - .into_iter() - .for_each(|(max, a, b)| { - let st: mainpod::Statement = Statement::max_of(max, a, b).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::MaxOf), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::None, - ); - let prev_statements = [Statement::None.into()]; - - let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - operation_verify(st, op, prev_statements.to_vec(), Aux::default()) - })); - match check { - Err(e) => { - let err_string = e.downcast_ref::().unwrap(); - if !err_string.contains("Integer too large to fit") { - panic!("Test failed with an unexpected error: {}", err_string); - } - } - Ok(Err(_)) => {} - _ => panic!("Test passed, yet it should have failed!"), - } - }) - } - - #[test] - fn test_operation_verify_lt_to_neq() -> Result<()> { - let local = dict!({ - "a" => 10, - "b" => 20, - }); - let st: mainpod::Statement = Statement::not_equal( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let st1: mainpod::Statement = Statement::lt( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::LtToNotEqual), - vec![OperationArg::Index(0)], - OperationAux::None, - ); - let prev_statements = vec![st1]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_transitive_eq() -> Result<()> { - let local = dict!({ - "a" => 10, - "b" => 10, - "c" => 10, - }); - let st: mainpod::Statement = Statement::equal( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "c")), - ) - .into(); - let st1: mainpod::Statement = Statement::equal( - AnchoredKey::from((&local, "a")), - AnchoredKey::from((&local, "b")), - ) - .into(); - let st2: mainpod::Statement = Statement::equal( - AnchoredKey::from((&local, "b")), - AnchoredKey::from((&local, "c")), - ) - .into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::TransitiveEqualFromStatements), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::None, - ); - let prev_statements = vec![st1, st2]; - operation_verify(st, op, prev_statements, Aux::default()) - } - - #[test] - fn test_operation_verify_sintains() -> Result<()> { - let kvs = [ - (1.into(), 55.into()), - (2.into(), 88.into()), - (175.into(), 0.into()), - ] - .into_iter() - .collect(); - let mt = MerkleTree::new(&kvs); - - let root = mt.root(); - let key = Value::from(5); - let local = dict!({ - "merkle_root" => root, - "key" => key.clone(), - }); - let root_ak = AnchoredKey::from((&local, "merkle_root")); - let key_ak = AnchoredKey::from((&local, "key")); - - let no_key_pf = mt.prove_nonexistence(&key.raw())?; - - let root_st: mainpod::Statement = - Statement::contains(local.clone(), "merkle_root", root).into(); - let key_st: mainpod::Statement = - Statement::contains(local.clone(), "key", key.clone()).into(); - let st: mainpod::Statement = Statement::not_contains(root_ak, key_ak).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::NotContainsFromEntries), - vec![OperationArg::Index(0), OperationArg::Index(1)], - OperationAux::MerkleProofIndex(0), - ); - - let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), None, no_key_pf); - let prev_statements = vec![root_st, key_st]; - operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) - } - - #[test] - fn test_operation_verify_contains() -> Result<()> { - let kvs = [ - (1.into(), 55.into()), - (2.into(), 88.into()), - (175.into(), 0.into()), - ] - .into_iter() - .collect(); - let mt = MerkleTree::new(&kvs); - - let root = mt.root(); - let key = Value::from(175); - let (value, key_pf) = mt.prove(&key.raw())?; - let local = dict!({ - "merkle_root" => root, - "key" => key.clone(), - "value" => value, - }); - let root_ak = AnchoredKey::from((&local, "merkle_root")); - let key_ak = AnchoredKey::from((&local, "key")); - let value_ak = AnchoredKey::from((&local, "value")); - - let root_st: mainpod::Statement = - Statement::contains(local.clone(), "merkle_root", root).into(); - let key_st: mainpod::Statement = - Statement::contains(local.clone(), "key", key.clone()).into(); - let value_st: mainpod::Statement = - Statement::contains(local.clone(), "value", value).into(); - - let st: mainpod::Statement = Statement::contains(root_ak, key_ak, value_ak).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainsFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(1), - OperationArg::Index(2), - ], - OperationAux::MerkleProofIndex(0), - ); - - let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), Some(value), key_pf); - let prev_statements = vec![root_st, key_st, value_st]; - operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) - } - - #[test] - fn test_operation_verify_merkle_insert() -> Result<()> { - let mut tree = MerkleTree::new(&[].into()); - - let key = Value::from(175); - let value = Value::from(0); - let state_transition_proof = tree.insert(&key.raw(), &value.raw())?; - let old_root = state_transition_proof.old_root; - let new_root = state_transition_proof.new_root; - - let st: mainpod::Statement = Statement::insert(new_root, old_root, key, value).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainerInsertFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::MerkleTreeStateTransitionProofIndex(0), - ); - - let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, aux) - } - - #[test] - fn test_operation_verify_merkle_update() -> Result<()> { - let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); - - let key = Value::from(175); - let value = Value::from(0); - let state_transition_proof = tree.update(&key.raw(), &value.raw())?; - let old_root = state_transition_proof.old_root; - let new_root = state_transition_proof.new_root; - - let st: mainpod::Statement = Statement::update(new_root, old_root, key, value).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainerUpdateFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::MerkleTreeStateTransitionProofIndex(0), - ); - - let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, aux) - } - - #[test] - fn test_operation_verify_merkle_delete() -> Result<()> { - let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); - - let key = Value::from(175); - let state_transition_proof = tree.delete(&key.raw())?; - let old_root = state_transition_proof.old_root; - let new_root = state_transition_proof.new_root; - - let st: mainpod::Statement = Statement::delete(new_root, old_root, key).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ContainerDeleteFromEntries), - vec![ - OperationArg::Index(0), - OperationArg::Index(0), - OperationArg::Index(0), - ], - OperationAux::MerkleTreeStateTransitionProofIndex(0), - ); - - let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, aux) - } - - #[test] - fn test_operation_verify_publickeyof_ok() -> Result<()> { - [ - SecretKey(BigUint::one()), - SecretKey::new_rand(), - SecretKey(&*GROUP_ORDER - BigUint::one()), - ] - .into_iter() - .try_for_each(|secret_key| { - let public_key = secret_key.public_key(); - - let st: mainpod::Statement = - Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)) - }) - } - - #[test] - fn test_operation_verify_publickeyof_failure_wrong_key() { - let secret_key = SecretKey(BigUint::one()); - let public_key = SecretKey(BigUint::ZERO).public_key(); - - let st: mainpod::Statement = - Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) - } - - #[test] - fn test_operation_verify_publickeyof_failure_pk_type() { - let secret_key = SecretKey(BigUint::one()); - let public_key = 123i64; - - let st: mainpod::Statement = - Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::None, - ); - let prev_statements = vec![Statement::None.into()]; - assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) - } - - #[test] - fn test_operation_verify_publickeyof_failure_sk_type() { - let secret_key = 123i64; - let public_key = SecretKey(BigUint::from(123u32)).public_key(); - - let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - let aux = Aux::secret_key(SecretKey(BigUint::from(123u32))); - assert!(operation_verify(st, op, prev_statements, aux,).is_err()) - } - - #[test] - fn test_operation_verify_publickeyof_failure_sk_size() { - let secret_key = SecretKey(&*GROUP_ORDER - BigUint::ZERO); - let public_key = secret_key.public_key(); - - let st: mainpod::Statement = - Statement::public_key_of(public_key, secret_key.clone()).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::PublicKeyOf), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::PublicKeyOfIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) - } - - #[test] - fn test_operation_verify_signedby_ok() -> Result<()> { - let sk = SecretKey(BigUint::from_u32(0xbadcafe).unwrap()); - let pk = sk.public_key(); - let msg = RawValue([F(1), F(2), F(3), F(4)]); - let nonce = BigUint::from_u32(123).unwrap(); - let sig = signer::Signer(sk).sign_with_nonce(nonce, msg); - let signed_by = SignedBy { - msg, - pk, - sig: sig.clone(), - }; - - let st: mainpod::Statement = Statement::signed_by(msg, pk).into(); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::SignedBy), - vec![OperationArg::Index(0), OperationArg::Index(0)], - OperationAux::SignedByIndex(0), - ); - let prev_statements = vec![Statement::None.into()]; - operation_verify(st, op, prev_statements, Aux::signed_by(signed_by)) - } - - #[test] - fn test_operation_replace_value_with_entry() -> Result<()> { - let d = dict!({"a" => 42, "b" => 33}); - - // 0: None - // 1: Lt(5, 42) - let st_in: mainpod::Statement = Statement::lt(5, 42).into(); - // 2: Contains(d, "a", 42) - let st_entry: mainpod::Statement = Statement::contains(d.clone(), "a", 42).into(); - - let st_out: mainpod::Statement = - Statement::lt(5, ValueRef::Key(AnchoredKey::from((&d, "a")))).into(); - let mut op_args: Vec<_> = iter::repeat(OperationArg::None) - .take(BASE_PARAMS.max_statement_args + 1) - .collect(); - op_args[1] = OperationArg::Index(2); - op_args[BASE_PARAMS.max_statement_args] = OperationArg::Index(1); - let op = mainpod::Operation( - OperationType::Native(NativeOperation::ReplaceValueWithEntry), - op_args, - OperationAux::None, - ); - - let prev_statements = vec![Statement::None.into(), st_in, st_entry]; - operation_verify(st_out, op, prev_statements, Aux::default()) - } - - fn helper_statement_arg_from_template( - params: &Params, - st_tmpl_arg: StatementTmplArg, - args: Vec, - expected_st_arg: StatementArg, - ) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let st_tmpl_arg_target = builder.add_virtual_statement_tmpl_arg(); - let args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_value()) - .collect(); - let st_arg_target = make_statement_arg_from_template_circuit( - params, - &mut builder, - &st_tmpl_arg_target, - &args_target, - ); - // TODO: Instead of connect, assign witness to result - let expected_st_arg_target = builder.add_virtual_statement_arg(); - builder.connect_array(expected_st_arg_target.elements, st_arg_target.elements); - - let mut pw = PartialWitness::::new(); - - st_tmpl_arg_target.set_targets(&mut pw, &st_tmpl_arg)?; - for (arg_target, arg) in args_target.iter().zip(args.iter()) { - arg_target.set_targets(&mut pw, arg)?; - } - expected_st_arg_target.set_targets(&mut pw, &expected_st_arg)?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof.clone()).unwrap(); - - Ok(()) - } - - #[test] - fn test_statement_arg_from_template() -> Result<()> { - let params = Params::default(); - - let dict = Hash([F(6), F(7), F(8), F(9)]); - - // case: None - let st_tmpl_arg = StatementTmplArg::None; - let args = vec![Value::from(1), Value::from(2), Value::from(3)]; - let expected_st_arg = StatementArg::None; - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - // case: Literal - let st_tmpl_arg = StatementTmplArg::Literal(Value::from("foo")); - let args = vec![Value::from(1), Value::from(2), Value::from(3)]; - let expected_st_arg = StatementArg::Literal(Value::from("foo")); - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - // case: AnchoredKey(id_wildcard, key_literal) - let st_tmpl_arg = - StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("foo")); - let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; - let expected_st_arg = StatementArg::Key(AnchoredKey::new(dict, Key::from("foo"))); - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - // case: WildcardLiteral(wildcard) - let st_tmpl_arg = StatementTmplArg::Wildcard(Wildcard::new("a".to_string(), 1)); - let args = vec![Value::from(1), Value::from("key"), Value::from(3)]; - let expected_st_arg = StatementArg::Literal(Value::from("key")); - helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; - - Ok(()) - } - - fn helper_statement_from_template( - params: &Params, - st_tmpl: StatementTmpl, - args: Vec, - expected_st: Statement, - ) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let st_tmpl_target = builder.add_virtual_statement_tmpl(false); - let args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_value()) - .collect(); - let st_target = make_statement_from_template_circuit( - params, - &mut builder, - &st_tmpl_target, - &args_target, - ); - // TODO: Instead of connect, assign witness to result - let expected_st_target = builder.add_virtual_statement(false); - builder.connect_flattenable(&expected_st_target, &st_target); - - let mut pw = PartialWitness::::new(); - - st_tmpl_target.set_targets(&mut pw, &st_tmpl)?; - for (arg_target, arg) in args_target.iter().zip(args.iter()) { - arg_target.set_targets(&mut pw, arg)?; - } - expected_st_target.set_targets(&mut pw, &expected_st.into())?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof.clone()).unwrap(); - - Ok(()) - } - - #[test] - fn test_statement_from_template() -> Result<()> { - let params = Params::default(); - - let dict = Hash([F(6), F(7), F(8), F(9)]); - - let st_tmpl = StatementTmpl { - pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)), - args: vec![ - StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), - StatementTmplArg::Literal(Value::from("value")), - ], - }; - let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; - let expected_st = Statement::equal( - AnchoredKey::new(dict, Key::from("key")), - Value::from("value"), - ); - helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; - - let st_tmpl = StatementTmpl { - pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)), - args: vec![ - StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), - StatementTmplArg::Literal(Value::from("value")), - ], - }; - let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(); - let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)]; - let expected_st = Statement::not_equal( - AnchoredKey::new(dict, Key::from("key")), - Value::from("value"), - ); - helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; - - Ok(()) - } - - fn helper_custom_operation_verify_gadget( - params: &Params, - custom_predicate: CustomPredicateRef, - mut op_args: Vec, - mut args: Vec, - expected_st: Option, - ) -> Result<()> { - // Pad - for _ in op_args.len()..BASE_PARAMS.max_operation_args { - op_args.push(Statement::None); - } - for _ in args.len()..params.max_custom_predicate_wildcards { - args.push(Value::from(EMPTY_VALUE)); - } - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let custom_predicate_target = builder.add_virtual_custom_predicate_entry(); - let op_args_target: Vec<_> = (0..op_args.len()) - .map(|_| builder.add_virtual_statement(false)) - .collect(); - let args_target: Vec<_> = (0..args.len()) - .map(|_| builder.add_virtual_value()) - .collect(); - let (st_target, op_type_target) = make_custom_statement_circuit( - params, - &mut builder, - &custom_predicate_target, - &op_args_target, - &args_target, - )?; - - let mut pw = PartialWitness::::new(); - - // Input - custom_predicate_target.set_targets(&mut pw, &custom_predicate)?; - for (op_arg_target, op_arg) in op_args_target.iter().zip(op_args.into_iter()) { - op_arg_target.set_targets(&mut pw, &op_arg.into())?; - } - for (arg_target, arg) in args_target.iter().zip(args.iter()) { - arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?; - } - // Expected Output - if let Some(expected_st) = expected_st { - st_target.set_targets(&mut pw, &expected_st.into())?; - } - - let expected_op_type = OperationType::Custom(custom_predicate); - op_type_target.set_targets(&mut pw, &expected_op_type)?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - Ok(data.verify(proof.clone())?) - } - - fn value_ref(v: impl Into) -> ValueRef { - v.into() - } - - // TODO: Add negative tests - #[test] - fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { - let params = Params::default(); - - use NativePredicate as NP; - use StatementTmplBuilder as STB; - let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb0 = STB::new_from_pred(NP::Equal) - .arg(("id", "score")) - .arg(literal(42)); - let stb1 = STB::new_from_pred(NP::Equal) - .arg(("id", "key")) - .arg("secret"); - let _ = builder.predicate_and( - "pred_and", - &["id"], - &["secret"], - &[stb0.clone(), stb1.clone()], - )?; - let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?; - let batch = builder.finish()?; - - let dict = Hash([F(6), F(7), F(8), F(9)]); - - // AND - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), - ]; - let args = vec![Value::from(dict), Value::from(1234)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - // OR (1) - let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::None, - ]; - let args = vec![Value::from(dict), Value::from(0)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - // OR (2) - let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); - let op_args = vec![ - Statement::None, - Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), - ]; - let args = vec![Value::from(dict), Value::from(1234)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - Ok(()) - } - - #[test] - fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> { - let params = Params::default(); - - use NativePredicate as NP; - use StatementTmplBuilder as STB; - let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb0 = STB::new_from_pred(NP::Equal) - .arg(("id", "score")) - .arg(literal(42)); - let stb1 = STB::new_from_pred(NP::Equal) - .arg(("secret_id", "key")) - .arg(("id", "score")); - let _ = builder.predicate_and( - "pred_and", - &["id"], - &["secret_id"], - &[stb0.clone(), stb1.clone()], - )?; - let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?; - let batch = builder.finish()?; - - let dict = Hash([F(1), F(2), F(3), F(4)]); - let secret_dict = Hash([F(6), F(7), F(8), F(9)]); - - // AND (0) Sanity check with correct values - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - let expected_st = Statement::Custom( - custom_predicate.clone(), - vec![value_ref(args[0].clone()), value_ref(0)], - ); - - helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - Some(expected_st), - ) - .unwrap(); - - // AND (1) Different dict for same wildcard - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(Hash([F(0), F(5), F(1), F(6)]), Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!(helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - None, - ) - .is_err()); - - // AND (2) key doesn't match template - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("BAD")), Value::from(42)), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!(helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - None, - ) - .is_err()); - - // AND (3) literal doesn't match template - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal( - AnchoredKey::new(dict, Key::from("score")), - Value::from(0xbad), - ), - Statement::equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!(helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - None, - ) - .is_err()); - - // AND (4) predicate doesn't match template - let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); - let op_args = vec![ - Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), - Statement::not_equal( - AnchoredKey::new(secret_dict, Key::from("key")), - AnchoredKey::new(dict, Key::from("score")), - ), - ]; - let args = vec![Value::from(dict), Value::from(secret_dict)]; - - assert!(helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - None, - ) - .is_err()); - - // OR (1) Two Nones - let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); - let op_args = vec![Statement::None, Statement::None]; - let args = vec![Value::from(dict), Value::from(0)]; - - assert!(helper_custom_operation_verify_gadget( - ¶ms, - custom_predicate, - op_args, - args, - None - ) - .is_err()); - - Ok(()) - } - - fn helper_calculate_statements_hash(params: &Params, statements: &[Statement]) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - let statements_target = (0..params.max_public_statements) - .map(|_| builder.add_virtual_statement(false)) - .collect_vec(); - let sts_hash_target = calculate_statements_hash_circuit(&mut builder, &statements_target); - - let mut pw = PartialWitness::::new(); - - // Input - let statements = statements - .iter() - .map(|st| { - let mut st = mainpod::Statement::from(st.clone()); - pad_statement(&mut st); - st - }) - .collect_vec(); - for (st_target, st) in statements_target.iter().zip(statements.iter()) { - st_target.set_targets(&mut pw, st)?; - } - // Expected Output - let expected_sts_hash = calculate_statements_hash(&statements); - pw.set_hash_target( - sts_hash_target, - HashOut { - elements: expected_sts_hash.0, - }, - )?; - - // generate & verify proof - let data = builder.build::(); - let proof = data.prove(pw)?; - Ok(data.verify(proof.clone())?) - } - - #[test] - fn test_calculate_sts_hash() -> frontend::Result<()> { - assert_eq!(Params::num_public_statements_hash(), 16); - // Case with no public public statements - let params = Params { - max_public_statements: 0, - ..Default::default() - }; - - helper_calculate_statements_hash(¶ms, &[]).unwrap(); - - // Case with number of statements for the sts_hash equal to number of public statements - let params = Params { - max_public_statements: Params::num_public_statements_hash(), - ..Default::default() - }; - - let dict = Hash([F(1), F(2), F(3), F(4)]); - let statements = (0..Params::num_public_statements_hash()) - .map(|i| Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(i as i64))) - .collect_vec(); - - helper_calculate_statements_hash(¶ms, &statements).unwrap(); - - // Case with more statements for the sts_hash than the number of public statements - let params = Params { - max_public_statements: 4, - ..Default::default() - }; - - let dict2 = Hash([F(5), F(6), F(7), F(8)]); - let statements = [ - Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(42)), - Statement::equal( - AnchoredKey::from((dict, "bar")), - AnchoredKey::from((dict, "baz")), - ), - Statement::lt( - AnchoredKey::from((dict2, "one")), - AnchoredKey::from((dict2, "two")), - ), - ] - .into_iter() - .chain(iter::repeat(Statement::None)) - .take(params.max_public_statements) - .collect_vec(); - - helper_calculate_statements_hash(¶ms, &statements).unwrap(); - - Ok(()) - } - - #[test] - fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> { - let params = Params::default(); - - // Build a batch with two predicates: - // pred_A: Equal(x, y) - // pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash - use NativePredicate as NP; - let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal) - .arg("x") - .arg("y"); - cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a]) - .unwrap(); - - // Build pred_B's template manually with SelfPredicateHash(0) - let stb_b_tmpl = StatementTmpl { - pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)), - args: vec![ - StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)), - StatementTmplArg::SelfPredicateHash(0), - ], - }; - let pred_b = CustomPredicate::new( - ¶ms, - "pred_B".into(), - true, - vec![stb_b_tmpl], - 1, - vec!["x".to_string()], - ) - .unwrap(); - cpb.predicates.push(pred_b); - let batch = cpb.finish().unwrap(); - - // Compute the expected resolved hash of pred_A - let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); - let pred_a_hash = Predicate::Custom(pred_a_ref).hash(); - let expected_pred_a_value = Value::from(pred_a_hash); - - // Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to - // Literal(pred_a_hash). Then make_statement_from_template_circuit should produce - // a statement with that literal value. - let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); - let pred_b_tmpl = &pred_b_ref.predicate().statements[0]; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::new(config); - - // Create the template target and batch id target - let st_tmpl_target = builder.add_virtual_statement_tmpl(true); - let batch_id = builder.add_virtual_hash(); - - // Normalize the template (this is what we're testing) - let normalized = - normalize_st_tmpl_circuit(¶ms, &mut builder, &st_tmpl_target, batch_id); - - // Feed normalized template into statement generation - let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards) - .map(|_| builder.add_virtual_value()) - .collect(); - let st_target = - make_statement_from_template_circuit(¶ms, &mut builder, &normalized, &args_target); - - // Connect to expected output - let expected_st_target = builder.add_virtual_statement(false); - builder.connect_flattenable(&expected_st_target, &st_target); - - // Set witness - let mut pw = PartialWitness::::new(); - st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?; - pw.set_target_arr(&batch_id.elements, &batch.id().0)?; - - let some_value = Value::from(42); - // args: first wildcard is "x" = some_value, rest are padding - let mut args_values = vec![some_value.clone()]; - for _ in 1..params.max_custom_predicate_wildcards { - args_values.push(Value::from(EMPTY_VALUE)); - } - for (target, value) in args_target.iter().zip(args_values.iter()) { - target.set_targets(&mut pw, value)?; - } - - // Expected statement: Equal(Literal(some_value), Literal(pred_a_hash)) - let expected_st: crate::backends::plonky2::mainpod::Statement = - Statement::equal(some_value, expected_pred_a_value).into(); - expected_st_target.set_targets(&mut pw, &expected_st)?; - - // Build and verify - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - - Ok(()) - } -} diff --git a/src/backends/plonky2/circuits/mainpod/tests.rs b/src/backends/plonky2/circuits/mainpod/tests.rs new file mode 100644 index 0000000..49fe4a0 --- /dev/null +++ b/src/backends/plonky2/circuits/mainpod/tests.rs @@ -0,0 +1,1707 @@ +use std::{iter, ops::Not}; + +use num::FromPrimitive; +use plonky2::{ + field::{goldilocks_field::GoldilocksField, types::Field}, + hash::hash_types::HashOut, + iop::witness::WitnessWrite, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, +}; + +use super::*; +use crate::{ + backends::plonky2::{ + basetypes::C, + circuits::common::tests::I64_TEST_PAIRS, + mainpod::{calculate_statements_hash, OperationArg, OperationAux, Size}, + primitives::{ + ec::schnorr::SecretKey, + merkletree::{MerkleClaimAndProof, MerkleTree, MerkleTreeStateTransitionProof}, + }, + signer, + }, + dict, + frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, + middleware::{ + self, hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, + RawValue, StatementArg, StatementTmpl, StatementTmplArg, ValueRef, Wildcard, BASE_PARAMS, + EMPTY_VALUE, + }, +}; + +#[derive(Default)] +struct Aux { + merkle_proofs: Vec, + secret_keys: Vec, + signed_bys: Vec, + merkle_transition_proofs: Vec, +} + +impl Aux { + fn merkle_proof(v: MerkleClaimAndProof) -> Self { + Self { + merkle_proofs: vec![v], + ..Default::default() + } + } + fn secret_key(v: SecretKey) -> Self { + Self { + secret_keys: vec![v], + ..Default::default() + } + } + fn signed_by(v: SignedBy) -> Self { + Self { + signed_bys: vec![v], + ..Default::default() + } + } + fn merkle_tree_state_transition_proof(v: MerkleTreeStateTransitionProof) -> Self { + Self { + merkle_transition_proofs: vec![v], + ..Default::default() + } + } +} + +fn operation_verify( + st: mainpod::Statement, + op: mainpod::Operation, + prev_statements: Vec, + aux: Aux, +) -> Result<()> { + let params = Params { + max_public_key_of: aux.secret_keys.len(), + max_signed_by: aux.signed_bys.len(), + containers: middleware::ParamsContainers { + state: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: aux.merkle_proofs.len(), + }, + transition: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: aux.merkle_transition_proofs.len(), + }, + max_depth_small: 8, + max_depth_medium: 32, + }, + max_custom_predicate_verifications: 0, + max_custom_predicates: 0, + ..Default::default() + }; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let st_target = builder.add_virtual_statement(false); + let op_target = builder.add_virtual_operation(¶ms); + let prev_statements_target: Vec<_> = (0..prev_statements.len()) + .map(|_| builder.add_virtual_statement(false)) + .collect(); + let prev_statement_flatteneds_target: Vec> = prev_statements_target + .iter() + .map(|st| st.flatten()) + .collect(); + let prev_statement_hashes_target: Vec<_> = prev_statement_flatteneds_target + .iter() + .map(|flat| builder.hash_n_to_hash_no_pad::(flat.clone())) + .collect(); + + let merkle_proofs_target = MerkleProofsTarget { + medium: aux + .merkle_proofs + .iter() + .map(|_| { + MerkleClaimAndProofTarget::new_virtual( + params.containers.max_depth_medium, + &mut builder, + ) + }) + .collect(), + small: Vec::new(), + }; + + let secret_keys_target: Vec<_> = aux + .secret_keys + .iter() + .map(|sk| builder.constant_biguint320(&sk.0)) + .collect(); + + let signed_by_targets: Vec<_> = aux + .signed_bys + .iter() + .map(|_| SignedByTarget::new_virtual(&mut builder)) + .collect(); + + let merkle_transition_proofs_target = MerkleTransitionProofsTarget { + medium: aux + .merkle_transition_proofs + .iter() + .map(|_| { + MerkleTreeStateTransitionProofTarget::new_virtual( + params.containers.max_depth_medium, + &mut builder, + ) + }) + .collect(), + small: Vec::new(), + }; + + let aux_table = build_operation_aux_table_circuit( + ¶ms, + &mut builder, + &merkle_proofs_target, + &merkle_transition_proofs_target, + &secret_keys_target, + &signed_by_targets, + &[], + &[], + )?; + + verify_operation_circuit( + ¶ms, + &mut builder, + &st_target, + &op_target, + &prev_statement_flatteneds_target, + &prev_statement_hashes_target, + &aux_table, + )?; + + let mut pw = PartialWitness::::new(); + st_target.set_targets(&mut pw, &st)?; + op_target.set_targets(&mut pw, ¶ms, &op)?; + for (prev_st_target, prev_st) in prev_statements_target.iter().zip(prev_statements.iter()) { + prev_st_target.set_targets(&mut pw, prev_st)?; + } + for (signed_by_target, signed_by) in signed_by_targets.iter().zip(aux.signed_bys.iter()) { + signed_by_target.set_targets(&mut pw, signed_by)? + } + for (merkle_proof_target, merkle_proof) in merkle_proofs_target + .medium + .iter() + .zip(aux.merkle_proofs.iter()) + { + merkle_proof_target.set_targets(&mut pw, merkle_proof)? + } + for (merkle_tree_state_transition_proof_target, merkle_tree_state_transition_proof) in + merkle_transition_proofs_target + .medium + .iter() + .zip(aux.merkle_transition_proofs.iter()) + { + merkle_tree_state_transition_proof_target + .set_targets(&mut pw, merkle_tree_state_transition_proof)? + } + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) +} + +#[test] +fn test_lt_lteq_verify_failures() { + let invalid_int = RawValue([ + GoldilocksField::NEG_ONE, + GoldilocksField::ZERO, + GoldilocksField::ZERO, + GoldilocksField::ZERO, + ]); + + let prev_statements = [Statement::None.into()]; + + [ + // 56 < 55, 55 < 55, 56 <= 55, -55 < -55, -55 < -56, -55 <= -56 should fail to verify + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(56, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(55, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt_eq(56, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(-55, -55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(-55, -56).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt_eq(-55, -56).into(), + ), + // 56 < p-1 and p-1 <= p-1 should fail to verify, where p + // is the Goldilocks prime and 'p-1' occupies a single + // limb. + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt(56, invalid_int).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::lt_eq(invalid_int, invalid_int).into(), + ), + ] + .into_iter() + .for_each(|(op, st)| { + let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + operation_verify(st, op, prev_statements.to_vec(), Aux::default()) + })); + match check { + Err(e) => { + let err_string = e.downcast_ref::().unwrap(); + if !err_string.contains("Integer too large to fit") { + panic!("Test failed with an unexpected error: {}", err_string); + } + } + Ok(Err(_)) => {} + _ => panic!("Test passed, yet it should have failed!"), + } + }); +} + +#[test] +fn test_eq_neq_verify_failures() { + let prev_statements = [Statement::None.into()]; + + [ + // 56 == 55, 55 != 55 should fail to verify + ( + mainpod::Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::equal(56, 55).into(), + ), + ( + mainpod::Operation( + OperationType::Native(NativeOperation::NotEqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ), + Statement::not_equal(55, 55).into(), + ), + ] + .into_iter() + .for_each(|(op, st)| { + assert!(operation_verify(st, op, prev_statements.to_vec(), Aux::default()).is_err()) + }); +} + +#[test] +fn test_operation_verify_none() -> Result<()> { + let st: mainpod::Statement = Statement::None.into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::None), + vec![], + OperationAux::None, + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_copy() -> Result<()> { + let st: mainpod::Statement = Statement::None.into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::CopyStatement), + vec![OperationArg::Index(0)], + OperationAux::None, + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_eq() -> Result<()> { + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"world" => 55}); + let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); + let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 55).into(); + let st: mainpod::Statement = Statement::equal( + AnchoredKey::from((&dict1, "hello")), + AnchoredKey::from((&dict2, "world")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_neq() -> Result<()> { + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"world" => 75}); + let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); + let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 75).into(); + let st: mainpod::Statement = Statement::not_equal( + AnchoredKey::from((&dict1, "hello")), + AnchoredKey::from((&dict2, "world")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::NotEqualFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_lt() -> Result<()> { + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"hello" => 56}); + let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); + let st2: mainpod::Statement = Statement::contains(dict2.clone(), "hello", 56).into(); + let st: mainpod::Statement = Statement::lt( + AnchoredKey::from((&dict1, "hello")), + AnchoredKey::from((&dict2, "hello")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2.clone()]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative < negative + let dict3 = dict!({"hola" => -56}); + let dict4 = dict!({"mundo" => -55}); + let st3: mainpod::Statement = Statement::contains(dict3.clone(), "hola", -56).into(); + let st4: mainpod::Statement = Statement::contains(dict4.clone(), "mundo", -55).into(); + let st: mainpod::Statement = Statement::lt( + AnchoredKey::from((&dict3, "hola")), + AnchoredKey::from((&dict4, "mundo")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3.clone(), st4]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative < positive + let st: mainpod::Statement = Statement::lt( + AnchoredKey::from((&dict3, "hola")), + AnchoredKey::from((&dict2, "hello")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3, st2]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_lteq() -> Result<()> { + let local = dict!({ + "n55" => 55, + "n56" => 56, + "n_56" => -56, + "n_55" => -55, + }); + let st1: mainpod::Statement = Statement::contains(local.clone(), "n55", 55).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "n56", 56).into(); + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n55")), + AnchoredKey::from((&local, "n56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2.clone()]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative <= negative + let st3: mainpod::Statement = Statement::contains(local.clone(), "n_56", -56).into(); + let st4: mainpod::Statement = Statement::contains(local.clone(), "n_55", -55).into(); + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n_56")), + AnchoredKey::from((&local, "n_55")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3.clone(), st4]; + operation_verify(st, op, prev_statements, Aux::default())?; + + // Also check negative <= positive + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n_56")), + AnchoredKey::from((&local, "n56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st3, st2]; + operation_verify(st, op, prev_statements.clone(), Aux::default())?; + + // Also check equality, both positive and negative. + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n_56")), + AnchoredKey::from((&local, "n_56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ); + operation_verify(st, op, prev_statements.clone(), Aux::default())?; + let st: mainpod::Statement = Statement::lt_eq( + AnchoredKey::from((&local, "n56")), + AnchoredKey::from((&local, "n56")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtEqFromEntries), + vec![OperationArg::Index(1), OperationArg::Index(1)], + OperationAux::None, + ); + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_hashof() -> Result<()> { + let input_values = [ + Value::from(RawValue([ + GoldilocksField(1), + GoldilocksField(2), + GoldilocksField(3), + GoldilocksField(4), + ])), + Value::from(512), + ]; + let v1 = hash_values(&input_values); + let [v2, v3] = input_values; + + let local = dict!({ + "hola" => v1, + "mundo" => v2.clone(), + "!" => v3.clone(), + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "hola", v1).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "mundo", v2).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "!", v3).into(); + + let st: mainpod::Statement = Statement::hash_of( + AnchoredKey::from((&local, "hola")), + AnchoredKey::from((&local, "mundo")), + AnchoredKey::from((&local, "!")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::HashOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_sumof() -> Result<()> { + I64_TEST_PAIRS + .into_iter() + .flat_map(|(a, b)| { + let (sum, overflow) = a.overflowing_add(b); + overflow.not().then_some((a, b, sum)) + }) + .try_for_each(|(a, b, sum)| { + let local = dict!({ + "sum" => sum, + "a" => a, + "b" => b, + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "sum", sum).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); + + let st: mainpod::Statement = Statement::sum_of( + AnchoredKey::from((&local, "sum")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SumOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) + }) +} + +#[test] +fn test_operation_verify_sumof_non_monotonic_repeated_indices() -> Result<()> { + let local = dict!({ + "a" => 3, + "noise" => 99, + "sum" => 6, + }); + let st_a: mainpod::Statement = Statement::contains(local.clone(), "a", 3).into(); + let st_noise: mainpod::Statement = Statement::contains(local.clone(), "noise", 99).into(); + let st_sum: mainpod::Statement = Statement::contains(local.clone(), "sum", 6).into(); + + let st: mainpod::Statement = Statement::sum_of( + AnchoredKey::from((&local, "sum")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "a")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SumOf), + vec![ + // Non-monotonic and repeated indices to stress random-access resolution. + OperationArg::Index(2), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::None, + ); + let prev_statements = vec![st_a, st_noise, st_sum]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_productof() -> Result<()> { + I64_TEST_PAIRS + .into_iter() + .flat_map(|(a, b)| { + let (prod, overflow) = a.overflowing_mul(b); + overflow.not().then_some((a, b, prod)) + }) + .try_for_each(|(a, b, prod)| { + let local = dict!({ + "prod" => prod, + "a" => a, + "b" => b, + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "prod", prod).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); + + let st: mainpod::Statement = Statement::product_of( + AnchoredKey::from((&local, "prod")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ProductOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) + }) +} + +#[test] +fn test_operation_verify_maxof() -> Result<()> { + I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| { + let max = i64::max(a, b); + let local = dict!({ + "max" => max, + "a" => a, + "b" => b, + }); + + let st1: mainpod::Statement = Statement::contains(local.clone(), "max", max).into(); + let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); + let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into(); + + let st: mainpod::Statement = Statement::max_of( + AnchoredKey::from((&local, "max")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + + let op = mainpod::Operation( + OperationType::Native(NativeOperation::MaxOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, Aux::default()) + }) +} + +#[test] +fn test_operation_verify_maxof_failures() { + [(5, 3, 4), (5, 5, 8), (3, 4, 5)] + .into_iter() + .for_each(|(max, a, b)| { + let st: mainpod::Statement = Statement::max_of(max, a, b).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::MaxOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::None, + ); + let prev_statements = [Statement::None.into()]; + + let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + operation_verify(st, op, prev_statements.to_vec(), Aux::default()) + })); + match check { + Err(e) => { + let err_string = e.downcast_ref::().unwrap(); + if !err_string.contains("Integer too large to fit") { + panic!("Test failed with an unexpected error: {}", err_string); + } + } + Ok(Err(_)) => {} + _ => panic!("Test passed, yet it should have failed!"), + } + }) +} + +#[test] +fn test_operation_verify_lt_to_neq() -> Result<()> { + let local = dict!({ + "a" => 10, + "b" => 20, + }); + let st: mainpod::Statement = Statement::not_equal( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let st1: mainpod::Statement = Statement::lt( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::LtToNotEqual), + vec![OperationArg::Index(0)], + OperationAux::None, + ); + let prev_statements = vec![st1]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_transitive_eq() -> Result<()> { + let local = dict!({ + "a" => 10, + "b" => 10, + "c" => 10, + }); + let st: mainpod::Statement = Statement::equal( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "c")), + ) + .into(); + let st1: mainpod::Statement = Statement::equal( + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "b")), + ) + .into(); + let st2: mainpod::Statement = Statement::equal( + AnchoredKey::from((&local, "b")), + AnchoredKey::from((&local, "c")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::TransitiveEqualFromStatements), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::None, + ); + let prev_statements = vec![st1, st2]; + operation_verify(st, op, prev_statements, Aux::default()) +} + +#[test] +fn test_operation_verify_sintains() -> Result<()> { + let kvs = [ + (1.into(), 55.into()), + (2.into(), 88.into()), + (175.into(), 0.into()), + ] + .into_iter() + .collect(); + let mt = MerkleTree::new(&kvs); + + let root = mt.root(); + let key = Value::from(5); + let local = dict!({ + "merkle_root" => root, + "key" => key.clone(), + }); + let root_ak = AnchoredKey::from((&local, "merkle_root")); + let key_ak = AnchoredKey::from((&local, "key")); + + let no_key_pf = mt.prove_nonexistence(&key.raw())?; + + let root_st: mainpod::Statement = + Statement::contains(local.clone(), "merkle_root", root).into(); + let key_st: mainpod::Statement = Statement::contains(local.clone(), "key", key.clone()).into(); + let st: mainpod::Statement = Statement::not_contains(root_ak, key_ak).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::NotContainsFromEntries), + vec![OperationArg::Index(0), OperationArg::Index(1)], + OperationAux::MerkleProofIndex(Size::Medium, 0), + ); + + let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), None, no_key_pf); + let prev_statements = vec![root_st, key_st]; + operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) +} + +#[test] +fn test_operation_verify_contains() -> Result<()> { + let kvs = [ + (1.into(), 55.into()), + (2.into(), 88.into()), + (175.into(), 0.into()), + ] + .into_iter() + .collect(); + let mt = MerkleTree::new(&kvs); + + let root = mt.root(); + let key = Value::from(175); + let (value, key_pf) = mt.prove(&key.raw())?; + let local = dict!({ + "merkle_root" => root, + "key" => key.clone(), + "value" => value, + }); + let root_ak = AnchoredKey::from((&local, "merkle_root")); + let key_ak = AnchoredKey::from((&local, "key")); + let value_ak = AnchoredKey::from((&local, "value")); + + let root_st: mainpod::Statement = + Statement::contains(local.clone(), "merkle_root", root).into(); + let key_st: mainpod::Statement = Statement::contains(local.clone(), "key", key.clone()).into(); + let value_st: mainpod::Statement = Statement::contains(local.clone(), "value", value).into(); + + let st: mainpod::Statement = Statement::contains(root_ak, key_ak, value_ak).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainsFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::MerkleProofIndex(Size::Medium, 0), + ); + + let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), Some(value), key_pf); + let prev_statements = vec![root_st, key_st, value_st]; + operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof)) +} + +#[test] +fn test_operation_verify_merkle_insert() -> Result<()> { + let mut tree = MerkleTree::new(&[].into()); + + let key = Value::from(175); + let value = Value::from(0); + let state_transition_proof = tree.insert(&key.raw(), &value.raw())?; + let old_root = state_transition_proof.old_root; + let new_root = state_transition_proof.new_root; + + let st: mainpod::Statement = Statement::insert(new_root, old_root, key, value).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainerInsertFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::MerkleTransitionProofIndex(Size::Medium, 0), + ); + + let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, aux) +} + +#[test] +fn test_operation_verify_merkle_update() -> Result<()> { + let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); + + let key = Value::from(175); + let value = Value::from(0); + let state_transition_proof = tree.update(&key.raw(), &value.raw())?; + let old_root = state_transition_proof.old_root; + let new_root = state_transition_proof.new_root; + + let st: mainpod::Statement = Statement::update(new_root, old_root, key, value).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainerUpdateFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::MerkleTransitionProofIndex(Size::Medium, 0), + ); + + let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, aux) +} + +#[test] +fn test_operation_verify_merkle_delete() -> Result<()> { + let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); + + let key = Value::from(175); + let state_transition_proof = tree.delete(&key.raw())?; + let old_root = state_transition_proof.old_root; + let new_root = state_transition_proof.new_root; + + let st: mainpod::Statement = Statement::delete(new_root, old_root, key).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ContainerDeleteFromEntries), + vec![ + OperationArg::Index(0), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::MerkleTransitionProofIndex(Size::Medium, 0), + ); + + let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, aux) +} + +#[test] +fn test_operation_verify_publickeyof_ok() -> Result<()> { + [ + SecretKey(BigUint::one()), + SecretKey::new_rand(), + SecretKey(&*GROUP_ORDER - BigUint::one()), + ] + .into_iter() + .try_for_each(|secret_key| { + let public_key = secret_key.public_key(); + + let st: mainpod::Statement = + Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)) + }) +} + +#[test] +fn test_operation_verify_publickeyof_failure_wrong_key() { + let secret_key = SecretKey(BigUint::one()); + let public_key = SecretKey(BigUint::ZERO).public_key(); + + let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) +} + +#[test] +fn test_operation_verify_publickeyof_failure_pk_type() { + let secret_key = SecretKey(BigUint::one()); + let public_key = 123i64; + + let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::None, + ); + let prev_statements = vec![Statement::None.into()]; + assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) +} + +#[test] +fn test_operation_verify_publickeyof_failure_sk_type() { + let secret_key = 123i64; + let public_key = SecretKey(BigUint::from(123u32)).public_key(); + + let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + let aux = Aux::secret_key(SecretKey(BigUint::from(123u32))); + assert!(operation_verify(st, op, prev_statements, aux,).is_err()) +} + +#[test] +fn test_operation_verify_publickeyof_failure_sk_size() { + let secret_key = SecretKey(&*GROUP_ORDER - BigUint::ZERO); + let public_key = secret_key.public_key(); + + let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key.clone()).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::PublicKeyOf), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::PublicKeyOfIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err()) +} + +#[test] +fn test_operation_verify_signedby_ok() -> Result<()> { + let sk = SecretKey(BigUint::from_u32(0xbadcafe).unwrap()); + let pk = sk.public_key(); + let msg = RawValue([F(1), F(2), F(3), F(4)]); + let nonce = BigUint::from_u32(123).unwrap(); + let sig = signer::Signer(sk).sign_with_nonce(nonce, msg); + let signed_by = SignedBy { + msg, + pk, + sig: sig.clone(), + }; + + let st: mainpod::Statement = Statement::signed_by(msg, pk).into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SignedBy), + vec![OperationArg::Index(0), OperationArg::Index(0)], + OperationAux::SignedByIndex(0), + ); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements, Aux::signed_by(signed_by)) +} + +#[test] +fn test_operation_replace_value_with_entry() -> Result<()> { + let d = dict!({"a" => 42, "b" => 33}); + + // 0: None + // 1: Lt(5, 42) + let st_in: mainpod::Statement = Statement::lt(5, 42).into(); + // 2: Contains(d, "a", 42) + let st_entry: mainpod::Statement = Statement::contains(d.clone(), "a", 42).into(); + + let st_out: mainpod::Statement = + Statement::lt(5, ValueRef::Key(AnchoredKey::from((&d, "a")))).into(); + let mut op_args: Vec<_> = iter::repeat(OperationArg::None) + .take(BASE_PARAMS.max_statement_args + 1) + .collect(); + op_args[1] = OperationArg::Index(2); + op_args[BASE_PARAMS.max_statement_args] = OperationArg::Index(1); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ReplaceValueWithEntry), + op_args, + OperationAux::None, + ); + + let prev_statements = vec![Statement::None.into(), st_in, st_entry]; + operation_verify(st_out, op, prev_statements, Aux::default()) +} + +fn helper_statement_arg_from_template( + params: &Params, + st_tmpl_arg: StatementTmplArg, + args: Vec, + expected_st_arg: StatementArg, +) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let st_tmpl_arg_target = builder.add_virtual_statement_tmpl_arg(); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_arg_target = make_statement_arg_from_template_circuit( + params, + &mut builder, + &st_tmpl_arg_target, + &args_target, + ); + // TODO: Instead of connect, assign witness to result + let expected_st_arg_target = builder.add_virtual_statement_arg(); + builder.connect_array(expected_st_arg_target.elements, st_arg_target.elements); + + let mut pw = PartialWitness::::new(); + + st_tmpl_arg_target.set_targets(&mut pw, &st_tmpl_arg)?; + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, arg)?; + } + expected_st_arg_target.set_targets(&mut pw, &expected_st_arg)?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + + Ok(()) +} + +#[test] +fn test_statement_arg_from_template() -> Result<()> { + let params = Params::default(); + + let dict = Hash([F(6), F(7), F(8), F(9)]); + + // case: None + let st_tmpl_arg = StatementTmplArg::None; + let args = vec![Value::from(1), Value::from(2), Value::from(3)]; + let expected_st_arg = StatementArg::None; + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: Literal + let st_tmpl_arg = StatementTmplArg::Literal(Value::from("foo")); + let args = vec![Value::from(1), Value::from(2), Value::from(3)]; + let expected_st_arg = StatementArg::Literal(Value::from("foo")); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: AnchoredKey(id_wildcard, key_literal) + let st_tmpl_arg = + StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("foo")); + let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; + let expected_st_arg = StatementArg::Key(AnchoredKey::new(dict, Key::from("foo"))); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: WildcardLiteral(wildcard) + let st_tmpl_arg = StatementTmplArg::Wildcard(Wildcard::new("a".to_string(), 1)); + let args = vec![Value::from(1), Value::from("key"), Value::from(3)]; + let expected_st_arg = StatementArg::Literal(Value::from("key")); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + Ok(()) +} + +fn helper_statement_from_template( + params: &Params, + st_tmpl: StatementTmpl, + args: Vec, + expected_st: Statement, +) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let st_tmpl_target = builder.add_virtual_statement_tmpl(false); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_target = + make_statement_from_template_circuit(params, &mut builder, &st_tmpl_target, &args_target); + // TODO: Instead of connect, assign witness to result + let expected_st_target = builder.add_virtual_statement(false); + builder.connect_flattenable(&expected_st_target, &st_target); + + let mut pw = PartialWitness::::new(); + + st_tmpl_target.set_targets(&mut pw, &st_tmpl)?; + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, arg)?; + } + expected_st_target.set_targets(&mut pw, &expected_st.into())?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + + Ok(()) +} + +#[test] +fn test_statement_from_template() -> Result<()> { + let params = Params::default(); + + let dict = Hash([F(6), F(7), F(8), F(9)]); + + let st_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)), + args: vec![ + StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), + StatementTmplArg::Literal(Value::from("value")), + ], + }; + let args = vec![Value::from(1), Value::from(dict), Value::from(3)]; + let expected_st = Statement::equal( + AnchoredKey::new(dict, Key::from("key")), + Value::from("value"), + ); + helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; + + let st_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)), + args: vec![ + StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")), + StatementTmplArg::Literal(Value::from("value")), + ], + }; + let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(); + let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)]; + let expected_st = Statement::not_equal( + AnchoredKey::new(dict, Key::from("key")), + Value::from("value"), + ); + helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; + + Ok(()) +} + +fn helper_custom_operation_verify_gadget( + params: &Params, + custom_predicate: CustomPredicateRef, + mut op_args: Vec, + mut args: Vec, + expected_st: Option, +) -> Result<()> { + // Pad + for _ in op_args.len()..BASE_PARAMS.max_operation_args { + op_args.push(Statement::None); + } + for _ in args.len()..params.max_custom_predicate_wildcards { + args.push(Value::from(EMPTY_VALUE)); + } + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let custom_predicate_target = builder.add_virtual_custom_predicate_entry(); + let op_args_target: Vec<_> = (0..op_args.len()) + .map(|_| builder.add_virtual_statement(false)) + .collect(); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let (st_target, op_type_target) = make_custom_statement_circuit( + params, + &mut builder, + &custom_predicate_target, + &op_args_target, + &args_target, + )?; + + let mut pw = PartialWitness::::new(); + + // Input + custom_predicate_target.set_targets(&mut pw, &custom_predicate)?; + for (op_arg_target, op_arg) in op_args_target.iter().zip(op_args.into_iter()) { + op_arg_target.set_targets(&mut pw, &op_arg.into())?; + } + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?; + } + // Expected Output + if let Some(expected_st) = expected_st { + st_target.set_targets(&mut pw, &expected_st.into())?; + } + + let expected_op_type = OperationType::Custom(custom_predicate); + op_type_target.set_targets(&mut pw, &expected_op_type)?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + Ok(data.verify(proof.clone())?) +} + +fn value_ref(v: impl Into) -> ValueRef { + v.into() +} + +// TODO: Add negative tests +#[test] +fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { + let params = Params::default(); + + use NativePredicate as NP; + use StatementTmplBuilder as STB; + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb0 = STB::new_from_pred(NP::Equal) + .arg(("id", "score")) + .arg(literal(42)); + let stb1 = STB::new_from_pred(NP::Equal) + .arg(("id", "key")) + .arg("secret"); + let _ = builder.predicate_and( + "pred_and", + &["id"], + &["secret"], + &[stb0.clone(), stb1.clone()], + )?; + let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?; + let batch = builder.finish()?; + + let dict = Hash([F(6), F(7), F(8), F(9)]); + + // AND + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), + ]; + let args = vec![Value::from(dict), Value::from(1234)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![value_ref(args[0].clone()), value_ref(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + // OR (1) + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::None, + ]; + let args = vec![Value::from(dict), Value::from(0)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![value_ref(args[0].clone()), value_ref(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + // OR (2) + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![ + Statement::None, + Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)), + ]; + let args = vec![Value::from(dict), Value::from(1234)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![value_ref(args[0].clone()), value_ref(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + Ok(()) +} + +#[test] +fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> { + let params = Params::default(); + + use NativePredicate as NP; + use StatementTmplBuilder as STB; + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb0 = STB::new_from_pred(NP::Equal) + .arg(("id", "score")) + .arg(literal(42)); + let stb1 = STB::new_from_pred(NP::Equal) + .arg(("secret_id", "key")) + .arg(("id", "score")); + let _ = builder.predicate_and( + "pred_and", + &["id"], + &["secret_id"], + &[stb0.clone(), stb1.clone()], + )?; + let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?; + let batch = builder.finish()?; + + let dict = Hash([F(1), F(2), F(3), F(4)]); + let secret_dict = Hash([F(6), F(7), F(8), F(9)]); + + // AND (0) Sanity check with correct values + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![value_ref(args[0].clone()), value_ref(0)], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + Some(expected_st), + ) + .unwrap(); + + // AND (1) Different dict for same wildcard + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(Hash([F(0), F(5), F(1), F(6)]), Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!( + helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) + .is_err() + ); + + // AND (2) key doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("BAD")), Value::from(42)), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!( + helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) + .is_err() + ); + + // AND (3) literal doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal( + AnchoredKey::new(dict, Key::from("score")), + Value::from(0xbad), + ), + Statement::equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!( + helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) + .is_err() + ); + + // AND (4) predicate doesn't match template + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)), + Statement::not_equal( + AnchoredKey::new(secret_dict, Key::from("key")), + AnchoredKey::new(dict, Key::from("score")), + ), + ]; + let args = vec![Value::from(dict), Value::from(secret_dict)]; + + assert!( + helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None,) + .is_err() + ); + + // OR (1) Two Nones + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![Statement::None, Statement::None]; + let args = vec![Value::from(dict), Value::from(0)]; + + assert!( + helper_custom_operation_verify_gadget(¶ms, custom_predicate, op_args, args, None) + .is_err() + ); + + Ok(()) +} + +fn helper_calculate_statements_hash(params: &Params, statements: &[Statement]) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + let statements_target = (0..params.max_public_statements) + .map(|_| builder.add_virtual_statement(false)) + .collect_vec(); + let sts_hash_target = calculate_statements_hash_circuit(&mut builder, &statements_target); + + let mut pw = PartialWitness::::new(); + + // Input + let statements = statements + .iter() + .map(|st| { + let mut st = mainpod::Statement::from(st.clone()); + pad_statement(&mut st); + st + }) + .collect_vec(); + for (st_target, st) in statements_target.iter().zip(statements.iter()) { + st_target.set_targets(&mut pw, st)?; + } + // Expected Output + let expected_sts_hash = calculate_statements_hash(&statements); + pw.set_hash_target( + sts_hash_target, + HashOut { + elements: expected_sts_hash.0, + }, + )?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + Ok(data.verify(proof.clone())?) +} + +#[test] +fn test_calculate_sts_hash() -> frontend::Result<()> { + assert_eq!(Params::num_public_statements_hash(), 16); + // Case with no public public statements + let params = Params { + max_public_statements: 0, + ..Default::default() + }; + + helper_calculate_statements_hash(¶ms, &[]).unwrap(); + + // Case with number of statements for the sts_hash equal to number of public statements + let params = Params { + max_public_statements: Params::num_public_statements_hash(), + ..Default::default() + }; + + let dict = Hash([F(1), F(2), F(3), F(4)]); + let statements = (0..Params::num_public_statements_hash()) + .map(|i| Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(i as i64))) + .collect_vec(); + + helper_calculate_statements_hash(¶ms, &statements).unwrap(); + + // Case with more statements for the sts_hash than the number of public statements + let params = Params { + max_public_statements: 4, + ..Default::default() + }; + + let dict2 = Hash([F(5), F(6), F(7), F(8)]); + let statements = [ + Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(42)), + Statement::equal( + AnchoredKey::from((dict, "bar")), + AnchoredKey::from((dict, "baz")), + ), + Statement::lt( + AnchoredKey::from((dict2, "one")), + AnchoredKey::from((dict2, "two")), + ), + ] + .into_iter() + .chain(iter::repeat(Statement::None)) + .take(params.max_public_statements) + .collect_vec(); + + helper_calculate_statements_hash(¶ms, &statements).unwrap(); + + Ok(()) +} + +#[test] +fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> { + let params = Params::default(); + + // Build a batch with two predicates: + // pred_A: Equal(x, y) + // pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash + use NativePredicate as NP; + let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal) + .arg("x") + .arg("y"); + cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a]) + .unwrap(); + + // Build pred_B's template manually with SelfPredicateHash(0) + let stb_b_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)), + args: vec![ + StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)), + StatementTmplArg::SelfPredicateHash(0), + ], + }; + let pred_b = CustomPredicate::new( + ¶ms, + "pred_B".into(), + true, + vec![stb_b_tmpl], + 1, + vec!["x".to_string()], + ) + .unwrap(); + cpb.predicates.push(pred_b); + let batch = cpb.finish().unwrap(); + + // Compute the expected resolved hash of pred_A + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_a_hash = Predicate::Custom(pred_a_ref).hash(); + let expected_pred_a_value = Value::from(pred_a_hash); + + // Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to + // Literal(pred_a_hash). Then make_statement_from_template_circuit should produce + // a statement with that literal value. + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_b_tmpl = &pred_b_ref.predicate().statements[0]; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + // Create the template target and batch id target + let st_tmpl_target = builder.add_virtual_statement_tmpl(true); + let batch_id = builder.add_virtual_hash(); + + // Normalize the template (this is what we're testing) + let normalized = normalize_st_tmpl_circuit(¶ms, &mut builder, &st_tmpl_target, batch_id); + + // Feed normalized template into statement generation + let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_target = + make_statement_from_template_circuit(¶ms, &mut builder, &normalized, &args_target); + + // Connect to expected output + let expected_st_target = builder.add_virtual_statement(false); + builder.connect_flattenable(&expected_st_target, &st_target); + + // Set witness + let mut pw = PartialWitness::::new(); + st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?; + pw.set_target_arr(&batch_id.elements, &batch.id().0)?; + + let some_value = Value::from(42); + // args: first wildcard is "x" = some_value, rest are padding + let mut args_values = vec![some_value.clone()]; + for _ in 1..params.max_custom_predicate_wildcards { + args_values.push(Value::from(EMPTY_VALUE)); + } + for (target, value) in args_target.iter().zip(args_values.iter()) { + target.set_targets(&mut pw, value)?; + } + + // Expected statement: Equal(Literal(some_value), Literal(pred_a_hash)) + let expected_st: crate::backends::plonky2::mainpod::Statement = + Statement::equal(some_value, expected_pred_a_value).into(); + expected_st_target.set_targets(&mut pw, &expected_st)?; + + // Build and verify + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) +} diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 5e9df2e..513b1da 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -148,14 +148,20 @@ pub(crate) fn extract_custom_predicate_verifications( Ok(table) } +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MerkleProofs { + pub(crate) medium: Vec, + pub(crate) small: Vec, +} + /// Extracts Merkle proofs from Contains/NotContains ops. pub(crate) fn extract_merkle_proofs( params: &Params, aux_list: &mut [OperationAux], operations: &[middleware::Operation], statements: &[middleware::Statement], -) -> Result> { - let mut table = Vec::new(); +) -> Result { + let mut tables = MerkleProofs::default(); for (i, (op, st)) in operations.iter().zip(statements.iter()).enumerate() { let deduction_err = || MiddlewareError::invalid_deduction(op.clone(), st.clone()); let (root, key, value, pf) = match (op, st) { @@ -178,31 +184,42 @@ pub(crate) fn extract_merkle_proofs( } _ => continue, }; - aux_list[i] = OperationAux::MerkleProofIndex(table.len()); - table.push(MerkleClaimAndProof::new( - Hash::from(root), - key, - value, - pf.clone(), - )); + let claim_proof = MerkleClaimAndProof::new(Hash::from(root), key, value, pf.clone()); + if pf.existence + // TODO: Make sure there's no off-by-one error here + && pf.siblings.len() <= params.containers.max_depth_small + && tables.small.len() < params.containers.state.max_small + { + aux_list[i] = OperationAux::MerkleProofIndex(Size::Small, tables.small.len()); + tables.small.push(claim_proof); + } else { + aux_list[i] = OperationAux::MerkleProofIndex(Size::Medium, tables.medium.len()); + tables.medium.push(claim_proof); + } } - if table.len() > params.max_merkle_proofs_containers { + if tables.medium.len() > params.containers.state.max_medium { return Err(Error::custom(format!( "The number of required Merkle proofs ({}) exceeds the maximum number ({}).", - table.len(), - params.max_merkle_proofs_containers + tables.medium.len(), + params.containers.state.max_medium ))); } - Ok(table) + Ok(tables) +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MerkleTransitionProofs { + pub(crate) medium: Vec, + pub(crate) small: Vec, } /// Extracts Merkle state transition proofs from container update ops. -pub(crate) fn extract_merkle_tree_state_transition_proofs( +pub(crate) fn extract_merkle_transition_proofs( params: &Params, aux_list: &mut [OperationAux], operations: &[middleware::Operation], -) -> Result> { - let mut table = Vec::new(); +) -> Result { + let mut tables = MerkleTransitionProofs::default(); for (i, op) in operations.iter().enumerate() { let pf = match op { middleware::Operation::ContainerInsertFromEntries(_, _, _, _, pf) @@ -210,17 +227,27 @@ pub(crate) fn extract_merkle_tree_state_transition_proofs( | middleware::Operation::ContainerDeleteFromEntries(_, _, _, pf) => pf.clone(), _ => continue, }; - aux_list[i] = OperationAux::MerkleTreeStateTransitionProofIndex(table.len()); - table.push(pf); + if pf.op_proof.existence + // TODO: Make sure there's no off-by-one error here + && pf.siblings.len() <= params.containers.max_depth_small + && tables.small.len() < params.containers.transition.max_small + { + aux_list[i] = OperationAux::MerkleTransitionProofIndex(Size::Small, tables.small.len()); + tables.small.push(pf); + } else { + aux_list[i] = + OperationAux::MerkleTransitionProofIndex(Size::Medium, tables.medium.len()); + tables.medium.push(pf); + } } - if table.len() > params.max_merkle_tree_state_transition_proofs_containers { + if tables.medium.len() > params.containers.transition.max_medium { return Err(Error::custom(format!( "The number of required Merkle proofs ({}) exceeds the maximum number ({}).", - table.len(), - params.max_merkle_tree_state_transition_proofs_containers + tables.medium.len(), + params.containers.transition.max_medium ))); } - Ok(table) + Ok(tables) } pub(crate) fn extract_public_key_of( @@ -513,6 +540,8 @@ impl MainPodProver for Prover { let mut aux_list = vec![OperationAux::None; params.max_priv_statements()]; let merkle_proofs = extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; + let merkle_transition_proofs = + extract_merkle_transition_proofs(params, &mut aux_list, inputs.operations)?; let custom_predicates = extract_custom_predicates(params, inputs.operations)?; let custom_predicate_verifications = extract_custom_predicate_verifications( params, @@ -537,9 +566,6 @@ impl MainPodProver for Prover { let signed_bys = extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?; - let merkle_tree_state_transition_proofs = - extract_merkle_tree_state_transition_proofs(params, &mut aux_list, inputs.operations)?; - let (statements, public_statements) = layout_statements(params, false, &inputs)?; let operations = process_private_statements_operations( params, @@ -572,20 +598,15 @@ impl MainPodProver for Prover { .collect_vec(); let mut vd_mt_proofs = Vec::with_capacity(inputs.pods.len()); + let pad_vd_mt_proof = inputs.vd_set.get_vds_proof_0(); for (pod, vd) in inputs.pods.iter().zip(&verifier_datas) { vd_mt_proofs.push(if pod.is_main() { - (true, inputs.vd_set.get_vds_proof(vd)?) + inputs.vd_set.get_vds_proof(vd)? } else { // For intro pods we don't verify inclusion of their vk into the vd set, so we - // generate a dummy mt proof with expected root and value to pass some constraints - ( - false, - MerkleClaimAndProof { - root: inputs.vd_set.root(), - value: RawValue::from(pod.verifier_data_hash()), - ..MerkleClaimAndProof::empty() - }, - ) + // use a valid vds proof that matches the expected root but not the value to pass + // the constraints + pad_vd_mt_proof.clone() }); } @@ -598,7 +619,7 @@ impl MainPodProver for Prover { merkle_proofs, public_key_of_sks, signed_bys, - merkle_tree_state_transition_proofs, + merkle_transition_proofs, custom_predicates_with_mpt_proofs, custom_predicate_verifications, }; @@ -985,7 +1006,18 @@ pub mod tests { max_statements: 2, max_public_statements: 1, max_input_pods_public_statements: 0, - max_merkle_proofs_containers: 0, + containers: middleware::ParamsContainers { + state: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 0, + }, + transition: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 0, + }, + max_depth_small: 8, + max_depth_medium: 32, + }, max_public_key_of: 0, max_custom_predicate_verifications: 0, max_custom_predicates: 0, @@ -1024,11 +1056,20 @@ pub mod tests { max_custom_predicates: 2, max_custom_predicate_verifications: 2, max_custom_predicate_wildcards: 3, - max_merkle_proofs_containers: 2, - max_merkle_tree_state_transition_proofs_containers: 2, max_public_key_of: 2, - max_depth_mt_containers: 4, max_depth_mt_vds: 6, + containers: middleware::ParamsContainers { + state: middleware::ParamsMerkleProofs { + max_small: 2, + max_medium: 2, + }, + transition: middleware::ParamsMerkleProofs { + max_small: 2, + max_medium: 2, + }, + max_depth_small: 2, + max_depth_medium: 4, + }, }; let mut vds = DEFAULT_VD_LIST.clone(); vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); @@ -1087,8 +1128,18 @@ pub mod tests { max_public_statements: 4, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, - max_merkle_proofs_containers: 3, - max_merkle_tree_state_transition_proofs_containers: 0, + containers: middleware::ParamsContainers { + state: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 3, + }, + transition: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 0, + }, + max_depth_small: 8, + max_depth_medium: 32, + }, ..Default::default() }; println!("{:#?}", params); @@ -1156,8 +1207,18 @@ pub mod tests { max_public_statements: 2, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, - max_merkle_proofs_containers: 0, - max_merkle_tree_state_transition_proofs_containers: 0, + containers: middleware::ParamsContainers { + state: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 0, + }, + transition: middleware::ParamsMerkleProofs { + max_small: 0, + max_medium: 0, + }, + max_depth_small: 8, + max_depth_medium: 32, + }, ..Default::default() }; let mut vds = DEFAULT_VD_LIST.clone(); diff --git a/src/backends/plonky2/mainpod/operation.rs b/src/backends/plonky2/mainpod/operation.rs index d7b44bb..2060ac7 100644 --- a/src/backends/plonky2/mainpod/operation.rs +++ b/src/backends/plonky2/mainpod/operation.rs @@ -5,8 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ error::{Error, Result}, - mainpod::{SignedBy, Statement}, - primitives::merkletree::{MerkleClaimAndProof, MerkleTreeStateTransitionProof}, + mainpod::{MerkleProofs, MerkleTransitionProofs, SignedBy, Statement}, }, middleware::{self, OperationType, Params}, }; @@ -30,50 +29,89 @@ impl OperationArg { } } +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] +pub enum Size { + Small, + Medium, +} + +impl fmt::Display for Size { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Small => write!(f, "small"), + Self::Medium => write!(f, "medium"), + } + } +} + +impl Size { + pub const fn min() -> Self { + Self::Small + } + pub const fn max() -> Self { + Self::Medium + } +} + #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub enum OperationAux { None, - MerkleProofIndex(usize), + MerkleProofIndex(Size, usize), + MerkleTransitionProofIndex(Size, usize), PublicKeyOfIndex(usize), SignedByIndex(usize), - MerkleTreeStateTransitionProofIndex(usize), CustomPredVerifyIndex(usize), } impl OperationAux { - fn table_offset_merkle_proof(_params: &Params) -> usize { - // At index 0 we store a zero entry - 1 + fn table_offset_merkle_proof(params: &Params, size: Size) -> usize { + match size { + // At index 0 we store a zero entry + Size::Small => 1, + Size::Medium => { + Self::table_offset_merkle_proof(params, Size::Small) + + params.containers.state.max_small + } + } + } + fn table_offset_merkle_transition_proof(params: &Params, size: Size) -> usize { + match size { + Size::Small => { + Self::table_offset_merkle_proof(params, Size::min()) + + params.containers.state.max_total() + } + Size::Medium => { + Self::table_offset_merkle_transition_proof(params, Size::Small) + + params.containers.transition.max_small + } + } + } + fn table_offset_custom_pred_verify(params: &Params) -> usize { + Self::table_offset_merkle_transition_proof(params, Size::min()) + + params.containers.transition.max_total() } fn table_offset_public_key_of(params: &Params) -> usize { - Self::table_offset_merkle_proof(params) + params.max_merkle_proofs_containers + Self::table_offset_custom_pred_verify(params) + params.max_custom_predicate_verifications } fn table_offset_signed_by(params: &Params) -> usize { Self::table_offset_public_key_of(params) + params.max_public_key_of } - fn table_offset_merkle_tree_state_transition_proof(params: &Params) -> usize { - Self::table_offset_signed_by(params) + params.max_signed_by - } - fn table_offset_custom_pred_verify(params: &Params) -> usize { - Self::table_offset_merkle_tree_state_transition_proof(params) - + params.max_merkle_tree_state_transition_proofs_containers - } pub(crate) fn table_size(params: &Params) -> usize { - 1 + params.max_merkle_proofs_containers + 1 + params.containers.state.max_total() + + params.containers.transition.max_total() + + params.max_custom_predicate_verifications + params.max_public_key_of + params.max_signed_by - + params.max_merkle_tree_state_transition_proofs_containers - + params.max_custom_predicate_verifications } pub fn table_index(&self, params: &Params) -> usize { match self { Self::None => 0, - Self::MerkleProofIndex(i) => Self::table_offset_merkle_proof(params) + *i, + Self::MerkleProofIndex(size, i) => Self::table_offset_merkle_proof(params, *size) + *i, + Self::MerkleTransitionProofIndex(size, i) => { + Self::table_offset_merkle_transition_proof(params, *size) + *i + } Self::PublicKeyOfIndex(i) => Self::table_offset_public_key_of(params) + *i, Self::SignedByIndex(i) => Self::table_offset_signed_by(params) + *i, - Self::MerkleTreeStateTransitionProofIndex(i) => { - Self::table_offset_merkle_tree_state_transition_proof(params) + *i - } Self::CustomPredVerifyIndex(i) => Self::table_offset_custom_pred_verify(params) + *i, } } @@ -96,8 +134,8 @@ impl Operation { &self, statements: &[Statement], signatures: &[SignedBy], - merkle_proofs: &[MerkleClaimAndProof], - merkle_tree_state_transition_proofs: &[MerkleTreeStateTransitionProof], + merkle_proofs: &MerkleProofs, + merkle_transition_proofs: &MerkleTransitionProofs, ) -> Result { let deref_args = self .1 @@ -113,17 +151,26 @@ impl Operation { .collect::>>()?; let deref_aux = match self.2 { OperationAux::None => crate::middleware::OperationAux::None, - OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None, - OperationAux::MerkleProofIndex(i) => crate::middleware::OperationAux::MerkleProof( - merkle_proofs - .get(i) - .ok_or(Error::custom(format!("Missing Merkle proof index {}", i)))? - .proof - .clone(), - ), - OperationAux::MerkleTreeStateTransitionProofIndex(i) => { + OperationAux::MerkleProofIndex(size, i) => { + let table = match size { + Size::Small => &merkle_proofs.small, + Size::Medium => &merkle_proofs.medium, + }; + crate::middleware::OperationAux::MerkleProof( + table + .get(i) + .ok_or(Error::custom(format!("Missing Merkle proof index {}", i)))? + .proof + .clone(), + ) + } + OperationAux::MerkleTransitionProofIndex(size, i) => { + let table = match size { + Size::Small => &merkle_transition_proofs.small, + Size::Medium => &merkle_transition_proofs.medium, + }; crate::middleware::OperationAux::MerkleTreeStateTransitionProof( - merkle_tree_state_transition_proofs + table .get(i) .ok_or(Error::custom(format!( "Missing Merkle state transition proof index {}", @@ -132,6 +179,7 @@ impl Operation { .clone(), ) } + OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None, OperationAux::SignedByIndex(i) => crate::middleware::OperationAux::Signature( signatures .get(i) @@ -165,12 +213,14 @@ impl fmt::Display for Operation { } match self.2 { OperationAux::None => (), - OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?, + OperationAux::MerkleProofIndex(size, i) => { + write!(f, " {}_merkle_proof_{:02}", size, i)? + } OperationAux::CustomPredVerifyIndex(i) => write!(f, " custom_pred_verify_{:02}", i)?, OperationAux::PublicKeyOfIndex(i) => write!(f, " public_key_of_{:02}", i)?, OperationAux::SignedByIndex(i) => write!(f, " signed_by_{:02}", i)?, - OperationAux::MerkleTreeStateTransitionProofIndex(i) => { - write!(f, " merkle_tree_state_transition_proof_{:02}", i)? + OperationAux::MerkleTransitionProofIndex(size, i) => { + write!(f, " {}_merkle_transition_proof_{:02}", size, i)? } } Ok(()) diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index b8c6a03..8dd710a 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -11,13 +11,12 @@ use crate::{ basetypes::{Proof, VerifierOnlyCircuitData}, error::{Error, Result}, mainpod::{ - calculate_statements_hash, extract_merkle_proofs, - extract_merkle_tree_state_transition_proofs, extract_signatures, layout_statements, - process_private_statements_operations, process_public_statements_operations, Operation, + calculate_statements_hash, extract_merkle_proofs, extract_merkle_transition_proofs, + extract_signatures, layout_statements, process_private_statements_operations, + process_public_statements_operations, MerkleProofs, MerkleTransitionProofs, Operation, OperationAux, SignedBy, Statement, }, mock::emptypod::MockEmptyPod, - primitives::merkletree::{MerkleClaimAndProof, MerkleTreeStateTransitionProof}, recursion::hash_verifier_data, }, middleware::{ @@ -45,10 +44,10 @@ pub struct MockMainPod { operations: Vec, // public subset of the `statements` vector public_statements: Vec, - // All Merkle proofs - merkle_proofs_containers: Vec, - // All Merkle tree state transition proofs - merkle_tree_state_transition_proofs_containers: Vec, + // All Merkle proofs for containers + merkle_proofs: MerkleProofs, + // All Merkle tree state transition proofs for containers + merkle_transition_proofs: MerkleTransitionProofs, // All verified signatures signatures: Vec, } @@ -124,8 +123,8 @@ struct Data { public_statements: Vec, operations: Vec, statements: Vec, - merkle_proofs: Vec, - merkle_tree_state_transition_proofs: Vec, + merkle_proofs: MerkleProofs, + merkle_transition_proofs: MerkleTransitionProofs, signatures: Vec, input_pods: Vec<(usize, Params, Hash, VDSet, serde_json::Value)>, } @@ -153,8 +152,8 @@ impl MockMainPod { let merkle_proofs = extract_merkle_proofs(params, &mut aux_list, inputs.operations, inputs.statements)?; // Similarly for Merkle state transition proofs. - let merkle_tree_state_transition_proofs = - extract_merkle_tree_state_transition_proofs(params, &mut aux_list, inputs.operations)?; + let merkle_transition_proofs = + extract_merkle_transition_proofs(params, &mut aux_list, inputs.operations)?; let signatures = extract_signatures(params, &mut aux_list, inputs.operations, inputs.statements)?; @@ -185,8 +184,8 @@ impl MockMainPod { public_statements, statements, operations, - merkle_proofs_containers: merkle_proofs, - merkle_tree_state_transition_proofs_containers: merkle_tree_state_transition_proofs, + merkle_proofs, + merkle_transition_proofs, signatures, }) } @@ -260,8 +259,8 @@ impl Pod for MockMainPod { .deref( &self.statements[..input_statement_offset + i], &self.signatures, - &self.merkle_proofs_containers, - &self.merkle_tree_state_transition_proofs_containers, + &self.merkle_proofs, + &self.merkle_transition_proofs, )? .check_and_log(&self.params, &s.clone().try_into()?) .map_err(|e| e.into()) @@ -321,10 +320,8 @@ impl Pod for MockMainPod { public_statements: self.public_statements.clone(), operations: self.operations.clone(), statements: self.statements.clone(), - merkle_proofs: self.merkle_proofs_containers.clone(), - merkle_tree_state_transition_proofs: self - .merkle_tree_state_transition_proofs_containers - .clone(), + merkle_proofs: self.merkle_proofs.clone(), + merkle_transition_proofs: self.merkle_transition_proofs.clone(), signatures: self.signatures.clone(), input_pods, }) @@ -344,7 +341,7 @@ impl Pod for MockMainPod { operations, statements, merkle_proofs, - merkle_tree_state_transition_proofs, + merkle_transition_proofs, signatures, input_pods, } = serde_json::from_value(data)?; @@ -362,8 +359,8 @@ impl Pod for MockMainPod { public_statements, operations, statements, - merkle_proofs_containers: merkle_proofs, - merkle_tree_state_transition_proofs_containers: merkle_tree_state_transition_proofs, + merkle_proofs, + merkle_transition_proofs, signatures, }) } diff --git a/src/backends/plonky2/primitives/merkletree/circuit.rs b/src/backends/plonky2/primitives/merkletree/circuit.rs index 2c54b8b..f53a143 100644 --- a/src/backends/plonky2/primitives/merkletree/circuit.rs +++ b/src/backends/plonky2/primitives/merkletree/circuit.rs @@ -42,8 +42,6 @@ use crate::{ #[derive(Clone, Debug, Serialize, Deserialize)] pub struct MerkleClaimAndProofTarget { pub(crate) max_depth: usize, - // `enabled` determines if the merkleproof verification is enabled - pub(crate) enabled: BoolTarget, pub(crate) root: HashOutTarget, pub(crate) key: ValueTarget, pub(crate) value: ValueTarget, @@ -121,16 +119,9 @@ pub fn verify_merkle_proof_circuit( let obtained_root = compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings); - // check that obtained_root==root (from inputs), when enabled==true - let zero = builder.zero(); - let expected_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(proof.enabled, proof.root.elements[j], zero)) - .collect(); - let computed_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero)) - .collect(); + // check that obtained_root==root (from inputs) for j in 0..HASH_SIZE { - builder.connect(computed_root[j], expected_root[j]); + builder.connect(obtained_root.elements[j], proof.root.elements[j]); } measure_gates_end!(builder, measure); } @@ -139,7 +130,6 @@ impl MerkleClaimAndProofTarget { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { MerkleClaimAndProofTarget { max_depth, - enabled: builder.add_virtual_bool_target_safe(), root: builder.add_virtual_hash(), key: builder.add_virtual_value(), value: builder.add_virtual_value(), @@ -154,12 +144,7 @@ impl MerkleClaimAndProofTarget { } /// assigns the given values to the targets #[allow(clippy::too_many_arguments)] - pub fn set_targets( - &self, - pw: &mut PartialWitness, - enabled: bool, - mp: &MerkleClaimAndProof, - ) -> Result<()> { + pub fn set_targets(&self, pw: &mut PartialWitness, mp: &MerkleClaimAndProof) -> Result<()> { if mp.proof.siblings.len() > self.max_depth { return Err(Error::Tree(TreeError::circuit_depth_too_small( self.max_depth, @@ -167,7 +152,6 @@ impl MerkleClaimAndProofTarget { ))); } - pw.set_bool_target(self.enabled, enabled)?; pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?; pw.set_target_arr(&self.key.elements, &mp.key.0)?; pw.set_target_arr(&self.value.elements, &mp.value.0)?; @@ -207,8 +191,6 @@ impl MerkleClaimAndProofTarget { #[derive(Clone, Serialize, Deserialize)] pub struct MerkleProofExistenceTarget { max_depth: usize, - // `enabled` determines if the merkleproof verification is enabled - pub(crate) enabled: BoolTarget, pub(crate) root: HashOutTarget, pub(crate) key: ValueTarget, pub(crate) value: ValueTarget, @@ -236,16 +218,9 @@ pub fn verify_merkle_proof_existence_circuit( let obtained_root = compute_root_from_leaf(max_depth, builder, &path, &leaf_hash, &proof.siblings); - // check that obtained_root==root (from inputs), when enabled==true - let zero = builder.zero(); - let expected_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(proof.enabled, proof.root.elements[j], zero)) - .collect(); - let computed_root: Vec = (0..HASH_SIZE) - .map(|j| builder.select(proof.enabled, obtained_root.elements[j], zero)) - .collect(); + // check that obtained_root==root (from inputs) for j in 0..HASH_SIZE { - builder.connect(computed_root[j], expected_root[j]); + builder.connect(obtained_root.elements[j], proof.root.elements[j]); } measure_gates_end!(builder, measure); @@ -256,7 +231,6 @@ impl MerkleProofExistenceTarget { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { MerkleProofExistenceTarget { max_depth, - enabled: builder.add_virtual_bool_target_safe(), root: builder.add_virtual_hash(), key: builder.add_virtual_value(), value: builder.add_virtual_value(), @@ -265,12 +239,7 @@ impl MerkleProofExistenceTarget { } } /// assigns the given values to the targets - pub fn set_targets( - &self, - pw: &mut PartialWitness, - enabled: bool, - mp: &MerkleClaimAndProof, - ) -> Result<()> { + pub fn set_targets(&self, pw: &mut PartialWitness, mp: &MerkleClaimAndProof) -> Result<()> { assert!(mp.proof.existence); // sanity check if mp.proof.siblings.len() > self.max_depth { return Err(Error::Tree(TreeError::circuit_depth_too_small( @@ -279,7 +248,6 @@ impl MerkleProofExistenceTarget { ))); } - pw.set_bool_target(self.enabled, enabled)?; pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?; pw.set_target_arr(&self.key.elements, &mp.key.0)?; pw.set_target_arr(&self.value.elements, &mp.value.0)?; @@ -456,8 +424,6 @@ fn hash_with_flag_target>( #[derive(Clone, Serialize, Deserialize)] pub struct MerkleTreeStateTransitionProofTarget { pub(crate) max_depth: usize, - // `enabled` determines if the merkleproof state transition verification is enabled - pub(crate) enabled: BoolTarget, pub(crate) op: Target, pub(crate) old_root: HashOutTarget, pub(crate) op_proof: MerkleClaimAndProofTarget, @@ -511,7 +477,6 @@ pub fn verify_merkle_state_transition_circuit( }; let new_key_proof = MerkleProofExistenceTarget { max_depth: proof.max_depth, - enabled: proof.enabled, root, key: proof.op_key, value: proof.op_value, @@ -523,13 +488,7 @@ pub fn verify_merkle_state_transition_circuit( // Insert/Delete: Non-existence // Update: Existence let proof_type = is_update; - builder.conditional_assert_eq( - proof.enabled.target, - proof.op_proof.existence.target, - proof_type.target, - ); - // 3.2) assert that proof.enabled matches with op_proof.enabled - builder.connect(proof.op_proof.enabled.target, proof.enabled.target); + builder.connect(proof.op_proof.existence.target, proof_type.target); // 4) assert proof_non_existence.root corresponds to the root // specified by the op (old_root for Insert/Update and new_root @@ -545,17 +504,9 @@ pub fn verify_merkle_state_transition_circuit( }; for j in 0..HASH_SIZE { // 4.1) assert that proof.proof_non_existence.root == proof.old_root - builder.conditional_assert_eq( - proof.enabled.target, - proof.op_proof.root.elements[j], - claim_root.elements[j], - ); + builder.connect(proof.op_proof.root.elements[j], claim_root.elements[j]); // 4.2) assert that the non-existence proof uses the op_key (value not needed). - builder.conditional_assert_eq( - proof.enabled.target, - proof.op_proof.key.elements[j], - proof.op_key.elements[j], - ); + builder.connect(proof.op_proof.key.elements[j], proof.op_key.elements[j]); } // prepare value for check 5.2) @@ -593,7 +544,7 @@ pub fn verify_merkle_state_transition_circuit( .map(|j| builder.select(is_divergence_level, zero, new_siblings[i].elements[j])) .collect(); for j in 0..HASH_SIZE { - builder.conditional_assert_eq(proof.enabled.target, old_sibling_i[j], new_sibling_i[j]); + builder.connect(old_sibling_i[j], new_sibling_i[j]); } // 5.2) when i==d && if old_siblings[i] != new_siblings[i], check that: @@ -611,7 +562,7 @@ pub fn verify_merkle_state_transition_circuit( let in_case_5_2 = builder.and(old_is_noteq_new, is_divergence_level); // do the case2's checks - let sel = builder.and(proof.enabled, in_case_5_2); + let sel = in_case_5_2; for j in 0..HASH_SIZE { builder.conditional_assert_eq(sel.target, old_siblings[i].elements[j], zero); builder.conditional_assert_eq( @@ -641,7 +592,6 @@ impl MerkleTreeStateTransitionProofTarget { pub fn new_virtual(max_depth: usize, builder: &mut CircuitBuilder) -> Self { Self { max_depth, - enabled: builder.add_virtual_bool_target_safe(), op: builder.add_virtual_target(), old_root: builder.add_virtual_hash(), @@ -661,7 +611,6 @@ impl MerkleTreeStateTransitionProofTarget { pub fn set_targets( &self, pw: &mut PartialWitness, - enabled: bool, mp: &MerkleTreeStateTransitionProof, ) -> Result<()> { let new_siblings = mp.siblings.clone(); @@ -672,13 +621,11 @@ impl MerkleTreeStateTransitionProofTarget { ))); } - pw.set_bool_target(self.enabled, enabled)?; pw.set_target(self.op, F::from_canonical_u8(mp.op as u8))?; pw.set_hash_target(self.old_root, HashOut::from_vec(mp.old_root.0.to_vec()))?; self.op_proof.set_targets( pw, - enabled, &MerkleClaimAndProof { root: if mp.op == MerkleTreeOp::Delete { mp.new_root @@ -859,7 +806,6 @@ pub mod tests { verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, - true, &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), )?; @@ -871,6 +817,42 @@ pub mod tests { Ok(()) } + #[test] + fn test_merkleproof_pad_valid() -> Result<()> { + // circuit + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::::new(); + + let targets = MerkleClaimAndProofTarget::new_virtual(32, &mut builder); + verify_merkle_proof_circuit(&mut builder, &targets); + targets.set_targets(&mut pw, &MerkleClaimAndProof::pad())?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) + } + + #[test] + fn test_merkleproof_transition_pad_valid() -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::::new(); + + let targets = MerkleTreeStateTransitionProofTarget::new_virtual(32, &mut builder); + verify_merkle_state_transition_circuit(&mut builder, &targets); + targets.set_targets(&mut pw, &MerkleTreeStateTransitionProof::pad())?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + Ok(()) + } + #[test] fn test_merkleproof_only_existence_verify() -> Result<()> { for max_depth in [10, 16, 32, 40, 64, 128, 130, 250, 256] { @@ -906,7 +888,6 @@ pub mod tests { verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, - true, &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), )?; @@ -982,7 +963,6 @@ pub mod tests { verify_merkle_proof_circuit(&mut builder, &targets); targets.set_targets( &mut pw, - true, &MerkleClaimAndProof::new(tree.root(), key, Some(value), proof), )?; @@ -1028,32 +1008,15 @@ pub mod tests { let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); verify_merkle_proof_circuit(&mut builder, &targets); - // verification enabled & proof of existence + // proof of existence let mp = MerkleClaimAndProof::new(tree2.root(), key, Some(value), proof); - targets.set_targets(&mut pw, true, &mp)?; + targets.set_targets(&mut pw, &mp)?; // generate proof, expecting it to fail (since we're using the wrong // root) let data = builder.build::(); assert!(data.prove(pw).is_err()); - // Now generate a new proof, using `enabled=false`, which should pass the verification - // despite containing 'wrong' witness. - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::::new(); - - let targets = MerkleClaimAndProofTarget::new_virtual(max_depth, &mut builder); - verify_merkle_proof_circuit(&mut builder, &targets); - // verification disabled & proof of existence - targets.set_targets(&mut pw, false, &mp)?; - - // generate proof, should pass despite using wrong witness, since the - // `enabled=false` - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - Ok(()) } @@ -1076,7 +1039,7 @@ pub mod tests { let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder); verify_merkle_state_transition_circuit(&mut builder, &targets); - targets.set_targets(&mut pw, true, state_transition_proof)?; + targets.set_targets(&mut pw, state_transition_proof)?; // generate & verify proof let data = builder.build::(); @@ -1273,71 +1236,4 @@ pub mod tests { assert_ne!(state_transition_proof.new_root, tree.root()); // Tamper check Ok(()) } - - #[test] - fn test_state_transition_gadget_disabled() -> Result<()> { - let max_depth: usize = 32; - let mut kvs = HashMap::new(); - for i in 0..8 { - kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); - } - let mut tree = MerkleTree::new(&kvs); - - let key = RawValue::from(37); - let value = RawValue::from(1037); - let _ = tree.insert(&key, &value)?; - - let key = RawValue::from(21); - let value = RawValue::from(1021); - let original_state_transition_proof = tree.insert(&key, &value)?; - - let mut state_transition_proof = original_state_transition_proof.clone(); - - // modify the proof, so that it should fail when `enabled=true`, by - // changing the new_root - state_transition_proof.new_root = state_transition_proof.old_root; - - run_circuit_disabled(max_depth, &state_transition_proof)?; - - // modify the proof, so that it should fail when `enabled=true`, by - // changing the new_sibling at the divergence level, which should not - // pass the verification in the case where we're inserting key=21 - let mut state_transition_proof = original_state_transition_proof.clone(); - state_transition_proof.siblings[4] = EMPTY_HASH; - - run_circuit_disabled(max_depth, &state_transition_proof)?; - - Ok(()) - } - - fn run_circuit_disabled( - max_depth: usize, - state_transition_proof: &MerkleTreeStateTransitionProof, - ) -> Result<()> { - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::::new(); - - let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder); - verify_merkle_state_transition_circuit(&mut builder, &targets); - targets.set_targets(&mut pw, true, state_transition_proof)?; - - // generate proof, and expect it to fail - let data = builder.build::(); - assert!(data.prove(pw).is_err()); // expect prove to fail - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::::new(); - - let targets = MerkleTreeStateTransitionProofTarget::new_virtual(max_depth, &mut builder); - verify_merkle_state_transition_circuit(&mut builder, &targets); - targets.set_targets(&mut pw, false, state_transition_proof)?; - - // generate and expect it to pass - let data = builder.build::(); - let proof = data.prove(pw)?; - data.verify(proof)?; - Ok(()) - } } diff --git a/src/backends/plonky2/primitives/merkletree/mod.rs b/src/backends/plonky2/primitives/merkletree/mod.rs index 0e29e14..e84da20 100644 --- a/src/backends/plonky2/primitives/merkletree/mod.rs +++ b/src/backends/plonky2/primitives/merkletree/mod.rs @@ -921,6 +921,21 @@ impl MerkleClaimAndProof { }, } } + /// Value used for padding. This is a valid merkle proof. + pub fn pad() -> Self { + let [key, value] = [EMPTY_VALUE, EMPTY_VALUE]; + let root = kv_hash(&key, Some(value)); + Self { + root, + key, + value, + proof: MerkleProof { + existence: true, + siblings: vec![], + other_leaf: None, + }, + } + } pub fn new(root: Hash, key: RawValue, value: Option, proof: MerkleProof) -> Self { Self { root, @@ -974,7 +989,6 @@ pub struct MerkleTreeStateTransitionProof { } impl MerkleTreeStateTransitionProof { - /// Value used for padding. pub fn empty() -> Self { let empty_proof_and_claim = MerkleClaimAndProof::empty(); Self { @@ -988,6 +1002,20 @@ impl MerkleTreeStateTransitionProof { siblings: vec![], } } + /// Value used for padding. This is a valid transition proof. + pub fn pad() -> Self { + let pad_proof_and_claim = MerkleClaimAndProof::pad(); + Self { + op: MerkleTreeOp::Update, + old_root: pad_proof_and_claim.root, + op_proof: pad_proof_and_claim.proof, + new_root: pad_proof_and_claim.root, + op_key: pad_proof_and_claim.key, + op_value: pad_proof_and_claim.value, + value: Some(pad_proof_and_claim.value), + siblings: vec![], + } + } } // NOTE: currently we use automatic serialization/deserialization, which is @@ -1165,6 +1193,15 @@ pub mod tests { Ok(()) } + #[test] + fn test_merkletree_pad() { + let claim = MerkleClaimAndProof::pad(); + MerkleTree::verify(claim.root, &claim.proof, &claim.key, &claim.value).unwrap(); + + let proof = MerkleTreeStateTransitionProof::pad(); + MerkleTree::verify_state_transition(&proof).unwrap(); + } + #[test] fn test_key_not_found() -> Result<()> { let db = Box::new(db::MemDB::new()); diff --git a/src/frontend/multi_pod/diagnostics.rs b/src/frontend/multi_pod/diagnostics.rs index 438f379..f56778f 100644 --- a/src/frontend/multi_pod/diagnostics.rs +++ b/src/frontend/multi_pod/diagnostics.rs @@ -78,12 +78,12 @@ fn aggregate_rows<'a>( UtilizationRow { name: "merkle proofs", used: merkle_proofs, - limit: params.max_merkle_proofs_containers, + limit: params.containers.state.max_medium, }, UtilizationRow { name: "merkle state transitions", used: merkle_state_transitions, - limit: params.max_merkle_tree_state_transition_proofs_containers, + limit: params.containers.transition.max_medium, }, UtilizationRow { name: "custom pred verifications", @@ -278,15 +278,24 @@ mod tests { use super::*; use crate::{ frontend::multi_pod::cost::CustomPredicateId, - middleware::{Hash, RawValue}, + middleware::{Hash, ParamsContainers, ParamsMerkleProofs, RawValue}, }; fn default_params() -> Params { Params { max_statements: 48, max_public_statements: 8, - max_merkle_proofs_containers: 8, - max_merkle_tree_state_transition_proofs_containers: 4, + containers: ParamsContainers { + state: ParamsMerkleProofs { + max_small: 0, + max_medium: 8, + }, + transition: ParamsMerkleProofs { + max_small: 0, + max_medium: 4, + }, + ..Default::default() + }, max_custom_predicate_verifications: 10, max_custom_predicates: 2, max_signed_by: 4, diff --git a/src/frontend/multi_pod/solver.rs b/src/frontend/multi_pod/solver.rs index db1502e..8d81ab3 100644 --- a/src/frontend/multi_pod/solver.rs +++ b/src/frontend/multi_pod/solver.rs @@ -395,13 +395,11 @@ pub fn solve(input: &SolverInput) -> Result { let lb_statement_groups = lower_bound_from_total(input.num_statements, max_stmts_per_pod); let lb_merkle = lower_bound_from_total( resource_totals.merkle_proofs, - input.params.max_merkle_proofs_containers, + input.params.containers.state.max_medium, ); let lb_merkle_transitions = lower_bound_from_total( resource_totals.merkle_state_transitions, - input - .params - .max_merkle_tree_state_transition_proofs_containers, + input.params.containers.transition.max_medium, ); let lb_custom_pred_verifications = lower_bound_from_total( resource_totals.custom_pred_verifications, @@ -753,7 +751,7 @@ fn try_solve_with_pods( .map(|s| (input.costs[s].merkle_proofs as f64) * prove[s][p]) .sum(); model.add_constraint(constraint!( - merkle_sum <= (input.params.max_merkle_proofs_containers as f64) * pod_used[p] + merkle_sum <= (input.params.containers.state.max_medium as f64) * pod_used[p] )); // 6d: Merkle state transitions @@ -761,11 +759,7 @@ fn try_solve_with_pods( .map(|s| (input.costs[s].merkle_state_transitions as f64) * prove[s][p]) .sum(); model.add_constraint(constraint!( - mst_sum - <= (input - .params - .max_merkle_tree_state_transition_proofs_containers as f64) - * pod_used[p] + mst_sum <= (input.params.containers.transition.max_medium as f64) * pod_used[p] )); // 6e: Custom predicate verifications diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 82675d7..d212ca8 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -780,6 +780,50 @@ pub const BASE_PARAMS: BaseParams = BaseParams { max_operation_args: 5 + 1, }; +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] +#[serde(rename_all = "camelCase")] +pub struct ParamsMerkleProofs { + pub max_small: usize, + pub max_medium: usize, +} + +impl ParamsMerkleProofs { + pub fn max_total(&self) -> usize { + self.max_small + self.max_medium + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] +#[serde(rename_all = "camelCase")] +pub struct ParamsContainers { + // Parameters for exists/nonexists container operations. The small set only supports exists + pub state: ParamsMerkleProofs, + // Parameters for transition container operations (insert, delete, update). The small set only + // supports update. + pub transition: ParamsMerkleProofs, + // Max depth of small proofs + pub max_depth_small: usize, + // Max depth of medium proofs + pub max_depth_medium: usize, +} + +impl Default for ParamsContainers { + fn default() -> Self { + Self { + state: ParamsMerkleProofs { + max_small: 22, + max_medium: 8, + }, + transition: ParamsMerkleProofs { + max_small: 12, + max_medium: 6, + }, + max_depth_small: 8, + max_depth_medium: 32, + } + } +} + /// Params: non dynamic parameters that define the circuit. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Hash)] #[serde(rename_all = "camelCase")] @@ -793,12 +837,7 @@ pub struct Params { // max number of operations using custom predicates that can be verified in the MainPod pub max_custom_predicate_verifications: usize, pub max_custom_predicate_wildcards: usize, - // maximum number of merkle proofs used for container operations - pub max_merkle_proofs_containers: usize, - // maximum number of merkle tree state transition proofs used for container update operations - pub max_merkle_tree_state_transition_proofs_containers: usize, - // maximum depth for merkle tree gadget used for container operations - pub max_depth_mt_containers: usize, + pub containers: ParamsContainers, // maximum depth of the merkle tree gadget used for verifier_data membership // check. This allows creating verifying sets of pod circuits of size // 2^max_depth_mt_vds. Limits the number of container operations of the type Contains, @@ -820,9 +859,7 @@ impl Default for Params { max_custom_predicates: 8, max_custom_predicate_verifications: 8, max_custom_predicate_wildcards: 8, - max_merkle_proofs_containers: 20, - max_merkle_tree_state_transition_proofs_containers: 6, - max_depth_mt_containers: 32, + containers: ParamsContainers::default(), max_depth_mt_vds: 6, // up to 64 (2^6) different pod circuits max_public_key_of: 2, max_signed_by: 4,