From 05c21ebe6abe17d08b12fa97e67f30a854e00874 Mon Sep 17 00:00:00 2001 From: Ahmad Afuni Date: Wed, 26 Feb 2025 00:44:27 +1000 Subject: [PATCH] feat: partial incorporation of custom predicates into statement and operation structures in middleware (#84) * Add custom predicates to middleware Statement enum * Add custom op enum variant and wildcard matching procedures --- src/backends/mock_main/statement.rs | 8 ++- src/frontend/custom.rs | 4 +- src/frontend/mod.rs | 8 +-- src/middleware/custom.rs | 89 +++++++++++++++++++++++++---- src/middleware/operation.rs | 9 ++- src/middleware/statement.rs | 37 ++++++------ 6 files changed, 119 insertions(+), 36 deletions(-) diff --git a/src/backends/mock_main/statement.rs b/src/backends/mock_main/statement.rs index 290bd61..8c38f0e 100644 --- a/src/backends/mock_main/statement.rs +++ b/src/backends/mock_main/statement.rs @@ -82,7 +82,13 @@ impl TryFrom for middleware::Statement { impl From for Statement { fn from(s: middleware::Statement) -> Self { - Statement(s.code(), s.args().into_iter().map(|arg| arg).collect()) + match s.code() { + middleware::Predicate::Native(c) => { + Statement(c, s.args().into_iter().map(|arg| arg).collect()) + } + // TODO: Custom statements + _ => todo!(), + } } } diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index b43fd9b..e4d0ab2 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -171,7 +171,7 @@ fn resolve_wildcard(args: &[&str], priv_args: &[&str], v: &HashOrWildcardStr) -> #[cfg(test)] mod tests { use super::*; - use crate::middleware::PodType; + use crate::middleware::{CustomPredicateRef, PodType}; #[test] fn test_custom_pred() { @@ -204,7 +204,7 @@ mod tests { println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); let eth_friend = builder.finish(); // This batch only has 1 predicate, so we pick it already for convenience - let eth_friend = Predicate::Custom(eth_friend, 0); + let eth_friend = Predicate::Custom(CustomPredicateRef(eth_friend, 0)); // next chunk builds: // eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and< diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index a781a2a..f3d665f 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -5,7 +5,7 @@ use anyhow::Result; use itertools::Itertools; use std::collections::HashMap; use std::convert::From; -use std::fmt; +use std::{fmt, hash as h}; use crate::middleware::{ self, @@ -22,7 +22,7 @@ pub use operation::*; pub use statement::*; /// This type is just for presentation purposes. -#[derive(Clone, Debug, Default, Hash, PartialEq, Eq)] +#[derive(Clone, Debug, Default, h::Hash, PartialEq, Eq)] pub enum PodClass { #[default] Signed, @@ -30,7 +30,7 @@ pub enum PodClass { } // An Origin, which represents a reference to an ancestor POD. -#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] +#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Default)] pub struct Origin(pub PodClass, pub PodId); #[derive(Clone, Debug, PartialEq, Eq)] @@ -166,7 +166,7 @@ impl SignedPod { } } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, h::Hash)] pub struct AnchoredKey(pub Origin, pub String); impl From for middleware::AnchoredKey { diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 39f5c5f..dc08d9d 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -1,16 +1,34 @@ -use std::fmt; use std::sync::Arc; +use std::{fmt, hash as h, iter::zip}; -use super::{hash_str, Hash, NativePredicate, ToFields, Value, F}; +use anyhow::{anyhow, Result}; + +use super::{ + hash_str, AnchoredKey, Hash, NativePredicate, PodId, Statement, StatementArg, ToFields, Value, + F, +}; // BEGIN Custom 1b -#[derive(Debug)] +#[derive(Clone, Debug, PartialEq, Eq, h::Hash)] pub enum HashOrWildcard { Hash(Hash), Wildcard(usize), } +impl HashOrWildcard { + /// Matches a hash or wildcard against a value, returning a pair + /// representing a wildcard binding (if any) or an error if no + /// match is possible. + pub fn match_against(&self, v: &Value) -> Result> { + match self { + HashOrWildcard::Hash(h) if &Value::from(h.clone()) == v => Ok(None), + HashOrWildcard::Wildcard(i) => Ok(Some((*i, v.clone()))), + _ => Err(anyhow!("Failed to match {} against {}.", self, v)), + } + } +} + impl fmt::Display for HashOrWildcard { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -20,13 +38,32 @@ impl fmt::Display for HashOrWildcard { } } -#[derive(Debug)] +#[derive(Clone, Debug, PartialEq, Eq, h::Hash)] pub enum StatementTmplArg { None, Literal(Value), Key(HashOrWildcard, HashOrWildcard), } +impl StatementTmplArg { + /// Matches a statement template argument against a statement + /// argument, returning a wildcard correspondence in the case of + /// one or more wildcard matches, nothing in the case of a + /// literal/hash match, and an error otherwise. + pub fn match_against(&self, s_arg: &StatementArg) -> Result> { + match (self, s_arg) { + (Self::None, StatementArg::None) => Ok(vec![]), + (Self::Literal(v), StatementArg::Literal(w)) if v == w => Ok(vec![]), + (Self::Key(tmpl_o, tmpl_k), StatementArg::Key(AnchoredKey(PodId(o), k))) => { + let o_corr = tmpl_o.match_against(&o.clone().into())?; + let k_corr = tmpl_k.match_against(&k.clone().into())?; + Ok([o_corr, k_corr].into_iter().flat_map(|x| x).collect()) + } + _ => Err(anyhow!("Failed to match {} against {}.", self, s_arg)), + } + } +} + impl fmt::Display for StatementTmplArg { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -50,10 +87,37 @@ impl fmt::Display for StatementTmplArg { // END /// Statement Template for a Custom Predicate -#[derive(Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct StatementTmpl(pub Predicate, pub Vec); -#[derive(Debug)] +impl StatementTmpl { + pub fn pred(&self) -> &Predicate { + &self.0 + } + pub fn args(&self) -> &[StatementTmplArg] { + &self.1 + } + /// Matches a statement template against a statement, returning + /// the variable bindings as an association list. Returns an error + /// if there is type or argument mismatch. + pub fn match_against(&self, s: &Statement) -> Result> { + type P = Predicate; + if matches!(self, Self(P::BatchSelf(_), _)) { + Err(anyhow!( + "Cannot check self-referencing statement templates." + )) + } else if self.pred() != &s.code() { + Err(anyhow!("Type mismatch between {:?} and {}.", self, s)) + } else { + zip(self.args(), s.args()) + .map(|(t_arg, s_arg)| t_arg.match_against(&s_arg)) + .collect::>>() + .map(|v| v.concat()) + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] pub struct CustomPredicate { /// true for "and", false for "or" pub conjunction: bool, @@ -96,7 +160,7 @@ impl fmt::Display for CustomPredicate { } } -#[derive(Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct CustomPredicateBatch { pub name: String, pub predicates: Vec, @@ -109,11 +173,14 @@ impl CustomPredicateBatch { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CustomPredicateRef(pub Arc, pub usize); + +#[derive(Clone, Debug, PartialEq, Eq)] pub enum Predicate { Native(NativePredicate), BatchSelf(usize), - Custom(Arc, usize), + Custom(CustomPredicateRef), } impl From for Predicate { @@ -127,7 +194,7 @@ impl ToFields for Predicate { match self { Self::Native(p) => p.to_fields(), Self::BatchSelf(i) => Value::from(i as i64).to_fields(), - Self::Custom(_pb, _i) => todo!(), // TODO + Self::Custom(_) => todo!(), // TODO } } } @@ -137,7 +204,7 @@ impl fmt::Display for Predicate { match self { Self::Native(p) => write!(f, "{:?}", p), Self::BatchSelf(i) => write!(f, "self.{}", i), - Self::Custom(pb, i) => write!(f, "{}.{}", pb.name, i), + Self::Custom(CustomPredicateRef(pb, i)) => write!(f, "{}.{}", pb.name, i), } } } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index f8934de..ddcc37d 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Result}; -use super::Statement; +use super::{CustomPredicateRef, Statement}; use crate::middleware::{AnchoredKey, SELF}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -42,6 +42,7 @@ pub enum Operation { SumOf(Statement, Statement, Statement), ProductOf(Statement, Statement, Statement), MaxOf(Statement, Statement, Statement), + Custom(CustomPredicateRef, Vec), } impl Operation { @@ -64,6 +65,7 @@ impl Operation { Self::SumOf(_, _, _) => SumOf, Self::ProductOf(_, _, _) => ProductOf, Self::MaxOf(_, _, _) => MaxOf, + Self::Custom(_, _) => todo!(), } } @@ -85,6 +87,7 @@ impl Operation { Self::SumOf(s1, s2, s3) => vec![s1, s2, s3], Self::ProductOf(s1, s2, s3) => vec![s1, s2, s3], Self::MaxOf(s1, s2, s3) => vec![s1, s2, s3], + Self::Custom(_, args) => args, } } /// Forms operation from op-code and arguments. @@ -171,6 +174,10 @@ impl Operation { let v3: i64 = v3.clone().try_into()?; Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3) } + ( + Self::Custom(CustomPredicateRef(cpb, i), _args), + Custom(CustomPredicateRef(s_cpb, s_i), _s_args), + ) if cpb == s_cpb && i == s_i => todo!(), _ => Err(anyhow!( "Invalid deduction: {:?} ⇏ {:#}", self, diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 0d35805..ce84f2a 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -1,15 +1,15 @@ use anyhow::{anyhow, Result}; use plonky2::field::types::Field; -use std::fmt; +use std::{collections::HashMap, fmt}; use strum_macros::FromRepr; -use super::{AnchoredKey, ToFields, Value, F}; +use super::{AnchoredKey, CustomPredicateRef, Hash, Predicate, ToFields, Value, F}; pub const KEY_SIGNER: &str = "_signer"; pub const KEY_TYPE: &str = "_type"; pub const STATEMENT_ARG_F_LEN: usize = 8; -#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash)] pub enum NativePredicate { None = 0, ValueOf = 1, @@ -30,9 +30,8 @@ impl ToFields for NativePredicate { } } -// TODO: Incorporate custom statements into this enum. /// Type encapsulating statements with their associated arguments. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum Statement { None, ValueOf(AnchoredKey, Value), @@ -45,25 +44,28 @@ pub enum Statement { SumOf(AnchoredKey, AnchoredKey, AnchoredKey), ProductOf(AnchoredKey, AnchoredKey, AnchoredKey), MaxOf(AnchoredKey, AnchoredKey, AnchoredKey), + Custom(CustomPredicateRef, Vec), } impl Statement { pub fn is_none(&self) -> bool { self == &Self::None } - pub fn code(&self) -> NativePredicate { + pub fn code(&self) -> Predicate { + use Predicate::*; match self { - Self::None => NativePredicate::None, - Self::ValueOf(_, _) => NativePredicate::ValueOf, - Self::Equal(_, _) => NativePredicate::Equal, - Self::NotEqual(_, _) => NativePredicate::NotEqual, - Self::Gt(_, _) => NativePredicate::Gt, - Self::Lt(_, _) => NativePredicate::Lt, - Self::Contains(_, _) => NativePredicate::Contains, - Self::NotContains(_, _) => NativePredicate::NotContains, - Self::SumOf(_, _, _) => NativePredicate::SumOf, - Self::ProductOf(_, _, _) => NativePredicate::ProductOf, - Self::MaxOf(_, _, _) => NativePredicate::MaxOf, + Self::None => Native(NativePredicate::None), + Self::ValueOf(_, _) => Native(NativePredicate::ValueOf), + Self::Equal(_, _) => Native(NativePredicate::Equal), + Self::NotEqual(_, _) => Native(NativePredicate::NotEqual), + Self::Gt(_, _) => Native(NativePredicate::Gt), + Self::Lt(_, _) => Native(NativePredicate::Lt), + Self::Contains(_, _) => Native(NativePredicate::Contains), + Self::NotContains(_, _) => Native(NativePredicate::NotContains), + Self::SumOf(_, _, _) => Native(NativePredicate::SumOf), + Self::ProductOf(_, _, _) => Native(NativePredicate::ProductOf), + Self::MaxOf(_, _, _) => Native(NativePredicate::MaxOf), + Self::Custom(cpr, _) => Custom(cpr.clone()), } } pub fn args(&self) -> Vec { @@ -80,6 +82,7 @@ impl Statement { Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], + Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(|h| Key(h))), } } }