From 2e9719a1cac962f7d61c636988d003eb89514abf Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Fri, 21 Feb 2025 01:55:36 +0100 Subject: [PATCH] Prototype custom predicates (#74) * wip * prototype custom predicates 1b * feat: implement custom pred recursion * files reorg, add github CI for rustfmt checks --------- Co-authored-by: arnaucube --- .github/workflows/rustfmt.yml | 21 + .github/workflows/typos.toml | 1 + .../{mock_main.rs => mock_main/mod.rs} | 69 +-- src/backends/mock_main/operation.rs | 23 +- src/backends/mock_main/statement.rs | 33 +- src/backends/mock_signed.rs | 7 +- src/{frontend.rs => frontend/mod.rs} | 26 +- src/frontend/operation.rs | 5 +- src/frontend/statement.rs | 32 +- src/middleware/custom.rs | 397 ++++++++++++++++++ src/middleware/mod.rs | 9 +- src/middleware/operation.rs | 2 +- src/middleware/statement.rs | 28 +- 13 files changed, 529 insertions(+), 124 deletions(-) create mode 100644 .github/workflows/rustfmt.yml rename src/backends/{mock_main.rs => mock_main/mod.rs} (93%) rename src/{frontend.rs => frontend/mod.rs} (96%) create mode 100644 src/middleware/custom.rs diff --git a/.github/workflows/rustfmt.yml b/.github/workflows/rustfmt.yml new file mode 100644 index 0000000..af1ff02 --- /dev/null +++ b/.github/workflows/rustfmt.yml @@ -0,0 +1,21 @@ +name: Rustfmt Check + +on: + pull_request: + branches: [ main ] + types: [ready_for_review, opened, synchronize, reopened] + push: + branches: [ main ] + +jobs: + rustfmt: + if: github.event.pull_request.draft == false + name: Rust formatting + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + components: rustfmt + - name: Check formatting + uses: actions-rust-lang/rustfmt@v1 diff --git a/.github/workflows/typos.toml b/.github/workflows/typos.toml index 3910746..4814700 100644 --- a/.github/workflows/typos.toml +++ b/.github/workflows/typos.toml @@ -1,2 +1,3 @@ [default.extend-words] groth = "groth" # to avoid it dectecting it as 'growth' +BA = "BA" diff --git a/src/backends/mock_main.rs b/src/backends/mock_main/mod.rs similarity index 93% rename from src/backends/mock_main.rs rename to src/backends/mock_main/mod.rs index 5025f0f..cbb0c7d 100644 --- a/src/backends/mock_main.rs +++ b/src/backends/mock_main/mod.rs @@ -1,19 +1,20 @@ -mod operation; -mod statement; - -use crate::middleware::{ - self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativeStatement, NonePod, - Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF, -}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use itertools::Itertools; -pub use operation::*; use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::Hasher; -pub use statement::*; use std::any::Any; use std::fmt; +use crate::middleware::{ + self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativePredicate, NonePod, + Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF, +}; + +mod operation; +mod statement; +pub use operation::*; +pub use statement::*; + pub const VALUE_TYPE: &str = "MockMainPOD"; pub struct MockProver {} @@ -222,18 +223,17 @@ impl MockMainPod { fn find_op_arg( statements: &[Statement], op_arg: &middleware::Statement, - ) -> Result { + ) -> Result { match op_arg { middleware::Statement::None => Ok(OperationArg::None), _ => statements .iter() .enumerate() .find_map(|(i, s)| { - // TODO: Error handling - (&middleware::Statement::try_from(s.clone()).unwrap() == op_arg).then_some(i) + (&middleware::Statement::try_from(s.clone()).ok()? == op_arg).then_some(i) }) .map(OperationArg::Index) - .ok_or(OperationArgError::StatementNotFound), + .ok_or(anyhow!("statement not found")), } } @@ -241,7 +241,7 @@ impl MockMainPod { params: &Params, statements: &[Statement], input_operations: &[middleware::Operation], - ) -> Result, OperationArgError> { + ) -> Result> { let mut operations = Vec::new(); for i in 0..params.max_priv_statements() { let op = input_operations @@ -252,7 +252,7 @@ impl MockMainPod { let mut args = mid_args .iter() .map(|mid_arg| Self::find_op_arg(statements, mid_arg)) - .collect::, OperationArgError>>()?; + .collect::>>()?; Self::pad_operation_args(params, &mut args); operations.push(Operation(op.code(), args)); } @@ -265,7 +265,7 @@ impl MockMainPod { params: &Params, statements: &[Statement], mut operations: Vec, - ) -> Result, OperationArgError> { + ) -> Result> { let offset_public_statements = statements.len() - params.max_public_statements; operations.push(Operation(NativeOperation::NewEntry, vec![])); for i in 0..(params.max_public_statements - 1) { @@ -318,7 +318,7 @@ impl MockMainPod { statements[statements.len() - params.max_public_statements..].to_vec(); // get the id out of the public statements - let id: PodId = PodId(hash_statements(&public_statements)?); + let id: PodId = PodId(hash_statements(&public_statements)); Ok(Self { params: params.clone(), @@ -335,7 +335,7 @@ impl MockMainPod { fn statement_none(params: &Params) -> Statement { let mut args = Vec::with_capacity(params.max_statement_args); Self::pad_statement_args(¶ms, &mut args); - Statement(NativeStatement::None, args) + Statement(NativePredicate::None, args) } fn operation_none(params: &Params) -> Operation { @@ -353,12 +353,12 @@ impl MockMainPod { } } -pub fn hash_statements(statements: &[Statement]) -> Result { +pub fn hash_statements(statements: &[Statement]) -> middleware::Hash { let field_elems = statements .into_iter() .flat_map(|statement| statement.clone().to_fields().0) .collect::>(); - Ok(Hash(PoseidonHash::hash_no_pad(&field_elems).elements)) + Hash(PoseidonHash::hash_no_pad(&field_elems).elements) } impl Pod for MockMainPod { @@ -367,14 +367,14 @@ impl Pod for MockMainPod { // get the input_statements from the self.statements let input_statements = &self.statements[input_statement_offset..]; // get the id out of the public statements, and ensure it is equal to self.id - let ids_match = self.id == PodId(hash_statements(&self.public_statements).unwrap()); + let ids_match = self.id == PodId(hash_statements(&self.public_statements)); // find a ValueOf statement from the public statements with key=KEY_TYPE and check that the // value is PodType::MockMainPod let has_type_statement = self .public_statements .iter() .find(|s| { - s.0 == NativeStatement::ValueOf + s.0 == NativePredicate::ValueOf && s.1.len() > 0 && if let StatementArg::Key(AnchoredKey(pod_id, key_hash)) = s.1[0] { pod_id == SELF && key_hash == hash_str(KEY_TYPE) @@ -402,7 +402,7 @@ impl Pod for MockMainPod { s, ) }) - .filter(|(_, s)| s.0 == NativeStatement::ValueOf) + .filter(|(_, s)| s.0 == NativePredicate::ValueOf) .flat_map(|(i, s)| { if let StatementArg::Key(ak) = &s.1[0] { vec![(i, ak.1, ak.0)] @@ -473,22 +473,22 @@ pub mod tests { use crate::middleware; #[test] - fn test_mock_main_zu_kyc() { + fn test_mock_main_zu_kyc() -> Result<()> { let params = middleware::Params::default(); let (gov_id_builder, pay_stub_builder) = zu_kyc_sign_pod_builders(¶ms); let mut signer = MockSigner { pk: "ZooGov".into(), }; - let gov_id_pod = gov_id_builder.sign(&mut signer).unwrap(); + let gov_id_pod = gov_id_builder.sign(&mut signer)?; let mut signer = MockSigner { pk: "ZooDeel".into(), }; - let pay_stub_pod = pay_stub_builder.sign(&mut signer).unwrap(); + let pay_stub_pod = pay_stub_builder.sign(&mut signer)?; let kyc_builder = zu_kyc_pod_builder(¶ms, &gov_id_pod, &pay_stub_pod); let mut prover = MockProver {}; - let kyc_pod = kyc_builder.prove(&mut prover).unwrap(); + let kyc_pod = kyc_builder.prove(&mut prover)?; let pod = kyc_pod.pod.into_any().downcast::().unwrap(); println!("{:#}", pod); @@ -496,14 +496,15 @@ pub mod tests { assert_eq!(pod.verify(), true); // TODO // println!("id: {}", pod.id()); // println!("pub_statements: {:?}", pod.pub_statements()); + Ok(()) } #[test] - fn test_mock_main_great_boy() { + fn test_mock_main_great_boy() -> Result<()> { let great_boy_builder = great_boy_pod_full_flow(); let mut prover = MockProver {}; - let great_boy_pod = great_boy_builder.prove(&mut prover).unwrap(); + let great_boy_pod = great_boy_builder.prove(&mut prover)?; let pod = great_boy_pod .pod .into_any() @@ -513,16 +514,20 @@ pub mod tests { println!("{}", pod); assert_eq!(pod.verify(), true); + + Ok(()) } #[test] - fn test_mock_main_tickets() { + fn test_mock_main_tickets() -> Result<()> { let tickets_builder = tickets_pod_full_flow(); let mut prover = MockProver {}; - let proof_pod = tickets_builder.prove(&mut prover).unwrap(); + let proof_pod = tickets_builder.prove(&mut prover)?; let pod = proof_pod.pod.into_any().downcast::().unwrap(); println!("{}", pod); assert_eq!(pod.verify(), true); + + Ok(()) } } diff --git a/src/backends/mock_main/operation.rs b/src/backends/mock_main/operation.rs index 12cb933..cb5ff3a 100644 --- a/src/backends/mock_main/operation.rs +++ b/src/backends/mock_main/operation.rs @@ -1,10 +1,8 @@ +use anyhow::Result; use std::fmt; -use anyhow::Result; - -use crate::middleware::{self, NativeOperation}; - use super::Statement; +use crate::middleware::{self, NativeOperation}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum OperationArg { @@ -18,23 +16,6 @@ impl OperationArg { } } -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum OperationArgError { - KeyNotFound, - StatementNotFound, -} - -impl std::fmt::Display for OperationArgError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - OperationArgError::KeyNotFound => write!(f, "Key not found"), - OperationArgError::StatementNotFound => write!(f, "Statement not found"), - } - } -} - -impl std::error::Error for OperationArgError {} - #[derive(Clone, Debug, PartialEq, Eq)] pub struct Operation(pub NativeOperation, pub Vec); diff --git a/src/backends/mock_main/statement.rs b/src/backends/mock_main/statement.rs index de1f510..290bd61 100644 --- a/src/backends/mock_main/statement.rs +++ b/src/backends/mock_main/statement.rs @@ -1,15 +1,14 @@ +use anyhow::{anyhow, Result}; use std::fmt; -use anyhow::{anyhow, Result}; - -use crate::middleware::{self, NativeStatement, StatementArg, ToFields}; +use crate::middleware::{self, NativePredicate, StatementArg, ToFields}; #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Statement(pub NativeStatement, pub Vec); +pub struct Statement(pub NativePredicate, pub Vec); impl Statement { pub fn is_none(&self) -> bool { - self.0 == NativeStatement::None + self.0 == NativePredicate::None } /// Argument method. Trailing Nones are filtered out. pub fn args(&self) -> Vec { @@ -44,7 +43,7 @@ impl TryFrom for middleware::Statement { type Error = anyhow::Error; fn try_from(s: Statement) -> Result { type S = middleware::Statement; - type NS = NativeStatement; + type NP = NativePredicate; type SA = StatementArg; let proper_args = s.args(); let args = ( @@ -53,27 +52,27 @@ impl TryFrom for middleware::Statement { proper_args.get(2).cloned(), ); Ok(match (s.0, args, proper_args.len()) { - (NS::None, _, 0) => S::None, - (NS::ValueOf, (Some(SA::Key(ak)), Some(SA::Literal(v)), None), 2) => S::ValueOf(ak, v), - (NS::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Equal(ak1, ak2), - (NS::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { + (NP::None, _, 0) => S::None, + (NP::ValueOf, (Some(SA::Key(ak)), Some(SA::Literal(v)), None), 2) => S::ValueOf(ak, v), + (NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Equal(ak1, ak2), + (NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { S::NotEqual(ak1, ak2) } - (NS::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Gt(ak1, ak2), - (NS::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Lt(ak1, ak2), - (NS::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { + (NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Gt(ak1, ak2), + (NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Lt(ak1, ak2), + (NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { S::Contains(ak1, ak2) } - (NS::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { + (NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { S::NotContains(ak1, ak2) } - (NS::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { + (NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { S::SumOf(ak1, ak2, ak3) } - (NS::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { + (NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { S::ProductOf(ak1, ak2, ak3) } - (NS::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { + (NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { S::MaxOf(ak1, ak2, ak3) } _ => Err(anyhow!("Ill-formed statement expression {:?}", s))?, diff --git a/src/backends/mock_signed.rs b/src/backends/mock_signed.rs index 3da701f..afbaf76 100644 --- a/src/backends/mock_signed.rs +++ b/src/backends/mock_signed.rs @@ -1,11 +1,12 @@ +use anyhow::Result; +use std::any::Any; +use std::collections::HashMap; + use crate::middleware::{ containers::Dictionary, hash_str, AnchoredKey, Hash, Params, Pod, PodId, PodSigner, PodType, Statement, Value, KEY_SIGNER, KEY_TYPE, }; use crate::primitives::merkletree::MerkleTree; -use anyhow::Result; -use std::any::Any; -use std::collections::HashMap; pub struct MockSigner { pub pk: String, diff --git a/src/frontend.rs b/src/frontend/mod.rs similarity index 96% rename from src/frontend.rs rename to src/frontend/mod.rs index 40efd84..085b093 100644 --- a/src/frontend.rs +++ b/src/frontend/mod.rs @@ -1,9 +1,6 @@ //! The frontend includes the user-level abstractions and user-friendly types to define and work //! with Pods. -mod operation; -mod statement; - use anyhow::Result; use itertools::Itertools; use std::collections::HashMap; @@ -13,9 +10,12 @@ use std::fmt; use crate::middleware::{ self, containers::{Array, Dictionary, Set}, - hash_str, Hash, MainPodInputs, NativeOperation, NativeStatement, Params, PodId, PodProver, + hash_str, Hash, MainPodInputs, NativeOperation, NativePredicate, Params, PodId, PodProver, PodSigner, SELF, }; + +mod operation; +mod statement; pub use operation::*; pub use statement::*; @@ -236,7 +236,7 @@ impl MainPodBuilder { for arg in args.iter_mut() { match arg { OperationArg::Statement(s) => { - if s.0 == NativeStatement::ValueOf { + if s.0 == NativePredicate::ValueOf { st_args.push(s.1[0].clone()) } else { panic!("Invalid statement argument."); @@ -276,27 +276,27 @@ impl MainPodBuilder { let Operation(op_type, ref mut args) = op; // TODO: argument type checking let st = match op_type { - None => Statement(NativeStatement::None, vec![]), - NewEntry => Statement(NativeStatement::ValueOf, self.op_args_entries(public, args)), + None => Statement(NativePredicate::None, vec![]), + NewEntry => Statement(NativePredicate::ValueOf, self.op_args_entries(public, args)), CopyStatement => todo!(), EqualFromEntries => { - Statement(NativeStatement::Equal, self.op_args_entries(public, args)) + Statement(NativePredicate::Equal, self.op_args_entries(public, args)) } NotEqualFromEntries => Statement( - NativeStatement::NotEqual, + NativePredicate::NotEqual, self.op_args_entries(public, args), ), - GtFromEntries => Statement(NativeStatement::Gt, self.op_args_entries(public, args)), - LtFromEntries => Statement(NativeStatement::Lt, self.op_args_entries(public, args)), + GtFromEntries => Statement(NativePredicate::Gt, self.op_args_entries(public, args)), + LtFromEntries => Statement(NativePredicate::Lt, self.op_args_entries(public, args)), TransitiveEqualFromStatements => todo!(), GtToNotEqual => todo!(), LtToNotEqual => todo!(), ContainsFromEntries => Statement( - NativeStatement::Contains, + NativePredicate::Contains, self.op_args_entries(public, args), ), NotContainsFromEntries => Statement( - NativeStatement::NotContains, + NativePredicate::NotContains, self.op_args_entries(public, args), ), RenameContainedBy => todo!(), diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index 0b713ed..57d6f4f 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -1,8 +1,7 @@ use std::fmt; -use crate::middleware::{hash_str, NativeOperation, NativeStatement}; - use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value}; +use crate::middleware::{hash_str, NativeOperation, NativePredicate}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum OperationArg { @@ -56,7 +55,7 @@ impl From<(&SignedPod, &str)> for OperationArg { // TODO: Actual value, TryFrom. let value = pod.kvs().get(&hash_str(key)).unwrap().clone(); Self::Statement(Statement( - NativeStatement::ValueOf, + NativePredicate::ValueOf, vec![ StatementArg::Key(AnchoredKey(pod.origin(), key.to_string())), StatementArg::Literal(Value::Raw(value)), diff --git a/src/frontend/statement.rs b/src/frontend/statement.rs index 9dd48de..59a75e2 100644 --- a/src/frontend/statement.rs +++ b/src/frontend/statement.rs @@ -1,10 +1,8 @@ +use anyhow::{anyhow, Result}; use std::fmt; -use anyhow::{anyhow, Result}; - -use crate::middleware::{self, NativeStatement}; - use super::{AnchoredKey, Value}; +use crate::middleware::{self, NativePredicate}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum StatementArg { @@ -22,13 +20,13 @@ impl fmt::Display for StatementArg { } #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Statement(pub NativeStatement, pub Vec); +pub struct Statement(pub NativePredicate, pub Vec); impl TryFrom for middleware::Statement { type Error = anyhow::Error; fn try_from(s: Statement) -> Result { type MS = middleware::Statement; - type NS = NativeStatement; + type NP = NativePredicate; type SA = StatementArg; let args = ( s.1.get(0).cloned(), @@ -36,35 +34,35 @@ impl TryFrom for middleware::Statement { s.1.get(2).cloned(), ); Ok(match (s.0, args) { - (NS::None, (None, None, None)) => MS::None, - (NS::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => { + (NP::None, (None, None, None)) => MS::None, + (NP::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => { MS::ValueOf(ak.into(), (&v).into()) } - (NS::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + (NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { MS::Equal(ak1.into(), ak2.into()) } - (NS::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + (NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { MS::NotEqual(ak1.into(), ak2.into()) } - (NS::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + (NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { MS::Gt(ak1.into(), ak2.into()) } - (NS::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + (NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { MS::Lt(ak1.into(), ak2.into()) } - (NS::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + (NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { MS::Contains(ak1.into(), ak2.into()) } - (NS::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + (NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { MS::NotContains(ak1.into(), ak2.into()) } - (NS::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { + (NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { MS::SumOf(ak1.into(), ak2.into(), ak3.into()) } - (NS::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { + (NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { MS::ProductOf(ak1.into(), ak2.into(), ak3.into()) } - (NS::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { + (NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { MS::MaxOf(ak1.into(), ak2.into(), ak3.into()) } _ => Err(anyhow!("Ill-formed statement: {}", s))?, diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs new file mode 100644 index 0000000..4bef99d --- /dev/null +++ b/src/middleware/custom.rs @@ -0,0 +1,397 @@ +use std::fmt; +use std::sync::Arc; + +use super::{hash_str, Hash, NativePredicate, ToFields, Value, F}; + +// BEGIN Custom 1b + +#[derive(Debug)] +pub enum HashOrWildcard { + Hash(Hash), + Wildcard(usize), +} + +impl fmt::Display for HashOrWildcard { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Hash(h) => write!(f, "{}", h), + Self::Wildcard(n) => write!(f, "*{}", n), + } + } +} + +#[derive(Debug)] +pub enum StatementTmplArg { + None, + Literal(Value), + Key(HashOrWildcard, HashOrWildcard), +} + +impl fmt::Display for StatementTmplArg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::None => write!(f, "none"), + Self::Literal(v) => write!(f, "{}", v), + Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key), + } + } +} + +// END + +// BEGIN Custom 2 + +// pub enum StatementTmplArg { +// None, +// Literal(Value), +// Wildcard(usize), +// } + +// END + +/// Statement Template for a Custom Predicate +#[derive(Debug)] +pub struct StatementTmpl(Predicate, Vec); + +#[derive(Debug)] +pub struct CustomPredicate { + /// true for "and", false for "or" + pub conjunction: bool, + pub statements: Vec, + pub args_len: usize, + // TODO: Add private args length? + // TODO: Add args type information? +} + +impl fmt::Display for CustomPredicate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?; + for st in &self.statements { + write!(f, " {}", st.0)?; + for (i, arg) in st.1.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; + } + writeln!(f, "),")?; + } + write!(f, ">(")?; + for i in 0..self.args_len { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "*{}", i)?; + } + writeln!(f, ")")?; + Ok(()) + } +} + +#[derive(Debug)] +pub struct CustomPredicateBatch { + predicates: Vec, +} + +impl CustomPredicateBatch { + pub fn hash(&self) -> Hash { + // TODO + hash_str(&format!("{:?}", self)) + } +} + +#[derive(Clone, Debug)] +pub enum Predicate { + Native(NativePredicate), + BatchSelf(usize), + Custom(Arc, usize), +} + +impl From for Predicate { + fn from(v: NativePredicate) -> Self { + Self::Native(v) + } +} + +impl ToFields for Predicate { + fn to_fields(self) -> (Vec, usize) { + todo!() + } +} + +impl fmt::Display for Predicate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Native(p) => write!(f, "{:?}", p), + Self::BatchSelf(i) => write!(f, "self.{}", i), + Self::Custom(pb, i) => write!(f, "{}.{}", pb.hash(), i), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::middleware::PodType; + + enum HashOrWildcardStr { + Hash(Hash), + Wildcard(String), + } + + fn l(s: &str) -> HashOrWildcardStr { + HashOrWildcardStr::Hash(hash_str(s)) + } + + fn w(s: &str) -> HashOrWildcardStr { + HashOrWildcardStr::Wildcard(s.to_string()) + } + + enum BuilderArg { + Literal(Value), + Key(HashOrWildcardStr, HashOrWildcardStr), + } + + impl From<(HashOrWildcardStr, HashOrWildcardStr)> for BuilderArg { + fn from((pod_id, key): (HashOrWildcardStr, HashOrWildcardStr)) -> Self { + Self::Key(pod_id, key) + } + } + + impl From for BuilderArg + where + V: Into, + { + fn from(v: V) -> Self { + Self::Literal(v.into()) + } + } + + struct StatementTmplBuilder { + predicate: Predicate, + args: Vec, + } + + fn st_tmpl(p: impl Into) -> StatementTmplBuilder { + StatementTmplBuilder { + predicate: p.into(), + args: Vec::new(), + } + } + + impl StatementTmplBuilder { + fn arg(mut self, a: impl Into) -> Self { + self.args.push(a.into()); + self + } + } + + struct CustomPredicateBatchBuilder { + predicates: Vec, + } + + impl CustomPredicateBatchBuilder { + fn new() -> Self { + Self { + predicates: Vec::new(), + } + } + + fn predicate_and( + &mut self, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + self.predicate(true, args, priv_args, sts) + } + + fn predicate_or( + &mut self, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + self.predicate(false, args, priv_args, sts) + } + + fn predicate( + &mut self, + conjunction: bool, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + use BuilderArg as BA; + let statements = sts + .iter() + .map(|sb| { + let args = sb + .args + .iter() + .map(|a| match a { + BA::Literal(v) => StatementTmplArg::Literal(*v), + BA::Key(pod_id, key) => StatementTmplArg::Key( + resolve_wildcard(args, priv_args, pod_id), + resolve_wildcard(args, priv_args, key), + ), + }) + .collect(); + StatementTmpl(sb.predicate.clone(), args) + }) + .collect(); + let custom_predicate = CustomPredicate { + conjunction, + statements, + args_len: args.len(), + }; + self.predicates.push(custom_predicate); + Predicate::BatchSelf(self.predicates.len() - 1) + } + + fn finish(self) -> Arc { + Arc::new(CustomPredicateBatch { + predicates: self.predicates, + }) + } + } + + fn resolve_wildcard( + args: &[&str], + priv_args: &[&str], + v: &HashOrWildcardStr, + ) -> HashOrWildcard { + match v { + HashOrWildcardStr::Hash(h) => HashOrWildcard::Hash(*h), + HashOrWildcardStr::Wildcard(s) => HashOrWildcard::Wildcard( + args.iter() + .chain(priv_args.iter()) + .enumerate() + .find_map(|(i, name)| (&s == name).then_some(i)) + .unwrap(), + ), + } + } + + #[test] + fn test_custom_pred() { + use NativePredicate as NP; + + let mut builder = CustomPredicateBatchBuilder::new(); + let _eth_friend = builder.predicate_and( + &["src_or", "src_key", "dst_or", "dst_key"], + &["attestation_pod"], + &[ + st_tmpl(NP::ValueOf) + .arg((w("attestation_pod"), l("type"))) + .arg(PodType::Signed), + st_tmpl(NP::Equal) + .arg((w("attestation_pod"), l("signer"))) + .arg((w("src_or"), w("src_key"))), + st_tmpl(NP::Equal) + .arg((w("attestation_pod"), l("attestation"))) + .arg((w("dst_or"), w("dst_key"))), + ], + ); + + println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); + let eth_friend = builder.finish(); + // This batch only has 1 predicate, so we pick it already for convenience + let eth_friend = Predicate::Custom(eth_friend, 0); + + let mut builder = CustomPredicateBatchBuilder::new(); + let eth_dos_distance_base = builder.predicate_and( + &[ + "src_or", + "src_key", + "dst_or", + "dst_key", + "distance_or", + "distance_key", + ], + &[], + &[ + st_tmpl(NP::Equal) + .arg((w("src_or"), l("src_key"))) + .arg((w("dst_or"), w("dst_key"))), + st_tmpl(NP::ValueOf) + .arg((w("distance_or"), w("distance_key"))) + .arg(0), + ], + ); + + println!( + "b.0. eth_dos_distance_base = {}", + builder.predicates.last().unwrap() + ); + + let eth_dos_distance = Predicate::BatchSelf(3); + + let eth_dos_distance_ind = builder.predicate_and( + &[ + "src_or", + "src_key", + "dst_or", + "dst_key", + "distance_or", + "distance_key", + ], + &[ + "one_or", + "one_key", + "shorter_distance_or", + "shorter_distance_key", + "intermed_or", + "intermed_key", + ], + &[ + st_tmpl(eth_dos_distance) + .arg((w("src_or"), w("src_key"))) + .arg((w("intermed_or"), w("intermed_key"))) + .arg((w("shorter_distance_or"), w("shorter_distance_key"))), + // distance == shorter_distance + 1 + st_tmpl(NP::ValueOf).arg((w("one_or"), w("one_key"))).arg(1), + st_tmpl(NP::SumOf) + .arg((w("distance_or"), w("distance_key"))) + .arg((w("shorter_distance_or"), w("shorter_distance_key"))) + .arg((w("one_or"), w("one_key"))), + // intermed is a friend of dst + st_tmpl(eth_friend) + .arg((w("intermed_or"), w("intermed_key"))) + .arg((w("dst_or"), w("dst_key"))), + ], + ); + + println!( + "b.1. eth_dos_distance_ind = {}", + builder.predicates.last().unwrap() + ); + + let _eth_dos_distance = builder.predicate_or( + &[ + "src_or", + "src_key", + "dst_or", + "dst_key", + "distance_or", + "distance_key", + ], + &[], + &[ + st_tmpl(eth_dos_distance_base) + .arg((w("src_or"), w("src_key"))) + .arg((w("dst_or"), w("dst_key"))) + .arg((w("distance_or"), w("distance_key"))), + st_tmpl(eth_dos_distance_ind) + .arg((w("src_or"), w("src_key"))) + .arg((w("dst_or"), w("dst_key"))) + .arg((w("distance_or"), w("distance_key"))), + ], + ); + + println!( + "b.2. eth_dos_distance = {}", + builder.predicates.last().unwrap() + ); + } +} diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 07e4032..14cd9f2 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,18 +1,20 @@ //! The middleware includes the type definitions and the traits used to connect the frontend and //! the backend. +mod custom; mod operation; mod statement; +pub use custom::*; +pub use operation::*; +pub use statement::*; use anyhow::{anyhow, Error, Result}; use dyn_clone::DynClone; use hex::{FromHex, FromHexError}; -pub use operation::*; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::{Hasher, PoseidonGoldilocksConfig}; -pub use statement::*; use std::any::Any; use std::cmp::{Ord, Ordering}; use std::collections::HashMap; @@ -201,7 +203,8 @@ impl From for Value { pub fn hash_str(s: &str) -> Hash { let mut input = s.as_bytes().to_vec(); input.push(1); // padding - // Merge 7 bytes into 1 field, because the field is slightly below 64 bits + + // Merge 7 bytes into 1 field, because the field is slightly below 64 bits let input: Vec = input .chunks(7) .map(|bytes| { diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 5f3a5c3..f8934de 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -1,7 +1,7 @@ -use crate::middleware::{AnchoredKey, SELF}; use anyhow::{anyhow, Result}; use super::Statement; +use crate::middleware::{AnchoredKey, SELF}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NativeOperation { diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 617ad6d..0d35805 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -10,7 +10,7 @@ pub const KEY_TYPE: &str = "_type"; pub const STATEMENT_ARG_F_LEN: usize = 8; #[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq)] -pub enum NativeStatement { +pub enum NativePredicate { None = 0, ValueOf = 1, Equal = 2, @@ -24,7 +24,7 @@ pub enum NativeStatement { MaxOf = 10, } -impl ToFields for NativeStatement { +impl ToFields for NativePredicate { fn to_fields(self) -> (Vec, usize) { (vec![F::from_canonical_u64(self as u64)], 1) } @@ -51,19 +51,19 @@ impl Statement { pub fn is_none(&self) -> bool { self == &Self::None } - pub fn code(&self) -> NativeStatement { + pub fn code(&self) -> NativePredicate { match self { - Self::None => NativeStatement::None, - Self::ValueOf(_, _) => NativeStatement::ValueOf, - Self::Equal(_, _) => NativeStatement::Equal, - Self::NotEqual(_, _) => NativeStatement::NotEqual, - Self::Gt(_, _) => NativeStatement::Gt, - Self::Lt(_, _) => NativeStatement::Lt, - Self::Contains(_, _) => NativeStatement::Contains, - Self::NotContains(_, _) => NativeStatement::NotContains, - Self::SumOf(_, _, _) => NativeStatement::SumOf, - Self::ProductOf(_, _, _) => NativeStatement::ProductOf, - Self::MaxOf(_, _, _) => NativeStatement::MaxOf, + Self::None => NativePredicate::None, + Self::ValueOf(_, _) => NativePredicate::ValueOf, + Self::Equal(_, _) => NativePredicate::Equal, + Self::NotEqual(_, _) => NativePredicate::NotEqual, + Self::Gt(_, _) => NativePredicate::Gt, + Self::Lt(_, _) => NativePredicate::Lt, + Self::Contains(_, _) => NativePredicate::Contains, + Self::NotContains(_, _) => NativePredicate::NotContains, + Self::SumOf(_, _, _) => NativePredicate::SumOf, + Self::ProductOf(_, _, _) => NativePredicate::ProductOf, + Self::MaxOf(_, _, _) => NativePredicate::MaxOf, } } pub fn args(&self) -> Vec {