diff --git a/src/backends/mock_main/mod.rs b/src/backends/mock_main/mod.rs index 4db4359..9008710 100644 --- a/src/backends/mock_main/mod.rs +++ b/src/backends/mock_main/mod.rs @@ -115,6 +115,11 @@ fn fill_pad(v: &mut Vec, pad_value: T, len: usize) { } } +/// Inputs are sorted as: +/// - SignedPods +/// - MainPods +/// - private Statements +/// - public Statements impl MockMainPod { fn offset_input_signed_pods(&self) -> usize { 0 @@ -136,6 +141,8 @@ impl MockMainPod { fill_pad(&mut op.1, OperationArg::None, params.max_operation_args) } + /// Returns the statements from the given MainPodInputs, padding to the + /// respective max lengths defined at the given Params. fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec { let mut statements = Vec::new(); @@ -259,8 +266,10 @@ impl MockMainPod { Ok(operations) } - // NOTE: In this implementation public statements are always copies from previous statements, - // so we fill in the operations accordingly. + // NOTE: In this implementation public statements are always copies from + // previous statements, so we fill in the operations accordingly. + /// This method assumes that the given `statements` array has been padded to + /// `params.max_statements`. fn process_public_statements_operations( params: &Params, statements: &[Statement], diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs new file mode 100644 index 0000000..b43fd9b --- /dev/null +++ b/src/frontend/custom.rs @@ -0,0 +1,315 @@ +#![allow(unused)] +use std::sync::Arc; + +use crate::middleware::{ + hash_str, CustomPredicate, CustomPredicateBatch, Hash, HashOrWildcard, NativePredicate, + Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, F, +}; + +/// Argument to an statement template +pub enum HashOrWildcardStr { + Hash(Hash), // represents a literal key + Wildcard(String), +} + +/// helper to build a literal HashOrWildcardStr::Hash from the given str +pub fn literal(s: &str) -> HashOrWildcardStr { + HashOrWildcardStr::Hash(hash_str(s)) +} + +/// helper to build a HashOrWildcardStr::Wildcard from the given str. For the +/// moment this method does not need to be public. +fn wildcard(s: &str) -> HashOrWildcardStr { + HashOrWildcardStr::Wildcard(s.to_string()) +} + +/// Builder Argument for the StatementTmplBuilder +pub enum BuilderArg { + Literal(Value), + /// Key: (origin, key), where origin & key can be both Hash or Wildcard + Key(HashOrWildcardStr, HashOrWildcardStr), +} + +/// When defining a `BuilderArg`, it can be done from 3 different inputs: +/// i. (&str, literal): this is to set a POD and a field, ie. (POD, literal("field")) +/// ii. (&str, &str): this is to define a origin-key wildcard pair, ie. (src_origin, src_dest) +/// iii. Value: this is to define a literal value, ie. 0 +/// +/// case i. +impl From<(&str, HashOrWildcardStr)> for BuilderArg { + fn from((origin, lit): (&str, HashOrWildcardStr)) -> Self { + // ensure that `lit` is of HashOrWildcardStr::Hash type + match lit { + HashOrWildcardStr::Hash(_) => (), + _ => panic!("not supported"), + }; + Self::Key(wildcard(&origin), lit) + } +} +/// case ii. +impl From<(&str, &str)> for BuilderArg { + fn from((origin, field): (&str, &str)) -> Self { + Self::Key(wildcard(&origin), wildcard(&field)) + } +} +/// case iii. +impl From for BuilderArg +where + V: Into, +{ + fn from(v: V) -> Self { + Self::Literal(v.into()) + } +} + +struct StatementTmplBuilder { + predicate: Predicate, + args: Vec, +} + +impl StatementTmplBuilder { + fn new(p: impl Into) -> StatementTmplBuilder { + StatementTmplBuilder { + predicate: p.into(), + args: Vec::new(), + } + } + + fn arg(mut self, a: impl Into) -> Self { + self.args.push(a.into()); + self + } +} + +struct CustomPredicateBatchBuilder { + name: String, + predicates: Vec, +} + +impl CustomPredicateBatchBuilder { + fn new(name: String) -> Self { + Self { + name, + predicates: Vec::new(), + } + } + + fn predicate_and( + &mut self, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + self.predicate(true, args, priv_args, sts) + } + + fn predicate_or( + &mut self, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + self.predicate(false, args, priv_args, sts) + } + + /// creates the custom predicate from the given input, adds it to the + /// self.predicates, and returns the index of the created predicate + fn predicate( + &mut self, + conjunction: bool, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + let statements = sts + .iter() + .map(|sb| { + let args = sb + .args + .iter() + .map(|a| match a { + BuilderArg::Literal(v) => StatementTmplArg::Literal(*v), + BuilderArg::Key(pod_id, key) => StatementTmplArg::Key( + resolve_wildcard(args, priv_args, pod_id), + resolve_wildcard(args, priv_args, key), + ), + }) + .collect(); + StatementTmpl(sb.predicate.clone(), args) + }) + .collect(); + let custom_predicate = CustomPredicate { + conjunction, + statements, + args_len: args.len(), + }; + self.predicates.push(custom_predicate); + Predicate::BatchSelf(self.predicates.len() - 1) + } + + fn finish(self) -> Arc { + Arc::new(CustomPredicateBatch { + name: self.name, + predicates: self.predicates, + }) + } +} + +fn resolve_wildcard(args: &[&str], priv_args: &[&str], v: &HashOrWildcardStr) -> HashOrWildcard { + match v { + HashOrWildcardStr::Hash(h) => HashOrWildcard::Hash(*h), + HashOrWildcardStr::Wildcard(s) => HashOrWildcard::Wildcard( + args.iter() + .chain(priv_args.iter()) + .enumerate() + .find_map(|(i, name)| (&s == name).then_some(i)) + .unwrap(), + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::middleware::PodType; + + #[test] + fn test_custom_pred() { + use NativePredicate as NP; + use StatementTmplBuilder as STB; + + let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into()); + let _eth_friend = builder.predicate_and( + // arguments: + &["src_ori", "src_key", "dst_ori", "dst_key"], + // private arguments: + &["attestation_pod"], + // statement templates: + &[ + // there is an attestation pod that's a SignedPod + STB::new(NP::ValueOf) + .arg(("attestation_pod", literal("type"))) + .arg(PodType::Signed), + // the attestation pod is signed by (src_or, src_key) + STB::new(NP::Equal) + .arg(("attestation_pod", literal("signer"))) + .arg(("src_ori", "src_key")), + // that same attestation pod has an "attestation" + STB::new(NP::Equal) + .arg(("attestation_pod", literal("attestation"))) + .arg(("dst_ori", "dst_key")), + ], + ); + + println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); + let eth_friend = builder.finish(); + // This batch only has 1 predicate, so we pick it already for convenience + let eth_friend = Predicate::Custom(eth_friend, 0); + + // next chunk builds: + // eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and< + // eq(src_or, src_key, dst_or, dst_key), + // ValueOf(distance_or, distance_key, 0) + // > + let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into()); + let eth_dos_distance_base = builder.predicate_and( + &[ + // arguments: + "src_ori", + "src_key", + "dst_ori", + "dst_key", + "distance_ori", + "distance_key", + ], + &[ // private arguments: + ], + &[ + // statement templates: + STB::new(NP::Equal) + .arg(("src_ori", "src_key")) + .arg(("dst_ori", "dst_key")), + STB::new(NP::ValueOf) + .arg(("distance_ori", "distance_key")) + .arg(0), + ], + ); + println!( + "b.0. eth_dos_distance_base = {}", + builder.predicates.last().unwrap() + ); + + let eth_dos_distance = Predicate::BatchSelf(3); + + // next chunk builds: + let eth_dos_distance_ind = builder.predicate_and( + &[ + // arguments: + "src_ori", + "src_key", + "dst_ori", + "dst_key", + "distance_ori", + "distance_key", + ], + &[ + // private arguments: + "one_ori", + "one_key", + "shorter_distance_ori", + "shorter_distance_key", + "intermed_ori", + "intermed_key", + ], + &[ + // statement templates: + STB::new(eth_dos_distance) + .arg(("src_ori", "src_key")) + .arg(("intermed_ori", "intermed_key")) + .arg(("shorter_distance_ori", "shorter_distance_key")), + // distance == shorter_distance + 1 + STB::new(NP::ValueOf).arg(("one_ori", "one_key")).arg(1), + STB::new(NP::SumOf) + .arg(("distance_ori", "distance_key")) + .arg(("shorter_distance_ori", "shorter_distance_key")) + .arg(("one_ori", "one_key")), + // intermed is a friend of dst + STB::new(eth_friend) + .arg(("intermed_ori", "intermed_key")) + .arg(("dst_ori", "dst_key")), + ], + ); + + println!( + "b.1. eth_dos_distance_ind = {}", + builder.predicates.last().unwrap() + ); + + let _eth_dos_distance = builder.predicate_or( + &[ + "src_ori", + "src_key", + "dst_ori", + "dst_key", + "distance_ori", + "distance_key", + ], + &[], + &[ + STB::new(eth_dos_distance_base) + .arg(("src_ori", "src_key")) + .arg(("dst_ori", "dst_key")) + .arg(("distance_ori", "distance_key")), + STB::new(eth_dos_distance_ind) + .arg(("src_ori", "src_key")) + .arg(("dst_ori", "dst_key")) + .arg(("distance_ori", "distance_key")), + ], + ); + + println!( + "b.2. eth_dos_distance = {}", + builder.predicates.last().unwrap() + ); + } +} diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index f1335cb..a781a2a 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -14,8 +14,10 @@ use crate::middleware::{ PodSigner, SELF, }; +mod custom; mod operation; mod statement; +pub use custom::*; pub use operation::*; pub use statement::*; diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 4bef99d..39f5c5f 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -51,7 +51,7 @@ impl fmt::Display for StatementTmplArg { /// Statement Template for a Custom Predicate #[derive(Debug)] -pub struct StatementTmpl(Predicate, Vec); +pub struct StatementTmpl(pub Predicate, pub Vec); #[derive(Debug)] pub struct CustomPredicate { @@ -63,6 +63,14 @@ pub struct CustomPredicate { // TODO: Add args type information? } +impl ToFields for CustomPredicate { + fn to_fields(self) -> (Vec, usize) { + todo!() + // let f: Vec = Vec::new(); + // (self.conjunction.to_f(), 1) + } +} + impl fmt::Display for CustomPredicate { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?; @@ -90,7 +98,8 @@ impl fmt::Display for CustomPredicate { #[derive(Debug)] pub struct CustomPredicateBatch { - predicates: Vec, + pub name: String, + pub predicates: Vec, } impl CustomPredicateBatch { @@ -115,7 +124,11 @@ impl From for Predicate { impl ToFields for Predicate { fn to_fields(self) -> (Vec, usize) { - todo!() + match self { + Self::Native(p) => p.to_fields(), + Self::BatchSelf(i) => Value::from(i as i64).to_fields(), + Self::Custom(_pb, _i) => todo!(), // TODO + } } } @@ -124,274 +137,7 @@ impl fmt::Display for Predicate { match self { Self::Native(p) => write!(f, "{:?}", p), Self::BatchSelf(i) => write!(f, "self.{}", i), - Self::Custom(pb, i) => write!(f, "{}.{}", pb.hash(), i), + Self::Custom(pb, i) => write!(f, "{}.{}", pb.name, i), } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::middleware::PodType; - - enum HashOrWildcardStr { - Hash(Hash), - Wildcard(String), - } - - fn l(s: &str) -> HashOrWildcardStr { - HashOrWildcardStr::Hash(hash_str(s)) - } - - fn w(s: &str) -> HashOrWildcardStr { - HashOrWildcardStr::Wildcard(s.to_string()) - } - - enum BuilderArg { - Literal(Value), - Key(HashOrWildcardStr, HashOrWildcardStr), - } - - impl From<(HashOrWildcardStr, HashOrWildcardStr)> for BuilderArg { - fn from((pod_id, key): (HashOrWildcardStr, HashOrWildcardStr)) -> Self { - Self::Key(pod_id, key) - } - } - - impl From for BuilderArg - where - V: Into, - { - fn from(v: V) -> Self { - Self::Literal(v.into()) - } - } - - struct StatementTmplBuilder { - predicate: Predicate, - args: Vec, - } - - fn st_tmpl(p: impl Into) -> StatementTmplBuilder { - StatementTmplBuilder { - predicate: p.into(), - args: Vec::new(), - } - } - - impl StatementTmplBuilder { - fn arg(mut self, a: impl Into) -> Self { - self.args.push(a.into()); - self - } - } - - struct CustomPredicateBatchBuilder { - predicates: Vec, - } - - impl CustomPredicateBatchBuilder { - fn new() -> Self { - Self { - predicates: Vec::new(), - } - } - - fn predicate_and( - &mut self, - args: &[&str], - priv_args: &[&str], - sts: &[StatementTmplBuilder], - ) -> Predicate { - self.predicate(true, args, priv_args, sts) - } - - fn predicate_or( - &mut self, - args: &[&str], - priv_args: &[&str], - sts: &[StatementTmplBuilder], - ) -> Predicate { - self.predicate(false, args, priv_args, sts) - } - - fn predicate( - &mut self, - conjunction: bool, - args: &[&str], - priv_args: &[&str], - sts: &[StatementTmplBuilder], - ) -> Predicate { - use BuilderArg as BA; - let statements = sts - .iter() - .map(|sb| { - let args = sb - .args - .iter() - .map(|a| match a { - BA::Literal(v) => StatementTmplArg::Literal(*v), - BA::Key(pod_id, key) => StatementTmplArg::Key( - resolve_wildcard(args, priv_args, pod_id), - resolve_wildcard(args, priv_args, key), - ), - }) - .collect(); - StatementTmpl(sb.predicate.clone(), args) - }) - .collect(); - let custom_predicate = CustomPredicate { - conjunction, - statements, - args_len: args.len(), - }; - self.predicates.push(custom_predicate); - Predicate::BatchSelf(self.predicates.len() - 1) - } - - fn finish(self) -> Arc { - Arc::new(CustomPredicateBatch { - predicates: self.predicates, - }) - } - } - - fn resolve_wildcard( - args: &[&str], - priv_args: &[&str], - v: &HashOrWildcardStr, - ) -> HashOrWildcard { - match v { - HashOrWildcardStr::Hash(h) => HashOrWildcard::Hash(*h), - HashOrWildcardStr::Wildcard(s) => HashOrWildcard::Wildcard( - args.iter() - .chain(priv_args.iter()) - .enumerate() - .find_map(|(i, name)| (&s == name).then_some(i)) - .unwrap(), - ), - } - } - - #[test] - fn test_custom_pred() { - use NativePredicate as NP; - - let mut builder = CustomPredicateBatchBuilder::new(); - let _eth_friend = builder.predicate_and( - &["src_or", "src_key", "dst_or", "dst_key"], - &["attestation_pod"], - &[ - st_tmpl(NP::ValueOf) - .arg((w("attestation_pod"), l("type"))) - .arg(PodType::Signed), - st_tmpl(NP::Equal) - .arg((w("attestation_pod"), l("signer"))) - .arg((w("src_or"), w("src_key"))), - st_tmpl(NP::Equal) - .arg((w("attestation_pod"), l("attestation"))) - .arg((w("dst_or"), w("dst_key"))), - ], - ); - - println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); - let eth_friend = builder.finish(); - // This batch only has 1 predicate, so we pick it already for convenience - let eth_friend = Predicate::Custom(eth_friend, 0); - - let mut builder = CustomPredicateBatchBuilder::new(); - let eth_dos_distance_base = builder.predicate_and( - &[ - "src_or", - "src_key", - "dst_or", - "dst_key", - "distance_or", - "distance_key", - ], - &[], - &[ - st_tmpl(NP::Equal) - .arg((w("src_or"), l("src_key"))) - .arg((w("dst_or"), w("dst_key"))), - st_tmpl(NP::ValueOf) - .arg((w("distance_or"), w("distance_key"))) - .arg(0), - ], - ); - - println!( - "b.0. eth_dos_distance_base = {}", - builder.predicates.last().unwrap() - ); - - let eth_dos_distance = Predicate::BatchSelf(3); - - let eth_dos_distance_ind = builder.predicate_and( - &[ - "src_or", - "src_key", - "dst_or", - "dst_key", - "distance_or", - "distance_key", - ], - &[ - "one_or", - "one_key", - "shorter_distance_or", - "shorter_distance_key", - "intermed_or", - "intermed_key", - ], - &[ - st_tmpl(eth_dos_distance) - .arg((w("src_or"), w("src_key"))) - .arg((w("intermed_or"), w("intermed_key"))) - .arg((w("shorter_distance_or"), w("shorter_distance_key"))), - // distance == shorter_distance + 1 - st_tmpl(NP::ValueOf).arg((w("one_or"), w("one_key"))).arg(1), - st_tmpl(NP::SumOf) - .arg((w("distance_or"), w("distance_key"))) - .arg((w("shorter_distance_or"), w("shorter_distance_key"))) - .arg((w("one_or"), w("one_key"))), - // intermed is a friend of dst - st_tmpl(eth_friend) - .arg((w("intermed_or"), w("intermed_key"))) - .arg((w("dst_or"), w("dst_key"))), - ], - ); - - println!( - "b.1. eth_dos_distance_ind = {}", - builder.predicates.last().unwrap() - ); - - let _eth_dos_distance = builder.predicate_or( - &[ - "src_or", - "src_key", - "dst_or", - "dst_key", - "distance_or", - "distance_key", - ], - &[], - &[ - st_tmpl(eth_dos_distance_base) - .arg((w("src_or"), w("src_key"))) - .arg((w("dst_or"), w("dst_key"))) - .arg((w("distance_or"), w("distance_key"))), - st_tmpl(eth_dos_distance_ind) - .arg((w("src_or"), w("src_key"))) - .arg((w("dst_or"), w("dst_key"))) - .arg((w("distance_or"), w("distance_key"))), - ], - ); - - println!( - "b.2. eth_dos_distance = {}", - builder.predicates.last().unwrap() - ); - } -} diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 34cbbe4..f120b69 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -48,6 +48,12 @@ pub type Entry = (String, Value); #[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)] pub struct Value(pub [F; 4]); +impl ToFields for Value { + fn to_fields(self) -> (Vec, usize) { + (self.0.to_vec(), 4) + } +} + impl Value { pub fn to_bytes(self) -> Vec { self.0