diff --git a/examples/main_pod_points.rs b/examples/main_pod_points.rs index 7c877e3..8eb3cfc 100644 --- a/examples/main_pod_points.rs +++ b/examples/main_pod_points.rs @@ -35,7 +35,7 @@ fn main() -> Result<(), Box> { let mock_prover = MockProver {}; let real_prover = Prover {}; let (vd_set, prover): (_, &dyn MainPodProver) = if mock { - (&VDSet::new(8, &[])?, &mock_prover) + (&VDSet::new(&[]), &mock_prover) } else { println!("Prebuilding circuits to calculate vd_set..."); let vd_set = &*DEFAULT_VD_SET; diff --git a/examples/signed_dict.rs b/examples/signed_dict.rs index f02b921..fc452c8 100644 --- a/examples/signed_dict.rs +++ b/examples/signed_dict.rs @@ -30,10 +30,7 @@ fn main() -> Result<(), Box> { .into_iter() .map(Value::from) .collect(); - builder.insert( - "friends", - Set::new(params.max_merkle_proofs_containers, friends_set)?, - ); + builder.insert("friends", Set::new(friends_set)); // Sign the dict and verify it let signed_dict = builder.sign(&signer)?; diff --git a/src/backends/plonky2/basetypes.rs b/src/backends/plonky2/basetypes.rs index 7a94551..d7d6b39 100644 --- a/src/backends/plonky2/basetypes.rs +++ b/src/backends/plonky2/basetypes.rs @@ -65,9 +65,8 @@ pub static DEFAULT_VD_LIST: LazyLock> = LazyLock::n }); pub static DEFAULT_VD_SET: LazyLock = LazyLock::new(|| { - let params = Params::default(); let vds = &*DEFAULT_VD_LIST; - VDSet::new(params.max_depth_mt_vds, vds).unwrap() + VDSet::new(vds) }); /// VDSet is the set of the allowed verifier_data hashes. When proving a @@ -84,35 +83,29 @@ pub struct VDSet { #[serde(skip)] #[schemars(skip)] proofs_map: HashMap, - tree_depth: usize, vds_hashes: Vec, } impl PartialEq for VDSet { fn eq(&self, other: &Self) -> bool { - self.root == other.root - && self.tree_depth == other.tree_depth - && self.vds_hashes == other.vds_hashes + self.root == other.root && self.vds_hashes == other.vds_hashes } } impl Eq for VDSet {} impl VDSet { - fn new_from_vds_hashes(tree_depth: usize, mut vds_hashes: Vec) -> Result { + fn new_from_vds_hashes(mut vds_hashes: Vec) -> Self { // before using the hash values, sort them, so that each set of // verifier_datas gets the same VDSet root vds_hashes.sort(); - let array = Array::new( - tree_depth, - vds_hashes.iter().map(|vd| Value::from(*vd)).collect(), - )?; + let array = Array::new(vds_hashes.iter().map(|vd| Value::from(*vd)).collect()); let root = array.commitment(); let mut proofs_map = HashMap::::new(); for (i, vd) in vds_hashes.iter().enumerate() { - let (value, proof) = array.prove(i)?; + let (value, proof) = array.prove(i).expect("exists"); let p = MerkleClaimAndProof { root, key: RawValue::from(i as i64), @@ -121,15 +114,14 @@ impl VDSet { }; proofs_map.insert(*vd, p); } - Ok(Self { + Self { root, proofs_map, - tree_depth, vds_hashes, - }) + } } /// builds the verifier_datas tree, and returns the root and the proofs - pub fn new(tree_depth: usize, vds: &[VerifierOnlyCircuitData]) -> Result { + pub fn new(vds: &[VerifierOnlyCircuitData]) -> Self { // compute the verifier_data's hashes let vds_hashes: Vec = vds .iter() @@ -141,7 +133,7 @@ impl VDSet { .map(|h| Hash(h.elements)) .collect::>(); - Self::new_from_vds_hashes(tree_depth, vds_hashes) + Self::new_from_vds_hashes(vds_hashes) } pub fn root(&self) -> Hash { self.root @@ -172,10 +164,9 @@ impl<'de> Deserialize<'de> for VDSet { { #[derive(Deserialize)] struct Aux { - tree_depth: usize, vds_hashes: Vec, } let aux = Aux::deserialize(deserializer)?; - VDSet::new_from_vds_hashes(aux.tree_depth, aux.vds_hashes).map_err(serde::de::Error::custom) + Ok(VDSet::new_from_vds_hashes(aux.vds_hashes)) } } diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 3e95462..46c56aa 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -2329,8 +2329,8 @@ mod tests { #[test] fn test_operation_verify_eq() -> Result<()> { - let dict1 = dict!(32, {"hello" => 55})?; - let dict2 = dict!(32, {"world" => 55})?; + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"world" => 55}); let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 55).into(); let st: mainpod::Statement = Statement::equal( @@ -2349,8 +2349,8 @@ mod tests { #[test] fn test_operation_verify_neq() -> Result<()> { - let dict1 = dict!(32, {"hello" => 55})?; - let dict2 = dict!(32, {"world" => 75})?; + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"world" => 75}); let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 75).into(); let st: mainpod::Statement = Statement::not_equal( @@ -2369,8 +2369,8 @@ mod tests { #[test] fn test_operation_verify_lt() -> Result<()> { - let dict1 = dict!(32, {"hello" => 55})?; - let dict2 = dict!(32, {"hello" => 56})?; + let dict1 = dict!({"hello" => 55}); + let dict2 = dict!({"hello" => 56}); let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into(); let st2: mainpod::Statement = Statement::contains(dict2.clone(), "hello", 56).into(); let st: mainpod::Statement = Statement::lt( @@ -2387,8 +2387,8 @@ mod tests { operation_verify(st, op, prev_statements, Aux::default())?; // Also check negative < negative - let dict3 = dict!(32, {"hola" => -56})?; - let dict4 = dict!(32, {"mundo" => -55})?; + let dict3 = dict!({"hola" => -56}); + let dict4 = dict!({"mundo" => -55}); let st3: mainpod::Statement = Statement::contains(dict3.clone(), "hola", -56).into(); let st4: mainpod::Statement = Statement::contains(dict4.clone(), "mundo", -55).into(); let st: mainpod::Statement = Statement::lt( @@ -2421,12 +2421,12 @@ mod tests { #[test] fn test_operation_verify_lteq() -> Result<()> { - let local = dict!(32, { + let local = dict!({ "n55" => 55, "n56" => 56, "n_56" => -56, "n_55" => -55, - })?; + }); let st1: mainpod::Statement = Statement::contains(local.clone(), "n55", 55).into(); let st2: mainpod::Statement = Statement::contains(local.clone(), "n56", 56).into(); let st: mainpod::Statement = Statement::lt_eq( @@ -2511,11 +2511,11 @@ mod tests { let v1 = hash_values(&input_values); let [v2, v3] = input_values; - let local = dict!(32, { + let local = dict!({ "hola" => v1, "mundo" => v2.clone(), "!" => v3.clone(), - })?; + }); let st1: mainpod::Statement = Statement::contains(local.clone(), "hola", v1).into(); let st2: mainpod::Statement = Statement::contains(local.clone(), "mundo", v2).into(); @@ -2549,11 +2549,11 @@ mod tests { overflow.not().then_some((a, b, sum)) }) .try_for_each(|(a, b, sum)| { - let local = dict!(32, { + let local = dict!({ "sum" => sum, "a" => a, "b" => b, - })?; + }); let st1: mainpod::Statement = Statement::contains(local.clone(), "sum", sum).into(); let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); @@ -2588,11 +2588,11 @@ mod tests { overflow.not().then_some((a, b, prod)) }) .try_for_each(|(a, b, prod)| { - let local = dict!(32, { + let local = dict!({ "prod" => prod, "a" => a, "b" => b, - })?; + }); let st1: mainpod::Statement = Statement::contains(local.clone(), "prod", prod).into(); @@ -2623,11 +2623,11 @@ mod tests { fn test_operation_verify_maxof() -> Result<()> { I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| { let max = i64::max(a, b); - let local = dict!(32, { + let local = dict!({ "max" => max, "a" => a, "b" => b, - })?; + }); let st1: mainpod::Statement = Statement::contains(local.clone(), "max", max).into(); let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into(); @@ -2689,10 +2689,10 @@ mod tests { #[test] fn test_operation_verify_lt_to_neq() -> Result<()> { - let local = dict!(32,{ + let local = dict!({ "a" => 10, "b" => 20, - })?; + }); let st: mainpod::Statement = Statement::not_equal( AnchoredKey::from((&local, "a")), AnchoredKey::from((&local, "b")), @@ -2714,11 +2714,11 @@ mod tests { #[test] fn test_operation_verify_transitive_eq() -> Result<()> { - let local = dict!(32,{ + let local = dict!({ "a" => 10, "b" => 10, "c" => 10, - })?; + }); let st: mainpod::Statement = Statement::equal( AnchoredKey::from((&local, "a")), AnchoredKey::from((&local, "c")), @@ -2745,8 +2745,6 @@ mod tests { #[test] fn test_operation_verify_sintains() -> Result<()> { - let params = Params::default(); - let kvs = [ (1.into(), 55.into()), (2.into(), 88.into()), @@ -2754,14 +2752,14 @@ mod tests { ] .into_iter() .collect(); - let mt = MerkleTree::new(params.max_depth_mt_containers, &kvs)?; + let mt = MerkleTree::new(&kvs); let root = mt.root(); let key = Value::from(5); - let local = dict!(32,{ + let local = dict!({ "merkle_root" => root, "key" => key.clone(), - })?; + }); let root_ak = AnchoredKey::from((&local, "merkle_root")); let key_ak = AnchoredKey::from((&local, "key")); @@ -2785,8 +2783,6 @@ mod tests { #[test] fn test_operation_verify_contains() -> Result<()> { - let params = Params::default(); - let kvs = [ (1.into(), 55.into()), (2.into(), 88.into()), @@ -2794,16 +2790,16 @@ mod tests { ] .into_iter() .collect(); - let mt = MerkleTree::new(params.max_depth_mt_containers, &kvs)?; + let mt = MerkleTree::new(&kvs); let root = mt.root(); let key = Value::from(175); let (value, key_pf) = mt.prove(&key.raw())?; - let local = dict!(32,{ + let local = dict!({ "merkle_root" => root, "key" => key.clone(), "value" => value, - })?; + }); let root_ak = AnchoredKey::from((&local, "merkle_root")); let key_ak = AnchoredKey::from((&local, "key")); let value_ak = AnchoredKey::from((&local, "value")); @@ -2833,9 +2829,7 @@ mod tests { #[test] fn test_operation_verify_merkle_insert() -> Result<()> { - let params = Params::default(); - - let mut tree = MerkleTree::new(params.max_depth_mt_containers, &[].into())?; + let mut tree = MerkleTree::new(&[].into()); let key = Value::from(175); let value = Value::from(0); @@ -2862,12 +2856,7 @@ mod tests { #[test] fn test_operation_verify_merkle_update() -> Result<()> { - let params = Params::default(); - - let mut tree = MerkleTree::new( - params.max_depth_mt_containers, - &[(175.into(), 55.into())].into(), - )?; + let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); let key = Value::from(175); let value = Value::from(0); @@ -2894,12 +2883,7 @@ mod tests { #[test] fn test_operation_verify_merkle_delete() -> Result<()> { - let params = Params::default(); - - let mut tree = MerkleTree::new( - params.max_depth_mt_containers, - &[(175.into(), 55.into())].into(), - )?; + let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into()); let key = Value::from(175); let state_transition_proof = tree.delete(&key.raw())?; diff --git a/src/backends/plonky2/emptypod.rs b/src/backends/plonky2/emptypod.rs index 2bc8b29..5e0be7c 100644 --- a/src/backends/plonky2/emptypod.rs +++ b/src/backends/plonky2/emptypod.rs @@ -260,7 +260,7 @@ pub mod tests { fn test_empty_pod() { let params = Params::default(); - let empty_pod = EmptyPod::new_boxed(¶ms, VDSet::new(8, &[]).unwrap()); + let empty_pod = EmptyPod::new_boxed(¶ms, VDSet::new(&[])); empty_pod.verify().unwrap(); } } diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 98e4fa9..68334b2 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -851,7 +851,7 @@ pub mod tests { println!("{:#?}", params); let mut vds = DEFAULT_VD_LIST.clone(); vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); - let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); + let vd_set = VDSet::new(&vds); let (gov_id_builder, pay_stub_builder) = zu_kyc_sign_dict_builders(¶ms); let signer = Signer(SecretKey(BigUint::one())); @@ -875,7 +875,7 @@ pub mod tests { env_logger::init(); let params = Params::default(); println!("{:#?}", params); - let vd_set = VDSet::new(params.max_depth_mt_vds, &[]).unwrap(); + let vd_set = VDSet::new(&[]); // Calculate rec common first to avoid duplicate metrics in `pod_builder.prove` let _rec_common_circuit_data = cache_get_standard_rec_main_pod_common_circuit_data(); @@ -912,7 +912,7 @@ pub mod tests { }; let mut vds = DEFAULT_VD_LIST.clone(); vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); - let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); + let vd_set = VDSet::new(&vds); let mut gov_id_builder = frontend::SignedDictBuilder::new(¶ms); gov_id_builder.insert("idNumber", "4242424242"); @@ -970,7 +970,7 @@ pub mod tests { }; let mut vds = DEFAULT_VD_LIST.clone(); vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); - let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); + let vd_set = VDSet::new(&vds); let builder = frontend::MainPodBuilder::new(¶ms, &vd_set); println!("{}", builder); @@ -1014,7 +1014,7 @@ pub mod tests { }; let mut vds = DEFAULT_VD_LIST.clone(); vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); - let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); + let vd_set = VDSet::new(&vds); let pod_builder = frontend::MainPodBuilder::new(¶ms, &vd_set); @@ -1080,7 +1080,7 @@ pub mod tests { println!("{:#?}", params); let mut vds = DEFAULT_VD_LIST.clone(); vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); - let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); + let vd_set = VDSet::new(&vds); let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "cpb".into()); let stb0 = STB::new(NP::Contains) @@ -1104,8 +1104,8 @@ pub mod tests { let mut pod_builder = MainPodBuilder::new(¶ms, &vd_set); - let dict = dict!(32, {"score" => 42})?; - let secret_dict = dict!(32, {"key" => 42})?; + let dict = dict!({"score" => 42}); + let secret_dict = dict!({"key" => 42}); let st0 = pod_builder.priv_op(frontend::Operation::dict_contains( dict.clone(), "score", @@ -1136,7 +1136,7 @@ pub mod tests { let params = Params::default(); let mut builder = MainPodBuilder::new(¶ms, &DEFAULT_VD_SET); let set: HashSet<_> = [1, 2, 3].into_iter().map(|n| n.into()).collect(); - let set = Set::new(params.max_depth_mt_containers, set).unwrap(); + let set = Set::new(set); builder.pub_op(frontend::Operation::set_contains(set, 1))?; let prover = Prover {}; diff --git a/src/backends/plonky2/mock/emptypod.rs b/src/backends/plonky2/mock/emptypod.rs index d8c5611..fff0b5c 100644 --- a/src/backends/plonky2/mock/emptypod.rs +++ b/src/backends/plonky2/mock/emptypod.rs @@ -105,7 +105,7 @@ pub mod tests { fn test_mock_empty_pod() { let params = Params::default(); - let empty_pod = MockEmptyPod::new_boxed(¶ms, VDSet::new(8, &[]).unwrap()); + let empty_pod = MockEmptyPod::new_boxed(¶ms, VDSet::new(&[])); empty_pod.verify().unwrap(); } } diff --git a/src/backends/plonky2/primitives/merkletree/circuit.rs b/src/backends/plonky2/primitives/merkletree/circuit.rs index 7456fa2..8b34999 100644 --- a/src/backends/plonky2/primitives/merkletree/circuit.rs +++ b/src/backends/plonky2/primitives/merkletree/circuit.rs @@ -29,9 +29,9 @@ use crate::{ backends::plonky2::{ basetypes::D, circuits::common::{CircuitBuilderPod, ValueTarget}, - error::Result, + error::{Error, Result}, primitives::merkletree::{ - MerkleClaimAndProof, MerkleTreeOp, MerkleTreeStateTransitionProof, + MerkleClaimAndProof, MerkleTreeOp, MerkleTreeStateTransitionProof, TreeError, }, }, measure_gates_begin, measure_gates_end, @@ -159,6 +159,13 @@ impl MerkleClaimAndProofTarget { enabled: bool, mp: &MerkleClaimAndProof, ) -> Result<()> { + if mp.proof.siblings.len() > self.max_depth { + return Err(Error::Tree(TreeError::circuit_depth_too_small( + self.max_depth, + mp.proof.siblings.len(), + ))); + } + pw.set_bool_target(self.enabled, enabled)?; pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?; pw.set_target_arr(&self.key.elements, &mp.key.0)?; @@ -166,7 +173,6 @@ impl MerkleClaimAndProofTarget { pw.set_bool_target(self.existence, mp.proof.existence)?; // pad siblings with zeros to length max_depth - assert!(mp.proof.siblings.len() <= self.max_depth); for (i, sibling) in mp .proof .siblings @@ -265,6 +271,12 @@ impl MerkleProofExistenceTarget { mp: &MerkleClaimAndProof, ) -> Result<()> { assert!(mp.proof.existence); // sanity check + if mp.proof.siblings.len() > self.max_depth { + return Err(Error::Tree(TreeError::circuit_depth_too_small( + self.max_depth, + mp.proof.siblings.len(), + ))); + } pw.set_bool_target(self.enabled, enabled)?; pw.set_hash_target(self.root, HashOut::from_vec(mp.root.0.to_vec()))?; @@ -272,7 +284,6 @@ impl MerkleProofExistenceTarget { pw.set_target_arr(&self.value.elements, &mp.value.0)?; // pad siblings with zeros to length max_depth - assert!(mp.proof.siblings.len() <= self.max_depth); for (i, sibling) in mp .proof .siblings @@ -610,6 +621,14 @@ impl MerkleTreeStateTransitionProofTarget { enabled: bool, mp: &MerkleTreeStateTransitionProof, ) -> Result<()> { + let new_siblings = mp.siblings.clone(); + if new_siblings.len() > self.max_depth { + return Err(Error::Tree(TreeError::circuit_depth_too_small( + self.max_depth, + new_siblings.len(), + ))); + } + pw.set_bool_target(self.enabled, enabled)?; pw.set_target(self.op, F::from_canonical_u8(mp.op as u8))?; @@ -633,9 +652,6 @@ impl MerkleTreeStateTransitionProofTarget { pw.set_target_arr(&self.op_key.elements, &mp.op_key.0)?; pw.set_target_arr(&self.op_value.elements, &mp.op_value.0)?; - let new_siblings = mp.siblings.clone(); - - assert!(new_siblings.len() <= self.max_depth); for (i, sibling) in new_siblings .iter() .chain(iter::repeat(&EMPTY_HASH)) @@ -683,7 +699,7 @@ pub mod tests { let mut pw = PartialWitness::::new(); let key = RawValue::from(hash_value(&RawValue::from(i))); - let expected_path = keypath(max_depth, key)?; + let expected_path = keypath(key); // small circuit logic to check // expected_path_targ==keypath_target(key_targ) @@ -769,7 +785,7 @@ pub mod tests { ); } - let tree = MerkleTree::new(max_depth, &kvs)?; + let tree = MerkleTree::new(&kvs); let (key, value, proof) = if existence { let key = RawValue::from(hash_value(&RawValue::from(5))); @@ -783,9 +799,9 @@ pub mod tests { assert_eq!(proof.existence, existence); if existence { - MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; + MerkleTree::verify(tree.root(), &proof, &key, &value)?; } else { - MerkleTree::verify_nonexistence(max_depth, tree.root(), &proof, &key)?; + MerkleTree::verify_nonexistence(tree.root(), &proof, &key)?; } // circuit @@ -826,14 +842,14 @@ pub mod tests { ); } - let tree = MerkleTree::new(max_depth, &kvs)?; + let tree = MerkleTree::new(&kvs); let key = RawValue::from(hash_value(&RawValue::from(5))); let (value, proof) = tree.prove(&key)?; assert_eq!(value, RawValue::from(5)); assert!(proof.existence); - MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; + MerkleTree::verify(tree.root(), &proof, &key, &value)?; // circuit let config = CircuitConfig::standard_recursion_config(); @@ -877,7 +893,7 @@ pub mod tests { kvs.insert(RawValue::from(13), RawValue::from(1013)); let max_depth = 5; - let tree = MerkleTree::new(max_depth, &kvs)?; + let tree = MerkleTree::new(&kvs); // existence test_merkletree_edgecase_opt(max_depth, &tree, RawValue::from(5))?; // non-existence case i) expected leaf does not exist @@ -906,9 +922,9 @@ pub mod tests { // verify the proof (non circuit) if proof.existence { - MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; + MerkleTree::verify(tree.root(), &proof, &key, &value)?; } else { - MerkleTree::verify_nonexistence(max_depth, tree.root(), &proof, &key)?; + MerkleTree::verify_nonexistence(tree.root(), &proof, &key)?; } // circuit @@ -939,7 +955,7 @@ pub mod tests { kvs.insert(RawValue::from(i), RawValue::from(i)); } let max_depth = 16; - let tree = MerkleTree::new(max_depth, &kvs)?; + let tree = MerkleTree::new(&kvs); let key = RawValue::from(3); let (value, proof) = tree.prove(&key)?; @@ -947,11 +963,11 @@ pub mod tests { // build another tree with an extra key-value, so that it has a // different root kvs.insert(RawValue::from(100), RawValue::from(100)); - let tree2 = MerkleTree::new(max_depth, &kvs)?; + let tree2 = MerkleTree::new(&kvs); - MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; + MerkleTree::verify(tree.root(), &proof, &key, &value)?; assert_eq!( - MerkleTree::verify(max_depth, tree2.root(), &proof, &key, &value) + MerkleTree::verify(tree2.root(), &proof, &key, &value) .unwrap_err() .inner() .unwrap() @@ -1002,10 +1018,10 @@ pub mod tests { ) -> Result<()> { // sanity check, run the out-circuit proof verification if expect_pass { - MerkleTree::verify_state_transition(max_depth, state_transition_proof)?; + MerkleTree::verify_state_transition(state_transition_proof)?; } else { // expect out-circuit verification to fail - let _ = MerkleTree::verify_state_transition(max_depth, state_transition_proof).is_err(); + let _ = MerkleTree::verify_state_transition(state_transition_proof).is_err(); } let config = CircuitConfig::standard_recursion_config(); @@ -1034,7 +1050,7 @@ pub mod tests { for i in 0..8 { kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); } - let mut tree = MerkleTree::new(max_depth, &kvs)?; + let mut tree = MerkleTree::new(&kvs); // key=37 shares path with key=5, till the level 6, needing 2 extra // 'empty' nodes between the original position of key=5 with the new @@ -1093,7 +1109,7 @@ pub mod tests { for i in 0..8 { kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); } - let mut tree = MerkleTree::new(max_depth, &kvs)?; + let mut tree = MerkleTree::new(&kvs); let old_root = tree.root(); let key = RawValue::from(4294967295); // 0xffffffff @@ -1157,7 +1173,7 @@ pub mod tests { for i in 0..8 { kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); } - let mut tree = MerkleTree::new(max_depth, &kvs)?; + let mut tree = MerkleTree::new(&kvs); // key=37 shares path with key=5, till the level 6, needing 2 extra // 'empty' nodes between the original position of key=5 with the new @@ -1200,7 +1216,6 @@ pub mod tests { other_leaf: None, }; let altered_root = altered_proof.compute_root_from_leaf( - max_depth, &state_transition_proof.op_key, Some(state_transition_proof.op_value), )?; @@ -1220,7 +1235,7 @@ pub mod tests { for i in 0..8 { kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); } - let mut tree = MerkleTree::new(max_depth, &kvs)?; + let mut tree = MerkleTree::new(&kvs); let key = RawValue::from(37); let value = RawValue::from(1037); diff --git a/src/backends/plonky2/primitives/merkletree/error.rs b/src/backends/plonky2/primitives/merkletree/error.rs index e8fcfbc..2eb3198 100644 --- a/src/backends/plonky2/primitives/merkletree/error.rs +++ b/src/backends/plonky2/primitives/merkletree/error.rs @@ -20,8 +20,8 @@ pub enum TreeInnerError { InvalidStateTransitionProogArg(String), #[error("state transition proof does not verify, reason: {0}")] StateTransitionProofFail(String), - #[error("key too short (key length: {0}) for the max_depth: {1}")] - TooShortKey(usize, usize), + #[error("circuit max_depth {0} is smaller than proof depth {1}")] + CircuitDepthTooSmall(usize, usize), } #[derive(thiserror::Error)] @@ -78,7 +78,7 @@ impl TreeError { pub(crate) fn state_transition_fail(reason: String) -> Self { new!(StateTransitionProofFail(reason)) } - pub(crate) fn too_short_key(depth: usize, max_depth: usize) -> Self { - new!(TooShortKey(depth, max_depth)) + pub(crate) fn circuit_depth_too_small(circuit_depth: usize, proof_depth: usize) -> Self { + new!(CircuitDepthTooSmall(circuit_depth, proof_depth)) } } diff --git a/src/backends/plonky2/primitives/merkletree/mod.rs b/src/backends/plonky2/primitives/merkletree/mod.rs index e0dff04..07f87dc 100644 --- a/src/backends/plonky2/primitives/merkletree/mod.rs +++ b/src/backends/plonky2/primitives/merkletree/mod.rs @@ -13,28 +13,30 @@ pub use circuit::*; pub mod error; pub use error::{TreeError, TreeResult}; +/// Theoretical max depth of a merkle tree. This limits appears because we store keys of 256 bits. +const MAX_DEPTH: usize = 256; + /// Implements the MerkleTree specified at /// #[derive(Clone, Debug)] pub struct MerkleTree { - max_depth: usize, root: Node, } impl MerkleTree { /// builds a new `MerkleTree` where the leaves contain the given key-values - pub fn new(max_depth: usize, kvs: &HashMap) -> TreeResult { + pub fn new(kvs: &HashMap) -> Self { // Start with an empty node as root. let mut root = Node::None; // Iterate over key-value pairs (if any) and add them. for (k, v) in kvs.iter() { - root.apply_op(max_depth, MerkleTreeOp::Insert, *k, Some(*v))?; + root.apply_op(MerkleTreeOp::Insert, *k, Some(*v)).unwrap(); } // Fill in hashes. let _ = root.compute_hash(); - Ok(Self { max_depth, root }) + Self { root } } /// returns the root of the tree @@ -42,15 +44,10 @@ impl MerkleTree { self.root.hash() } - /// returns the max_depth parameter from the tree - pub fn max_depth(&self) -> usize { - self.max_depth - } - /// returns the value at the given key pub fn get(&self, key: &RawValue) -> TreeResult { - let path = keypath(self.max_depth, *key)?; - let (key_resolution, _) = self.root.down(0, self.max_depth, path, None)?; + let path = keypath(*key); + let (key_resolution, _) = self.root.down(0, path, None); match key_resolution { Some((k, v)) if &k == key => Ok(v), _ => Err(TreeError::key_not_found()), @@ -59,15 +56,9 @@ impl MerkleTree { /// returns a boolean indicating whether the key exists in the tree pub fn contains(&self, key: &RawValue) -> TreeResult { - let path = keypath(self.max_depth, *key)?; - match self.root.down(0, self.max_depth, path, None) { - Ok((Some((k, _)), _)) => { - if &k == key { - Ok(true) - } else { - Ok(false) - } - } + let path = keypath(*key); + match self.root.down(0, path, None) { + (Some((k, _)), _) if &k == key => Ok(true), _ => Ok(false), } } @@ -81,7 +72,7 @@ impl MerkleTree { let old_root: Hash = self.root.hash(); self.root - .apply_op(self.max_depth, MerkleTreeOp::Insert, *key, Some(*value))?; + .apply_op(MerkleTreeOp::Insert, *key, Some(*value))?; let new_root = self.root.compute_hash(); let (v, proof) = self.prove(key)?; @@ -110,7 +101,7 @@ impl MerkleTree { let old_root: Hash = self.root.hash(); self.root - .apply_op(self.max_depth, MerkleTreeOp::Update, *key, Some(*value))?; + .apply_op(MerkleTreeOp::Update, *key, Some(*value))?; let new_root = self.root.compute_hash(); let (v, proof) = self.prove(key)?; @@ -134,8 +125,7 @@ impl MerkleTree { let (value, proof_existence) = self.prove(key)?; let old_root: Hash = self.root.hash(); - self.root - .apply_op(self.max_depth, MerkleTreeOp::Delete, *key, None)?; + self.root.apply_op(MerkleTreeOp::Delete, *key, None)?; let new_root = self.root.compute_hash(); let proof = self.prove_nonexistence(key)?; @@ -157,14 +147,11 @@ impl MerkleTree { /// the tree. It returns the `value` of the leaf at the given `key`, and the /// `MerkleProof`. pub fn prove(&self, key: &RawValue) -> TreeResult<(RawValue, MerkleProof)> { - let path = keypath(self.max_depth, *key)?; + let path = keypath(*key); let mut siblings: Vec = Vec::new(); - match self - .root - .down(0, self.max_depth, path, Some(&mut siblings))? - { + match self.root.down(0, path, Some(&mut siblings)) { (Some((k, v)), _) if &k == key => Ok(( v, MerkleProof { @@ -182,15 +169,12 @@ impl MerkleTree { /// the key-value pair in the leaf reached as a result of /// resolving `key` as well as a `MerkleProof`. pub fn prove_nonexistence(&self, key: &RawValue) -> TreeResult { - let path = keypath(self.max_depth, *key)?; + let path = keypath(*key); let mut siblings: Vec = Vec::new(); // note: non-existence of a key can be in 2 cases: - match self - .root - .down(0, self.max_depth, path, Some(&mut siblings))? - { + match self.root.down(0, path, Some(&mut siblings)) { // case i) the expected leaf does not exist (None, _) => Ok(MerkleProof { existence: false, @@ -203,20 +187,19 @@ impl MerkleTree { siblings, other_leaf: Some((k, v)), }), - _ => Err(TreeError::key_not_found()), + _ => Err(TreeError::key_exists()), } // both cases prove that the given key don't exist in the tree. } /// verifies an inclusion proof for the given `key` and `value` pub fn verify( - max_depth: usize, root: Hash, proof: &MerkleProof, key: &RawValue, value: &RawValue, ) -> TreeResult<()> { - let h = proof.compute_root_from_leaf(max_depth, key, Some(*value))?; + let h = proof.compute_root_from_leaf(key, Some(*value))?; if h != root { Err(TreeError::proof_fail("inclusion".to_string())) @@ -227,12 +210,7 @@ impl MerkleTree { /// verifies a non-inclusion proof for the given `key`, that is, the given /// `key` does not exist in the tree - pub fn verify_nonexistence( - max_depth: usize, - root: Hash, - proof: &MerkleProof, - key: &RawValue, - ) -> TreeResult<()> { + pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &RawValue) -> TreeResult<()> { match proof.other_leaf { Some((k, _v)) if &k == key => { Err(TreeError::invalid_proof("non-existence".to_string())) @@ -240,7 +218,7 @@ impl MerkleTree { _ => { let k = proof.other_leaf.map(|(k, _)| k).unwrap_or(*key); let v: Option = proof.other_leaf.map(|(_, v)| v); - let h = proof.compute_root_from_leaf(max_depth, &k, v)?; + let h = proof.compute_root_from_leaf(&k, v)?; if h != root { Err(TreeError::proof_fail("exclusion".to_string())) @@ -251,10 +229,7 @@ impl MerkleTree { } } - pub fn verify_state_transition( - max_depth: usize, - proof: &MerkleTreeStateTransitionProof, - ) -> TreeResult<()> { + pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> TreeResult<()> { let mut old_siblings = proof.op_proof.siblings.clone(); let new_siblings = proof.siblings.clone(); @@ -267,12 +242,11 @@ impl MerkleTree { old_root: proof.new_root, ..proof.clone() }; - Self::verify_state_transition(max_depth, &equivalent_insertion_proof) + Self::verify_state_transition(&equivalent_insertion_proof) } MerkleTreeOp::Update => { // check that for the old_root, (op_key, value) *does* exist in the tree Self::verify( - max_depth, proof.old_root, &proof.op_proof, &proof.op_key, @@ -280,7 +254,6 @@ impl MerkleTree { )?; // check that for the new_root, (op_key, op_value) *does* exist in the tree Self::verify( - max_depth, proof.new_root, &MerkleProof { existence: true, @@ -301,16 +274,10 @@ impl MerkleTree { } MerkleTreeOp::Insert => { // check that for the old_root, the new_key does not exist in the tree - Self::verify_nonexistence( - max_depth, - proof.old_root, - &proof.op_proof, - &proof.op_key, - )?; + Self::verify_nonexistence(proof.old_root, &proof.op_proof, &proof.op_key)?; // check that new_siblings verify with the new_root Self::verify( - max_depth, proof.new_root, &MerkleProof { existence: true, @@ -323,8 +290,8 @@ impl MerkleTree { // if other_leaf exists, check path divergence if let Some((other_key, _)) = proof.op_proof.other_leaf { - let old_path = keypath(max_depth, other_key)?; - let new_path = keypath(max_depth, proof.op_key)?; + let old_path = keypath(other_key); + let new_path = keypath(proof.op_key); let divergence_lvl: usize = match zip_eq(old_path, new_path).position(|(x, y)| x != y) { @@ -455,26 +422,12 @@ impl MerkleProof { /// Computes the root of the Merkle tree suggested by a Merkle proof given a /// key & value. If a value is not provided, the terminal node is assumed to /// be empty. - fn compute_root_from_leaf( - &self, - max_depth: usize, - key: &RawValue, - value: Option, - ) -> TreeResult { - let path = keypath(max_depth, *key)?; + fn compute_root_from_leaf(&self, key: &RawValue, value: Option) -> TreeResult { + let path = keypath(*key); let h = kv_hash(key, value); - self.compute_root_from_node(max_depth, &h, path) + self.compute_root_from_node(&h, path) } - fn compute_root_from_node( - &self, - max_depth: usize, - node_hash: &Hash, - path: Vec, - ) -> TreeResult { - if self.siblings.len() >= max_depth { - return Err(TreeError::max_depth()); - } - + fn compute_root_from_node(&self, node_hash: &Hash, path: Vec) -> TreeResult { let mut h = *node_hash; for (i, sibling) in self.siblings.iter().enumerate().rev() { let mut input: Vec = if path[i] { @@ -677,26 +630,21 @@ impl Node { fn down( &self, lvl: usize, - max_depth: usize, path: Vec, mut siblings: Option<&mut Vec>, - ) -> TreeResult<(Option<(RawValue, RawValue)>, usize)> { - if lvl >= max_depth { - return Err(TreeError::max_depth()); - } - + ) -> (Option<(RawValue, RawValue)>, usize) { match self { Self::Intermediate(n) => { if path[lvl] { if let Some(s) = siblings.as_mut() { s.push(n.left.hash()); } - n.right.down(lvl + 1, max_depth, path, siblings) + n.right.down(lvl + 1, path, siblings) } else { if let Some(s) = siblings.as_mut() { s.push(n.right.hash()); } - n.left.down(lvl + 1, max_depth, path, siblings) + n.left.down(lvl + 1, path, siblings) } } Self::Leaf(Leaf { @@ -704,20 +652,19 @@ impl Node { value, path: _p, hash: _h, - }) => Ok((Some((*key, *value)), lvl)), - _ => Ok((None, lvl)), + }) => (Some((*key, *value)), lvl), + _ => (None, lvl), } } /// Applies given Merkle tree op without computing hashes. pub(crate) fn apply_op( &mut self, - max_depth: usize, op: MerkleTreeOp, key: RawValue, maybe_value: Option, ) -> TreeResult<()> { - let key_path = keypath(max_depth, key)?; + let key_path = keypath(key); // Rule out invalid arguments match (op, maybe_value) { (MerkleTreeOp::Insert, None) | (MerkleTreeOp::Update, None) => { @@ -736,7 +683,7 @@ impl Node { }?; // Loop through to leaf. - self.apply_op_loop(0, max_depth, op, key, &key_path, maybe_value)?; + self.apply_op_loop(0, op, key, &key_path, maybe_value)?; // If we are dealing with a deletion, normalise along key // path. @@ -777,28 +724,23 @@ impl Node { fn apply_op_loop( &mut self, lvl: usize, - max_depth: usize, op: MerkleTreeOp, key: RawValue, key_path: &[bool], maybe_value: Option, ) -> TreeResult<()> { - if lvl >= max_depth { - return Err(TreeError::max_depth()); - } - match self { Self::Intermediate(n) => { if key_path[lvl] { n.right - .apply_op_loop(lvl + 1, max_depth, op, key, key_path, maybe_value) + .apply_op_loop(lvl + 1, op, key, key_path, maybe_value) } else { n.left - .apply_op_loop(lvl + 1, max_depth, op, key, key_path, maybe_value) + .apply_op_loop(lvl + 1, op, key, key_path, maybe_value) } } _ => { - *self = Self::op_node_check(max_depth, lvl, self, op, key, key_path, maybe_value)?; + *self = Self::op_node_check(lvl, self, op, key, key_path, maybe_value)?; Ok(()) } } @@ -813,7 +755,6 @@ impl Node { /// value is replaced in the case of an update and the leaf removed /// in the case of a deletion. pub(crate) fn op_node_check( - max_depth: usize, lvl: usize, node: &Node, op: MerkleTreeOp, @@ -826,7 +767,7 @@ impl Node { // Invalid args are assumed to have been ruled out. match (op, node, maybe_value) { // Insertion case - (Insert, Node::None, Some(value)) => Ok(Node::Leaf(Leaf::new(max_depth, key, value)?)), + (Insert, Node::None, Some(value)) => Ok(Node::Leaf(Leaf::new(key, value))), (Insert, Node::Leaf(l), Some(value)) => { // in this case, it means that we found a leaf in the new-leaf // path, thus we need to push both leaves (old-leaf and @@ -845,7 +786,6 @@ impl Node { let mut new_node = Node::Intermediate(Intermediate::empty()); new_node.down_till_divergence( lvl, - max_depth, old_leaf, Leaf { hash: None, @@ -859,7 +799,7 @@ impl Node { } // Update case (Update, Node::Leaf(l), Some(value)) if l.key == key => { - Ok(Node::Leaf(Leaf::new(max_depth, key, value)?)) + Ok(Node::Leaf(Leaf::new(key, value))) } // Deletion case (Delete, Node::Leaf(l), None) if l.key == key => Ok(Node::None), @@ -878,14 +818,9 @@ impl Node { fn down_till_divergence( &mut self, lvl: usize, - max_depth: usize, old_leaf: Leaf, new_leaf: Leaf, ) -> TreeResult<()> { - if lvl >= max_depth { - return Err(TreeError::max_depth()); - } - if let Node::Intermediate(ref mut n) = self { if old_leaf.path[lvl] != new_leaf.path[lvl] { // reached divergence in next level, set the leaves as children @@ -903,14 +838,10 @@ impl Node { // no divergence yet, continue going down if new_leaf.path[lvl] { n.right = Box::new(Node::Intermediate(Intermediate::empty())); - return n - .right - .down_till_divergence(lvl + 1, max_depth, old_leaf, new_leaf); + return n.right.down_till_divergence(lvl + 1, old_leaf, new_leaf); } else { n.left = Box::new(Node::Intermediate(Intermediate::empty())); - return n - .left - .down_till_divergence(lvl + 1, max_depth, old_leaf, new_leaf); + return n.left.down_till_divergence(lvl + 1, old_leaf, new_leaf); } } Ok(()) @@ -956,13 +887,13 @@ pub(crate) struct Leaf { pub(crate) value: RawValue, } impl Leaf { - fn new(max_depth: usize, key: RawValue, value: RawValue) -> TreeResult { - Ok(Self { + fn new(key: RawValue, value: RawValue) -> Self { + Self { hash: None, - path: keypath(max_depth, key)?, + path: keypath(key), key, value, - }) + } } fn compute_hash(&mut self) -> Hash { let h = kv_hash(&self.key, Some(self.value)); @@ -981,17 +912,12 @@ impl Leaf { // max-depth? ie, what happens when two keys share the same path for more bits // than the max_depth? /// returns the path of the given key -pub(crate) fn keypath(max_depth: usize, k: RawValue) -> TreeResult> { +pub(crate) fn keypath(k: RawValue) -> Vec { let bytes = k.to_bytes(); - if max_depth > 8 * bytes.len() { - // note that our current keys are of Value type, which are 4 Goldilocks - // field elements, ie ~256 bits, therefore the max_depth can not be - // bigger than 256. - return Err(TreeError::too_short_key(8 * bytes.len(), max_depth)); - } - Ok((0..max_depth) + debug_assert_eq!(MAX_DEPTH, bytes.len() * 8); + (0..MAX_DEPTH) .map(|n| bytes[n / 8] & (1 << (n % 8)) != 0) - .collect()) + .collect() } pub struct Iter<'a> { @@ -1035,7 +961,6 @@ pub mod tests { #[test] fn test_merkletree() -> TreeResult<()> { - let max_depth: usize = 32; let mut kvs = HashMap::new(); for i in 0..8 { if i == 1 { @@ -1047,7 +972,7 @@ pub mod tests { let value = RawValue::from(1013); kvs.insert(key, value); - let tree = MerkleTree::new(max_depth, &kvs)?; + let tree = MerkleTree::new(&kvs); // when printing the tree, it should print the same tree as in // https://0xparc.github.io/pod2/merkletree.html#example-2 println!("{}", tree); @@ -1057,7 +982,7 @@ pub mod tests { assert_eq!(v, RawValue::from(1013)); println!("{}", proof); - MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; + MerkleTree::verify(tree.root(), &proof, &key, &value)?; // Exclusion checks let key = RawValue::from(12); @@ -1068,42 +993,40 @@ pub mod tests { ); println!("{}", proof); - MerkleTree::verify_nonexistence(max_depth, tree.root(), &proof, &key)?; + MerkleTree::verify_nonexistence(tree.root(), &proof, &key)?; let key = RawValue::from(1); let proof = tree.prove_nonexistence(&RawValue::from(1))?; assert_eq!(proof.other_leaf, None); println!("{}", proof); - MerkleTree::verify_nonexistence(max_depth, tree.root(), &proof, &key)?; + MerkleTree::verify_nonexistence(tree.root(), &proof, &key)?; // Check iterator let collected_kvs: Vec<_> = tree.into_iter().collect::>(); // Expected key ordering - let cmp = |max_depth: usize| { - move |k1, k2| { - let path1 = keypath(max_depth, k1).unwrap(); - let path2 = keypath(max_depth, k2).unwrap(); + let cmp = |k1, k2| { + let path1 = keypath(k1); + let path2 = keypath(k2); - let first_unequal_bits = std::iter::zip(path1, path2).find(|(b1, b2)| b1 != b2); + let first_unequal_bits = std::iter::zip(path1, path2).find(|(b1, b2)| b1 != b2); - match first_unequal_bits { - Some((b1, b2)) => { - if !b1 & b2 { - Ordering::Less - } else { - Ordering::Greater - } + match first_unequal_bits { + Some((b1, b2)) => { + if !b1 & b2 { + Ordering::Less + } else { + Ordering::Greater } - _ => Ordering::Equal, } + _ => Ordering::Equal, } }; let sorted_kvs = kvs .iter() - .sorted_by(|(k1, _), (k2, _)| cmp(max_depth)(**k1, **k2)) + .sorted_by(|(k1, _), (k2, _)| cmp(**k1, **k2)) .collect::>(); assert_eq!(collected_kvs, sorted_kvs); @@ -1113,13 +1036,12 @@ pub mod tests { #[test] fn test_state_transition() -> TreeResult<()> { - let max_depth: usize = 32; let mut kvs = HashMap::new(); for i in 0..8 { kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); } - let mut tree = MerkleTree::new(max_depth, &kvs)?; + let mut tree = MerkleTree::new(&kvs); let old_root = tree.root(); // key=37 shares path with key=5, till the level 6, needing 2 extra @@ -1129,7 +1051,7 @@ pub mod tests { let value = RawValue::from(1037); let state_transition_proof = tree.insert(&key, &value)?; - MerkleTree::verify_state_transition(max_depth, &state_transition_proof)?; + MerkleTree::verify_state_transition(&state_transition_proof)?; assert_eq!(state_transition_proof.old_root, old_root); assert_eq!(state_transition_proof.new_root, tree.root()); assert_eq!(state_transition_proof.op_key, key); @@ -1140,7 +1062,7 @@ pub mod tests { // should be the same (mutatis mutandis). let mut tree_with_deleted_key = tree.clone(); let state_transition_proof1 = tree_with_deleted_key.delete(&key)?; - MerkleTree::verify_state_transition(max_depth, &state_transition_proof1)?; + MerkleTree::verify_state_transition(&state_transition_proof1)?; assert_eq!( state_transition_proof1.old_root, state_transition_proof.new_root @@ -1172,14 +1094,14 @@ pub mod tests { let value = RawValue::from(1021); let state_transition_proof = tree_with_another_leaf.insert(&key, &value)?; - MerkleTree::verify_state_transition(max_depth, &state_transition_proof)?; + MerkleTree::verify_state_transition(&state_transition_proof)?; // Alternatively add this key with another value then update. let value1 = RawValue::from(99); tree.insert(&key, &value1)?; let state_transition_proof1 = tree.update(&key, &value)?; - MerkleTree::verify_state_transition(max_depth, &state_transition_proof1)?; + MerkleTree::verify_state_transition(&state_transition_proof1)?; // `tree` and `tree_with_another_leaf` should coincide. assert_eq!(tree.root(), tree_with_another_leaf.root()); diff --git a/src/examples/mod.rs b/src/examples/mod.rs index b5b2604..0801978 100644 --- a/src/examples/mod.rs +++ b/src/examples/mod.rs @@ -5,7 +5,7 @@ use std::{collections::HashSet, sync::LazyLock}; use custom::eth_dos_batch; use num::BigUint; -pub static MOCK_VD_SET: LazyLock = LazyLock::new(|| VDSet::new(6, &[]).unwrap()); +pub static MOCK_VD_SET: LazyLock = LazyLock::new(|| VDSet::new(&[])); use crate::{ backends::plonky2::{primitives::ec::schnorr::SecretKey, signer::Signer}, @@ -50,8 +50,7 @@ pub fn zu_kyc_pod_builder( .iter() .map(|s| Value::from(*s)) .collect(); - let sanction_set = - Value::from(Set::new(params.max_depth_mt_containers, sanctions_values).unwrap()); + let sanction_set = Value::from(Set::new(sanctions_values)); let mut kyc = MainPodBuilder::new(params, vd_set); kyc.pub_op(Operation::dict_signed_by(gov_id))?; @@ -72,13 +71,11 @@ pub fn zu_kyc_pod_builder( } pub fn zu_kyc_pod_request(gov_signer: &Value, pay_signer: &Value) -> Result { - let params = Params::default(); let sanctions_values: HashSet = ZU_KYC_SANCTION_LIST .iter() .map(|s| Value::from(*s)) .collect(); - let sanction_set = - Value::from(Set::new(params.max_depth_mt_containers, sanctions_values).unwrap()); + let sanction_set = Value::from(Set::new(sanctions_values)); let input = format!( r#" REQUEST( @@ -347,9 +344,8 @@ pub fn great_boy_pod_full_flow() -> Result { alice_friend_pods.push(friend.sign(&charlie_signer).unwrap()); let good_boy_issuers = Value::from(Set::new( - params.max_depth_mt_containers, good_boy_issuers.into_iter().map(Value::from).collect(), - )?); + )); let builder = great_boy_pod_builder( ¶ms, @@ -433,6 +429,6 @@ pub fn tickets_pod_full_flow(params: &Params, vd_set: &VDSet) -> Result = [1, 2, 3].iter().map(|i| Value::from(*i)).collect(); - let s1 = Set::new(params.max_depth_mt_containers, set_values)?; + let s1 = Set::new(set_values); let s2 = 1; let set_contains = mp_builder.pub_op(Operation::set_contains(s1, s2))?; diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 2b1c070..1f5046d 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -59,7 +59,7 @@ impl SignedDictBuilder { pub fn sign(&self, signer: &S) -> Result { // Sign committed KV store. - let dict = Dictionary::new(self.params.max_depth_mt_containers, self.kvs.clone())?; + let dict = Dictionary::new(self.kvs.clone()); // NOTE: This is the same way that `TypedValue::Dictionary` computes the `RawValue` let msg_raw = RawValue::from(dict.commitment()); let signature = signer.sign(msg_raw); @@ -1026,11 +1026,11 @@ pub mod tests { let vd_set = &*MOCK_VD_SET; let mut builder = SignedDictBuilder::new(¶ms); - let dict = dict!(params.max_depth_mt_containers, { + let dict = dict!({ "a" => 1, "b" => 2, "c" => 3, - })?; + }); let dict_root = Value::from(dict.clone()); builder.insert("dict", dict_root); @@ -1042,7 +1042,7 @@ pub mod tests { .pub_op(Operation::dict_signed_by(&signed_dict)) .unwrap(); let st0 = signed_dict.get_statement("dict").unwrap(); - let local = dict!(32, {"key" => "a"})?; + let local = dict!({"key" => "a"}); let st1 = builder .op(true, vec![], Operation::dict_contains(local, "key", "a")) .unwrap(); @@ -1115,7 +1115,7 @@ pub mod tests { let vd_set = &*MOCK_VD_SET; let mut builder = MainPodBuilder::new(¶ms, vd_set); - let empty_set = Set::new(params.max_depth_mt_containers, [].into())?; + let empty_set = Set::new([].into()); let mut set1 = empty_set.clone(); set1.insert(&1.into())?; @@ -1163,7 +1163,7 @@ pub mod tests { let vd_set = &*MOCK_VD_SET; let mut builder = MainPodBuilder::new(¶ms, vd_set); - let array1 = Array::new(params.max_depth_mt_containers, [1.into()].into())?; + let array1 = Array::new([1.into()].into()); let mut array2 = array1.clone(); array2.update(0, &5.into())?; @@ -1217,7 +1217,7 @@ pub mod tests { "owner", Value::from(pk), ))?; - let local = dict!(32, { "known_secret" => sk.clone() })?; + let local = dict!({ "known_secret" => sk.clone() }); let st1 = builder.priv_op(Operation::dict_contains( local, "known_secret", @@ -1262,7 +1262,7 @@ pub mod tests { .pub_op(Operation::dict_signed_by(&signed_dict)) .unwrap(); let st0 = signed_dict.get_statement("owner").unwrap(); - let local = dict!(32, {"known_secret" => SecretKey(BigUint::from(123u32))})?; + let local = dict!({"known_secret" => SecretKey(BigUint::from(123u32))}); let st1 = builder .op( true, @@ -1333,7 +1333,7 @@ pub mod tests { let params = Params::default(); let vd_set = &*MOCK_VD_SET; let mut builder = MainPodBuilder::new(¶ms, vd_set); - let local = dict!(32, {"a" => 3, "b" => 27}).unwrap(); + let local = dict!({"a" => 3, "b" => 27}); let value_of_a = Statement::contains(local.clone(), "a", 3); let value_of_b = Statement::contains(local.clone(), "b", 27); diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index 9e2d020..8a47db3 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -89,19 +89,18 @@ mod tests { #[test] fn test_value_serialization() { - let params = &Params::default(); // Pairs of values and their expected serialized representations let values = vec![ (TypedValue::String("hello".to_string()), "\"hello\""), (TypedValue::Int(42), "{\"Int\":\"42\"}"), (TypedValue::Bool(true), "true"), ( - TypedValue::Array(Array::new(params.max_depth_mt_containers, vec!["foo".into(), false.into()]).unwrap()), - "{\"max_depth\":32,\"array\":[\"foo\",false]}", + TypedValue::Array(Array::new(vec!["foo".into(), false.into()])), + "{\"array\":[\"foo\",false]}", ), ( TypedValue::Dictionary( - Dictionary::new(params.max_depth_mt_containers, HashMap::from([ + Dictionary::new(HashMap::from([ // The set of valid keys is equal to the set of valid JSON keys ("foo".into(), 123.into()), // Empty strings are valid JSON keys @@ -115,13 +114,12 @@ mod tests { // Keys can contain emojis (("🥳".into()), "party time!".into()), ])) - .unwrap(), ), - "{\"max_depth\":32,\"kvs\":{\"\":\"baz\",\"\\u0000\":\"\",\" hi\":false,\"!@£$%^&&*()\":\"\",\"foo\":{\"Int\":\"123\"},\"🥳\":\"party time!\"}}", + "{\"kvs\":{\"\":\"baz\",\"\\u0000\":\"\",\" hi\":false,\"!@£$%^&&*()\":\"\",\"foo\":{\"Int\":\"123\"},\"🥳\":\"party time!\"}}", ), ( - TypedValue::Set(Set::new(params.max_depth_mt_containers, HashSet::from(["foo".into(), "bar".into()])).unwrap()), - "{\"max_depth\":32,\"set\":[\"bar\",\"foo\"]}", + TypedValue::Set(Set::new(HashSet::from(["foo".into(), "bar".into()]))), + "{\"set\":[\"bar\",\"foo\"]}", ), ]; @@ -147,39 +145,21 @@ mod tests { builder.insert("very_large_int", 1152921504606846976); builder.insert( "a_dict_containing_one_key", - Dictionary::new( - params.max_depth_mt_containers, - HashMap::from([ - ("foo".into(), 123.into()), - ( - "an_array_containing_three_ints".into(), - Array::new( - params.max_depth_mt_containers, - vec![1.into(), 2.into(), 3.into()], - ) - .unwrap() - .into(), - ), - ( - "a_set_containing_two_strings".into(), - Set::new( - params.max_depth_mt_containers, - HashSet::from([ - Array::new( - params.max_depth_mt_containers, - vec!["foo".into(), "bar".into()], - ) - .unwrap() - .into(), - "baz".into(), - ]), - ) - .unwrap() - .into(), - ), - ]), - ) - .unwrap(), + Dictionary::new(HashMap::from([ + ("foo".into(), 123.into()), + ( + "an_array_containing_three_ints".into(), + Array::new(vec![1.into(), 2.into(), 3.into()]).into(), + ), + ( + "a_set_containing_two_strings".into(), + Set::new(HashSet::from([ + Array::new(vec!["foo".into(), "bar".into()]).into(), + "baz".into(), + ])) + .into(), + ), + ])), ); builder } @@ -228,7 +208,7 @@ mod tests { }; let mut vds = DEFAULT_VD_LIST.clone(); vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); - let vd_set = VDSet::new(params.max_depth_mt_vds, &vds).unwrap(); + let vd_set = VDSet::new(&vds); let (gov_id_builder, pay_stub_builder) = zu_kyc_sign_dict_builders(¶ms); let signer = Signer(SecretKey(1u32.into())); diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index e5ad305..a8863d0 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -372,7 +372,7 @@ impl<'a> Lowerer<'a> { // Convert AST args to BuilderArgs let mut builder = StatementTmplBuilder::new(predicate); for arg in &stmt.args { - let builder_arg = self.lower_statement_arg_to_builder(arg)?; + let builder_arg = Self::lower_statement_arg_to_builder(arg)?; builder = builder.arg(builder_arg); } @@ -380,13 +380,10 @@ impl<'a> Lowerer<'a> { Ok(builder) } - fn lower_statement_arg_to_builder( - &self, - arg: &StatementTmplArg, - ) -> Result { + fn lower_statement_arg_to_builder(arg: &StatementTmplArg) -> Result { match arg { StatementTmplArg::Literal(lit) => { - let value = self.lower_literal(lit)?; + let value = Self::lower_literal(lit)?; Ok(BuilderArg::Literal(value)) } StatementTmplArg::Wildcard(id) => { @@ -403,7 +400,7 @@ impl<'a> Lowerer<'a> { } } - fn lower_literal(&self, lit: &LiteralValue) -> Result { + fn lower_literal(lit: &LiteralValue) -> Result { let value = match lit { LiteralValue::Int(i) => middleware::Value::from(i.value), LiteralValue::Bool(b) => middleware::Value::from(b.value), @@ -413,15 +410,15 @@ impl<'a> Lowerer<'a> { LiteralValue::SecretKey(sk) => middleware::Value::from(sk.secret_key.clone()), LiteralValue::Array(a) => { let elements: Result, _> = - a.elements.iter().map(|e| self.lower_literal(e)).collect(); - let array = containers::Array::new(self.params.max_depth_mt_containers, elements?)?; + a.elements.iter().map(Self::lower_literal).collect(); + let array = containers::Array::new(elements?); middleware::Value::from(array) } LiteralValue::Set(s) => { let elements: Result, _> = - s.elements.iter().map(|e| self.lower_literal(e)).collect(); + s.elements.iter().map(Self::lower_literal).collect(); let set_values: std::collections::HashSet<_> = elements?.into_iter().collect(); - let set = containers::Set::new(self.params.max_depth_mt_containers, set_values)?; + let set = containers::Set::new(set_values); middleware::Value::from(set) } LiteralValue::Dict(d) => { @@ -430,13 +427,12 @@ impl<'a> Lowerer<'a> { .iter() .map(|pair| { let key = middleware::Key::from(pair.key.value.as_str()); - let value = self.lower_literal(&pair.value)?; + let value = Self::lower_literal(&pair.value)?; Ok((key, value)) }) .collect(); let dict_map: std::collections::HashMap<_, _> = pairs?.into_iter().collect(); - let dict = - containers::Dictionary::new(self.params.max_depth_mt_containers, dict_map)?; + let dict = containers::Dictionary::new(dict_map); middleware::Value::from(dict) } }; diff --git a/src/middleware/containers.rs b/src/middleware/containers.rs index 8a7f1cc..d01f43f 100644 --- a/src/middleware/containers.rs +++ b/src/middleware/containers.rs @@ -22,34 +22,30 @@ pub struct Dictionary { #[serde(skip)] #[schemars(skip)] mt: MerkleTree, - max_depth: usize, #[serde(serialize_with = "ordered_map")] kvs: HashMap, } #[macro_export] macro_rules! dict { - ($max_depth:expr, { $($key:expr => $val:expr),* , }) => ( - $crate::dict!($max_depth, { $($key => $val),* }) + ({ $($key:expr => $val:expr),* , }) => ( + $crate::dict!({ $($key => $val),* }) ); - ($max_depth:expr, { $($key:expr => $val:expr),* }) => ({ + ({ $($key:expr => $val:expr),* }) => ({ let mut map = ::std::collections::HashMap::new(); $( map.insert($crate::middleware::Key::from($key), $crate::middleware::Value::from($val)); )* - $crate::middleware::containers::Dictionary::new($max_depth, map) + $crate::middleware::containers::Dictionary::new( map) }); } impl Dictionary { - /// max_depth determines the depth of the underlying MerkleTree, allowing to - /// store 2^max_depth elements in the Dictionary - pub fn new(max_depth: usize, kvs: HashMap) -> Result { + pub fn new(kvs: HashMap) -> Self { let kvs_raw: HashMap = kvs.iter().map(|(k, v)| (k.raw(), v.raw())).collect(); - Ok(Self { - mt: MerkleTree::new(max_depth, &kvs_raw)?, - max_depth, + Self { + mt: MerkleTree::new(&kvs_raw), kvs, - }) + } } pub fn commitment(&self) -> Hash { self.mt.root() @@ -82,46 +78,21 @@ impl Dictionary { self.kvs.remove(key); Ok(mtp) } - pub fn verify( - max_depth: usize, - root: Hash, - proof: &MerkleProof, - key: &Key, - value: &Value, - ) -> Result<()> { + pub fn verify(root: Hash, proof: &MerkleProof, key: &Key, value: &Value) -> Result<()> { let key = key.raw(); - Ok(MerkleTree::verify( - max_depth, - root, - proof, - &key, - &value.raw(), - )?) + Ok(MerkleTree::verify(root, proof, &key, &value.raw())?) } - pub fn verify_nonexistence( - max_depth: usize, - root: Hash, - proof: &MerkleProof, - key: &Key, - ) -> Result<()> { + pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Key) -> Result<()> { let key = key.raw(); - Ok(MerkleTree::verify_nonexistence( - max_depth, root, proof, &key, - )?) + Ok(MerkleTree::verify_nonexistence(root, proof, &key)?) } - pub fn verify_state_transition( - max_depth: usize, - proof: &MerkleTreeStateTransitionProof, - ) -> Result<()> { - MerkleTree::verify_state_transition(max_depth, proof).map_err(|e| e.into()) + pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { + MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) } // TODO: Rename to dict to be consistent maybe? pub fn kvs(&self) -> &HashMap { &self.kvs } - pub fn max_depth(&self) -> usize { - self.max_depth - } } impl PartialEq for Dictionary { @@ -140,10 +111,9 @@ impl<'de> Deserialize<'de> for Dictionary { struct Aux { #[serde(serialize_with = "ordered_map")] kvs: HashMap, - max_depth: usize, } let aux = Aux::deserialize(deserializer)?; - Dictionary::new(aux.max_depth, aux.kvs).map_err(serde::de::Error::custom) + Ok(Dictionary::new(aux.kvs)) } } @@ -155,15 +125,12 @@ pub struct Set { #[serde(skip)] #[schemars(skip)] mt: MerkleTree, - max_depth: usize, #[serde(serialize_with = "ordered_set")] set: HashSet, } impl Set { - /// max_depth determines the depth of the underlying MerkleTree, allowing to - /// store 2^max_depth elements in the Array - pub fn new(max_depth: usize, set: HashSet) -> Result { + pub fn new(set: HashSet) -> Self { let kvs_raw: HashMap = set .iter() .map(|e| { @@ -171,11 +138,10 @@ impl Set { (rv, rv) }) .collect(); - Ok(Self { - mt: MerkleTree::new(max_depth, &kvs_raw)?, - max_depth, + Self { + mt: MerkleTree::new(&kvs_raw), set, - }) + } } pub fn commitment(&self) -> Hash { self.mt.root() @@ -203,33 +169,20 @@ impl Set { self.set.remove(value); Ok(mtp) } - pub fn verify(max_depth: usize, root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { + pub fn verify(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { let rv = value.raw(); - Ok(MerkleTree::verify(max_depth, root, proof, &rv, &rv)?) + Ok(MerkleTree::verify(root, proof, &rv, &rv)?) } - pub fn verify_nonexistence( - max_depth: usize, - root: Hash, - proof: &MerkleProof, - value: &Value, - ) -> Result<()> { + pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { let rv = value.raw(); - Ok(MerkleTree::verify_nonexistence( - max_depth, root, proof, &rv, - )?) + Ok(MerkleTree::verify_nonexistence(root, proof, &rv)?) } - pub fn verify_state_transition( - max_depth: usize, - proof: &MerkleTreeStateTransitionProof, - ) -> Result<()> { - MerkleTree::verify_state_transition(max_depth, proof).map_err(|e| e.into()) + pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { + MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) } pub fn set(&self) -> &HashSet { &self.set } - pub fn max_depth(&self) -> usize { - self.max_depth - } } impl PartialEq for Set { @@ -248,10 +201,9 @@ impl<'de> Deserialize<'de> for Set { struct Aux { #[serde(serialize_with = "ordered_set")] set: HashSet, - max_depth: usize, } let aux = Aux::deserialize(deserializer)?; - Set::new(aux.max_depth, aux.set).map_err(serde::de::Error::custom) + Ok(Set::new(aux.set)) } } @@ -264,25 +216,21 @@ pub struct Array { #[serde(skip)] #[schemars(skip)] mt: MerkleTree, - max_depth: usize, array: Vec, } impl Array { - /// max_depth determines the depth of the underlying MerkleTree, allowing to - /// store 2^max_depth elements in the Array - pub fn new(max_depth: usize, array: Vec) -> Result { + pub fn new(array: Vec) -> Self { let kvs_raw: HashMap = array .iter() .enumerate() .map(|(i, e)| (RawValue::from(i as i64), e.raw())) .collect(); - Ok(Self { - mt: MerkleTree::new(max_depth, &kvs_raw)?, - max_depth, + Self { + mt: MerkleTree::new(&kvs_raw), array, - }) + } } pub fn commitment(&self) -> Hash { self.mt.root() @@ -302,33 +250,20 @@ impl Array { self.array[i] = value.clone(); Ok(mtp) } - pub fn verify( - max_depth: usize, - root: Hash, - proof: &MerkleProof, - i: usize, - value: &Value, - ) -> Result<()> { + pub fn verify(root: Hash, proof: &MerkleProof, i: usize, value: &Value) -> Result<()> { Ok(MerkleTree::verify( - max_depth, root, proof, &RawValue::from(i as i64), &value.raw(), )?) } - pub fn verify_state_transition( - max_depth: usize, - proof: &MerkleTreeStateTransitionProof, - ) -> Result<()> { - MerkleTree::verify_state_transition(max_depth, proof).map_err(|e| e.into()) + pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { + MerkleTree::verify_state_transition(proof).map_err(|e| e.into()) } pub fn array(&self) -> &[Value] { &self.array } - pub fn max_depth(&self) -> usize { - self.max_depth - } } impl PartialEq for Array { @@ -346,9 +281,8 @@ impl<'de> Deserialize<'de> for Array { #[derive(Deserialize, JsonSchema)] struct Aux { array: Vec, - max_depth: usize, } let aux = Aux::deserialize(deserializer)?; - Array::new(aux.max_depth, aux.array).map_err(serde::de::Error::custom) + Ok(Array::new(aux.array)) } } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 642284b..9254760 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -524,13 +524,13 @@ mod tests { )?], ); - let d0 = dict!(32, { + let d0 = dict!({ "a" => 10, - })?; - let d1 = dict!(32, { + }); + let d1 = dict!({ "b" => 15, "c" => 17, - })?; + }); let custom_statement = Statement::Custom( CustomPredicateRef::new(cust_pred_batch.clone(), 0), vec![Value::from(d0.clone())], diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 21d34f0..2200513 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -446,24 +446,13 @@ impl Operation { let root = val(root_v, root_s)?; let key = val(key_v, key_s)?; let value = val(val_v, val_s)?; - MerkleTree::verify( - params.max_depth_mt_containers, - root.raw().into(), - pf, - &key.raw(), - &value.raw(), - )?; + MerkleTree::verify(root.raw().into(), pf, &key.raw(), &value.raw())?; true } (Self::NotContainsFromEntries(root_s, key_s, pf), NotContains(root_v, key_v)) => { let root = val(root_v, root_s)?; let key = val(key_v, key_s)?; - MerkleTree::verify_nonexistence( - params.max_depth_mt_containers, - root.raw().into(), - pf, - &key.raw(), - )?; + MerkleTree::verify_nonexistence(root.raw().into(), pf, &key.raw())?; true } ( @@ -507,7 +496,7 @@ impl Operation { "The provided Merkle tree state transition proof does not match the claim." .into(), ))?; - MerkleTree::verify_state_transition(params.max_depth_mt_containers, pf)?; + MerkleTree::verify_state_transition(pf)?; true } ( @@ -528,7 +517,7 @@ impl Operation { "The provided Merkle tree state transition proof does not match the claim." .into(), ))?; - MerkleTree::verify_state_transition(params.max_depth_mt_containers, pf)?; + MerkleTree::verify_state_transition(pf)?; true } ( @@ -547,7 +536,7 @@ impl Operation { "The provided Merkle tree state transition proof does not match the claim." .into(), ))?; - MerkleTree::verify_state_transition(params.max_depth_mt_containers, pf)?; + MerkleTree::verify_state_transition(pf)?; true } (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) @@ -815,7 +804,7 @@ mod tests { let kvs = (0..10) .map(|i| (hash_value(&i.into()).into(), i.into())) .collect::>(); - let mt = MerkleTree::new(params.max_depth_mt_containers, &kvs)?; + let mt = MerkleTree::new(&kvs); let root = mt.root(); // Check existence proofs @@ -873,7 +862,7 @@ mod tests { let kvs = (0..10) .map(|i| (hash_value(&i.into()).into(), i.into())) .collect::>(); - let mut mt = MerkleTree::new(params.max_depth_mt_containers, &kvs)?; + let mut mt = MerkleTree::new(&kvs); // Check insertion proofs (11..20)