diff --git a/src/backends/plonky2/mock_main/mod.rs b/src/backends/plonky2/mock_main/mod.rs index 28d2bb5..a4b4cad 100644 --- a/src/backends/plonky2/mock_main/mod.rs +++ b/src/backends/plonky2/mock_main/mod.rs @@ -455,7 +455,7 @@ impl Pod for MockMainPod { StatementArg::Key(AnchoredKey(pod_id, h)) if *pod_id == SELF => { StatementArg::Key(AnchoredKey(self.id(), *h)) } - _ => sa.clone(), + _ => *sa, }) .collect(), ) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index a0b70ac..b397c42 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -1,7 +1,7 @@ //! The frontend includes the user-level abstractions and user-friendly types to define and work //! with Pods. -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Error, Result}; use itertools::Itertools; use std::collections::HashMap; use std::convert::From; @@ -83,6 +83,17 @@ impl From for Value { } } +impl TryInto for Value { + type Error = Error; + fn try_into(self) -> std::result::Result { + if let Value::Int(n) = self { + Ok(n) + } else { + Err(anyhow!("Value not an int")) + } + } +} + impl fmt::Display for Value { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -317,6 +328,7 @@ impl MainPodBuilder { panic!("Invalid statement argument."); } } + // todo: better error handling OperationArg::Literal(v) => { let k = format!("c{}", self.const_cnt); self.const_cnt += 1; @@ -354,45 +366,226 @@ impl MainPodBuilder { use NativeOperation::*; let Operation(op_type, ref mut args) = &mut op; // TODO: argument type checking - let st = match op_type { + let pred = op_type + .output_predicate() + .map(|p| Ok(p)) + .unwrap_or_else(|| { + // We are dealing with a copy here. + match (&args).get(0) { + Some(OperationArg::Statement(s)) if args.len() == 1 => Ok(s.0.clone()), + _ => Err(anyhow!("Invalid arguments to copy operation: {:?}", args)), + } + })?; + + let st_args: Vec = match op_type { OperationType::Native(o) => match o { - None => Statement(Predicate::Native(NativePredicate::None), vec![]), - NewEntry => Statement( - Predicate::Native(NativePredicate::ValueOf), - self.op_args_entries(public, args)?, - ), - CopyStatement => todo!(), - EqualFromEntries => Statement( - Predicate::Native(NativePredicate::Equal), - self.op_args_entries(public, args)?, - ), - NotEqualFromEntries => Statement( - Predicate::Native(NativePredicate::NotEqual), - self.op_args_entries(public, args)?, - ), - GtFromEntries => Statement( - Predicate::Native(NativePredicate::Gt), - self.op_args_entries(public, args)?, - ), - LtFromEntries => Statement( - Predicate::Native(NativePredicate::Lt), - self.op_args_entries(public, args)?, - ), - TransitiveEqualFromStatements => todo!(), - GtToNotEqual => todo!(), - LtToNotEqual => todo!(), - ContainsFromEntries => Statement( - Predicate::Native(NativePredicate::Contains), - self.op_args_entries(public, args)?, - ), - NotContainsFromEntries => Statement( - Predicate::Native(NativePredicate::NotContains), - self.op_args_entries(public, args)?, - ), - RenameContainedBy => todo!(), - SumOf => todo!(), - ProductOf => todo!(), - MaxOf => todo!(), + None => vec![], + NewEntry => self.op_args_entries(public, args)?, + CopyStatement => match &args[0] { + OperationArg::Statement(s) => s.1.clone(), + _ => { + return Err(anyhow!("Invalid arguments to operation: {}", op)); + } + }, + EqualFromEntries => self.op_args_entries(public, args)?, + NotEqualFromEntries => self.op_args_entries(public, args)?, + GtFromEntries => self.op_args_entries(public, args)?, + LtFromEntries => self.op_args_entries(public, args)?, + TransitiveEqualFromStatements => { + match (args[0].clone(), args[1].clone()) { + ( + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::Equal), + st0_args, + )), + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::Equal), + st1_args, + )), + ) => { + // st_args0 == vec![ak0, ak1] + // st_args1 == vec![ak1, ak2] + // output statement Equals(ak0, ak2) + if st0_args[1] == st1_args[0] { + vec![st0_args[0].clone(), st1_args[1].clone()] + } else { + return Err(anyhow!("Invalid arguments to operation")); + } + } + _ => { + return Err(anyhow!("Invalid arguments to operation")); + } + } + } + GtToNotEqual => match args[0].clone() { + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::Gt), + st_args, + )) => { + vec![st_args[0].clone()] + } + _ => { + return Err(anyhow!("Invalid arguments to operation")); + } + }, + LtToNotEqual => match args[0].clone() { + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::Lt), + st_args, + )) => { + vec![st_args[0].clone()] + } + _ => { + return Err(anyhow!("Invalid arguments to operation")); + } + }, + ContainsFromEntries => self.op_args_entries(public, args)?, + NotContainsFromEntries => self.op_args_entries(public, args)?, + SumOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { + ( + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::ValueOf), + st0_args, + )), + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::ValueOf), + st1_args, + )), + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::ValueOf), + st2_args, + )), + ) => { + let st_args: Vec = match ( + st0_args[1].clone(), + st1_args[1].clone(), + st2_args[1].clone(), + ) { + ( + StatementArg::Literal(v0), + StatementArg::Literal(v1), + StatementArg::Literal(v2), + ) => { + let v0: i64 = v0.clone().try_into()?; + let v1: i64 = v1.clone().try_into()?; + let v2: i64 = v2.clone().try_into()?; + if v0 == v1 + v2 { + vec![ + st0_args[0].clone(), + st1_args[0].clone(), + st2_args[0].clone(), + ] + } else { + return Err(anyhow!("Invalid arguments to operation")); + } + } + _ => { + return Err(anyhow!("Invalid arguments to operation")); + } + }; + st_args + } + _ => { + return Err(anyhow!("Invalid arguments to operation")); + } + }, + ProductOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { + ( + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::ValueOf), + st0_args, + )), + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::ValueOf), + st1_args, + )), + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::ValueOf), + st2_args, + )), + ) => { + let st_args: Vec = match ( + st0_args[1].clone(), + st1_args[1].clone(), + st2_args[1].clone(), + ) { + ( + StatementArg::Literal(v0), + StatementArg::Literal(v1), + StatementArg::Literal(v2), + ) => { + let v0: i64 = v0.clone().try_into()?; + let v1: i64 = v1.clone().try_into()?; + let v2: i64 = v2.clone().try_into()?; + if v0 == v1 * v2 { + vec![ + st0_args[0].clone(), + st1_args[0].clone(), + st2_args[0].clone(), + ] + } else { + return Err(anyhow!("Invalid arguments to operation")); + } + } + _ => { + return Err(anyhow!("Invalid arguments to operation")); + } + }; + st_args + } + _ => { + return Err(anyhow!("Invalid arguments to operation")); + } + }, + MaxOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { + ( + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::ValueOf), + st0_args, + )), + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::ValueOf), + st1_args, + )), + OperationArg::Statement(Statement( + Predicate::Native(NativePredicate::ValueOf), + st2_args, + )), + ) => { + let st_args: Vec = match ( + st0_args[1].clone(), + st1_args[1].clone(), + st2_args[1].clone(), + ) { + ( + StatementArg::Literal(v0), + StatementArg::Literal(v1), + StatementArg::Literal(v2), + ) => { + let v0: i64 = v0.clone().try_into()?; + let v1: i64 = v1.clone().try_into()?; + let v2: i64 = v2.clone().try_into()?; + if v0 == std::cmp::max(v1, v2) { + vec![ + st0_args[0].clone(), + st1_args[0].clone(), + st2_args[0].clone(), + ] + } else { + return Err(anyhow!("Invalid arguments to operation")); + } + } + _ => { + return Err(anyhow!("Invalid arguments to operation")); + } + }; + st_args + } + RenameContainedBy => todo!(), + _ => { + return Err(anyhow!("Invalid arguments to operation")); + } + }, }, OperationType::Custom(cpr) => { // All args should be statements to be pattern matched against statement templates. @@ -413,7 +606,8 @@ impl MainPodBuilder { )) }) .collect::>>()?; - let output_args = output_arg_values + + output_arg_values .chunks(2) .map(|chunk| { Ok(StatementArg::Key(AnchoredKey( @@ -430,10 +624,10 @@ impl MainPodBuilder { .ok_or(anyhow!("Missing key corresponding to hash."))?, ))) }) - .collect::>>()?; - Statement(Predicate::Custom(cpr.clone()), output_args) + .collect::>>()? } }; + let st = Statement(pred, st_args); self.operations.push(op); if public { self.public_statements.push(st.clone()); @@ -679,8 +873,8 @@ pub mod build_utils { $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::EqualFromEntries), $crate::op_args!($($arg),*)) }; (ne, $($arg:expr),+) => { $crate::frontend::Operation( - $crate::middleware::OperationType::Native(crate::middleware::NativeOperation::NotEqualFromEntries), - crate::op_args!($($arg),*)) }; + $crate::middleware::OperationType::Native($crate::middleware::NativeOperation::NotEqualFromEntries), + $crate::op_args!($($arg),*)) }; (gt, $($arg:expr),+) => { crate::frontend::Operation( crate::middleware::OperationType::Native(crate::middleware::NativeOperation::GtFromEntries), crate::op_args!($($arg),*)) }; @@ -830,6 +1024,54 @@ pub mod tests { Ok(()) } + #[test] + // Transitive equality not implemented yet + #[should_panic] + fn test_equal() { + let params = Params::default(); + let mut signed_builder = SignedPodBuilder::new(¶ms); + signed_builder.insert("a", 1); + signed_builder.insert("b", 1); + let mut signer = MockSigner { pk: "key".into() }; + let signed_pod = signed_builder.sign(&mut signer).unwrap(); + + let mut builder = MainPodBuilder::new(¶ms); + builder.add_signed_pod(&signed_pod); + + //let op_val1 = Operation{ + // OperationType::Native(NativeOperation::CopyStatement), + // signed_pod. + //} + + let op_eq1 = Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![ + OperationArg::from((&signed_pod, "a")), + OperationArg::from((&signed_pod, "b")), + ], + ); + let st1 = builder.op(true, op_eq1).unwrap(); + let op_eq2 = Operation( + OperationType::Native(NativeOperation::EqualFromEntries), + vec![ + OperationArg::from((&signed_pod, "b")), + OperationArg::from((&signed_pod, "a")), + ], + ); + let st2 = builder.op(true, op_eq2).unwrap(); + + let op_eq3 = Operation( + OperationType::Native(NativeOperation::TransitiveEqualFromStatements), + vec![OperationArg::Statement(st1), OperationArg::Statement(st2)], + ); + let st3 = builder.op(true, op_eq3); + + let mut prover = MockProver {}; + let pod = builder.prove(&mut prover, ¶ms).unwrap(); + + println!("{}", pod); + } + #[test] #[should_panic] fn test_false_st() { diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 9faeda4..5ae6107 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -32,7 +32,7 @@ impl fmt::Display for PodId { } /// AnchoredKey is a tuple containing (OriginId: PodId, key: Hash) -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct AnchoredKey(pub PodId, pub Hash); impl AnchoredKey { diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index c4cce4c..cdfcd02 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -2,8 +2,8 @@ use std::fmt; use anyhow::{anyhow, Result}; -use super::{CustomPredicateRef, Statement}; -use crate::middleware::{AnchoredKey, Params, Value, SELF}; +use super::{CustomPredicateRef, NativePredicate, Statement, StatementArg}; +use crate::middleware::{AnchoredKey, Params, Predicate, Value, SELF}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum OperationType { @@ -25,12 +25,49 @@ pub enum NativeOperation { LtToNotEqual = 9, ContainsFromEntries = 10, NotContainsFromEntries = 11, - RenameContainedBy = 12, SumOf = 13, ProductOf = 14, MaxOf = 15, } +impl OperationType { + /// Gives the type of predicate that the operation will output, if known. + /// CopyStatement may output any predicate (it will match the statement copied), + /// so output_predicate returns None on CopyStatement. + pub fn output_predicate(&self) -> Option { + match self { + OperationType::Native(native_op) => match native_op { + NativeOperation::None => Some(Predicate::Native(NativePredicate::None)), + NativeOperation::NewEntry => Some(Predicate::Native(NativePredicate::ValueOf)), + NativeOperation::CopyStatement => None, + NativeOperation::EqualFromEntries => { + Some(Predicate::Native(NativePredicate::Equal)) + } + NativeOperation::NotEqualFromEntries => { + Some(Predicate::Native(NativePredicate::NotEqual)) + } + NativeOperation::GtFromEntries => Some(Predicate::Native(NativePredicate::Gt)), + NativeOperation::LtFromEntries => Some(Predicate::Native(NativePredicate::Lt)), + NativeOperation::TransitiveEqualFromStatements => { + Some(Predicate::Native(NativePredicate::Equal)) + } + NativeOperation::GtToNotEqual => Some(Predicate::Native(NativePredicate::NotEqual)), + NativeOperation::LtToNotEqual => Some(Predicate::Native(NativePredicate::NotEqual)), + NativeOperation::ContainsFromEntries => { + Some(Predicate::Native(NativePredicate::Contains)) + } + NativeOperation::NotContainsFromEntries => { + Some(Predicate::Native(NativePredicate::NotContains)) + } + NativeOperation::SumOf => Some(Predicate::Native(NativePredicate::SumOf)), + NativeOperation::ProductOf => Some(Predicate::Native(NativePredicate::ProductOf)), + NativeOperation::MaxOf => Some(Predicate::Native(NativePredicate::MaxOf)), + }, + OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())), + } + } +} + // TODO: Refine this enum. #[derive(Clone, Debug, PartialEq, Eq)] pub enum Operation { @@ -46,7 +83,6 @@ pub enum Operation { LtToNotEqual(Statement), ContainsFromEntries(Statement, Statement), NotContainsFromEntries(Statement, Statement), - RenameContainedBy(Statement, Statement), SumOf(Statement, Statement, Statement), ProductOf(Statement, Statement, Statement), MaxOf(Statement, Statement, Statement), @@ -70,7 +106,6 @@ impl Operation { Self::LtToNotEqual(_) => OT::Native(LtToNotEqual), Self::ContainsFromEntries(_, _) => OT::Native(ContainsFromEntries), Self::NotContainsFromEntries(_, _) => OT::Native(NotContainsFromEntries), - Self::RenameContainedBy(_, _) => OT::Native(RenameContainedBy), Self::SumOf(_, _, _) => OT::Native(SumOf), Self::ProductOf(_, _, _) => OT::Native(ProductOf), Self::MaxOf(_, _, _) => OT::Native(MaxOf), @@ -92,7 +127,6 @@ impl Operation { Self::LtToNotEqual(s) => vec![s], Self::ContainsFromEntries(s1, s2) => vec![s1, s2], Self::NotContainsFromEntries(s1, s2) => vec![s1, s2], - Self::RenameContainedBy(s1, s2) => vec![s1, s2], Self::SumOf(s1, s2, s3) => vec![s1, s2, s3], Self::ProductOf(s1, s2, s3) => vec![s1, s2, s3], Self::MaxOf(s1, s2, s3) => vec![s1, s2, s3], @@ -126,9 +160,6 @@ impl Operation { (NO::NotContainsFromEntries, (Some(s1), Some(s2), None), 2) => { Self::NotContainsFromEntries(s1, s2) } - (NO::RenameContainedBy, (Some(s1), Some(s2), None), 2) => { - Self::RenameContainedBy(s1, s2) - } (NO::SumOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::SumOf(s1, s2, s3), (NO::ProductOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::ProductOf(s1, s2, s3), (NO::MaxOf, (Some(s1), Some(s2), Some(s3)), 3) => Self::MaxOf(s1, s2, s3), @@ -141,6 +172,142 @@ impl Operation { OperationType::Custom(cpr) => Self::Custom(cpr, args.to_vec()), }) } + /// Gives the output statement of the given operation, where determined + /// A ValueOf statement is not determined by the NewEntry operation, so returns Ok(None) + /// The outer Result is error handling + pub fn output_statement(&self) -> Result> { + use Statement::*; + let pred: Option = self.code().output_predicate(); + + let st_args: Option> = match self { + Self::None => Some(vec![]), + Self::NewEntry => Option::None, + Self::CopyStatement(s1) => Some(s1.args()), + Self::EqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)) => { + if v1 == v2 { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } else { + return Err(anyhow!("Invalid operation")); + } + } + Self::EqualFromEntries(_, _) => { + return Err(anyhow!("Invalid operation")); + } + Self::NotEqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)) => { + if v1 != v2 { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } else { + return Err(anyhow!("Invalid operation")); + } + } + Self::NotEqualFromEntries(_, _) => { + return Err(anyhow!("Invalid operation")); + } + Self::GtFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)) => { + if v1 > v2 { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } else { + return Err(anyhow!("Invalid operation")); + } + } + Self::GtFromEntries(_, _) => { + return Err(anyhow!("Invalid operation")); + } + Self::LtFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)) => { + if v1 < v2 { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } else { + return Err(anyhow!("Invalid operation")); + } + } + Self::LtFromEntries(_, _) => { + return Err(anyhow!("Invalid operation")); + } + Self::TransitiveEqualFromStatements(Equal(ak1, ak2), Equal(ak3, ak4)) => { + if ak2 == ak3 { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak3)]) + } else { + return Err(anyhow!("Invalid operation")); + } + } + Self::TransitiveEqualFromStatements(_, _) => { + return Err(anyhow!("Invalid operation")); + } + Self::GtToNotEqual(Gt(ak1, ak2)) => { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } + Self::GtToNotEqual(_) => { + return Err(anyhow!("Invalid operation")); + } + Self::LtToNotEqual(Gt(ak1, ak2)) => { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } + Self::LtToNotEqual(_) => { + return Err(anyhow!("Invalid operation")); + } + Self::ContainsFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)) => + /* TODO */ + { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } + Self::ContainsFromEntries(_, _) => { + return Err(anyhow!("Invalid operation")); + } + Self::NotContainsFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)) => + /* TODO */ + { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } + Self::NotContainsFromEntries(_, _) => { + return Err(anyhow!("Invalid operation")); + } + Self::SumOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)) => { + let v1: i64 = (*v1).try_into()?; + let v2: i64 = (*v2).try_into()?; + let v3: i64 = (*v3).try_into()?; + if v1 == v2 + v3 { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } else { + return Err(anyhow!("Invalid operation")); + } + } + Self::SumOf(_, _, _) => { + return Err(anyhow!("Invalid operation")); + } + Self::ProductOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)) => { + let v1: i64 = (*v1).try_into()?; + let v2: i64 = (*v2).try_into()?; + let v3: i64 = (*v3).try_into()?; + if v1 == v2 * v3 { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } else { + return Err(anyhow!("Invalid operation")); + } + } + Self::ProductOf(_, _, _) => { + return Err(anyhow!("Invalid operation")); + } + Self::MaxOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)) => { + let v1: i64 = (*v1).try_into()?; + let v2: i64 = (*v2).try_into()?; + let v3: i64 = (*v3).try_into()?; + if v1 == std::cmp::max(v2, v3) { + Some(vec![StatementArg::Key(*ak1), StatementArg::Key(*ak2)]) + } else { + return Err(anyhow!("Invalid operation")); + } + } + Self::MaxOf(_, _, _) => { + return Err(anyhow!("Invalid operation")); + } + Self::Custom(_, _) => todo!(), + }; + + let x: Option> = pred + .zip(st_args) + .map(|(pred, st_args)| Statement::from_args(pred, st_args)); + x.transpose() + } /// Checks the given operation against a statement. pub fn check(&self, _params: &Params, output_statement: &Statement) -> Result { use Statement::*; @@ -176,9 +343,6 @@ impl Operation { ) => Ok(ak2 == ak3 && ak5 == ak1 && ak6 == ak4), (Self::GtToNotEqual(Gt(ak1, ak2)), NotEqual(ak3, ak4)) => Ok(ak1 == ak3 && ak2 == ak4), (Self::LtToNotEqual(Lt(ak1, ak2)), NotEqual(ak3, ak4)) => Ok(ak1 == ak3 && ak2 == ak4), - (Self::RenameContainedBy(Contains(ak1, ak2), Equal(ak3, ak4)), Contains(ak5, ak6)) => { - Ok(ak1 == ak3 && ak4 == ak5 && ak2 == ak6) - } ( Self::SumOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)), SumOf(ak4, ak5, ak6), @@ -231,7 +395,7 @@ impl fmt::Display for Operation { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "middleware::Operation:")?; writeln!(f, " {:?} ", self.code())?; - for (_, arg) in self.args().iter().enumerate() { + for arg in self.args().iter() { writeln!(f, " {}", arg)?; } Ok(()) diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 7a0321f..6b03d51 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -85,6 +85,100 @@ impl Statement { Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(Key)), } } + pub fn from_args(pred: Predicate, args: Vec) -> Result { + use Predicate::*; + let st: Result = match pred { + Native(NativePredicate::None) => Ok(Self::None), + Native(NativePredicate::ValueOf) => { + if let (StatementArg::Key(a0), StatementArg::Literal(v1)) = (args[0], args[1]) { + Ok(Self::ValueOf(a0, v1)) + } else { + Err(anyhow!("Incorrect statement args")) + } + } + Native(NativePredicate::Equal) => { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + Ok(Self::Equal(a0, a1)) + } else { + Err(anyhow!("Incorrect statement args")) + } + } + Native(NativePredicate::NotEqual) => { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + Ok(Self::NotEqual(a0, a1)) + } else { + Err(anyhow!("Incorrect statement args")) + } + } + Native(NativePredicate::Gt) => { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + Ok(Self::Gt(a0, a1)) + } else { + Err(anyhow!("Incorrect statement args")) + } + } + Native(NativePredicate::Lt) => { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + Ok(Self::Lt(a0, a1)) + } else { + Err(anyhow!("Incorrect statement args")) + } + } + Native(NativePredicate::Contains) => { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + Ok(Self::Contains(a0, a1)) + } else { + Err(anyhow!("Incorrect statement args")) + } + } + Native(NativePredicate::NotContains) => { + if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) { + Ok(Self::NotContains(a0, a1)) + } else { + Err(anyhow!("Incorrect statement args")) + } + } + Native(NativePredicate::SumOf) => { + if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = + (args[0], args[1], args[2]) + { + Ok(Self::SumOf(a0, a1, a2)) + } else { + Err(anyhow!("Incorrect statement args")) + } + } + Native(NativePredicate::ProductOf) => { + if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = + (args[0], args[1], args[2]) + { + Ok(Self::ProductOf(a0, a1, a2)) + } else { + Err(anyhow!("Incorrect statement args")) + } + } + Native(NativePredicate::MaxOf) => { + if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = + (args[0], args[1], args[2]) + { + Ok(Self::MaxOf(a0, a1, a2)) + } else { + Err(anyhow!("Incorrect statement args")) + } + } + BatchSelf(_) => unreachable!(), + Custom(cpr) => { + let ak_args: Result> = args + .iter() + .map(|x| match x { + StatementArg::Key(ak) => Ok(*ak), + _ => Err(anyhow!("Incorrect statement args")), + }) + .collect(); + Ok(Self::Custom(cpr, ak_args?)) + } + }; + st + } } impl ToFields for Statement { @@ -120,7 +214,7 @@ impl fmt::Display for Statement { } /// Statement argument type. Useful for statement decompositions. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum StatementArg { None, Literal(Value), @@ -149,7 +243,7 @@ impl StatementArg { } pub fn key(&self) -> Result { match self { - Self::Key(ak) => Ok(ak.clone()), + Self::Key(ak) => Ok(*ak), _ => Err(anyhow!("Statement argument {:?} is not a key.", self)), } }