diff --git a/src/examples/custom.rs b/src/examples/custom.rs index e88ca93..81a63ec 100644 --- a/src/examples/custom.rs +++ b/src/examples/custom.rs @@ -3,11 +3,11 @@ use std::sync::Arc; use anyhow::Result; use crate::{ - frontend::{literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, - middleware::{ - CustomPredicateBatch, CustomPredicateRef, NativePredicate, Params, PodType, Predicate, - KEY_SIGNER, KEY_TYPE, + frontend::{ + literal, CustomPredicateBatch, CustomPredicateBatchBuilder, CustomPredicateRef, + NativePredicate, Predicate, StatementTmplBuilder, Value, }, + middleware::{self, Params, PodType, KEY_SIGNER, KEY_TYPE}, }; use NativePredicate as NP; @@ -27,7 +27,7 @@ pub fn eth_friend_batch(params: &Params) -> Result> { // there is an attestation pod that's a SignedPod STB::new(NP::ValueOf) .arg(("attestation_pod", literal(KEY_TYPE))) - .arg(PodType::MockSigned), // TODO + .arg(middleware::Value::from(PodType::MockSigned)), // TODO // the attestation pod is signed by (src_or, src_key) STB::new(NP::Equal) .arg(("attestation_pod", literal(KEY_SIGNER))) @@ -37,6 +37,7 @@ pub fn eth_friend_batch(params: &Params) -> Result> { .arg(("attestation_pod", literal("attestation"))) .arg(("dst_ori", "dst_key")), ], + "eth_friend", )?; println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); @@ -45,7 +46,7 @@ pub fn eth_friend_batch(params: &Params) -> Result> { /// Instantiates an ETHDoS batch pub fn eth_dos_batch(params: &Params) -> Result> { - let eth_friend = Predicate::Custom(CustomPredicateRef(eth_friend_batch(params)?, 0)); + let eth_friend = Predicate::Custom(CustomPredicateRef::new(eth_friend_batch(params)?, 0)); let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into()); // eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and< @@ -74,6 +75,7 @@ pub fn eth_dos_batch(params: &Params) -> Result> { .arg(("distance_ori", "distance_key")) .arg(0), ], + "eth_dos_distance_base", )?; println!( "b.0. eth_dos_distance_base = {}", @@ -119,6 +121,7 @@ pub fn eth_dos_batch(params: &Params) -> Result> { .arg(("intermed_ori", "intermed_key")) .arg(("dst_ori", "dst_key")), ], + "eth_dos_distance_ind", )?; println!( @@ -147,6 +150,7 @@ pub fn eth_dos_batch(params: &Params) -> Result> { .arg(("dst_ori", "dst_key")) .arg(("distance_ori", "distance_key")), ], + "eth_dos_distance", )?; println!( diff --git a/src/examples/mod.rs b/src/examples/mod.rs index b52ad35..38e8b9b 100644 --- a/src/examples/mod.rs +++ b/src/examples/mod.rs @@ -5,11 +5,11 @@ use custom::{eth_dos_batch, eth_friend_batch}; use std::collections::HashMap; use crate::backends::plonky2::mock::signedpod::MockSigner; +use crate::frontend::CustomPredicateRef; use crate::frontend::{ containers::{Dictionary, Set}, MainPodBuilder, SignedPod, SignedPodBuilder, Statement, Value, }; -use crate::middleware::CustomPredicateRef; use crate::middleware::{Params, PodType, KEY_SIGNER, KEY_TYPE}; use crate::op; @@ -94,11 +94,11 @@ pub fn eth_dos_pod_builder( bob_pubkey: &Value, ) -> Result { // Will need ETH friend and ETH DoS custom predicate batches. - let eth_friend = CustomPredicateRef(eth_friend_batch(params)?, 0); + let eth_friend = CustomPredicateRef::new(eth_friend_batch(params)?, 0); let eth_dos_batch = eth_dos_batch(params)?; - let eth_dos_base = CustomPredicateRef(eth_dos_batch.clone(), 0); - let eth_dos_ind = CustomPredicateRef(eth_dos_batch.clone(), 1); - let eth_dos = CustomPredicateRef(eth_dos_batch.clone(), 2); + let eth_dos_base = CustomPredicateRef::new(eth_dos_batch.clone(), 0); + let eth_dos_ind = CustomPredicateRef::new(eth_dos_batch.clone(), 1); + let eth_dos = CustomPredicateRef::new(eth_dos_batch.clone(), 2); // ETHDoS POD builder let mut alice_bob_ethdos = MainPodBuilder::new(params); diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 9ef8e8d..2074ad6 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -1,34 +1,94 @@ #![allow(unused)] -use anyhow::Result; +use anyhow::{anyhow, Result}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::iter::zip; use std::sync::Arc; +use std::{fmt, hash as h, iter}; -use crate::middleware::{ - hash_str, CustomPredicate, CustomPredicateBatch, Hash, HashOrWildcard, NativePredicate, Params, - Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, F, -}; +use crate::middleware::{self, hash_str, HashOrWildcard, Params, PodId, ToFields}; +use crate::util::hashmap_insert_no_dupe; +use super::{AnchoredKey, NativePredicate, Origin, Statement, StatementArg, Value}; + +#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] /// Argument to a statement template -pub enum HashOrWildcardStr { - Hash(Hash), // represents a literal key +pub enum KeyOrWildcardStr { + Key(String), // represents a literal key Wildcard(String), } -/// helper to build a literal HashOrWildcardStr::Hash from the given str -pub fn literal(s: &str) -> HashOrWildcardStr { - HashOrWildcardStr::Hash(hash_str(s)) +#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] +pub struct IndexedWildcard { + wildcard: String, + index: usize, } -/// helper to build a HashOrWildcardStr::Wildcard from the given str. For the +impl IndexedWildcard { + pub fn new(wildcard: String, index: usize) -> Self { + Self { wildcard, index } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", content = "value")] +/// Represents a key or resolved wildcard +pub enum KeyOrWildcard { + Key(String), + Wildcard(IndexedWildcard), +} + +impl KeyOrWildcard { + /// Matches a key or wildcard against a value, returning a pair + /// representing a wildcard binding (if any) or an error if no + /// match is possible. + pub fn match_against(&self, v: &Value) -> Result> { + match self { + KeyOrWildcard::Key(k) if Value::from(k.as_str()) == *v => Ok(None), + KeyOrWildcard::Wildcard(i) => Ok(Some((i.index, v.clone()))), + _ => Err(anyhow!( + "Failed to match key or wildcard {} against value {}.", + self, + v + )), + } + } +} + +impl From for middleware::HashOrWildcard { + fn from(v: KeyOrWildcard) -> Self { + match v { + KeyOrWildcard::Key(k) => middleware::HashOrWildcard::Hash(hash_str(&k)), + KeyOrWildcard::Wildcard(n) => middleware::HashOrWildcard::Wildcard(n.index), + } + } +} +impl fmt::Display for KeyOrWildcard { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Key(k) => write!(f, "{}", k), + Self::Wildcard(n) => write!(f, "*{}", n.wildcard), + } + } +} + +/// helper to build a literal KeyOrWildcardStr::Key from the given str +pub fn literal(s: &str) -> KeyOrWildcardStr { + KeyOrWildcardStr::Key(s.to_string()) +} + +/// helper to build a KeyOrWildcardStr::Wildcard from the given str. For the /// moment this method does not need to be public. -fn wildcard(s: &str) -> HashOrWildcardStr { - HashOrWildcardStr::Wildcard(s.to_string()) +fn wildcard(s: &str) -> KeyOrWildcardStr { + KeyOrWildcardStr::Wildcard(s.to_string()) } /// Builder Argument for the StatementTmplBuilder pub enum BuilderArg { Literal(Value), /// Key: (origin, key), where origin & key can be both Hash or Wildcard - Key(HashOrWildcardStr, HashOrWildcardStr), + Key(KeyOrWildcardStr, KeyOrWildcardStr), } /// When defining a `BuilderArg`, it can be done from 3 different inputs: @@ -37,11 +97,11 @@ pub enum BuilderArg { /// iii. Value: this is to define a literal value, ie. 0 /// /// case i. -impl From<(&str, HashOrWildcardStr)> for BuilderArg { - fn from((origin, lit): (&str, HashOrWildcardStr)) -> Self { +impl From<(&str, KeyOrWildcardStr)> for BuilderArg { + fn from((origin, lit): (&str, KeyOrWildcardStr)) -> Self { // ensure that `lit` is of HashOrWildcardStr::Hash type match lit { - HashOrWildcardStr::Hash(_) => (), + KeyOrWildcardStr::Key(_) => (), _ => panic!("not supported"), }; Self::Key(wildcard(origin), lit) @@ -63,6 +123,251 @@ where } } +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", content = "value")] +pub enum Predicate { + Native(NativePredicate), + BatchSelf(usize), + Custom(CustomPredicateRef), +} + +impl From for Predicate { + fn from(v: NativePredicate) -> Self { + Self::Native(v) + } +} + +impl From for middleware::Predicate { + fn from(v: Predicate) -> Self { + match v { + Predicate::Native(p) => middleware::Predicate::Native(p.into()), + Predicate::BatchSelf(i) => middleware::Predicate::BatchSelf(i), + Predicate::Custom(CustomPredicateRef { + batch: pb, + index: i, + }) => { + let cpb: middleware::CustomPredicateBatch = Arc::unwrap_or_clone(pb).into(); + middleware::Predicate::Custom(middleware::CustomPredicateRef(Arc::new(cpb), i)) + } + } + } +} + +impl fmt::Display for Predicate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Native(p) => write!(f, "{:?}", p), + Self::BatchSelf(i) => write!(f, "self.{}", i), + Self::Custom(CustomPredicateRef { batch, index }) => { + write!(f, "{}.{}", batch.name, index) + } + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +pub struct CustomPredicateRef { + pub batch: Arc, + pub index: usize, +} + +impl From for middleware::CustomPredicateRef { + fn from(v: CustomPredicateRef) -> Self { + let cpb: middleware::CustomPredicateBatch = Arc::unwrap_or_clone(v.batch).into(); + middleware::CustomPredicateRef(Arc::new(cpb), v.index) + } +} + +impl CustomPredicateRef { + pub fn new(batch: Arc, index: usize) -> Self { + Self { batch, index } + } + + pub fn arg_len(&self) -> usize { + self.batch.predicates[self.index].args_len + } + pub fn match_against(&self, statements: &[Statement]) -> Result> { + let mut bindings = HashMap::new(); + // Single out custom predicate, replacing batch-self + // references with custom predicate references. + let custom_predicate = { + let cp = &Arc::unwrap_or_clone(self.batch.clone()).predicates[self.index]; + CustomPredicate { + conjunction: cp.conjunction, + statements: cp + .statements + .iter() + .map(|StatementTmpl { pred: p, args }| StatementTmpl { + pred: match p { + Predicate::BatchSelf(i) => { + Predicate::Custom(CustomPredicateRef::new(self.batch.clone(), *i)) + } + _ => p.clone(), + }, + args: args.to_vec(), + }) + .collect(), + args_len: cp.args_len, + name: cp.name.to_string(), + } + }; + match custom_predicate.conjunction { + true if custom_predicate.statements.len() == statements.len() => { + // Match op args against statement templates + let match_bindings = iter::zip(custom_predicate.statements, statements).map( + |(s_tmpl, s)| s_tmpl.match_against(s) + ).collect::>>() + .map(|v| v.concat())?; + // Add bindings to binding table, throwing if there is an inconsistency. + match_bindings.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; + Ok(bindings) + }, + false if statements.len() == 1 => { + // Match op arg against each statement template + custom_predicate.statements.iter().map( + |s_tmpl| { + let mut bindings = bindings.clone(); + s_tmpl.match_against(&statements[0])?.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?; + Ok::<_, anyhow::Error>(bindings) + } + ).find(|m| m.is_ok()).unwrap_or(Err(anyhow!("Statement {} does not match disjunctive custom predicate {}.", &statements[0], custom_predicate))) + }, + _ => Err(anyhow!("Custom predicate statement template list {:?} does not match op argument list {:?}.", custom_predicate.statements, statements)) + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +pub struct CustomPredicateBatch { + pub name: String, + pub predicates: Vec, +} + +impl From for middleware::CustomPredicateBatch { + fn from(v: CustomPredicateBatch) -> Self { + middleware::CustomPredicateBatch { + name: v.name, + predicates: v.predicates.into_iter().map(|p| p.into()).collect(), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +pub struct CustomPredicate { + /// NOTE: fields are not public (outside of crate) to enforce the struct instantiation through + /// the `::and/or` methods, which performs checks on the values. + + /// true for "and", false for "or" + pub(crate) conjunction: bool, + pub(crate) statements: Vec, + pub(crate) args_len: usize, + // TODO: Add private args length? + // TODO: Add args type information? + pub(crate) name: String, +} + +impl CustomPredicate { + pub fn and( + params: &Params, + statements: Vec, + args_len: usize, + name: &str, + ) -> Result { + Self::new(params, true, statements, args_len, name) + } + pub fn or( + params: &Params, + statements: Vec, + args_len: usize, + name: &str, + ) -> Result { + Self::new(params, false, statements, args_len, name) + } + pub fn new( + params: &Params, + conjunction: bool, + statements: Vec, + args_len: usize, + name: &str, + ) -> Result { + if statements.len() > params.max_custom_predicate_arity { + return Err(anyhow!("Custom predicate depends on too many statements")); + } + + Ok(Self { + conjunction, + statements, + args_len, + name: name.to_string(), + }) + } +} + +impl From for middleware::CustomPredicate { + fn from(v: CustomPredicate) -> Self { + middleware::CustomPredicate { + conjunction: v.conjunction, + statements: v.statements.into_iter().map(|s| s.into()).collect(), + args_len: v.args_len, + } + } +} +impl fmt::Display for CustomPredicate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?; + for st in &self.statements { + write!(f, " {}", st.pred)?; + for (i, arg) in st.args.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; + } + writeln!(f, "),")?; + } + write!(f, ">(")?; + for i in 0..self.args_len { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "*{}", i)?; + } + writeln!(f, ")")?; + Ok(()) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", content = "value")] +pub enum StatementTmplArg { + None, + Literal(Value), + // #[schemars(with = "Vec")] + Key(KeyOrWildcard, KeyOrWildcard), +} + +impl From for middleware::StatementTmplArg { + fn from(v: StatementTmplArg) -> Self { + match v { + StatementTmplArg::None => middleware::StatementTmplArg::None, + StatementTmplArg::Literal(v) => middleware::StatementTmplArg::Literal((&v).into()), + StatementTmplArg::Key(pod_id, key) => { + middleware::StatementTmplArg::Key(pod_id.into(), key.into()) + } + } + } +} + +impl fmt::Display for StatementTmplArg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::None => write!(f, "none"), + Self::Literal(v) => write!(f, "{}", v), + Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key), + } + } +} + pub struct StatementTmplBuilder { predicate: Predicate, args: Vec, @@ -82,6 +387,83 @@ impl StatementTmplBuilder { } } +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +pub struct StatementTmpl { + pub pred: Predicate, + pub args: Vec, +} + +impl StatementTmpl { + pub fn pred(&self) -> &Predicate { + &self.pred + } + pub fn args(&self) -> &[StatementTmplArg] { + &self.args + } + /// Matches a statement template against a statement, returning + /// the variable bindings as an association list. Returns an error + /// if there is type or argument mismatch. + pub fn match_against(&self, s: &Statement) -> Result> { + type P = Predicate; + if matches!( + self, + Self { + pred: P::BatchSelf(_), + args: _ + } + ) { + Err(anyhow!( + "Cannot check self-referencing statement templates." + )) + } else if self.pred() != &s.predicate { + Err(anyhow!("Type mismatch between {:?} and {}.", self, s)) + } else { + zip(self.args(), s.args.clone()) + .map(|(t_arg, s_arg)| t_arg.match_against(&s_arg)) + .collect::>>() + .map(|v| v.concat()) + } + } +} + +impl From for middleware::StatementTmpl { + fn from(v: StatementTmpl) -> Self { + middleware::StatementTmpl( + v.pred.into(), + v.args.into_iter().map(|a| a.into()).collect(), + ) + } +} + +impl StatementTmplArg { + /// Matches a statement template argument against a statement + /// argument, returning a wildcard correspondence in the case of + /// one or more wildcard matches, nothing in the case of a + /// literal/hash match, and an error otherwise. + pub fn match_against(&self, s_arg: &StatementArg) -> Result> { + match (self, s_arg) { + // (Self::None, StatementArg::None) => Ok(vec![]), + (Self::Literal(v), StatementArg::Literal(w)) if v == w => Ok(vec![]), + ( + Self::Key(tmpl_o, tmpl_k), + StatementArg::Key(AnchoredKey { + origin: Origin { pod_id: PodId(o) }, + key: k, + }), + ) => { + let o_corr = tmpl_o.match_against(&(middleware::Value::from(*o)).into())?; + let k_corr = tmpl_k.match_against(&(*k.as_str()).into())?; + Ok([o_corr, k_corr].into_iter().flatten().collect()) + } + _ => Err(anyhow!( + "Failed to match statement template argument {:?} against statement argument {:?}.", + self, + s_arg + )), + } + } +} + pub struct CustomPredicateBatchBuilder { pub name: String, pub predicates: Vec, @@ -101,8 +483,9 @@ impl CustomPredicateBatchBuilder { args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], + name: &str, ) -> Result { - self.predicate(params, true, args, priv_args, sts) + self.predicate(params, true, args, priv_args, sts, name) } pub fn predicate_or( @@ -111,8 +494,9 @@ impl CustomPredicateBatchBuilder { args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], + name: &str, ) -> Result { - self.predicate(params, false, args, priv_args, sts) + self.predicate(params, false, args, priv_args, sts, name) } /// creates the custom predicate from the given input, adds it to the @@ -124,6 +508,7 @@ impl CustomPredicateBatchBuilder { args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], + name: &str, ) -> Result { let statements = sts .iter() @@ -132,17 +517,21 @@ impl CustomPredicateBatchBuilder { .args .iter() .map(|a| match a { - BuilderArg::Literal(v) => StatementTmplArg::Literal(*v), + BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), BuilderArg::Key(pod_id, key) => StatementTmplArg::Key( resolve_wildcard(args, priv_args, pod_id), resolve_wildcard(args, priv_args, key), ), }) .collect(); - StatementTmpl(sb.predicate.clone(), args) + StatementTmpl { + pred: sb.predicate.clone(), + args, + } }) .collect(); - let custom_predicate = CustomPredicate::new(params, conjunction, statements, args.len())?; + let custom_predicate = + CustomPredicate::new(params, conjunction, statements, args.len(), name)?; self.predicates.push(custom_predicate); Ok(Predicate::BatchSelf(self.predicates.len() - 1)) } @@ -155,14 +544,14 @@ impl CustomPredicateBatchBuilder { } } -fn resolve_wildcard(args: &[&str], priv_args: &[&str], v: &HashOrWildcardStr) -> HashOrWildcard { +fn resolve_wildcard(args: &[&str], priv_args: &[&str], v: &KeyOrWildcardStr) -> KeyOrWildcard { match v { - HashOrWildcardStr::Hash(h) => HashOrWildcard::Hash(*h), - HashOrWildcardStr::Wildcard(s) => HashOrWildcard::Wildcard( + KeyOrWildcardStr::Key(k) => KeyOrWildcard::Key(k.clone()), + KeyOrWildcardStr::Wildcard(s) => KeyOrWildcard::Wildcard( args.iter() .chain(priv_args.iter()) .enumerate() - .find_map(|(i, name)| (&s == name).then_some(i)) + .find_map(|(i, name)| (&s == name).then_some(IndexedWildcard::new(s.clone(), i))) .unwrap(), ), } @@ -173,7 +562,8 @@ mod tests { use super::*; use crate::{ examples::custom::{eth_dos_batch, eth_friend_batch}, - middleware::{CustomPredicateRef, Params, PodType}, + middleware, + // middleware::{CustomPredicateRef, Params, PodType}, }; #[test] @@ -188,10 +578,12 @@ mod tests { let eth_friend = eth_friend_batch(¶ms)?; // This batch only has 1 predicate, so we pick it already for convenience - let eth_friend = Predicate::Custom(CustomPredicateRef(eth_friend, 0)); + let eth_friend = Predicate::Custom(CustomPredicateRef::new(eth_friend, 0)); let eth_dos_batch = eth_dos_batch(¶ms)?; - let fields = eth_dos_batch.to_fields(¶ms); + let eth_dos_batch_mw: middleware::CustomPredicateBatch = + Arc::unwrap_or_clone(eth_dos_batch).into(); + let fields = eth_dos_batch_mw.to_fields(¶ms); println!("Batch b, serialized: {:?}", fields); Ok(()) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index d891a7e..a9679e4 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -8,13 +8,12 @@ use crate::middleware::{ use crate::middleware::{KEY_SIGNER, KEY_TYPE}; use anyhow::{anyhow, Error, Result}; use containers::{Array, Dictionary, Set}; -use env_logger; use itertools::Itertools; -use log::error; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::convert::From; +use std::hash::Hasher; use std::{fmt, hash as h}; use crate::middleware::{hash_value, OperationAux, EMPTY_VALUE}; @@ -26,6 +25,7 @@ mod predicate; mod serialization; mod statement; pub use custom::*; +pub use custom::{CustomPredicateRef, Predicate}; pub use operation::*; pub use predicate::*; pub use statement::*; @@ -41,13 +41,12 @@ pub enum PodClass { // An Origin, which represents a reference to an ancestor POD. #[derive(Clone, Debug, PartialEq, Eq, h::Hash, Default, Serialize, Deserialize, JsonSchema)] pub struct Origin { - pub pod_class: PodClass, pub pod_id: PodId, } impl Origin { - pub fn new(pod_class: PodClass, pod_id: PodId) -> Self { - Self { pod_class, pod_id } + pub fn new(pod_id: PodId) -> Self { + Self { pod_id } } } @@ -87,6 +86,24 @@ pub enum Value { Bool(bool), } +impl h::Hash for Value { + fn hash(&self, state: &mut H) { + // Hash the discriminant first + std::mem::discriminant(self).hash(state); + + // Hash the inner values only for types that implement Hash + match self { + Value::String(s) => s.hash(state), + Value::Int(i) => i.hash(state), + Value::Bool(b) => b.hash(state), + Value::Dictionary(d) => d.middleware_dict().commitment().hash(state), + Value::Set(s) => s.middleware_set().commitment().hash(state), + Value::Array(a) => a.middleware_array().commitment().hash(state), + Value::Raw(r) => r.hash(state), + } + } +} + impl From<&str> for Value { fn from(s: &str) -> Self { Value::String(s.to_string()) @@ -256,7 +273,7 @@ impl SignedPod { self.pod.id() } pub fn origin(&self) -> Origin { - Origin::new(PodClass::Signed, self.id()) + Origin::new(self.id()) } pub fn verify(&self) -> Result<()> { self.pod.verify() @@ -299,7 +316,6 @@ pub struct MainPodBuilder { // Internal state const_cnt: usize, key_table: HashMap, - pod_class_table: HashMap, } impl fmt::Display for MainPodBuilder { @@ -334,7 +350,6 @@ impl MainPodBuilder { public_statements: Vec::new(), const_cnt: 0, key_table: HashMap::new(), - pod_class_table: HashMap::from_iter([(SELF, PodClass::Main)]), } } pub fn add_signed_pod(&mut self, pod: &SignedPod) { @@ -343,25 +358,19 @@ impl MainPodBuilder { pod.kvs.iter().for_each(|(key, _)| { self.key_table.insert(hash_str(key), key.clone()); }); - // Add POD class to POD class table. - self.pod_class_table.insert(pod.id(), PodClass::Signed); } pub fn add_main_pod(&mut self, pod: MainPod) { - // Add POD class to POD class table. - self.pod_class_table.insert(pod.id(), PodClass::Main); // Add key-hash and POD ID-class correspondences to tables. pod.public_statements .iter() .flat_map(|s| &s.args) .flat_map(|arg| match arg { - StatementArg::Key(AnchoredKey { - origin: Origin { pod_class, pod_id }, - key, - }) => Some((*pod_id, pod_class.clone(), hash_str(key), key.clone())), + StatementArg::Key(AnchoredKey { origin: _, key }) => { + Some((hash_str(key), key.clone())) + } _ => None, }) - .for_each(|(pod_id, pod_class, hash, key)| { - self.pod_class_table.insert(pod_id, pod_class); + .for_each(|(hash, key)| { self.key_table.insert(hash, key); }); self.input_main_pods.push(pod); @@ -396,7 +405,7 @@ impl MainPodBuilder { } OperationArg::Entry(k, v) => { st_args.push(StatementArg::Key(AnchoredKey::new( - Origin::new(PodClass::Main, SELF), + Origin::new(SELF), k.clone(), ))); st_args.push(StatementArg::Literal(v.clone())) @@ -652,7 +661,7 @@ impl MainPodBuilder { // All args should be statements to be pattern matched against statement templates. let args = args.iter().map( |a| match a { - OperationArg::Statement(s) => Ok(middleware::Statement::try_from(s.clone())?), + OperationArg::Statement(s) => Ok(s.clone()), _ => Err(anyhow!("Invalid argument {} to operation corresponding to custom predicate {:?}.", a, cpr)) } ).collect::>>()?; @@ -672,15 +681,15 @@ impl MainPodBuilder { .chunks(2) .map(|chunk| { Ok(StatementArg::Key(AnchoredKey::new( - Origin::new( - self.pod_class_table - .get(&PodId(chunk[0].into())) - .cloned() - .ok_or(anyhow!("Missing POD class value."))?, - PodId(chunk[0].into()), - ), + Origin::new(PodId(match chunk[0] { + Value::Raw(v) => v.try_into()?, + _ => return Err(anyhow!("Invalid POD class value.")), + })), self.key_table - .get(&chunk[1].into()) + .get(&match &chunk[1] { + Value::String(s) => hash_str(s.as_str()), + _ => return Err(anyhow!("Invalid key value.")), + }) .cloned() .ok_or(anyhow!("Missing key corresponding to hash."))?, ))) @@ -768,7 +777,7 @@ impl MainPodBuilder { predicate: Predicate::Native(NativePredicate::ValueOf), args: vec![ StatementArg::Key(AnchoredKey::new( - Origin::new(PodClass::Main, pod_id), + Origin::new(pod_id), KEY_TYPE.to_string(), )), StatementArg::Literal(value.into()), @@ -788,14 +797,10 @@ impl MainPodBuilder { .into_iter() .map(|arg| match arg { StatementArg::Key(AnchoredKey { - origin: - Origin { - pod_class: class, - pod_id: id, - }, + origin: Origin { pod_id: id }, key, }) if id == SELF => { - StatementArg::Key(AnchoredKey::new(Origin::new(class, pod_id), key)) + StatementArg::Key(AnchoredKey::new(Origin::new(pod_id), key)) } _ => arg, }) @@ -839,7 +844,7 @@ impl MainPod { self.pod.id() } pub fn origin(&self) -> Origin { - Origin::new(PodClass::Main, self.id()) + Origin::new(self.id()) } } @@ -1020,9 +1025,9 @@ impl MainPodCompiler { // TODO: Take Merkle proof into account. let mop_args = op.1.iter() - .flat_map(|arg| self.compile_op_arg(arg).map(|op_arg| Ok(op_arg))) - .collect::>>()?; - middleware::Operation::op(mop_code, &mop_args, &op.2) + .flat_map(|arg| self.compile_op_arg(arg).map(|s| Ok(s.try_into()?))) + .collect::>>()?; + middleware::Operation::op(mop_code.into(), &mop_args, &op.2) } fn compile_st_op(&mut self, st: &Statement, op: &Operation, params: &Params) -> Result<()> { @@ -1340,7 +1345,7 @@ pub mod tests { vec![OperationArg::Statement(st1), OperationArg::Statement(st2)], OperationAux::None, ); - let st3 = builder.op(true, op_eq3); + builder.op(true, op_eq3).unwrap(); let mut prover = MockProver {}; let pod = builder.prove(&mut prover, ¶ms).unwrap(); @@ -1440,10 +1445,7 @@ pub mod tests { Statement::new( Predicate::Native(NativePredicate::ValueOf), vec![ - StatementArg::Key(AnchoredKey::new( - Origin::new(PodClass::Main, SELF), - "a".into(), - )), + StatementArg::Key(AnchoredKey::new(Origin::new(SELF), "a".into())), StatementArg::Literal(Value::Int(3)), ], ), @@ -1457,10 +1459,7 @@ pub mod tests { Statement::new( Predicate::Native(NativePredicate::ValueOf), vec![ - StatementArg::Key(AnchoredKey::new( - Origin::new(PodClass::Main, SELF), - "a".into(), - )), + StatementArg::Key(AnchoredKey::new(Origin::new(SELF), "a".into())), StatementArg::Literal(Value::Int(28)), ], ), @@ -1479,25 +1478,19 @@ pub mod tests { // right now the mock prover catches this when it calls compile() let params = Params::default(); let mut builder = MainPodBuilder::new(¶ms); - let self_a = AnchoredKey::new(Origin::new(PodClass::Main, SELF), "a".into()); - let self_b = AnchoredKey::new(Origin::new(PodClass::Main, SELF), "b".into()); + let self_a = AnchoredKey::new(Origin::new(SELF), "a".into()); + let self_b = AnchoredKey::new(Origin::new(SELF), "b".into()); let value_of_a = Statement::new( Predicate::Native(NativePredicate::ValueOf), vec![ - StatementArg::Key(AnchoredKey::new( - Origin::new(PodClass::Main, SELF), - "a".into(), - )), + StatementArg::Key(self_a.clone()), StatementArg::Literal(Value::Int(3)), ], ); let value_of_b = Statement::new( Predicate::Native(NativePredicate::ValueOf), vec![ - StatementArg::Key(AnchoredKey::new( - Origin::new(PodClass::Main, SELF), - "b".into(), - )), + StatementArg::Key(self_b.clone()), StatementArg::Literal(Value::Int(27)), ], ); diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index 9a6be7a..d996300 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -1,10 +1,9 @@ use std::fmt; -use super::{NativePredicate, Predicate, SignedPod, Statement, Value}; -use crate::{ - backends::plonky2::primitives::merkletree::MerkleProof, - middleware::{self, OperationAux}, -}; +use serde::{Deserialize, Serialize}; + +use super::{CustomPredicateRef, NativePredicate, Predicate, SignedPod, Statement, Value}; +use crate::middleware::{self, OperationAux}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum OperationArg { @@ -71,13 +70,13 @@ impl> From<(&str, V)> for OperationArg { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum OperationType { Native(NativeOperation), - Custom(middleware::CustomPredicateRef), + Custom(CustomPredicateRef), } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum NativeOperation { None = 0, NewEntry = 1, @@ -131,7 +130,7 @@ impl TryFrom for middleware::OperationType { MwOT::Native(MwNO::NotContainsFromEntries) } FeOT::Native(FeNO::ArrayContainsFromEntries) => MwOT::Native(MwNO::ContainsFromEntries), - FeOT::Custom(mw_cpr) => MwOT::Custom(mw_cpr), + FeOT::Custom(mw_cpr) => MwOT::Custom(mw_cpr.into()), }; Ok(mw_ot) } diff --git a/src/frontend/predicate.rs b/src/frontend/predicate.rs index 58a774b..ae52a61 100644 --- a/src/frontend/predicate.rs +++ b/src/frontend/predicate.rs @@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize}; use std::fmt; use super::{AnchoredKey, SignedPod, Value}; -//use crate::middleware::{self, NativePredicate, Predicate}; use crate::middleware::{self, CustomPredicateRef}; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] @@ -25,6 +24,29 @@ pub enum NativePredicate { ArrayContains = 15, // there is no ArrayNotContains } +impl From for middleware::NativePredicate { + fn from(np: NativePredicate) -> Self { + use middleware::NativePredicate as MidNP; + use NativePredicate::*; + match np { + None => MidNP::None, + ValueOf => MidNP::ValueOf, + Equal => MidNP::Equal, + NotEqual => MidNP::NotEqual, + Gt => MidNP::Gt, + Lt => MidNP::Lt, + SumOf => MidNP::SumOf, + ProductOf => MidNP::ProductOf, + MaxOf => MidNP::MaxOf, + DictContains => MidNP::Contains, + DictNotContains => MidNP::NotContains, + SetContains => MidNP::Contains, + SetNotContains => MidNP::NotContains, + ArrayContains => MidNP::Contains, + } + } +} + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] pub enum Predicate { Native(NativePredicate), diff --git a/src/frontend/serialization.rs b/src/frontend/serialization.rs index 199a67e..6a4425d 100644 --- a/src/frontend/serialization.rs +++ b/src/frontend/serialization.rs @@ -12,6 +12,7 @@ use crate::middleware::PodId; use super::{MainPod, SignedPod, Value}; #[derive(Serialize, Deserialize, JsonSchema)] +#[schemars(title = "SignedPod")] pub struct SignedPodHelper { entries: HashMap, proof: String, @@ -54,6 +55,7 @@ impl From for SignedPodHelper { } #[derive(Serialize, Deserialize, JsonSchema)] +#[schemars(title = "MainPod")] pub struct MainPodHelper { public_statements: Vec, proof: String, @@ -152,6 +154,7 @@ pub fn transform_value_schema(schema: &mut Schema) { #[cfg(test)] mod tests { use anyhow::Result; + use schemars::generate::SchemaSettings; use crate::{ backends::plonky2::mock::{mainpod::MockProver, signedpod::MockSigner}, @@ -297,4 +300,17 @@ mod tests { Ok(()) } + + #[test] + fn test_schema() { + let generator = SchemaSettings::draft07().into_generator(); + let mainpod_schema = generator.clone().into_root_schema_for::(); + let signedpod_schema = generator.into_root_schema_for::(); + + println!("{}", serde_json::to_string_pretty(&mainpod_schema).unwrap()); + println!( + "{}", + serde_json::to_string_pretty(&signedpod_schema).unwrap() + ); + } } diff --git a/src/frontend/statement.rs b/src/frontend/statement.rs index 8b7858c..d819b07 100644 --- a/src/frontend/statement.rs +++ b/src/frontend/statement.rs @@ -1,4 +1,5 @@ -use super::{AnchoredKey, NativePredicate, Predicate, SignedPod, Value}; +use super::{AnchoredKey, NativePredicate, SignedPod, Value}; +use crate::frontend::Predicate; use crate::middleware; use anyhow::{anyhow, Result}; use schemars::JsonSchema; @@ -130,7 +131,7 @@ impl TryFrom for middleware::Statement { _ => Err(anyhow!("Ill-formed statement: {}", s))?, }, Predicate::Custom(cpr) => MS::Custom( - cpr.clone(), + cpr.clone().into(), s.args .iter() .map(|arg| match arg { diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index d8379c9..9a0294c 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -5,9 +5,7 @@ use serde::{Deserialize, Serialize}; use std::{fmt, iter}; use strum_macros::FromRepr; -use super::{ - AnchoredKey, CustomPredicateRef, Params, Predicate, ToFields, Value, F, HASH_SIZE, VALUE_SIZE, -}; +use super::{AnchoredKey, CustomPredicateRef, Params, Predicate, ToFields, Value, F, VALUE_SIZE}; // hash(KEY_SIGNER) = [2145458785152392366, 15113074911296146791, 15323228995597834291, 11804480340100333725] pub const KEY_SIGNER: &str = "_signer";