From c232c8dae52697f120804677f1c9c485754ab048 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Wed, 16 Apr 2025 11:59:30 +0200 Subject: [PATCH] Refactor frontend/middleware types (#194) * unify fe/be NativeOp and NativePred * remove Origin in favour of PodId * Combine string and hash in Key * use middleware::AnchoredKey in frontend * merge frontend/middleware types * refactor custom predicates * clean up a bit * fix middleware custom tests * clean up * clean up 2 * add acronyms in typos list --- .github/workflows/typos.toml | 2 + src/backends/plonky2/basetypes.rs | 256 +--- src/backends/plonky2/circuits/common.rs | 6 +- src/backends/plonky2/circuits/mainpod.rs | 58 +- src/backends/plonky2/circuits/signedpod.rs | 43 +- src/backends/plonky2/mainpod.rs | 75 +- src/backends/plonky2/mock/mainpod/mod.rs | 58 +- .../plonky2/mock/mainpod/operation.rs | 32 +- .../plonky2/mock/mainpod/statement.rs | 14 +- src/backends/plonky2/mock/signedpod.rs | 101 +- src/backends/plonky2/primitives/merkletree.rs | 64 +- .../plonky2/primitives/merkletree_circuit.rs | 83 +- src/backends/plonky2/primitives/signature.rs | 47 +- .../plonky2/primitives/signature_circuit.rs | 33 +- src/backends/plonky2/signedpod.rs | 70 +- src/examples/custom.rs | 64 +- src/examples/mod.rs | 114 +- src/frontend/containers.rs | 112 -- src/frontend/custom.rs | 481 +------- src/frontend/mod.rs | 1079 +++++------------ src/frontend/operation.rs | 135 +-- src/frontend/predicate.rs | 58 - src/frontend/serialization.rs | 62 +- src/frontend/statement.rs | 162 --- src/lib.rs | 1 - src/middleware/basetypes.rs | 209 +++- src/middleware/containers.rs | 135 ++- src/middleware/custom.rs | 479 ++++---- src/middleware/mod.rs | 314 ++++- src/middleware/operation.rs | 282 +++-- src/middleware/serialization.rs | 3 + src/middleware/statement.rs | 133 +- src/util.rs | 20 - 33 files changed, 1985 insertions(+), 2800 deletions(-) delete mode 100644 src/frontend/containers.rs delete mode 100644 src/frontend/predicate.rs delete mode 100644 src/frontend/statement.rs delete mode 100644 src/util.rs diff --git a/.github/workflows/typos.toml b/.github/workflows/typos.toml index 07d01ab..beca8e5 100644 --- a/.github/workflows/typos.toml +++ b/.github/workflows/typos.toml @@ -5,3 +5,5 @@ Ded = "Ded" # "ANDed", it thought "Ded" should be "Dead" OT = "OT" aks = "aks" # anchored keys nin = "nin" # not in +kow = "kow" # key or wildcard +KOW = "KOW" # Key Or Wildcard diff --git a/src/backends/plonky2/basetypes.rs b/src/backends/plonky2/basetypes.rs index 170ff74..0f17151 100644 --- a/src/backends/plonky2/basetypes.rs +++ b/src/backends/plonky2/basetypes.rs @@ -2,37 +2,10 @@ //! `backend_plonky2` feature is enabled. //! See src/middleware/basetypes.rs for more details. -use std::{ - cmp::{Ord, Ordering}, - fmt, -}; +use plonky2::plonk::{config::PoseidonGoldilocksConfig, proof::Proof as Plonky2Proof}; -use anyhow::{anyhow, Error, Result}; -use hex::{FromHex, FromHexError}; -use plonky2::{ - field::{ - goldilocks_field::GoldilocksField, - types::{Field, PrimeField64}, - }, - hash::poseidon::PoseidonHash, - plonk::{ - config::{Hasher, PoseidonGoldilocksConfig}, - proof::Proof as Plonky2Proof, - }, -}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +use crate::middleware::F; -use crate::middleware::{ - serialization::{ - deserialize_hash_tuple, deserialize_value_tuple, serialize_hash_tuple, - serialize_value_tuple, - }, - Params, ToFields, -}; - -/// F is the native field we use everywhere. Currently it's Goldilocks from plonky2 -pub type F = GoldilocksField; /// C is the Plonky2 config used in POD2 to work with Plonky2 recursion. pub type C = PoseidonGoldilocksConfig; /// D defines the extension degree of the field used in the Plonky2 proofs (quadratic extension). @@ -40,228 +13,3 @@ pub const D: usize = 2; /// proof system proof pub type Proof = Plonky2Proof; - -pub const HASH_SIZE: usize = 4; -pub const VALUE_SIZE: usize = 4; - -pub const EMPTY_VALUE: Value = Value([F::ZERO, F::ZERO, F::ZERO, F::ZERO]); -pub const SELF_ID_HASH: Hash = Hash([F::ONE, F::ZERO, F::ZERO, F::ZERO]); -pub const EMPTY_HASH: Hash = Hash([F::ZERO, F::ZERO, F::ZERO, F::ZERO]); - -#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -#[schemars(rename = "MiddlewareValue")] -pub struct Value( - #[serde( - serialize_with = "serialize_value_tuple", - deserialize_with = "deserialize_value_tuple" - )] - // We know that Serde will serialize and deserialize this as a string, so we can - // use the JsonSchema to validate the format. - #[schemars(with = "String", regex(pattern = r"^[0-9a-fA-F]{64}$"))] - pub [F; VALUE_SIZE], -); - -impl ToFields for Value { - fn to_fields(&self, _params: &Params) -> Vec { - self.0.to_vec() - } -} - -impl Value { - pub fn to_bytes(self) -> Vec { - self.0 - .iter() - .flat_map(|e| e.to_canonical_u64().to_le_bytes()) - .collect() - } -} - -impl Ord for Value { - fn cmp(&self, other: &Self) -> Ordering { - for (lhs, rhs) in self.0.iter().zip(other.0.iter()).rev() { - let (lhs, rhs) = (lhs.to_canonical_u64(), rhs.to_canonical_u64()); - match lhs.cmp(&rhs) { - Ordering::Less => return Ordering::Less, - Ordering::Greater => return Ordering::Greater, - _ => {} - } - } - Ordering::Equal - } -} - -impl PartialOrd for Value { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl From for Value { - fn from(v: i64) -> Self { - let lo = F::from_canonical_u64((v as u64) & 0xffffffff); - let hi = F::from_canonical_u64((v as u64) >> 32); - Value([lo, hi, F::ZERO, F::ZERO]) - } -} - -impl From for Value { - fn from(h: Hash) -> Self { - Value(h.0) - } -} - -impl TryInto for Value { - type Error = Error; - fn try_into(self) -> std::result::Result { - let value = self.0; - if value[2..] != [F::ZERO, F::ZERO] - || value[..2] - .iter() - .all(|x| x.to_canonical_u64() > u32::MAX as u64) - { - Err(anyhow!("Value not an element of the i64 embedding.")) - } else { - Ok((value[0].to_canonical_u64() | (value[1].to_canonical_u64() << 32)) as i64) - } - } -} - -impl fmt::Display for Value { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.0[2].is_zero() && self.0[3].is_zero() { - // Assume this is an integer - let (l0, l1) = (self.0[0].to_canonical_u64(), self.0[1].to_canonical_u64()); - assert!(l0 < (1 << 32)); - assert!(l1 < (1 << 32)); - write!(f, "{}", l0 + l1 * (1 << 32)) - } else { - // Assume this is a hash - Hash(self.0).fmt(f) - } - } -} - -#[derive(Clone, Copy, Debug, Default, Hash, Eq, PartialEq, Serialize, Deserialize, JsonSchema)] -pub struct Hash( - #[serde( - serialize_with = "serialize_hash_tuple", - deserialize_with = "deserialize_hash_tuple" - )] - #[schemars(with = "String", regex(pattern = r"^[0-9a-fA-F]{64}$"))] - pub [F; HASH_SIZE], -); - -pub fn hash_value(input: &Value) -> Hash { - hash_fields(&input.0) -} - -pub fn hash_fields(input: &[F]) -> Hash { - Hash(PoseidonHash::hash_no_pad(input).elements) -} - -impl From for Hash { - fn from(v: Value) -> Self { - Hash(v.0) - } -} -impl Hash { - pub fn value(self) -> Value { - Value(self.0) - } -} - -impl ToFields for Hash { - fn to_fields(&self, _params: &Params) -> Vec { - self.0.to_vec() - } -} - -impl Ord for Hash { - fn cmp(&self, other: &Self) -> Ordering { - Value(self.0).cmp(&Value(other.0)) - } -} - -impl PartialOrd for Hash { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl fmt::Display for Hash { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let v0 = self.0[0].to_canonical_u64(); - for i in 0..HASH_SIZE { - write!(f, "{:02x}", (v0 >> (i * 8)) & 0xff)?; - } - write!(f, "…") - } -} - -impl FromHex for Hash { - type Error = FromHexError; - - // TODO make it dependant on backend::Value len - fn from_hex>(hex: T) -> Result { - // In little endian - let bytes = <[u8; 32]>::from_hex(hex)?; - let mut buf: [u8; 8] = [0; 8]; - let mut inner = [F::ZERO; HASH_SIZE]; - for i in 0..HASH_SIZE { - buf.copy_from_slice(&bytes[8 * i..8 * (i + 1)]); - inner[i] = F::from_canonical_u64(u64::from_le_bytes(buf)); - } - Ok(Self(inner)) - } -} - -impl From<&str> for Hash { - fn from(s: &str) -> Self { - hash_str(s) - } -} - -pub fn hash_str(s: &str) -> Hash { - let mut input = s.as_bytes().to_vec(); - input.push(1); // padding - - // Merge 7 bytes into 1 field, because the field is slightly below 64 bits - let input: Vec = input - .chunks(7) - .map(|bytes| { - let mut v: u64 = 0; - for b in bytes.iter().rev() { - v <<= 8; - v += *b as u64; - } - F::from_canonical_u64(v) - }) - .collect(); - hash_fields(&input) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_i64_value_roundtrip() { - let test_cases = [ - 0i64, - 1, - -1, - i64::MAX, - i64::MIN, - 42, - -42, - 1 << 32, - -(1 << 32), - ]; - - for &original in test_cases.iter() { - let value = Value::from(original); - let roundtrip: i64 = value.try_into().unwrap(); - assert_eq!(original, roundtrip, "Failed roundtrip for {}", original); - } - } -} diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 63c6c8e..eb70900 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -23,7 +23,7 @@ use crate::{ primitives::merkletree::MerkleClaimAndProofTarget, }, middleware::{ - NativeOperation, NativePredicate, Params, Predicate, StatementArg, ToFields, Value, + NativeOperation, NativePredicate, Params, Predicate, RawValue, StatementArg, ToFields, EMPTY_VALUE, F, HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE, }, @@ -294,7 +294,7 @@ pub trait CircuitBuilderPod, const D: usize> { fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget; fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget; - fn constant_value(&mut self, v: Value) -> ValueTarget; + fn constant_value(&mut self, v: RawValue) -> ValueTarget; fn is_equal_slice(&mut self, xs: &[Target], ys: &[Target]) -> BoolTarget; // Convenience methods for checking values. @@ -365,7 +365,7 @@ impl CircuitBuilderPod for CircuitBuilder { BoolTarget::new_unsafe(self.select(b, x.target, y.target)) } - fn constant_value(&mut self, v: Value) -> ValueTarget { + fn constant_value(&mut self, v: RawValue) -> ValueTarget { ValueTarget { elements: std::array::from_fn(|i| { self.constant(F::from_noncanonical_u64(v.0[i].to_noncanonical_u64())) diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 135e262..c2e238c 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -8,7 +8,7 @@ use plonky2::{ use crate::{ backends::plonky2::{ - basetypes::{Value, D, EMPTY_HASH, F, VALUE_SIZE}, + basetypes::D, circuits::{ common::{ CircuitBuilderPod, Flattenable, MerkleClaimTarget, OperationTarget, @@ -24,8 +24,8 @@ use crate::{ signedpod::SignedPod, }, middleware::{ - hash_str, AnchoredKey, NativeOperation, NativePredicate, Params, PodType, Statement, - StatementArg, ToFields, KEY_TYPE, SELF, + AnchoredKey, NativeOperation, NativePredicate, Params, PodType, Statement, StatementArg, + ToFields, Value, F, KEY_TYPE, SELF, VALUE_SIZE, }, }; @@ -304,7 +304,7 @@ impl OperationVerifyGadget { let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::ValueOf); let expected_arg_prefix = builder.constants( - &StatementArg::Key(AnchoredKey(SELF, EMPTY_HASH)).to_fields(&self.params)[..VALUE_SIZE], + &StatementArg::Key(AnchoredKey::from((SELF, ""))).to_fields(&self.params)[..VALUE_SIZE], ); let arg_prefix_ok = builder.is_equal_slice(&st.args[0].elements[..VALUE_SIZE], &expected_arg_prefix); @@ -422,11 +422,13 @@ impl MainPodVerifyGadget { let type_statement = &pub_statements[0]; // TODO: Store this hash in a global static with lazy init so that we don't have to // compute it every time. - let key_type = hash_str(KEY_TYPE); let expected_type_statement = StatementTarget::from_flattened( &builder.constants( - &Statement::ValueOf(AnchoredKey(SELF, key_type), Value::from(PodType::MockMain)) - .to_fields(params), + &Statement::ValueOf( + AnchoredKey::from((SELF, KEY_TYPE)), + Value::from(PodType::MockMain), + ) + .to_fields(params), ), ); builder.connect_flattenable(type_statement, &expected_type_statement); @@ -541,7 +543,7 @@ mod tests { mainpod::{OperationArg, OperationAux}, }, }, - middleware::{OperationType, PodId}, + middleware::{OperationType, PodId, RawValue}, }; fn operation_verify( @@ -573,7 +575,7 @@ mod tests { .map(|pf| pf.into()) .collect(); - let operation_verify = OperationVerifyGadget { + OperationVerifyGadget { params: params.clone(), } .eval( @@ -634,10 +636,10 @@ mod tests { // NewEntry let st1: mainpod::Statement = - Statement::ValueOf(AnchoredKey(SELF, "hello".into()), 55.into()).into(); + Statement::ValueOf(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); let st2: mainpod::Statement = Statement::ValueOf( - AnchoredKey(PodId(Value::from(75).into()), "hello".into()), - 55.into(), + AnchoredKey::from((PodId(RawValue::from(75).into()), "hello")), + Value::from(55), ) .into(); let prev_statements = vec![st2]; @@ -665,13 +667,13 @@ mod tests { // Eq let st2: mainpod::Statement = Statement::ValueOf( - AnchoredKey(PodId(Value::from(75).into()), "world".into()), - 55.into(), + AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), + Value::from(55), ) .into(); let st: mainpod::Statement = Statement::Equal( - AnchoredKey(SELF, "hello".into()), - AnchoredKey(PodId(Value::from(75).into()), "world".into()), + AnchoredKey::from((SELF, "hello")), + AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), ) .into(); let op = mainpod::Operation( @@ -684,13 +686,13 @@ mod tests { // Lt let st2: mainpod::Statement = Statement::ValueOf( - AnchoredKey(PodId(Value::from(88).into()), "hello".into()), - 56.into(), + AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), + Value::from(56), ) .into(); let st: mainpod::Statement = Statement::Lt( - AnchoredKey(SELF, "hello".into()), - AnchoredKey(PodId(Value::from(88).into()), "hello".into()), + AnchoredKey::from((SELF, "hello")), + AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), ) .into(); let op = mainpod::Operation( @@ -711,16 +713,16 @@ mod tests { .collect(); let mt = MerkleTree::new(params.max_depth_mt_gadget, &kvs)?; - let root = mt.root().into(); - let root_ak = AnchoredKey(PodId(Value::from(88).into()), "merkle root".into()); + let root = Value::from(mt.root()); + let root_ak = AnchoredKey::from((PodId(RawValue::from(88).into()), "merkle root")); let key = 5.into(); - let key_ak = AnchoredKey(PodId(Value::from(88).into()), "key".into()); + let key_ak = AnchoredKey::from((PodId(RawValue::from(88).into()), "key")); let no_key_pf = mt.prove_nonexistence(&key)?; - let root_st: mainpod::Statement = Statement::ValueOf(root_ak, root).into(); - let key_st: mainpod::Statement = Statement::ValueOf(key_ak, key).into(); + let root_st: mainpod::Statement = Statement::ValueOf(root_ak.clone(), root.clone()).into(); + let key_st: mainpod::Statement = Statement::ValueOf(key_ak.clone(), key.into()).into(); let st: mainpod::Statement = Statement::NotContains(root_ak, key_ak).into(); let op = mainpod::Operation( OperationType::Native(NativeOperation::NotContainsFromEntries), @@ -729,7 +731,11 @@ mod tests { ); let merkle_proofs = vec![mainpod::MerkleClaimAndProof::try_from_middleware( - ¶ms, &root, &key, None, &no_key_pf, + ¶ms, + &root.raw(), + &key, + None, + &no_key_pf, )?]; let prev_statements = vec![root_st, key_st]; operation_verify(st, op, prev_statements, merkle_proofs.clone())?; diff --git a/src/backends/plonky2/circuits/signedpod.rs b/src/backends/plonky2/circuits/signedpod.rs index 9008089..1091784 100644 --- a/src/backends/plonky2/circuits/signedpod.rs +++ b/src/backends/plonky2/circuits/signedpod.rs @@ -13,7 +13,7 @@ use plonky2::{ use crate::{ backends::plonky2::{ - basetypes::{Value, D, EMPTY_VALUE, F}, + basetypes::D, circuits::common::{CircuitBuilderPod, StatementArgTarget, StatementTarget, ValueTarget}, primitives::{ merkletree::{MerkleProof, MerkleProofExistenceGadget, MerkleProofExistenceTarget}, @@ -22,7 +22,8 @@ use crate::{ signedpod::SignedPod, }, middleware::{ - hash_str, NativePredicate, Params, PodType, Predicate, ToFields, KEY_SIGNER, KEY_TYPE, SELF, + hash_str, Key, NativePredicate, Params, PodType, Predicate, RawValue, ToFields, Value, + EMPTY_VALUE, F, KEY_SIGNER, KEY_TYPE, SELF, }, }; @@ -48,7 +49,7 @@ impl SignedPodVerifyGadget { let type_mt_proof = &mt_proofs[0]; let key_type = builder.constant_value(hash_str(KEY_TYPE).into()); builder.connect_values(type_mt_proof.key, key_type); - let value_type = builder.constant_value(Value::from(PodType::Signed)); + let value_type = builder.constant_value(Value::from(PodType::Signed).raw()); builder.connect_values(type_mt_proof.value, value_type); // 3.a. Verify signature @@ -56,7 +57,7 @@ impl SignedPodVerifyGadget { // 3.b. Verify signer (ie. signature.pk == merkletree.signer_leaf) let signer_mt_proof = &mt_proofs[1]; - let key_signer = builder.constant_value(hash_str(KEY_SIGNER).into()); + let key_signer = builder.constant_value(Key::from(KEY_SIGNER).raw()); builder.connect_values(signer_mt_proof.key, key_signer); builder.connect_values(signer_mt_proof.value, signature.pk); @@ -122,21 +123,28 @@ impl SignedPodVerifyTarget { // - empty leaves (if needed) // add proof verification of KEY_TYPE & KEY_SIGNER leaves - let key_type_key = Value::from(hash_str(KEY_TYPE)); - let key_signer_key = Value::from(hash_str(KEY_SIGNER)); - let key_signer_value = [key_type_key, key_signer_key] + let key_type_key = Key::from(KEY_TYPE); + let key_signer_key = Key::from(KEY_SIGNER); + let key_signer_value = [&key_type_key, &key_signer_key] .iter() .enumerate() .map(|(i, k)| { let (v, proof) = pod.dict.prove(k)?; - self.mt_proofs[i].set_targets(pw, true, pod.dict.commitment(), proof, *k, v)?; + self.mt_proofs[i].set_targets( + pw, + true, + pod.dict.commitment(), + proof, + k.raw(), + v.raw(), + )?; Ok(v) }) - .collect::>>()?[1]; + .collect::>>()?[1]; // add the verification of the rest of leaves let mut curr = 2; // since we already added key_type and key_signer - for (k, v) in pod.dict.iter().sorted_by_key(|kv| kv.0) { + for (k, v) in pod.dict.kvs().iter().sorted_by_key(|kv| kv.0.hash()) { if *k == key_type_key || *k == key_signer_key { // skip the key_type & key_signer leaves, since they have // already been checked @@ -144,9 +152,16 @@ impl SignedPodVerifyTarget { } let (obtained_v, proof) = pod.dict.prove(k)?; - assert_eq!(obtained_v, *v); // sanity check + assert_eq!(obtained_v, v); // sanity check - self.mt_proofs[curr].set_targets(pw, true, pod.dict.commitment(), proof, *k, *v)?; + self.mt_proofs[curr].set_targets( + pw, + true, + pod.dict.commitment(), + proof, + k.raw(), + v.raw(), + )?; curr += 1; } // sanity check @@ -170,9 +185,9 @@ impl SignedPodVerifyTarget { } // get the signer pk - let pk = PublicKey(key_signer_value); + let pk = PublicKey(key_signer_value.raw()); // the msg signed is the pod.id - let msg = Value::from(pod.id.0); + let msg = RawValue::from(pod.id.0); // set signature targets values self.signature diff --git a/src/backends/plonky2/mainpod.rs b/src/backends/plonky2/mainpod.rs index 4e0e68f..fca8542 100644 --- a/src/backends/plonky2/mainpod.rs +++ b/src/backends/plonky2/mainpod.rs @@ -11,13 +11,13 @@ use plonky2::{ use crate::{ backends::plonky2::{ - basetypes::{C, D, F}, + basetypes::{C, D}, circuits::mainpod::{MainPodVerifyCircuit, MainPodVerifyInput}, mock::mainpod::{hash_statements, MockMainPod, Statement}, signedpod::SignedPod, }, middleware::{ - self, AnchoredKey, MainPodInputs, Params, Pod, PodId, PodProver, StatementArg, SELF, + self, AnchoredKey, MainPodInputs, Params, Pod, PodId, PodProver, StatementArg, F, SELF, }, }; // TODO: Move the shared components between MockMainPod and MainPod to a common place. @@ -136,10 +136,10 @@ impl Pod for MainPod { .1 .iter() .map(|sa| match &sa { - StatementArg::Key(AnchoredKey(pod_id, h)) if *pod_id == SELF => { - StatementArg::Key(AnchoredKey(self.id(), *h)) + StatementArg::Key(AnchoredKey { pod_id, key }) if *pod_id == SELF => { + StatementArg::Key(AnchoredKey::new(self.id(), key.clone())) } - _ => *sa, + _ => sa.clone(), }) .collect(), ) @@ -168,61 +168,12 @@ pub mod tests { backends::plonky2::{ mock::mainpod::MockProver, primitives::signature::SecretKey, signedpod::Signer, }, - examples::zu_kyc_sign_pod_builders, + examples::{zu_kyc_pod_builder, zu_kyc_sign_pod_builders}, frontend, middleware, - middleware::Value, + middleware::RawValue, op, }; - // TODO: Use the method from examples once everything works - pub fn zu_kyc_pod_builder( - params: &Params, - gov_id: &frontend::SignedPod, - pay_stub: &frontend::SignedPod, - sanction_list: &frontend::SignedPod, - ) -> Result { - let sanction_set = match sanction_list.kvs.get("sanctionList") { - Some(frontend::Value::Set(s)) => Ok(s), - _ => Err(anyhow!("Missing sanction list!")), - }?; - let now_minus_18y: i64 = 1169909388; - let now_minus_1y: i64 = 1706367566; - - let gov_id_kvs = gov_id.kvs(); - let id_number_value = gov_id_kvs.get(&"idNumber".into()).unwrap(); - - let mut kyc = frontend::MainPodBuilder::new(params); - kyc.add_signed_pod(gov_id); - kyc.add_signed_pod(pay_stub); - kyc.add_signed_pod(sanction_list); - kyc.pub_op(op!( - set_not_contains, - (sanction_list, "sanctionList"), - (gov_id, "idNumber"), - sanction_set - .middleware_set() - .prove_nonexistence(id_number_value)? - ))?; - kyc.pub_op(op!(lt, (gov_id, "dateOfBirth"), now_minus_18y))?; - kyc.pub_op(op!( - eq, - (gov_id, "socialSecurityNumber"), - (pay_stub, "socialSecurityNumber") - ))?; - let start_date_st = kyc.pub_op(frontend::Operation( - frontend::OperationType::Native(frontend::NativeOperation::NewEntry), - vec![frontend::OperationArg::Entry( - "startDate".to_string(), - now_minus_1y.into(), - )], - middleware::OperationAux::None, - ))?; - kyc.pub_op(op!(eq, (pay_stub, "startDate"), start_date_st))?; - kyc.pub_op(op!(eq, (pay_stub, "startDate"), now_minus_1y))?; - - Ok(kyc) - } - #[test] fn test_main_zu_kyc() -> Result<()> { let params = middleware::Params { @@ -232,15 +183,13 @@ pub mod tests { ..Default::default() }; - let sanctions_values = vec!["A343434340".into()]; - let sanction_set = frontend::Value::Set(frontend::containers::Set::new(sanctions_values)?); let (gov_id_builder, pay_stub_builder, sanction_list_builder) = - zu_kyc_sign_pod_builders(¶ms, &sanction_set); - let mut signer = Signer(SecretKey(Value::from(1))); + zu_kyc_sign_pod_builders(¶ms); + let mut signer = Signer(SecretKey(RawValue::from(1))); let gov_id_pod = gov_id_builder.sign(&mut signer)?; - let mut signer = Signer(SecretKey(Value::from(2))); + let mut signer = Signer(SecretKey(RawValue::from(2))); let pay_stub_pod = pay_stub_builder.sign(&mut signer)?; - let mut signer = Signer(SecretKey(Value::from(3))); + let mut signer = Signer(SecretKey(RawValue::from(3))); let sanction_list_pod = sanction_list_builder.sign(&mut signer)?; let kyc_builder = zu_kyc_pod_builder(¶ms, &gov_id_pod, &pay_stub_pod, &sanction_list_pod)?; @@ -267,7 +216,7 @@ pub mod tests { gov_id_builder.insert("idNumber", "4242424242"); gov_id_builder.insert("dateOfBirth", 1169909384); gov_id_builder.insert("socialSecurityNumber", "G2121210"); - let mut signer = Signer(SecretKey(Value::from(42))); + let mut signer = Signer(SecretKey(RawValue::from(42))); let gov_id = gov_id_builder.sign(&mut signer).unwrap(); let now_minus_18y: i64 = 1169909388; let mut kyc_builder = frontend::MainPodBuilder::new(¶ms); diff --git a/src/backends/plonky2/mock/mainpod/mod.rs b/src/backends/plonky2/mock/mainpod/mod.rs index 8f27883..93a0fe0 100644 --- a/src/backends/plonky2/mock/mainpod/mod.rs +++ b/src/backends/plonky2/mock/mainpod/mod.rs @@ -1,10 +1,10 @@ use std::{any::Any, fmt}; use anyhow::{anyhow, Result}; -use base64::prelude::*; +// use base64::prelude::*; use plonky2::{hash::poseidon::PoseidonHash, plonk::config::Hasher}; -use serde::{Deserialize, Serialize}; +// use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::primitives::merkletree, middleware::{ @@ -27,7 +27,7 @@ impl PodProver for MockProver { } } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug)] pub struct MockMainPod { params: Params, id: PodId, @@ -199,7 +199,7 @@ impl MockMainPod { // Public statements assert!(inputs.public_statements.len() < params.max_public_statements); let mut type_st = middleware::Statement::ValueOf( - AnchoredKey(SELF, hash_str(KEY_TYPE)), + AnchoredKey::from((SELF, KEY_TYPE)), middleware::Value::from(PodType::MockMain), ) .into(); @@ -235,9 +235,9 @@ impl MockMainPod { pf, ) => Some(MerkleClaimAndProof::try_from_middleware( params, - root, - key, - Some(value), + &root.raw(), + &key.raw(), + Some(&value.raw()), pf, )), middleware::Operation::NotContainsFromEntries( @@ -245,7 +245,11 @@ impl MockMainPod { middleware::Statement::ValueOf(_, key), pf, ) => Some(MerkleClaimAndProof::try_from_middleware( - params, root, key, None, pf, + params, + &root.raw(), + &key.raw(), + None, + pf, )), _ => None, }) @@ -417,14 +421,14 @@ impl MockMainPod { fill_pad(args, OperationArg::None, params.max_operation_args) } - pub fn deserialize(serialized: String) -> Result { - let proof = String::from_utf8(BASE64_STANDARD.decode(&serialized)?) - .map_err(|e| anyhow::anyhow!("Invalid base64 encoding: {}", e))?; - let pod: MockMainPod = serde_json::from_str(&proof) - .map_err(|e| anyhow::anyhow!("Failed to parse proof: {}", e))?; + // pub fn deserialize(serialized: String) -> Result { + // let proof = String::from_utf8(BASE64_STANDARD.decode(&serialized)?) + // .map_err(|e| anyhow::anyhow!("Invalid base64 encoding: {}", e))?; + // let pod: MockMainPod = serde_json::from_str(&proof) + // .map_err(|e| anyhow::anyhow!("Failed to parse proof: {}", e))?; - Ok(pod) - } + // Ok(pod) + // } } pub fn hash_statements(statements: &[Statement], _params: &Params) -> middleware::Hash { @@ -449,8 +453,8 @@ impl Pod for MockMainPod { let has_type_statement = self.public_statements.iter().any(|s| { s.0 == Predicate::Native(NativePredicate::ValueOf) && !s.1.is_empty() - && if let StatementArg::Key(AnchoredKey(pod_id, key_hash)) = s.1[0] { - pod_id == SELF && key_hash == hash_str(KEY_TYPE) + && if let StatementArg::Key(AnchoredKey { pod_id, ref key }) = s.1[0] { + pod_id == SELF && key.hash() == hash_str(KEY_TYPE) } else { false } @@ -477,7 +481,7 @@ impl Pod for MockMainPod { .filter(|(_, s)| s.0 == Predicate::Native(NativePredicate::ValueOf)) .flat_map(|(i, s)| { if let StatementArg::Key(ak) = &s.1[0] { - vec![(i, ak.1, ak.0)] + vec![(i, ak.pod_id, ak.key.hash())] } else { vec![] } @@ -536,10 +540,10 @@ impl Pod for MockMainPod { .1 .iter() .map(|sa| match &sa { - StatementArg::Key(AnchoredKey(pod_id, h)) if *pod_id == SELF => { - StatementArg::Key(AnchoredKey(self.id(), *h)) + StatementArg::Key(AnchoredKey { pod_id, key }) if *pod_id == SELF => { + StatementArg::Key(AnchoredKey::new(self.id(), key.clone())) } - _ => *sa, + _ => sa.clone(), }) .collect(), ) @@ -557,7 +561,8 @@ impl Pod for MockMainPod { } fn serialized_proof(&self) -> String { - BASE64_STANDARD.encode(serde_json::to_string(self).unwrap()) + todo!() + // BASE64_STANDARD.encode(serde_json::to_string(self).unwrap()) } } @@ -570,19 +575,14 @@ pub mod tests { great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_sign_pod_builders, }, - middleware, + middleware::{self}, }; #[test] fn test_mock_main_zu_kyc() -> Result<()> { let params = middleware::Params::default(); - let sanctions_values = ["A343434340"].map(|s| crate::frontend::Value::from(s)); - let sanction_set = crate::frontend::Value::Set(crate::frontend::containers::Set::new( - sanctions_values.to_vec(), - )?); - let (gov_id_builder, pay_stub_builder, sanction_list_builder) = - zu_kyc_sign_pod_builders(¶ms, &sanction_set); + zu_kyc_sign_pod_builders(¶ms); let mut signer = MockSigner { pk: "ZooGov".into(), }; diff --git a/src/backends/plonky2/mock/mainpod/operation.rs b/src/backends/plonky2/mock/mainpod/operation.rs index 325f827..fb361bd 100644 --- a/src/backends/plonky2/mock/mainpod/operation.rs +++ b/src/backends/plonky2/mock/mainpod/operation.rs @@ -2,17 +2,19 @@ use std::{fmt, iter}; use anyhow::{anyhow, Result}; use plonky2::field::types::Field; -use serde::{Deserialize, Serialize}; +// use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ mock::mainpod::Statement, primitives::merkletree::{self}, }, - middleware::{self, Hash, OperationType, Params, ToFields, Value, EMPTY_HASH, EMPTY_VALUE, F}, + middleware::{ + self, Hash, OperationType, Params, RawValue, ToFields, EMPTY_HASH, EMPTY_VALUE, F, + }, }; -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq)] pub enum OperationArg { None, Index(usize), @@ -34,7 +36,7 @@ impl OperationArg { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq)] pub enum OperationAux { None, MerkleProofIndex(usize), @@ -50,17 +52,17 @@ impl ToFields for OperationAux { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq)] pub struct MerkleClaimAndProof { pub enabled: bool, pub root: Hash, - pub key: Value, - pub value: Value, + pub key: RawValue, + pub value: RawValue, pub existence: bool, pub siblings: Vec, pub case_ii_selector: bool, - pub other_key: Value, - pub other_value: Value, + pub other_key: RawValue, + pub other_value: RawValue, } impl MerkleClaimAndProof { @@ -68,7 +70,7 @@ impl MerkleClaimAndProof { Self { enabled: false, root: EMPTY_HASH, - key: Value::from(1), + key: RawValue::from(1), value: EMPTY_VALUE, existence: false, siblings: iter::repeat(EMPTY_HASH).take(max_depth).collect(), @@ -79,9 +81,9 @@ impl MerkleClaimAndProof { } pub fn try_from_middleware( params: &Params, - root: &Value, - key: &Value, - value: Option<&Value>, + root: &RawValue, + key: &RawValue, + value: Option<&RawValue>, mid_mp: &merkletree::MerkleProof, ) -> Result { if mid_mp.siblings.len() > params.max_depth_mt_gadget { @@ -152,7 +154,7 @@ impl fmt::Display for MerkleClaimAndProof { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq)] pub struct Operation(pub OperationType, pub Vec, pub OperationAux); impl Operation { @@ -209,7 +211,7 @@ impl fmt::Display for Operation { } match self.2 { OperationAux::None => (), - OperationAux::MerkleProofIndex(i) => write!(f, "merkle_proof_{:02}", i)?, + OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?, } Ok(()) } diff --git a/src/backends/plonky2/mock/mainpod/statement.rs b/src/backends/plonky2/mock/mainpod/statement.rs index f9bfe8f..225ce2b 100644 --- a/src/backends/plonky2/mock/mainpod/statement.rs +++ b/src/backends/plonky2/mock/mainpod/statement.rs @@ -1,13 +1,13 @@ use std::fmt; use anyhow::{anyhow, Result}; -use serde::{Deserialize, Serialize}; +// use serde::{Deserialize, Serialize}; use crate::middleware::{ - self, AnchoredKey, NativePredicate, Params, Predicate, StatementArg, ToFields, + self, NativePredicate, Params, Predicate, StatementArg, ToFields, WildcardValue, }; -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq)] pub struct Statement(pub Predicate, pub Vec); impl Statement { @@ -81,15 +81,15 @@ impl TryFrom for middleware::Statement { _ => Err(anyhow!("Ill-formed statement expression {:?}", s))?, }, Predicate::Custom(cpr) => { - let aks: Vec = proper_args + let vs: Vec = proper_args .into_iter() .filter_map(|arg| match arg { SA::None => None, - SA::Key(ak) => Some(ak), - SA::Literal(_) => unreachable!(), + SA::WildcardLiteral(v) => Some(v), + _ => unreachable!(), }) .collect(); - S::Custom(cpr, aks) + S::Custom(cpr, vs) } Predicate::BatchSelf(_) => { unreachable!() diff --git a/src/backends/plonky2/mock/signedpod.rs b/src/backends/plonky2/mock/signedpod.rs index e93fa42..a48979d 100644 --- a/src/backends/plonky2/mock/signedpod.rs +++ b/src/backends/plonky2/mock/signedpod.rs @@ -7,8 +7,8 @@ use crate::{ backends::plonky2::primitives::merkletree::MerkleTree, constants::MAX_DEPTH, middleware::{ - containers::Dictionary, hash_str, AnchoredKey, Hash, Params, Pod, PodId, PodSigner, - PodType, Statement, Value, KEY_SIGNER, KEY_TYPE, + containers::Dictionary, hash_str, AnchoredKey, Hash, Key, Params, Pod, PodId, PodSigner, + PodType, RawValue, Statement, Value, KEY_SIGNER, KEY_TYPE, }, }; @@ -17,26 +17,22 @@ pub struct MockSigner { } impl MockSigner { - pub fn pubkey(&self) -> Value { - Value(hash_str(&self.pk).0) + pub fn pubkey(&self) -> Hash { + hash_str(&self.pk) } } impl PodSigner for MockSigner { - fn sign(&mut self, _params: &Params, kvs: &HashMap) -> Result> { + fn sign(&mut self, _params: &Params, kvs: &HashMap) -> Result> { let mut kvs = kvs.clone(); let pubkey = self.pubkey(); - kvs.insert(hash_str(KEY_SIGNER), pubkey); - kvs.insert(hash_str(KEY_TYPE), Value::from(PodType::MockSigned)); + kvs.insert(Key::from(KEY_SIGNER), Value::from(pubkey)); + kvs.insert(Key::from(KEY_TYPE), Value::from(PodType::MockSigned)); - let dict = Dictionary::new(&kvs)?; + let dict = Dictionary::new(kvs.clone())?; let id = PodId(dict.commitment()); let signature = format!("{}_signed_by_{}", id, pubkey); - Ok(Box::new(MockSignedPod { - dict, - id, - signature, - })) + Ok(Box::new(MockSignedPod { id, signature, kvs })) } } @@ -44,18 +40,18 @@ impl PodSigner for MockSigner { pub struct MockSignedPod { id: PodId, signature: String, - dict: Dictionary, + kvs: HashMap, } -impl MockSignedPod { - pub fn deserialize(id: PodId, signature: String, dict: Dictionary) -> Self { - Self { - id, - signature, - dict, - } - } -} +// impl MockSignedPod { +// pub fn deserialize(id: PodId, signature: String, dict: Dictionary) -> Self { +// Self { +// id, +// signature, +// dict, +// } +// } +// } impl Pod for MockSignedPod { fn verify(&self) -> Result<()> { @@ -63,10 +59,10 @@ impl Pod for MockSignedPod { let mt = MerkleTree::new( MAX_DEPTH, &self - .dict + .kvs .iter() - .map(|(&k, &v)| (k, v)) - .collect::>(), + .map(|(k, v)| (k.raw(), v.raw())) + .collect::>(), )?; let id = PodId(mt.root()); if id != self.id { @@ -78,8 +74,11 @@ impl Pod for MockSignedPod { } // 2. Verify type - let value_at_type = self.dict.get(&hash_str(KEY_TYPE).into())?; - if Value::from(PodType::MockSigned) != value_at_type { + let value_at_type = self + .kvs + .get(&Key::from(KEY_TYPE)) + .ok_or(anyhow!("key not found"))?; + if &Value::from(PodType::MockSigned) != value_at_type { return Err(anyhow!( "type does not match, expected MockSigned ({}), found {}", PodType::MockSigned, @@ -88,7 +87,10 @@ impl Pod for MockSignedPod { } // 3. Verify signature - let pk_hash = self.dict.get(&hash_str(KEY_SIGNER).into())?; + let pk_hash = self + .kvs + .get(&Key::from(KEY_SIGNER)) + .ok_or(anyhow!("key not found"))?; let signature = format!("{}_signed_by_{}", id, pk_hash); if signature != self.signature { return Err(anyhow!( @@ -108,15 +110,15 @@ impl Pod for MockSignedPod { fn pub_statements(&self) -> Vec { let id = self.id(); // By convention we put the KEY_TYPE first and KEY_SIGNER second - let mut kvs: HashMap<_, _> = self.dict.iter().collect(); - let key_type = Value::from(hash_str(KEY_TYPE)); + let mut kvs = self.kvs.clone(); + let key_type = Key::from(KEY_TYPE); let value_type = kvs.remove(&key_type).expect("KEY_TYPE"); - let key_signer = Value::from(hash_str(KEY_SIGNER)); + let key_signer = Key::from(KEY_SIGNER); let value_signer = kvs.remove(&key_signer).expect("KEY_SIGNER"); - [(&key_type, value_type), (&key_signer, value_signer)] + [(key_type, value_type), (key_signer, value_signer)] .into_iter() - .chain(kvs.into_iter().sorted_by_key(|kv| kv.0)) - .map(|(k, v)| Statement::ValueOf(AnchoredKey(id, Hash(k.0)), *v)) + .chain(kvs.into_iter().sorted_by_key(|kv| kv.0.hash())) + .map(|(k, v)| Statement::ValueOf(AnchoredKey::from((id, k)), v)) .collect() } @@ -140,9 +142,8 @@ pub mod tests { use super::*; use crate::{ - constants::MAX_DEPTH, frontend, - middleware::{self, EMPTY_HASH, F}, + middleware::{self, EMPTY_VALUE, F}, }; #[test] @@ -170,27 +171,25 @@ pub mod tests { assert!(bad_pod.verify().is_err()); let mut bad_pod = pod.clone(); - let bad_kv = (hash_str(KEY_SIGNER).into(), Value(PodId(EMPTY_HASH).0 .0)); - let bad_kvs_mt = &bad_pod - .kvs() + let bad_kv = (Key::from(KEY_SIGNER), Value::from(EMPTY_VALUE)); + let bad_kvs = bad_pod + .kvs + .clone() .into_iter() - .map(|(AnchoredKey(_, k), v)| (Value(k.0), v)) .chain(iter::once(bad_kv)) - .collect::>(); - let bad_mt = MerkleTree::new(MAX_DEPTH, bad_kvs_mt)?; - bad_pod.dict.mt = bad_mt; + .collect::>(); + bad_pod.kvs = bad_kvs; assert!(bad_pod.verify().is_err()); let mut bad_pod = pod.clone(); - let bad_kv = (hash_str(KEY_TYPE).into(), Value::from(0)); - let bad_kvs_mt = &bad_pod - .kvs() + let bad_kv = (Key::from(KEY_TYPE), Value::from(0)); + let bad_kvs = bad_pod + .kvs + .clone() .into_iter() - .map(|(AnchoredKey(_, k), v)| (Value(k.0), v)) .chain(iter::once(bad_kv)) - .collect::>(); - let bad_mt = MerkleTree::new(MAX_DEPTH, bad_kvs_mt)?; - bad_pod.dict.mt = bad_mt; + .collect::>(); + bad_pod.kvs = bad_kvs; assert!(bad_pod.verify().is_err()); Ok(()) diff --git a/src/backends/plonky2/primitives/merkletree.rs b/src/backends/plonky2/primitives/merkletree.rs index b2ff86d..87694d1 100644 --- a/src/backends/plonky2/primitives/merkletree.rs +++ b/src/backends/plonky2/primitives/merkletree.rs @@ -4,10 +4,10 @@ use std::{collections::HashMap, fmt, iter::IntoIterator}; use anyhow::{anyhow, Result}; use plonky2::field::types::Field; -use serde::{Deserialize, Serialize}; +// use serde::{Deserialize, Serialize}; pub use super::merkletree_circuit::*; -use crate::backends::plonky2::basetypes::{hash_fields, Hash, Value, EMPTY_HASH, F}; +use crate::middleware::{hash_fields, Hash, RawValue, EMPTY_HASH, F}; /// Implements the MerkleTree specified at /// https://0xparc.github.io/pod2/merkletree.html @@ -19,7 +19,7 @@ pub struct MerkleTree { impl MerkleTree { /// builds a new `MerkleTree` where the leaves contain the given key-values - pub fn new(max_depth: usize, kvs: &HashMap) -> Result { + pub fn new(max_depth: usize, kvs: &HashMap) -> Result { // Construct leaves. let mut leaves: Vec<_> = kvs .iter() @@ -50,7 +50,7 @@ impl MerkleTree { } /// returns the value at the given key - pub fn get(&self, key: &Value) -> Result { + pub fn get(&self, key: &RawValue) -> Result { let path = keypath(self.max_depth, *key)?; let key_resolution = self.root.down(0, self.max_depth, path, None)?; match key_resolution { @@ -60,7 +60,7 @@ impl MerkleTree { } /// returns a boolean indicating whether the key exists in the tree - pub fn contains(&self, key: &Value) -> Result { + pub fn contains(&self, key: &RawValue) -> Result { let path = keypath(self.max_depth, *key)?; match self.root.down(0, self.max_depth, path, None) { Ok(Some((k, _))) => { @@ -77,7 +77,7 @@ impl MerkleTree { /// returns a proof of existence, which proves that the given key exists in /// the tree. It returns the `value` of the leaf at the given `key`, and the /// `MerkleProof`. - pub fn prove(&self, key: &Value) -> Result<(Value, MerkleProof)> { + pub fn prove(&self, key: &RawValue) -> Result<(RawValue, MerkleProof)> { let path = keypath(self.max_depth, *key)?; let mut siblings: Vec = Vec::new(); @@ -102,7 +102,7 @@ impl MerkleTree { /// `key` does not exist in the tree. The return value specifies /// the key-value pair in the leaf reached as a result of /// resolving `key` as well as a `MerkleProof`. - pub fn prove_nonexistence(&self, key: &Value) -> Result { + pub fn prove_nonexistence(&self, key: &RawValue) -> Result { let path = keypath(self.max_depth, *key)?; let mut siblings: Vec = Vec::new(); @@ -134,8 +134,8 @@ impl MerkleTree { max_depth: usize, root: Hash, proof: &MerkleProof, - key: &Value, - value: &Value, + key: &RawValue, + value: &RawValue, ) -> Result<()> { let h = proof.compute_root_from_leaf(max_depth, key, Some(*value))?; @@ -152,13 +152,13 @@ impl MerkleTree { max_depth: usize, root: Hash, proof: &MerkleProof, - key: &Value, + key: &RawValue, ) -> Result<()> { match proof.other_leaf { Some((k, _v)) if &k == key => Err(anyhow!("Invalid non-existence proof.")), _ => { let k = proof.other_leaf.map(|(k, _)| k).unwrap_or(*key); - let v: Option = proof.other_leaf.map(|(_, v)| v); + let v: Option = proof.other_leaf.map(|(_, v)| v); let h = proof.compute_root_from_leaf(max_depth, &k, v)?; if h != root { @@ -180,14 +180,14 @@ impl MerkleTree { /// Hash function for key-value pairs. Different branch pair hashes to /// mitigate fake proofs. -pub fn kv_hash(key: &Value, value: Option) -> Hash { +pub fn kv_hash(key: &RawValue, value: Option) -> Hash { value .map(|v| hash_fields(&[key.0.to_vec(), v.0.to_vec(), vec![F::ONE]].concat())) .unwrap_or(EMPTY_HASH) } impl<'a> IntoIterator for &'a MerkleTree { - type Item = (&'a Value, &'a Value); + type Item = (&'a RawValue, &'a RawValue); type IntoIter = Iter<'a>; fn into_iter(self) -> Self::IntoIter { @@ -208,7 +208,7 @@ impl fmt::Display for MerkleTree { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq)] pub struct MerkleProof { // note: currently we don't use the `_existence` field, we would use if we merge the methods // `verify` and `verify_nonexistence` into a single one @@ -216,7 +216,7 @@ pub struct MerkleProof { pub(crate) existence: bool, pub(crate) siblings: Vec, // other_leaf is used for non-existence proofs - pub(crate) other_leaf: Option<(Value, Value)>, + pub(crate) other_leaf: Option<(RawValue, RawValue)>, } impl fmt::Display for MerkleProof { @@ -238,8 +238,8 @@ impl MerkleProof { fn compute_root_from_leaf( &self, max_depth: usize, - key: &Value, - value: Option, + key: &RawValue, + value: Option, ) -> Result { if self.siblings.len() >= max_depth { return Err(anyhow!("max depth reached")); @@ -335,7 +335,7 @@ impl Node { max_depth: usize, path: Vec, mut siblings: Option<&mut Vec>, - ) -> Result> { + ) -> Result> { if lvl >= max_depth { return Err(anyhow!("max depth reached")); } @@ -494,11 +494,11 @@ impl Intermediate { struct Leaf { hash: Option, path: Vec, - key: Value, - value: Value, + key: RawValue, + value: RawValue, } impl Leaf { - fn new(max_depth: usize, key: Value, value: Value) -> Result { + fn new(max_depth: usize, key: RawValue, value: RawValue) -> Result { Ok(Self { hash: None, path: keypath(max_depth, key)?, @@ -523,7 +523,7 @@ impl Leaf { // max-depth? ie, what happens when two keys share the same path for more bits // than the max_depth? /// returns the path of the given key -pub(crate) fn keypath(max_depth: usize, k: Value) -> Result> { +pub(crate) fn keypath(max_depth: usize, k: RawValue) -> Result> { let bytes = k.to_bytes(); if max_depth > 8 * bytes.len() { // note that our current keys are of Value type, which are 4 Goldilocks @@ -545,7 +545,7 @@ pub struct Iter<'a> { } impl<'a> Iterator for Iter<'a> { - type Item = (&'a Value, &'a Value); + type Item = (&'a RawValue, &'a RawValue); fn next(&mut self) -> Option { let node = self.state.pop(); @@ -586,10 +586,10 @@ pub mod tests { if i == 1 { continue; } - kvs.insert(Value::from(i), Value::from(1000 + i)); + kvs.insert(RawValue::from(i), RawValue::from(1000 + i)); } - let key = Value::from(13); - let value = Value::from(1013); + let key = RawValue::from(13); + let value = RawValue::from(1013); kvs.insert(key, value); let tree = MerkleTree::new(32, &kvs)?; @@ -598,25 +598,25 @@ pub mod tests { println!("{}", tree); // Inclusion checks - let (v, proof) = tree.prove(&Value::from(13))?; - assert_eq!(v, Value::from(1013)); + let (v, proof) = tree.prove(&RawValue::from(13))?; + assert_eq!(v, RawValue::from(1013)); println!("{}", proof); MerkleTree::verify(32, tree.root(), &proof, &key, &value)?; // Exclusion checks - let key = Value::from(12); + let key = RawValue::from(12); let proof = tree.prove_nonexistence(&key)?; assert_eq!( proof.other_leaf.unwrap(), - (Value::from(4), Value::from(1004)) + (RawValue::from(4), RawValue::from(1004)) ); println!("{}", proof); MerkleTree::verify_nonexistence(32, tree.root(), &proof, &key)?; - let key = Value::from(1); - let proof = tree.prove_nonexistence(&Value::from(1))?; + let key = RawValue::from(1); + let proof = tree.prove_nonexistence(&RawValue::from(1))?; assert_eq!(proof.other_leaf, None); println!("{}", proof); diff --git a/src/backends/plonky2/primitives/merkletree_circuit.rs b/src/backends/plonky2/primitives/merkletree_circuit.rs index dd62f0f..4b3327d 100644 --- a/src/backends/plonky2/primitives/merkletree_circuit.rs +++ b/src/backends/plonky2/primitives/merkletree_circuit.rs @@ -24,10 +24,13 @@ use plonky2::{ plonk::circuit_builder::CircuitBuilder, }; -use crate::backends::plonky2::{ - basetypes::{Hash, Value, D, EMPTY_HASH, EMPTY_VALUE, F, HASH_SIZE}, - circuits::common::{CircuitBuilderPod, ValueTarget}, - primitives::merkletree::MerkleProof, +use crate::{ + backends::plonky2::{ + basetypes::D, + circuits::common::{CircuitBuilderPod, ValueTarget}, + primitives::merkletree::MerkleProof, + }, + middleware::{Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F, HASH_SIZE}, }; /// `MerkleProofGadget` allows to verify both proofs of existence and proofs @@ -163,8 +166,8 @@ impl MerkleClaimAndProofTarget { existence: bool, root: Hash, proof: MerkleProof, - key: Value, - value: Value, + key: RawValue, + value: RawValue, ) -> Result<()> { pw.set_bool_target(self.enabled, enabled)?; pw.set_hash_target(self.root, HashOut::from_vec(root.0.to_vec()))?; @@ -268,8 +271,8 @@ impl MerkleProofExistenceTarget { enabled: bool, root: Hash, proof: MerkleProof, - key: Value, - value: Value, + key: RawValue, + value: RawValue, ) -> Result<()> { assert!(proof.existence); // sanity check @@ -405,9 +408,9 @@ pub mod tests { use plonky2::plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}; use super::*; - use crate::backends::plonky2::{ - basetypes::{hash_value, C}, - primitives::merkletree::*, + use crate::{ + backends::plonky2::{basetypes::C, primitives::merkletree::*}, + middleware::{hash_value, RawValue}, }; #[test] @@ -424,7 +427,7 @@ pub mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::::new(); - let key = Value::from(hash_value(&Value::from(i))); + let key = RawValue::from(hash_value(&RawValue::from(i))); let expected_path = keypath(max_depth, key)?; // small circuit logic to check @@ -455,8 +458,8 @@ pub mod tests { #[test] fn test_kv_hash() -> Result<()> { for i in 0..10 { - let key = Value::from(hash_value(&Value::from(i))); - let value = Value::from(1000 + i); + let key = RawValue::from(hash_value(&RawValue::from(i))); + let value = RawValue::from(1000 + i); let h = kv_hash(&key, Some(value)); // circuit @@ -502,20 +505,23 @@ pub mod tests { // test logic to be reused both by the existence & nonexistence tests fn test_merkleproof_verify_opt(max_depth: usize, existence: bool) -> Result<()> { - let mut kvs: HashMap = HashMap::new(); + let mut kvs: HashMap = HashMap::new(); for i in 0..10 { - kvs.insert(Value::from(hash_value(&Value::from(i))), Value::from(i)); + kvs.insert( + RawValue::from(hash_value(&RawValue::from(i))), + RawValue::from(i), + ); } let tree = MerkleTree::new(max_depth, &kvs)?; let (key, value, proof) = if existence { - let key = Value::from(hash_value(&Value::from(5))); + let key = RawValue::from(hash_value(&RawValue::from(5))); let (value, proof) = tree.prove(&key)?; - assert_eq!(value, Value::from(5)); + assert_eq!(value, RawValue::from(5)); (key, value, proof) } else { - let key = Value::from(hash_value(&Value::from(200))); + let key = RawValue::from(hash_value(&RawValue::from(200))); (key, EMPTY_VALUE, tree.prove_nonexistence(&key)?) }; assert_eq!(proof.existence, existence); @@ -559,16 +565,19 @@ pub mod tests { } fn test_merkleproof_only_existence_verify_opt(max_depth: usize) -> Result<()> { - let mut kvs: HashMap = HashMap::new(); + let mut kvs: HashMap = HashMap::new(); for i in 0..10 { - kvs.insert(Value::from(hash_value(&Value::from(i))), Value::from(i)); + kvs.insert( + RawValue::from(hash_value(&RawValue::from(i))), + RawValue::from(i), + ); } let tree = MerkleTree::new(max_depth, &kvs)?; - let key = Value::from(hash_value(&Value::from(5))); + let key = RawValue::from(hash_value(&RawValue::from(5))); let (value, proof) = tree.prove(&key)?; - assert_eq!(value, Value::from(5)); + assert_eq!(value, RawValue::from(5)); assert_eq!(proof.existence, true); MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; @@ -604,24 +613,28 @@ pub mod tests { // 5 13 let mut kvs = HashMap::new(); - kvs.insert(Value::from(0), Value::from(1000)); - kvs.insert(Value::from(2), Value::from(1002)); - kvs.insert(Value::from(5), Value::from(1005)); - kvs.insert(Value::from(13), Value::from(1013)); + kvs.insert(RawValue::from(0), RawValue::from(1000)); + kvs.insert(RawValue::from(2), RawValue::from(1002)); + kvs.insert(RawValue::from(5), RawValue::from(1005)); + kvs.insert(RawValue::from(13), RawValue::from(1013)); let max_depth = 5; let tree = MerkleTree::new(max_depth, &kvs)?; // existence - test_merkletree_edgecase_opt(max_depth, &tree, Value::from(5))?; + test_merkletree_edgecase_opt(max_depth, &tree, RawValue::from(5))?; // non-existence case i) expected leaf does not exist - test_merkletree_edgecase_opt(max_depth, &tree, Value::from(1))?; + test_merkletree_edgecase_opt(max_depth, &tree, RawValue::from(1))?; // non-existence case ii) expected leaf does exist but it has a different 'key' - test_merkletree_edgecase_opt(max_depth, &tree, Value::from(21))?; + test_merkletree_edgecase_opt(max_depth, &tree, RawValue::from(21))?; Ok(()) } - fn test_merkletree_edgecase_opt(max_depth: usize, tree: &MerkleTree, key: Value) -> Result<()> { + fn test_merkletree_edgecase_opt( + max_depth: usize, + tree: &MerkleTree, + key: RawValue, + ) -> Result<()> { let contains = tree.contains(&key)?; // generate merkleproof let (value, proof) = if contains { @@ -666,19 +679,19 @@ pub mod tests { #[test] fn test_wrong_witness() -> Result<()> { - let mut kvs: HashMap = HashMap::new(); + let mut kvs: HashMap = HashMap::new(); for i in 0..10 { - kvs.insert(Value::from(i), Value::from(i)); + kvs.insert(RawValue::from(i), RawValue::from(i)); } let max_depth = 16; let tree = MerkleTree::new(max_depth, &kvs)?; - let key = Value::from(3); + let key = RawValue::from(3); let (value, proof) = tree.prove(&key)?; // build another tree with an extra key-value, so that it has a // different root - kvs.insert(Value::from(100), Value::from(100)); + kvs.insert(RawValue::from(100), RawValue::from(100)); let tree2 = MerkleTree::new(max_depth, &kvs)?; MerkleTree::verify(max_depth, tree.root(), &proof, &key, &value)?; diff --git a/src/backends/plonky2/primitives/signature.rs b/src/backends/plonky2/primitives/signature.rs index abe416e..cbd3d6c 100644 --- a/src/backends/plonky2/primitives/signature.rs +++ b/src/backends/plonky2/primitives/signature.rs @@ -21,7 +21,10 @@ use plonky2::{ }; pub use super::signature_circuit::*; -use crate::backends::plonky2::basetypes::{Proof, Value, C, D, F, VALUE_SIZE}; +use crate::{ + backends::plonky2::basetypes::{Proof, C, D}, + middleware::{RawValue, F, VALUE_SIZE}, +}; lazy_static! { /// Signature prover parameters @@ -45,10 +48,10 @@ pub struct ProverParams { pub struct VerifierParams(pub(crate) VerifierCircuitData); #[derive(Clone, Debug)] -pub struct SecretKey(pub(crate) Value); +pub struct SecretKey(pub(crate) RawValue); #[derive(Clone, Debug)] -pub struct PublicKey(pub(crate) Value); +pub struct PublicKey(pub(crate) RawValue); #[derive(Clone, Debug)] pub struct Signature(pub(crate) Proof); @@ -57,16 +60,16 @@ pub struct Signature(pub(crate) Proof); impl SecretKey { pub fn new_rand() -> Self { // note: the `F::rand()` internally uses `rand::rngs::OsRng` - Self(Value(std::array::from_fn(|_| F::rand()))) + Self(RawValue(std::array::from_fn(|_| F::rand()))) } pub fn public_key(&self) -> PublicKey { - PublicKey(Value(PoseidonHash::hash_no_pad(&self.0 .0).elements)) + PublicKey(RawValue(PoseidonHash::hash_no_pad(&self.0 .0).elements)) } - pub fn sign(&self, msg: Value) -> Result { + pub fn sign(&self, msg: RawValue) -> Result { let pk = self.public_key(); - let s = Value(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); + let s = RawValue(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); let mut pw = PartialWitness::::new(); PP.circuit.set_targets(&mut pw, self.clone(), pk, msg, s)?; @@ -108,9 +111,9 @@ impl Signature { Ok((builder, circuit)) } - pub fn verify(&self, pk: &PublicKey, msg: Value) -> Result<()> { + pub fn verify(&self, pk: &PublicKey, msg: RawValue) -> Result<()> { // prepare public inputs as [pk, msg, s] - let s = Value(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); + let s = RawValue(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); let public_inputs: Vec = [pk.0 .0, msg.0, s.0].concat(); // verify plonky2 proof @@ -122,16 +125,16 @@ impl Signature { } fn dummy_public_inputs() -> Result> { - let sk = SecretKey(Value::from(0)); + let sk = SecretKey(RawValue::from(0)); let pk = sk.public_key(); - let msg = Value::from(0); - let s = Value(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); + let msg = RawValue::from(0); + let s = RawValue(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); Ok([pk.0 .0, msg.0, s.0].concat()) } fn dummy_signature() -> Result { - let sk = SecretKey(Value::from(0)); - let msg = Value::from(0); + let sk = SecretKey(RawValue::from(0)); + let msg = RawValue::from(0); sk.sign(msg) } @@ -186,8 +189,8 @@ impl SignatureInternalCircuit { pw: &mut PartialWitness, sk: SecretKey, pk: PublicKey, - msg: Value, - s: Value, + msg: RawValue, + s: RawValue, ) -> Result<()> { pw.set_target_arr(&self.sk_targ, sk.0 .0.as_ref())?; pw.set_hash_target(self.pk_targ, HashOut::::from_vec(pk.0 .0.to_vec()))?; @@ -201,23 +204,23 @@ impl SignatureInternalCircuit { #[cfg(test)] pub mod tests { use super::*; - use crate::backends::plonky2::basetypes::Hash; + use crate::middleware::hash_str; #[test] fn test_signature() -> Result<()> { let sk = SecretKey::new_rand(); let pk = sk.public_key(); - let msg = Value::from(42); + let msg = RawValue::from(42); let sig = sk.sign(msg)?; sig.verify(&pk, msg)?; // expect the signature verification to fail when using a different msg - let v = sig.verify(&pk, Value::from(24)); + let v = sig.verify(&pk, RawValue::from(24)); assert!(v.is_err(), "should fail to verify"); // perform a 2nd signature over another msg and verify it - let msg_2 = Value::from(Hash::from("message")); + let msg_2 = RawValue::from(hash_str("message")); let sig2 = sk.sign(msg_2)?; sig2.verify(&pk, msg_2)?; @@ -226,9 +229,9 @@ pub mod tests { #[test] fn test_dummy_signature() -> Result<()> { - let sk = SecretKey(Value::from(0)); + let sk = SecretKey(RawValue::from(0)); let pk = sk.public_key(); - let msg = Value::from(0); + let msg = RawValue::from(0); DUMMY_SIGNATURE.clone().verify(&pk, msg)?; Ok(()) diff --git a/src/backends/plonky2/primitives/signature_circuit.rs b/src/backends/plonky2/primitives/signature_circuit.rs index 2d7ae88..82e2d76 100644 --- a/src/backends/plonky2/primitives/signature_circuit.rs +++ b/src/backends/plonky2/primitives/signature_circuit.rs @@ -22,12 +22,15 @@ use plonky2::{ }, }; -use crate::backends::plonky2::{ - basetypes::{Hash, Proof, Value, C, D, EMPTY_HASH, EMPTY_VALUE, F, VALUE_SIZE}, - circuits::common::{CircuitBuilderPod, ValueTarget}, - primitives::signature::{ - PublicKey, SecretKey, Signature, DUMMY_PUBLIC_INPUTS, DUMMY_SIGNATURE, +use crate::{ + backends::plonky2::{ + basetypes::{Proof, C, D}, + circuits::common::{CircuitBuilderPod, ValueTarget}, + primitives::signature::{ + PublicKey, SecretKey, Signature, DUMMY_PUBLIC_INPUTS, DUMMY_SIGNATURE, + }, }, + middleware::{Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F, VALUE_SIZE}, }; lazy_static! { @@ -81,12 +84,12 @@ impl SignatureVerifyGadget { let dummy_pi = DUMMY_PUBLIC_INPUTS.clone(); let pk_targ_dummy = - builder.constant_value(Value(dummy_pi[..VALUE_SIZE].try_into().unwrap())); - let msg_targ_dummy = builder.constant_value(Value( + builder.constant_value(RawValue(dummy_pi[..VALUE_SIZE].try_into().unwrap())); + let msg_targ_dummy = builder.constant_value(RawValue( dummy_pi[VALUE_SIZE..VALUE_SIZE * 2].try_into().unwrap(), )); let s_targ_dummy = - builder.constant_value(Value(dummy_pi[VALUE_SIZE * 2..].try_into().unwrap())); + builder.constant_value(RawValue(dummy_pi[VALUE_SIZE * 2..].try_into().unwrap())); // connect the {pk, msg, s} with the proof_targ.public_inputs conditionally let pk_targ_connect = builder.select_value(enabled, pk_targ, pk_targ_dummy); @@ -129,7 +132,7 @@ impl SignatureVerifyTarget { pw: &mut PartialWitness, enabled: bool, pk: PublicKey, - msg: Value, + msg: RawValue, signature: Signature, ) -> Result<()> { pw.set_bool_target(self.enabled, enabled)?; @@ -137,7 +140,7 @@ impl SignatureVerifyTarget { pw.set_target_arr(&self.msg.elements, &msg.0)?; // note that this hash is checked again in-circuit at the `SignatureInternalCircuit` - let s = Value(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); + let s = RawValue(PoseidonHash::hash_no_pad(&[pk.0 .0, msg.0].concat()).elements); let public_inputs: Vec = [pk.0 .0, msg.0, s.0].concat(); if enabled { @@ -170,14 +173,14 @@ impl SignatureVerifyTarget { #[cfg(test)] pub mod tests { use super::*; - use crate::backends::plonky2::{basetypes::Hash, primitives::signature::SecretKey}; + use crate::{backends::plonky2::primitives::signature::SecretKey, middleware::Hash}; #[test] fn test_signature_gadget() -> Result<()> { // generate a valid signature let sk = SecretKey::new_rand(); let pk = sk.public_key(); - let msg = Value::from(42); + let msg = RawValue::from(42); let sig = sk.sign(msg)?; sig.verify(&pk, msg)?; @@ -208,15 +211,15 @@ pub mod tests { // generate a valid signature let sk = SecretKey::new_rand(); let pk = sk.public_key(); - let msg = Value::from(42); + let msg = RawValue::from(42); let sig = sk.sign(msg)?; // verification should pass sig.verify(&pk, msg)?; // replace the message, so that verifications should fail - let msg = Value::from(24); + let msg = RawValue::from(24); // expect signature native verification to fail - let v = sig.verify(&pk, Value::from(24)); + let v = sig.verify(&pk, RawValue::from(24)); assert!(v.is_err(), "should fail to verify"); // circuit diff --git a/src/backends/plonky2/signedpod.rs b/src/backends/plonky2/signedpod.rs index 818239b..9158130 100644 --- a/src/backends/plonky2/signedpod.rs +++ b/src/backends/plonky2/signedpod.rs @@ -10,22 +10,22 @@ use crate::{ }, constants::MAX_DEPTH, middleware::{ - containers::Dictionary, hash_str, AnchoredKey, Hash, Params, Pod, PodId, PodSigner, - PodType, Statement, Value, KEY_SIGNER, KEY_TYPE, + containers::Dictionary, AnchoredKey, Hash, Key, Params, Pod, PodId, PodSigner, PodType, + RawValue, Statement, Value, KEY_SIGNER, KEY_TYPE, }, }; pub struct Signer(pub SecretKey); impl PodSigner for Signer { - fn sign(&mut self, _params: &Params, kvs: &HashMap) -> Result> { + fn sign(&mut self, _params: &Params, kvs: &HashMap) -> Result> { let mut kvs = kvs.clone(); let pubkey = self.0.public_key(); - kvs.insert(hash_str(KEY_SIGNER), pubkey.0); - kvs.insert(hash_str(KEY_TYPE), Value::from(PodType::Signed)); + kvs.insert(Key::from(KEY_SIGNER), Value::from(pubkey.0)); + kvs.insert(Key::from(KEY_TYPE), Value::from(PodType::Signed)); - let dict = Dictionary::new(&kvs)?; - let id = Value::from(dict.commitment()); // PodId as Value + let dict = Dictionary::new(kvs)?; + let id = RawValue::from(dict.commitment()); // PodId as Value let signature: Signature = self.0.sign(id)?; Ok(Box::new(SignedPod { @@ -46,8 +46,8 @@ pub struct SignedPod { impl Pod for SignedPod { fn verify(&self) -> Result<()> { // 1. Verify type - let value_at_type = self.dict.get(&hash_str(KEY_TYPE).into())?; - if Value::from(PodType::Signed) != value_at_type { + let value_at_type = self.dict.get(&Key::from(KEY_TYPE))?; + if Value::from(PodType::Signed) != *value_at_type { return Err(anyhow!( "type does not match, expected Signed ({}), found {}", PodType::Signed, @@ -60,9 +60,10 @@ impl Pod for SignedPod { MAX_DEPTH, &self .dict + .kvs() .iter() - .map(|(&k, &v)| (k, v)) - .collect::>(), + .map(|(k, v)| (k.raw(), v.raw())) + .collect::>(), )?; let id = PodId(mt.root()); if id != self.id { @@ -74,9 +75,9 @@ impl Pod for SignedPod { } // 3. Verify signature - let pk_value = self.dict.get(&hash_str(KEY_SIGNER).into())?; - let pk = PublicKey(pk_value); - self.signature.verify(&pk, Value::from(id.0))?; + let pk_value = self.dict.get(&Key::from(KEY_SIGNER))?; + let pk = PublicKey(pk_value.raw()); + self.signature.verify(&pk, RawValue::from(id.0))?; Ok(()) } @@ -88,15 +89,15 @@ impl Pod for SignedPod { fn pub_statements(&self) -> Vec { let id = self.id(); // By convention we put the KEY_TYPE first and KEY_SIGNER second - let mut kvs: HashMap<_, _> = self.dict.iter().collect(); - let key_type = Value::from(hash_str(KEY_TYPE)); + let mut kvs: HashMap = self.dict.kvs().clone(); + let key_type = Key::from(KEY_TYPE); let value_type = kvs.remove(&key_type).expect("KEY_TYPE"); - let key_signer = Value::from(hash_str(KEY_SIGNER)); + let key_signer = Key::from(KEY_SIGNER); let value_signer = kvs.remove(&key_signer).expect("KEY_SIGNER"); - [(&key_type, value_type), (&key_signer, value_signer)] + [(key_type, value_type), (key_signer, value_signer)] .into_iter() - .chain(kvs.into_iter().sorted_by_key(|kv| kv.0)) - .map(|(k, v)| Statement::ValueOf(AnchoredKey(id, Hash(k.0)), *v)) + .chain(kvs.into_iter().sorted_by_key(|kv| kv.0.hash())) + .map(|(k, v)| Statement::ValueOf(AnchoredKey::from((id, k)), v)) .collect() } @@ -123,9 +124,8 @@ pub mod tests { use super::*; use crate::{ - constants::MAX_DEPTH, frontend, - middleware::{self, EMPTY_HASH, F}, + middleware::{self, EMPTY_VALUE, F}, }; #[test] @@ -147,7 +147,7 @@ pub mod tests { println!("kvs: {:?}", pod.kvs()); let mut bad_pod = pod.clone(); - bad_pod.signature = signer.0.sign(Value::from(42_i64))?; + bad_pod.signature = signer.0.sign(RawValue::from(42_i64))?; assert!(bad_pod.verify().is_err()); let mut bad_pod = pod.clone(); @@ -155,27 +155,27 @@ pub mod tests { assert!(bad_pod.verify().is_err()); let mut bad_pod = pod.clone(); - let bad_kv = (hash_str(KEY_SIGNER).into(), Value(PodId(EMPTY_HASH).0 .0)); - let bad_kvs_mt = &bad_pod + let bad_kv = (Key::from(KEY_SIGNER), Value::from(EMPTY_VALUE)); + let bad_kvs = bad_pod + .dict .kvs() + .clone() .into_iter() - .map(|(AnchoredKey(_, k), v)| (Value(k.0), v)) .chain(iter::once(bad_kv)) - .collect::>(); - let bad_mt = MerkleTree::new(MAX_DEPTH, bad_kvs_mt)?; - bad_pod.dict.mt = bad_mt; + .collect::>(); + bad_pod.dict = Dictionary::new(bad_kvs).unwrap(); assert!(bad_pod.verify().is_err()); let mut bad_pod = pod.clone(); - let bad_kv = (hash_str(KEY_TYPE).into(), Value::from(0)); - let bad_kvs_mt = &bad_pod + let bad_kv = (Key::from(KEY_TYPE), Value::from(0)); + let bad_kvs = bad_pod + .dict .kvs() + .clone() .into_iter() - .map(|(AnchoredKey(_, k), v)| (Value(k.0), v)) .chain(iter::once(bad_kv)) - .collect::>(); - let bad_mt = MerkleTree::new(MAX_DEPTH, bad_kvs_mt)?; - bad_pod.dict.mt = bad_mt; + .collect::>(); + bad_pod.dict = Dictionary::new(bad_kvs).unwrap(); assert!(bad_pod.verify().is_err()); Ok(()) diff --git a/src/examples/custom.rs b/src/examples/custom.rs index 16622a1..439fa05 100644 --- a/src/examples/custom.rs +++ b/src/examples/custom.rs @@ -1,21 +1,21 @@ use std::sync::Arc; use anyhow::Result; -use NativePredicate as NP; use StatementTmplBuilder as STB; use crate::{ - frontend::{ - literal, CustomPredicateBatch, CustomPredicateBatchBuilder, CustomPredicateRef, - NativePredicate, Predicate, StatementTmplBuilder, + frontend::{key, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, + middleware::{ + CustomPredicateBatch, CustomPredicateRef, NativePredicate as NP, Params, PodType, + Predicate, KEY_SIGNER, KEY_TYPE, }, - middleware::{self, Params, PodType, KEY_SIGNER, KEY_TYPE}, }; /// Instantiates an ETH friend batch pub fn eth_friend_batch(params: &Params) -> Result> { let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into()); let _eth_friend = builder.predicate_and( + "eth_friend", params, // arguments: &["src_ori", "src_key", "dst_ori", "dst_key"], @@ -25,18 +25,17 @@ pub fn eth_friend_batch(params: &Params) -> Result> { &[ // there is an attestation pod that's a SignedPod STB::new(NP::ValueOf) - .arg(("attestation_pod", literal(KEY_TYPE))) - .arg(middleware::Value::from(PodType::MockSigned)), // TODO + .arg(("attestation_pod", key(KEY_TYPE))) + .arg(literal(PodType::MockSigned)), // TODO // the attestation pod is signed by (src_or, src_key) STB::new(NP::Equal) - .arg(("attestation_pod", literal(KEY_SIGNER))) + .arg(("attestation_pod", key(KEY_SIGNER))) .arg(("src_ori", "src_key")), // that same attestation pod has an "attestation" STB::new(NP::Equal) - .arg(("attestation_pod", literal("attestation"))) + .arg(("attestation_pod", key("attestation"))) .arg(("dst_ori", "dst_key")), ], - "eth_friend", )?; println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); @@ -53,6 +52,7 @@ pub fn eth_dos_batch(params: &Params) -> Result> { // ValueOf(distance_or, distance_key, 0) // > let eth_dos_distance_base = builder.predicate_and( + "eth_dos_distance_base", params, &[ // arguments: @@ -72,9 +72,8 @@ pub fn eth_dos_batch(params: &Params) -> Result> { .arg(("dst_ori", "dst_key")), STB::new(NP::ValueOf) .arg(("distance_ori", "distance_key")) - .arg(0), + .arg(literal(0)), ], - "eth_dos_distance_base", )?; println!( "b.0. eth_dos_distance_base = {}", @@ -84,6 +83,7 @@ pub fn eth_dos_batch(params: &Params) -> Result> { let eth_dos_distance = Predicate::BatchSelf(2); let eth_dos_distance_ind = builder.predicate_and( + "eth_dos_distance_ind", params, &[ // arguments: @@ -106,21 +106,27 @@ pub fn eth_dos_batch(params: &Params) -> Result> { &[ // statement templates: STB::new(eth_dos_distance) - .arg(("src_ori", "src_key")) - .arg(("intermed_ori", "intermed_key")) - .arg(("shorter_distance_ori", "shorter_distance_key")), + .arg("src_ori") + .arg("src_key") + .arg("intermed_ori") + .arg("intermed_key") + .arg("shorter_distance_ori") + .arg("shorter_distance_key"), // distance == shorter_distance + 1 - STB::new(NP::ValueOf).arg(("one_ori", "one_key")).arg(1), + STB::new(NP::ValueOf) + .arg(("one_ori", "one_key")) + .arg(literal(1)), STB::new(NP::SumOf) .arg(("distance_ori", "distance_key")) .arg(("shorter_distance_ori", "shorter_distance_key")) .arg(("one_ori", "one_key")), // intermed is a friend of dst STB::new(eth_friend) - .arg(("intermed_ori", "intermed_key")) - .arg(("dst_ori", "dst_key")), + .arg("intermed_ori") + .arg("intermed_key") + .arg("dst_ori") + .arg("dst_key"), ], - "eth_dos_distance_ind", )?; println!( @@ -129,6 +135,7 @@ pub fn eth_dos_batch(params: &Params) -> Result> { ); let _eth_dos_distance = builder.predicate_or( + "eth_dos_distance", params, &[ "src_ori", @@ -141,15 +148,20 @@ pub fn eth_dos_batch(params: &Params) -> Result> { &[], &[ STB::new(eth_dos_distance_base) - .arg(("src_ori", "src_key")) - .arg(("dst_ori", "dst_key")) - .arg(("distance_ori", "distance_key")), + .arg("src_ori") + .arg("src_key") + .arg("dst_ori") + .arg("dst_key") + .arg("distance_ori") + .arg("distance_key"), STB::new(eth_dos_distance_ind) - .arg(("src_ori", "src_key")) - .arg(("dst_ori", "dst_key")) - .arg(("distance_ori", "distance_key")), + .arg("src_ori") + .arg("src_key") + .arg("dst_ori") + .arg("dst_key") + .arg("distance_ori") + .arg("distance_key"), ], - "eth_dos_distance", )?; println!( diff --git a/src/examples/mod.rs b/src/examples/mod.rs index 5161f27..63296d3 100644 --- a/src/examples/mod.rs +++ b/src/examples/mod.rs @@ -1,17 +1,17 @@ pub mod custom; -use std::collections::HashMap; +use std::collections::HashSet; use anyhow::{anyhow, Result}; use custom::{eth_dos_batch, eth_friend_batch}; use crate::{ backends::plonky2::mock::signedpod::MockSigner, - frontend::{ - containers::Dictionary, CustomPredicateRef, MainPodBuilder, SignedPod, SignedPodBuilder, - Statement, Value, + frontend::{MainPodBuilder, SignedPod, SignedPodBuilder}, + middleware::{ + containers::Set, CustomPredicateRef, Key, Params, PodType, Statement, TypedValue, Value, + KEY_SIGNER, KEY_TYPE, }, - middleware::{Params, PodType, KEY_SIGNER, KEY_TYPE}, op, }; @@ -19,8 +19,10 @@ use crate::{ pub fn zu_kyc_sign_pod_builders( params: &Params, - sanction_set: &Value, ) -> (SignedPodBuilder, SignedPodBuilder, SignedPodBuilder) { + let sanctions_values: HashSet = ["A343434340"].iter().map(|s| Value::from(*s)).collect(); + let sanction_set = Value::from(Set::new(sanctions_values).unwrap()); + let mut gov_id = SignedPodBuilder::new(params); gov_id.insert("idNumber", "4242424242"); gov_id.insert("dateOfBirth", 1169909384); @@ -31,7 +33,8 @@ pub fn zu_kyc_sign_pod_builders( pay_stub.insert("startDate", 1706367566); let mut sanction_list = SignedPodBuilder::new(params); - sanction_list.insert("sanctionList", sanction_set.clone()); + + sanction_list.insert("sanctionList", sanction_set); (gov_id, pay_stub, sanction_list) } @@ -42,15 +45,15 @@ pub fn zu_kyc_pod_builder( pay_stub: &SignedPod, sanction_list: &SignedPod, ) -> Result { - let sanction_set = match sanction_list.kvs.get("sanctionList") { - Some(Value::Set(s)) => Ok(s), + let sanction_set = match sanction_list.get("sanctionList").map(|v| v.typed()) { + Some(TypedValue::Set(s)) => Ok(s), _ => Err(anyhow!("Missing sanction list!")), }?; let now_minus_18y: i64 = 1169909388; let now_minus_1y: i64 = 1706367566; let gov_id_kvs = gov_id.kvs(); - let id_number_value = gov_id_kvs.get(&"idNumber".into()).unwrap(); + let id_number_value = gov_id_kvs.get(&Key::from("idNumber")).unwrap(); let mut kyc = MainPodBuilder::new(params); kyc.add_signed_pod(gov_id); @@ -60,9 +63,7 @@ pub fn zu_kyc_pod_builder( set_not_contains, (sanction_list, "sanctionList"), (gov_id, "idNumber"), - sanction_set - .middleware_set() - .prove_nonexistence(id_number_value)? + sanction_set.prove_nonexistence(id_number_value)? ))?; kyc.pub_op(op!(lt, (gov_id, "dateOfBirth"), now_minus_18y))?; kyc.pub_op(op!( @@ -77,7 +78,10 @@ pub fn zu_kyc_pod_builder( // ETHDoS -pub fn eth_friend_signed_pod_builder(params: &Params, friend_pubkey: Value) -> SignedPodBuilder { +pub fn eth_friend_signed_pod_builder( + params: &Params, + friend_pubkey: TypedValue, +) -> SignedPodBuilder { let mut attestation = SignedPodBuilder::new(params); attestation.insert("attestation", friend_pubkey); @@ -88,7 +92,7 @@ pub fn eth_dos_pod_builder( params: &Params, alice_attestation: &SignedPod, charlie_attestation: &SignedPod, - bob_pubkey: &Value, + bob_pubkey: &TypedValue, ) -> Result { // Will need ETH friend and ETH DoS custom predicate batches. let eth_friend = CustomPredicateRef::new(eth_friend_batch(params)?, 0); @@ -103,14 +107,14 @@ pub fn eth_dos_pod_builder( alice_bob_ethdos.add_signed_pod(charlie_attestation); // Attestation POD entries - let alice_pubkey = *alice_attestation - .kvs() - .get(&KEY_SIGNER.into()) - .ok_or(anyhow!("Could not find Alice's public key!"))?; - let charlie_pubkey = *charlie_attestation - .kvs() - .get(&KEY_SIGNER.into()) - .ok_or(anyhow!("Could not find Charlie's public key!"))?; + let alice_pubkey = alice_attestation + .get(KEY_SIGNER) + .expect("Could not find Alice's public key!") + .clone(); + let charlie_pubkey = charlie_attestation + .get(KEY_SIGNER) + .expect("Could not find Charlie's public key!") + .clone(); // Include Alice and Bob's keys as public statements. We don't // want to reveal the middleman. @@ -119,7 +123,7 @@ pub fn eth_dos_pod_builder( let charlie_pubkey = alice_bob_ethdos.priv_op(op!(new_entry, ("Charlie", charlie_pubkey)))?; // The ETHDoS distance from Alice to Alice is 0. - let zero = alice_bob_ethdos.priv_literal(&0)?; + let zero = alice_bob_ethdos.priv_literal(0)?; let alice_equals_alice = alice_bob_ethdos.priv_op(op!( eq, (alice_attestation, KEY_SIGNER), @@ -134,11 +138,12 @@ pub fn eth_dos_pod_builder( let ethdos_alice_alice_is_zero = alice_bob_ethdos.priv_op(op!( custom, eth_dos.clone(), - ethdos_alice_alice_is_zero_base + ethdos_alice_alice_is_zero_base, + Statement::None ))?; // Alice and Charlie are ETH friends. - let attestation_is_signed_pod = Statement::from((alice_attestation, KEY_TYPE)); + let attestation_is_signed_pod = alice_attestation.get_statement(KEY_TYPE).unwrap(); let attestation_signed_by_alice = alice_bob_ethdos.priv_op(op!(eq, (alice_attestation, KEY_SIGNER), alice_pubkey_copy))?; let alice_attests_to_charlie = alice_bob_ethdos.priv_op(op!( @@ -155,7 +160,7 @@ pub fn eth_dos_pod_builder( ))?; // ...and so are Chuck and Bob. - let attestation_is_signed_pod = Statement::from((charlie_attestation, KEY_TYPE)); + let attestation_is_signed_pod = charlie_attestation.get_statement(KEY_TYPE).unwrap(); let attestation_signed_by_charlie = alice_bob_ethdos.priv_op(op!(eq, (charlie_attestation, KEY_SIGNER), charlie_pubkey))?; let charlie_attests_to_bob = alice_bob_ethdos.priv_op(op!( @@ -172,7 +177,7 @@ pub fn eth_dos_pod_builder( ))?; // The ETHDoS distance from Alice to Charlie is 1. - let one = alice_bob_ethdos.priv_literal(&1)?; + let one = alice_bob_ethdos.priv_literal(1)?; // 1 = 0 + 1 let ethdos_sum = alice_bob_ethdos.priv_op(op!(sum_of, one.clone(), zero.clone(), one.clone()))?; @@ -187,13 +192,14 @@ pub fn eth_dos_pod_builder( let ethdos_alice_charlie_is_one = alice_bob_ethdos.priv_op(op!( custom, eth_dos.clone(), + Statement::None, ethdos_alice_charlie_is_one_ind ))?; // The ETHDoS distance from Alice to Bob is 2. // The constant "TWO" and the final statement are both to be // public. - let two = alice_bob_ethdos.pub_literal(&2)?; + let two = alice_bob_ethdos.pub_literal(2)?; // 2 = 1 + 1 let ethdos_sum = alice_bob_ethdos.priv_op(op!(sum_of, two.clone(), one.clone(), one.clone()))?; @@ -205,8 +211,12 @@ pub fn eth_dos_pod_builder( ethdos_sum, ethfriends_charlie_bob ))?; - let _ethdos_alice_bob_is_two = - alice_bob_ethdos.pub_op(op!(custom, eth_dos.clone(), ethdos_alice_bob_is_two_ind))?; + let _ethdos_alice_bob_is_two = alice_bob_ethdos.pub_op(op!( + custom, + eth_dos.clone(), + Statement::None, + ethdos_alice_bob_is_two_ind + ))?; Ok(alice_bob_ethdos) } @@ -268,18 +278,15 @@ pub fn great_boy_pod_builder( PodType::MockSigned as i64 ))?; // Each good boy POD comes from a valid issuer - let good_boy_proof = match good_boy_issuers { - Value::Dictionary(dict) => Ok(dict), + let good_boy_proof = match good_boy_issuers.typed() { + TypedValue::Set(set) => Ok(set), _ => Err(anyhow!("Invalid good boy issuers!")), }? - .middleware_dict() - .prove(pod_kvs.get(&KEY_SIGNER.into()).unwrap())? - .1; + .prove(pod_kvs.get(&Key::from(KEY_SIGNER)).unwrap())?; great_boy.pub_op(op!( - dict_contains, + set_contains, good_boy_issuers, (good_boy_pods[good_boy_idx * 2 + issuer_idx], KEY_SIGNER), - 0, good_boy_proof ))?; // Each good boy has 2 good boy pods @@ -357,11 +364,8 @@ pub fn great_boy_pod_full_flow() -> Result { alice_friend_pods.push(friend.sign(&mut bob_signer).unwrap()); alice_friend_pods.push(friend.sign(&mut charlie_signer).unwrap()); - let good_boy_issuers = Value::Dictionary(Dictionary::new( - good_boy_issuers - .into_iter() - .map(|issuer| (issuer.to_string(), 0.into())) - .collect(), + let good_boy_issuers = Value::from(Set::new( + good_boy_issuers.into_iter().map(Value::from).collect(), )?); great_boy_pod_builder( @@ -397,13 +401,15 @@ pub fn tickets_pod_builder( signed_pod: &SignedPod, expected_event_id: i64, expect_consumed: bool, - blacklisted_emails: &Dictionary, + blacklisted_emails: &Set, ) -> Result { - let attendee_email_value = signed_pod.kvs.get("attendeeEmail").unwrap(); - let attendee_nin_blacklist_pf = blacklisted_emails - .middleware_dict() - .prove_nonexistence(&attendee_email_value.into())?; - let blacklisted_email_dict_value = Value::Dictionary(blacklisted_emails.clone()); + let attendee_email_value = signed_pod + .kvs() + .get(&Key::from("attendeeEmail")) + .unwrap() + .clone(); + let attendee_nin_blacklist_pf = blacklisted_emails.prove_nonexistence(&attendee_email_value)?; + let blacklisted_email_set_value = Value::from(TypedValue::Set(blacklisted_emails.clone())); // Create a main pod referencing this signed pod with some statements let mut builder = MainPodBuilder::new(params); builder.add_signed_pod(signed_pod); @@ -412,7 +418,7 @@ pub fn tickets_pod_builder( builder.pub_op(op!(eq, (signed_pod, "isRevoked"), false))?; builder.pub_op(op!( dict_not_contains, - blacklisted_email_dict_value, + blacklisted_email_set_value, (signed_pod, "attendeeEmail"), attendee_nin_blacklist_pf ))?; @@ -423,11 +429,5 @@ pub fn tickets_pod_full_flow() -> Result { let params = Params::default(); let builder = tickets_sign_pod_builder(¶ms); let signed_pod = builder.sign(&mut MockSigner { pk: "test".into() }).unwrap(); - tickets_pod_builder( - ¶ms, - &signed_pod, - 123, - true, - &Dictionary::new(HashMap::new())?, - ) + tickets_pod_builder(¶ms, &signed_pod, 123, true, &Set::new(HashSet::new())?) } diff --git a/src/frontend/containers.rs b/src/frontend/containers.rs deleted file mode 100644 index 081bbb6..0000000 --- a/src/frontend/containers.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::collections::HashMap; - -use anyhow::Result; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; - -use crate::{ - frontend::{serialization::ordered_map, Value}, - middleware::{ - containers::{ - Array as MiddlewareArray, Dictionary as MiddlewareDictionary, Set as MiddlewareSet, - }, - hash_str, Value as MiddlewareValue, - }, -}; - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)] -#[serde(transparent)] -pub struct Set(Vec, #[serde(skip)] MiddlewareSet); - -impl Set { - pub fn new(values: Vec) -> Result { - let set = - MiddlewareSet::new(&values.iter().map(MiddlewareValue::from).collect::>())?; - Ok(Self(values, set)) - } - - pub fn middleware_set(&self) -> &MiddlewareSet { - &self.1 - } - - pub fn values(&self) -> &Vec { - &self.0 - } -} - -impl<'de> Deserialize<'de> for Set { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let values: Vec = Vec::deserialize(deserializer)?; - Set::new(values).map_err(serde::de::Error::custom) - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)] -#[serde(transparent)] -pub struct Dictionary( - #[serde(serialize_with = "ordered_map")] HashMap, - #[serde(skip)] MiddlewareDictionary, -); - -impl Dictionary { - pub fn new(values: HashMap) -> Result { - let dict = MiddlewareDictionary::new( - &values - .iter() - .map(|(k, v)| (hash_str(k), MiddlewareValue::from(v))) - .collect::>(), - )?; - Ok(Self(values, dict)) - } - - pub fn middleware_dict(&self) -> &MiddlewareDictionary { - &self.1 - } - - pub fn values(&self) -> &HashMap { - &self.0 - } -} - -impl<'de> Deserialize<'de> for Dictionary { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let values: HashMap = HashMap::deserialize(deserializer)?; - Dictionary::new(values).map_err(serde::de::Error::custom) - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)] -#[serde(transparent)] -pub struct Array(Vec, #[serde(skip)] MiddlewareArray); - -impl Array { - pub fn new(values: Vec) -> Result { - let array = - MiddlewareArray::new(&values.iter().map(MiddlewareValue::from).collect::>())?; - Ok(Self(values, array)) - } - - pub fn middleware_array(&self) -> &MiddlewareArray { - &self.1 - } - - pub fn values(&self) -> &Vec { - &self.0 - } -} - -impl<'de> Deserialize<'de> for Array { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let values: Vec = Vec::deserialize(deserializer)?; - Array::new(values).map_err(serde::de::Error::custom) - } -} diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index b709014..d7e6298 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -3,97 +3,40 @@ use std::{collections::HashMap, fmt, hash as h, iter, iter::zip, sync::Arc}; use anyhow::{anyhow, Result}; use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +// use serde::{Deserialize, Serialize}; use crate::{ - frontend::{AnchoredKey, NativePredicate, Origin, Statement, StatementArg, Value}, - middleware::{self, hash_str, HashOrWildcard, Params, PodId, ToFields}, - util::hashmap_insert_no_dupe, + frontend::{AnchoredKey, Statement, StatementArg}, + middleware::{ + self, hash_str, CustomPredicate, CustomPredicateBatch, Key, KeyOrWildcard, NativePredicate, + Params, PodId, Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard, + }, }; -#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq, Eq)] /// Argument to a statement template pub enum KeyOrWildcardStr { Key(String), // represents a literal key Wildcard(String), } -#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] -pub struct IndexedWildcard { - wildcard: String, - index: usize, -} - -impl IndexedWildcard { - pub fn new(wildcard: String, index: usize) -> Self { - Self { wildcard, index } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] -#[serde(tag = "type", content = "value")] -/// Represents a key or resolved wildcard -pub enum KeyOrWildcard { - Key(String), - Wildcard(IndexedWildcard), -} - -impl KeyOrWildcard { - /// Matches a key or wildcard against a value, returning a pair - /// representing a wildcard binding (if any) or an error if no - /// match is possible. - pub fn match_against(&self, v: &Value) -> Result> { - match self { - KeyOrWildcard::Key(k) if Value::from(k.as_str()) == *v => Ok(None), - KeyOrWildcard::Wildcard(i) => Ok(Some((i.index, v.clone()))), - _ => Err(anyhow!( - "Failed to match key or wildcard {} against value {}.", - self, - v - )), - } - } -} - -impl From for middleware::HashOrWildcard { - fn from(v: KeyOrWildcard) -> Self { - match v { - KeyOrWildcard::Key(k) => middleware::HashOrWildcard::Hash(hash_str(&k)), - KeyOrWildcard::Wildcard(n) => middleware::HashOrWildcard::Wildcard(n.index), - } - } -} -impl fmt::Display for KeyOrWildcard { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Key(k) => write!(f, "{}", k), - Self::Wildcard(n) => write!(f, "*{}", n.wildcard), - } - } -} - /// helper to build a literal KeyOrWildcardStr::Key from the given str -pub fn literal(s: &str) -> KeyOrWildcardStr { +pub fn key(s: &str) -> KeyOrWildcardStr { KeyOrWildcardStr::Key(s.to_string()) } -/// helper to build a KeyOrWildcardStr::Wildcard from the given str. For the -/// moment this method does not need to be public. -fn wildcard(s: &str) -> KeyOrWildcardStr { - KeyOrWildcardStr::Wildcard(s.to_string()) -} - /// Builder Argument for the StatementTmplBuilder pub enum BuilderArg { Literal(Value), - /// Key: (origin, key), where origin & key can be both Hash or Wildcard - Key(KeyOrWildcardStr, KeyOrWildcardStr), + /// Key: (origin, key), where origin is a Wildcard and key can be both Key or Wildcard + Key(String, KeyOrWildcardStr), + WildcardLiteral(String), } /// When defining a `BuilderArg`, it can be done from 3 different inputs: /// i. (&str, literal): this is to set a POD and a field, ie. (POD, literal("field")) /// ii. (&str, &str): this is to define a origin-key wildcard pair, ie. (src_origin, src_dest) -/// iii. Value: this is to define a literal value, ie. 0 +/// iii. &str: this is to define a WildcardValue wildcard, ie. "src_or" /// /// case i. impl From<(&str, KeyOrWildcardStr)> for BuilderArg { @@ -103,267 +46,24 @@ impl From<(&str, KeyOrWildcardStr)> for BuilderArg { KeyOrWildcardStr::Key(_) => (), _ => panic!("not supported"), }; - Self::Key(wildcard(origin), lit) + Self::Key(origin.into(), lit) } } /// case ii. impl From<(&str, &str)> for BuilderArg { fn from((origin, field): (&str, &str)) -> Self { - Self::Key(wildcard(origin), wildcard(field)) + Self::Key(origin.into(), KeyOrWildcardStr::Wildcard(field.to_string())) } } /// case iii. -impl From for BuilderArg -where - V: Into, -{ - fn from(v: V) -> Self { - Self::Literal(v.into()) +impl From<&str> for BuilderArg { + fn from(wc: &str) -> Self { + Self::WildcardLiteral(wc.to_string()) } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -#[serde(tag = "type", content = "value")] -pub enum Predicate { - Native(NativePredicate), - BatchSelf(usize), - Custom(CustomPredicateRef), -} - -impl From for Predicate { - fn from(v: NativePredicate) -> Self { - Self::Native(v) - } -} - -impl From for middleware::Predicate { - fn from(v: Predicate) -> Self { - match v { - Predicate::Native(p) => middleware::Predicate::Native(p.into()), - Predicate::BatchSelf(i) => middleware::Predicate::BatchSelf(i), - Predicate::Custom(CustomPredicateRef { - batch: pb, - index: i, - }) => { - let cpb: middleware::CustomPredicateBatch = Arc::unwrap_or_clone(pb).into(); - middleware::Predicate::Custom(middleware::CustomPredicateRef(Arc::new(cpb), i)) - } - } - } -} - -impl fmt::Display for Predicate { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Native(p) => write!(f, "{:?}", p), - Self::BatchSelf(i) => write!(f, "self.{}", i), - Self::Custom(CustomPredicateRef { batch, index }) => { - write!(f, "{}.{}", batch.name, index) - } - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -pub struct CustomPredicateRef { - pub batch: Arc, - pub index: usize, -} - -impl From for middleware::CustomPredicateRef { - fn from(v: CustomPredicateRef) -> Self { - let cpb: middleware::CustomPredicateBatch = Arc::unwrap_or_clone(v.batch).into(); - middleware::CustomPredicateRef(Arc::new(cpb), v.index) - } -} - -impl CustomPredicateRef { - pub fn new(batch: Arc, index: usize) -> Self { - Self { batch, index } - } - - pub fn arg_len(&self) -> usize { - self.batch.predicates[self.index].args_len - } - pub fn match_against(&self, statements: &[Statement]) -> Result> { - let mut bindings = HashMap::new(); - // Single out custom predicate, replacing batch-self - // references with custom predicate references. - let custom_predicate = { - let cp = &Arc::unwrap_or_clone(self.batch.clone()).predicates[self.index]; - CustomPredicate { - conjunction: cp.conjunction, - statements: cp - .statements - .iter() - .map(|StatementTmpl { pred: p, args }| StatementTmpl { - pred: match p { - Predicate::BatchSelf(i) => { - Predicate::Custom(CustomPredicateRef::new(self.batch.clone(), *i)) - } - _ => p.clone(), - }, - args: args.to_vec(), - }) - .collect(), - args_len: cp.args_len, - name: cp.name.to_string(), - } - }; - match custom_predicate.conjunction { - true if custom_predicate.statements.len() == statements.len() => { - // Match op args against statement templates - let match_bindings = iter::zip(custom_predicate.statements, statements).map( - |(s_tmpl, s)| s_tmpl.match_against(s) - ).collect::>>() - .map(|v| v.concat())?; - // Add bindings to binding table, throwing if there is an inconsistency. - match_bindings.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; - Ok(bindings) - }, - false if statements.len() == 1 => { - // Match op arg against each statement template - custom_predicate.statements.iter().map( - |s_tmpl| { - let mut bindings = bindings.clone(); - s_tmpl.match_against(&statements[0])?.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; - Ok::<_, anyhow::Error>(bindings) - } - ).find(|m| m.is_ok()).unwrap_or(Err(anyhow!("Statement {} does not match disjunctive custom predicate {}.", &statements[0], custom_predicate))) - }, - _ => Err(anyhow!("Custom predicate statement template list {:?} does not match op argument list {:?}.", custom_predicate.statements, statements)) - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -pub struct CustomPredicateBatch { - pub name: String, - pub predicates: Vec, -} - -impl From for middleware::CustomPredicateBatch { - fn from(v: CustomPredicateBatch) -> Self { - middleware::CustomPredicateBatch { - name: v.name, - predicates: v.predicates.into_iter().map(|p| p.into()).collect(), - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -/// NOTE: fields are not public (outside of crate) to enforce the struct instantiation through -/// the `::and/or` methods, which performs checks on the values. -pub struct CustomPredicate { - /// true for "and", false for "or" - pub(crate) conjunction: bool, - pub(crate) statements: Vec, - pub(crate) args_len: usize, - // TODO: Add private args length? - // TODO: Add args type information? - pub(crate) name: String, -} - -impl CustomPredicate { - pub fn and( - params: &Params, - statements: Vec, - args_len: usize, - name: &str, - ) -> Result { - Self::new(params, true, statements, args_len, name) - } - pub fn or( - params: &Params, - statements: Vec, - args_len: usize, - name: &str, - ) -> Result { - Self::new(params, false, statements, args_len, name) - } - pub fn new( - params: &Params, - conjunction: bool, - statements: Vec, - args_len: usize, - name: &str, - ) -> Result { - if statements.len() > params.max_custom_predicate_arity { - return Err(anyhow!("Custom predicate depends on too many statements")); - } - - Ok(Self { - conjunction, - statements, - args_len, - name: name.to_string(), - }) - } -} - -impl From for middleware::CustomPredicate { - fn from(v: CustomPredicate) -> Self { - middleware::CustomPredicate { - conjunction: v.conjunction, - statements: v.statements.into_iter().map(|s| s.into()).collect(), - args_len: v.args_len, - } - } -} -impl fmt::Display for CustomPredicate { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?; - for st in &self.statements { - write!(f, " {}", st.pred)?; - for (i, arg) in st.args.iter().enumerate() { - if i != 0 { - write!(f, ", ")?; - } - write!(f, "{}", arg)?; - } - writeln!(f, "),")?; - } - write!(f, ">(")?; - for i in 0..self.args_len { - if i != 0 { - write!(f, ", ")?; - } - write!(f, "*{}", i)?; - } - writeln!(f, ")")?; - Ok(()) - } -} - -#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] -#[serde(tag = "type", content = "value")] -pub enum StatementTmplArg { - None, - Literal(Value), - // #[schemars(with = "Vec")] - Key(KeyOrWildcard, KeyOrWildcard), -} - -impl From for middleware::StatementTmplArg { - fn from(v: StatementTmplArg) -> Self { - match v { - StatementTmplArg::None => middleware::StatementTmplArg::None, - StatementTmplArg::Literal(v) => middleware::StatementTmplArg::Literal((&v).into()), - StatementTmplArg::Key(pod_id, key) => { - middleware::StatementTmplArg::Key(pod_id.into(), key.into()) - } - } - } -} - -impl fmt::Display for StatementTmplArg { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::None => write!(f, "none"), - Self::Literal(v) => write!(f, "{}", v), - Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key), - } - } +pub fn literal(v: impl Into) -> BuilderArg { + BuilderArg::Literal(v.into()) } pub struct StatementTmplBuilder { @@ -385,83 +85,6 @@ impl StatementTmplBuilder { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -pub struct StatementTmpl { - pub pred: Predicate, - pub args: Vec, -} - -impl StatementTmpl { - pub fn pred(&self) -> &Predicate { - &self.pred - } - pub fn args(&self) -> &[StatementTmplArg] { - &self.args - } - /// Matches a statement template against a statement, returning - /// the variable bindings as an association list. Returns an error - /// if there is type or argument mismatch. - pub fn match_against(&self, s: &Statement) -> Result> { - type P = Predicate; - if matches!( - self, - Self { - pred: P::BatchSelf(_), - args: _ - } - ) { - Err(anyhow!( - "Cannot check self-referencing statement templates." - )) - } else if self.pred() != &s.predicate { - Err(anyhow!("Type mismatch between {:?} and {}.", self, s)) - } else { - zip(self.args(), s.args.clone()) - .map(|(t_arg, s_arg)| t_arg.match_against(&s_arg)) - .collect::>>() - .map(|v| v.concat()) - } - } -} - -impl From for middleware::StatementTmpl { - fn from(v: StatementTmpl) -> Self { - middleware::StatementTmpl( - v.pred.into(), - v.args.into_iter().map(|a| a.into()).collect(), - ) - } -} - -impl StatementTmplArg { - /// Matches a statement template argument against a statement - /// argument, returning a wildcard correspondence in the case of - /// one or more wildcard matches, nothing in the case of a - /// literal/hash match, and an error otherwise. - pub fn match_against(&self, s_arg: &StatementArg) -> Result> { - match (self, s_arg) { - // (Self::None, StatementArg::None) => Ok(vec![]), - (Self::Literal(v), StatementArg::Literal(w)) if v == w => Ok(vec![]), - ( - Self::Key(tmpl_o, tmpl_k), - StatementArg::Key(AnchoredKey { - origin: Origin { pod_id: PodId(o) }, - key: k, - }), - ) => { - let o_corr = tmpl_o.match_against(&(middleware::Value::from(*o)).into())?; - let k_corr = tmpl_k.match_against(&(*k.as_str()).into())?; - Ok([o_corr, k_corr].into_iter().flatten().collect()) - } - _ => Err(anyhow!( - "Failed to match statement template argument {:?} against statement argument {:?}.", - self, - s_arg - )), - } - } -} - pub struct CustomPredicateBatchBuilder { pub name: String, pub predicates: Vec, @@ -477,37 +100,52 @@ impl CustomPredicateBatchBuilder { pub fn predicate_and( &mut self, + name: &str, params: &Params, args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], - name: &str, ) -> Result { - self.predicate(params, true, args, priv_args, sts, name) + self.predicate(name, params, true, args, priv_args, sts) } pub fn predicate_or( &mut self, + name: &str, params: &Params, args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], - name: &str, ) -> Result { - self.predicate(params, false, args, priv_args, sts, name) + self.predicate(name, params, false, args, priv_args, sts) } /// creates the custom predicate from the given input, adds it to the /// self.predicates, and returns the index of the created predicate fn predicate( &mut self, + name: &str, params: &Params, conjunction: bool, args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], - name: &str, ) -> Result { + if args.len() > params.max_statement_args { + return Err(anyhow!( + "args.len {} is over the limit {}", + args.len(), + params.max_statement_args + )); + } + if (args.len() + priv_args.len()) > params.max_custom_predicate_wildcards { + return Err(anyhow!( + "wildcards.len {} is over the limit {}", + args.len() + priv_args.len(), + params.max_custom_predicate_wildcards + )); + } + let statements = sts .iter() .map(|sb| { @@ -518,8 +156,11 @@ impl CustomPredicateBatchBuilder { BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), BuilderArg::Key(pod_id, key) => StatementTmplArg::Key( resolve_wildcard(args, priv_args, pod_id), - resolve_wildcard(args, priv_args, key), + resolve_key_or_wildcard(args, priv_args, key), ), + BuilderArg::WildcardLiteral(v) => { + StatementTmplArg::WildcardLiteral(resolve_wildcard(args, priv_args, v)) + } }) .collect(); StatementTmpl { @@ -529,7 +170,7 @@ impl CustomPredicateBatchBuilder { }) .collect(); let custom_predicate = - CustomPredicate::new(params, conjunction, statements, args.len(), name)?; + CustomPredicate::new(name.into(), params, conjunction, statements, args.len())?; self.predicates.push(custom_predicate); Ok(Predicate::BatchSelf(self.predicates.len() - 1)) } @@ -542,26 +183,34 @@ impl CustomPredicateBatchBuilder { } } -fn resolve_wildcard(args: &[&str], priv_args: &[&str], v: &KeyOrWildcardStr) -> KeyOrWildcard { +fn resolve_key_or_wildcard( + args: &[&str], + priv_args: &[&str], + v: &KeyOrWildcardStr, +) -> KeyOrWildcard { match v { - KeyOrWildcardStr::Key(k) => KeyOrWildcard::Key(k.clone()), - KeyOrWildcardStr::Wildcard(s) => KeyOrWildcard::Wildcard( - args.iter() - .chain(priv_args.iter()) - .enumerate() - .find_map(|(i, name)| (s == name).then_some(IndexedWildcard::new(s.clone(), i))) - .unwrap(), - ), + KeyOrWildcardStr::Key(k) => KeyOrWildcard::Key(Key::from(k)), + KeyOrWildcardStr::Wildcard(s) => { + KeyOrWildcard::Wildcard(resolve_wildcard(args, priv_args, s)) + } } } +fn resolve_wildcard(args: &[&str], priv_args: &[&str], s: &str) -> Wildcard { + args.iter() + .chain(priv_args.iter()) + .enumerate() + .find_map(|(i, name)| (s == *name).then_some(Wildcard::new(s.to_string(), i))) + .unwrap() +} + #[cfg(test)] mod tests { use super::*; use crate::{ examples::custom::{eth_dos_batch, eth_friend_batch}, middleware, - // middleware::{CustomPredicateRef, Params, PodType}, + middleware::{CustomPredicateRef, Params, PodType}, }; #[test] @@ -569,7 +218,11 @@ mod tests { use NativePredicate as NP; use StatementTmplBuilder as STB; - let params = Params::default(); + let params = Params { + max_statement_args: 6, + max_custom_predicate_wildcards: 12, + ..Default::default() + }; params.print_serialized_sizes(); // ETH friend custom predicate batch diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 370a099..a38759e 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -1,186 +1,43 @@ //! The frontend includes the user-level abstractions and user-friendly types to define and work //! with Pods. -use std::{collections::HashMap, convert::From, fmt, hash as h, hash::Hasher}; +use std::{collections::HashMap, convert::From, fmt}; -use anyhow::{anyhow, Error, Result}; -use containers::{Array, Dictionary, Set}; +use anyhow::{anyhow, Result}; use itertools::Itertools; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use crate::{ - frontend::serialization::*, - middleware::{ - self, hash_str, hash_value, Hash, MainPodInputs, OperationAux, Params, PodId, PodProver, - PodSigner, EMPTY_VALUE, KEY_SIGNER, KEY_TYPE, SELF, - }, +// use schemars::JsonSchema; + +// use serde::{Deserialize, Serialize}; +use crate::middleware::{ + self, check_st_tmpl, hash_str, AnchoredKey, Key, MainPodInputs, NativeOperation, + NativePredicate, OperationAux, OperationType, Params, PodId, PodProver, PodSigner, Predicate, + Statement, StatementArg, Value, WildcardValue, EMPTY_VALUE, KEY_TYPE, SELF, }; -pub mod containers; mod custom; mod operation; -mod predicate; -mod serialization; -mod statement; -pub use custom::{CustomPredicateRef, Predicate, *}; +pub use custom::*; pub use operation::*; -pub use predicate::*; -pub use statement::*; /// This type is just for presentation purposes. -#[derive(Clone, Debug, Default, h::Hash, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub enum PodClass { #[default] Signed, Main, } -// An Origin, which represents a reference to an ancestor POD. -#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Default, Serialize, Deserialize, JsonSchema)] -pub struct Origin { - pub pod_id: PodId, -} - -impl Origin { - pub fn new(pod_id: PodId) -> Self { - Self { pod_id } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -#[schemars(transform = serialization::transform_value_schema)] -pub enum Value { - // Serde cares about the order of the enum variants, with untagged variants - // appearing at the end. - // Variants without "untagged" will be serialized as "tagged" values by - // default, meaning that a Set appears in JSON as {"Set":[...]} - // and not as [...] - // Arrays, Strings and Booleans are untagged, as there is a natural JSON - // representation for them that is unambiguous to deserialize and is fully - // compatible with the semantics of the POD types. - // As JSON integers do not specify precision, and JavaScript is limited to - // 53-bit precision for integers, integers are represented as tagged - // strings, with a custom serializer and deserializer. - // TAGGED TYPES: - Set(Set), - Dictionary(Dictionary), - Int( - #[serde(serialize_with = "serialize_i64", deserialize_with = "deserialize_i64")] - #[schemars(with = "String", regex(pattern = r"^\d+$"))] - i64, - ), - // Uses the serialization for middleware::Value: - Raw(middleware::Value), - // UNTAGGED TYPES: - #[serde(untagged)] - #[schemars(skip)] - Array(Array), - #[serde(untagged)] - #[schemars(skip)] - String(String), - #[serde(untagged)] - #[schemars(skip)] - Bool(bool), -} - -impl h::Hash for Value { - fn hash(&self, state: &mut H) { - // Hash the discriminant first - std::mem::discriminant(self).hash(state); - - // Hash the inner values only for types that implement Hash - match self { - Value::String(s) => s.hash(state), - Value::Int(i) => i.hash(state), - Value::Bool(b) => b.hash(state), - Value::Dictionary(d) => d.middleware_dict().commitment().hash(state), - Value::Set(s) => s.middleware_set().commitment().hash(state), - Value::Array(a) => a.middleware_array().commitment().hash(state), - Value::Raw(r) => r.hash(state), - } - } -} - -impl From<&str> for Value { - fn from(s: &str) -> Self { - Value::String(s.to_string()) - } -} - -impl From for Value { - fn from(v: i64) -> Self { - Value::Int(v) - } -} - -impl From for Value { - fn from(b: bool) -> Self { - Value::Bool(b) - } -} - -impl From<&Value> for middleware::Value { - fn from(v: &Value) -> Self { - match v { - Value::String(s) => hash_str(s).value(), - Value::Int(v) => middleware::Value::from(*v), - Value::Bool(b) => middleware::Value::from(*b as i64), - Value::Dictionary(d) => d.middleware_dict().commitment().value(), - Value::Set(s) => s.middleware_set().commitment().value(), - Value::Array(a) => a.middleware_array().commitment().value(), - Value::Raw(v) => *v, - } - } -} - -impl From for Value { - fn from(v: middleware::Value) -> Self { - Self::Raw(v) - } -} - -impl From for Value { - fn from(v: middleware::Hash) -> Self { - Self::Raw(v.into()) - } -} - -impl TryInto for Value { - type Error = Error; - fn try_into(self) -> std::result::Result { - if let Value::Int(n) = self { - Ok(n) - } else { - Err(anyhow!("Value not an int")) - } - } -} - -impl fmt::Display for Value { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Value::String(s) => write!(f, "\"{}\"", s), - Value::Int(v) => write!(f, "{}", v), - Value::Bool(b) => write!(f, "{}", b), - Value::Dictionary(d) => write!(f, "dict:{}", d.middleware_dict().commitment()), - Value::Set(s) => write!(f, "set:{}", s.middleware_set().commitment()), - Value::Array(a) => write!(f, "arr:{}", a.middleware_array().commitment()), - Value::Raw(v) => write!(f, "{}", v), - } - } -} - #[derive(Clone, Debug)] pub struct SignedPodBuilder { pub params: Params, - pub kvs: HashMap, + pub kvs: HashMap, } impl fmt::Display for SignedPodBuilder { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "SignedPodBuilder:")?; - for (k, v) in self.kvs.iter().sorted_by_key(|kv| kv.0) { + for (k, v) in self.kvs.iter().sorted_by_key(|kv| kv.0.hash()) { writeln!(f, " - {}: {}", k, v)?; } Ok(()) @@ -195,55 +52,26 @@ impl SignedPodBuilder { } } - pub fn insert(&mut self, key: impl Into, value: impl Into) { + pub fn insert(&mut self, key: impl Into, value: impl Into) { self.kvs.insert(key.into(), value.into()); } pub fn sign(&self, signer: &mut S) -> Result { // Sign POD with committed KV store. - let committed_kvs = self - .kvs - .iter() - .map(|(k, v)| (hash_str(k), v.into())) - .collect::>(); - let pod = signer.sign(&self.params, &committed_kvs)?; + let pod = signer.sign(&self.params, &self.kvs)?; - let mut kvs = self.kvs.clone(); - - // Type and signer information are passed in by the - // backend. Include these in the frontend representation. - let mid_kvs = pod.kvs(); - let pod_type = mid_kvs - .get(&crate::middleware::AnchoredKey( - pod.id(), - hash_str(KEY_TYPE), - )) - .cloned() - .ok_or(anyhow!("Missing POD type information in POD: {:?}", pod))?; - let pod_signer = mid_kvs - .get(&crate::middleware::AnchoredKey( - pod.id(), - hash_str(KEY_SIGNER), - )) - .cloned() - .ok_or(anyhow!("Missing POD signer in POD: {:?}", pod))?; - kvs.insert(KEY_TYPE.to_string(), pod_type.into()); - kvs.insert(KEY_SIGNER.to_string(), pod_signer.into()); - Ok(SignedPod { pod, kvs }) + Ok(SignedPod::new(pod)) } } /// SignedPod is a wrapper on top of backend::SignedPod, which additionally stores the /// string<-->hash relation of the keys. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(try_from = "SignedPodHelper", into = "SignedPodHelper")] +#[derive(Debug, Clone)] +// #[serde(try_from = "SignedPodHelper", into = "SignedPodHelper")] pub struct SignedPod { pub pod: Box, - /// Key-value pairs as represented in the frontend. These should - /// correspond to the entries of `pod.kvs()` after hashing and - /// replacing each key with its corresponding anchored key. - #[serde(serialize_with = "ordered_map")] - pub kvs: HashMap, + // We store a copy of the key values for quick access + kvs: HashMap, } impl fmt::Display for SignedPod { @@ -253,56 +81,45 @@ impl fmt::Display for SignedPod { // https://0xparc.github.io/pod2/merkletree.html will not need it since it will be // deterministic based on the keys values not on the order of the keys when added into the // tree. - for (k, v) in self.kvs.iter().sorted_by_key(|kv| kv.0) { - writeln!( - f, - " - {} = {}: {}", - hash_str(k), - k, - crate::middleware::Value::from(v) - )?; + for (k, v) in self.pod.kvs().iter().sorted_by_key(|kv| kv.0.key.hash()) { + writeln!(f, " - {} = {}", k, v)?; } Ok(()) } } impl SignedPod { + pub fn new(pod: Box) -> Self { + let kvs = pod + .kvs() + .into_iter() + .map(|(AnchoredKey { key, .. }, v)| (key, v)) + .collect(); + Self { pod, kvs } + } pub fn id(&self) -> PodId { self.pod.id() } - pub fn origin(&self) -> Origin { - Origin::new(self.id()) - } pub fn verify(&self) -> Result<()> { self.pod.verify() } - pub fn kvs(&self) -> HashMap { - self.pod - .kvs() - .into_iter() - .map(|(middleware::AnchoredKey(_, k), v)| (k, v)) - .collect() - } -} - -#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] -pub struct AnchoredKey { - pub origin: Origin, - pub key: String, -} - -impl AnchoredKey { - pub fn new(origin: Origin, key: String) -> Self { - Self { origin, key } - } -} - -impl From for middleware::AnchoredKey { - fn from(ak: AnchoredKey) -> Self { - middleware::AnchoredKey(ak.origin.pod_id, hash_str(&ak.key)) + pub fn kvs(&self) -> &HashMap { + &self.kvs + } + pub fn get(&self, key: impl Into) -> Option<&Value> { + self.kvs.get(&key.into()) + } + // Returns the ValueOf statement that defines key if it exists. + pub fn get_statement(&self, key: impl Into) -> Option { + let key: Key = key.into(); + self.kvs() + .get(&key) + .map(|value| Statement::ValueOf(AnchoredKey::from((self.id(), key)), value.clone())) } } +/// The MainPodBuilder allows interactive creation of a MainPod by applying operations and creating +/// the corresponding statements. #[derive(Debug)] pub struct MainPodBuilder { pub params: Params, @@ -312,8 +129,10 @@ pub struct MainPodBuilder { pub operations: Vec, pub public_statements: Vec, // Internal state + /// Counter for constants created from literals const_cnt: usize, - key_table: HashMap, + /// Map from (public, Value) to Key of already created literals via ValueOf statements. + literals: HashMap<(bool, Value), Key>, } impl fmt::Display for MainPodBuilder { @@ -347,36 +166,29 @@ impl MainPodBuilder { operations: Vec::new(), public_statements: Vec::new(), const_cnt: 0, - key_table: HashMap::new(), + literals: HashMap::new(), } } pub fn add_signed_pod(&mut self, pod: &SignedPod) { self.input_signed_pods.push(pod.clone()); - // Add key-hash correspondences to key table. - pod.kvs.iter().for_each(|(key, _)| { - self.key_table.insert(hash_str(key), key.clone()); - }); } pub fn add_main_pod(&mut self, pod: MainPod) { - // Add key-hash and POD ID-class correspondences to tables. - pod.public_statements - .iter() - .flat_map(|s| &s.args) - .flat_map(|arg| match arg { - StatementArg::Key(AnchoredKey { origin: _, key }) => { - Some((hash_str(key), key.clone())) - } - _ => None, - }) - .for_each(|(hash, key)| { - self.key_table.insert(hash, key); - }); self.input_main_pods.push(pod); } - pub fn insert(&mut self, st_op: (Statement, Operation)) { + pub fn insert(&mut self, public: bool, st_op: (Statement, Operation)) { + // TODO: Do error handling instead of panic let (st, op) = st_op; + if public { + self.public_statements.push(st.clone()); + } + if self.public_statements.len() > self.params.max_public_statements { + panic!("too many public statements"); + } self.statements.push(st); self.operations.push(op); + if self.statements.len() > self.params.max_statements { + panic!("too many statements"); + } } /// Convert [OperationArg]s to [StatementArg]s for the operations that work with entries @@ -386,26 +198,24 @@ impl MainPodBuilder { args: &mut [OperationArg], ) -> Result> { let mut st_args = Vec::new(); + // TODO: Rewrite without calling args() and instead using matches? for arg in args.iter_mut() { match arg { OperationArg::Statement(s) => { - if s.predicate == Predicate::Native(NativePredicate::ValueOf) { - st_args.push(s.args[0].clone()) + if s.predicate() == Predicate::Native(NativePredicate::ValueOf) { + st_args.push(s.args()[0].clone()) } else { panic!("Invalid statement argument."); } } // todo: better error handling OperationArg::Literal(v) => { - let value_of_st = self.literal(public, v)?; + let value_of_st = self.literal(public, v.clone())?; *arg = OperationArg::Statement(value_of_st.clone()); - st_args.push(value_of_st.args[0].clone()) + st_args.push(value_of_st.args()[0].clone()) } OperationArg::Entry(k, v) => { - st_args.push(StatementArg::Key(AnchoredKey::new( - Origin::new(SELF), - k.clone(), - ))); + st_args.push(StatementArg::Key(AnchoredKey::from((SELF, k.as_str())))); st_args.push(StatementArg::Literal(v.clone())) } }; @@ -421,14 +231,46 @@ impl MainPodBuilder { self.op(false, op) } - fn op(&mut self, public: bool, mut op: Operation) -> Result { + /// Lower syntactic sugar operation into backend compatible operation. + /// - {Dict,Array,Set}Contains/NotContains becomes Contains/NotContains. + fn lower_op(op: Operation) -> Operation { use NativeOperation::*; + use OperationType::*; + match op.0 { + Native(DictContainsFromEntries) => { + let [dict, key, value] = op.1.try_into().unwrap(); // TODO: Error handling + Operation(Native(ContainsFromEntries), vec![dict, key, value], op.2) + } + Native(DictNotContainsFromEntries) => { + let [dict, key] = op.1.try_into().unwrap(); // TODO: Error handling + Operation(Native(NotContainsFromEntries), vec![dict, key], op.2) + } + Native(SetContainsFromEntries) => { + let [set, value] = op.1.try_into().unwrap(); // TODO: Error handling + let empty = OperationArg::Literal(Value::from(EMPTY_VALUE)); + Operation(Native(ContainsFromEntries), vec![set, value, empty], op.2) + } + Native(SetNotContainsFromEntries) => { + let [set, value] = op.1.try_into().unwrap(); // TODO: Error handling + Operation(Native(NotContainsFromEntries), vec![set, value], op.2) + } + Native(ArrayContainsFromEntries) => { + let [array, index, value] = op.1.try_into().unwrap(); // TODO: Error handling + Operation(Native(ContainsFromEntries), vec![array, index, value], op.2) + } + _ => op, + } + } + + fn op(&mut self, public: bool, op: Operation) -> Result { + use NativeOperation::*; + let mut op = Self::lower_op(op); let Operation(op_type, ref mut args, _) = &mut op; // TODO: argument type checking let pred = op_type.output_predicate().map(Ok).unwrap_or_else(|| { // We are dealing with a copy here. match (args).first() { - Some(OperationArg::Statement(s)) if args.len() == 1 => Ok(s.predicate.clone()), + Some(OperationArg::Statement(s)) if args.len() == 1 => Ok(s.predicate().clone()), _ => Err(anyhow!("Invalid arguments to copy operation: {:?}", args)), } })?; @@ -438,7 +280,7 @@ impl MainPodBuilder { None => vec![], NewEntry => self.op_args_entries(public, args)?, CopyStatement => match &args[0] { - OperationArg::Statement(s) => s.args.clone(), + OperationArg::Statement(s) => s.args().clone(), _ => { return Err(anyhow!("Invalid arguments to copy operation: {}", op)); } @@ -450,20 +292,14 @@ impl MainPodBuilder { TransitiveEqualFromStatements => { match (args[0].clone(), args[1].clone()) { ( - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::Equal), - args: st0_args, - }), - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::Equal), - args: st1_args, - }), + OperationArg::Statement(Statement::Equal(ak0, ak1)), + OperationArg::Statement(Statement::Equal(ak2, ak3)), ) => { // st_args0 == vec![ak0, ak1] // st_args1 == vec![ak1, ak2] // output statement Equals(ak0, ak2) - if st0_args[1] == st1_args[0] { - vec![st0_args[0].clone(), st1_args[1].clone()] + if ak1 == ak2 { + vec![StatementArg::Key(ak0), StatementArg::Key(ak3)] } else { return Err(anyhow!( "Invalid arguments to transitive equality operation" @@ -478,22 +314,16 @@ impl MainPodBuilder { } } GtToNotEqual => match args[0].clone() { - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::Gt), - args: st_args, - }) => { - vec![st_args[0].clone()] + OperationArg::Statement(Statement::Gt(ak0, ak1)) => { + vec![StatementArg::Key(ak0), StatementArg::Key(ak1)] } _ => { return Err(anyhow!("Invalid arguments to gt-to-neq operation")); } }, LtToNotEqual => match args[0].clone() { - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::Lt), - args: st_args, - }) => { - vec![st_args[0].clone()] + OperationArg::Statement(Statement::Lt(ak0, ak1)) => { + vec![StatementArg::Key(ak0), StatementArg::Key(ak1)] } _ => { return Err(anyhow!("Invalid arguments to lt-to-neq operation")); @@ -501,47 +331,22 @@ impl MainPodBuilder { }, SumOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { ( - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: st0_args, - }), - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: st1_args, - }), - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: st2_args, - }), + OperationArg::Statement(Statement::ValueOf(ak0, v0)), + OperationArg::Statement(Statement::ValueOf(ak1, v1)), + OperationArg::Statement(Statement::ValueOf(ak2, v2)), ) => { - let st_args: Vec = match ( - st0_args[1].clone(), - st1_args[1].clone(), - st2_args[1].clone(), - ) { - ( - StatementArg::Literal(v0), - StatementArg::Literal(v1), - StatementArg::Literal(v2), - ) => { - let v0: i64 = v0.clone().try_into()?; - let v1: i64 = v1.clone().try_into()?; - let v2: i64 = v2.clone().try_into()?; - if v0 == v1 + v2 { - vec![ - st0_args[0].clone(), - st1_args[0].clone(), - st2_args[0].clone(), - ] - } else { - return Err(anyhow!("Invalid arguments to sum-of operation")); - } - } - _ => { - return Err(anyhow!("Invalid arguments to sum-of operation")); - } - }; - st_args + let v0: i64 = v0.typed().try_into()?; + let v1: i64 = v1.typed().try_into()?; + let v2: i64 = v2.typed().try_into()?; + if v0 == v1 + v2 { + vec![ + StatementArg::Key(ak0), + StatementArg::Key(ak1), + StatementArg::Key(ak2), + ] + } else { + return Err(anyhow!("Invalid arguments to sum-of operation")); + } } _ => { return Err(anyhow!("Invalid arguments to sum-of operation")); @@ -549,49 +354,22 @@ impl MainPodBuilder { }, ProductOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { ( - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: st0_args, - }), - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: st1_args, - }), - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: st2_args, - }), + OperationArg::Statement(Statement::ValueOf(ak0, v0)), + OperationArg::Statement(Statement::ValueOf(ak1, v1)), + OperationArg::Statement(Statement::ValueOf(ak2, v2)), ) => { - let st_args: Vec = match ( - st0_args[1].clone(), - st1_args[1].clone(), - st2_args[1].clone(), - ) { - ( - StatementArg::Literal(v0), - StatementArg::Literal(v1), - StatementArg::Literal(v2), - ) => { - let v0: i64 = v0.clone().try_into()?; - let v1: i64 = v1.clone().try_into()?; - let v2: i64 = v2.clone().try_into()?; - if v0 == v1 * v2 { - vec![ - st0_args[0].clone(), - st1_args[0].clone(), - st2_args[0].clone(), - ] - } else { - return Err(anyhow!( - "Invalid arguments to product-of operation" - )); - } - } - _ => { - return Err(anyhow!("Invalid arguments to product-of operation")); - } - }; - st_args + let v0: i64 = v0.typed().try_into()?; + let v1: i64 = v1.typed().try_into()?; + let v2: i64 = v2.typed().try_into()?; + if v0 == v1 * v2 { + vec![ + StatementArg::Key(ak0), + StatementArg::Key(ak1), + StatementArg::Key(ak2), + ] + } else { + return Err(anyhow!("Invalid arguments to product-of operation")); + } } _ => { return Err(anyhow!("Invalid arguments to product-of operation")); @@ -599,52 +377,31 @@ impl MainPodBuilder { }, MaxOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { ( - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: st0_args, - }), - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: st1_args, - }), - OperationArg::Statement(Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: st2_args, - }), + OperationArg::Statement(Statement::ValueOf(ak0, v0)), + OperationArg::Statement(Statement::ValueOf(ak1, v1)), + OperationArg::Statement(Statement::ValueOf(ak2, v2)), ) => { - let st_args: Vec = match ( - st0_args[1].clone(), - st1_args[1].clone(), - st2_args[1].clone(), - ) { - ( - StatementArg::Literal(v0), - StatementArg::Literal(v1), - StatementArg::Literal(v2), - ) => { - let v0: i64 = v0.clone().try_into()?; - let v1: i64 = v1.clone().try_into()?; - let v2: i64 = v2.clone().try_into()?; - if v0 == std::cmp::max(v1, v2) { - vec![ - st0_args[0].clone(), - st1_args[0].clone(), - st2_args[0].clone(), - ] - } else { - return Err(anyhow!("Invalid arguments to max-of operation")); - } - } - _ => { - return Err(anyhow!("Invalid arguments to max-of operation")); - } - }; - st_args + let v0: i64 = v0.typed().try_into()?; + let v1: i64 = v1.typed().try_into()?; + let v2: i64 = v2.typed().try_into()?; + if v0 == std::cmp::max(v1, v2) { + vec![ + StatementArg::Key(ak0), + StatementArg::Key(ak1), + StatementArg::Key(ak2), + ] + } else { + return Err(anyhow!("Invalid arguments to max-of operation")); + } } _ => { - return Err(anyhow!("Invalid arguments to operation")); + return Err(anyhow!("Invalid arguments to max-of operation")); } }, + ContainsFromEntries => self.op_args_entries(public, args)?, + NotContainsFromEntries => self.op_args_entries(public, args)?, + // NOTE: Could we remove these and assume that this function is never called with + // syntax sugar operations? DictContainsFromEntries => self.op_args_entries(public, args)?, DictNotContainsFromEntries => self.op_args_entries(public, args)?, SetContainsFromEntries => self.op_args_entries(public, args)?, @@ -652,84 +409,78 @@ impl MainPodBuilder { ArrayContainsFromEntries => self.op_args_entries(public, args)?, }, OperationType::Custom(cpr) => { + let pred = &cpr.batch.predicates[cpr.index]; + if pred.statements.len() != args.len() { + return Err(anyhow!( + "Custom predicate operation needs {} statements but has {}.", + pred.statements.len(), + args.len() + )); + } // All args should be statements to be pattern matched against statement templates. let args = args.iter().map( |a| match a { OperationArg::Statement(s) => Ok(s.clone()), - _ => Err(anyhow!("Invalid argument {} to operation corresponding to custom predicate {:?}.", a, cpr)) + _ => Err(anyhow!("Invalid argument {} to operation corresponding to custom predicate {:?}.", a, cpr)) } ).collect::>>()?; - // Match these statements against the custom predicate definition - let bindings = cpr.match_against(&args)?; - let output_arg_values = (0..cpr.arg_len()) - .map(|i| { - bindings.get(&i).cloned().ok_or(anyhow!( - "Wildcard {} of custom predicate {:?} is unbound.", - i, - cpr - )) - }) - .collect::>>()?; - output_arg_values - .chunks(2) - .map(|chunk| { - Ok(StatementArg::Key(AnchoredKey::new( - Origin::new(PodId(match chunk[0] { - Value::Raw(v) => v.into(), - _ => return Err(anyhow!("Invalid POD class value.")), - })), - self.key_table - .get(&match &chunk[1] { - Value::String(s) => hash_str(s.as_str()), - _ => return Err(anyhow!("Invalid key value.")), - }) - .cloned() - .ok_or(anyhow!("Missing key corresponding to hash."))?, - ))) - }) - .collect::>>()? + let mut wildcard_map = + vec![Option::None; self.params.max_custom_predicate_wildcards]; + for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) { + let st_args = st.args(); + for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { + if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) { + // TODO: Add wildcard_map in the error for better context + return Err(anyhow!("{} doesn't match {}", st, st_tmpl)); + } + } + } + let v_default = WildcardValue::PodId(SELF); + wildcard_map + .into_iter() + .take(pred.args_len) + .map(|v| StatementArg::WildcardLiteral(v.unwrap_or_else(|| v_default.clone()))) + .collect() } }; - let st = Statement::new(pred, st_args); - self.operations.push(op); - if public { - self.public_statements.push(st.clone()); - } + let st = Statement::from_args(pred, st_args).expect("valid arguments"); + self.insert(public, (st, op)); - // Add key-hash pairs in statement to table. - st.args.iter().for_each(|arg| { - if let StatementArg::Key(AnchoredKey { origin: _, key }) = arg { - self.key_table.insert(hash_str(key), key.clone()); - } - }); - - self.statements.push(st); Ok(self.statements[self.statements.len() - 1].clone()) } /// Convenience method for introducing public constants. - pub fn pub_literal>(&mut self, v: &V) -> Result { - self.literal(true, v) + pub fn pub_literal(&mut self, v: impl Into) -> Result { + self.literal(true, v.into()) } /// Convenience method for introducing private constants. - pub fn priv_literal>(&mut self, v: &V) -> Result { - self.literal(false, v) + pub fn priv_literal(&mut self, v: impl Into) -> Result { + self.literal(false, v.into()) } - fn literal>(&mut self, public: bool, v: &V) -> Result { - let v: Value = v.clone().into(); - let k = format!("c{}", self.const_cnt); - self.const_cnt += 1; - self.op( - public, - Operation( - OperationType::Native(NativeOperation::NewEntry), - vec![OperationArg::Entry(k.clone(), v)], - OperationAux::None, - ), - ) + fn literal(&mut self, public: bool, value: Value) -> Result { + let public_value = (public, value); + if let Some(key) = self.literals.get(&public_value) { + Ok(Statement::ValueOf( + AnchoredKey::new(SELF, key.clone()), + public_value.1, + )) + } else { + let key = format!("c{}", self.const_cnt); + self.literals + .insert(public_value.clone(), Key::new(key.clone())); + self.const_cnt += 1; + self.op( + public, + Operation( + OperationType::Native(NativeOperation::NewEntry), + vec![OperationArg::Entry(key.clone(), public_value.1)], + OperationAux::None, + ), + ) + } } pub fn reveal(&mut self, st: &Statement) { @@ -772,19 +523,14 @@ impl MainPodBuilder { .pub_statements() .into_iter() .find_map(|s| match s { - crate::middleware::Statement::ValueOf( - crate::middleware::AnchoredKey(id, key), - value, - ) if id == pod_id && key == type_key_hash => Some(Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: vec![ - StatementArg::Key(AnchoredKey::new( - Origin::new(pod_id), - KEY_TYPE.to_string(), - )), - StatementArg::Literal(value.into()), - ], - }), + Statement::ValueOf(AnchoredKey { pod_id: id, key }, value) + if id == pod_id && key.hash() == type_key_hash => + { + Some(Statement::ValueOf( + AnchoredKey::from((pod_id, KEY_TYPE)), + value, + )) + } _ => None, }) .ok_or(anyhow!("Missing POD type information in POD: {:?}", pod))?; @@ -793,21 +539,18 @@ impl MainPodBuilder { let public_statements = [type_statement] .into_iter() .chain(self.public_statements.clone().into_iter().map(|s| { - let s_type = s.predicate; + let s_type = s.predicate(); let s_args = s - .args + .args() .into_iter() .map(|arg| match arg { - StatementArg::Key(AnchoredKey { - origin: Origin { pod_id: id }, - key, - }) if id == SELF => { - StatementArg::Key(AnchoredKey::new(Origin::new(pod_id), key)) + StatementArg::Key(AnchoredKey { pod_id: id, key }) if id == SELF => { + StatementArg::Key(AnchoredKey::new(pod_id, key)) } _ => arg, }) .collect(); - Statement::new(s_type, s_args) + Statement::from_args(s_type, s_args).expect("valid arguments") })) .collect(); @@ -818,8 +561,8 @@ impl MainPodBuilder { } } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(try_from = "MainPodHelper", into = "MainPodHelper")] +#[derive(Debug, Clone)] +// #[serde(try_from = "MainPodHelper", into = "MainPodHelper")] pub struct MainPod { pub pod: Box, pub public_statements: Vec, @@ -845,9 +588,6 @@ impl MainPod { pub fn id(&self) -> PodId { self.pod.id() } - pub fn origin(&self) -> Origin { - Origin::new(self.id()) - } } struct MainPodCompilerInputs<'a> { @@ -858,18 +598,12 @@ struct MainPodCompilerInputs<'a> { pub public_statements: &'a [Statement], } +/// The compiler converts frontend::Operation into middleware::Operation struct MainPodCompiler { params: Params, // Output - statements: Vec, + statements: Vec, operations: Vec, - // Internal state - // Tracks literal constants assigned to ValueOf statements by self.literal() - // If `val` has been added as a literal, - // then `self.literals.get(&val)` returns `Some(idx)`, and - // then `self.statements[idx]` is the ValueOf statement - // where it was introduced. - literals: HashMap, } impl MainPodCompiler { @@ -878,18 +612,20 @@ impl MainPodCompiler { params: params.clone(), statements: Vec::new(), operations: Vec::new(), - literals: HashMap::new(), } } - fn push_st_op(&mut self, st: middleware::Statement, op: middleware::Operation) { + fn push_st_op(&mut self, st: Statement, op: middleware::Operation) { self.statements.push(st); self.operations.push(op); + if self.statements.len() > self.params.max_statements { + panic!("too many statements"); + } } - fn compile_op_arg(&self, op_arg: &OperationArg) -> Option { + fn compile_op_arg(&self, op_arg: &OperationArg) -> Option { match op_arg { - OperationArg::Statement(s) => self.compile_st(s).ok(), + OperationArg::Statement(s) => Some(s.clone()), OperationArg::Literal(_v) => { // OperationArg::Literal is a syntax sugar for the frontend. This is translated to // a new ValueOf statement and it's key used instead. @@ -904,151 +640,28 @@ impl MainPodCompiler { } } - // Introduces a literal value if it hasn't been introduced, - // or else returns the existing ValueOf statement where it was first introduced. - // TODO: this might produce duplicate keys, fix - fn literal>(&mut self, val: V) -> &middleware::Statement { - let val: middleware::Value = val.into(); - match self.literals.get(&val) { - Some(idx) => &self.statements[*idx], - None => { - let ak = middleware::AnchoredKey(SELF, hash_value(&val)); - let st = middleware::Statement::ValueOf(ak, val); - let op = middleware::Operation::NewEntry; - self.statements.push(st); - self.operations.push(op); - self.statements.last().unwrap() - } - } - } - - // Returns the existing ValueOf statement where it was first introduced, - // or None if it does not exist. - fn get_literal>( - &self, - val: V, - ) -> Option<&middleware::Statement> { - let val: middleware::Value = val.into(); - match self.literals.get(&val) { - Some(idx) => Some(&self.statements[*idx]), - None => None, - } - } - - // This function handles cases where one frontend statement - // compiles to multiple middleware statements. - // For example: DictContains(x, y) on the frontend compiles to: - // ValueOf(empty, EMPTY_VALUE) - // Contains(x, y, empty) - fn manual_compile_st_op(&mut self, st: &Statement, op: &Operation) -> Result<()> { - match st.predicate { - Predicate::Native(NativePredicate::DictContains) => { - let empty_st = self.literal(EMPTY_VALUE).clone(); - let empty_ak = match empty_st { - middleware::Statement::ValueOf(ak, _) => ak, - _ => unreachable!(), - }; - let (ak1, ak2) = match (st.args.get(0).cloned(), st.args.get(1).cloned()) { - (Some(StatementArg::Key(ak1)), Some(StatementArg::Key(ak2))) => (ak1, ak2), - _ => Err(anyhow!("Ill-formed statement: {}", st))?, - }; - let middle_st = middleware::Statement::Contains(ak1.into(), ak2.into(), empty_ak); - let middle_op = middleware::Operation::ContainsFromEntries( - match &op.1[0] { - OperationArg::Statement(s) => self.compile_st(s)?, - _ => Err(anyhow!("Statement compile failed in manual compile"))?, - }, - match &op.1[1] { - OperationArg::Statement(s) => self.compile_st(s)?, - _ => Err(anyhow!("Statement compile failed in manual compile"))?, - }, - empty_st, - match &op.2 { - OperationAux::MerkleProof(mp) => mp.clone(), - _ => { - return Err(anyhow!( - "Auxiliary argument to DictContainsFromEntries must be Merkle proof" - )); - } - }, - ); - self.statements.push(middle_st); - self.operations.push(middle_op); - assert_eq!(self.statements.len(), self.operations.len()); - Ok(()) - } - _ => unreachable!(), - } - } - - // If the frontend statement `st` compiles to a single middleware statement, - // returns that middleware statement. - // If it compiles to multiple middlewarestatements, returns StatementConversionError. - // This is only a helper method within compile_st_op(). - // If you want to compile a statement in general, run compile_st(). - fn compile_st_try_simple( - &self, - st: &Statement, - ) -> Result { - st.clone().try_into() - } - - // Compiles the frontend statement `st` to a middleware statement. - // This function assumes the middleware statement already exists -- - // it should not be called from compile_st_op. - fn compile_st(&self, st: &Statement) -> Result { - match self.compile_st_try_simple(st) { - Ok(s) => Ok(s), - Err(StatementConversionError::Error(e)) => Err(e), - Err(StatementConversionError::MCR(_)) => { - let empty_st = self - .get_literal(EMPTY_VALUE) - .ok_or(anyhow!("Literal value not found for empty literal."))?; - let empty_ak = match empty_st { - middleware::Statement::ValueOf(ak, _) => ak, - _ => unreachable!(), - }; - let (ak1, ak2) = match (st.args.get(0).cloned(), st.args.get(1).cloned()) { - (Some(StatementArg::Key(ak1)), Some(StatementArg::Key(ak2))) => (ak1, ak2), - _ => Err(anyhow!("Ill-formed statement: {}", st))?, - }; - let middle_st = middleware::Statement::Contains(ak1.into(), ak2.into(), *empty_ak); - Ok(middle_st) - } - } - } - fn compile_op(&self, op: &Operation) -> Result { - let mop_code: middleware::OperationType = op.0.clone().try_into()?; - // TODO: Take Merkle proof into account. let mop_args = op.1.iter() .flat_map(|arg| self.compile_op_arg(arg)) .collect_vec(); - middleware::Operation::op(mop_code, &mop_args, &op.2) + middleware::Operation::op(op.0.clone(), &mop_args, &op.2) } fn compile_st_op(&mut self, st: &Statement, op: &Operation, params: &Params) -> Result<()> { - let middle_st_res = self.compile_st_try_simple(st); - match middle_st_res { - Ok(middle_st) => { - let middle_op = self.compile_op(op)?; - let is_correct = middle_op.check(params, &middle_st)?; - if !is_correct { - // todo: improve error handling - Err(anyhow!( - "Compile failed due to invalid deduction:\n {} ⇏ {}", - middle_op, - middle_st - )) - } else { - self.push_st_op(middle_st, middle_op); - Ok(()) - } - } - Err(StatementConversionError::Error(e)) => Err(e), - Err(StatementConversionError::MCR(_)) => self.manual_compile_st_op(st, op), + let middle_op = self.compile_op(op)?; + let is_correct = middle_op.check(params, st)?; + if !is_correct { + // todo: improve error handling + Err(anyhow!( + "Compile failed due to invalid deduction:\n {} ⇏ {}", + middle_op, + st + )) + } else { + self.push_st_op(st.clone(), middle_op); + Ok(()) } } @@ -1057,9 +670,9 @@ impl MainPodCompiler { inputs: MainPodCompilerInputs<'_>, params: &Params, ) -> Result<( - Vec, // input statements + Vec, // input statements Vec, - Vec, // public statements + Vec, // public statements )> { let MainPodCompilerInputs { // signed_pods: _, @@ -1070,15 +683,8 @@ impl MainPodCompiler { } = inputs; for (st, op) in statements.iter().zip_eq(operations.iter()) { self.compile_st_op(st, op, params)?; - if self.statements.len() > self.params.max_statements { - panic!("too many statements"); - } } - let public_statements = public_statements - .iter() - .map(|st| self.compile_st(st)) - .collect::>>()?; - Ok((self.statements, self.operations, public_statements)) + Ok((self.statements, self.operations, public_statements.to_vec())) } } @@ -1095,56 +701,56 @@ pub mod build_utils { #[macro_export] macro_rules! op { (new_entry, ($key:expr, $value:expr)) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::NewEntry), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::NewEntry), $crate::op_args!(($key, $value)), $crate::middleware::OperationAux::None) }; (eq, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::EqualFromEntries), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::EqualFromEntries), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (ne, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::NotEqualFromEntries), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::NotEqualFromEntries), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (gt, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::GtFromEntries), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::GtFromEntries), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (lt, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::LtFromEntries), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::LtFromEntries), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (transitive_eq, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::TransitiveEqualFromStatements), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::TransitiveEqualFromStatements), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (gt_to_ne, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::GtToNotEqual), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::GtToNotEqual), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (lt_to_ne, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::LtToNotEqual), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::LtToNotEqual), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (sum_of, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::SumOf), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::SumOf), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (product_of, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::ProductOf), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::ProductOf), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (max_of, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::MaxOf), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::MaxOf), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (custom, $op:expr, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Custom($op), + $crate::middleware::OperationType::Custom($op), $crate::op_args!($($arg),*), $crate::middleware::OperationAux::None) }; (dict_contains, $dict:expr, $key:expr, $value:expr, $aux:expr) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::DictContainsFromEntries), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::DictContainsFromEntries), $crate::op_args!($dict, $key, $value), $crate::middleware::OperationAux::MerkleProof($aux)) }; (dict_not_contains, $dict:expr, $key:expr, $aux:expr) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::DictNotContainsFromEntries), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::DictNotContainsFromEntries), $crate::op_args!($dict, $key), $crate::middleware::OperationAux::MerkleProof($aux)) }; (set_contains, $set:expr, $value:expr, $aux:expr) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::SetContainsFromEntries), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::SetContainsFromEntries), $crate::op_args!($set, $value), $crate::middleware::OperationAux::MerkleProof($aux)) }; (set_not_contains, $set:expr, $value:expr, $aux:expr) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::SetNotContainsFromEntries), + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::SetNotContainsFromEntries), $crate::op_args!($set, $value), $crate::middleware::OperationAux::MerkleProof($aux)) }; - (array_contains, $array:expr, $value:expr, $aux:expr) => { $crate::frontend::Operation( - $crate::frontend::OperationType::Native($crate::frontend::NativeOperation::ArrayContainsFromEntries), - $crate::op_args!($array, $value), $crate::middleware::OperationAux::MerkleProof($aux)) }; + (array_contains, $array:expr, $index:expr, $value:expr, $aux:expr) => { $crate::frontend::Operation( + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::ArrayContainsFromEntries), + $crate::op_args!($array, $index, $value), $crate::middleware::OperationAux::MerkleProof($aux)) }; } } @@ -1152,14 +758,12 @@ pub mod build_utils { pub mod tests { use super::*; use crate::{ - backends::plonky2::{ - basetypes, - mock::{mainpod::MockProver, signedpod::MockSigner}, - }, + backends::plonky2::mock::{mainpod::MockProver, signedpod::MockSigner}, examples::{ eth_dos_pod_builder, eth_friend_signed_pod_builder, great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_sign_pod_builders, }, + middleware::{containers::Dictionary, Value}, }; // Check that frontend public statements agree with those @@ -1179,14 +783,15 @@ pub mod tests { fn check_kvs(pod: &SignedPod) -> Result<()> { let kvs = pod .kvs - .iter() - .map(|(k, v)| (hash_str(k), middleware::Value::from(v))) + .clone() + .into_iter() + .map(|(k, v)| (k, v)) .collect::>(); let embedded_kvs = pod .pod .kvs() .into_iter() - .map(|(middleware::AnchoredKey(_, k), v)| (k, v)) + .map(|(middleware::AnchoredKey { key, .. }, v)| (key, v)) .collect::>(); if kvs == embedded_kvs { @@ -1203,9 +808,7 @@ pub mod tests { #[test] fn test_front_zu_kyc() -> Result<()> { let params = Params::default(); - let sanctions_values = vec!["A343434340".into()]; - let sanction_set = Value::Set(Set::new(sanctions_values)?); - let (gov_id, pay_stub, sanction_list) = zu_kyc_sign_pod_builders(¶ms, &sanction_set); + let (gov_id, pay_stub, sanction_list) = zu_kyc_sign_pod_builders(¶ms); println!("{}", gov_id); println!("{}", pay_stub); @@ -1251,10 +854,11 @@ pub mod tests { max_statements: 31, max_signed_pod_values: 8, max_public_statements: 10, - max_statement_args: 5, + max_statement_args: 6, max_operation_args: 5, max_custom_predicate_arity: 5, max_custom_batch_size: 5, + max_custom_predicate_wildcards: 12, ..Default::default() }; @@ -1385,15 +989,14 @@ pub mod tests { let params = Params::default(); let mut builder = SignedPodBuilder::new(¶ms); - type BeValue = basetypes::Value; - let mut my_dict_kvs: HashMap = HashMap::new(); - my_dict_kvs.insert("a".to_string(), Value::from(1)); - my_dict_kvs.insert("b".to_string(), Value::from(2)); - my_dict_kvs.insert("c".to_string(), Value::from(3)); + let mut my_dict_kvs: HashMap = HashMap::new(); + my_dict_kvs.insert(Key::from("a"), Value::from(1)); + my_dict_kvs.insert(Key::from("b"), Value::from(2)); + my_dict_kvs.insert(Key::from("c"), Value::from(3)); // let my_dict_as_mt = MerkleTree::new(5, &my_dict_kvs).unwrap(); // let dict = Dictionary { mt: my_dict_as_mt }; let dict = Dictionary::new(my_dict_kvs)?; - let dict_root = Value::Dictionary(dict.clone()); + let dict_root = Value::from(dict.clone()); builder.insert("dict", dict_root); let mut signer = MockSigner { @@ -1403,9 +1006,9 @@ pub mod tests { let mut builder = MainPodBuilder::new(¶ms); builder.add_signed_pod(&pod); - let st0 = Statement::from((&pod, "dict")); + let st0 = pod.get_statement("dict").unwrap(); let st1 = builder.op(true, op!(new_entry, ("key", "a"))).unwrap(); - let st2 = builder.literal(false, &Value::Int(1)).unwrap(); + let st2 = builder.literal(false, Value::from(1)).unwrap(); builder .pub_op(Operation( @@ -1417,12 +1020,7 @@ pub mod tests { OperationArg::Statement(st1), OperationArg::Statement(st2), ], - OperationAux::MerkleProof( - dict.middleware_dict() - .prove(&Hash::from("a").into()) - .unwrap() - .1, - ), + OperationAux::MerkleProof(dict.prove(&Key::from("a")).unwrap().1), )) .unwrap(); let mut main_prover = MockProver {}; @@ -1434,6 +1032,7 @@ pub mod tests { } #[should_panic] + #[test] fn test_incorrect_pod() { // try to insert the same key multiple times // right now this is not caught when you build the pod, @@ -1442,93 +1041,45 @@ pub mod tests { let params = Params::default(); let mut builder = MainPodBuilder::new(¶ms); - builder.insert(( - Statement::new( - Predicate::Native(NativePredicate::ValueOf), - vec![ - StatementArg::Key(AnchoredKey::new(Origin::new(SELF), "a".into())), - StatementArg::Literal(Value::Int(3)), - ], - ), - Operation( - OperationType::Native(NativeOperation::NewEntry), - vec![], - OperationAux::None, - ), - )); - builder.insert(( - Statement::new( - Predicate::Native(NativePredicate::ValueOf), - vec![ - StatementArg::Key(AnchoredKey::new(Origin::new(SELF), "a".into())), - StatementArg::Literal(Value::Int(28)), - ], - ), - Operation( - OperationType::Native(NativeOperation::NewEntry), - vec![], - OperationAux::None, - ), - )); + let st = Statement::ValueOf(AnchoredKey::from((SELF, "a")), Value::from(3)); + let op_new_entry = Operation( + OperationType::Native(NativeOperation::NewEntry), + vec![], + OperationAux::None, + ); + builder.insert(false, (st, op_new_entry.clone())); + + let st = Statement::ValueOf(AnchoredKey::from((SELF, "a")), Value::from(28)); + builder.insert(false, (st, op_new_entry.clone())); let mut prover = MockProver {}; let pod = builder.prove(&mut prover, ¶ms).unwrap(); - pod.pod.verify(); + pod.pod.verify().unwrap(); // try to insert a statement that doesn't follow from the operation // right now the mock prover catches this when it calls compile() let params = Params::default(); let mut builder = MainPodBuilder::new(¶ms); - let self_a = AnchoredKey::new(Origin::new(SELF), "a".into()); - let self_b = AnchoredKey::new(Origin::new(SELF), "b".into()); - let value_of_a = Statement::new( - Predicate::Native(NativePredicate::ValueOf), - vec![ - StatementArg::Key(self_a.clone()), - StatementArg::Literal(Value::Int(3)), - ], - ); - let value_of_b = Statement::new( - Predicate::Native(NativePredicate::ValueOf), - vec![ - StatementArg::Key(self_b.clone()), - StatementArg::Literal(Value::Int(27)), - ], - ); + let self_a = AnchoredKey::from((SELF, "a")); + let self_b = AnchoredKey::from((SELF, "b")); + let value_of_a = Statement::ValueOf(self_a.clone(), Value::from(3)); + let value_of_b = Statement::ValueOf(self_b.clone(), Value::from(27)); - builder.insert(( - value_of_a.clone(), - Operation( - OperationType::Native(NativeOperation::NewEntry), - vec![], - OperationAux::None, - ), - )); - builder.insert(( - value_of_b.clone(), - Operation( - OperationType::Native(NativeOperation::NewEntry), - vec![], - OperationAux::None, - ), - )); - builder.insert(( - Statement::new( - Predicate::Native(NativePredicate::Equal), - vec![StatementArg::Key(self_a), StatementArg::Key(self_b)], - ), - Operation( - OperationType::Native(NativeOperation::EqualFromEntries), - vec![ - OperationArg::Statement(value_of_a), - OperationArg::Statement(value_of_b), - ], - OperationAux::None, - ), - )); + builder.insert(false, (value_of_a.clone(), op_new_entry.clone())); + builder.insert(false, (value_of_b.clone(), op_new_entry)); + let st = Statement::Equal(self_a, self_b); + let op = Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![ + OperationArg::Statement(value_of_a), + OperationArg::Statement(value_of_b), + ], + OperationAux::None, + ); + builder.insert(false, (st, op)); let mut prover = MockProver {}; let pod = builder.prove(&mut prover, ¶ms).unwrap(); - pod.pod.verify(); + pod.pod.verify().unwrap(); } } diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index 1c884ee..15730de 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -1,13 +1,12 @@ use std::fmt; -use serde::{Deserialize, Serialize}; - +// use serde::{Deserialize, Serialize}; use crate::{ - frontend::{CustomPredicateRef, NativePredicate, Predicate, SignedPod, Statement, Value}, - middleware::{self, OperationAux}, + frontend::SignedPod, + middleware::{AnchoredKey, OperationAux, OperationType, Statement, Value}, }; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq)] pub enum OperationArg { Statement(Statement), Literal(Value), @@ -56,7 +55,16 @@ impl From for OperationArg { impl From<(&SignedPod, &str)> for OperationArg { fn from((pod, key): (&SignedPod, &str)) -> Self { - Self::Statement((pod, key).into()) + // TODO: TryFrom. + let value = pod + .kvs() + .get(&key.into()) + .cloned() + .unwrap_or_else(|| panic!("Key {} is not present in POD: {}", key, pod)); + Self::Statement(Statement::ValueOf( + AnchoredKey::from((pod.id(), key)), + value, + )) } } @@ -72,120 +80,7 @@ impl> From<(&str, V)> for OperationArg { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub enum OperationType { - Native(NativeOperation), - Custom(CustomPredicateRef), -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub enum NativeOperation { - None = 0, - NewEntry = 1, - CopyStatement = 2, - EqualFromEntries = 3, - NotEqualFromEntries = 4, - GtFromEntries = 5, - LtFromEntries = 6, - TransitiveEqualFromStatements = 7, - GtToNotEqual = 8, - LtToNotEqual = 9, - SumOf = 13, - ProductOf = 14, - MaxOf = 15, - DictContainsFromEntries = 16, - DictNotContainsFromEntries = 17, - SetContainsFromEntries = 18, - SetNotContainsFromEntries = 19, - ArrayContainsFromEntries = 20, -} - -impl TryFrom for middleware::OperationType { - type Error = anyhow::Error; - fn try_from(fe_ot: OperationType) -> Result { - type FeOT = OperationType; - type FeNO = NativeOperation; - type MwOT = middleware::OperationType; - type MwNO = middleware::NativeOperation; - let mw_ot = match fe_ot { - FeOT::Native(FeNO::None) => MwOT::Native(MwNO::None), - FeOT::Native(FeNO::NewEntry) => MwOT::Native(MwNO::NewEntry), - FeOT::Native(FeNO::CopyStatement) => MwOT::Native(MwNO::CopyStatement), - FeOT::Native(FeNO::EqualFromEntries) => MwOT::Native(MwNO::EqualFromEntries), - FeOT::Native(FeNO::NotEqualFromEntries) => MwOT::Native(MwNO::NotEqualFromEntries), - FeOT::Native(FeNO::GtFromEntries) => MwOT::Native(MwNO::GtFromEntries), - FeOT::Native(FeNO::LtFromEntries) => MwOT::Native(MwNO::LtFromEntries), - FeOT::Native(FeNO::TransitiveEqualFromStatements) => { - MwOT::Native(MwNO::TransitiveEqualFromStatements) - } - FeOT::Native(FeNO::GtToNotEqual) => MwOT::Native(MwNO::GtToNotEqual), - FeOT::Native(FeNO::LtToNotEqual) => MwOT::Native(MwNO::LtToNotEqual), - FeOT::Native(FeNO::SumOf) => MwOT::Native(MwNO::SumOf), - FeOT::Native(FeNO::ProductOf) => MwOT::Native(MwNO::ProductOf), - FeOT::Native(FeNO::MaxOf) => MwOT::Native(MwNO::MaxOf), - FeOT::Native(FeNO::DictContainsFromEntries) => MwOT::Native(MwNO::ContainsFromEntries), - FeOT::Native(FeNO::DictNotContainsFromEntries) => { - MwOT::Native(MwNO::NotContainsFromEntries) - } - FeOT::Native(FeNO::SetContainsFromEntries) => MwOT::Native(MwNO::ContainsFromEntries), - FeOT::Native(FeNO::SetNotContainsFromEntries) => { - MwOT::Native(MwNO::NotContainsFromEntries) - } - FeOT::Native(FeNO::ArrayContainsFromEntries) => MwOT::Native(MwNO::ContainsFromEntries), - FeOT::Custom(mw_cpr) => MwOT::Custom(mw_cpr.into()), - }; - Ok(mw_ot) - } -} - -impl OperationType { - /// Gives the type of predicate that the operation will output, if known. - /// CopyStatement may output any predicate (it will match the statement copied), - /// so output_predicate returns None on CopyStatement. - pub fn output_predicate(&self) -> Option { - match self { - OperationType::Native(native_op) => match native_op { - NativeOperation::None => Some(Predicate::Native(NativePredicate::None)), - NativeOperation::NewEntry => Some(Predicate::Native(NativePredicate::ValueOf)), - NativeOperation::CopyStatement => None, - NativeOperation::EqualFromEntries => { - Some(Predicate::Native(NativePredicate::Equal)) - } - NativeOperation::NotEqualFromEntries => { - Some(Predicate::Native(NativePredicate::NotEqual)) - } - NativeOperation::GtFromEntries => Some(Predicate::Native(NativePredicate::Gt)), - NativeOperation::LtFromEntries => Some(Predicate::Native(NativePredicate::Lt)), - NativeOperation::TransitiveEqualFromStatements => { - Some(Predicate::Native(NativePredicate::Equal)) - } - NativeOperation::GtToNotEqual => Some(Predicate::Native(NativePredicate::NotEqual)), - NativeOperation::LtToNotEqual => Some(Predicate::Native(NativePredicate::NotEqual)), - NativeOperation::SumOf => Some(Predicate::Native(NativePredicate::SumOf)), - NativeOperation::ProductOf => Some(Predicate::Native(NativePredicate::ProductOf)), - NativeOperation::MaxOf => Some(Predicate::Native(NativePredicate::MaxOf)), - NativeOperation::DictContainsFromEntries => { - Some(Predicate::Native(NativePredicate::DictContains)) - } - NativeOperation::DictNotContainsFromEntries => { - Some(Predicate::Native(NativePredicate::DictNotContains)) - } - NativeOperation::SetContainsFromEntries => { - Some(Predicate::Native(NativePredicate::SetContains)) - } - NativeOperation::SetNotContainsFromEntries => { - Some(Predicate::Native(NativePredicate::SetNotContains)) - } - NativeOperation::ArrayContainsFromEntries => { - Some(Predicate::Native(NativePredicate::ArrayContains)) - } - }, - OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())), - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq)] pub struct Operation(pub OperationType, pub Vec, pub OperationAux); impl fmt::Display for Operation { diff --git a/src/frontend/predicate.rs b/src/frontend/predicate.rs deleted file mode 100644 index b4581d6..0000000 --- a/src/frontend/predicate.rs +++ /dev/null @@ -1,58 +0,0 @@ -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; - -use crate::middleware::{self, CustomPredicateRef}; - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] -pub enum NativePredicate { - None = 0, - ValueOf = 1, - Equal = 2, - NotEqual = 3, - Gt = 4, - Lt = 5, - SumOf = 8, - ProductOf = 9, - MaxOf = 10, - DictContains = 11, - DictNotContains = 12, - SetContains = 13, - SetNotContains = 14, - ArrayContains = 15, // there is no ArrayNotContains -} - -impl From for middleware::NativePredicate { - fn from(np: NativePredicate) -> Self { - use middleware::NativePredicate as MidNP; - use NativePredicate::*; - match np { - None => MidNP::None, - ValueOf => MidNP::ValueOf, - Equal => MidNP::Equal, - NotEqual => MidNP::NotEqual, - Gt => MidNP::Gt, - Lt => MidNP::Lt, - SumOf => MidNP::SumOf, - ProductOf => MidNP::ProductOf, - MaxOf => MidNP::MaxOf, - DictContains => MidNP::Contains, - DictNotContains => MidNP::NotContains, - SetContains => MidNP::Contains, - SetNotContains => MidNP::NotContains, - ArrayContains => MidNP::Contains, - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -pub enum Predicate { - Native(NativePredicate), - BatchSelf(usize), - Custom(CustomPredicateRef), -} - -impl From for Predicate { - fn from(v: NativePredicate) -> Self { - Self::Native(v) - } -} diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index 28cda95..7add3a8 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -1,3 +1,4 @@ +/* use std::collections::{BTreeMap, HashMap}; use schemars::{JsonSchema, Schema}; @@ -5,14 +6,14 @@ use serde::{Deserialize, Serialize, Serializer}; use crate::{ backends::plonky2::mock::{mainpod::MockMainPod, signedpod::MockSignedPod}, - frontend::{containers::Dictionary, MainPod, SignedPod, Statement, Value}, + frontend::{containers::Dictionary, MainPod, SignedPod, Statement, TypedValue}, middleware::PodId, }; #[derive(Serialize, Deserialize, JsonSchema)] #[schemars(title = "SignedPod")] pub struct SignedPodHelper { - entries: HashMap, + entries: HashMap, proof: String, pod_class: String, pod_type: String, @@ -169,30 +170,34 @@ mod tests { fn test_value_serialization() { // Pairs of values and their expected serialized representations let values = vec![ - (Value::String("hello".to_string()), "\"hello\""), - (Value::Int(42), "{\"Int\":\"42\"}"), - (Value::Bool(true), "true"), + (TypedValue::String("hello".to_string()), "\"hello\""), + (TypedValue::Int(42), "{\"Int\":\"42\"}"), + (TypedValue::Bool(true), "true"), ( - Value::Array( - Array::new(vec![Value::String("foo".to_string()), Value::Bool(false)]).unwrap(), + TypedValue::Array( + Array::new(vec![ + TypedValue::String("foo".to_string()), + TypedValue::Bool(false), + ]) + .unwrap(), ), "[\"foo\",false]", ), ( - Value::Dictionary( + TypedValue::Dictionary( Dictionary::new(HashMap::from([ - ("foo".to_string(), Value::Int(123)), - ("bar".to_string(), Value::String("baz".to_string())), + ("foo".to_string(), TypedValue::Int(123)), + ("bar".to_string(), TypedValue::String("baz".to_string())), ])) .unwrap(), ), "{\"Dictionary\":{\"bar\":\"baz\",\"foo\":{\"Int\":\"123\"}}}", ), ( - Value::Set( + TypedValue::Set( Set::new(vec![ - Value::String("foo".to_string()), - Value::String("bar".to_string()), + TypedValue::String("foo".to_string()), + TypedValue::String("bar".to_string()), ]) .unwrap(), ), @@ -203,9 +208,9 @@ mod tests { for (value, expected) in values { let serialized = serde_json::to_string(&value).unwrap(); assert_eq!(serialized, expected); - let deserialized: Value = serde_json::from_str(&serialized).unwrap(); + let deserialized: TypedValue = serde_json::from_str(&serialized).unwrap(); assert_eq!(value, deserialized); - let expected_deserialized: Value = serde_json::from_str(&expected).unwrap(); + let expected_deserialized: TypedValue = serde_json::from_str(&expected).unwrap(); assert_eq!(value, expected_deserialized); } } @@ -219,27 +224,32 @@ mod tests { builder.insert("very_large_int", 1152921504606846976); builder.insert( "a_dict_containing_one_key", - Value::Dictionary( + TypedValue::Dictionary( Dictionary::new(HashMap::from([ - ("foo".to_string(), Value::Int(123)), + ("foo".to_string(), TypedValue::Int(123)), ( "an_array_containing_three_ints".to_string(), - Value::Array( - Array::new(vec![Value::Int(1), Value::Int(2), Value::Int(3)]).unwrap(), + TypedValue::Array( + Array::new(vec![ + TypedValue::Int(1), + TypedValue::Int(2), + TypedValue::Int(3), + ]) + .unwrap(), ), ), ( "a_set_containing_two_strings".to_string(), - Value::Set( + TypedValue::Set( Set::new(vec![ - Value::Array( + TypedValue::Array( Array::new(vec![ - Value::String("foo".to_string()), - Value::String("bar".to_string()), + TypedValue::String("foo".to_string()), + TypedValue::String("bar".to_string()), ]) .unwrap(), ), - Value::String("baz".to_string()), + TypedValue::String("baz".to_string()), ]) .unwrap(), ), @@ -256,7 +266,6 @@ mod tests { let deserialized: SignedPod = serde_json::from_str(&serialized).unwrap(); assert_eq!(pod.kvs, deserialized.kvs); - assert_eq!(pod.origin(), deserialized.origin()); assert_eq!(pod.verify().is_ok(), deserialized.verify().is_ok()); assert_eq!(pod.id(), deserialized.id()) } @@ -265,7 +274,7 @@ mod tests { fn test_main_pod_serialization() -> Result<()> { let params = middleware::Params::default(); let sanctions_values = vec!["A343434340".into()]; - let sanction_set = Value::Set(Set::new(sanctions_values)?); + let sanction_set = TypedValue::Set(Set::new(sanctions_values)?); let (gov_id_builder, pay_stub_builder, sanction_list_builder) = zu_kyc_sign_pod_builders(¶ms, &sanction_set); @@ -311,3 +320,4 @@ mod tests { ); } } +*/ diff --git a/src/frontend/statement.rs b/src/frontend/statement.rs deleted file mode 100644 index cf183f5..0000000 --- a/src/frontend/statement.rs +++ /dev/null @@ -1,162 +0,0 @@ -use std::fmt; - -use anyhow::{anyhow, Result}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; - -use crate::{ - frontend::{AnchoredKey, NativePredicate, Predicate, SignedPod, Value}, - middleware, -}; - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -pub enum StatementArg { - Literal(Value), - Key(AnchoredKey), -} - -impl fmt::Display for StatementArg { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Literal(v) => write!(f, "{}", v), - Self::Key(r) => write!(f, "{}.{}", r.origin.pod_id, r.key), - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -pub struct Statement { - pub predicate: Predicate, - pub args: Vec, -} - -impl Statement { - pub fn new(predicate: Predicate, args: Vec) -> Self { - Self { predicate, args } - } -} - -impl From<(&SignedPod, &str)> for Statement { - fn from((pod, key): (&SignedPod, &str)) -> Self { - // TODO: TryFrom. - let value = pod - .kvs - .get(key) - .cloned() - .unwrap_or_else(|| panic!("Key {} is not present in POD: {}", key, pod)); - Statement { - predicate: Predicate::Native(NativePredicate::ValueOf), - args: vec![ - StatementArg::Key(AnchoredKey::new(pod.origin(), key.to_string())), - StatementArg::Literal(value), - ], - } - } -} - -#[derive(Debug)] -pub struct ManualConversionRequired(); - -impl std::fmt::Display for StatementConversionError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "Statement conversion error: statement conversion must be implemented manually." - ) - } -} - -impl std::error::Error for StatementConversionError {} - -#[derive(Debug)] -pub enum StatementConversionError { - MCR(ManualConversionRequired), - Error(anyhow::Error), -} - -impl From for StatementConversionError { - fn from(value: anyhow::Error) -> Self { - Self::Error(value) - } -} - -impl TryFrom for middleware::Statement { - type Error = StatementConversionError; - fn try_from(s: Statement) -> Result { - type MS = middleware::Statement; - type NP = NativePredicate; - type SA = StatementArg; - let args = ( - s.args.first().cloned(), - s.args.get(1).cloned(), - s.args.get(2).cloned(), - ); - Ok(match &s.predicate { - Predicate::Native(np) => match (np, args) { - (NP::None, (None, None, None)) => MS::None, - (NP::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => { - MS::ValueOf(ak.into(), (&v).into()) - } - (NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::Equal(ak1.into(), ak2.into()) - } - (NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::NotEqual(ak1.into(), ak2.into()) - } - (NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::Gt(ak1.into(), ak2.into()) - } - (NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::Lt(ak1.into(), ak2.into()) - } - (NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { - MS::SumOf(ak1.into(), ak2.into(), ak3.into()) - } - (NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { - MS::ProductOf(ak1.into(), ak2.into(), ak3.into()) - } - (NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { - MS::MaxOf(ak1.into(), ak2.into(), ak3.into()) - } - ( - NP::DictContains, - (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), - ) => MS::Contains(ak1.into(), ak2.into(), ak3.into()), - (NP::DictNotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::NotContains(ak1.into(), ak2.into()) - } - (NP::SetContains, (Some(SA::Key(_)), Some(SA::Key(_)), None)) => { - return Err(StatementConversionError::MCR(ManualConversionRequired())); - } - (NP::SetNotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { - MS::NotContains(ak1.into(), ak2.into()) - } - _ => Err(anyhow!("Ill-formed statement: {}", s))?, - }, - Predicate::Custom(cpr) => MS::Custom( - cpr.clone().into(), - s.args - .iter() - .map(|arg| match arg { - StatementArg::Key(ak) => Ok(ak.clone().into()), - _ => Err(anyhow!("Invalid statement arg: {}", arg)), - }) - .collect::>>()?, - ), - _ => Err(anyhow!("Ill-formed statement: {}", s))?, - }) - } -} - -impl fmt::Display for Statement { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:?} ", self.predicate)?; - for (i, arg) in self.args.iter().enumerate() { - if i != 0 { - write!(f, " ")?; - } - write!(f, "{}", arg)?; - } - Ok(()) - } -} diff --git a/src/lib.rs b/src/lib.rs index 99fce2d..48e89a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,6 @@ pub mod backends; pub mod constants; pub mod frontend; pub mod middleware; -mod util; #[cfg(test)] pub mod examples; diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index f9960d0..5dcb0d5 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -1,3 +1,4 @@ +// TODO: Update this doc //! This file exposes the backend dependent basetypes as middleware types, //! taking them from the feature-enabled backend. //! @@ -29,12 +30,212 @@ //! u64/i64 to F conversion. Eventually we will do those conversions through the //! approach described in this file, removing the imports of plonky2 in the //! middleware. +//! TODO: Update this doc /// Value, Hash and F are imported based on 'features'. For example by default /// we use the 'plonky2' feature, but it could be used a 'plonky3' feature, so /// then the Value, Hash and F types would come from the plonky3 backend. -#[cfg(feature = "backend_plonky2")] -pub use crate::backends::plonky2::basetypes::{ - hash_fields, hash_str, hash_value, Hash, Value, EMPTY_HASH, EMPTY_VALUE, F, HASH_SIZE, - SELF_ID_HASH, VALUE_SIZE, +// #[cfg(feature = "backend_plonky2")] +// pub use crate::backends::plonky2::basetypes::{ +// hash_fields, hash_str, hash_value, Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F, HASH_SIZE, +// SELF_ID_HASH, VALUE_SIZE, +// }; +use std::{ + cmp::{Ord, Ordering}, + fmt, }; + +use anyhow::Result; +use hex::{FromHex, FromHexError}; +use plonky2::{ + field::{ + goldilocks_field::GoldilocksField, + types::{Field, PrimeField64}, + }, + hash::poseidon::PoseidonHash, + plonk::config::Hasher, +}; + +use crate::middleware::{ + // serialization::{ + // deserialize_hash_tuple, deserialize_value_tuple, serialize_hash_tuple, + // serialize_value_tuple, + // }, + Params, + ToFields, +}; + +/// F is the native field we use everywhere. Currently it's Goldilocks from plonky2 +pub type F = GoldilocksField; + +pub const HASH_SIZE: usize = 4; +pub const VALUE_SIZE: usize = 4; + +pub const EMPTY_VALUE: RawValue = RawValue([F::ZERO, F::ZERO, F::ZERO, F::ZERO]); +pub const SELF_ID_HASH: Hash = Hash([F::ONE, F::ZERO, F::ZERO, F::ZERO]); +pub const EMPTY_HASH: Hash = Hash([F::ZERO, F::ZERO, F::ZERO, F::ZERO]); + +#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)] +// #[schemars(rename = "RawValue")] +pub struct RawValue( + // #[serde( + // serialize_with = "serialize_value_tuple", + // deserialize_with = "deserialize_value_tuple" + // )] + // We know that Serde will serialize and deserialize this as a string, so we can + // use the JsonSchema to validate the format. + // #[schemars(with = "String", regex(pattern = r"^[0-9a-fA-F]{64}$"))] + pub [F; VALUE_SIZE], +); + +impl ToFields for RawValue { + fn to_fields(&self, _params: &Params) -> Vec { + self.0.to_vec() + } +} + +impl RawValue { + pub fn to_bytes(self) -> Vec { + self.0 + .iter() + .flat_map(|e| e.to_canonical_u64().to_le_bytes()) + .collect() + } +} + +impl Ord for RawValue { + fn cmp(&self, other: &Self) -> Ordering { + for (lhs, rhs) in self.0.iter().zip(other.0.iter()).rev() { + let (lhs, rhs) = (lhs.to_canonical_u64(), rhs.to_canonical_u64()); + match lhs.cmp(&rhs) { + Ordering::Less => return Ordering::Less, + Ordering::Greater => return Ordering::Greater, + _ => {} + } + } + Ordering::Equal + } +} + +impl PartialOrd for RawValue { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl From for RawValue { + fn from(v: i64) -> Self { + let lo = F::from_canonical_u64((v as u64) & 0xffffffff); + let hi = F::from_canonical_u64((v as u64) >> 32); + RawValue([lo, hi, F::ZERO, F::ZERO]) + } +} + +impl From for RawValue { + fn from(h: Hash) -> Self { + RawValue(h.0) + } +} + +impl fmt::Display for RawValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.0[2].is_zero() && self.0[3].is_zero() { + // Assume this is an integer + let (l0, l1) = (self.0[0].to_canonical_u64(), self.0[1].to_canonical_u64()); + assert!(l0 < (1 << 32)); + assert!(l1 < (1 << 32)); + write!(f, "{}", l0 + l1 * (1 << 32)) + } else { + // Assume this is a hash + Hash(self.0).fmt(f) + } + } +} + +#[derive(Clone, Copy, Debug, Default, Hash, Eq, PartialEq)] +pub struct Hash( + // #[serde( + // serialize_with = "serialize_hash_tuple", + // deserialize_with = "deserialize_hash_tuple" + // )] + // #[schemars(with = "String", regex(pattern = r"^[0-9a-fA-F]{64}$"))] + pub [F; HASH_SIZE], +); + +pub fn hash_value(input: &RawValue) -> Hash { + hash_fields(&input.0) +} + +pub fn hash_fields(input: &[F]) -> Hash { + Hash(PoseidonHash::hash_no_pad(input).elements) +} + +impl From for Hash { + fn from(v: RawValue) -> Self { + Hash(v.0) + } +} + +impl ToFields for Hash { + fn to_fields(&self, _params: &Params) -> Vec { + self.0.to_vec() + } +} + +impl PartialOrd for Hash { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Hash { + fn cmp(&self, other: &Self) -> Ordering { + RawValue(self.0).cmp(&RawValue(other.0)) + } +} + +impl fmt::Display for Hash { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let v0 = self.0[0].to_canonical_u64(); + for i in 0..HASH_SIZE { + write!(f, "{:02x}", (v0 >> (i * 8)) & 0xff)?; + } + write!(f, "…") + } +} + +impl FromHex for Hash { + type Error = FromHexError; + + // TODO make it dependant on backend::Value len + fn from_hex>(hex: T) -> Result { + // In little endian + let bytes = <[u8; 32]>::from_hex(hex)?; + let mut buf: [u8; 8] = [0; 8]; + let mut inner = [F::ZERO; HASH_SIZE]; + for i in 0..HASH_SIZE { + buf.copy_from_slice(&bytes[8 * i..8 * (i + 1)]); + inner[i] = F::from_canonical_u64(u64::from_le_bytes(buf)); + } + Ok(Self(inner)) + } +} + +pub fn hash_str(s: &str) -> Hash { + let mut input = s.as_bytes().to_vec(); + input.push(1); // padding + + // Merge 7 bytes into 1 field, because the field is slightly below 64 bits + let input: Vec = input + .chunks(7) + .map(|bytes| { + let mut v: u64 = 0; + for b in bytes.iter().rev() { + v <<= 8; + v += *b as u64; + } + F::from_canonical_u64(v) + }) + .collect(); + hash_fields(&input) +} diff --git a/src/middleware/containers.rs b/src/middleware/containers.rs index 7ddab7d..4c819ac 100644 --- a/src/middleware/containers.rs +++ b/src/middleware/containers.rs @@ -1,14 +1,14 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; /// This file implements the types defined at /// https://0xparc.github.io/pod2/values.html#dictionary-array-set . -use anyhow::Result; +use anyhow::{anyhow, Result}; #[cfg(feature = "backend_plonky2")] -use crate::backends::plonky2::primitives::merkletree::{Iter as TreeIter, MerkleProof, MerkleTree}; +use crate::backends::plonky2::primitives::merkletree::{MerkleProof, MerkleTree}; use crate::{ constants::MAX_DEPTH, - middleware::basetypes::{hash_value, Hash, Value, EMPTY_VALUE}, + middleware::{hash_value, Hash, Key, RawValue, Value, EMPTY_VALUE}, }; /// Dictionary: the user original keys and values are hashed to be used in the leaf. @@ -16,47 +16,58 @@ use crate::{ /// leaf.value=hash(original_value) #[derive(Clone, Debug)] pub struct Dictionary { - // exposed with pub(crate) so that it can be modified at tests - pub(crate) mt: MerkleTree, + mt: MerkleTree, + kvs: HashMap, } impl Dictionary { - pub fn new(kvs: &HashMap) -> Result { - let kvs: HashMap = kvs.iter().map(|(&k, &v)| (Value(k.0), v)).collect(); + pub fn new(kvs: HashMap) -> Result { + let kvs_raw: HashMap = kvs + .iter() + .map(|(k, v)| (RawValue(k.hash().0), v.raw())) + .collect(); Ok(Self { - mt: MerkleTree::new(MAX_DEPTH, &kvs)?, + mt: MerkleTree::new(MAX_DEPTH, &kvs_raw)?, + kvs, }) } pub fn commitment(&self) -> Hash { self.mt.root() } - pub fn get(&self, key: &Value) -> Result { - self.mt.get(key) + pub fn get(&self, key: &Key) -> Result<&Value> { + self.kvs + .get(key) + .ok_or_else(|| anyhow!("key \"{}\" not found", key.name())) } - pub fn prove(&self, key: &Value) -> Result<(Value, MerkleProof)> { - self.mt.prove(key) + pub fn prove(&self, key: &Key) -> Result<(&Value, MerkleProof)> { + let (_, mtp) = self.mt.prove(&RawValue(key.hash().0))?; + let value = self.kvs.get(key).expect("key exists"); + Ok((value, mtp)) } - pub fn prove_nonexistence(&self, key: &Value) -> Result { - self.mt.prove_nonexistence(key) + pub fn prove_nonexistence(&self, key: &Key) -> Result { + self.mt.prove_nonexistence(&RawValue(key.hash().0)) } - pub fn verify(root: Hash, proof: &MerkleProof, key: &Value, value: &Value) -> Result<()> { - MerkleTree::verify(MAX_DEPTH, root, proof, key, value) + pub fn verify(root: Hash, proof: &MerkleProof, key: &Key, value: &Value) -> Result<()> { + let key = RawValue(key.hash().0); + MerkleTree::verify(MAX_DEPTH, root, proof, &key, &value.raw()) } - pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Value) -> Result<()> { - MerkleTree::verify_nonexistence(MAX_DEPTH, root, proof, key) + pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Key) -> Result<()> { + let key = RawValue(key.hash().0); + MerkleTree::verify_nonexistence(MAX_DEPTH, root, proof, &key) } - pub fn iter(&self) -> TreeIter { - self.mt.iter() - } -} -impl<'a> IntoIterator for &'a Dictionary { - type Item = (&'a Value, &'a Value); - type IntoIter = TreeIter<'a>; - - fn into_iter(self) -> Self::IntoIter { - self.mt.iter() + // TODO: Rename to dict to be consistent maybe? + pub fn kvs(&self) -> &HashMap { + &self.kvs } } +// impl<'a> IntoIterator for &'a Dictionary { +// type Item = (&'a RawValue, &'a RawValue); +// type IntoIter = TreeIter<'a>; +// +// fn into_iter(self) -> Self::IntoIter { +// self.mt.iter() +// } +// } impl PartialEq for Dictionary { fn eq(&self, other: &Self) -> bool { @@ -71,42 +82,48 @@ impl Eq for Dictionary {} #[derive(Clone, Debug)] pub struct Set { mt: MerkleTree, + set: HashSet, } impl Set { - pub fn new(set: &[Value]) -> Result { - let kvs: HashMap = set + pub fn new(set: HashSet) -> Result { + let kvs_raw: HashMap = set .iter() .map(|e| { - let h = hash_value(e); - (Value::from(h), EMPTY_VALUE) + let h = hash_value(&e.raw()); + (RawValue::from(h), EMPTY_VALUE) }) .collect(); Ok(Self { - mt: MerkleTree::new(MAX_DEPTH, &kvs)?, + mt: MerkleTree::new(MAX_DEPTH, &kvs_raw)?, + set, }) } pub fn commitment(&self) -> Hash { self.mt.root() } - pub fn contains(&self, value: &Value) -> Result { - self.mt.contains(value) + pub fn contains(&self, value: &Value) -> bool { + self.set.contains(value) } pub fn prove(&self, value: &Value) -> Result { - let (_, proof) = self.mt.prove(value)?; + let h = hash_value(&value.raw()); + let (_, proof) = self.mt.prove(&RawValue::from(h))?; Ok(proof) } pub fn prove_nonexistence(&self, value: &Value) -> Result { - self.mt.prove_nonexistence(value) + let h = hash_value(&value.raw()); + self.mt.prove_nonexistence(&RawValue::from(h)) } pub fn verify(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { - MerkleTree::verify(MAX_DEPTH, root, proof, value, &EMPTY_VALUE) + let h = hash_value(&value.raw()); + MerkleTree::verify(MAX_DEPTH, root, proof, &RawValue::from(h), &EMPTY_VALUE) } pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { - MerkleTree::verify_nonexistence(MAX_DEPTH, root, proof, value) + let h = hash_value(&value.raw()); + MerkleTree::verify_nonexistence(MAX_DEPTH, root, proof, &RawValue::from(h)) } - pub fn iter(&self) -> TreeIter { - self.mt.iter() + pub fn set(&self) -> &HashSet { + &self.set } } @@ -124,34 +141,46 @@ impl Eq for Set {} #[derive(Clone, Debug)] pub struct Array { mt: MerkleTree, + array: Vec, } impl Array { - pub fn new(array: &[Value]) -> Result { - let kvs: HashMap = array + pub fn new(array: Vec) -> Result { + let kvs_raw: HashMap = array .iter() .enumerate() - .map(|(i, &e)| (Value::from(i as i64), e)) + .map(|(i, e)| (RawValue::from(i as i64), e.raw())) .collect(); Ok(Self { - mt: MerkleTree::new(MAX_DEPTH, &kvs)?, + mt: MerkleTree::new(MAX_DEPTH, &kvs_raw)?, + array, }) } pub fn commitment(&self) -> Hash { self.mt.root() } - pub fn get(&self, i: usize) -> Result { - self.mt.get(&Value::from(i as i64)) + pub fn get(&self, i: usize) -> Result<&Value> { + self.array + .get(i) + .ok_or_else(|| anyhow!("index {} out of bounds 0..{}", i, self.array.len())) } - pub fn prove(&self, i: usize) -> Result<(Value, MerkleProof)> { - self.mt.prove(&Value::from(i as i64)) + pub fn prove(&self, i: usize) -> Result<(&Value, MerkleProof)> { + let (_, mtp) = self.mt.prove(&RawValue::from(i as i64))?; + let value = self.array.get(i).expect("valid index"); + Ok((value, mtp)) } pub fn verify(root: Hash, proof: &MerkleProof, i: usize, value: &Value) -> Result<()> { - MerkleTree::verify(MAX_DEPTH, root, proof, &Value::from(i as i64), value) + MerkleTree::verify( + MAX_DEPTH, + root, + proof, + &RawValue::from(i as i64), + &value.raw(), + ) } - pub fn iter(&self) -> TreeIter { - self.mt.iter() + pub fn array(&self) -> &[Value] { + &self.array } } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index ca17cb3..63832a9 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -1,101 +1,85 @@ -use std::{collections::HashMap, fmt, hash as h, iter, iter::zip, sync::Arc}; +use std::{fmt, iter, sync::Arc}; use anyhow::{anyhow, Result}; use plonky2::field::types::Field; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +// use schemars::JsonSchema; + +// use serde::{Deserialize, Serialize}; use crate::{ - backends::plonky2::basetypes::HASH_SIZE, - middleware::{ - hash_fields, AnchoredKey, Hash, NativePredicate, Params, PodId, Statement, StatementArg, - ToFields, Value, F, - }, - util::hashmap_insert_no_dupe, + middleware::HASH_SIZE, + middleware::{hash_fields, Hash, Key, NativePredicate, Params, ToFields, Value, F}, }; -// BEGIN Custom 1b - -#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] -pub enum HashOrWildcard { - Hash(Hash), - Wildcard(usize), +#[derive(Clone, Debug, PartialEq)] +pub struct Wildcard { + pub name: String, + pub index: usize, } -impl HashOrWildcard { - /// Matches a hash or wildcard against a value, returning a pair - /// representing a wildcard binding (if any) or an error if no - /// match is possible. - pub fn match_against(&self, v: &Value) -> Result> { - match self { - HashOrWildcard::Hash(h) if &Value::from(*h) == v => Ok(None), - HashOrWildcard::Wildcard(i) => Ok(Some((*i, *v))), - _ => Err(anyhow!( - "Failed to match hash or wildcard {} against value {}.", - self, - v - )), - } +impl Wildcard { + pub fn new(name: String, index: usize) -> Self { + Self { name, index } } } -impl fmt::Display for HashOrWildcard { +impl fmt::Display for Wildcard { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "*{}[{}]", self.index, self.name) + } +} + +impl ToFields for Wildcard { + fn to_fields(&self, _params: &Params) -> Vec { + vec![F::from_canonical_u64(self.index as u64)] + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum KeyOrWildcard { + Key(Key), + Wildcard(Wildcard), +} + +impl fmt::Display for KeyOrWildcard { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Self::Hash(h) => write!(f, "{}", h), - Self::Wildcard(n) => write!(f, "*{}", n), + Self::Key(k) => write!(f, "{}", k), + Self::Wildcard(wc) => write!(f, "{}", wc), } } } -impl ToFields for HashOrWildcard { +impl ToFields for KeyOrWildcard { fn to_fields(&self, params: &Params) -> Vec { match self { - HashOrWildcard::Hash(h) => h.to_fields(params), - HashOrWildcard::Wildcard(w) => (0..HASH_SIZE - 1) - .chain(iter::once(*w)) - .map(|x| F::from_canonical_u64(x as u64)) + KeyOrWildcard::Key(k) => k.hash().to_fields(params), + KeyOrWildcard::Wildcard(wc) => iter::once(F::ZERO) + .take(HASH_SIZE - 1) + .chain(iter::once(F::from_canonical_u64(wc.index as u64))) .collect(), } } } -#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq)] pub enum StatementTmplArg { None, Literal(Value), - Key(HashOrWildcard, HashOrWildcard), -} - -impl StatementTmplArg { - /// Matches a statement template argument against a statement - /// argument, returning a wildcard correspondence in the case of - /// one or more wildcard matches, nothing in the case of a - /// literal/hash match, and an error otherwise. - pub fn match_against(&self, s_arg: &StatementArg) -> Result> { - match (self, s_arg) { - (Self::None, StatementArg::None) => Ok(vec![]), - (Self::Literal(v), StatementArg::Literal(w)) if v == w => Ok(vec![]), - (Self::Key(tmpl_o, tmpl_k), StatementArg::Key(AnchoredKey(PodId(o), k))) => { - let o_corr = tmpl_o.match_against(&(*o).into())?; - let k_corr = tmpl_k.match_against(&(*k).into())?; - Ok([o_corr, k_corr].into_iter().flatten().collect()) - } - _ => Err(anyhow!( - "Failed to match statement template argument {:?} against statement argument {:?}.", - self, - s_arg - )), - } - } + // AnchoredKey + Key(Wildcard, KeyOrWildcard), + // TODO: This naming is a bit confusing: a WildcardLiteral that contains a Wildcard... + // Could we merge WildcardValue and Value and allow wildcard value apart from pod_id and key? + WildcardLiteral(Wildcard), } impl ToFields for StatementTmplArg { fn to_fields(&self, params: &Params) -> Vec { // None => (0, ...) // Literal(value) => (1, [value], 0, 0, 0, 0) - // Key(hash_or_wildcard1, hash_or_wildcard2) - // => (2, [hash_or_wildcard1], [hash_or_wildcard2]) + // Key(wildcard1, key_or_wildcard2) + // => (2, [wildcard1], [key_or_wildcard2]) + // WildcardLiteral(wildcard) => (3, [wildcard], 0, 0, 0, 0) // In all three cases, we pad to 2 * hash_size + 1 = 9 field elements let statement_tmpl_arg_size = 2 * HASH_SIZE + 1; match self { @@ -107,15 +91,22 @@ impl ToFields for StatementTmplArg { } StatementTmplArg::Literal(v) => { let fields: Vec = iter::once(F::from_canonical_u64(1)) - .chain(v.to_fields(params)) + .chain(v.raw().to_fields(params)) .chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) .collect(); fields } - StatementTmplArg::Key(hw1, hw2) => { + StatementTmplArg::Key(wc1, kw2) => { let fields: Vec = iter::once(F::from_canonical_u64(2)) - .chain(hw1.to_fields(params)) - .chain(hw2.to_fields(params)) + .chain(wc1.to_fields(params)) + .chain(kw2.to_fields(params)) + .collect(); + fields + } + StatementTmplArg::WildcardLiteral(wc) => { + let fields: Vec = iter::once(F::from_canonical_u64(3)) + .chain(wc.to_fields(params)) + .chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) .collect(); fields } @@ -129,50 +120,37 @@ impl fmt::Display for StatementTmplArg { Self::None => write!(f, "none"), Self::Literal(v) => write!(f, "{}", v), Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key), + Self::WildcardLiteral(v) => write!(f, "{}", v), } } } -// END - -// BEGIN Custom 2 - -// pub enum StatementTmplArg { -// None, -// Literal(Value), -// Wildcard(usize), -// } - -// END - /// Statement Template for a Custom Predicate -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -pub struct StatementTmpl(pub Predicate, pub Vec); +#[derive(Clone, Debug, PartialEq)] +pub struct StatementTmpl { + pub pred: Predicate, + pub args: Vec, +} impl StatementTmpl { pub fn pred(&self) -> &Predicate { - &self.0 + &self.pred } pub fn args(&self) -> &[StatementTmplArg] { - &self.1 + &self.args } - /// Matches a statement template against a statement, returning - /// the variable bindings as an association list. Returns an error - /// if there is type or argument mismatch. - pub fn match_against(&self, s: &Statement) -> Result> { - type P = Predicate; - if matches!(self, Self(P::BatchSelf(_), _)) { - Err(anyhow!( - "Cannot check self-referencing statement templates." - )) - } else if self.pred() != &s.predicate() { - Err(anyhow!("Type mismatch between {:?} and {}.", self, s)) - } else { - zip(self.args(), s.args()) - .map(|(t_arg, s_arg)| t_arg.match_against(&s_arg)) - .collect::>>() - .map(|v| v.concat()) +} + +impl fmt::Display for StatementTmpl { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}(", self.pred)?; + for (i, arg) in self.args.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; } + writeln!(f) } } @@ -185,25 +163,26 @@ impl ToFields for StatementTmpl { // TODO think if this check should go into the StatementTmpl creation, // instead of at the `to_fields` method, where we should assume that the // values are already valid - if self.1.len() > params.max_statement_args { + if self.args.len() > params.max_statement_args { panic!("Statement template has too many arguments"); } let mut fields: Vec = self - .0 + .pred .to_fields(params) .into_iter() - .chain(self.1.iter().flat_map(|sta| sta.to_fields(params))) + .chain(self.args.iter().flat_map(|sta| sta.to_fields(params))) .collect(); fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0)); fields } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq)] /// NOTE: fields are not public (outside of crate) to enforce the struct instantiation through /// the `::and/or` methods, which performs checks on the values. pub struct CustomPredicate { + pub name: String, // Non-cryptographic metadata /// true for "and", false for "or" pub(crate) conjunction: bool, pub(crate) statements: Vec, @@ -213,13 +192,24 @@ pub struct CustomPredicate { } impl CustomPredicate { - pub fn and(params: &Params, statements: Vec, args_len: usize) -> Result { - Self::new(params, true, statements, args_len) + pub fn and( + name: String, + params: &Params, + statements: Vec, + args_len: usize, + ) -> Result { + Self::new(name, params, true, statements, args_len) } - pub fn or(params: &Params, statements: Vec, args_len: usize) -> Result { - Self::new(params, false, statements, args_len) + pub fn or( + name: String, + params: &Params, + statements: Vec, + args_len: usize, + ) -> Result { + Self::new(name, params, false, statements, args_len) } pub fn new( + name: String, params: &Params, conjunction: bool, statements: Vec, @@ -230,6 +220,7 @@ impl CustomPredicate { } Ok(Self { + name, conjunction, statements, args_len, @@ -266,8 +257,8 @@ impl fmt::Display for CustomPredicate { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?; for st in &self.statements { - write!(f, " {}", st.0)?; - for (i, arg) in st.1.iter().enumerate() { + write!(f, " {}(", st.pred)?; + for (i, arg) in st.args.iter().enumerate() { if i != 0 { write!(f, ", ")?; } @@ -287,7 +278,7 @@ impl fmt::Display for CustomPredicate { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, PartialEq)] pub struct CustomPredicateBatch { pub name: String, pub predicates: Vec, @@ -324,67 +315,23 @@ impl CustomPredicateBatch { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -pub struct CustomPredicateRef(pub Arc, pub usize); +#[derive(Clone, Debug, PartialEq)] +pub struct CustomPredicateRef { + pub batch: Arc, + pub index: usize, +} impl CustomPredicateRef { - pub fn arg_len(&self) -> usize { - self.0.predicates[self.1].args_len + pub fn new(batch: Arc, index: usize) -> Self { + Self { batch, index } } - pub fn match_against(&self, statements: &[Statement]) -> Result> { - let mut bindings = HashMap::new(); - // Single out custom predicate, replacing batch-self - // references with custom predicate references. - let custom_predicate = { - let cp = &Arc::unwrap_or_clone(self.0.clone()).predicates[self.1]; - CustomPredicate { - conjunction: cp.conjunction, - statements: cp - .statements - .iter() - .map(|StatementTmpl(p, args)| { - StatementTmpl( - match p { - Predicate::BatchSelf(i) => { - Predicate::Custom(CustomPredicateRef(self.0.clone(), *i)) - } - _ => p.clone(), - }, - args.to_vec(), - ) - }) - .collect(), - args_len: cp.args_len, - } - }; - match custom_predicate.conjunction { - true if custom_predicate.statements.len() == statements.len() => { - // Match op args against statement templates - let match_bindings = iter::zip(custom_predicate.statements, statements).map( - |(s_tmpl, s)| s_tmpl.match_against(s) - ).collect::>>() - .map(|v| v.concat())?; - // Add bindings to binding table, throwing if there is an inconsistency. - match_bindings.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; - Ok(bindings) - }, - false if statements.len() == 1 => { - // Match op arg against each statement template - custom_predicate.statements.iter().map( - |s_tmpl| { - let mut bindings = bindings.clone(); - s_tmpl.match_against(&statements[0])?.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; - Ok::<_, anyhow::Error>(bindings) - } - ).find(|m| m.is_ok()).unwrap_or(Err(anyhow!("Statement {} does not match disjunctive custom predicate {}.", &statements[0], custom_predicate))) - }, - _ => Err(anyhow!("Custom predicate statement template list {:?} does not match op argument list {:?}.", custom_predicate.statements, statements)) - } + pub fn arg_len(&self) -> usize { + self.batch.predicates[self.index].args_len } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -#[serde(tag = "type", content = "value")] +#[derive(Clone, Debug, PartialEq)] +// #[serde(tag = "type", content = "value")] pub enum Predicate { Native(NativePredicate), BatchSelf(usize), @@ -414,10 +361,12 @@ impl ToFields for Predicate { Self::BatchSelf(i) => iter::once(F::from_canonical_u64(2)) .chain(iter::once(F::from_canonical_usize(*i))) .collect(), - Self::Custom(CustomPredicateRef(pb, i)) => iter::once(F::from_canonical_u64(3)) - .chain(pb.hash(params).0) - .chain(iter::once(F::from_canonical_usize(*i))) - .collect(), + Self::Custom(CustomPredicateRef { batch, index }) => { + iter::once(F::from_canonical_u64(3)) + .chain(batch.hash(params).0) + .chain(iter::once(F::from_canonical_usize(*index))) + .collect() + } }; fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0)); fields @@ -429,7 +378,13 @@ impl fmt::Display for Predicate { match self { Self::Native(p) => write!(f, "{:?}", p), Self::BatchSelf(i) => write!(f, "self.{}", i), - Self::Custom(CustomPredicateRef(pb, i)) => write!(f, "{}.{}", pb.name, i), + Self::Custom(CustomPredicateRef { batch, index }) => { + write!( + f, + "{}.{}[{}]", + batch.name, index, batch.predicates[*index].name + ) + } } } } @@ -441,18 +396,29 @@ mod tests { use anyhow::Result; use plonky2::field::goldilocks_field::GoldilocksField; + use super::*; use crate::middleware::{ AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Hash, - HashOrWildcard, NativePredicate, Operation, Params, PodId, PodType, Predicate, Statement, - StatementTmpl, StatementTmplArg, SELF, + KeyOrWildcard, NativePredicate, Operation, Params, PodId, PodType, Predicate, Statement, + StatementTmpl, StatementTmplArg, WildcardValue, SELF, }; fn st(p: Predicate, args: Vec) -> StatementTmpl { - StatementTmpl(p, args) + StatementTmpl { pred: p, args } + } + + fn kow_wc(i: usize) -> KOW { + KOW::Wildcard(wc(i)) + } + fn wc(i: usize) -> Wildcard { + Wildcard { + name: format!("{}", i), + index: i, + } } type STA = StatementTmplArg; - type HOW = HashOrWildcard; + type KOW = KeyOrWildcard; type P = Predicate; type NP = NativePredicate; @@ -468,44 +434,42 @@ mod tests { let cust_pred_batch = Arc::new(CustomPredicateBatch { name: "is_double".to_string(), predicates: vec![CustomPredicate::and( + "_".into(), ¶ms, vec![ st( P::Native(NP::ValueOf), - vec![ - STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), - STA::Literal(2.into()), - ], + vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(2.into())], ), st( P::Native(NP::ProductOf), vec![ - STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), - STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), - STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), + STA::Key(wc(0), kow_wc(1)), + STA::Key(wc(4), kow_wc(5)), + STA::Key(wc(2), kow_wc(3)), ], ), ], - 4, + 2, )?], }); let custom_statement = Statement::Custom( - CustomPredicateRef(cust_pred_batch.clone(), 0), + CustomPredicateRef::new(cust_pred_batch.clone(), 0), vec![ - AnchoredKey(SELF, "Some value".into()), - AnchoredKey(SELF, "Some other value".into()), + WildcardValue::PodId(SELF), + WildcardValue::Key(Key::from("Some value")), ], ); let custom_deduction = Operation::Custom( - CustomPredicateRef(cust_pred_batch, 0), + CustomPredicateRef::new(cust_pred_batch, 0), vec![ - Statement::ValueOf(AnchoredKey(SELF, "Some constant".into()), 2.into()), + Statement::ValueOf(AnchoredKey::from((SELF, "Some constant")), 2.into()), Statement::ProductOf( - AnchoredKey(SELF, "Some value".into()), - AnchoredKey(SELF, "Some constant".into()), - AnchoredKey(SELF, "Some other value".into()), + AnchoredKey::from((SELF, "Some value")), + AnchoredKey::from((SELF, "Some constant")), + AnchoredKey::from((SELF, "Some other value")), ), ], ); @@ -517,30 +481,34 @@ mod tests { #[test] fn ethdos_test() -> Result<()> { - let params = Params::default(); + let params = Params { + max_custom_predicate_wildcards: 12, + ..Default::default() + }; let eth_friend_cp = CustomPredicate::and( + "eth_friend_cp".into(), ¶ms, vec![ st( P::Native(NP::ValueOf), vec![ - STA::Key(HOW::Wildcard(4), HashOrWildcard::Hash("type".into())), + STA::Key(wc(4), KeyOrWildcard::Key("type".into())), STA::Literal(PodType::Signed.into()), ], ), st( P::Native(NP::Equal), vec![ - STA::Key(HOW::Wildcard(4), HashOrWildcard::Hash("signer".into())), - STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), + STA::Key(wc(4), KeyOrWildcard::Key("signer".into())), + STA::Key(wc(0), kow_wc(1)), ], ), st( P::Native(NP::Equal), vec![ - STA::Key(HOW::Wildcard(4), HashOrWildcard::Hash("attestation".into())), - STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), + STA::Key(wc(4), KeyOrWildcard::Key("attestation".into())), + STA::Key(wc(2), kow_wc(3)), ], ), ], @@ -552,81 +520,89 @@ mod tests { predicates: vec![eth_friend_cp], }); + // 0 let eth_dos_base = CustomPredicate::and( + "eth_dos_base".into(), ¶ms, vec![ st( P::Native(NP::Equal), - vec![ - STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), - STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), - ], + vec![STA::Key(wc(0), kow_wc(1)), STA::Key(wc(2), kow_wc(3))], ), st( P::Native(NP::ValueOf), - vec![ - STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), - STA::Literal(0.into()), - ], + vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(0.into())], ), ], 6, )?; + // 1 let eth_dos_ind = CustomPredicate::and( + "eth_dos_ind".into(), ¶ms, vec![ st( P::BatchSelf(2), vec![ - STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), - STA::Key(HOW::Wildcard(10), HOW::Wildcard(11)), - STA::Key(HOW::Wildcard(8), HOW::Wildcard(9)), + STA::WildcardLiteral(wc(0)), + STA::WildcardLiteral(wc(1)), + STA::WildcardLiteral(wc(10)), + STA::WildcardLiteral(wc(11)), + STA::WildcardLiteral(wc(8)), + STA::WildcardLiteral(wc(9)), ], ), st( P::Native(NP::ValueOf), - vec![ - STA::Key(HOW::Wildcard(6), HOW::Wildcard(7)), - STA::Literal(1.into()), - ], + vec![STA::Key(wc(6), kow_wc(7)), STA::Literal(1.into())], ), st( P::Native(NP::SumOf), vec![ - STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), - STA::Key(HOW::Wildcard(8), HOW::Wildcard(9)), - STA::Key(HOW::Wildcard(6), HOW::Wildcard(7)), + STA::Key(wc(4), kow_wc(5)), + STA::Key(wc(8), kow_wc(9)), + STA::Key(wc(6), kow_wc(7)), ], ), st( - P::Custom(CustomPredicateRef(eth_friend_batch.clone(), 0)), + P::Custom(CustomPredicateRef::new(eth_friend_batch.clone(), 0)), vec![ - STA::Key(HOW::Wildcard(10), HOW::Wildcard(11)), - STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), + STA::WildcardLiteral(wc(10)), + STA::WildcardLiteral(wc(11)), + STA::WildcardLiteral(wc(2)), + STA::WildcardLiteral(wc(3)), ], ), ], 6, )?; + // 2 let eth_dos_distance_either = CustomPredicate::or( + "eth_dos_distance_either".into(), ¶ms, vec![ st( P::BatchSelf(0), vec![ - STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), - STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), - STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), + STA::WildcardLiteral(wc(0)), + STA::WildcardLiteral(wc(1)), + STA::WildcardLiteral(wc(2)), + STA::WildcardLiteral(wc(3)), + STA::WildcardLiteral(wc(4)), + STA::WildcardLiteral(wc(5)), ], ), st( P::BatchSelf(1), vec![ - STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)), - STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)), - STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)), + STA::WildcardLiteral(wc(0)), + STA::WildcardLiteral(wc(1)), + STA::WildcardLiteral(wc(2)), + STA::WildcardLiteral(wc(3)), + STA::WildcardLiteral(wc(4)), + STA::WildcardLiteral(wc(5)), ], ), ], @@ -646,11 +622,14 @@ mod tests { // Example statement let ethdos_example = Statement::Custom( - CustomPredicateRef(eth_dos_distance_batch.clone(), 2), + CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), vec![ - AnchoredKey(pod_id1, "Alice".into()), - AnchoredKey(pod_id2, "Bob".into()), - AnchoredKey(SELF, "Seven".into()), + WildcardValue::PodId(pod_id1), + WildcardValue::Key(Key::from("Alice")), + WildcardValue::PodId(pod_id2), + WildcardValue::Key(Key::from("Bob")), + WildcardValue::PodId(SELF), + WildcardValue::Key(Key::from("Seven")), ], ); @@ -659,17 +638,20 @@ mod tests { // This could arise as the inductive step. let ethdos_ind_example = Statement::Custom( - CustomPredicateRef(eth_dos_distance_batch.clone(), 1), + CustomPredicateRef::new(eth_dos_distance_batch.clone(), 1), vec![ - AnchoredKey(pod_id1, "Alice".into()), - AnchoredKey(pod_id2, "Bob".into()), - AnchoredKey(SELF, "Seven".into()), + WildcardValue::PodId(pod_id1), + WildcardValue::Key(Key::from("Alice")), + WildcardValue::PodId(pod_id2), + WildcardValue::Key(Key::from("Bob")), + WildcardValue::PodId(SELF), + WildcardValue::Key(Key::from("Seven")), ], ); assert!(Operation::Custom( - CustomPredicateRef(eth_dos_distance_batch.clone(), 2), - vec![ethdos_ind_example.clone()] + CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), + vec![Statement::None, ethdos_ind_example.clone()] ) .check(¶ms, ðdos_example)?); @@ -678,30 +660,35 @@ mod tests { // less than 7, and Charlie is ETH-friends with Bob. let ethdos_facts = vec![ Statement::Custom( - CustomPredicateRef(eth_dos_distance_batch.clone(), 2), + CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), vec![ - AnchoredKey(pod_id1, "Alice".into()), - AnchoredKey(pod_id3, "Charlie".into()), - AnchoredKey(pod_id4, "Six".into()), + WildcardValue::PodId(pod_id1), + WildcardValue::Key(Key::from("Alice")), + WildcardValue::PodId(pod_id3), + WildcardValue::Key(Key::from("Charlie")), + WildcardValue::PodId(pod_id4), + WildcardValue::Key(Key::from("Six")), ], ), - Statement::ValueOf(AnchoredKey(SELF, "One".into()), 1.into()), + Statement::ValueOf(AnchoredKey::from((SELF, "One")), 1.into()), Statement::SumOf( - AnchoredKey(SELF, "Seven".into()), - AnchoredKey(pod_id4, "Six".into()), - AnchoredKey(SELF, "One".into()), + AnchoredKey::from((SELF, "Seven")), + AnchoredKey::from((pod_id4, "Six")), + AnchoredKey::from((SELF, "One")), ), Statement::Custom( - CustomPredicateRef(eth_friend_batch.clone(), 0), + CustomPredicateRef::new(eth_friend_batch.clone(), 0), vec![ - AnchoredKey(pod_id3, "Charlie".into()), - AnchoredKey(pod_id2, "Bob".into()), + WildcardValue::PodId(pod_id3), + WildcardValue::Key(Key::from("Charlie")), + WildcardValue::PodId(pod_id2), + WildcardValue::Key(Key::from("Bob")), ], ), ]; assert!(Operation::Custom( - CustomPredicateRef(eth_dos_distance_batch.clone(), 1), + CustomPredicateRef::new(eth_dos_distance_batch.clone(), 1), ethdos_facts ) .check(¶ms, ðdos_ind_example)?); diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 40ef643..b39591c 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,7 +1,15 @@ //! The middleware includes the type definitions and the traits used to connect the frontend and //! the backend. +use std::sync::Arc; mod basetypes; +use std::{ + cmp::{Ordering, PartialEq, PartialOrd}, + hash, +}; + +use anyhow::anyhow; +use containers::{Array, Dictionary, Set}; pub mod containers; mod custom; mod operation; @@ -14,12 +22,214 @@ pub use basetypes::*; pub use custom::*; use dyn_clone::DynClone; pub use operation::*; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +// use schemars::JsonSchema; +// use serde::{Deserialize, Serialize}; pub use statement::*; pub const SELF: PodId = PodId(SELF_ID_HASH); +// TODO: Move all value-related types to to `value.rs` +#[derive(Clone, Debug)] +// TODO #[schemars(transform = serialization::transform_value_schema)] +pub enum TypedValue { + // Serde cares about the order of the enum variants, with untagged variants + // appearing at the end. + // Variants without "untagged" will be serialized as "tagged" values by + // default, meaning that a Set appears in JSON as {"Set":[...]} + // and not as [...] + // Arrays, Strings and Booleans are untagged, as there is a natural JSON + // representation for them that is unambiguous to deserialize and is fully + // compatible with the semantics of the POD types. + // As JSON integers do not specify precision, and JavaScript is limited to + // 53-bit precision for integers, integers are represented as tagged + // strings, with a custom serializer and deserializer. + // TAGGED TYPES: + Set(Set), + Dictionary(Dictionary), + Int( + // TODO #[serde(serialize_with = "serialize_i64", deserialize_with = "deserialize_i64")] + // #[schemars(with = "String", regex(pattern = r"^\d+$"))] + i64, + ), + // Uses the serialization for middleware::Value: + Raw(RawValue), + // UNTAGGED TYPES: + // #[serde(untagged)] + // #[schemars(skip)] + Array(Array), + // #[serde(untagged)] + // #[schemars(skip)] + String(String), + // #[serde(untagged)] + // #[schemars(skip)] + Bool(bool), +} + +impl From<&str> for TypedValue { + fn from(s: &str) -> Self { + TypedValue::String(s.to_string()) + } +} + +impl From for TypedValue { + fn from(s: String) -> Self { + TypedValue::String(s) + } +} + +impl From for TypedValue { + fn from(v: i64) -> Self { + TypedValue::Int(v) + } +} + +impl From for TypedValue { + fn from(b: bool) -> Self { + TypedValue::Bool(b) + } +} + +impl From for TypedValue { + fn from(h: Hash) -> Self { + TypedValue::Raw(RawValue(h.0)) + } +} + +impl From for TypedValue { + fn from(s: Set) -> Self { + TypedValue::Set(s) + } +} + +impl From for TypedValue { + fn from(d: Dictionary) -> Self { + TypedValue::Dictionary(d) + } +} + +impl From for TypedValue { + fn from(a: Array) -> Self { + TypedValue::Array(a) + } +} + +impl From for TypedValue { + fn from(v: RawValue) -> Self { + TypedValue::Raw(v) + } +} + +impl From for TypedValue { + fn from(t: PodType) -> Self { + TypedValue::from(t as i64) + } +} + +impl TryFrom<&TypedValue> for i64 { + type Error = anyhow::Error; + fn try_from(v: &TypedValue) -> std::result::Result { + if let TypedValue::Int(n) = v { + Ok(*n) + } else { + Err(anyhow!("Value not an int")) + } + } +} + +impl fmt::Display for TypedValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TypedValue::String(s) => write!(f, "\"{}\"", s), + TypedValue::Int(v) => write!(f, "{}", v), + TypedValue::Bool(b) => write!(f, "{}", b), + TypedValue::Dictionary(d) => write!(f, "dict:{}", d.commitment()), + TypedValue::Set(s) => write!(f, "set:{}", s.commitment()), + TypedValue::Array(a) => write!(f, "arr:{}", a.commitment()), + TypedValue::Raw(v) => write!(f, "{}", v), + } + } +} + +impl From<&TypedValue> for RawValue { + fn from(v: &TypedValue) -> Self { + match v { + TypedValue::String(s) => RawValue::from(hash_str(s)), + TypedValue::Int(v) => RawValue::from(*v), + TypedValue::Bool(b) => RawValue::from(*b as i64), + TypedValue::Dictionary(d) => RawValue::from(d.commitment()), + TypedValue::Set(s) => RawValue::from(s.commitment()), + TypedValue::Array(a) => RawValue::from(a.commitment()), + TypedValue::Raw(v) => *v, + } + } +} + +#[derive(Clone, Debug)] +pub struct Value { + // The `TypedValue` is under `Arc` so that cloning a `Value` is cheap. + typed: Arc, + raw: RawValue, +} + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + self.raw == other.raw + } +} + +impl Eq for Value {} + +impl PartialOrd for Value { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.raw.cmp(&other.raw)) + } +} + +impl Ord for Value { + fn cmp(&self, other: &Self) -> Ordering { + self.raw.cmp(&other.raw) + } +} + +impl hash::Hash for Value { + fn hash(&self, state: &mut H) { + self.raw.hash(state) + } +} + +impl fmt::Display for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.typed) + } +} + +impl Value { + pub fn new(value: TypedValue) -> Self { + let raw_value = RawValue::from(&value); + Self { + typed: Arc::new(value), + raw: raw_value, + } + } + + pub fn typed(&self) -> &TypedValue { + &self.typed + } + pub fn raw(&self) -> RawValue { + self.raw + } +} + +// A Value can be created from any type Into type: bool, string-like, i64, ... +impl From for Value +where + T: Into, +{ + fn from(t: T) -> Self { + Self::new(t.into()) + } +} + impl fmt::Display for PodId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if *self == SELF { @@ -32,30 +242,93 @@ impl fmt::Display for PodId { } } -/// AnchoredKey is a tuple containing (OriginId: PodId, key: Hash) -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub struct AnchoredKey(pub PodId, pub Hash); +impl From<&Value> for Hash { + fn from(v: &Value) -> Self { + Self(v.raw.0) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Key { + name: String, + hash: Hash, +} + +impl Key { + pub fn new(name: String) -> Self { + let hash = hash_str(&name); + Self { name, hash } + } + + pub fn name(&self) -> &str { + &self.name + } + pub fn hash(&self) -> Hash { + self.hash + } + pub fn raw(&self) -> RawValue { + RawValue(self.hash.0) + } +} + +// A Key can easily be created from a string-like type +impl From for Key +where + T: Into, +{ + fn from(t: T) -> Self { + Self::new(t.into()) + } +} + +impl ToFields for Key { + fn to_fields(&self, params: &Params) -> Vec { + self.hash.to_fields(params) + } +} + +impl fmt::Display for Key { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name)?; + Ok(()) + } +} + +impl From for RawValue { + fn from(key: Key) -> RawValue { + RawValue(key.hash.0) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct AnchoredKey { + pub pod_id: PodId, + pub key: Key, +} impl AnchoredKey { - pub fn origin(&self) -> PodId { - self.0 - } - pub fn key(&self) -> Hash { - self.1 + pub fn new(pod_id: PodId, key: Key) -> Self { + Self { pod_id, key } } } impl fmt::Display for AnchoredKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}.{}", self.0, self.1)?; + write!(f, "{}.{}", self.pod_id, self.key)?; Ok(()) } } -/// An entry consists of a key-value pair. -pub type Entry = (String, Value); +impl From<(PodId, T)> for AnchoredKey +where + T: Into, +{ + fn from((pod_id, t): (PodId, T)) -> Self { + Self::new(pod_id, t.into()) + } +} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] pub struct PodId(pub Hash); impl ToFields for PodId { @@ -72,6 +345,7 @@ pub enum PodType { Signed = 3, Main = 4, } + impl fmt::Display for PodType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -84,13 +358,7 @@ impl fmt::Display for PodType { } } -impl From for Value { - fn from(v: PodType) -> Self { - Value::from(v as i64) - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Params { pub max_input_signed_pods: usize, pub max_input_main_pods: usize, @@ -102,6 +370,7 @@ pub struct Params { // max number of statements that can be ANDed or ORed together // in a custom predicate pub max_custom_predicate_arity: usize, + pub max_custom_predicate_wildcards: usize, pub max_custom_batch_size: usize, // maximum number of merkle proofs pub max_merkle_proofs: usize, @@ -120,6 +389,7 @@ impl Default for Params { max_statement_args: 5, max_operation_args: 5, max_custom_predicate_arity: 5, + max_custom_predicate_wildcards: 10, max_custom_batch_size: 5, max_merkle_proofs: 5, max_depth_mt_gadget: 32, @@ -213,7 +483,7 @@ pub trait Pod: fmt::Debug + DynClone { dyn_clone::clone_trait_object!(Pod); pub trait PodSigner { - fn sign(&mut self, params: &Params, kvs: &HashMap) -> Result>; + fn sign(&mut self, params: &Params, kvs: &HashMap) -> Result>; } /// This is a filler type that fulfills the Pod trait and always verifies. It's empty. This diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index a412bcd..b6131a2 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -1,25 +1,26 @@ -use std::{fmt, iter}; +use std::{fmt, iter, sync::Arc}; use anyhow::{anyhow, Result}; use log::error; use plonky2::field::types::Field; -use serde::{Deserialize, Serialize}; +// use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::primitives::merkletree::{MerkleProof, MerkleTree}, middleware::{ - AnchoredKey, CustomPredicateRef, NativePredicate, Params, Predicate, Statement, - StatementArg, ToFields, Value, F, SELF, + custom::KeyOrWildcard, AnchoredKey, CustomPredicateBatch, CustomPredicateRef, + NativePredicate, Params, Predicate, Statement, StatementArg, StatementTmplArg, ToFields, + Wildcard, WildcardValue, F, SELF, }, }; -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq)] pub enum OperationType { Native(NativeOperation), Custom(CustomPredicateRef), } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq)] pub enum OperationAux { None, MerkleProof(MerkleProof), @@ -41,17 +42,19 @@ impl ToFields for OperationType { Self::Native(p) => iter::once(F::from_canonical_u64(1)) .chain(p.to_fields(params)) .collect(), - Self::Custom(CustomPredicateRef(pb, i)) => iter::once(F::from_canonical_u64(3)) - .chain(pb.hash(params).0) - .chain(iter::once(F::from_canonical_usize(*i))) - .collect(), + Self::Custom(CustomPredicateRef { batch, index }) => { + iter::once(F::from_canonical_u64(3)) + .chain(batch.hash(params).0) + .chain(iter::once(F::from_canonical_usize(*index))) + .collect() + } }; fields.resize_with(Params::operation_type_size(), || F::from_canonical_u64(0)); fields } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NativeOperation { None = 0, NewEntry = 1, @@ -68,6 +71,14 @@ pub enum NativeOperation { SumOf = 13, ProductOf = 14, MaxOf = 15, + + // Syntactic sugar operations. These operations are not supported by the backend. The + // frontend compiler is responsible of translating these operations into the operations above. + DictContainsFromEntries = 1001, + DictNotContainsFromEntries = 1002, + SetContainsFromEntries = 1003, + SetNotContainsFromEntries = 1004, + ArrayContainsFromEntries = 1005, } impl ToFields for NativeOperation { @@ -108,6 +119,7 @@ impl OperationType { NativeOperation::SumOf => Some(Predicate::Native(NativePredicate::SumOf)), NativeOperation::ProductOf => Some(Predicate::Native(NativePredicate::ProductOf)), NativeOperation::MaxOf => Some(Predicate::Native(NativePredicate::MaxOf)), + no => unreachable!("Unexpected syntactic sugar op {:?}", no), }, OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())), } @@ -115,7 +127,7 @@ impl OperationType { } // TODO: Refine this enum. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq)] pub enum Operation { None, NewEntry, @@ -263,7 +275,10 @@ impl Operation { Self::CopyStatement(s1) => Some(s1.args()), Self::EqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)) => { if v1 == v2 { - Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + Some(vec![ + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + ]) } else { return Err(anyhow!("Invalid operation")); } @@ -273,7 +288,10 @@ impl Operation { } Self::NotEqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)) => { if v1 != v2 { - Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + Some(vec![ + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + ]) } else { return Err(anyhow!("Invalid operation")); } @@ -283,7 +301,10 @@ impl Operation { } Self::GtFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)) => { if v1 > v2 { - Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + Some(vec![ + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + ]) } else { return Err(anyhow!("Invalid operation")); } @@ -293,7 +314,10 @@ impl Operation { } Self::LtFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)) => { if v1 < v2 { - Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + Some(vec![ + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + ]) } else { return Err(anyhow!("Invalid operation")); } @@ -303,7 +327,10 @@ impl Operation { } Self::TransitiveEqualFromStatements(Equal(ak1, ak2), Equal(ak3, ak4)) => { if ak2 == ak3 { - Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak4)]) + Some(vec![ + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak4.clone()), + ]) } else { return Err(anyhow!("Invalid operation")); } @@ -311,48 +338,54 @@ impl Operation { Self::TransitiveEqualFromStatements(_, _) => { return Err(anyhow!("Invalid operation")); } - Self::GtToNotEqual(Gt(ak1, ak2)) => { - Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) - } + Self::GtToNotEqual(Gt(ak1, ak2)) => Some(vec![ + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + ]), Self::GtToNotEqual(_) => { return Err(anyhow!("Invalid operation")); } - Self::LtToNotEqual(Gt(ak1, ak2)) => { - Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) - } + Self::LtToNotEqual(Gt(ak1, ak2)) => Some(vec![ + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + ]), Self::LtToNotEqual(_) => { return Err(anyhow!("Invalid operation")); } Self::ContainsFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3), pf) - if MerkleTree::verify(pf.siblings.len(), (*v1).into(), pf, v2, v3).is_ok() => + if MerkleTree::verify(pf.siblings.len(), v1.into(), pf, &v2.raw(), &v3.raw()) + .is_ok() => { Some(vec![ - StatementArg::Key(*ak1), - StatementArg::Key(*ak2), - StatementArg::Key(*ak3), + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + StatementArg::Key(ak3.clone()), ]) } Self::ContainsFromEntries(_, _, _, _) => { return Err(anyhow!("Invalid operation")); } Self::NotContainsFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2), pf) - if MerkleTree::verify_nonexistence(pf.siblings.len(), (*v1).into(), pf, v2) + if MerkleTree::verify_nonexistence(pf.siblings.len(), v1.into(), pf, &v2.raw()) .is_ok() => { - Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + Some(vec![ + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + ]) } Self::NotContainsFromEntries(_, _, _) => { return Err(anyhow!("Invalid operation")); } Self::SumOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)) => { - let v1: i64 = (*v1).try_into()?; - let v2: i64 = (*v2).try_into()?; - let v3: i64 = (*v3).try_into()?; + let v1: i64 = v1.typed().try_into()?; + let v2: i64 = v2.typed().try_into()?; + let v3: i64 = v3.typed().try_into()?; if v1 == v2 + v3 { Some(vec![ - StatementArg::Key(*ak1), - StatementArg::Key(*ak2), - StatementArg::Key(*ak3), + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + StatementArg::Key(ak3.clone()), ]) } else { return Err(anyhow!("Invalid operation")); @@ -362,14 +395,14 @@ impl Operation { return Err(anyhow!("Invalid operation")); } Self::ProductOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)) => { - let v1: i64 = (*v1).try_into()?; - let v2: i64 = (*v2).try_into()?; - let v3: i64 = (*v3).try_into()?; + let v1: i64 = v1.typed().try_into()?; + let v2: i64 = v2.typed().try_into()?; + let v3: i64 = v3.typed().try_into()?; if v1 == v2 * v3 { Some(vec![ - StatementArg::Key(*ak1), - StatementArg::Key(*ak2), - StatementArg::Key(*ak3), + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + StatementArg::Key(ak3.clone()), ]) } else { return Err(anyhow!("Invalid operation")); @@ -379,14 +412,14 @@ impl Operation { return Err(anyhow!("Invalid operation")); } Self::MaxOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)) => { - let v1: i64 = (*v1).try_into()?; - let v2: i64 = (*v2).try_into()?; - let v3: i64 = (*v3).try_into()?; + let v1: i64 = v1.typed().try_into()?; + let v2: i64 = v2.typed().try_into()?; + let v3: i64 = v3.typed().try_into()?; if v1 == std::cmp::max(v2, v3) { Some(vec![ - StatementArg::Key(*ak1), - StatementArg::Key(*ak2), - StatementArg::Key(*ak3), + StatementArg::Key(ak1.clone()), + StatementArg::Key(ak2.clone()), + StatementArg::Key(ak3.clone()), ]) } else { return Err(anyhow!("Invalid operation")); @@ -413,11 +446,11 @@ impl Operation { Ok(valid) } /// Checks the given operation against a statement. - pub fn check(&self, _params: &Params, output_statement: &Statement) -> Result { + pub fn check(&self, params: &Params, output_statement: &Statement) -> Result { use Statement::*; match (self, output_statement) { (Self::None, None) => Ok(true), - (Self::NewEntry, ValueOf(AnchoredKey(pod_id, _), _)) => Ok(pod_id == &SELF), + (Self::NewEntry, ValueOf(AnchoredKey { pod_id, .. }, _)) => Ok(pod_id == &SELF), (Self::CopyStatement(s1), s2) => Ok(s1 == s2), (Self::EqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), Equal(ak3, ak4)) => { Ok(v1 == v2 && ak3 == ak1 && ak4 == ak2) @@ -451,40 +484,15 @@ impl Operation { Self::SumOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)), SumOf(ak4, ak5, ak6), ) => { - let v1: i64 = (*v1).try_into()?; - let v2: i64 = (*v2).try_into()?; - let v3: i64 = (*v3).try_into()?; + let v1: i64 = v1.typed().try_into()?; + let v2: i64 = v2.typed().try_into()?; + let v3: i64 = v3.typed().try_into()?; Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3) } - (Self::Custom(CustomPredicateRef(cpb, i), args), Custom(cpr, s_args)) - if cpb == &cpr.0 && i == &cpr.1 => + (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) + if batch == &cpr.batch && index == &cpr.index => { - // Bind according to custom predicate pattern match against arg list. - let bindings = cpr.match_against(args)?; - // Check arg length - let arg_len = cpr.arg_len(); - if arg_len != 2 * s_args.len() { - Err(anyhow!("Custom predicate arg list {:?} must have {} arguments after destructuring.", s_args, arg_len)) - } else { - let bound_args = (0..arg_len) - .map(|i| { - bindings.get(&i).cloned().ok_or(anyhow!( - "Wildcard {} of custom predicate {:?} is unbound.", - i, - cpr - )) - }) - .collect::>>()?; - let s_args = s_args - .iter() - .flat_map(|AnchoredKey(o, k)| [Value::from(o.0), (*k).into()]) - .collect::>(); - if bound_args != s_args { - Err(anyhow!("Arguments to output statement {} do not match those implied by operation {:?}", output_statement,self)) - } else { - Ok(true) - } - } + check_custom_pred(params, batch, *index, args, s_args) } _ => Err(anyhow!( "Invalid deduction: {:?} ⇏ {:#}", @@ -495,6 +503,120 @@ impl Operation { } } +/// Check that a StatementArg follows a StatementTmplArg based on the currently mapped wildcards. +/// Update the wildcard map with newly found wildcards. +pub fn check_st_tmpl( + st_tmpl_arg: &StatementTmplArg, + st_arg: &StatementArg, + // Map from wildcards to values that we have seen so far. + wildcard_map: &mut [Option], +) -> bool { + // Check that the value `v` at wildcard `wc` exists in the map or set it. + fn check_or_set( + v: WildcardValue, + wc: &Wildcard, + wildcard_map: &mut [Option], + ) -> bool { + if let Some(prev) = &wildcard_map[wc.index] { + if *prev != v { + // TODO: Return nice error + return false; + } + } else { + wildcard_map[wc.index] = Some(v); + } + true + } + + match (st_tmpl_arg, st_arg) { + (StatementTmplArg::None, StatementArg::None) => true, + (StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true, + ( + StatementTmplArg::Key(pod_id_wc, key_or_wc), + StatementArg::Key(AnchoredKey { pod_id, key }), + ) => { + let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map); + let key_ok = match key_or_wc { + KeyOrWildcard::Key(tmpl_key) => tmpl_key == key, + KeyOrWildcard::Wildcard(key_wc) => { + check_or_set(WildcardValue::Key(key.clone()), key_wc, wildcard_map) + } + }; + pod_id_ok && key_ok + } + (StatementTmplArg::WildcardLiteral(wc), StatementArg::WildcardLiteral(v)) => { + check_or_set(v.clone(), wc, wildcard_map) + } + _ => false, + } +} + +fn check_custom_pred( + params: &Params, + batch: &Arc, + index: usize, + args: &[Statement], + s_args: &[WildcardValue], +) -> Result { + let pred = &batch.predicates[index]; + if pred.statements.len() != args.len() { + return Err(anyhow!( + "Custom predicate operation needs {} statements but has {}.", + pred.statements.len(), + args.len() + )); + } + if pred.args_len != s_args.len() { + return Err(anyhow!( + "Custom predicate statement needs {} args but has {}.", + pred.args_len, + s_args.len() + )); + } + + // Check that all wildcard have consistent values as assigned in the statements while storing a + // map of their values. Count the number of statements that match the templates by predicate. + // NOTE: We assume the statements have the same order as defined in the custom predicate. For + // disjunctions we expect Statement::None for the unused statements. + let mut num_matches = 0; + let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards]; + for (st_tmpl, st) in pred.statements.iter().zip(args) { + let st_args = st.args(); + for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { + if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) { + // TODO: Better errors. Example: + // println!("{} doesn't match {}", st_arg, st_tmpl_arg); + // println!("{} doesn't match {}", st, st_tmpl); + return Ok(false); + } + } + + let st_tmpl_pred = match &st_tmpl.pred { + Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef { + batch: batch.clone(), + index: *i, + }), + p => p.clone(), + }; + if st_tmpl_pred == st.predicate() { + num_matches += 1; + } + } + + // Check that the resolved wildcard match the statement arguments. + for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) { + if !wc_value.as_ref().is_none_or(|wc_value| *wc_value == *s_arg) { + return Ok(false); + } + } + + if pred.conjunction { + Ok(num_matches == pred.statements.len()) + } else { + Ok(num_matches > 0) + } +} + impl ToFields for Operation { fn to_fields(&self, _params: &Params) -> Vec { todo!() diff --git a/src/middleware/serialization.rs b/src/middleware/serialization.rs index a743252..eafa920 100644 --- a/src/middleware/serialization.rs +++ b/src/middleware/serialization.rs @@ -1,3 +1,5 @@ +// TODO: Reenable +/* use plonky2::field::types::Field; use serde::Deserialize; @@ -67,3 +69,4 @@ where { deserialize_field_tuple::(deserializer) } +*/ diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index e27f0a8..a648f65 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -2,14 +2,16 @@ use std::{fmt, iter}; use anyhow::{anyhow, Result}; use plonky2::field::types::Field; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +// use schemars::JsonSchema; +// use serde::{Deserialize, Serialize}; use strum_macros::FromRepr; use crate::middleware::{ - AnchoredKey, CustomPredicateRef, Params, Predicate, ToFields, Value, F, VALUE_SIZE, + AnchoredKey, CustomPredicateRef, Key, Params, PodId, Predicate, RawValue, ToFields, Value, F, + VALUE_SIZE, }; +// TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static // hash(KEY_SIGNER) = [2145458785152392366, 15113074911296146791, 15323228995597834291, 11804480340100333725] pub const KEY_SIGNER: &str = "_signer"; // hash(KEY_TYPE) = [17948789436443445142, 12513915140657440811, 15878361618879468769, 938231894693848619] @@ -18,7 +20,7 @@ pub const STATEMENT_ARG_F_LEN: usize = 8; pub const OPERATION_ARG_F_LEN: usize = 1; pub const OPERATION_AUX_F_LEN: usize = 1; -#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash)] pub enum NativePredicate { None = 0, ValueOf = 1, @@ -31,6 +33,14 @@ pub enum NativePredicate { SumOf = 8, ProductOf = 9, MaxOf = 10, + + // Syntactic sugar predicates. These predicates are not supported by the backend. The + // frontend compiler is responsible of translating these predicates into the predicates above. + DictContains = 1000, + DictNotContains = 1001, + SetContains = 1002, + SetNotContains = 1003, + ArrayContains = 1004, // there is no ArrayNotContains } impl ToFields for NativePredicate { @@ -39,8 +49,41 @@ impl ToFields for NativePredicate { } } +#[derive(Clone, Debug, PartialEq)] +pub enum WildcardValue { + PodId(PodId), + Key(Key), +} + +impl WildcardValue { + pub fn raw(&self) -> RawValue { + match self { + WildcardValue::PodId(pod_id) => RawValue::from(pod_id.0), + WildcardValue::Key(key) => key.raw(), + } + } +} + +impl fmt::Display for WildcardValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WildcardValue::PodId(pod_id) => write!(f, "{}", pod_id), + WildcardValue::Key(key) => write!(f, "{}", key), + } + } +} + +impl ToFields for WildcardValue { + fn to_fields(&self, params: &Params) -> Vec { + match self { + WildcardValue::PodId(pod_id) => pod_id.to_fields(params), + WildcardValue::Key(key) => key.to_fields(params), + } + } +} + /// Type encapsulating statements with their associated arguments. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq)] pub enum Statement { None, ValueOf(AnchoredKey, Value), @@ -57,7 +100,7 @@ pub enum Statement { SumOf(AnchoredKey, AnchoredKey, AnchoredKey), ProductOf(AnchoredKey, AnchoredKey, AnchoredKey), MaxOf(AnchoredKey, AnchoredKey, AnchoredKey), - Custom(CustomPredicateRef, Vec), + Custom(CustomPredicateRef, Vec), } impl Statement { @@ -95,7 +138,7 @@ impl Statement { Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], - Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(Key)), + Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(WildcardLiteral)), } } pub fn from_args(pred: Predicate, args: Vec) -> Result { @@ -103,35 +146,45 @@ impl Statement { let st: Result = match pred { Native(NativePredicate::None) => Ok(Self::None), Native(NativePredicate::ValueOf) => { - if let (StatementArg::Key(a0), StatementArg::Literal(v1)) = (args[0], args[1]) { + if let (StatementArg::Key(a0), StatementArg::Literal(v1)) = + (args[0].clone(), args[1].clone()) + { Ok(Self::ValueOf(a0, v1)) } else { Err(anyhow!("Incorrect statement args")) } } Native(NativePredicate::Equal) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = + (args[0].clone(), args[1].clone()) + { Ok(Self::Equal(a0, a1)) } else { Err(anyhow!("Incorrect statement args")) } } Native(NativePredicate::NotEqual) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = + (args[0].clone(), args[1].clone()) + { Ok(Self::NotEqual(a0, a1)) } else { Err(anyhow!("Incorrect statement args")) } } Native(NativePredicate::Gt) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = + (args[0].clone(), args[1].clone()) + { Ok(Self::Gt(a0, a1)) } else { Err(anyhow!("Incorrect statement args")) } } Native(NativePredicate::Lt) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = + (args[0].clone(), args[1].clone()) + { Ok(Self::Lt(a0, a1)) } else { Err(anyhow!("Incorrect statement args")) @@ -139,7 +192,7 @@ impl Statement { } Native(NativePredicate::Contains) => { if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = - (args[0], args[1], args[2]) + (args[0].clone(), args[1].clone(), args[2].clone()) { Ok(Self::Contains(a0, a1, a2)) } else { @@ -147,7 +200,9 @@ impl Statement { } } Native(NativePredicate::NotContains) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = + (args[0].clone(), args[1].clone()) + { Ok(Self::NotContains(a0, a1)) } else { Err(anyhow!("Incorrect statement args")) @@ -155,7 +210,7 @@ impl Statement { } Native(NativePredicate::SumOf) => { if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = - (args[0], args[1], args[2]) + (args[0].clone(), args[1].clone(), args[2].clone()) { Ok(Self::SumOf(a0, a1, a2)) } else { @@ -164,7 +219,7 @@ impl Statement { } Native(NativePredicate::ProductOf) => { if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = - (args[0], args[1], args[2]) + (args[0].clone(), args[1].clone(), args[2].clone()) { Ok(Self::ProductOf(a0, a1, a2)) } else { @@ -173,23 +228,24 @@ impl Statement { } Native(NativePredicate::MaxOf) => { if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = - (args[0], args[1], args[2]) + (args[0].clone(), args[1].clone(), args[2].clone()) { Ok(Self::MaxOf(a0, a1, a2)) } else { Err(anyhow!("Incorrect statement args")) } } + Native(np) => Err(anyhow!("Predicate {:?} is syntax sugar", np)), BatchSelf(_) => unreachable!(), Custom(cpr) => { - let ak_args: Result> = args + let v_args: Result> = args .iter() .map(|x| match x { - StatementArg::Key(ak) => Ok(*ak), + StatementArg::WildcardLiteral(v) => Ok(v.clone()), _ => Err(anyhow!("Incorrect statement args")), }) .collect(); - Ok(Self::Custom(cpr, ak_args?)) + Ok(Self::Custom(cpr, v_args?)) } }; st @@ -207,23 +263,24 @@ impl ToFields for Statement { impl fmt::Display for Statement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:?} ", self.predicate())?; + write!(f, "{}(", self.predicate())?; for (i, arg) in self.args().iter().enumerate() { if i != 0 { - write!(f, " ")?; + write!(f, ", ")?; } write!(f, "{}", arg)?; } - Ok(()) + write!(f, ")") } } /// Statement argument type. Useful for statement decompositions. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq)] pub enum StatementArg { None, Literal(Value), Key(AnchoredKey), + WildcardLiteral(WildcardValue), } impl fmt::Display for StatementArg { @@ -231,7 +288,8 @@ impl fmt::Display for StatementArg { match self { StatementArg::None => write!(f, "none"), StatementArg::Literal(v) => write!(f, "{}", v), - StatementArg::Key(r) => write!(f, "{}.{}", r.0, r.1), + StatementArg::Key(r) => write!(f, "{}.{}", r.pod_id, r.key), + StatementArg::WildcardLiteral(v) => write!(f, "{}", v), } } } @@ -242,13 +300,13 @@ impl StatementArg { } pub fn literal(&self) -> Result { match self { - Self::Literal(value) => Ok(*value), + Self::Literal(value) => Ok(value.clone()), _ => Err(anyhow!("Statement argument {:?} is not a literal.", self)), } } pub fn key(&self) -> Result { match self { - Self::Key(ak) => Ok(*ak), + Self::Key(ak) => Ok(ak.clone()), _ => Err(anyhow!("Statement argument {:?} is not a key.", self)), } } @@ -265,16 +323,23 @@ impl ToFields for StatementArg { // dealing with `Literal` it would be of length 4. let f = match self { StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN], - StatementArg::Literal(v) => { - v.0.into_iter() - .chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE)) - .collect() - } + StatementArg::Literal(v) => v + .raw() + .0 + .into_iter() + .chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE)) + .collect(), StatementArg::Key(ak) => { - let mut fields = ak.0.to_fields(_params); - fields.extend(ak.1.to_fields(_params)); + let mut fields = ak.pod_id.to_fields(_params); + fields.extend(ak.key.to_fields(_params)); fields } + StatementArg::WildcardLiteral(v) => v + .raw() + .0 + .into_iter() + .chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE)) + .collect(), }; assert_eq!(f.len(), STATEMENT_ARG_F_LEN); // sanity check f diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index e969528..0000000 --- a/src/util.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::{collections::HashMap, fmt::Debug, hash::Hash}; - -use anyhow::{anyhow, Result}; - -pub(crate) fn hashmap_insert_no_dupe( - hm: &mut HashMap, - kv: (S, T), -) -> Result<()> { - let (k, v) = kv.clone(); - let res = hm.insert(kv.0, kv.1); - match res { - Some(w) if w != v => Err(anyhow!( - "Key {:?} exists in table with value {:?} != {:?}.", - k, - w, - v - )), - _ => Ok(()), - } -}