From 7eeb595dc27350f95643e1750bb0f21d70459c9f Mon Sep 17 00:00:00 2001 From: tideofwords Date: Mon, 3 Mar 2025 15:55:30 -0800 Subject: [PATCH] Backend support for custom statements and deductions (#105) * Custom statements on backend * Add support for custom statements and deductions on backend * typo checker smh * clean up match statement Co-authored-by: Ahmad Afuni * clean up more match statement Co-authored-by: Ahmad Afuni * delete done todo Co-authored-by: Ahmad Afuni --------- Co-authored-by: Ahmad Afuni --- .github/workflows/typos.toml | 3 +- src/backends/plonky2/mock_main/mod.rs | 26 +++--- src/backends/plonky2/mock_main/operation.rs | 4 +- src/backends/plonky2/mock_main/statement.rs | 90 ++++++++++++++------- 4 files changed, 76 insertions(+), 47 deletions(-) diff --git a/.github/workflows/typos.toml b/.github/workflows/typos.toml index 471d317..66d2d65 100644 --- a/.github/workflows/typos.toml +++ b/.github/workflows/typos.toml @@ -2,4 +2,5 @@ groth = "groth" # to avoid it dectecting it as 'growth' BA = "BA" Ded = "Ded" # "ANDed", it thought "Ded" should be "Dead" -OT = "OT" \ No newline at end of file +OT = "OT" +aks = "aks" # anchored keys diff --git a/src/backends/plonky2/mock_main/mod.rs b/src/backends/plonky2/mock_main/mod.rs index a033f7b..1c4377b 100644 --- a/src/backends/plonky2/mock_main/mod.rs +++ b/src/backends/plonky2/mock_main/mod.rs @@ -7,7 +7,8 @@ use std::fmt; use crate::middleware::{ self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativePredicate, NonePod, - OperationType, Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF, + OperationType, Params, Pod, PodId, PodProver, Predicate, StatementArg, ToFields, KEY_TYPE, + SELF, }; mod operation; @@ -261,11 +262,7 @@ impl MockMainPod { .map(|mid_arg| Self::find_op_arg(statements, mid_arg)) .collect::>>()?; Self::pad_operation_args(params, &mut args); - let op_code = match op.code() { - OperationType::Native(code) => code, - _ => unimplemented!(), - }; - operations.push(Operation(op_code, args)); + operations.push(Operation(op.code(), args)); } Ok(operations) } @@ -280,15 +277,18 @@ impl MockMainPod { mut operations: Vec, ) -> Result> { let offset_public_statements = statements.len() - params.max_public_statements; - operations.push(Operation(NativeOperation::NewEntry, vec![])); + operations.push(Operation( + OperationType::Native(NativeOperation::NewEntry), + vec![], + )); for i in 0..(params.max_public_statements - 1) { let st = &statements[offset_public_statements + i + 1]; let mut op = if st.is_none() { - Operation(NativeOperation::None, vec![]) + Operation(OperationType::Native(NativeOperation::None), vec![]) } else { let mid_arg = st.clone(); Operation( - NativeOperation::CopyStatement, + OperationType::Native(NativeOperation::CopyStatement), // TODO vec![Self::find_op_arg(statements, &mid_arg.try_into().unwrap())?], ) @@ -348,11 +348,11 @@ impl MockMainPod { fn statement_none(params: &Params) -> Statement { let mut args = Vec::with_capacity(params.max_statement_args); Self::pad_statement_args(¶ms, &mut args); - Statement(NativePredicate::None, args) + Statement(Predicate::Native(NativePredicate::None), args) } fn operation_none(params: &Params) -> Operation { - let mut op = Operation(NativeOperation::None, vec![]); + let mut op = Operation(OperationType::Native(NativeOperation::None), vec![]); fill_pad(&mut op.1, OperationArg::None, params.max_operation_args); op } @@ -387,7 +387,7 @@ impl Pod for MockMainPod { .public_statements .iter() .find(|s| { - s.0 == NativePredicate::ValueOf + s.0 == Predicate::Native(NativePredicate::ValueOf) && s.1.len() > 0 && if let StatementArg::Key(AnchoredKey(pod_id, key_hash)) = s.1[0] { pod_id == SELF && key_hash == hash_str(KEY_TYPE) @@ -415,7 +415,7 @@ impl Pod for MockMainPod { s, ) }) - .filter(|(_, s)| s.0 == NativePredicate::ValueOf) + .filter(|(_, s)| s.0 == Predicate::Native(NativePredicate::ValueOf)) .flat_map(|(i, s)| { if let StatementArg::Key(ak) = &s.1[0] { vec![(i, ak.1, ak.0)] diff --git a/src/backends/plonky2/mock_main/operation.rs b/src/backends/plonky2/mock_main/operation.rs index f5ae1de..c1ec964 100644 --- a/src/backends/plonky2/mock_main/operation.rs +++ b/src/backends/plonky2/mock_main/operation.rs @@ -17,7 +17,7 @@ impl OperationArg { } #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Operation(pub NativeOperation, pub Vec); +pub struct Operation(pub OperationType, pub Vec); impl Operation { pub fn deref(&self, statements: &[Statement]) -> Result { @@ -29,7 +29,7 @@ impl Operation { OperationArg::Index(i) => Some(statements[*i].clone().try_into()), }) .collect::>>()?; - middleware::Operation::op(OperationType::Native(self.0), &deref_args) + middleware::Operation::op(self.0.clone(), &deref_args) } } diff --git a/src/backends/plonky2/mock_main/statement.rs b/src/backends/plonky2/mock_main/statement.rs index a9b27dd..69343df 100644 --- a/src/backends/plonky2/mock_main/statement.rs +++ b/src/backends/plonky2/mock_main/statement.rs @@ -1,14 +1,16 @@ use anyhow::{anyhow, Result}; use std::fmt; -use crate::middleware::{self, NativePredicate, Params, StatementArg, ToFields}; +use crate::middleware::{ + self, AnchoredKey, NativePredicate, Params, Predicate, StatementArg, ToFields, +}; #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Statement(pub NativePredicate, pub Vec); +pub struct Statement(pub Predicate, pub Vec); impl Statement { pub fn is_none(&self) -> bool { - self.0 == NativePredicate::None + self.0 == Predicate::Native(NativePredicate::None) } /// Argument method. Trailing Nones are filtered out. pub fn args(&self) -> Vec { @@ -52,31 +54,53 @@ impl TryFrom for middleware::Statement { proper_args.get(1).cloned(), proper_args.get(2).cloned(), ); - Ok(match (s.0, args, proper_args.len()) { - (NP::None, _, 0) => S::None, - (NP::ValueOf, (Some(SA::Key(ak)), Some(SA::Literal(v)), None), 2) => S::ValueOf(ak, v), - (NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Equal(ak1, ak2), - (NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { - S::NotEqual(ak1, ak2) + Ok(match s.0 { + Predicate::Native(np) => match (np, args, proper_args.len()) { + (NP::None, _, 0) => S::None, + (NP::ValueOf, (Some(SA::Key(ak)), Some(SA::Literal(v)), None), 2) => { + S::ValueOf(ak, v) + } + (NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { + S::Equal(ak1, ak2) + } + (NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { + S::NotEqual(ak1, ak2) + } + (NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Gt(ak1, ak2), + (NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Lt(ak1, ak2), + (NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { + S::Contains(ak1, ak2) + } + (NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { + S::NotContains(ak1, ak2) + } + (NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { + S::SumOf(ak1, ak2, ak3) + } + ( + NP::ProductOf, + (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), + 3, + ) => S::ProductOf(ak1, ak2, ak3), + (NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { + S::MaxOf(ak1, ak2, ak3) + } + _ => Err(anyhow!("Ill-formed statement expression {:?}", s))?, + }, + Predicate::Custom(cpr) => { + let aks: Vec = proper_args + .into_iter() + .filter_map(|arg| match arg { + SA::None => None, + SA::Key(ak) => Some(ak), + SA::Literal(_) => unreachable!(), + }) + .collect(); + S::Custom(cpr, aks) } - (NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Gt(ak1, ak2), - (NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Lt(ak1, ak2), - (NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { - S::Contains(ak1, ak2) + Predicate::BatchSelf(_) => { + unreachable!() } - (NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { - S::NotContains(ak1, ak2) - } - (NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { - S::SumOf(ak1, ak2, ak3) - } - (NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { - S::ProductOf(ak1, ak2, ak3) - } - (NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { - S::MaxOf(ak1, ak2, ak3) - } - _ => Err(anyhow!("Ill-formed statement expression {:?}", s))?, }) } } @@ -84,11 +108,15 @@ impl TryFrom for middleware::Statement { impl From for Statement { fn from(s: middleware::Statement) -> Self { match s.code() { - middleware::Predicate::Native(c) => { - Statement(c, s.args().into_iter().map(|arg| arg).collect()) - } - // TODO: Custom statements - _ => todo!(), + middleware::Predicate::Native(c) => Statement( + middleware::Predicate::Native(c), + s.args().into_iter().map(|arg| arg).collect(), + ), + middleware::Predicate::Custom(cpr) => Statement( + middleware::Predicate::Custom(cpr), + s.args().into_iter().map(|arg| arg).collect(), + ), + middleware::Predicate::BatchSelf(_) => unreachable!(), } } }