diff --git a/src/backends/plonky2/common.rs b/src/backends/plonky2/common.rs index 8272a6a..1522f05 100644 --- a/src/backends/plonky2/common.rs +++ b/src/backends/plonky2/common.rs @@ -1,12 +1,17 @@ //! Common functionality to build Pod circuits with plonky2 -use crate::middleware::STATEMENT_ARG_F_LEN; -use crate::middleware::{Params, Value, HASH_SIZE, VALUE_SIZE}; +use crate::backends::plonky2::mock_main::Statement; +use crate::backends::plonky2::mock_main::{Operation, OperationArg}; +use crate::middleware::{Params, StatementArg, ToFields, Value, F, HASH_SIZE, VALUE_SIZE}; +use crate::middleware::{OPERATION_ARG_F_LEN, STATEMENT_ARG_F_LEN}; +use anyhow::Result; use plonky2::field::extension::Extendable; -use plonky2::field::types::PrimeField64; +use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::hash_types::RichField; use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; +use std::iter; #[derive(Copy, Clone)] pub struct ValueTarget { @@ -15,25 +20,65 @@ pub struct ValueTarget { #[derive(Clone)] pub struct StatementTarget { - pub code: [Target; HASH_SIZE + 2], + pub predicate: [Target; Params::predicate_size()], pub args: Vec<[Target; STATEMENT_ARG_F_LEN]>, } impl StatementTarget { pub fn to_flattened(&self) -> Vec { - self.code + self.predicate .iter() .chain(self.args.iter().flatten()) .cloned() .collect() } + + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + st: &Statement, + ) -> Result<()> { + pw.set_target_arr(&self.predicate, &st.predicate().to_fields(params))?; + for (i, arg) in st + .args() + .iter() + .chain(iter::repeat(&StatementArg::None)) + .take(params.max_statement_args) + .enumerate() + { + pw.set_target_arr(&self.args[i], &arg.to_fields(params))?; + } + Ok(()) + } } // TODO: Implement Operation::to_field to determine the size of each element #[derive(Clone)] pub struct OperationTarget { - pub code: [Target; 6], // TODO: Figure out the length - pub args: Vec<[Target; STATEMENT_ARG_F_LEN]>, // TODO: Figure out the length + pub op_type: [Target; Params::operation_type_size()], + pub args: Vec<[Target; OPERATION_ARG_F_LEN]>, +} + +impl OperationTarget { + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + op: &Operation, + ) -> Result<()> { + pw.set_target_arr(&self.op_type, &op.op_type().to_fields(params))?; + for (i, arg) in op + .args() + .iter() + .chain(iter::repeat(&OperationArg::None)) + .take(params.max_operation_args) + .enumerate() + { + pw.set_target_arr(&self.args[i], &arg.to_fields(params))?; + } + Ok(()) + } } pub trait CircuitBuilderPod, const D: usize> { @@ -70,15 +115,20 @@ impl, const D: usize> CircuitBuilderPod fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget { StatementTarget { - code: self.add_virtual_target_arr::<6>(), + predicate: self.add_virtual_target_arr(), args: (0..params.max_statement_args) - .map(|_| self.add_virtual_target_arr::()) + .map(|_| self.add_virtual_target_arr()) .collect(), } } fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { - todo!() + OperationTarget { + op_type: self.add_virtual_target_arr(), + args: (0..params.max_operation_args) + .map(|_| self.add_virtual_target_arr()) + .collect(), + } } fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget { diff --git a/src/backends/plonky2/main.rs b/src/backends/plonky2/main.rs index 88e37c2..e419aad 100644 --- a/src/backends/plonky2/main.rs +++ b/src/backends/plonky2/main.rs @@ -2,11 +2,14 @@ use crate::backends::plonky2::basetypes::{Hash, Value, D, EMPTY_HASH, EMPTY_VALU use crate::backends::plonky2::common::{ CircuitBuilderPod, OperationTarget, StatementTarget, ValueTarget, }; -use crate::backends::plonky2::primitives::merkletree::MerkleProofExistenceCircuit; +use crate::backends::plonky2::mock_main::Operation; use crate::backends::plonky2::primitives::merkletree::{MerkleProof, MerkleTree}; +use crate::backends::plonky2::primitives::merkletree::{ + MerkleProofExistenceGate, MerkleProofExistenceTarget, +}; use crate::middleware::{ - hash_str, AnchoredKey, NativeOperation, NativePredicate, Operation, Params, PodType, Predicate, - Statement, StatementArg, ToFields, KEY_TYPE, SELF, STATEMENT_ARG_F_LEN, + hash_str, AnchoredKey, NativeOperation, NativePredicate, Params, PodType, Predicate, Statement, + StatementArg, ToFields, KEY_TYPE, SELF, STATEMENT_ARG_F_LEN, }; use anyhow::Result; use itertools::Itertools; @@ -23,9 +26,7 @@ use plonky2::{ plonk::circuit_builder::CircuitBuilder, }; use std::collections::HashMap; - -/// MerkleTree Max Depth -const MD: usize = 32; +use std::iter; // // SignedPod verification @@ -41,7 +42,10 @@ impl SignedPodVerifyGate { let id = builder.add_virtual_hash(); let mut mt_proofs = Vec::new(); for _ in 0..self.params.max_signed_pod_values { - let mt_proof = MerkleProofExistenceCircuit::::add_targets(builder)?; + let mt_proof = MerkleProofExistenceGate { + max_depth: self.params.max_depth_mt_gate, + } + .eval(builder)?; builder.connect_hashes(id, mt_proof.root); mt_proofs.push(mt_proof); } @@ -68,7 +72,7 @@ struct SignedPodVerifyTarget { id: HashOutTarget, // The KEY_TYPE entry must be the first one // The KEY_SIGNER entry must be the second one - mt_proofs: Vec>, + mt_proofs: Vec, } struct SignedPodVerifyInput { @@ -94,16 +98,22 @@ impl SignedPodVerifyTarget { fn set_targets(&self, pw: &mut PartialWitness, input: &SignedPodVerifyInput) -> Result<()> { assert!(input.kvs.len() <= self.params.max_signed_pod_values); - let tree = MerkleTree::new(MD, &input.kvs)?; - for (i, (k, v)) in input.kvs.iter().sorted_by_key(|kv| kv.0).enumerate() { + let tree = MerkleTree::new(self.params.max_depth_mt_gate, &input.kvs)?; + + // First handle the type entry, then the rest of the entries, and finally pad with + // repetitions of the type entry (which always exists) + let mut kvs = input.kvs.clone(); + let key_type = Value::from(hash_str(KEY_TYPE)); + let value_type = kvs.remove(&key_type).expect("KEY_TYPE"); + + for (i, (k, v)) in iter::once((key_type, value_type)) + .chain(kvs.into_iter().sorted_by_key(|kv| kv.0)) + .chain(iter::repeat((key_type, value_type))) + .take(self.params.max_signed_pod_values) + .enumerate() + { let (_, proof) = tree.prove(&k)?; - self.mt_proofs[i].set_targets(pw, tree.root(), proof, *k, *v)?; - } - // Padding - for i in input.kvs.len()..self.params.max_signed_pod_values { - // TODO: We need to disable the proofs for the unused slots. We could add a flag - // "enable" to the MerkleTree proof circuit that skips the verification when false. - // self.mt_proofs[i].set_targets(pw, false, EMPTY_HASH, proof, *k, *v)?; + self.mt_proofs[i].set_targets(pw, tree.root(), proof, k, v)?; } Ok(()) } @@ -125,34 +135,55 @@ impl OperationVerifyGate { op: &OperationTarget, prev_statements: &[StatementTarget], ) -> Result { + let _true = builder._true(); + let _false = builder._false(); + let one = builder.constant(F::ONE); + // Verify that the operation `op` correctly generates the statement `st`. The operation // can reference any of the `prev_statements`. // The verification may require aux data which needs to be stored in the // `OperationVerifyTarget` so that we can set during witness generation. - // TODO: Figure out the right encoding of op.code + // For now only support native operations + builder.connect(op.op_type[0], one); + let native_op = op.op_type[1]; + + let mut op_flags = Vec::new(); let op_none = builder.constant(F::from_canonical_u64(NativeOperation::None as u64)); - let is_none = builder.is_equal(op.code[0], op_none); + let is_none = builder.is_equal(native_op, op_none); + op_flags.push(is_none); let op_new_entry = builder.constant(F::from_canonical_u64(NativeOperation::NewEntry as u64)); - let is_new_entry = builder.is_equal(op.code[0], op_new_entry); + let is_new_entry = builder.is_equal(native_op, op_new_entry); + op_flags.push(is_new_entry); let op_copy_statement = builder.constant(F::from_canonical_u64(NativeOperation::CopyStatement as u64)); - let is_copy_statement = builder.is_equal(op.code[0], op_copy_statement); + let is_copy_statement = builder.is_equal(native_op, op_copy_statement); + op_flags.push(is_copy_statement); let op_eq_from_entries = builder.constant(F::from_canonical_u64( NativeOperation::EqualFromEntries as u64, )); - let is_eq_from_entries = builder.is_equal(op.code[0], op_eq_from_entries); - let op_gt_from_entries = - builder.constant(F::from_canonical_u64(NativeOperation::GtFromEntries as u64)); - let is_gt_from_entries = builder.is_equal(op.code[0], op_gt_from_entries); + let is_eq_from_entries = builder.is_equal(native_op, op_eq_from_entries); + op_flags.push(is_eq_from_entries); let op_lt_from_entries = builder.constant(F::from_canonical_u64(NativeOperation::LtFromEntries as u64)); - let is_lt_from_entries = builder.is_equal(op.code[0], op_lt_from_entries); - let op_contains_from_entries = builder.constant(F::from_canonical_u64( - NativeOperation::ContainsFromEntries as u64, + let is_lt_from_entries = builder.is_equal(native_op, op_lt_from_entries); + op_flags.push(is_lt_from_entries); + let op_not_contains_from_entries = builder.constant(F::from_canonical_u64( + NativeOperation::NotContainsFromEntries as u64, )); - let is_contains_from_entries = builder.is_equal(op.code[0], op_contains_from_entries); + let is_not_contains_from_entries = + builder.is_equal(native_op, op_not_contains_from_entries); + op_flags.push(is_not_contains_from_entries); + + // One supported operation must be used. We sum all operation flags and expect the result + // to be 1. Since the flags are boolean and at most one of them is true the sum is + // equivalent to the OR. + let or_op_flags = op_flags + .iter() + .map(|b| b.target) + .fold(_false.target, |acc, x| builder.add(acc, x)); + builder.connect(or_op_flags, _true.target); let ok = builder._true(); let none_ok = self.eval_none(builder, st, op); @@ -160,10 +191,9 @@ impl OperationVerifyGate { let new_entry_ok = self.eval_new_entry(builder, st, op); let ok = builder.select_bool(is_new_entry, new_entry_ok, ok); - let _true = builder._true(); builder.connect(ok.target, _true.target); - todo!() + Ok(OperationVerifyTarget {}) } fn eval_none( @@ -184,14 +214,14 @@ impl OperationVerifyGate { _op: &OperationTarget, ) -> BoolTarget { let value_of_st = &Statement::ValueOf(AnchoredKey(SELF, EMPTY_HASH), EMPTY_VALUE); - let expected_code = + let expected_predicate = builder.constants(&Predicate::Native(NativePredicate::ValueOf).to_fields(&self.params)); - let code_ok = builder.is_equal_slice(&st.code, &expected_code); + let predicate_ok = builder.is_equal_slice(&st.predicate, &expected_predicate); let expected_arg_prefix = builder.constants( &StatementArg::Key(AnchoredKey(SELF, EMPTY_HASH)).to_fields(&self.params)[..VALUE_SIZE], ); let arg_prefix_ok = builder.is_equal_slice(&st.args[0][..VALUE_SIZE], &expected_arg_prefix); - builder.and(code_ok, arg_prefix_ok) + builder.and(predicate_ok, arg_prefix_ok) } } @@ -199,13 +229,14 @@ struct OperationVerifyTarget { // TODO } -struct OperationVerifyInputs { +struct OperationVerifyInput { // TODO } impl OperationVerifyTarget { - fn set_targets(&self, pw: &mut PartialWitness, input: &OperationVerifyInputs) -> Result<()> { - todo!() + fn set_targets(&self, pw: &mut PartialWitness, input: &OperationVerifyInput) -> Result<()> { + // TODO + Ok(()) } } @@ -246,14 +277,13 @@ impl MainPodVerifyGate { // 2. Calculate the Pod Id from the public statements let pub_statements_flattened = pub_statements .iter() - .map(|s| s.code.iter().chain(s.args.iter().flatten())) + .map(|s| s.predicate.iter().chain(s.args.iter().flatten())) .flatten() .cloned() .collect(); let id = builder.hash_n_to_hash_no_pad::(pub_statements_flattened); - // 3. TODO check that all `input_statements` of type `ValueOf` with origin=SELF have unique - // keys (no duplicates) + // 3. TODO check that all `input_statements` of type `ValueOf` with origin=SELF have unique keys (no duplicates). Maybe we can do this via the NewEntry operation (check that the key doesn't exist in a previous statement with ID=SELF) // 4. Verify type let type_statement = &pub_statements[0]; @@ -324,12 +354,12 @@ impl MainPodVerifyTarget { } } -struct MainPodVerifyCircuit { - params: Params, +pub struct MainPodVerifyCircuit { + pub params: Params, } impl MainPodVerifyCircuit { - fn eval(&self, builder: &mut CircuitBuilder) -> Result { + pub fn eval(&self, builder: &mut CircuitBuilder) -> Result { let main_pod = MainPodVerifyGate { params: self.params.clone(), } @@ -338,3 +368,95 @@ impl MainPodVerifyCircuit { Ok(main_pod) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::backends::plonky2::basetypes::C; + use crate::backends::plonky2::mock_main; + use crate::middleware::OperationType; + use plonky2::plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}; + + #[test] + fn test_signed_pod_verify() -> Result<()> { + let params = Params::default(); + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let signed_pod_verify = SignedPodVerifyGate { params }.eval(&mut builder)?; + + let mut pw = PartialWitness::::new(); + let kvs = [ + ( + Value::from(hash_str(KEY_TYPE)), + Value::from(PodType::MockSigned), + ), + (Value::from(hash_str("foo")), Value::from(42)), + ] + .into(); + let input = SignedPodVerifyInput { kvs }; + signed_pod_verify.set_targets(&mut pw, &input)?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) + } + + fn operation_verify( + st: mock_main::Statement, + op: mock_main::Operation, + prev_statements: Vec, + ) -> Result<()> { + let params = Params::default(); + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let st_target = builder.add_virtual_statement(¶ms); + let op_target = builder.add_virtual_operation(¶ms); + let prev_statements_target: Vec<_> = (0..prev_statements.len()) + .map(|_| builder.add_virtual_statement(¶ms)) + .collect(); + + let operation_verify = OperationVerifyGate { + params: params.clone(), + } + .eval( + &mut builder, + &st_target, + &op_target, + &prev_statements_target, + )?; + + let mut pw = PartialWitness::::new(); + st_target.set_targets(&mut pw, ¶ms, &st)?; + op_target.set_targets(&mut pw, ¶ms, &op)?; + for (prev_st_target, prev_st) in prev_statements_target.iter().zip(prev_statements.iter()) { + prev_st_target.set_targets(&mut pw, ¶ms, prev_st)?; + } + let input = OperationVerifyInput {}; + operation_verify.set_targets(&mut pw, &input)?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) + } + + #[test] + fn test_operation_verify() -> Result<()> { + // None + let st: mock_main::Statement = Statement::None.into(); + let op = mock_main::Operation(OperationType::Native(NativeOperation::None), vec![]); + let prev_statements = vec![Statement::None.into()]; + operation_verify(st, op, prev_statements)?; + + Ok(()) + } +} diff --git a/src/backends/plonky2/mock_main/mod.rs b/src/backends/plonky2/mock_main/mod.rs index 030ebe4..4583ab5 100644 --- a/src/backends/plonky2/mock_main/mod.rs +++ b/src/backends/plonky2/mock_main/mod.rs @@ -260,7 +260,7 @@ impl MockMainPod { .map(|mid_arg| Self::find_op_arg(statements, mid_arg)) .collect::>>()?; Self::pad_operation_args(params, &mut args); - operations.push(Operation(op.predicate(), args)); + operations.push(Operation(op.op_type(), args)); } Ok(operations) } diff --git a/src/backends/plonky2/mock_main/operation.rs b/src/backends/plonky2/mock_main/operation.rs index 460f739..631407a 100644 --- a/src/backends/plonky2/mock_main/operation.rs +++ b/src/backends/plonky2/mock_main/operation.rs @@ -1,6 +1,7 @@ use super::Statement; -use crate::middleware::{self, OperationType}; +use crate::middleware::{self, OperationType, Params, ToFields, F}; use anyhow::Result; +use plonky2::field::types::{Field, PrimeField64}; use serde::{Deserialize, Serialize}; use std::fmt; @@ -10,6 +11,16 @@ pub enum OperationArg { Index(usize), } +impl ToFields for OperationArg { + fn to_fields(&self, _params: &Params) -> Vec { + let f = match self { + Self::None => F::ZERO, + Self::Index(i) => F::from_canonical_usize(*i), + }; + vec![f] + } +} + impl OperationArg { pub fn is_none(&self) -> bool { matches!(self, OperationArg::None) @@ -20,6 +31,12 @@ impl OperationArg { pub struct Operation(pub OperationType, pub Vec); impl Operation { + pub fn op_type(&self) -> OperationType { + self.0.clone() + } + pub fn args(&self) -> &[OperationArg] { + &self.1 + } pub fn deref(&self, statements: &[Statement]) -> Result { let deref_args = self .1 diff --git a/src/backends/plonky2/mock_main/statement.rs b/src/backends/plonky2/mock_main/statement.rs index 0bbcf9e..cb0dd6a 100644 --- a/src/backends/plonky2/mock_main/statement.rs +++ b/src/backends/plonky2/mock_main/statement.rs @@ -13,6 +13,9 @@ impl Statement { pub fn is_none(&self) -> bool { self.0 == Predicate::Native(NativePredicate::None) } + pub fn predicate(&self) -> Predicate { + self.0.clone() + } /// Argument method. Trailing Nones are filtered out. pub fn args(&self) -> Vec { let maybe_last_arg_index = (0..self.1.len()).rev().find(|i| !self.1[*i].is_none()); @@ -96,7 +99,7 @@ impl TryFrom for middleware::Statement { impl From for Statement { fn from(s: middleware::Statement) -> Self { - match s.code() { + match s.predicate() { middleware::Predicate::Native(c) => Statement( middleware::Predicate::Native(c), s.args().into_iter().collect(), diff --git a/src/backends/plonky2/primitives/merkletree_circuit.rs b/src/backends/plonky2/primitives/merkletree_circuit.rs index 6da565b..c907a1a 100644 --- a/src/backends/plonky2/primitives/merkletree_circuit.rs +++ b/src/backends/plonky2/primitives/merkletree_circuit.rs @@ -29,11 +29,16 @@ use crate::backends::plonky2::common::{ }; use crate::backends::plonky2::primitives::merkletree::MerkleProof; -/// `MerkleProofCircuit` allows to verify both proofs of existence and proofs +/// `MerkleProofGate` allows to verify both proofs of existence and proofs /// non-existence with the same circuit. -/// If only proofs of existence are needed, use `MerkleProofExistenceCircuit`, +/// If only proofs of existence are needed, use `MerkleProofExistenceGate`, /// which requires less amount of constraints. -pub struct MerkleProofCircuit { +pub struct MerkleProofGate { + pub max_depth: usize, +} + +pub struct MerkleProofTarget { + max_depth: usize, pub root: HashOutTarget, pub key: ValueTarget, pub value: ValueTarget, @@ -44,16 +49,16 @@ pub struct MerkleProofCircuit { pub other_value: ValueTarget, } -impl MerkleProofCircuit { +impl MerkleProofGate { /// creates the targets and defines the logic of the circuit - pub fn add_targets(builder: &mut CircuitBuilder) -> Result { + pub fn eval(&self, builder: &mut CircuitBuilder) -> Result { // create the targets let key = builder.add_virtual_value(); let value = builder.add_virtual_value(); // from proof struct: let existence = builder.add_virtual_bool_target_safe(); - // siblings are padded till MAX_DEPTH length - let siblings = builder.add_virtual_hashes(MAX_DEPTH); + // siblings are padded till max_depth length + let siblings = builder.add_virtual_hashes(self.max_depth); let case_ii_selector = builder.add_virtual_bool_target_safe(); let other_key = builder.add_virtual_value(); @@ -107,16 +112,17 @@ impl MerkleProofCircuit { ); // get key's path - let path = keypath_target::(builder, &key); + let path = keypath_target(self.max_depth, builder, &key); // compute the root for the given siblings and the computed leaf_hash // (this is for the three cases (existence, non-existence case i, and // non-existence case ii). // This root will be assigned in the `set_targets` method, and it is a // public input. - let root = compute_root_from_leaf::(builder, &path, &leaf_hash, &siblings)?; + let root = compute_root_from_leaf(self.max_depth, builder, &path, &leaf_hash, &siblings)?; - Ok(Self { + Ok(MerkleProofTarget { + max_depth: self.max_depth, existence, root, siblings, @@ -127,7 +133,9 @@ impl MerkleProofCircuit { other_value, }) } +} +impl MerkleProofTarget { /// assigns the given values to the targets pub fn set_targets( &self, @@ -143,9 +151,9 @@ impl MerkleProofCircuit { pw.set_target_arr(&self.value.elements, &value.0)?; pw.set_bool_target(self.existence, existence)?; - // pad siblings with zeros to length MAX_DEPTH + // pad siblings with zeros to length max_depth let mut siblings = proof.siblings.clone(); - siblings.resize(MAX_DEPTH, EMPTY_HASH); + siblings.resize(self.max_depth, EMPTY_HASH); assert_eq!(self.siblings.len(), siblings.len()); for (i, sibling) in siblings.iter().enumerate() { @@ -173,41 +181,49 @@ impl MerkleProofCircuit { /// `MerkleProofExistenceCircuit` allows to verify proofs of existence only. If /// proofs of non-existence are needed, use `MerkleProofCircuit`. -pub struct MerkleProofExistenceCircuit { +pub struct MerkleProofExistenceGate { + pub max_depth: usize, +} + +pub struct MerkleProofExistenceTarget { + max_depth: usize, pub root: HashOutTarget, pub key: ValueTarget, pub value: ValueTarget, pub siblings: Vec, } -impl MerkleProofExistenceCircuit { +impl MerkleProofExistenceGate { /// creates the targets and defines the logic of the circuit - pub fn add_targets(builder: &mut CircuitBuilder) -> Result { + pub fn eval(&self, builder: &mut CircuitBuilder) -> Result { // create the targets let key = builder.add_virtual_value(); let value = builder.add_virtual_value(); - // siblings are padded till MAX_DEPTH length - let siblings = builder.add_virtual_hashes(MAX_DEPTH); + // siblings are padded till max_depth length + let siblings = builder.add_virtual_hashes(self.max_depth); // get leaf's hash for the selected k & v let leaf_hash = kv_hash_target(builder, &key, &value); // get key's path - let path = keypath_target::(builder, &key); + let path = keypath_target(self.max_depth, builder, &key); // compute the root for the given siblings and the computed leaf_hash. // This root will be assigned in the `set_targets` method, and it is a // public input. - let root = compute_root_from_leaf::(builder, &path, &leaf_hash, &siblings)?; + let root = compute_root_from_leaf(self.max_depth, builder, &path, &leaf_hash, &siblings)?; - Ok(Self { + Ok(MerkleProofExistenceTarget { + max_depth: self.max_depth, root, siblings, key, value, }) } +} +impl MerkleProofExistenceTarget { /// assigns the given values to the targets pub fn set_targets( &self, @@ -221,9 +237,9 @@ impl MerkleProofExistenceCircuit { pw.set_target_arr(&self.key.elements, &key.0)?; pw.set_target_arr(&self.value.elements, &value.0)?; - // pad siblings with zeros to length MAX_DEPTH + // pad siblings with zeros to length max_depth let mut siblings = proof.siblings.clone(); - siblings.resize(MAX_DEPTH, EMPTY_HASH); + siblings.resize(self.max_depth, EMPTY_HASH); assert_eq!(self.siblings.len(), siblings.len()); for (i, sibling) in siblings.iter().enumerate() { @@ -234,13 +250,14 @@ impl MerkleProofExistenceCircuit { } } -fn compute_root_from_leaf( +fn compute_root_from_leaf( + max_depth: usize, builder: &mut CircuitBuilder, path: &Vec, leaf_hash: &HashOutTarget, siblings: &Vec, ) -> Result { - assert_eq!(siblings.len(), MAX_DEPTH); + assert_eq!(siblings.len(), max_depth); // Convenience constants let zero = builder.zero(); let one = builder.one(); @@ -295,12 +312,13 @@ fn compute_root_from_leaf( // Note: this logic is in its own method for easy of reusability but // specially to be able to test it isolated. -fn keypath_target( +fn keypath_target( + max_depth: usize, builder: &mut CircuitBuilder, key: &ValueTarget, ) -> Vec { - let n_complete_field_elems: usize = MAX_DEPTH / F::BITS; - let n_extra_bits: usize = MAX_DEPTH - n_complete_field_elems * F::BITS; + let n_complete_field_elems: usize = max_depth / F::BITS; + let n_extra_bits: usize = max_depth - n_complete_field_elems * F::BITS; let path: Vec = key .elements @@ -351,41 +369,35 @@ pub mod tests { #[test] fn test_keypath() -> Result<()> { - test_keypath_opt::<10>()?; - test_keypath_opt::<16>()?; - test_keypath_opt::<32>()?; - test_keypath_opt::<40>()?; - test_keypath_opt::<64>()?; - test_keypath_opt::<128>()?; - test_keypath_opt::<130>()?; - test_keypath_opt::<250>()?; - test_keypath_opt::<256>()?; + for max_depth in [10, 16, 32, 40, 64, 128, 130, 250, 256] { + test_keypath_opt(max_depth)?; + } Ok(()) } - fn test_keypath_opt() -> Result<()> { + fn test_keypath_opt(max_depth: usize) -> Result<()> { for i in 0..5 { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); let key = Value::from(hash_value(&Value::from(i))); - let expected_path = keypath(MD, key)?; + let expected_path = keypath(max_depth, key)?; // small circuit logic to check // expected_path_targ==keypath_target(key_targ) - let expected_path_targ: Vec = (0..MD) + let expected_path_targ: Vec = (0..max_depth) .map(|_| builder.add_virtual_bool_target_safe()) .collect(); let key_targ = builder.add_virtual_value(); - let computed_path_targ = keypath_target::(&mut builder, &key_targ); - for i in 0..MD { + let computed_path_targ = keypath_target(max_depth, &mut builder, &key_targ); + for i in 0..max_depth { builder.connect(computed_path_targ[i].target, expected_path_targ[i].target); } // assign the input values to the targets pw.set_target_arr(&key_targ.elements, &key.0)?; - for i in 0..MD { + for i in 0..max_depth { pw.set_bool_target(expected_path_targ[i], expected_path[i])?; } @@ -431,40 +443,28 @@ pub mod tests { #[test] fn test_merkleproof_verify_existence() -> Result<()> { - test_merkleproof_verify_opt::<10>(true)?; - test_merkleproof_verify_opt::<16>(true)?; - test_merkleproof_verify_opt::<32>(true)?; - test_merkleproof_verify_opt::<40>(true)?; - test_merkleproof_verify_opt::<64>(true)?; - test_merkleproof_verify_opt::<128>(true)?; - test_merkleproof_verify_opt::<130>(true)?; - test_merkleproof_verify_opt::<250>(true)?; - test_merkleproof_verify_opt::<256>(true)?; + for max_depth in [10, 16, 32, 40, 64, 128, 130, 250, 256] { + test_merkleproof_verify_opt(max_depth, true)?; + } Ok(()) } #[test] fn test_merkleproof_verify_nonexistence() -> Result<()> { - test_merkleproof_verify_opt::<10>(false)?; - test_merkleproof_verify_opt::<16>(false)?; - test_merkleproof_verify_opt::<32>(false)?; - test_merkleproof_verify_opt::<40>(false)?; - test_merkleproof_verify_opt::<64>(false)?; - test_merkleproof_verify_opt::<128>(false)?; - test_merkleproof_verify_opt::<130>(false)?; - test_merkleproof_verify_opt::<250>(false)?; - test_merkleproof_verify_opt::<256>(false)?; + for max_depth in [10, 16, 32, 40, 64, 128, 130, 250, 256] { + test_merkleproof_verify_opt(max_depth, false)?; + } Ok(()) } // test logic to be reused both by the existence & nonexistence tests - fn test_merkleproof_verify_opt(existence: bool) -> Result<()> { + fn test_merkleproof_verify_opt(max_depth: usize, existence: bool) -> Result<()> { let mut kvs: HashMap = HashMap::new(); for i in 0..10 { kvs.insert(Value::from(hash_value(&Value::from(i))), Value::from(i)); } - let tree = MerkleTree::new(MD, &kvs)?; + let tree = MerkleTree::new(max_depth, &kvs)?; let (key, value, proof) = if existence { let key = Value::from(hash_value(&Value::from(5))); @@ -478,9 +478,9 @@ pub mod tests { assert_eq!(proof.existence, existence); if existence { - MerkleTree::verify(MD, tree.root(), &proof, &key, &value)?; + MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; } else { - MerkleTree::verify_nonexistence(MD, tree.root(), &proof, &key)?; + MerkleTree::verify_nonexistence(max_depth, tree.root(), &proof, &key)?; } // circuit @@ -488,7 +488,7 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = MerkleProofCircuit::::add_targets(&mut builder)?; + let targets = MerkleProofGate { max_depth }.eval(&mut builder)?; targets.set_targets(&mut pw, existence, tree.root(), proof, key, value)?; // generate & verify proof @@ -501,39 +501,33 @@ pub mod tests { #[test] fn test_merkleproof_only_existence_verify() -> Result<()> { - test_merkleproof_only_existence_verify_opt::<10>()?; - test_merkleproof_only_existence_verify_opt::<16>()?; - test_merkleproof_only_existence_verify_opt::<32>()?; - test_merkleproof_only_existence_verify_opt::<40>()?; - test_merkleproof_only_existence_verify_opt::<64>()?; - test_merkleproof_only_existence_verify_opt::<128>()?; - test_merkleproof_only_existence_verify_opt::<130>()?; - test_merkleproof_only_existence_verify_opt::<250>()?; - test_merkleproof_only_existence_verify_opt::<256>()?; + for max_depth in [10, 16, 32, 40, 64, 128, 130, 250, 256] { + test_merkleproof_only_existence_verify_opt(max_depth)?; + } Ok(()) } - fn test_merkleproof_only_existence_verify_opt() -> Result<()> { + fn test_merkleproof_only_existence_verify_opt(max_depth: usize) -> Result<()> { let mut kvs: HashMap = HashMap::new(); for i in 0..10 { kvs.insert(Value::from(hash_value(&Value::from(i))), Value::from(i)); } - let tree = MerkleTree::new(MD, &kvs)?; + let tree = MerkleTree::new(max_depth, &kvs)?; let key = Value::from(hash_value(&Value::from(5))); let (value, proof) = tree.prove(&key)?; assert_eq!(value, Value::from(5)); assert_eq!(proof.existence, true); - MerkleTree::verify(MD, tree.root(), &proof, &key, &value)?; + MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; // circuit let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = MerkleProofExistenceCircuit::::add_targets(&mut builder)?; + let targets = MerkleProofExistenceGate { max_depth }.eval(&mut builder)?; targets.set_targets(&mut pw, tree.root(), proof, key, value)?; // generate & verify proof @@ -564,19 +558,19 @@ pub mod tests { kvs.insert(Value::from(5), Value::from(1005)); kvs.insert(Value::from(13), Value::from(1013)); - const MD: usize = 5; - let tree = MerkleTree::new(MD, &kvs)?; + let max_depth = 5; + let tree = MerkleTree::new(max_depth, &kvs)?; // existence - test_merkletree_edgecase_opt::(&tree, Value::from(5))?; + test_merkletree_edgecase_opt(max_depth, &tree, Value::from(5))?; // non-existence case i) expected leaf does not exist - test_merkletree_edgecase_opt::(&tree, Value::from(1))?; + test_merkletree_edgecase_opt(max_depth, &tree, Value::from(1))?; // non-existence case ii) expected leaf does exist but it has a different 'key' - test_merkletree_edgecase_opt::(&tree, Value::from(21))?; + test_merkletree_edgecase_opt(max_depth, &tree, Value::from(21))?; Ok(()) } - fn test_merkletree_edgecase_opt(tree: &MerkleTree, key: Value) -> Result<()> { + fn test_merkletree_edgecase_opt(max_depth: usize, tree: &MerkleTree, key: Value) -> Result<()> { let contains = tree.contains(&key)?; // generate merkleproof let (value, proof) = if contains { @@ -590,9 +584,9 @@ pub mod tests { // verify the proof (non circuit) if proof.existence { - MerkleTree::verify(MD, tree.root(), &proof, &key, &value)?; + MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; } else { - MerkleTree::verify_nonexistence(MD, tree.root(), &proof, &key)?; + MerkleTree::verify_nonexistence(max_depth, tree.root(), &proof, &key)?; } // circuit @@ -600,7 +594,7 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = MerkleProofCircuit::::add_targets(&mut builder)?; + let targets = MerkleProofGate { max_depth }.eval(&mut builder)?; targets.set_targets(&mut pw, proof.existence, tree.root(), proof, key, value)?; // generate & verify proof @@ -617,8 +611,8 @@ pub mod tests { for i in 0..10 { kvs.insert(Value::from(i), Value::from(i)); } - const MD: usize = 16; - let tree = MerkleTree::new(MD, &kvs)?; + let max_depth = 16; + let tree = MerkleTree::new(max_depth, &kvs)?; let key = Value::from(3); let (value, proof) = tree.prove(&key)?; @@ -626,11 +620,11 @@ pub mod tests { // build another tree with an extra key-value, so that it has a // different root kvs.insert(Value::from(100), Value::from(100)); - let tree2 = MerkleTree::new(MD, &kvs)?; + let tree2 = MerkleTree::new(max_depth, &kvs)?; - MerkleTree::verify(MD, tree.root(), &proof, &key, &value)?; + MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; assert_eq!( - MerkleTree::verify(MD, tree2.root(), &proof, &key, &value) + MerkleTree::verify(max_depth, tree2.root(), &proof, &key, &value) .unwrap_err() .to_string(), "proof of inclusion does not verify" @@ -641,7 +635,7 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let targets = MerkleProofCircuit::::add_targets(&mut builder)?; + let targets = MerkleProofGate { max_depth }.eval(&mut builder)?; targets.set_targets(&mut pw, true, tree2.root(), proof, key, value)?; // generate proof, expecting it to fail (since we're using the wrong diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 5f92a36..22fd99c 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -1091,6 +1091,7 @@ pub mod tests { max_operation_args: 5, max_custom_predicate_arity: 5, max_custom_batch_size: 5, + ..Default::default() }; let mut alice = MockSigner { pk: "Alice".into() }; diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 6f8d5cb..e71a46a 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -5,6 +5,7 @@ use std::{fmt, hash as h, iter, iter::zip}; use anyhow::{anyhow, Result}; use plonky2::field::types::Field; use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use super::{ hash_fields, AnchoredKey, Hash, NativePredicate, Params, PodId, Statement, StatementArg, @@ -12,7 +13,6 @@ use super::{ }; use crate::backends::plonky2::basetypes::HASH_SIZE; use crate::util::hashmap_insert_no_dupe; -use serde::{Deserialize, Serialize}; // BEGIN Custom 1b @@ -49,9 +49,9 @@ impl fmt::Display for HashOrWildcard { } impl ToFields for HashOrWildcard { - fn to_fields(&self, _params: &Params) -> Vec { + fn to_fields(&self, params: &Params) -> Vec { match self { - HashOrWildcard::Hash(h) => h.to_fields(_params), + HashOrWildcard::Hash(h) => h.to_fields(params), HashOrWildcard::Wildcard(w) => (0..HASH_SIZE - 1) .chain(iter::once(*w)) .map(|x| F::from_canonical_u64(x as u64)) @@ -91,7 +91,7 @@ impl StatementTmplArg { } impl ToFields for StatementTmplArg { - fn to_fields(&self, _params: &Params) -> Vec { + fn to_fields(&self, params: &Params) -> Vec { // None => (0, ...) // Literal(value) => (1, [value], 0, 0, 0, 0) // Key(hash_or_wildcard1, hash_or_wildcard2) @@ -107,15 +107,15 @@ impl ToFields for StatementTmplArg { } StatementTmplArg::Literal(v) => { let fields: Vec = iter::once(F::from_canonical_u64(1)) - .chain(v.to_fields(_params)) + .chain(v.to_fields(params)) .chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) .collect(); fields } StatementTmplArg::Key(hw1, hw2) => { let fields: Vec = iter::once(F::from_canonical_u64(2)) - .chain(hw1.to_fields(_params)) - .chain(hw2.to_fields(_params)) + .chain(hw1.to_fields(params)) + .chain(hw2.to_fields(params)) .collect(); fields } @@ -165,7 +165,7 @@ impl StatementTmpl { Err(anyhow!( "Cannot check self-referencing statement templates." )) - } else if self.pred() != &s.code() { + } else if self.pred() != &s.predicate() { Err(anyhow!("Type mismatch between {:?} and {}.", self, s)) } else { zip(self.args(), s.args()) @@ -318,8 +318,8 @@ impl ToFields for CustomPredicateBatch { } impl CustomPredicateBatch { - pub fn hash(&self, _params: &Params) -> Hash { - let input = self.to_fields(_params); + pub fn hash(&self, params: &Params) -> Hash { + let input = self.to_fields(params); hash_fields(&input) } @@ -399,7 +399,7 @@ impl From for Predicate { } impl ToFields for Predicate { - fn to_fields(&self, _params: &Params) -> Vec { + fn to_fields(&self, params: &Params) -> Vec { // serialize: // NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize // BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize @@ -410,13 +410,13 @@ impl ToFields for Predicate { // in every case: pad to (hash_size + 2) field elements let mut fields: Vec = match self { Self::Native(p) => iter::once(F::from_canonical_u64(1)) - .chain(p.to_fields(_params)) + .chain(p.to_fields(params)) .collect(), Self::BatchSelf(i) => iter::once(F::from_canonical_u64(2)) .chain(iter::once(F::from_canonical_usize(*i))) .collect(), Self::Custom(CustomPredicateRef(pb, i)) => iter::once(F::from_canonical_u64(3)) - .chain(pb.hash(_params).0) + .chain(pb.hash(params).0) .chain(iter::once(F::from_canonical_usize(*i))) .collect(), }; diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 6d9881c..b07d478 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -93,6 +93,8 @@ pub struct Params { // in a custom predicate pub max_custom_predicate_arity: usize, pub max_custom_batch_size: usize, + // maximum depth for merkle tree gates + pub max_depth_mt_gate: usize, } impl Default for Params { @@ -107,6 +109,7 @@ impl Default for Params { max_operation_args: 5, max_custom_predicate_arity: 5, max_custom_batch_size: 5, + max_depth_mt_gate: 32, } } } @@ -116,15 +119,27 @@ impl Params { self.max_statements - self.max_public_statements } - pub fn statement_tmpl_arg_size() -> usize { + pub const fn statement_tmpl_arg_size() -> usize { 2 * HASH_SIZE + 1 } - pub fn predicate_size() -> usize { + pub const fn predicate_size() -> usize { HASH_SIZE + 2 } - pub fn statement_tmpl_size(&self) -> usize { + pub const fn operation_type_size() -> usize { + HASH_SIZE + 2 + } + + pub fn statement_size(&self) -> usize { + Self::predicate_size() + STATEMENT_ARG_F_LEN * self.max_statement_args + } + + pub fn operation_size(&self) -> usize { + Self::operation_type_size() + OPERATION_ARG_F_LEN * self.max_operation_args + } + + pub const fn statement_tmpl_size(&self) -> usize { Self::predicate_size() + self.max_statement_args * Self::statement_tmpl_arg_size() } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index da54613..52d3e1c 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -1,9 +1,11 @@ use anyhow::{anyhow, Result}; use log::error; +use plonky2::field::types::Field; use serde::{Deserialize, Serialize}; use std::fmt; +use std::iter; -use super::{CustomPredicateRef, NativePredicate, Statement, StatementArg}; +use super::{CustomPredicateRef, NativePredicate, Statement, StatementArg, ToFields, F}; use crate::middleware::{AnchoredKey, Params, Predicate, Value, SELF}; #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -12,6 +14,22 @@ pub enum OperationType { Custom(CustomPredicateRef), } +impl ToFields for OperationType { + fn to_fields(&self, params: &Params) -> Vec { + let mut fields: Vec = match self { + Self::Native(p) => iter::once(F::from_canonical_u64(1)) + .chain(p.to_fields(params)) + .collect(), + Self::Custom(CustomPredicateRef(pb, i)) => iter::once(F::from_canonical_u64(3)) + .chain(pb.hash(params).0) + .chain(iter::once(F::from_canonical_usize(*i))) + .collect(), + }; + fields.resize_with(Params::operation_type_size(), || F::from_canonical_u64(0)); + fields + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum NativeOperation { None = 0, @@ -31,6 +49,12 @@ pub enum NativeOperation { MaxOf = 15, } +impl ToFields for NativeOperation { + fn to_fields(&self, _params: &Params) -> Vec { + vec![F::from_canonical_u64(*self as u64)] + } +} + impl OperationType { /// Gives the type of predicate that the operation will output, if known. /// CopyStatement may output any predicate (it will match the statement copied), @@ -91,7 +115,7 @@ pub enum Operation { } impl Operation { - pub fn predicate(&self) -> OperationType { + pub fn op_type(&self) -> OperationType { type OT = OperationType; use NativeOperation::*; match self { @@ -178,7 +202,7 @@ impl Operation { /// The outer Result is error handling pub fn output_statement(&self) -> Result> { use Statement::*; - let pred: Option = self.predicate().output_predicate(); + let pred: Option = self.op_type().output_predicate(); let st_args: Option> = match self { Self::None => Some(vec![]), @@ -401,10 +425,16 @@ impl Operation { } } +impl ToFields for Operation { + fn to_fields(&self, params: &Params) -> Vec { + todo!() + } +} + impl fmt::Display for Operation { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "middleware::Operation:")?; - writeln!(f, " {:?} ", self.predicate())?; + writeln!(f, " {:?} ", self.op_type())?; for arg in self.args().iter() { writeln!(f, " {}", arg)?; } diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index ba879ec..7454f19 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -5,11 +5,14 @@ use serde::{Deserialize, Serialize}; use std::{fmt, iter}; use strum_macros::FromRepr; -use super::{AnchoredKey, CustomPredicateRef, Params, Predicate, ToFields, Value, F, VALUE_SIZE}; +use super::{ + AnchoredKey, CustomPredicateRef, Params, Predicate, ToFields, Value, F, HASH_SIZE, VALUE_SIZE, +}; pub const KEY_SIGNER: &str = "_signer"; pub const KEY_TYPE: &str = "_type"; pub const STATEMENT_ARG_F_LEN: usize = 8; +pub const OPERATION_ARG_F_LEN: usize = 1; #[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub enum NativePredicate { @@ -53,7 +56,7 @@ impl Statement { pub fn is_none(&self) -> bool { self == &Self::None } - pub fn code(&self) -> Predicate { + pub fn predicate(&self) -> Predicate { use Predicate::*; match self { Self::None => Native(NativePredicate::None), @@ -184,16 +187,17 @@ impl Statement { } impl ToFields for Statement { - fn to_fields(&self, _params: &Params) -> Vec { - let mut fields = self.code().to_fields(_params); - fields.extend(self.args().iter().flat_map(|arg| arg.to_fields(_params))); + fn to_fields(&self, params: &Params) -> Vec { + let mut fields = self.predicate().to_fields(params); + fields.extend(self.args().iter().flat_map(|arg| arg.to_fields(params))); + fields.resize_with(params.statement_size(), || F::ZERO); fields } } impl fmt::Display for Statement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:?} ", self.code())?; + write!(f, "{:?} ", self.predicate())?; for (i, arg) in self.args().iter().enumerate() { if i != 0 { write!(f, " ")?;