diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 8cafce8..0ace524 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -8,9 +8,9 @@ use serde::{Deserialize, Serialize}; pub use serialization::{SerializedMainPod, SerializedSignedPod}; use crate::middleware::{ - self, check_st_tmpl, hash_op, hash_str, max_op, prod_op, sum_op, AnchoredKey, Key, - MainPodInputs, NativeOperation, OperationAux, OperationType, Params, PodId, PodProver, - PodSigner, Statement, StatementArg, VDSet, Value, ValueRef, KEY_TYPE, SELF, + self, check_custom_pred, check_st_tmpl, hash_op, hash_str, max_op, prod_op, sum_op, + AnchoredKey, Key, MainPodInputs, NativeOperation, OperationAux, OperationType, Params, PodId, + PodProver, PodSigner, Statement, StatementArg, VDSet, Value, ValueRef, KEY_TYPE, SELF, }; mod custom; @@ -285,190 +285,138 @@ impl MainPodBuilder { fn op_statement(&mut self, op: Operation) -> Result { use NativeOperation::*; - let arg_error = |s: &str| Error::op_invalid_args(s.to_string()); let st = match op.0 { - OperationType::Native(o) => match (o, &op.1.as_slice()) { - (None, &[]) => Statement::None, - (NewEntry, &[OperationArg::Entry(k, v)]) => { - Statement::equal(AnchoredKey::from((SELF, k.as_str())), v.clone()) - } - (EqualFromEntries, &[a1, a2]) => { - let (r1, v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("equal-from-entries"))?; - let (r2, v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("equal-from-entries"))?; - if v1 == v2 { - Statement::equal(r1, r2) - } else { - return Err(arg_error("equal-from-entries")); + OperationType::Native(o) => { + let native_arg_error = move || Error::op_invalid_args(format!("{o:?}")); + match (o, &op.1.as_slice()) { + (None, &[]) => Statement::None, + (NewEntry, &[OperationArg::Entry(k, v)]) => { + Statement::equal(AnchoredKey::from((SELF, k.as_str())), v.clone()) + } + (EqualFromEntries, &[a1, a2]) => { + let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + if v1 == v2 { + Statement::equal(r1, r2) + } else { + return Err(native_arg_error()); + } + } + (NotEqualFromEntries, &[a1, a2]) => { + let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + if v1 != v2 { + Statement::not_equal(r1, r2) + } else { + return Err(native_arg_error()); + } + } + (LtFromEntries, &[a1, a2]) => { + let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + if v1 < v2 { + Statement::lt(r1, r2) + } else { + return Err(native_arg_error()); + } + } + (LtEqFromEntries, &[a1, a2]) => { + let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + if v1 <= v2 { + Statement::not_equal(r1, r2) + } else { + return Err(native_arg_error()); + } + } + (CopyStatement, &[OperationArg::Statement(s)]) => s.clone(), + ( + TransitiveEqualFromStatements, + &[OperationArg::Statement(Statement::Equal(r1, r2)), OperationArg::Statement(Statement::Equal(r3, r4))], + ) => { + if r2 == r3 { + Statement::Equal(r1.clone(), r4.clone()) + } else { + return Err(native_arg_error()); + } + } + (LtToNotEqual, &[OperationArg::Statement(Statement::Lt(r1, r2))]) => { + Statement::NotEqual(r1.clone(), r2.clone()) + } + (SumOf, &[a1, a2, a3]) => { + let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + let (r3, v3) = a3.value_and_ref().ok_or_else(native_arg_error)?; + if middleware::Operation::check_int_fn(v1, v2, v3, sum_op)? { + Statement::SumOf(r1, r2, r3) + } else { + return Err(native_arg_error()); + } + } + (ProductOf, &[a1, a2, a3]) => { + let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + let (r3, v3) = a3.value_and_ref().ok_or_else(native_arg_error)?; + if middleware::Operation::check_int_fn(v1, v2, v3, prod_op)? { + Statement::ProductOf(r1, r2, r3) + } else { + return Err(native_arg_error()); + } + } + (MaxOf, &[a1, a2, a3]) => { + let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + let (r3, v3) = a3.value_and_ref().ok_or_else(native_arg_error)?; + if middleware::Operation::check_int_fn(v1, v2, v3, max_op)? { + Statement::MaxOf(r1, r2, r3) + } else { + return Err(native_arg_error()); + } + } + (HashOf, &[a1, a2, a3]) => { + let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + let (r3, v3) = a3.value_and_ref().ok_or_else(native_arg_error)?; + if v1 == &hash_op(v2.clone(), v3.clone()) { + Statement::HashOf(r1, r2, r3) + } else { + return Err(native_arg_error()); + } + } + (ContainsFromEntries, &[a1, a2, a3]) => { + let (r1, _v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, _v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + let (r3, _v3) = a3.value_and_ref().ok_or_else(native_arg_error)?; + // TODO: validate proof + Statement::Contains(r1, r2, r3) + } + (NotContainsFromEntries, &[a1, a2]) => { + let (r1, _v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, _v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + // TODO: validate proof + Statement::NotContains(r1, r2) + } + (PublicKeyOf, &[a1, a2]) => { + let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + if middleware::Operation::check_public_key(v1, v2)? { + Statement::PublicKeyOf(r1, r2) + } else { + return Err(native_arg_error()); + } + } + (t, _) => { + if t.is_syntactic_sugar() { + return Err(Error::custom(format!( + "Unexpected syntactic sugar: {:?}", + t + ))); + } else { + return Err(native_arg_error()); + } } } - (NotEqualFromEntries, &[a1, a2]) => { - let (r1, v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("not-equal-from-entries"))?; - let (r2, v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("not-equal-from-entries"))?; - if v1 != v2 { - Statement::not_equal(r1, r2) - } else { - return Err(arg_error("not-equal-from-entries")); - } - } - (LtFromEntries, &[a1, a2]) => { - let (r1, v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("lt-from-entries"))?; - let (r2, v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("lt-from-entries"))?; - if v1 < v2 { - Statement::lt(r1, r2) - } else { - return Err(arg_error("lt-from-entries")); - } - } - (LtEqFromEntries, &[a1, a2]) => { - let (r1, v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("lt-eq-from-entries"))?; - let (r2, v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("lt-eq-from-entries"))?; - if v1 <= v2 { - Statement::not_equal(r1, r2) - } else { - return Err(arg_error("lt-eq-from-entries")); - } - } - (CopyStatement, &[OperationArg::Statement(s)]) => s.clone(), - ( - TransitiveEqualFromStatements, - &[OperationArg::Statement(Statement::Equal(r1, r2)), OperationArg::Statement(Statement::Equal(r3, r4))], - ) => { - if r2 == r3 { - Statement::Equal(r1.clone(), r4.clone()) - } else { - return Err(arg_error("transitive-eq")); - } - } - (LtToNotEqual, &[OperationArg::Statement(Statement::Lt(r1, r2))]) => { - Statement::NotEqual(r1.clone(), r2.clone()) - } - (SumOf, &[a1, a2, a3]) => { - let (r1, v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("sum-from-entries"))?; - let (r2, v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("sum-from-entries"))?; - let (r3, v3) = a3 - .value_and_ref() - .ok_or_else(|| arg_error("sum-from-entries"))?; - if middleware::Operation::check_int_fn(v1, v2, v3, sum_op)? { - Statement::SumOf(r1, r2, r3) - } else { - return Err(arg_error("sum-from-entries")); - } - } - (ProductOf, &[a1, a2, a3]) => { - let (r1, v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("prod-from-entries"))?; - let (r2, v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("prod-from-entries"))?; - let (r3, v3) = a3 - .value_and_ref() - .ok_or_else(|| arg_error("prod-from-entries"))?; - if middleware::Operation::check_int_fn(v1, v2, v3, prod_op)? { - Statement::ProductOf(r1, r2, r3) - } else { - return Err(arg_error("prod-from-entries")); - } - } - (MaxOf, &[a1, a2, a3]) => { - let (r1, v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("max-from-entries"))?; - let (r2, v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("max-from-entries"))?; - let (r3, v3) = a3 - .value_and_ref() - .ok_or_else(|| arg_error("max-from-entries"))?; - if middleware::Operation::check_int_fn(v1, v2, v3, max_op)? { - Statement::MaxOf(r1, r2, r3) - } else { - return Err(arg_error("max-from-entries")); - } - } - (HashOf, &[a1, a2, a3]) => { - let (r1, v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("hash-from-entries"))?; - let (r2, v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("hash-from-entries"))?; - let (r3, v3) = a3 - .value_and_ref() - .ok_or_else(|| arg_error("hash-from-entries"))?; - if v1 == &hash_op(v2.clone(), v3.clone()) { - Statement::HashOf(r1, r2, r3) - } else { - return Err(arg_error("hash-from-entries")); - } - } - (ContainsFromEntries, &[a1, a2, a3]) => { - let (r1, _v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("contains-from-entries"))?; - let (r2, _v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("contains-from-entries"))?; - let (r3, _v3) = a3 - .value_and_ref() - .ok_or_else(|| arg_error("contains-from-entries"))?; - // TODO: validate proof - Statement::Contains(r1, r2, r3) - } - (NotContainsFromEntries, &[a1, a2]) => { - let (r1, _v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("contains-from-entries"))?; - let (r2, _v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("contains-from-entries"))?; - // TODO: validate proof - Statement::NotContains(r1, r2) - } - (PublicKeyOf, &[a1, a2]) => { - let (r1, v1) = a1 - .value_and_ref() - .ok_or_else(|| arg_error("public-key-from-entries"))?; - let (r2, v2) = a2 - .value_and_ref() - .ok_or_else(|| arg_error("public-key-from-entries"))?; - if middleware::Operation::check_public_key(v1, v2)? { - Statement::PublicKeyOf(r1, r2) - } else { - return Err(arg_error("public-key-from-entries")); - } - } - (t, _) => { - if t.is_syntactic_sugar() { - return Err(Error::custom(format!( - "Unexpected syntactic sugar: {:?}", - t - ))); - } else { - return Err(arg_error("malformed operation")); - } - } - }, + } OperationType::Custom(cpr) => { let pred = &cpr.batch.predicates()[cpr.index]; if pred.statements.len() != op.1.len() { @@ -509,6 +457,7 @@ impl MainPodBuilder { .take(pred.args_len) .map(|v| v.unwrap_or_else(|| v_default.clone())) .collect(); + check_custom_pred(&self.params, &cpr, &args, &st_args)?; Statement::Custom(cpr, st_args) } }; diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index 9f7b572..ed96a1b 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -164,10 +164,10 @@ macro_rules! op_impl_st { } impl Operation { - pub fn new_entry(a1: impl Into, a2: impl Into) -> Self { + pub fn new_entry(a1: impl Into, a2: impl Into) -> Self { Self( OperationType::Native(NativeOperation::NewEntry), - vec![a1.into(), a2.into().into()], + vec![OperationArg::Entry(a1.into(), a2.into())], OperationAux::None, ) } @@ -180,6 +180,12 @@ impl Operation { op_impl_oa!(sum_of, SumOf, 3); op_impl_oa!(product_of, ProductOf, 3); op_impl_oa!(max_of, MaxOf, 3); + /// Creates a custom operation. + /// + /// `args` should contain the statements that are needed to prove the + /// custom statement. It should have the same length as + /// `cpr.predicate().statements()`. If `cpr` refers to an `or` predicate, + /// then all but one of the statements should be `Statement::None`. pub fn custom(cpr: CustomPredicateRef, args: Vec) -> Self { Self(OperationType::Custom(cpr), args, OperationAux::None) } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 7ad0655..123f4d6 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -227,6 +227,14 @@ impl CustomPredicate { ) -> Result { Self::new(params, name, false, statements, args_len, wildcard_names) } + /// Creates a new custom predicate. + /// + /// # Arguments + /// * `name` - The name of the custom predicate. + /// * `conjunction` - `true` for an `and` predicate, `false` for an `or` predicate. + /// * `statements` - The statements required to apply the custom predicate. + /// * `args_len` - The number of public arguments. + /// * `wildcard_names` - The names of the arguments (public and private). pub fn new( params: &Params, name: String, diff --git a/src/middleware/error.rs b/src/middleware/error.rs index 1c35bc3..da23f3d 100644 --- a/src/middleware/error.rs +++ b/src/middleware/error.rs @@ -3,8 +3,8 @@ use std::{backtrace::Backtrace, fmt::Debug}; use crate::middleware::{ - CustomPredicate, Key, Operation, PodId, Statement, StatementArg, StatementTmplArg, Value, - Wildcard, + CustomPredicate, Key, Operation, PodId, Predicate, Statement, StatementArg, StatementTmplArg, + Value, Wildcard, }; pub type Result = core::result::Result; @@ -19,7 +19,7 @@ pub enum MiddlewareInnerError { InvalidStatementArg(StatementArg, String), #[error("{0} {1} is over the limit {2}")] MaxLength(String, usize, usize), - #[error("{0} amount of {1} should be {1} but it's {2}")] + #[error("{0} amount of {1} should be {2} but it's {3}")] DiffAmount(String, String, usize, usize), #[error("{0} should be assigned the value {1} but has previously been assigned {2}")] InvalidWildcardAssignment(Wildcard, Value, Value), @@ -27,12 +27,10 @@ pub enum MiddlewareInnerError { MismatchedAnchoredKeyInStatementTmplArg(Wildcard, PodId, Key, Key), #[error("{0} does not match against {1}")] MismatchedStatementTmplArg(StatementTmplArg, StatementArg), + #[error("Expected a statement of type {0}, got {1}")] + MismatchedStatementType(Predicate, Predicate), #[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")] MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate), - #[error( - "Not all statement templates of the following custom predicate have been matched:\n{0}" - )] - UnsatisfiedCustomPredicateConjunction(CustomPredicate), #[error( "None of the statement templates of the following custom predicate have been matched:\n{0}" )] @@ -110,6 +108,9 @@ impl Error { ) -> Self { new!(MismatchedStatementTmplArg(st_tmpl_arg, st_arg)) } + pub(crate) fn mismatched_statement_type(expected: Predicate, seen: Predicate) -> Self { + new!(MismatchedStatementType(expected, seen)) + } pub(crate) fn mismatched_wildcard_value_and_statement_arg( wc_value: Value, st_arg: Value, @@ -120,9 +121,6 @@ impl Error { wc_value, st_arg, arg_index, pred )) } - pub(crate) fn unsatisfied_custom_predicate_conjunction(pred: CustomPredicate) -> Self { - new!(UnsatisfiedCustomPredicateConjunction(pred)) - } pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self { new!(UnsatisfiedCustomPredicateDisjunction(pred)) } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 333b998..8adb595 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -14,8 +14,8 @@ use crate::{ }, middleware::{ hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, NativePredicate, - Params, Predicate, Result, Statement, StatementArg, StatementTmplArg, ToFields, Value, - ValueRef, Wildcard, F, SELF, + Params, Predicate, Result, Statement, StatementArg, StatementTmpl, StatementTmplArg, + ToFields, Value, ValueRef, Wildcard, F, SELF, }, }; @@ -486,7 +486,37 @@ pub fn resolve_wildcard_values( .collect()) } -fn check_custom_pred( +fn check_custom_pred_argument( + custom_pred_ref: &CustomPredicateRef, + template: &StatementTmpl, + statement: &Statement, +) -> Result<()> { + let template_pred = match &template.pred { + &Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef { + batch: custom_pred_ref.batch.clone(), + index: i, + }), + p => p.clone(), + }; + if template_pred != statement.predicate() { + return Err(Error::mismatched_statement_type( + template_pred, + statement.predicate(), + )); + } + let st_args_len = statement.args().len(); + if template.args.len() != st_args_len { + return Err(Error::diff_amount( + "statement template in custom predicate".to_string(), + "arguments".to_string(), + st_args_len, + template.args.len(), + )); + } + Ok(()) +} + +pub(crate) fn check_custom_pred( params: &Params, custom_pred_ref: &CustomPredicateRef, args: &[Statement], @@ -510,19 +540,24 @@ fn check_custom_pred( )); } - // Count the number of statements that match the templates by predicate. - let mut num_matches = 0; + let mut match_exists = false; for (st_tmpl, st) in pred.statements.iter().zip(args) { - let st_tmpl_pred = match &st_tmpl.pred { - Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef { - batch: custom_pred_ref.batch.clone(), - index: *i, - }), - p => p.clone(), - }; - if st_tmpl_pred == st.predicate() { - num_matches += 1; + // For `or` predicates, only one statement needs to match the template. + // The rest of the statements can be `None`. + if !pred.conjunction + && matches!(st, Statement::None) + && st_tmpl.pred != Predicate::Native(NativePredicate::None) + { + continue; } + check_custom_pred_argument(custom_pred_ref, st_tmpl, st)?; + match_exists = true; + } + + if !pred.conjunction && !match_exists { + return Err(Error::unsatisfied_custom_predicate_disjunction( + pred.clone(), + )); } let wildcard_map = resolve_wildcard_values(params, pred, args)?; @@ -539,18 +574,6 @@ fn check_custom_pred( } } - if pred.conjunction { - if num_matches != pred.statements.len() { - return Err(Error::unsatisfied_custom_predicate_conjunction( - pred.clone(), - )); - } - } else if num_matches == 0 { - return Err(Error::unsatisfied_custom_predicate_disjunction( - pred.clone(), - )); - } - Ok(()) }