diff --git a/src/backends/mock_main.rs b/src/backends/mock_main.rs index 648d9d3..5025f0f 100644 --- a/src/backends/mock_main.rs +++ b/src/backends/mock_main.rs @@ -1,15 +1,21 @@ +mod operation; +mod statement; + use crate::middleware::{ self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativeStatement, NonePod, - Params, Pod, PodId, PodProver, Statement, StatementArg, ToFields, KEY_TYPE, SELF, + Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF, }; use anyhow::Result; use itertools::Itertools; +pub use operation::*; use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::Hasher; +pub use statement::*; use std::any::Any; -use std::error::Error; use std::fmt; +pub const VALUE_TYPE: &str = "MockMainPOD"; + pub struct MockProver {} impl PodProver for MockProver { @@ -18,72 +24,6 @@ impl PodProver for MockProver { } } -#[derive(Clone, Debug, PartialEq, Eq)] -enum OperationArg { - None, - Index(usize), -} - -impl OperationArg { - fn is_none(&self) -> bool { - matches!(self, OperationArg::None) - } -} - -#[derive(Clone, Debug, PartialEq, Eq)] -enum OperationArgError { - KeyNotFound, - StatementNotFound, -} - -impl std::fmt::Display for OperationArgError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - OperationArgError::KeyNotFound => write!(f, "Key not found"), - OperationArgError::StatementNotFound => write!(f, "Statement not found"), - } - } -} - -impl std::error::Error for OperationArgError {} - -#[derive(Clone, Debug, PartialEq, Eq)] -struct Operation(pub NativeOperation, pub Vec); - -impl Operation { - pub fn deref(&self, statements: &[Statement]) -> crate::middleware::Operation { - let deref_args = self - .1 - .iter() - .map(|arg| match arg { - OperationArg::None => middleware::OperationArg::None, - OperationArg::Index(i) => { - middleware::OperationArg::Statement(statements[*i].clone()) - } - }) - .collect(); - middleware::Operation(self.0, deref_args) - } -} - -impl fmt::Display for Operation { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:?} ", self.0)?; - for (i, arg) in self.1.iter().enumerate() { - if !(!f.alternate() && arg.is_none()) { - if i != 0 { - write!(f, " ")?; - } - match arg { - OperationArg::None => write!(f, "none")?, - OperationArg::Index(i) => write!(f, "{:02}", i)?, - } - } - } - Ok(()) - } -} - #[derive(Clone, Debug)] pub struct MockMainPod { params: Params, @@ -188,12 +128,16 @@ impl MockMainPod { fn offset_public_statements(&self) -> usize { self.offset_input_statements() + self.params.max_priv_statements() } + fn pad_statement(params: &Params, s: &mut Statement) { + fill_pad(&mut s.1, StatementArg::None, params.max_statement_args) + } + fn pad_operation(params: &Params, op: &mut Operation) { + fill_pad(&mut op.1, OperationArg::None, params.max_operation_args) + } fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec { let mut statements = Vec::new(); - let st_none = Self::statement_none(params); - // Input signed pods region let none_sig_pod: Box = Box::new(NonePod {}); assert!(inputs.signed_pods.len() <= params.max_input_signed_pods); @@ -206,8 +150,12 @@ impl MockMainPod { let sts = pod.pub_statements(); assert!(sts.len() <= params.max_signed_pod_values); for j in 0..params.max_signed_pod_values { - let mut st = sts.get(j).unwrap_or(&st_none).clone(); - Self::pad_statement_args(params, &mut st.1); + let mut st = sts + .get(j) + .unwrap_or(&middleware::Statement::None) + .clone() + .into(); + Self::pad_statement(params, &mut st); statements.push(st); } } @@ -224,8 +172,12 @@ impl MockMainPod { let sts = pod.pub_statements(); assert!(sts.len() <= params.max_public_statements); for j in 0..params.max_public_statements { - let mut st = sts.get(j).unwrap_or(&st_none).clone(); - Self::pad_statement_args(params, &mut st.1); + let mut st = sts + .get(j) + .unwrap_or(&middleware::Statement::None) + .clone() + .into(); + Self::pad_statement(params, &mut st); statements.push(st); } } @@ -233,54 +185,55 @@ impl MockMainPod { // Input statements assert!(inputs.statements.len() <= params.max_priv_statements()); for i in 0..params.max_priv_statements() { - let mut st = inputs.statements.get(i).unwrap_or(&st_none).clone(); - Self::pad_statement_args(params, &mut st.1); + let mut st = inputs + .statements + .get(i) + .unwrap_or(&middleware::Statement::None) + .clone() + .into(); + Self::pad_statement(params, &mut st); statements.push(st); } // Public statements assert!(inputs.public_statements.len() < params.max_public_statements); - statements.push(Statement( - NativeStatement::ValueOf, - vec![StatementArg::Key(AnchoredKey(SELF, hash_str(KEY_TYPE)))], - )); + let mut type_st = middleware::Statement::ValueOf( + AnchoredKey(SELF, hash_str(KEY_TYPE)), + middleware::Value(hash_str(VALUE_TYPE).0), + ) + .into(); + Self::pad_statement(params, &mut type_st); + statements.push(type_st); + for i in 0..(params.max_public_statements - 1) { - let mut st = inputs.public_statements.get(i).unwrap_or(&st_none).clone(); - Self::pad_statement_args(params, &mut st.1); + let mut st = inputs + .public_statements + .get(i) + .unwrap_or(&middleware::Statement::None) + .clone() + .into(); + Self::pad_statement(params, &mut st); statements.push(st); } statements } - pub fn find_op_arg( + fn find_op_arg( statements: &[Statement], - op_arg: &middleware::OperationArg, + op_arg: &middleware::Statement, ) -> Result { match op_arg { - middleware::OperationArg::None => Ok(OperationArg::None), - middleware::OperationArg::Key(k) => { - statements - .iter() - .enumerate() - .find_map(|(i, s)| match s.0 { - NativeStatement::ValueOf => match &s.1[0] { - StatementArg::Key(sk) => (sk == k).then_some(i), - _ => None, - }, - _ => None, - }) - .map(OperationArg::Index) - .ok_or(OperationArgError::KeyNotFound) - } - middleware::OperationArg::Statement(st) => { - statements - .iter() - .enumerate() - .find_map(|(i, s)| (s == st).then_some(i)) - .map(OperationArg::Index) - .ok_or(OperationArgError::StatementNotFound) - } + middleware::Statement::None => Ok(OperationArg::None), + _ => statements + .iter() + .enumerate() + .find_map(|(i, s)| { + // TODO: Error handling + (&middleware::Statement::try_from(s.clone()).unwrap() == op_arg).then_some(i) + }) + .map(OperationArg::Index) + .ok_or(OperationArgError::StatementNotFound), } } @@ -289,19 +242,19 @@ impl MockMainPod { statements: &[Statement], input_operations: &[middleware::Operation], ) -> Result, OperationArgError> { - let op_none = Self::operation_none(params); - let mut operations = Vec::new(); for i in 0..params.max_priv_statements() { - let op = input_operations.get(i).unwrap_or(&op_none).clone(); - let mut mid_args = op.1; - Self::pad_operation_args(params, &mut mid_args); - let mut args = Vec::with_capacity(mid_args.len()); - for mid_arg in &mid_args { - let op_arg = Self::find_op_arg(statements, mid_arg)?; - args.push(op_arg) - } - operations.push(Operation(op.0, args)); + let op = input_operations + .get(i) + .unwrap_or(&middleware::Operation::None) + .clone(); + let mid_args = op.args(); + let mut args = mid_args + .iter() + .map(|mid_arg| Self::find_op_arg(statements, mid_arg)) + .collect::, OperationArgError>>()?; + Self::pad_operation_args(params, &mut args); + operations.push(Operation(op.code(), args)); } Ok(operations) } @@ -320,11 +273,11 @@ impl MockMainPod { let mut op = if st.is_none() { Operation(NativeOperation::None, vec![]) } else { - let mid_arg = middleware::OperationArg::Statement(st.clone()); - let op_arg = Self::find_op_arg(statements, &mid_arg)?; + let mid_arg = st.clone(); Operation( NativeOperation::CopyStatement, - vec![op_arg], + // TODO + vec![Self::find_op_arg(statements, &mid_arg.try_into().unwrap())?], ) }; fill_pad(&mut op.1, OperationArg::None, params.max_operation_args); @@ -351,7 +304,16 @@ impl MockMainPod { .map(|p| (*p).clone()) .collect_vec(); let input_main_pods = inputs.main_pods.iter().map(|p| (*p).clone()).collect_vec(); - let input_statements = inputs.statements.iter().cloned().collect_vec(); + let input_statements = inputs + .statements + .iter() + .cloned() + .map(|s| { + let mut s = s.into(); + Self::pad_statement(params, &mut s); + s + }) + .collect_vec(); let public_statements = statements[statements.len() - params.max_public_statements..].to_vec(); @@ -376,26 +338,22 @@ impl MockMainPod { Statement(NativeStatement::None, args) } - fn operation_none(params: &Params) -> middleware::Operation { - let mut args = Vec::with_capacity(params.max_operation_args); - Self::pad_operation_args(¶ms, &mut args); - middleware::Operation(NativeOperation::None, args) + fn operation_none(params: &Params) -> Operation { + let mut op = Operation(NativeOperation::None, vec![]); + fill_pad(&mut op.1, OperationArg::None, params.max_operation_args); + op } fn pad_statement_args(params: &Params, args: &mut Vec) { fill_pad(args, StatementArg::None, params.max_statement_args) } - fn pad_operation_args(params: &Params, args: &mut Vec) { - fill_pad( - args, - middleware::OperationArg::None, - params.max_operation_args, - ) + fn pad_operation_args(params: &Params, args: &mut Vec) { + fill_pad(args, OperationArg::None, params.max_operation_args) } } -pub fn hash_statements(statements: &[middleware::Statement]) -> Result { +pub fn hash_statements(statements: &[Statement]) -> Result { let field_elems = statements .into_iter() .flat_map(|statement| statement.clone().to_fields().0) @@ -444,7 +402,7 @@ impl Pod for MockMainPod { s, ) }) - .filter(|(i, s)| s.0 == NativeStatement::ValueOf) + .filter(|(_, s)| s.0 == NativeStatement::ValueOf) .flat_map(|(i, s)| { if let StatementArg::Key(ak) = &s.1[0] { vec![(i, ak.1, ak.0)] @@ -463,7 +421,8 @@ impl Pod for MockMainPod { .map(|(i, s)| { self.operations[i] .deref(&self.statements[..input_statement_offset + i]) - .check(s.clone()) + .unwrap() + .check(&s.clone().try_into().unwrap()) }) .collect::>>() .unwrap(); @@ -472,7 +431,7 @@ impl Pod for MockMainPod { fn id(&self) -> PodId { self.id } - fn pub_statements(&self) -> Vec { + fn pub_statements(&self) -> Vec { // return the public statements, where when origin=SELF is replaced by origin=self.id() self.statements .iter() @@ -492,6 +451,8 @@ impl Pod for MockMainPod { }) .collect(), ) + .try_into() + .unwrap() }) .collect() } @@ -505,7 +466,10 @@ impl Pod for MockMainPod { pub mod tests { use super::*; use crate::backends::mock_signed::MockSigner; - use crate::examples::{great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_sign_pod_builders}; + use crate::examples::{ + great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, + zu_kyc_sign_pod_builders, + }; use crate::middleware; #[test] @@ -559,6 +523,6 @@ pub mod tests { let pod = proof_pod.pod.into_any().downcast::().unwrap(); println!("{}", pod); - assert_eq!(pod.verify(), true); + assert_eq!(pod.verify(), true); } } diff --git a/src/backends/mock_main/operation.rs b/src/backends/mock_main/operation.rs new file mode 100644 index 0000000..12cb933 --- /dev/null +++ b/src/backends/mock_main/operation.rs @@ -0,0 +1,71 @@ +use std::fmt; + +use anyhow::Result; + +use crate::middleware::{self, NativeOperation}; + +use super::Statement; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum OperationArg { + None, + Index(usize), +} + +impl OperationArg { + pub fn is_none(&self) -> bool { + matches!(self, OperationArg::None) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum OperationArgError { + KeyNotFound, + StatementNotFound, +} + +impl std::fmt::Display for OperationArgError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OperationArgError::KeyNotFound => write!(f, "Key not found"), + OperationArgError::StatementNotFound => write!(f, "Statement not found"), + } + } +} + +impl std::error::Error for OperationArgError {} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Operation(pub NativeOperation, pub Vec); + +impl Operation { + pub fn deref(&self, statements: &[Statement]) -> Result { + let deref_args = self + .1 + .iter() + .flat_map(|arg| match arg { + OperationArg::None => None, + OperationArg::Index(i) => Some(statements[*i].clone().try_into()), + }) + .collect::>>()?; + middleware::Operation::op(self.0, &deref_args) + } +} + +impl fmt::Display for Operation { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?} ", self.0)?; + for (i, arg) in self.1.iter().enumerate() { + if !(!f.alternate() && arg.is_none()) { + if i != 0 { + write!(f, " ")?; + } + match arg { + OperationArg::None => write!(f, "none")?, + OperationArg::Index(i) => write!(f, "{:02}", i)?, + } + } + } + Ok(()) + } +} diff --git a/src/backends/mock_main/statement.rs b/src/backends/mock_main/statement.rs new file mode 100644 index 0000000..de1f510 --- /dev/null +++ b/src/backends/mock_main/statement.rs @@ -0,0 +1,103 @@ +use std::fmt; + +use anyhow::{anyhow, Result}; + +use crate::middleware::{self, NativeStatement, StatementArg, ToFields}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Statement(pub NativeStatement, pub Vec); + +impl Statement { + pub fn is_none(&self) -> bool { + self.0 == NativeStatement::None + } + /// Argument method. Trailing Nones are filtered out. + pub fn args(&self) -> Vec { + let maybe_last_arg_index = (0..self.1.len()).rev().find(|i| !self.1[*i].is_none()); + match maybe_last_arg_index { + None => vec![], + Some(i) => self.1[0..i + 1].to_vec(), + } + } +} + +impl ToFields for Statement { + fn to_fields(self) -> (Vec, usize) { + let (native_statement_f, native_statement_f_len) = self.0.to_fields(); + let (vec_statementarg_f, vec_statementarg_f_len) = self + .1 + .into_iter() + .map(|statement_arg| statement_arg.to_fields()) + .fold((Vec::new(), 0), |mut acc, (f, l)| { + acc.0.extend(f); + acc.1 += l; + acc + }); + ( + [native_statement_f, vec_statementarg_f].concat(), + native_statement_f_len + vec_statementarg_f_len, + ) + } +} + +impl TryFrom for middleware::Statement { + type Error = anyhow::Error; + fn try_from(s: Statement) -> Result { + type S = middleware::Statement; + type NS = NativeStatement; + type SA = StatementArg; + let proper_args = s.args(); + let args = ( + proper_args.get(0).cloned(), + proper_args.get(1).cloned(), + proper_args.get(2).cloned(), + ); + Ok(match (s.0, args, proper_args.len()) { + (NS::None, _, 0) => S::None, + (NS::ValueOf, (Some(SA::Key(ak)), Some(SA::Literal(v)), None), 2) => S::ValueOf(ak, v), + (NS::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Equal(ak1, ak2), + (NS::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { + S::NotEqual(ak1, ak2) + } + (NS::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Gt(ak1, ak2), + (NS::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Lt(ak1, ak2), + (NS::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { + S::Contains(ak1, ak2) + } + (NS::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { + S::NotContains(ak1, ak2) + } + (NS::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { + S::SumOf(ak1, ak2, ak3) + } + (NS::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { + S::ProductOf(ak1, ak2, ak3) + } + (NS::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { + S::MaxOf(ak1, ak2, ak3) + } + _ => Err(anyhow!("Ill-formed statement expression {:?}", s))?, + }) + } +} + +impl From for Statement { + fn from(s: middleware::Statement) -> Self { + Statement(s.code(), s.args().into_iter().map(|arg| arg).collect()) + } +} + +impl fmt::Display for Statement { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?} ", self.0)?; + for (i, arg) in self.1.iter().enumerate() { + if !(!f.alternate() && arg.is_none()) { + if i != 0 { + write!(f, " ")?; + } + write!(f, "{}", arg)?; + } + } + Ok(()) + } +} diff --git a/src/backends/mock_signed.rs b/src/backends/mock_signed.rs index 62a8ccc..3da701f 100644 --- a/src/backends/mock_signed.rs +++ b/src/backends/mock_signed.rs @@ -1,6 +1,6 @@ use crate::middleware::{ - containers::Dictionary, hash_str, AnchoredKey, Hash, NativeStatement, Params, Pod, PodId, - PodSigner, PodType, Statement, StatementArg, Value, KEY_SIGNER, KEY_TYPE, + containers::Dictionary, hash_str, AnchoredKey, Hash, Params, Pod, PodId, PodSigner, PodType, + Statement, Value, KEY_SIGNER, KEY_TYPE, }; use crate::primitives::merkletree::MerkleTree; use anyhow::Result; @@ -81,15 +81,7 @@ impl Pod for MockSignedPod { let id = self.id(); self.dict .iter() - .map(|(k, v)| { - Statement( - NativeStatement::ValueOf, - vec![ - StatementArg::Key(AnchoredKey(id, Hash(k.0))), - StatementArg::Literal(*v), - ], - ) - }) + .map(|(k, v)| Statement::ValueOf(AnchoredKey(id, Hash(k.0)), *v)) .collect() } diff --git a/src/examples.rs b/src/examples.rs index 8ffba33..b617aae 100644 --- a/src/examples.rs +++ b/src/examples.rs @@ -208,14 +208,24 @@ pub fn tickets_sign_pod_builder(params: &Params) -> SignedPodBuilder { builder } -pub fn tickets_pod_builder(params: &Params, signed_pod: &SignedPod, expected_event_id: i64, expect_consumed: bool, blacklisted_emails: &Value) -> MainPodBuilder { +pub fn tickets_pod_builder( + params: &Params, + signed_pod: &SignedPod, + expected_event_id: i64, + expect_consumed: bool, + blacklisted_emails: &Value, +) -> MainPodBuilder { // Create a main pod referencing this signed pod with some statements let mut builder = MainPodBuilder::new(params); builder.add_signed_pod(signed_pod); builder.pub_op(op!(eq, (signed_pod, "eventId"), expected_event_id)); builder.pub_op(op!(eq, (signed_pod, "isConsumed"), expect_consumed)); builder.pub_op(op!(eq, (signed_pod, "isRevoked"), false)); - builder.pub_op(op!(not_contains, blacklisted_emails, (signed_pod, "attendeeEmail"))); + builder.pub_op(op!( + not_contains, + blacklisted_emails, + (signed_pod, "attendeeEmail") + )); builder } @@ -223,5 +233,11 @@ pub fn tickets_pod_full_flow() -> MainPodBuilder { let params = Params::default(); let builder = tickets_sign_pod_builder(¶ms); let signed_pod = builder.sign(&mut MockSigner { pk: "test".into() }).unwrap(); - tickets_pod_builder(¶ms, &signed_pod, 123, true, &Value::Dictionary(Dictionary::new(&HashMap::new()))) + tickets_pod_builder( + ¶ms, + &signed_pod, + 123, + true, + &Value::Dictionary(Dictionary::new(&HashMap::new())), + ) } diff --git a/src/frontend.rs b/src/frontend.rs index b949ef9..40efd84 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -1,6 +1,9 @@ //! The frontend includes the user-level abstractions and user-friendly types to define and work //! with Pods. +mod operation; +mod statement; + use anyhow::Result; use itertools::Itertools; use std::collections::HashMap; @@ -13,6 +16,8 @@ use crate::middleware::{ hash_str, Hash, MainPodInputs, NativeOperation, NativeStatement, Params, PodId, PodProver, PodSigner, SELF, }; +pub use operation::*; +pub use statement::*; /// This type is just for presentation purposes. #[derive(Clone, Debug, Default, Hash, PartialEq, Eq)] @@ -34,6 +39,7 @@ pub enum Value { Dictionary(Dictionary), Set(Set), Array(Array), + Raw(middleware::Value), } impl From<&str> for Value { @@ -63,6 +69,7 @@ impl From<&Value> for middleware::Value { Value::Dictionary(d) => middleware::Value(d.commitment().0), Value::Set(s) => middleware::Value(s.commitment().0), Value::Array(a) => middleware::Value(a.commitment().0), + Value::Raw(v) => v.clone(), } } } @@ -76,6 +83,7 @@ impl fmt::Display for Value { Value::Dictionary(d) => write!(f, "dict:{}", d.commitment()), Value::Set(s) => write!(f, "set:{}", s.commitment()), Value::Array(a) => write!(f, "arr:{}", a.commitment()), + Value::Raw(v) => write!(f, "{}", v), } } } @@ -159,111 +167,9 @@ impl SignedPod { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct AnchoredKey(pub Origin, pub String); -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum StatementArg { - Literal(Value), - Key(AnchoredKey), -} - -impl fmt::Display for StatementArg { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Literal(v) => write!(f, "{}", v), - Self::Key(r) => write!(f, "{}.{}", r.0 .1, r.1), - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Statement(pub NativeStatement, pub Vec); - -impl fmt::Display for Statement { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:?} ", self.0)?; - for (i, arg) in self.1.iter().enumerate() { - if i != 0 { - write!(f, " ")?; - } - write!(f, "{}", arg)?; - } - Ok(()) - } -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum OperationArg { - Statement(Statement), - Key(AnchoredKey), - Literal(Value), - Entry(String, Value), -} - -impl fmt::Display for OperationArg { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - OperationArg::Statement(s) => write!(f, "{}", s), - OperationArg::Key(k) => write!(f, "{}.{}", k.0 .1, k.1), - OperationArg::Literal(v) => write!(f, "{}", v), - OperationArg::Entry(k, v) => write!(f, "({}, {})", k, v), - } - } -} - -impl From for OperationArg { - fn from(v: Value) -> Self { - Self::Literal(v) - } -} - -impl From<&Value> for OperationArg { - fn from(v: &Value) -> Self { - Self::Literal(v.clone()) - } -} - -impl From<&str> for OperationArg { - fn from(s: &str) -> Self { - Self::Literal(Value::from(s)) - } -} - -impl From for OperationArg { - fn from(v: i64) -> Self { - Self::Literal(Value::from(v)) - } -} - -impl From for OperationArg { - fn from(b: bool) -> Self { - Self::Literal(Value::from(b)) - } -} - -impl From<(Origin, &str)> for OperationArg { - fn from((origin, key): (Origin, &str)) -> Self { - Self::Key(AnchoredKey(origin, key.to_string())) - } -} - -impl From<(&SignedPod, &str)> for OperationArg { - fn from((pod, key): (&SignedPod, &str)) -> Self { - Self::Key(AnchoredKey(pod.origin(), key.to_string())) - } -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Operation(pub NativeOperation, pub Vec); - -impl fmt::Display for Operation { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:?} ", self.0)?; - for (i, arg) in self.1.iter().enumerate() { - if i != 0 { - write!(f, " ")?; - } - write!(f, "{}", arg)?; - } - Ok(()) +impl From for middleware::AnchoredKey { + fn from(ak: AnchoredKey) -> Self { + middleware::AnchoredKey(ak.0 .1, hash_str(&ak.1)) } } @@ -329,8 +235,13 @@ impl MainPodBuilder { let mut st_args = Vec::new(); for arg in args.iter_mut() { match arg { - OperationArg::Statement(_s) => panic!("can't convert Statement to StatementArg"), - OperationArg::Key(k) => st_args.push(StatementArg::Key(k.clone())), + OperationArg::Statement(s) => { + if s.0 == NativeStatement::ValueOf { + st_args.push(s.1[0].clone()) + } else { + panic!("Invalid statement argument."); + } + } OperationArg::Literal(v) => { let k = format!("c{}", self.const_cnt); self.const_cnt += 1; @@ -341,7 +252,7 @@ impl MainPodBuilder { vec![OperationArg::Entry(k.clone(), v.clone())], ), ); - *arg = OperationArg::Key(AnchoredKey(Origin(PodClass::Main, SELF), k.clone())); + *arg = OperationArg::Statement(value_of_st.clone()); st_args.push(value_of_st.1[0].clone()) } OperationArg::Entry(k, v) => { @@ -472,10 +383,9 @@ impl MainPodCompiler { self.operations.push(op); } - fn compile_op_arg(&self, op_arg: &OperationArg) -> middleware::OperationArg { + fn compile_op_arg(&self, op_arg: &OperationArg) -> Option { match op_arg { - OperationArg::Statement(s) => middleware::OperationArg::Statement(self.compile_st(s)), - OperationArg::Key(k) => middleware::OperationArg::Key(Self::compile_anchored_key(k)), + OperationArg::Statement(s) => Some(self.compile_st(s)), OperationArg::Literal(_v) => { // OperationArg::Literal is a syntax sugar for the frontend. This is translated to // a new ValueOf statement and it's key used instead. @@ -485,47 +395,29 @@ impl MainPodCompiler { // OperationArg::Entry is only used in the frontend. The (key, value) will only // appear in the ValueOf statement in the backend. This is because a new ValueOf // statement doesn't have any requirement on the key and value. - middleware::OperationArg::None + None } } } - fn compile_anchored_key(key: &AnchoredKey) -> middleware::AnchoredKey { - middleware::AnchoredKey(key.0 .1, hash_str(&key.1)) - } - fn compile_st(&self, st: &Statement) -> middleware::Statement { - let mut st_args = Vec::new(); - let Statement(front_st_typ, front_st_args) = st; - for front_st_arg in front_st_args { - match front_st_arg { - StatementArg::Literal(v) => { - st_args.push(middleware::StatementArg::Literal(middleware::Value::from( - v, - ))); - } - StatementArg::Key(k) => { - let key = Self::compile_anchored_key(k); - st_args.push(middleware::StatementArg::Key(key)); - } - }; - if st_args.len() > self.params.max_statement_args { - panic!("too many statement st_args"); - } - } + st.clone().try_into().unwrap() + } - middleware::Statement(*front_st_typ, st_args) + fn compile_op(&self, op: &Operation) -> middleware::Operation { + // TODO + let mop_code: middleware::NativeOperation = op.0.into(); + let mop_args = + op.1.iter() + .flat_map(|arg| self.compile_op_arg(arg).map(|s| s.try_into().unwrap())) + .collect::>(); + middleware::Operation::op(mop_code, &mop_args).unwrap() } fn compile_st_op(&mut self, st: &Statement, op: &Operation) { let middle_st = self.compile_st(st); - self.push_st_op( - middle_st, - middleware::Operation( - op.0, - op.1.iter().map(|arg| self.compile_op_arg(arg)).collect(), - ), - ); + let middle_op = self.compile_op(op); + self.push_st_op(middle_st, middle_op); } pub fn compile<'a>( @@ -593,9 +485,11 @@ pub mod build_utils { #[cfg(test)] pub mod tests { use super::*; - use crate::backends::mock_main::MockProver; use crate::backends::mock_signed::MockSigner; - use crate::examples::{great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_sign_pod_builders}; + use crate::examples::{ + great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, + zu_kyc_sign_pod_builders, + }; #[test] fn test_front_zu_kyc() -> Result<()> { diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs new file mode 100644 index 0000000..0b713ed --- /dev/null +++ b/src/frontend/operation.rs @@ -0,0 +1,82 @@ +use std::fmt; + +use crate::middleware::{hash_str, NativeOperation, NativeStatement}; + +use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum OperationArg { + Statement(Statement), + Literal(Value), + Entry(String, Value), +} + +impl fmt::Display for OperationArg { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + OperationArg::Statement(s) => write!(f, "{}", s), + OperationArg::Literal(v) => write!(f, "{}", v), + OperationArg::Entry(k, v) => write!(f, "({}, {})", k, v), + } + } +} + +impl From for OperationArg { + fn from(v: Value) -> Self { + Self::Literal(v) + } +} + +impl From<&Value> for OperationArg { + fn from(v: &Value) -> Self { + Self::Literal(v.clone()) + } +} + +impl From<&str> for OperationArg { + fn from(s: &str) -> Self { + Self::Literal(Value::from(s)) + } +} + +impl From for OperationArg { + fn from(v: i64) -> Self { + Self::Literal(Value::from(v)) + } +} + +impl From for OperationArg { + fn from(b: bool) -> Self { + Self::Literal(Value::from(b)) + } +} + +impl From<(&SignedPod, &str)> for OperationArg { + fn from((pod, key): (&SignedPod, &str)) -> Self { + // TODO: Actual value, TryFrom. + let value = pod.kvs().get(&hash_str(key)).unwrap().clone(); + Self::Statement(Statement( + NativeStatement::ValueOf, + vec![ + StatementArg::Key(AnchoredKey(pod.origin(), key.to_string())), + StatementArg::Literal(Value::Raw(value)), + ], + )) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Operation(pub NativeOperation, pub Vec); + +impl fmt::Display for Operation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?} ", self.0)?; + for (i, arg) in self.1.iter().enumerate() { + if i != 0 { + write!(f, " ")?; + } + write!(f, "{}", arg)?; + } + Ok(()) + } +} diff --git a/src/frontend/statement.rs b/src/frontend/statement.rs new file mode 100644 index 0000000..9dd48de --- /dev/null +++ b/src/frontend/statement.rs @@ -0,0 +1,86 @@ +use std::fmt; + +use anyhow::{anyhow, Result}; + +use crate::middleware::{self, NativeStatement}; + +use super::{AnchoredKey, Value}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum StatementArg { + Literal(Value), + Key(AnchoredKey), +} + +impl fmt::Display for StatementArg { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Literal(v) => write!(f, "{}", v), + Self::Key(r) => write!(f, "{}.{}", r.0 .1, r.1), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Statement(pub NativeStatement, pub Vec); + +impl TryFrom for middleware::Statement { + type Error = anyhow::Error; + fn try_from(s: Statement) -> Result { + type MS = middleware::Statement; + type NS = NativeStatement; + type SA = StatementArg; + let args = ( + s.1.get(0).cloned(), + s.1.get(1).cloned(), + s.1.get(2).cloned(), + ); + Ok(match (s.0, args) { + (NS::None, (None, None, None)) => MS::None, + (NS::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => { + MS::ValueOf(ak.into(), (&v).into()) + } + (NS::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::Equal(ak1.into(), ak2.into()) + } + (NS::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::NotEqual(ak1.into(), ak2.into()) + } + (NS::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::Gt(ak1.into(), ak2.into()) + } + (NS::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::Lt(ak1.into(), ak2.into()) + } + (NS::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::Contains(ak1.into(), ak2.into()) + } + (NS::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => { + MS::NotContains(ak1.into(), ak2.into()) + } + (NS::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { + MS::SumOf(ak1.into(), ak2.into(), ak3.into()) + } + (NS::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { + MS::ProductOf(ak1.into(), ak2.into(), ak3.into()) + } + (NS::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => { + MS::MaxOf(ak1.into(), ak2.into(), ak3.into()) + } + _ => Err(anyhow!("Ill-formed statement: {}", s))?, + }) + } +} + +impl fmt::Display for Statement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?} ", self.0)?; + for (i, arg) in self.1.iter().enumerate() { + if i != 0 { + write!(f, " ")?; + } + write!(f, "{}", arg)?; + } + Ok(()) + } +} diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index e5bddd9..07e4032 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,25 +1,25 @@ //! The middleware includes the type definitions and the traits used to connect the frontend and //! the backend. +mod operation; +mod statement; + use anyhow::{anyhow, Error, Result}; use dyn_clone::DynClone; use hex::{FromHex, FromHexError}; +pub use operation::*; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::{Hasher, PoseidonGoldilocksConfig}; +pub use statement::*; use std::any::Any; use std::cmp::{Ord, Ordering}; use std::collections::HashMap; use std::fmt; -use strum_macros::FromRepr; pub mod containers; -pub const KEY_SIGNER: &str = "_signer"; -pub const KEY_TYPE: &str = "_type"; -pub const STATEMENT_ARG_F_LEN: usize = 8; - /// F is the native field we use everywhere. Currently it's Goldilocks from plonky2 pub type F = GoldilocksField; /// C is the Plonky2 config used in POD2 to work with Plonky2 recursion. @@ -27,6 +27,22 @@ pub type C = PoseidonGoldilocksConfig; /// D defines the extension degree of the field used in the Plonky2 proofs (quadratic extension). pub const D: usize = 2; +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +/// AnchoredKey is a tuple containing (OriginId: PodId, key: Hash) +pub struct AnchoredKey(pub PodId, pub Hash); + +impl AnchoredKey { + pub fn origin(&self) -> PodId { + self.0 + } + pub fn key(&self) -> Hash { + self.1 + } +} + +/// An entry consists of a key-value pair. +pub type Entry = (String, Value); + #[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)] pub struct Value(pub [F; 4]); @@ -231,306 +247,6 @@ impl Default for Params { } } -#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq)] -pub enum NativeStatement { - None = 0, - ValueOf = 1, - Equal = 2, - NotEqual = 3, - Gt = 4, - Lt = 5, - Contains = 6, - NotContains = 7, - SumOf = 8, - ProductOf = 9, - MaxOf = 10, -} - -impl ToFields for NativeStatement { - fn to_fields(self) -> (Vec, usize) { - (vec![F::from_canonical_u64(self as u64)], 1) - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -/// AnchoredKey is a tuple containing (OriginId: PodId, key: Hash) -pub struct AnchoredKey(pub PodId, pub Hash); - -impl AnchoredKey { - pub fn origin(&self) -> PodId { - self.0 - } - pub fn key(&self) -> Hash { - self.1 - } -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum StatementArg { - None, - Literal(Value), - Key(AnchoredKey), -} - -impl fmt::Display for StatementArg { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - StatementArg::None => write!(f, "none"), - StatementArg::Literal(v) => write!(f, "{}", v), - StatementArg::Key(r) => write!(f, "{}.{}", r.0, r.1), - } - } -} - -impl StatementArg { - pub fn is_none(&self) -> bool { - matches!(self, Self::None) - } - pub fn literal(&self) -> Result { - match self { - Self::Literal(value) => Ok(*value), - _ => Err(anyhow!("Statement argument {:?} is not a literal.", self)), - } - } - pub fn key(&self) -> Result { - match self { - Self::Key(ak) => Ok(ak.clone()), - _ => Err(anyhow!("Statement argument {:?} is not a key.", self)), - } - } -} - -impl ToFields for StatementArg { - fn to_fields(self) -> (Vec, usize) { - // NOTE: current version returns always the same amount of field elements in the returned - // vector, which means that the `None` case is padded with 8 zeroes, and the `Literal` case - // is padded with 4 zeroes. Since the returned vector will mostly be hashed (and reproduced - // in-circuit), we might be interested into reducing the length of it. If that's the case, - // we can check if it makes sense to make it dependant on the concrete StatementArg; that - // is, when dealing with a `None` it would be a single field element (zero value), and when - // dealing with `Literal` it would be of length 4. - let f = match self { - StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN], - StatementArg::Literal(v) => { - let value_f = v.0.to_vec(); - [ - value_f.clone(), - vec![F::ZERO; STATEMENT_ARG_F_LEN - value_f.len()], - ] - .concat() - } - StatementArg::Key(ak) => { - let (podid_f, _) = ak.0.to_fields(); - let (hash_f, _) = ak.1.to_fields(); - [podid_f, hash_f].concat() - } - }; - assert_eq!(f.len(), STATEMENT_ARG_F_LEN); // sanity check - (f, STATEMENT_ARG_F_LEN) - } -} - -// TODO: Replace this with a more stringly typed enum as in the Devcon implementation. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Statement(pub NativeStatement, pub Vec); - -impl fmt::Display for Statement { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:?} ", self.0)?; - for (i, arg) in self.1.iter().enumerate() { - if !(!f.alternate() && arg.is_none()) { - if i != 0 { - write!(f, " ")?; - } - write!(f, "{}", arg)?; - } - } - Ok(()) - } -} - -impl Statement { - pub fn code(&self) -> NativeStatement { - self.0 - } - pub fn args(&self) -> &[StatementArg] { - &self.1 - } - pub fn is_none(&self) -> bool { - matches!(self.0, NativeStatement::None) - } -} - -impl ToFields for Statement { - fn to_fields(self) -> (Vec, usize) { - let (native_statement_f, native_statement_f_len) = self.0.to_fields(); - let (vec_statementarg_f, vec_statementarg_f_len) = self - .1 - .into_iter() - .map(|statement_arg| statement_arg.to_fields()) - .fold((Vec::new(), 0), |mut acc, (f, l)| { - acc.0.extend(f); - acc.1 += l; - acc - }); - ( - [native_statement_f, vec_statementarg_f].concat(), - native_statement_f_len + vec_statementarg_f_len, - ) - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum NativeOperation { - None = 0, - NewEntry = 1, - CopyStatement = 2, - EqualFromEntries = 3, - NotEqualFromEntries = 4, - GtFromEntries = 5, - LtFromEntries = 6, - TransitiveEqualFromStatements = 7, - GtToNotEqual = 8, - LtToNotEqual = 9, - ContainsFromEntries = 10, - NotContainsFromEntries = 11, - RenameContainedBy = 12, - SumOf = 13, - ProductOf = 14, - MaxOf = 15, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum OperationArg { - None, - Statement(Statement), - Key(AnchoredKey), -} - -impl OperationArg { - pub fn is_none(&self) -> bool { - matches!(self, Self::None) - } - pub fn statement(&self) -> Result { - match self { - Self::Statement(statement) => Ok(statement.clone()), - _ => Err(anyhow!("Operation argument {:?} is not a statement.", self)), - } - } - pub fn key(&self) -> Result { - match self { - Self::Key(ak) => Ok(ak.clone()), - _ => Err(anyhow!("Operation argument {:?} is not a key.", self)), - } - } -} - -// TODO: Replace this with a more stringly typed enum as in the Devcon implementation. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Operation(pub NativeOperation, pub Vec); - -impl Operation { - pub fn code(&self) -> NativeOperation { - self.0 - } - pub fn args(&self) -> &[OperationArg] { - &self.1 - } - // TODO: Argument checking. - // TODO: Use `Err` for all type mismatches rather than `false`. - /// Checks the given operation against a statement. - pub fn check(&self, output_statement: Statement) -> Result { - use NativeOperation::*; - match self.0 { - // Nothing to check. - None => Ok(output_statement.code() == NativeStatement::None), - // Check that the resulting statement is of type `ValueOf` - // and its origin is `SELF`. - NewEntry => - Ok(output_statement.code() == NativeStatement::ValueOf && output_statement.args()[0].key()?.origin() == SELF) - , - // Check that the operation acts on a statement *and* the - // output is equal to this statement. - CopyStatement => Ok(output_statement == self.args()[0].statement()?) - , - EqualFromEntries => { - let s1 = self.args()[0].statement()?; - let (s1_key, s1_value) = (s1.args()[0].key()?, s1.args()[1].literal()?); - let s2 = self.args()[1].statement()?; - let (s2_key, s2_value) = (s2.args()[0].key()?, s2.args()[1].literal()?); - let statements_equal = s1.code() == NativeStatement::ValueOf && s2.code() == NativeStatement::ValueOf && s1_value == s2_value; - Ok(statements_equal && output_statement.code() == NativeStatement::Equal && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)} - , - NotEqualFromEntries => { - let s1 = self.args()[0].statement()?; - let (s1_key, s1_value) = (s1.args()[0].key()?, s1.args()[1].literal()?); - let s2 = self.args()[1].statement()?; - let (s2_key, s2_value) = (s2.args()[0].key()?, s2.args()[1].literal()?); - let statements_not_equal = s1.code() == NativeStatement::ValueOf && s2.code() == NativeStatement::ValueOf && s1_value != s2_value; - Ok(statements_not_equal && output_statement.code() == NativeStatement::NotEqual && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)} , - GtFromEntries => { - let s1 = self.args()[0].statement()?; - let (s1_key, s1_value) = (s1.args()[0].key()?, s1.args()[1].literal()?); - let s2 = self.args()[1].statement()?; - let (s2_key, s2_value) = (s2.args()[0].key()?, s2.args()[1].literal()?); - let statements_not_equal = s1.code() == NativeStatement::ValueOf && s2.code() == NativeStatement::ValueOf && s1_value > s2_value; - Ok(statements_not_equal && output_statement.code() == NativeStatement::Gt && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)}, - LtFromEntries => { - let s1 = self.args()[0].statement()?; - let (s1_key, s1_value) = (s1.args()[0].key()?, s1.args()[1].literal()?); - let s2 = self.args()[1].statement()?; - let (s2_key, s2_value) = (s2.args()[0].key()?, s2.args()[1].literal()?); - let statements_not_equal = s1.code() == NativeStatement::ValueOf && s2.code() == NativeStatement::ValueOf && s1_value < s2_value; - Ok(statements_not_equal && output_statement.code() == NativeStatement::Lt && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)}, - TransitiveEqualFromStatements => { - let s1 = self.args()[0].statement()?; - let s2 = self.args()[1].statement()?; - let key1 = s1.args()[0].key()?; - let key2 = s1.args()[1].key()?; - let key3 = s2.args()[0].key()?; - let key4 = s2.args()[1].key()?; - let statements_satisfy_transitivity = s1.code() == NativeStatement::Equal && s2.code() == NativeStatement::Equal && key2 == key3; - Ok(statements_satisfy_transitivity && output_statement.code() == NativeStatement::Equal && output_statement.args()[0].key()? == key1 && output_statement.args()[1].key()? == key4) - }, - GtToNotEqual => { - let s = self.args()[0].statement()?; - let arg_is_gt = s.code() == NativeStatement::Gt; - Ok(arg_is_gt && output_statement.code() == NativeStatement::NotEqual && output_statement.args() == s.args()) - }, - LtToNotEqual => { - let s = self.args()[0].statement()?; - let arg_is_lt = s.code() == NativeStatement::Lt; - Ok(arg_is_lt && output_statement.code() == NativeStatement::NotEqual && output_statement.args() == s.args()) - }, - RenameContainedBy => { - let s1 = self.args()[0].statement()?; - let s2 = self.args()[1].statement()?; - let key1 = s1.args()[0].key()?; - let key2 = s1.args()[1].key()?; - let key3 = s2.args()[0].key()?; - let key4 = s2.args()[1].key()?; - let args_satisfy_rename = s1.code() == NativeStatement::Contains && s2.code() == NativeStatement::Equal && key1 == key3; - Ok(args_satisfy_rename && output_statement.code() == NativeStatement::Contains && output_statement.args()[0].key()? == key4 && output_statement.args()[1].key()? == key2) - }, - SumOf => { - let s1 = self.args()[0].statement()?; - let s1_key = s1.args()[0].key()?; - let s1_value: i64 = s1.args()[1].literal()?.try_into()?; - let s2 = self.args()[1].statement()?; - let s2_key = s2.args()[0].key()?; - let s2_value:i64 = s2.args()[1].literal()?.try_into()?; - let s3 = self.args()[2].statement()?; - let s3_key = s3.args()[0].key()?; - let s3_value: i64 = s3.args()[1].literal()?.try_into()?; - let sum_holds = s1.code() == NativeStatement::ValueOf && s2.code() == NativeStatement::ValueOf && s3.code() == NativeStatement::ValueOf && s1_value == s2_value + s3_value; - Ok(sum_holds && output_statement.code() == NativeStatement::SumOf && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key && output_statement.args()[2].key()? == s3_key) - }, - // TODO: Remaining ops. - _ => Ok(true) - } - } -} - pub trait Pod: fmt::Debug + DynClone { fn verify(&self) -> bool; fn id(&self) -> PodId; @@ -539,11 +255,8 @@ pub trait Pod: fmt::Debug + DynClone { fn kvs(&self) -> HashMap { self.pub_statements() .into_iter() - .filter_map(|st| match st.0 { - NativeStatement::ValueOf => Some(( - st.1[0].key().expect("key"), - st.1[1].literal().expect("literal"), - )), + .filter_map(|st| match st { + Statement::ValueOf(ak, v) => Some((ak, v)), _ => None, }) .collect() diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs new file mode 100644 index 0000000..5f3a5c3 --- /dev/null +++ b/src/middleware/operation.rs @@ -0,0 +1,181 @@ +use crate::middleware::{AnchoredKey, SELF}; +use anyhow::{anyhow, Result}; + +use super::Statement; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum NativeOperation { + None = 0, + NewEntry = 1, + CopyStatement = 2, + EqualFromEntries = 3, + NotEqualFromEntries = 4, + GtFromEntries = 5, + LtFromEntries = 6, + TransitiveEqualFromStatements = 7, + GtToNotEqual = 8, + LtToNotEqual = 9, + ContainsFromEntries = 10, + NotContainsFromEntries = 11, + RenameContainedBy = 12, + SumOf = 13, + ProductOf = 14, + MaxOf = 15, +} + +// TODO: Refine this enum. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Operation { + None, + NewEntry, + CopyStatement(Statement), + EqualFromEntries(Statement, Statement), + NotEqualFromEntries(Statement, Statement), + GtFromEntries(Statement, Statement), + LtFromEntries(Statement, Statement), + TransitiveEqualFromStatements(Statement, Statement), + GtToNotEqual(Statement), + LtToNotEqual(Statement), + ContainsFromEntries(Statement, Statement), + NotContainsFromEntries(Statement, Statement), + RenameContainedBy(Statement, Statement), + SumOf(Statement, Statement, Statement), + ProductOf(Statement, Statement, Statement), + MaxOf(Statement, Statement, Statement), +} + +impl Operation { + pub fn code(&self) -> NativeOperation { + use NativeOperation::*; + match self { + Self::None => None, + Self::NewEntry => NewEntry, + Self::CopyStatement(_) => CopyStatement, + Self::EqualFromEntries(_, _) => EqualFromEntries, + Self::NotEqualFromEntries(_, _) => NotEqualFromEntries, + Self::GtFromEntries(_, _) => GtFromEntries, + Self::LtFromEntries(_, _) => LtFromEntries, + Self::TransitiveEqualFromStatements(_, _) => TransitiveEqualFromStatements, + Self::GtToNotEqual(_) => GtToNotEqual, + Self::LtToNotEqual(_) => LtToNotEqual, + Self::ContainsFromEntries(_, _) => ContainsFromEntries, + Self::NotContainsFromEntries(_, _) => NotContainsFromEntries, + Self::RenameContainedBy(_, _) => RenameContainedBy, + Self::SumOf(_, _, _) => SumOf, + Self::ProductOf(_, _, _) => ProductOf, + Self::MaxOf(_, _, _) => MaxOf, + } + } + + pub fn args(&self) -> Vec { + match self.clone() { + Self::None => vec![], + Self::NewEntry => vec![], + Self::CopyStatement(s) => vec![s], + Self::EqualFromEntries(s1, s2) => vec![s1, s2], + Self::NotEqualFromEntries(s1, s2) => vec![s1, s2], + Self::GtFromEntries(s1, s2) => vec![s1, s2], + Self::LtFromEntries(s1, s2) => vec![s1, s2], + Self::TransitiveEqualFromStatements(s1, s2) => vec![s1, s2], + Self::GtToNotEqual(s) => vec![s], + Self::LtToNotEqual(s) => vec![s], + Self::ContainsFromEntries(s1, s2) => vec![s1, s2], + Self::NotContainsFromEntries(s1, s2) => vec![s1, s2], + Self::RenameContainedBy(s1, s2) => vec![s1, s2], + Self::SumOf(s1, s2, s3) => vec![s1, s2, s3], + Self::ProductOf(s1, s2, s3) => vec![s1, s2, s3], + Self::MaxOf(s1, s2, s3) => vec![s1, s2, s3], + } + } + /// Forms operation from op-code and arguments. + pub fn op(op_code: NativeOperation, args: &[Statement]) -> Result { + type NO = NativeOperation; + let arg_tup = ( + args.get(0).cloned(), + args.get(1).cloned(), + args.get(2).cloned(), + ); + Ok(match (op_code, arg_tup, args.len()) { + (NO::None, (None, None, None), 0) => Self::None, + (NO::NewEntry, (None, None, None), 0) => Self::NewEntry, + (NO::CopyStatement, (Some(s), None, None), 1) => Self::CopyStatement(s), + (NO::EqualFromEntries, (Some(s1), Some(s2), None), 2) => Self::EqualFromEntries(s1, s2), + (NO::NotEqualFromEntries, (Some(s1), Some(s2), None), 2) => { + Self::NotEqualFromEntries(s1, s2) + } + (NO::GtFromEntries, (Some(s1), Some(s2), None), 2) => Self::GtFromEntries(s1, s2), + (NO::LtFromEntries, (Some(s1), Some(s2), None), 2) => Self::LtFromEntries(s1, s2), + (NO::ContainsFromEntries, (Some(s1), Some(s2), None), 2) => { + Self::ContainsFromEntries(s1, s2) + } + (NO::NotContainsFromEntries, (Some(s1), Some(s2), None), 2) => { + Self::NotContainsFromEntries(s1, s2) + } + (NO::RenameContainedBy, (Some(s1), Some(s2), None), 2) => { + Self::RenameContainedBy(s1, s2) + } + (NO::SumOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::SumOf(s1, s2, s3), + (NO::ProductOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::ProductOf(s1, s2, s3), + (NO::MaxOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::MaxOf(s1, s2, s3), + _ => Err(anyhow!( + "Ill-formed operation {:?} with arguments {:?}.", + op_code, + args + ))?, + }) + } + /// Checks the given operation against a statement. + pub fn check(&self, output_statement: &Statement) -> Result { + use Statement::*; + match (self, output_statement) { + (Self::None, None) => Ok(true), + (Self::NewEntry, ValueOf(AnchoredKey(pod_id, _), _)) => Ok(pod_id == &SELF), + (Self::CopyStatement(s1), s2) => Ok(s1 == s2), + (Self::EqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), Equal(ak3, ak4)) => { + Ok(v1 == v2 && ak3 == ak1 && ak4 == ak2) + } + (Self::NotEqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), NotEqual(ak3, ak4)) => { + Ok(v1 != v2 && ak3 == ak1 && ak4 == ak2) + } + (Self::GtFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), Gt(ak3, ak4)) => { + Ok(v1 > v2 && ak3 == ak1 && ak4 == ak2) + } + (Self::LtFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), Lt(ak3, ak4)) => { + Ok(v1 < v2 && ak3 == ak1 && ak4 == ak2) + } + (Self::ContainsFromEntries(_, _), Contains(_, _)) => + /* TODO */ + { + Ok(true) + } + (Self::NotContainsFromEntries(_, _), NotContains(_, _)) => + /* TODO */ + { + Ok(true) + } + ( + Self::TransitiveEqualFromStatements(Equal(ak1, ak2), Equal(ak3, ak4)), + Equal(ak5, ak6), + ) => Ok(ak2 == ak3 && ak5 == ak1 && ak6 == ak4), + (Self::GtToNotEqual(Gt(ak1, ak2)), NotEqual(ak3, ak4)) => Ok(ak1 == ak3 && ak2 == ak4), + (Self::LtToNotEqual(Lt(ak1, ak2)), NotEqual(ak3, ak4)) => Ok(ak1 == ak3 && ak2 == ak4), + (Self::RenameContainedBy(Contains(ak1, ak2), Equal(ak3, ak4)), Contains(ak5, ak6)) => { + Ok(ak1 == ak3 && ak4 == ak5 && ak2 == ak6) + } + ( + Self::SumOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)), + SumOf(ak4, ak5, ak6), + ) => { + let v1: i64 = v1.clone().try_into()?; + let v2: i64 = v2.clone().try_into()?; + let v3: i64 = v3.clone().try_into()?; + Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3) + } + _ => Err(anyhow!( + "Invalid deduction: {:?} ⇏ {:#}", + self, + output_statement + )), + } + } +} diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs new file mode 100644 index 0000000..617ad6d --- /dev/null +++ b/src/middleware/statement.rs @@ -0,0 +1,183 @@ +use anyhow::{anyhow, Result}; +use plonky2::field::types::Field; +use std::fmt; +use strum_macros::FromRepr; + +use super::{AnchoredKey, ToFields, Value, F}; + +pub const KEY_SIGNER: &str = "_signer"; +pub const KEY_TYPE: &str = "_type"; +pub const STATEMENT_ARG_F_LEN: usize = 8; + +#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq)] +pub enum NativeStatement { + None = 0, + ValueOf = 1, + Equal = 2, + NotEqual = 3, + Gt = 4, + Lt = 5, + Contains = 6, + NotContains = 7, + SumOf = 8, + ProductOf = 9, + MaxOf = 10, +} + +impl ToFields for NativeStatement { + fn to_fields(self) -> (Vec, usize) { + (vec![F::from_canonical_u64(self as u64)], 1) + } +} + +// TODO: Incorporate custom statements into this enum. +/// Type encapsulating statements with their associated arguments. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Statement { + None, + ValueOf(AnchoredKey, Value), + Equal(AnchoredKey, AnchoredKey), + NotEqual(AnchoredKey, AnchoredKey), + Gt(AnchoredKey, AnchoredKey), + Lt(AnchoredKey, AnchoredKey), + Contains(AnchoredKey, AnchoredKey), + NotContains(AnchoredKey, AnchoredKey), + SumOf(AnchoredKey, AnchoredKey, AnchoredKey), + ProductOf(AnchoredKey, AnchoredKey, AnchoredKey), + MaxOf(AnchoredKey, AnchoredKey, AnchoredKey), +} + +impl Statement { + pub fn is_none(&self) -> bool { + self == &Self::None + } + pub fn code(&self) -> NativeStatement { + match self { + Self::None => NativeStatement::None, + Self::ValueOf(_, _) => NativeStatement::ValueOf, + Self::Equal(_, _) => NativeStatement::Equal, + Self::NotEqual(_, _) => NativeStatement::NotEqual, + Self::Gt(_, _) => NativeStatement::Gt, + Self::Lt(_, _) => NativeStatement::Lt, + Self::Contains(_, _) => NativeStatement::Contains, + Self::NotContains(_, _) => NativeStatement::NotContains, + Self::SumOf(_, _, _) => NativeStatement::SumOf, + Self::ProductOf(_, _, _) => NativeStatement::ProductOf, + Self::MaxOf(_, _, _) => NativeStatement::MaxOf, + } + } + pub fn args(&self) -> Vec { + use StatementArg::*; + match self.clone() { + Self::None => vec![], + Self::ValueOf(ak, v) => vec![Key(ak), Literal(v)], + Self::Equal(ak1, ak2) => vec![Key(ak1), Key(ak2)], + Self::NotEqual(ak1, ak2) => vec![Key(ak1), Key(ak2)], + Self::Gt(ak1, ak2) => vec![Key(ak1), Key(ak2)], + Self::Lt(ak1, ak2) => vec![Key(ak1), Key(ak2)], + Self::Contains(ak1, ak2) => vec![Key(ak1), Key(ak2)], + Self::NotContains(ak1, ak2) => vec![Key(ak1), Key(ak2)], + Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], + Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], + Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], + } + } +} + +impl ToFields for Statement { + fn to_fields(self) -> (Vec, usize) { + let (native_statement_f, native_statement_f_len) = self.code().to_fields(); + let (vec_statementarg_f, vec_statementarg_f_len) = self + .args() + .into_iter() + .map(|statement_arg| statement_arg.to_fields()) + .fold((Vec::new(), 0), |mut acc, (f, l)| { + acc.0.extend(f); + acc.1 += l; + acc + }); + ( + [native_statement_f, vec_statementarg_f].concat(), + native_statement_f_len + vec_statementarg_f_len, + ) + } +} + +impl fmt::Display for Statement { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?} ", self.code())?; + for (i, arg) in self.args().iter().enumerate() { + if i != 0 { + write!(f, " ")?; + } + write!(f, "{}", arg)?; + } + Ok(()) + } +} + +/// Statement argument type. Useful for statement decompositions. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum StatementArg { + None, + Literal(Value), + Key(AnchoredKey), +} + +impl fmt::Display for StatementArg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + StatementArg::None => write!(f, "none"), + StatementArg::Literal(v) => write!(f, "{}", v), + StatementArg::Key(r) => write!(f, "{}.{}", r.0, r.1), + } + } +} + +impl StatementArg { + pub fn is_none(&self) -> bool { + matches!(self, Self::None) + } + pub fn literal(&self) -> Result { + match self { + Self::Literal(value) => Ok(*value), + _ => Err(anyhow!("Statement argument {:?} is not a literal.", self)), + } + } + pub fn key(&self) -> Result { + match self { + Self::Key(ak) => Ok(ak.clone()), + _ => Err(anyhow!("Statement argument {:?} is not a key.", self)), + } + } +} + +impl ToFields for StatementArg { + fn to_fields(self) -> (Vec, usize) { + // NOTE: current version returns always the same amount of field elements in the returned + // vector, which means that the `None` case is padded with 8 zeroes, and the `Literal` case + // is padded with 4 zeroes. Since the returned vector will mostly be hashed (and reproduced + // in-circuit), we might be interested into reducing the length of it. If that's the case, + // we can check if it makes sense to make it dependant on the concrete StatementArg; that + // is, when dealing with a `None` it would be a single field element (zero value), and when + // dealing with `Literal` it would be of length 4. + let f = match self { + StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN], + StatementArg::Literal(v) => { + let value_f = v.0.to_vec(); + [ + value_f.clone(), + vec![F::ZERO; STATEMENT_ARG_F_LEN - value_f.len()], + ] + .concat() + } + StatementArg::Key(ak) => { + let (podid_f, _) = ak.0.to_fields(); + let (hash_f, _) = ak.1.to_fields(); + [podid_f, hash_f].concat() + } + }; + assert_eq!(f.len(), STATEMENT_ARG_F_LEN); // sanity check + (f, STATEMENT_ARG_F_LEN) + } +}