From 7d0d3ad76952927ef089b936a4ed4baaa705121e Mon Sep 17 00:00:00 2001 From: Daniel Gulotta Date: Fri, 13 Jun 2025 10:27:19 -0700 Subject: [PATCH] Allow literals in statements (#276) Implements #229 and #261. --- src/backends/plonky2/circuits/common.rs | 17 + src/backends/plonky2/circuits/mainpod.rs | 525 +++++++++++---------- src/backends/plonky2/circuits/signedpod.rs | 3 +- src/backends/plonky2/emptypod.rs | 2 +- src/backends/plonky2/mainpod/mod.rs | 73 +-- src/backends/plonky2/mainpod/operation.rs | 6 +- src/backends/plonky2/mainpod/statement.rs | 46 +- src/backends/plonky2/mock/emptypod.rs | 2 +- src/backends/plonky2/mock/mainpod.rs | 60 +-- src/backends/plonky2/mock/signedpod.rs | 2 +- src/backends/plonky2/signedpod.rs | 2 +- src/examples/custom.rs | 8 +- src/frontend/mod.rs | 409 ++++++++-------- src/frontend/operation.rs | 107 ++++- src/lang/mod.rs | 32 +- src/lang/processor.rs | 43 +- src/middleware/custom.rs | 16 +- src/middleware/mod.rs | 5 +- src/middleware/operation.rs | 189 +++++--- src/middleware/statement.rs | 270 ++++++----- 20 files changed, 992 insertions(+), 825 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 2861e3f..09d7a4e 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -864,6 +864,12 @@ pub trait CircuitBuilderPod, const D: usize> { fn add_virtual_custom_predicate_entry(&mut self, params: &Params) -> CustomPredicateEntryTarget; fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; + fn select_statement_arg( + &mut self, + b: BoolTarget, + x: &StatementArgTarget, + y: &StatementArgTarget, + ) -> StatementArgTarget; fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget; fn constant_value(&mut self, v: RawValue) -> ValueTarget; fn is_equal_slice(&mut self, xs: &[Target], ys: &[Target]) -> BoolTarget; @@ -1038,6 +1044,17 @@ impl CircuitBuilderPod for CircuitBuilder { } } + fn select_statement_arg( + &mut self, + b: BoolTarget, + x: &StatementArgTarget, + y: &StatementArgTarget, + ) -> StatementArgTarget { + StatementArgTarget { + elements: std::array::from_fn(|i| self.select(b, x.elements[i], y.elements[i])), + } + } + fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget { BoolTarget::new_unsafe(self.select(b, x.target, y.target)) } diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 1f8939c..9772db4 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -41,7 +41,7 @@ use crate::{ middleware::{ AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, NativePredicate, Params, PodType, PredicatePrefix, Statement, StatementArg, ToFields, - Value, WildcardValue, EMPTY_VALUE, F, HASH_SIZE, KEY_TYPE, SELF, VALUE_SIZE, + Value, ValueRef, WildcardValue, EMPTY_VALUE, F, HASH_SIZE, KEY_TYPE, SELF, VALUE_SIZE, }, }; @@ -60,39 +60,95 @@ struct OperationVerifyGadget { params: Params, } -impl OperationVerifyGadget { - /// Checks whether the first `N` arguments to an op are ValueOf - /// statements, returning a boolean target indicating whether this - /// is the case as well as the value targets derived from each - /// argument. - fn first_n_args_as_values( - &self, +const MAX_VALUE_ARGS: usize = 3; + +struct StatementArgCache { + rhs: ValueTarget, + lhs: StatementArgTarget, + valid: BoolTarget, +} + +struct StatementCache { + equations: [StatementArgCache; MAX_VALUE_ARGS], + first_n_equations_valid: [BoolTarget; MAX_VALUE_ARGS], + op_args: Vec, +} + +impl StatementCache { + fn new( + params: &Params, builder: &mut CircuitBuilder, - resolved_op_args: &[StatementTarget], - ) -> (BoolTarget, [ValueTarget; N]) { - let arg_is_valueof = resolved_op_args[..N] - .iter() - .map(|arg| { - let st_type_ok = - arg.has_native_type(builder, &self.params, NativePredicate::ValueOf); - let value_arg_ok = builder.statement_arg_is_value(&arg.args[1]); - builder.and(st_type_ok, value_arg_ok) - }) - .collect::>(); - let first_n_args_are_valueofs = arg_is_valueof - .into_iter() - .reduce(|a, b| builder.and(a, b)) - .expect("No args specified."); - let values = array::from_fn(|i| resolved_op_args[i].args[1].as_value()); - (first_n_args_are_valueofs, values) + op: &OperationTarget, + st: &StatementTarget, + prev_statements: &[StatementTarget], + ) -> Self { + let op_args = if prev_statements.is_empty() { + (0..params.max_operation_args) + .map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[])) + .collect_vec() + } else { + // `op.args` is a vector of arrays of length 1, so `.flatten()` is just + // converting a length 1 array into a scalar. + op.args + .iter() + .flatten() + .map(|&i| builder.vec_ref(params, prev_statements, i)) + .collect::>() + }; + assert!(params.max_operation_args >= 3); + assert!(params.max_statement_args >= 3); + let equations = array::from_fn(|i| { + let pred_is_none = op_args[i].has_native_type(builder, params, NativePredicate::None); + let arg_is_value = builder.statement_arg_is_value(&st.args[i]); + let is_literal = builder.and(pred_is_none, arg_is_value); + let pred_is_eq = op_args[i].has_native_type(builder, params, NativePredicate::Equal); + let ref_is_value = builder.statement_arg_is_value(&op_args[i].args[1]); + let is_reference = builder.and(pred_is_eq, ref_is_value); + let valid = builder.or(is_literal, is_reference); + let rhs_literal = st.args[i].as_value(); + let rhs_reference = op_args[i].args[1].as_value(); + let rhs = builder.select_value(pred_is_none, rhs_literal, rhs_reference); + let lhs = builder.select_statement_arg(pred_is_none, &st.args[i], &op_args[i].args[0]); + StatementArgCache { rhs, lhs, valid } + }); + let mut first_n_equations_valid = [equations[0].valid; MAX_VALUE_ARGS]; + for i in 1..MAX_VALUE_ARGS { + first_n_equations_valid[i] = + builder.and(equations[i].valid, first_n_equations_valid[i - 1]); + } + StatementCache { + equations, + first_n_equations_valid, + op_args, + } } + /// Attempts to interpret the first `N` arguments as values. + /// + /// If the operation argument is a statement of type `None`, then the value + /// should be the corresponding argument of the current statement. + /// If the operation argument is a statement of type `Equals`, then the value + /// should be the argument at index 1 of that statement. + /// If the function successfully interprets the arguments as values, + /// returns `True` along with those values. Otherwise, returns `False` + /// along with some arbitrary values. + fn first_n_args_as_values(&self) -> (BoolTarget, [ValueTarget; N]) { + ( + self.first_n_equations_valid[N - 1], + array::from_fn(|i| self.equations[i].rhs), + ) + } +} + +impl OperationVerifyGadget { + #[allow(clippy::too_many_arguments)] fn eval( &self, builder: &mut CircuitBuilder, st: &StatementTarget, op: &OperationTarget, prev_statements: &[StatementTarget], + input_statements_offset: usize, merkle_claims: &[MerkleClaimTarget], custom_predicate_verification_table: &[HashOutTarget], ) -> Result<()> { @@ -104,19 +160,7 @@ impl OperationVerifyGadget { // can reference any of the `prev_statements`. // TODO: Clean this up. let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); - let resolved_op_args = if prev_statements.is_empty() { - (0..self.params.max_operation_args) - .map(|_| { - StatementTarget::new_native(builder, &self.params, NativePredicate::None, &[]) - }) - .collect_vec() - } else { - op.args - .iter() - .flatten() - .map(|&i| builder.vec_ref(&self.params, prev_statements, i)) - .collect::>() - }; + let cache = StatementCache::new(&self.params, builder, op, st, prev_statements); measure_gates_end!(builder, measure_resolve_op_args); // TODO: Can we have a single table with merkel claims and verified custom predicates // together (with an identifying prefix) and then we only need one random access instead of @@ -158,22 +202,28 @@ impl OperationVerifyGadget { let op_checks = [ vec![ self.eval_none(builder, st, &op.op_type), - self.eval_new_entry(builder, st, &op.op_type, prev_statements), + self.eval_new_entry( + builder, + st, + &op.op_type, + prev_statements, + input_statements_offset, + ), ], // Skip these if there are no resolved op args - if resolved_op_args.is_empty() { + if cache.op_args.is_empty() { vec![] } else { vec![ - self.eval_copy(builder, st, &op.op_type, &resolved_op_args), - self.eval_eq_neq_from_entries(builder, st, &op.op_type, &resolved_op_args), - self.eval_lt_lteq_from_entries(builder, st, &op.op_type, &resolved_op_args), - self.eval_transitive_eq(builder, st, &op.op_type, &resolved_op_args), - self.eval_lt_to_neq(builder, st, &op.op_type, &resolved_op_args), - self.eval_hash_of(builder, st, &op.op_type, &resolved_op_args), - self.eval_sum_of(builder, st, &op.op_type, &resolved_op_args), - self.eval_product_of(builder, st, &op.op_type, &resolved_op_args), - self.eval_max_of(builder, st, &op.op_type, &resolved_op_args), + self.eval_copy(builder, st, &op.op_type, &cache.op_args), + self.eval_eq_neq_from_entries(builder, st, &op.op_type, &cache), + self.eval_lt_lteq_from_entries(builder, st, &op.op_type, &cache), + self.eval_transitive_eq(builder, st, &op.op_type, &cache.op_args), + self.eval_lt_to_neq(builder, st, &op.op_type, &cache.op_args), + self.eval_hash_of(builder, st, &op.op_type, &cache), + self.eval_sum_of(builder, st, &op.op_type, &cache), + self.eval_product_of(builder, st, &op.op_type, &cache), + self.eval_max_of(builder, st, &op.op_type, &cache), ] }, // Skip these if there are no resolved Merkle claims @@ -184,14 +234,14 @@ impl OperationVerifyGadget { st, &op.op_type, resolved_merkle_claim, - &resolved_op_args, + &cache, ), self.eval_not_contains_from_entries( builder, st, &op.op_type, resolved_merkle_claim, - &resolved_op_args, + &cache, ), ] } else { @@ -204,7 +254,7 @@ impl OperationVerifyGadget { st, &op.op_type, resolved_custom_pred_verification, - &resolved_op_args, + &cache.op_args, )] } else { vec![] @@ -225,13 +275,13 @@ impl OperationVerifyGadget { st: &StatementTarget, op_type: &OperationTypeTarget, resolved_merkle_claim: MerkleClaimTarget, - resolved_op_args: &[StatementTarget], + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpContainsFromEntries"); let op_code_ok = op_type.has_native(builder, NativeOperation::ContainsFromEntries); let (arg_types_ok, [merkle_root_value, key_value, value_value]) = - self.first_n_args_as_values(builder, resolved_op_args); + cache.first_n_args_as_values(); // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ @@ -251,14 +301,14 @@ impl OperationVerifyGadget { let merkle_proof_ok = builder.all(merkle_proof_checks); // Check output statement - let arg1_key = resolved_op_args[0].args[0].clone(); - let arg2_key = resolved_op_args[1].args[0].clone(); - let arg3_key = resolved_op_args[2].args[0].clone(); + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let arg3_expected = cache.equations[2].lhs.clone(); let expected_statement = StatementTarget::new_native( builder, &self.params, NativePredicate::Contains, - &[arg1_key, arg2_key, arg3_key], + &[arg1_expected, arg2_expected, arg3_expected], ); let st_ok = builder.is_equal_flattenable(st, &expected_statement); @@ -273,13 +323,12 @@ impl OperationVerifyGadget { st: &StatementTarget, op_type: &OperationTypeTarget, resolved_merkle_claim: MerkleClaimTarget, - resolved_op_args: &[StatementTarget], + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries"); let op_code_ok = op_type.has_native(builder, NativeOperation::NotContainsFromEntries); - let (arg_types_ok, [merkle_root_value, key_value]) = - self.first_n_args_as_values(builder, resolved_op_args); + let (arg_types_ok, [merkle_root_value, key_value]) = cache.first_n_args_as_values(); // Check Merkle proof (verified elsewhere) against op args. let merkle_proof_checks = [ @@ -298,13 +347,13 @@ impl OperationVerifyGadget { let merkle_proof_ok = builder.all(merkle_proof_checks); // Check output statement - let arg1_key = resolved_op_args[0].args[0].clone(); - let arg2_key = resolved_op_args[1].args[0].clone(); + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); let expected_statement = StatementTarget::new_native( builder, &self.params, NativePredicate::NotContains, - &[arg1_key, arg2_key], + &[arg1_expected, arg2_expected], ); let st_ok = builder.is_equal_flattenable(st, &expected_statement); @@ -343,7 +392,7 @@ impl OperationVerifyGadget { builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - resolved_op_args: &[StatementTarget], + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries"); let eq_op_st_code_ok = { @@ -358,16 +407,15 @@ impl OperationVerifyGadget { }; let op_st_code_ok = builder.or(eq_op_st_code_ok, neq_op_st_code_ok); - let (arg_types_ok, [arg1_value, arg2_value]) = - self.first_n_args_as_values(builder, resolved_op_args); + let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values(); let op_args_eq = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements); let op_args_ok = builder.is_equal(op_args_eq.target, eq_op_st_code_ok.target); - let arg1_key = resolved_op_args[0].args[0].clone(); - let arg2_key = resolved_op_args[1].args[0].clone(); + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); - let expected_st_args: Vec<_> = [arg1_key, arg2_key] + let expected_st_args: Vec<_> = [arg1_expected, arg2_expected] .into_iter() .chain(std::iter::repeat_with(|| StatementArgTarget::none(builder))) .take(self.params.max_statement_args) @@ -394,7 +442,7 @@ impl OperationVerifyGadget { builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - resolved_op_args: &[StatementTarget], + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpLtLteqFromEntries"); let zero = ValueTarget::zero(builder); @@ -412,8 +460,7 @@ impl OperationVerifyGadget { }; let op_st_code_ok = builder.or(lt_op_st_code_ok, lteq_op_st_code_ok); - let (arg_types_ok, [arg1_value, arg2_value]) = - self.first_n_args_as_values(builder, resolved_op_args); + let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values(); // If we are not dealing with the right op & statement types, // replace args with dummy values in the following checks. @@ -435,10 +482,10 @@ impl OperationVerifyGadget { }; builder.assert_i64_less_if(lt_check_flag, value1, value2); - let arg1_key = resolved_op_args[0].args[0].clone(); - let arg2_key = resolved_op_args[1].args[0].clone(); + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); - let expected_st_args: Vec<_> = [arg1_key, arg2_key] + let expected_st_args: Vec<_> = [arg1_expected, arg2_expected] .into_iter() .chain(std::iter::repeat_with(|| StatementArgTarget::none(builder))) .take(self.params.max_statement_args) @@ -463,27 +510,26 @@ impl OperationVerifyGadget { builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - resolved_op_args: &[StatementTarget], + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpHashOf"); let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf); - let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = - self.first_n_args_as_values(builder, resolved_op_args); + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); let expected_hash_value = builder.hash_values(arg2_value, arg3_value); let hash_value_ok = builder.is_equal_slice(&arg1_value.elements, &expected_hash_value.elements); - let arg1_key = resolved_op_args[0].args[0].clone(); - let arg2_key = resolved_op_args[1].args[0].clone(); - let arg3_key = resolved_op_args[2].args[0].clone(); + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let arg3_expected = cache.equations[2].lhs.clone(); let expected_statement = StatementTarget::new_native( builder, &self.params, NativePredicate::HashOf, - &[arg1_key, arg2_key, arg3_key], + &[arg1_expected, arg2_expected, arg3_expected], ); let st_ok = builder.is_equal_flattenable(st, &expected_statement); @@ -497,15 +543,14 @@ impl OperationVerifyGadget { builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - resolved_op_args: &[StatementTarget], + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpSumOf"); let value_zero = ValueTarget::zero(builder); let op_code_ok = op_type.has_native(builder, NativeOperation::SumOf); - let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = - self.first_n_args_as_values(builder, resolved_op_args); + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); // Select to avoid overflow. let summand1 = builder.select_value(op_code_ok, arg2_value, value_zero); @@ -515,14 +560,14 @@ impl OperationVerifyGadget { let sum_ok = builder.is_equal_slice(&arg1_value.elements, &expected_sum.elements); - let arg1_key = resolved_op_args[0].args[0].clone(); - let arg2_key = resolved_op_args[1].args[0].clone(); - let arg3_key = resolved_op_args[2].args[0].clone(); + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let arg3_expected = cache.equations[2].lhs.clone(); let expected_statement = StatementTarget::new_native( builder, &self.params, NativePredicate::SumOf, - &[arg1_key, arg2_key, arg3_key], + &[arg1_expected, arg2_expected, arg3_expected], ); let st_ok = builder.is_equal_flattenable(st, &expected_statement); @@ -536,15 +581,14 @@ impl OperationVerifyGadget { builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - resolved_op_args: &[StatementTarget], + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpProductOf"); let value_zero = ValueTarget::zero(builder); let op_code_ok = op_type.has_native(builder, NativeOperation::ProductOf); - let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = - self.first_n_args_as_values(builder, resolved_op_args); + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); // Select to avoid overflow. let factor1 = builder.select_value(op_code_ok, arg2_value, value_zero); @@ -554,14 +598,14 @@ impl OperationVerifyGadget { let product_ok = builder.is_equal_slice(&arg1_value.elements, &expected_product.elements); - let arg1_key = resolved_op_args[0].args[0].clone(); - let arg2_key = resolved_op_args[1].args[0].clone(); - let arg3_key = resolved_op_args[2].args[0].clone(); + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let arg3_expected = cache.equations[2].lhs.clone(); let expected_statement = StatementTarget::new_native( builder, &self.params, NativePredicate::ProductOf, - &[arg1_key, arg2_key, arg3_key], + &[arg1_expected, arg2_expected, arg3_expected], ); let st_ok = builder.is_equal_flattenable(st, &expected_statement); @@ -575,13 +619,12 @@ impl OperationVerifyGadget { builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - resolved_op_args: &[StatementTarget], + cache: &StatementCache, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpMaxOf"); let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf); - let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = - self.first_n_args_as_values(builder, resolved_op_args); + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values(); // Check that arg1_value is equal to one of the other two // values. @@ -600,14 +643,14 @@ impl OperationVerifyGadget { let lt_check_enabled = builder.and(not_all_eq, op_code_ok); builder.assert_i64_less_if(lt_check_enabled, lower_bound, arg1_value); - let arg1_key = resolved_op_args[0].args[0].clone(); - let arg2_key = resolved_op_args[1].args[0].clone(); - let arg3_key = resolved_op_args[2].args[0].clone(); + let arg1_expected = cache.equations[0].lhs.clone(); + let arg2_expected = cache.equations[1].lhs.clone(); + let arg3_expected = cache.equations[2].lhs.clone(); let expected_statement = StatementTarget::new_native( builder, &self.params, NativePredicate::MaxOf, - &[arg1_key, arg2_key, arg3_key], + &[arg1_expected, arg2_expected, arg3_expected], ); let st_ok = builder.is_equal_flattenable(st, &expected_statement); @@ -633,22 +676,22 @@ impl OperationVerifyGadget { resolved_op_args[1].has_native_type(builder, &self.params, NativePredicate::Equal); let arg_types_ok = builder.all([arg1_type_ok, arg2_type_ok]); - let arg1_key1 = &resolved_op_args[0].args[0]; - let arg1_key2 = &resolved_op_args[0].args[1]; - let arg2_key1 = &resolved_op_args[1].args[0]; - let arg2_key2 = &resolved_op_args[1].args[1]; + let arg1_lhs = &resolved_op_args[0].args[0]; + let arg1_rhs = &resolved_op_args[0].args[1]; + let arg2_lhs = &resolved_op_args[1].args[0]; + let arg2_rhs = &resolved_op_args[1].args[1]; - let inner_keys_match = builder.is_equal_slice(&arg1_key2.elements, &arg2_key1.elements); + let inner_args_match = builder.is_equal_slice(&arg1_rhs.elements, &arg2_lhs.elements); let expected_statement = StatementTarget::new_native( builder, &self.params, NativePredicate::Equal, - &[arg1_key1.clone(), arg2_key2.clone()], + &[arg1_lhs.clone(), arg2_rhs.clone()], ); let st_ok = builder.is_equal_flattenable(st, &expected_statement); - let ok = builder.all([op_code_ok, arg_types_ok, inner_keys_match, st_ok]); + let ok = builder.all([op_code_ok, arg_types_ok, inner_args_match, st_ok]); measure_gates_end!(builder, measure); ok } @@ -676,11 +719,11 @@ impl OperationVerifyGadget { st: &StatementTarget, op_type: &OperationTypeTarget, prev_statements: &[StatementTarget], + input_statements_offset: usize, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpNewEntry"); let op_code_ok = op_type.has_native(builder, NativeOperation::NewEntry); - - let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::ValueOf); + let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::Equal); let expected_arg_prefix = builder.constants( &StatementArg::Key(AnchoredKey::from((SELF, ""))).to_fields(&self.params)[..VALUE_SIZE], @@ -688,19 +731,12 @@ impl OperationVerifyGadget { let arg_prefix_ok = builder.is_equal_slice(&st.args[0].elements[..VALUE_SIZE], &expected_arg_prefix); - let dupe_check = { - let individual_checks = prev_statements - .iter() - .map(|ps| { - let same_predicate = builder.is_equal_flattenable(&st.predicate, &ps.predicate); - let same_anchored_key = - builder.is_equal_slice(&st.args[0].elements, &ps.args[0].elements); - builder.and(same_predicate, same_anchored_key) - }) - .collect::>(); - builder.any(individual_checks) - }; - + let input_statements = &prev_statements[input_statements_offset..]; + let individual_dupe_checks = input_statements + .iter() + .map(|ps| builder.is_equal_slice(&st.args[0].elements, &ps.args[0].elements)) + .collect::>(); + let dupe_check = builder.any(individual_dupe_checks); let no_dupes_ok = builder.not(dupe_check); let ok = builder.all([op_code_ok, st_code_ok, arg_prefix_ok, no_dupes_ok]); @@ -721,14 +757,14 @@ impl OperationVerifyGadget { let arg_type_ok = resolved_op_args[0].has_native_type(builder, &self.params, NativePredicate::Lt); - let arg1_key = resolved_op_args[0].args[0].clone(); - let arg2_key = resolved_op_args[0].args[1].clone(); + let arg1_expected = resolved_op_args[0].args[0].clone(); + let arg2_expected = resolved_op_args[0].args[1].clone(); let expected_statement = StatementTarget::new_native( builder, &self.params, NativePredicate::NotEqual, - &[arg1_key, arg2_key], + &[arg1_expected, arg2_expected], ); let st_ok = builder.is_equal_flattenable(st, &expected_statement); @@ -1300,9 +1336,9 @@ impl MainPodVerifyGadget { let expected_type_statement = StatementTarget::from_flattened( &self.params, &builder.constants( - &Statement::ValueOf( - AnchoredKey::from((SELF, KEY_TYPE)), - Value::from(PodType::MockMain), + &Statement::equal( + ValueRef::Key(AnchoredKey::from((SELF, KEY_TYPE))), + ValueRef::Literal(Value::from(PodType::MockMain)), ) .to_fields(params), ), @@ -1322,6 +1358,7 @@ impl MainPodVerifyGadget { st, op, prev_statements, + input_statements_offset, &merkle_claims, &custom_predicate_verification_table, )?; @@ -1624,6 +1661,7 @@ mod tests { &st_target, &op_target, &prev_statements_target, + 0, &merkle_claims_target, &custom_predicate_verification_table, )?; @@ -1651,13 +1689,13 @@ mod tests { #[test] fn test_lt_lteq_verify_failures() { let st1: mainpod::Statement = - Statement::ValueOf(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); - let st2: mainpod::Statement = Statement::ValueOf( + Statement::equal(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), Value::from(56), ) .into(); - let st3: mainpod::Statement = Statement::ValueOf( + let st3: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), Value::from(RawValue([ GoldilocksField::NEG_ONE, @@ -1667,12 +1705,12 @@ mod tests { ])), ) .into(); - let st4: mainpod::Statement = Statement::ValueOf( + let st4: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(74).into()), "mundo")), Value::from(-55), ) .into(); - let st5: mainpod::Statement = Statement::ValueOf( + let st5: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(70).into()), "que")), Value::from(-56), ) @@ -1688,7 +1726,7 @@ mod tests { vec![OperationArg::Index(1), OperationArg::Index(0)], OperationAux::None, ), - Statement::Lt( + Statement::lt( AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), AnchoredKey::from((SELF, "hello")), ) @@ -1700,7 +1738,7 @@ mod tests { vec![OperationArg::Index(0), OperationArg::Index(0)], OperationAux::None, ), - Statement::Lt( + Statement::lt( AnchoredKey::from((SELF, "hello")), AnchoredKey::from((SELF, "hello")), ) @@ -1712,7 +1750,7 @@ mod tests { vec![OperationArg::Index(1), OperationArg::Index(0)], OperationAux::None, ), - Statement::LtEq( + Statement::lt_eq( AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), AnchoredKey::from((SELF, "hello")), ) @@ -1724,7 +1762,7 @@ mod tests { vec![OperationArg::Index(3), OperationArg::Index(3)], OperationAux::None, ), - Statement::Lt( + Statement::lt( AnchoredKey::from((PodId(RawValue::from(74).into()), "mundo")), AnchoredKey::from((PodId(RawValue::from(74).into()), "mundo")), ) @@ -1736,7 +1774,7 @@ mod tests { vec![OperationArg::Index(3), OperationArg::Index(4)], OperationAux::None, ), - Statement::Lt( + Statement::lt( AnchoredKey::from((PodId(RawValue::from(74).into()), "mundo")), AnchoredKey::from((PodId(RawValue::from(70).into()), "que")), ) @@ -1748,7 +1786,7 @@ mod tests { vec![OperationArg::Index(3), OperationArg::Index(4)], OperationAux::None, ), - Statement::LtEq( + Statement::lt_eq( AnchoredKey::from((PodId(RawValue::from(74).into()), "mundo")), AnchoredKey::from((PodId(RawValue::from(70).into()), "que")), ) @@ -1763,7 +1801,7 @@ mod tests { vec![OperationArg::Index(1), OperationArg::Index(2)], OperationAux::None, ), - Statement::Lt( + Statement::lt( AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), ) @@ -1775,7 +1813,7 @@ mod tests { vec![OperationArg::Index(2), OperationArg::Index(2)], OperationAux::None, ), - Statement::LtEq( + Statement::lt_eq( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), ) @@ -1803,13 +1841,13 @@ mod tests { #[test] fn test_eq_neq_verify_failures() { let st1: mainpod::Statement = - Statement::ValueOf(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); - let st2: mainpod::Statement = Statement::ValueOf( + Statement::equal(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), Value::from(56), ) .into(); - let st3: mainpod::Statement = Statement::ValueOf( + let st3: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), Value::from(RawValue([ GoldilocksField::NEG_ONE, @@ -1829,7 +1867,7 @@ mod tests { vec![OperationArg::Index(1), OperationArg::Index(0)], OperationAux::None, ), - Statement::Equal( + Statement::equal( AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), AnchoredKey::from((SELF, "hello")), ) @@ -1841,7 +1879,7 @@ mod tests { vec![OperationArg::Index(0), OperationArg::Index(0)], OperationAux::None, ), - Statement::NotEqual( + Statement::not_equal( AnchoredKey::from((SELF, "hello")), AnchoredKey::from((SELF, "hello")), ) @@ -1869,8 +1907,8 @@ mod tests { #[test] fn test_operation_verify_newentry() -> Result<()> { let st1: mainpod::Statement = - Statement::ValueOf(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); - let st2: mainpod::Statement = Statement::ValueOf( + Statement::equal(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(75).into()), "hello")), Value::from(55), ) @@ -1899,13 +1937,13 @@ mod tests { #[test] fn test_operation_verify_eq() -> Result<()> { let st1: mainpod::Statement = - Statement::ValueOf(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); - let st2: mainpod::Statement = Statement::ValueOf( + Statement::equal(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), Value::from(55), ) .into(); - let st: mainpod::Statement = Statement::Equal( + let st: mainpod::Statement = Statement::equal( AnchoredKey::from((SELF, "hello")), AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), ) @@ -1922,13 +1960,13 @@ mod tests { #[test] fn test_operation_verify_neq() -> Result<()> { let st1: mainpod::Statement = - Statement::ValueOf(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); - let st2: mainpod::Statement = Statement::ValueOf( + Statement::equal(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), Value::from(58), ) .into(); - let st: mainpod::Statement = Statement::NotEqual( + let st: mainpod::Statement = Statement::not_equal( AnchoredKey::from((SELF, "hello")), AnchoredKey::from((PodId(RawValue::from(75).into()), "world")), ) @@ -1945,13 +1983,13 @@ mod tests { #[test] fn test_operation_verify_lt() -> Result<()> { let st1: mainpod::Statement = - Statement::ValueOf(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); - let st2: mainpod::Statement = Statement::ValueOf( + Statement::equal(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), Value::from(56), ) .into(); - let st: mainpod::Statement = Statement::Lt( + let st: mainpod::Statement = Statement::lt( AnchoredKey::from((SELF, "hello")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), ) @@ -1965,17 +2003,17 @@ mod tests { operation_verify(st, op, prev_statements, vec![])?; // Also check negative < negative - let st3: mainpod::Statement = Statement::ValueOf( + let st3: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(89).into()), "hola")), Value::from(-56), ) .into(); - let st4: mainpod::Statement = Statement::ValueOf( + let st4: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(84).into()), "mundo")), Value::from(-55), ) .into(); - let st: mainpod::Statement = Statement::Lt( + let st: mainpod::Statement = Statement::lt( AnchoredKey::from((PodId(RawValue::from(89).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(84).into()), "mundo")), ) @@ -1989,7 +2027,7 @@ mod tests { operation_verify(st, op, prev_statements, vec![])?; // Also check negative < positive - let st: mainpod::Statement = Statement::Lt( + let st: mainpod::Statement = Statement::lt( AnchoredKey::from((PodId(RawValue::from(89).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), ) @@ -2006,13 +2044,13 @@ mod tests { #[test] fn test_operation_verify_lteq() -> Result<()> { let st1: mainpod::Statement = - Statement::ValueOf(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); - let st2: mainpod::Statement = Statement::ValueOf( + Statement::equal(AnchoredKey::from((SELF, "hello")), Value::from(55)).into(); + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), Value::from(56), ) .into(); - let st: mainpod::Statement = Statement::LtEq( + let st: mainpod::Statement = Statement::lt_eq( AnchoredKey::from((SELF, "hello")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), ) @@ -2026,17 +2064,17 @@ mod tests { operation_verify(st, op, prev_statements, vec![])?; // Also check negative <= negative - let st3: mainpod::Statement = Statement::ValueOf( + let st3: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(89).into()), "hola")), Value::from(-56), ) .into(); - let st4: mainpod::Statement = Statement::ValueOf( + let st4: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(84).into()), "mundo")), Value::from(-55), ) .into(); - let st: mainpod::Statement = Statement::LtEq( + let st: mainpod::Statement = Statement::lt_eq( AnchoredKey::from((PodId(RawValue::from(89).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(84).into()), "mundo")), ) @@ -2050,7 +2088,7 @@ mod tests { operation_verify(st, op, prev_statements, vec![])?; // Also check negative <= positive - let st: mainpod::Statement = Statement::LtEq( + let st: mainpod::Statement = Statement::lt_eq( AnchoredKey::from((PodId(RawValue::from(89).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), ) @@ -2064,7 +2102,7 @@ mod tests { operation_verify(st, op, prev_statements.clone(), vec![])?; // Also check equality, both positive and negative. - let st: mainpod::Statement = Statement::LtEq( + let st: mainpod::Statement = Statement::lt_eq( AnchoredKey::from((PodId(RawValue::from(89).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(89).into()), "hola")), ) @@ -2075,7 +2113,7 @@ mod tests { OperationAux::None, ); operation_verify(st, op, prev_statements.clone(), vec![])?; - let st: mainpod::Statement = Statement::LtEq( + let st: mainpod::Statement = Statement::lt_eq( AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), ) @@ -2102,23 +2140,23 @@ mod tests { let v1 = hash_values(&input_values); let [v2, v3] = input_values; - let st1: mainpod::Statement = Statement::ValueOf( + let st1: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), - v1.into(), + Value::from(v1), ) .into(); - let st2: mainpod::Statement = Statement::ValueOf( + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), v2, ) .into(); - let st3: mainpod::Statement = Statement::ValueOf( + let st3: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), v3, ) .into(); - let st: mainpod::Statement = Statement::HashOf( + let st: mainpod::Statement = Statement::hash_of( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), @@ -2146,25 +2184,25 @@ mod tests { overflow.not().then_some((a, b, sum)) }) .try_for_each(|(a, b, sum)| { - let st1: mainpod::Statement = Statement::ValueOf( + let st1: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), - sum.into(), + sum, ) .into(); - let st2: mainpod::Statement = Statement::ValueOf( + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), - a.into(), + a, ) .into(); - let st3: mainpod::Statement = Statement::ValueOf( + let st3: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), - b.into(), + b, ) .into(); - let st: mainpod::Statement = Statement::SumOf( + let st: mainpod::Statement = Statement::sum_of( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), @@ -2193,25 +2231,25 @@ mod tests { overflow.not().then_some((a, b, prod)) }) .try_for_each(|(a, b, prod)| { - let st1: mainpod::Statement = Statement::ValueOf( + let st1: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), - prod.into(), + prod, ) .into(); - let st2: mainpod::Statement = Statement::ValueOf( + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), - a.into(), + a, ) .into(); - let st3: mainpod::Statement = Statement::ValueOf( + let st3: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), - b.into(), + b, ) .into(); - let st: mainpod::Statement = Statement::ProductOf( + let st: mainpod::Statement = Statement::product_of( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), @@ -2235,25 +2273,25 @@ mod tests { fn test_operation_verify_maxof() -> Result<()> { I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| { let max = i64::max(a, b); - let st1: mainpod::Statement = Statement::ValueOf( + let st1: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), - max.into(), + max, ) .into(); - let st2: mainpod::Statement = Statement::ValueOf( + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), - a.into(), + a, ) .into(); - let st3: mainpod::Statement = Statement::ValueOf( + let st3: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), - b.into(), + b, ) .into(); - let st: mainpod::Statement = Statement::MaxOf( + let st: mainpod::Statement = Statement::max_of( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), @@ -2278,25 +2316,25 @@ mod tests { [(5, 3, 4), (5, 5, 8), (3, 4, 5)] .into_iter() .for_each(|(max, a, b)| { - let st1: mainpod::Statement = Statement::ValueOf( + let st1: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), - max.into(), + max, ) .into(); - let st2: mainpod::Statement = Statement::ValueOf( + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), - a.into(), + a, ) .into(); - let st3: mainpod::Statement = Statement::ValueOf( + let st3: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), - b.into(), + b, ) .into(); - let st: mainpod::Statement = Statement::MaxOf( + let st: mainpod::Statement = Statement::max_of( AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), @@ -2331,12 +2369,12 @@ mod tests { #[test] fn test_operation_verify_lt_to_neq() -> Result<()> { - let st: mainpod::Statement = Statement::NotEqual( + let st: mainpod::Statement = Statement::not_equal( AnchoredKey::from((SELF, "hello")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), ) .into(); - let st1: mainpod::Statement = Statement::Lt( + let st1: mainpod::Statement = Statement::lt( AnchoredKey::from((SELF, "hello")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hello")), ) @@ -2352,17 +2390,17 @@ mod tests { #[test] fn test_operation_verify_transitive_eq() -> Result<()> { - let st: mainpod::Statement = Statement::Equal( + let st: mainpod::Statement = Statement::equal( AnchoredKey::from((SELF, "hello")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), ) .into(); - let st1: mainpod::Statement = Statement::Equal( + let st1: mainpod::Statement = Statement::equal( AnchoredKey::from((SELF, "hello")), AnchoredKey::from((PodId(RawValue::from(89).into()), "world")), ) .into(); - let st2: mainpod::Statement = Statement::Equal( + let st2: mainpod::Statement = Statement::equal( AnchoredKey::from((PodId(RawValue::from(89).into()), "world")), AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), ) @@ -2397,9 +2435,9 @@ mod tests { let no_key_pf = mt.prove_nonexistence(&key)?; - let root_st: mainpod::Statement = Statement::ValueOf(root_ak.clone(), root.clone()).into(); - let key_st: mainpod::Statement = Statement::ValueOf(key_ak.clone(), key.into()).into(); - let st: mainpod::Statement = Statement::NotContains(root_ak, key_ak).into(); + let root_st: mainpod::Statement = Statement::equal(root_ak.clone(), root.clone()).into(); + let key_st: mainpod::Statement = Statement::equal(key_ak.clone(), key).into(); + let st: mainpod::Statement = Statement::not_contains(root_ak, key_ak).into(); let op = mainpod::Operation( OperationType::Native(NativeOperation::NotContainsFromEntries), vec![OperationArg::Index(0), OperationArg::Index(1)], @@ -2438,12 +2476,11 @@ mod tests { let (value, key_pf) = mt.prove(&key)?; let value_ak = AnchoredKey::from((PodId(RawValue::from(72).into()), "value")); - let root_st: mainpod::Statement = Statement::ValueOf(root_ak.clone(), root.clone()).into(); - let key_st: mainpod::Statement = Statement::ValueOf(key_ak.clone(), key.into()).into(); - let value_st: mainpod::Statement = - Statement::ValueOf(value_ak.clone(), value.into()).into(); + let root_st: mainpod::Statement = Statement::equal(root_ak.clone(), root.clone()).into(); + let key_st: mainpod::Statement = Statement::equal(key_ak.clone(), key).into(); + let value_st: mainpod::Statement = Statement::equal(value_ak.clone(), value).into(); - let st: mainpod::Statement = Statement::Contains(root_ak, key_ak, value_ak).into(); + let st: mainpod::Statement = Statement::contains(root_ak, key_ak, value_ak).into(); let op = mainpod::Operation( OperationType::Native(NativeOperation::ContainsFromEntries), vec![ @@ -2609,7 +2646,7 @@ mod tests { let pod_id = PodId(hash_str("pod_id")); let st_tmpl = StatementTmpl { - pred: Predicate::Native(NativePredicate::ValueOf), + pred: Predicate::Native(NativePredicate::Equal), args: vec![ StatementTmplArg::AnchoredKey( SelfOrWildcard::Wildcard(Wildcard::new("a".to_string(), 1)), @@ -2619,7 +2656,7 @@ mod tests { ], }; let args = vec![Value::from(1), Value::from(pod_id.0), Value::from(3)]; - let expected_st = Statement::ValueOf( + let expected_st = Statement::equal( AnchoredKey::new(pod_id, Key::from("key")), Value::from("value"), ); @@ -2695,10 +2732,10 @@ mod tests { use NativePredicate as NP; use StatementTmplBuilder as STB; let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb0 = STB::new(NP::ValueOf) + let stb0 = STB::new(NP::Equal) .arg(("id", key("score"))) .arg(literal(42)); - let stb1 = STB::new(NP::ValueOf) + let stb1 = STB::new(NP::Equal) .arg(("id", "secret_key")) .arg(literal(1234)); let _ = builder.predicate_and( @@ -2715,11 +2752,11 @@ mod tests { // AND let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); let op_args = vec![ - Statement::ValueOf( + Statement::equal( AnchoredKey::new(pod_id, Key::from("score")), Value::from(42), ), - Statement::ValueOf( + Statement::equal( AnchoredKey::new(pod_id, Key::from("foo")), Value::from(1234), ), @@ -2745,7 +2782,7 @@ mod tests { // OR (1) let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); let op_args = vec![ - Statement::ValueOf( + Statement::equal( AnchoredKey::new(pod_id, Key::from("score")), Value::from(42), ), @@ -2770,7 +2807,7 @@ mod tests { let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); let op_args = vec![ Statement::None, - Statement::ValueOf( + Statement::equal( AnchoredKey::new(pod_id, Key::from("foo")), Value::from(1234), ), @@ -2811,7 +2848,7 @@ mod tests { use NativePredicate as NP; use StatementTmplBuilder as STB; let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); - let stb0 = STB::new(NP::ValueOf) + let stb0 = STB::new(NP::Equal) .arg(("id", key("score"))) .arg(literal(42)); let stb1 = STB::new(NP::Equal) @@ -2831,11 +2868,11 @@ mod tests { // AND (0) Sanity check with correct values let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); let op_args = vec![ - Statement::ValueOf( + Statement::equal( AnchoredKey::new(pod_id, Key::from("score")), Value::from(42), ), - Statement::Equal( + Statement::equal( AnchoredKey::new(pod_id, Key::from("foo")), AnchoredKey::new(pod_id, Key::from("score")), ), @@ -2861,11 +2898,11 @@ mod tests { // AND (1) Different pod_id for same wildcard let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); let op_args = vec![ - Statement::ValueOf( + Statement::equal( AnchoredKey::new(pod_id, Key::from("score")), Value::from(42), ), - Statement::Equal( + Statement::equal( AnchoredKey::new(PodId(hash_str("BAD")), Key::from("foo")), AnchoredKey::new(pod_id, Key::from("score")), ), @@ -2887,8 +2924,8 @@ mod tests { // AND (2) key doesn't match template let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); let op_args = vec![ - Statement::ValueOf(AnchoredKey::new(pod_id, Key::from("BAD")), Value::from(42)), - Statement::Equal( + Statement::equal(AnchoredKey::new(pod_id, Key::from("BAD")), Value::from(42)), + Statement::equal( AnchoredKey::new(pod_id, Key::from("foo")), AnchoredKey::new(pod_id, Key::from("score")), ), @@ -2910,11 +2947,11 @@ mod tests { // AND (3) literal doesn't match template let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); let op_args = vec![ - Statement::ValueOf( + Statement::equal( AnchoredKey::new(pod_id, Key::from("score")), Value::from(0xbad), ), - Statement::Equal( + Statement::equal( AnchoredKey::new(pod_id, Key::from("foo")), AnchoredKey::new(pod_id, Key::from("score")), ), @@ -2936,11 +2973,11 @@ mod tests { // AND (4) predicate doesn't match template let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); let op_args = vec![ - Statement::ValueOf( + Statement::equal( AnchoredKey::new(pod_id, Key::from("score")), Value::from(42), ), - Statement::NotEqual( + Statement::not_equal( AnchoredKey::new(pod_id, Key::from("foo")), AnchoredKey::new(pod_id, Key::from("score")), ), @@ -3036,8 +3073,8 @@ mod tests { }; let statements = [ - Statement::ValueOf(AnchoredKey::from((SELF, "foo")), Value::from(42)), - Statement::Equal( + Statement::equal(AnchoredKey::from((SELF, "foo")), Value::from(42)), + Statement::equal( AnchoredKey::from((SELF, "bar")), AnchoredKey::from((SELF, "baz")), ), @@ -3058,12 +3095,12 @@ mod tests { let pod_id = PodId(hash_str("pod_id")); let statements = [ - Statement::ValueOf(AnchoredKey::from((SELF, "foo")), Value::from(42)), - Statement::Equal( + Statement::equal(AnchoredKey::from((SELF, "foo")), Value::from(42)), + Statement::equal( AnchoredKey::from((SELF, "bar")), AnchoredKey::from((SELF, "baz")), ), - Statement::Lt( + Statement::lt( AnchoredKey::from((pod_id, "one")), AnchoredKey::from((pod_id, "two")), ), diff --git a/src/backends/plonky2/circuits/signedpod.rs b/src/backends/plonky2/circuits/signedpod.rs index e1ac623..ee5a1e2 100644 --- a/src/backends/plonky2/circuits/signedpod.rs +++ b/src/backends/plonky2/circuits/signedpod.rs @@ -94,8 +94,7 @@ impl SignedPodVerifyTarget { self_id: bool, ) -> Vec { let mut statements = Vec::new(); - let predicate = - PredicateTarget::new_native(builder, &self.params, NativePredicate::ValueOf); + let predicate = PredicateTarget::new_native(builder, &self.params, NativePredicate::Equal); let pod_id = if self_id { builder.constant_value(SELF.0.into()) } else { diff --git a/src/backends/plonky2/emptypod.rs b/src/backends/plonky2/emptypod.rs index 21ba8c9..5a60674 100644 --- a/src/backends/plonky2/emptypod.rs +++ b/src/backends/plonky2/emptypod.rs @@ -40,7 +40,7 @@ struct EmptyPodVerifyCircuit { } fn type_statement() -> Statement { - Statement::ValueOf( + Statement::equal( AnchoredKey::from((SELF, KEY_TYPE)), Value::from(PodType::Empty), ) diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index db3e801..a814a77 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -26,9 +26,9 @@ use crate::{ STANDARD_REC_MAIN_POD_CIRCUIT_DATA, }, middleware::{ - self, resolve_wildcard_values, AnchoredKey, CustomPredicateBatch, DynError, Hash, - MainPodInputs, NativeOperation, OperationType, Params, Pod, PodId, PodProver, PodType, - RecursivePod, StatementArg, ToFields, VDSet, F, KEY_TYPE, SELF, + self, resolve_wildcard_values, value_from_op, AnchoredKey, CustomPredicateBatch, DynError, + Hash, MainPodInputs, NativeOperation, OperationType, Params, Pod, PodId, PodProver, + PodType, RecursivePod, StatementArg, ToFields, VDSet, F, KEY_TYPE, SELF, }, }; @@ -125,31 +125,40 @@ pub(crate) fn extract_custom_predicate_verifications( pub(crate) fn extract_merkle_proofs( params: &Params, operations: &[middleware::Operation], + statements: &[middleware::Statement], ) -> Result> { + assert_eq!(operations.len(), statements.len()); let merkle_proofs: Vec<_> = operations .iter() - .flat_map(|op| match op { - middleware::Operation::ContainsFromEntries( - middleware::Statement::ValueOf(_, root), - middleware::Statement::ValueOf(_, key), - middleware::Statement::ValueOf(_, value), - pf, - ) => Some(MerkleClaimAndProof::new( - Hash::from(root.raw()), - key.raw(), - Some(value.raw()), - pf.clone(), - )), - middleware::Operation::NotContainsFromEntries( - middleware::Statement::ValueOf(_, root), - middleware::Statement::ValueOf(_, key), - pf, - ) => Some(MerkleClaimAndProof::new( - Hash::from(root.raw()), - key.raw(), - None, - pf.clone(), - )), + .zip(statements.iter()) + .flat_map(|(op, st)| match (op, st) { + ( + middleware::Operation::ContainsFromEntries(root_s, key_s, value_s, pf), + middleware::Statement::Contains(root_ref, key_ref, value_ref), + ) => { + let root = value_from_op(root_s, root_ref)?; + let key = value_from_op(key_s, key_ref)?; + let value = value_from_op(value_s, value_ref)?; + Some(MerkleClaimAndProof::new( + Hash::from(root.raw()), + key.raw(), + Some(value.raw()), + pf.clone(), + )) + } + ( + middleware::Operation::NotContainsFromEntries(root_s, key_s, pf), + middleware::Statement::NotContains(root_ref, key_ref), + ) => { + let root = value_from_op(root_s, root_ref)?; + let key = value_from_op(key_s, key_ref)?; + Some(MerkleClaimAndProof::new( + Hash::from(root.raw()), + key.raw(), + None, + pf.clone(), + )) + } _ => None, }) .collect(); @@ -320,9 +329,9 @@ pub(crate) fn layout_statements( // Public statements assert!(inputs.public_statements.len() < params.max_public_statements); - let mut type_st = middleware::Statement::ValueOf( - AnchoredKey::from((SELF, KEY_TYPE)), - middleware::Value::from(PodType::MockMain), + let mut type_st = middleware::Statement::Equal( + AnchoredKey::from((SELF, KEY_TYPE)).into(), + middleware::Value::from(PodType::MockMain).into(), ) .into(); pad_statement(params, &mut type_st); @@ -470,7 +479,7 @@ impl Prover { }) .collect_vec(); - let merkle_proofs = extract_merkle_proofs(params, inputs.operations)?; + let merkle_proofs = extract_merkle_proofs(params, inputs.operations, inputs.statements)?; let custom_predicate_batches = extract_custom_predicate_batches(params, inputs.operations)?; let custom_predicate_verifications = extract_custom_predicate_verifications( params, @@ -805,12 +814,12 @@ pub mod tests { max_signed_pod_values: 2, max_public_statements: 2, num_public_statements_id: 4, - max_statement_args: 2, + max_statement_args: 3, max_operation_args: 3, max_custom_predicate_batches: 2, max_custom_predicate_verifications: 2, max_custom_predicate_arity: 2, - max_custom_predicate_wildcards: 2, + max_custom_predicate_wildcards: 3, max_custom_batch_size: 2, max_merkle_proofs_containers: 2, max_depth_mt_containers: 4, @@ -905,7 +914,7 @@ pub mod tests { let vd_set = &*DEFAULT_VD_SET; let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "cpb".into()); - let stb0 = STB::new(NP::ValueOf) + let stb0 = STB::new(NP::Equal) .arg(("id", key("score"))) .arg(literal(42)); let stb1 = STB::new(NP::Equal) diff --git a/src/backends/plonky2/mainpod/operation.rs b/src/backends/plonky2/mainpod/operation.rs index 561ed82..6a91188 100644 --- a/src/backends/plonky2/mainpod/operation.rs +++ b/src/backends/plonky2/mainpod/operation.rs @@ -75,7 +75,11 @@ impl Operation { .iter() .flat_map(|arg| match arg { OperationArg::None => None, - OperationArg::Index(i) => Some(statements[*i].clone().try_into()), + OperationArg::Index(i) => { + let st: Result = + statements[*i].clone().try_into(); + Some(st) + } }) .collect::>>()?; let deref_aux = match self.2 { diff --git a/src/backends/plonky2/mainpod/statement.rs b/src/backends/plonky2/mainpod/statement.rs index d121252..a02bee3 100644 --- a/src/backends/plonky2/mainpod/statement.rs +++ b/src/backends/plonky2/mainpod/statement.rs @@ -48,41 +48,25 @@ impl TryFrom for middleware::Statement { type NP = NativePredicate; type SA = StatementArg; let proper_args = s.args(); - let args = ( - proper_args.first().cloned(), - proper_args.get(1).cloned(), - proper_args.get(2).cloned(), - ); Ok(match s.0 { - Predicate::Native(np) => match (np, args, proper_args.len()) { - (NP::None, _, 0) => S::None, - (NP::ValueOf, (Some(SA::Key(ak)), Some(SA::Literal(v)), None), 2) => { - S::ValueOf(ak, v) + Predicate::Native(np) => match (np, &proper_args.as_slice()) { + (NP::None, &[]) => S::None, + (NP::Equal, &[a1, a2]) => S::Equal(a1.try_into()?, a2.try_into()?), + (NP::NotEqual, &[a1, a2]) => S::NotEqual(a1.try_into()?, a2.try_into()?), + (NP::LtEq, &[a1, a2]) => S::LtEq(a1.try_into()?, a2.try_into()?), + (NP::Lt, &[a1, a2]) => S::Lt(a1.try_into()?, a2.try_into()?), + (NP::Contains, &[a1, a2, a3]) => { + S::Contains(a1.try_into()?, a2.try_into()?, a3.try_into()?) } - (NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { - S::Equal(ak1, ak2) + (NP::NotContains, &[a1, a2]) => S::NotContains(a1.try_into()?, a2.try_into()?), + (NP::SumOf, &[a1, a2, a3]) => { + S::SumOf(a1.try_into()?, a2.try_into()?, a3.try_into()?) } - (NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { - S::NotEqual(ak1, ak2) + (NP::ProductOf, &[a1, a2, a3]) => { + S::ProductOf(a1.try_into()?, a2.try_into()?, a3.try_into()?) } - (NP::LtEq, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::LtEq(ak1, ak2), - (NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => S::Lt(ak1, ak2), - (NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { - S::Contains(ak1, ak2, ak3) - } - (NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => { - S::NotContains(ak1, ak2) - } - (NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { - S::SumOf(ak1, ak2, ak3) - } - ( - NP::ProductOf, - (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), - 3, - ) => S::ProductOf(ak1, ak2, ak3), - (NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => { - S::MaxOf(ak1, ak2, ak3) + (NP::MaxOf, &[a1, a2, a3]) => { + S::MaxOf(a1.try_into()?, a2.try_into()?, a3.try_into()?) } _ => Err(Error::custom(format!( "Ill-formed statement expression {:?}", diff --git a/src/backends/plonky2/mock/emptypod.rs b/src/backends/plonky2/mock/emptypod.rs index 60c76bc..aa4387c 100644 --- a/src/backends/plonky2/mock/emptypod.rs +++ b/src/backends/plonky2/mock/emptypod.rs @@ -19,7 +19,7 @@ pub struct MockEmptyPod { } fn type_statement() -> Statement { - Statement::ValueOf( + Statement::equal( AnchoredKey::from((SELF, KEY_TYPE)), Value::from(PodType::Empty), ) diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index 5def079..1760ed0 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -4,6 +4,7 @@ use std::fmt; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use crate::{ @@ -18,8 +19,9 @@ use crate::{ primitives::merkletree::MerkleClaimAndProof, }, middleware::{ - self, hash_str, AnchoredKey, DynError, Hash, MainPodInputs, NativePredicate, Params, Pod, - PodId, PodProver, PodType, Predicate, RecursivePod, StatementArg, VDSet, KEY_TYPE, SELF, + self, hash_str, AnchoredKey, DynError, Hash, MainPodInputs, NativeOperation, + NativePredicate, OperationType, Params, Pod, PodId, PodProver, PodType, Predicate, + RecursivePod, StatementArg, VDSet, KEY_TYPE, SELF, }, }; @@ -160,7 +162,7 @@ impl MockMainPod { // value=PodType::MockMainPod` let (statements, public_statements) = layout_statements(params, true, &inputs)?; // Extract Merkle proofs and pad. - let merkle_proofs = extract_merkle_proofs(params, inputs.operations)?; + let merkle_proofs = extract_merkle_proofs(params, inputs.operations, inputs.statements)?; let operations = process_private_statements_operations( params, @@ -225,44 +227,30 @@ impl MockMainPod { // find a ValueOf statement from the public statements with key=KEY_TYPE and check that the // value is PodType::MockMainPod let has_type_statement = self.public_statements.iter().any(|s| { - s.0 == Predicate::Native(NativePredicate::ValueOf) - && !s.1.is_empty() - && if let StatementArg::Key(AnchoredKey { pod_id, ref key }) = s.1[0] { - pod_id == SELF && key.hash() == hash_str(KEY_TYPE) + s.0 == Predicate::Native(NativePredicate::Equal) && { + if let [StatementArg::Key(AnchoredKey { pod_id, ref key }), StatementArg::Literal(_)] = &s.1[..2] { + pod_id == &SELF && key.hash() == hash_str(KEY_TYPE) } else { false } - }); - // 3. check that all `input_statements` of type `ValueOf` with origin=SELF have unique keys + }}); + // 3. check that all `NewEntry` operations have unique keys // (no duplicates) - // TODO: Instead of doing this, do a uniqueness check when verifying the output of a - // `NewValue` operation. - let value_ofs_unique = { - let key_id_pairs = input_statements - .iter() - .enumerate() - .map(|(i, s)| { - ( - // Separate private from public statements. - if i < self.params.max_priv_statements() { - 0 - } else { - 1 - }, - s, - ) - }) - .filter(|(_, s)| s.0 == Predicate::Native(NativePredicate::ValueOf)) - .flat_map(|(i, s)| { - if let StatementArg::Key(ak) = &s.1[0] { - vec![(i, ak.pod_id, ak.key.hash())] - } else { - vec![] + let value_ofs_unique = input_statements + .iter() + .zip(self.operations.iter()) + .filter_map(|(s, o)| { + if matches!(o.0, OperationType::Native(NativeOperation::NewEntry)) { + match s.1.get(0) { + Some(StatementArg::Key(k)) => Some(k), + // malformed NewEntry operations are caught in step 5 + _ => None, } - }) - .collect::>(); - !(0..key_id_pairs.len() - 1).any(|i| key_id_pairs[i + 1..].contains(&key_id_pairs[i])) - }; + } else { + None + } + }) + .all_unique(); // 4. TODO: Verify type // 5. verify that all `input_statements` are correctly generated diff --git a/src/backends/plonky2/mock/signedpod.rs b/src/backends/plonky2/mock/signedpod.rs index 3ac05d0..c0ebaa6 100644 --- a/src/backends/plonky2/mock/signedpod.rs +++ b/src/backends/plonky2/mock/signedpod.rs @@ -156,7 +156,7 @@ impl Pod for MockSignedPod { [(key_type, value_type), (key_signer, value_signer)] .into_iter() .chain(kvs.into_iter().sorted_by_key(|kv| kv.0.hash())) - .map(|(k, v)| Statement::ValueOf(AnchoredKey::from((SELF, k)), v)) + .map(|(k, v)| Statement::equal(AnchoredKey::from((SELF, k)), v)) .collect() } diff --git a/src/backends/plonky2/signedpod.rs b/src/backends/plonky2/signedpod.rs index 975e836..88d47c2 100644 --- a/src/backends/plonky2/signedpod.rs +++ b/src/backends/plonky2/signedpod.rs @@ -192,7 +192,7 @@ impl Pod for SignedPod { [(key_type, value_type), (key_signer, value_signer)] .into_iter() .chain(kvs.into_iter().sorted_by_key(|kv| kv.0.hash())) - .map(|(k, v)| Statement::ValueOf(AnchoredKey::from((SELF, k)), v)) + .map(|(k, v)| Statement::equal(AnchoredKey::from((SELF, k)), v)) .collect() } diff --git a/src/examples/custom.rs b/src/examples/custom.rs index 75c2f0d..a64131c 100644 --- a/src/examples/custom.rs +++ b/src/examples/custom.rs @@ -27,7 +27,7 @@ pub fn eth_friend_batch(params: &Params, mock: bool) -> Result Result Result) -> Option<&Value> { self.kvs.get(&key.into()) } - // Returns the ValueOf statement that defines key if it exists. + // Returns the Equal statement that defines key if it exists. pub fn get_statement(&self, key: impl Into) -> Option { let key: Key = key.into(); self.kvs() .get(&key) - .map(|value| Statement::ValueOf(AnchoredKey::from((self.id(), key)), value.clone())) + .map(|value| Statement::equal(AnchoredKey::from((self.id(), key)), value.clone())) } } @@ -125,7 +125,7 @@ pub struct MainPodBuilder { // Internal state /// Counter for constants created from literals const_cnt: usize, - /// Map from (public, Value) to Key of already created literals via ValueOf statements. + /// Map from (public, Value) to Key of already created literals via Equal statements. literals: HashMap<(bool, Value), Key>, } @@ -186,38 +186,6 @@ impl MainPodBuilder { } } - /// Convert [OperationArg]s to [StatementArg]s for the operations that work with entries - fn op_args_entries( - &mut self, - public: bool, - args: &mut [OperationArg], - ) -> Result> { - let mut st_args = Vec::new(); - // TODO: Rewrite without calling args() and instead using matches? - for arg in args.iter_mut() { - match arg { - OperationArg::Statement(s) => { - if s.predicate() == Predicate::Native(NativePredicate::ValueOf) { - st_args.push(s.args()[0].clone()) - } else { - panic!("Invalid statement argument."); - } - } - // todo: better error handling - OperationArg::Literal(v) => { - let value_of_st = self.literal(public, v.clone())?; - *arg = OperationArg::Statement(value_of_st.clone()); - st_args.push(value_of_st.args()[0].clone()) - } - OperationArg::Entry(k, v) => { - st_args.push(StatementArg::Key(AnchoredKey::from((SELF, k.as_str())))); - st_args.push(StatementArg::Literal(v.clone())) - } - }; - } - Ok(st_args) - } - pub fn pub_op(&mut self, op: Operation) -> Result { self.op(true, op) } @@ -307,169 +275,190 @@ impl MainPodBuilder { } } - fn op(&mut self, public: bool, op: Operation) -> Result { + fn op_statement(&mut self, op: Operation) -> Result { use NativeOperation::*; - let mut op = Self::fill_in_aux(Self::lower_op(op))?; - let Operation(op_type, ref mut args, _) = &mut op; - // TODO: argument type checking - let pred = op_type.output_predicate().map(Ok).unwrap_or_else(|| { - // We are dealing with a copy here. - match (args).first() { - Some(OperationArg::Statement(s)) if args.len() == 1 => Ok(s.predicate().clone()), - _ => Err(Error::op_invalid_args("copy".to_string())), - } - })?; - - let st_args: Vec = match op_type { - OperationType::Native(o) => match o { - None => vec![], - NewEntry | EqualFromEntries | NotEqualFromEntries | LtFromEntries - | LtEqFromEntries => self.op_args_entries(public, args)?, - CopyStatement => match &args[0] { - OperationArg::Statement(s) => s.args().clone(), - _ => { - return Err(Error::op_invalid_args("copy".to_string())); - } - }, - TransitiveEqualFromStatements => { - match (args[0].clone(), args[1].clone()) { - ( - OperationArg::Statement(Statement::Equal(ak0, ak1)), - OperationArg::Statement(Statement::Equal(ak2, ak3)), - ) => { - // st_args0 == vec![ak0, ak1] - // st_args1 == vec![ak1, ak2] - // output statement Equals(ak0, ak2) - if ak1 == ak2 { - vec![StatementArg::Key(ak0), StatementArg::Key(ak3)] - } else { - return Err(Error::op_invalid_args( - "transitivity equality".to_string(), - )); - } - } - _ => { - return Err(Error::op_invalid_args( - "transitivity equality".to_string(), - )); - } + let arg_error = |s: &str| Error::op_invalid_args(s.to_string()); + let st = match op.0 { + OperationType::Native(o) => match (o, &op.1.as_slice()) { + (None, &[]) => Statement::None, + (NewEntry, &[OperationArg::Entry(k, v)]) => { + Statement::equal(AnchoredKey::from((SELF, k.as_str())), v.clone()) + } + (EqualFromEntries, &[a1, a2]) => { + let (r1, v1) = a1 + .value_and_ref() + .ok_or_else(|| arg_error("equal-from-entries"))?; + let (r2, v2) = a2 + .value_and_ref() + .ok_or_else(|| arg_error("equal-from-entries"))?; + if v1 == v2 { + Statement::equal(r1, r2) + } else { + return Err(arg_error("equal-from-entries")); } } - LtToNotEqual => match args[0].clone() { - OperationArg::Statement(Statement::Lt(ak0, ak1)) => { - vec![StatementArg::Key(ak0), StatementArg::Key(ak1)] + (NotEqualFromEntries, &[a1, a2]) => { + let (r1, v1) = a1 + .value_and_ref() + .ok_or_else(|| arg_error("not-equal-from-entries"))?; + let (r2, v2) = a2 + .value_and_ref() + .ok_or_else(|| arg_error("not-equal-from-entries"))?; + if v1 != v2 { + Statement::not_equal(r1, r2) + } else { + return Err(arg_error("not-equal-from-entries")); } - _ => { - return Err(Error::op_invalid_args("lt-to-neq".to_string())); + } + (LtFromEntries, &[a1, a2]) => { + let (r1, v1) = a1 + .value_and_ref() + .ok_or_else(|| arg_error("lt-from-entries"))?; + let (r2, v2) = a2 + .value_and_ref() + .ok_or_else(|| arg_error("lt-from-entries"))?; + if v1 < v2 { + Statement::lt(r1, r2) + } else { + return Err(arg_error("lt-from-entries")); } - }, - SumOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { - ( - OperationArg::Statement(Statement::ValueOf(ak0, v0)), - OperationArg::Statement(Statement::ValueOf(ak1, v1)), - OperationArg::Statement(Statement::ValueOf(ak2, v2)), - ) => { - let v0: i64 = v0.typed().try_into()?; - let v1: i64 = v1.typed().try_into()?; - let v2: i64 = v2.typed().try_into()?; - if v0 == v1 + v2 { - vec![ - StatementArg::Key(ak0), - StatementArg::Key(ak1), - StatementArg::Key(ak2), - ] - } else { - return Err(Error::op_invalid_args("sum-of".to_string())); - } + } + (LtEqFromEntries, &[a1, a2]) => { + let (r1, v1) = a1 + .value_and_ref() + .ok_or_else(|| arg_error("lt-eq-from-entries"))?; + let (r2, v2) = a2 + .value_and_ref() + .ok_or_else(|| arg_error("lt-eq-from-entries"))?; + if v1 <= v2 { + Statement::not_equal(r1, r2) + } else { + return Err(arg_error("lt-eq-from-entries")); } - _ => { - return Err(Error::op_invalid_args("sum-of".to_string())); + } + (CopyStatement, &[OperationArg::Statement(s)]) => s.clone(), + ( + TransitiveEqualFromStatements, + &[OperationArg::Statement(Statement::Equal(r1, r2)), OperationArg::Statement(Statement::Equal(r3, r4))], + ) => { + if r2 == r3 { + Statement::Equal(r1.clone(), r4.clone()) + } else { + return Err(arg_error("transitive-eq")); } - }, - ProductOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { - ( - OperationArg::Statement(Statement::ValueOf(ak0, v0)), - OperationArg::Statement(Statement::ValueOf(ak1, v1)), - OperationArg::Statement(Statement::ValueOf(ak2, v2)), - ) => { - let v0: i64 = v0.typed().try_into()?; - let v1: i64 = v1.typed().try_into()?; - let v2: i64 = v2.typed().try_into()?; - if v0 == v1 * v2 { - vec![ - StatementArg::Key(ak0), - StatementArg::Key(ak1), - StatementArg::Key(ak2), - ] - } else { - return Err(Error::op_invalid_args("product-of".to_string())); - } + } + (LtToNotEqual, &[OperationArg::Statement(Statement::Lt(r1, r2))]) => { + Statement::NotEqual(r1.clone(), r2.clone()) + } + (SumOf, &[a1, a2, a3]) => { + let (r1, v1) = a1 + .value_and_ref() + .ok_or_else(|| arg_error("sum-from-entries"))?; + let (r2, v2) = a2 + .value_and_ref() + .ok_or_else(|| arg_error("sum-from-entries"))?; + let (r3, v3) = a3 + .value_and_ref() + .ok_or_else(|| arg_error("sum-from-entries"))?; + if middleware::Operation::check_int_fn(v1, v2, v3, sum_op)? { + Statement::SumOf(r1, r2, r3) + } else { + return Err(arg_error("sum-from-entries")); } - _ => { - return Err(Error::op_invalid_args("product-of".to_string())); + } + (ProductOf, &[a1, a2, a3]) => { + let (r1, v1) = a1 + .value_and_ref() + .ok_or_else(|| arg_error("prod-from-entries"))?; + let (r2, v2) = a2 + .value_and_ref() + .ok_or_else(|| arg_error("prod-from-entries"))?; + let (r3, v3) = a3 + .value_and_ref() + .ok_or_else(|| arg_error("prod-from-entries"))?; + if middleware::Operation::check_int_fn(v1, v2, v3, prod_op)? { + Statement::ProductOf(r1, r2, r3) + } else { + return Err(arg_error("prod-from-entries")); } - }, - MaxOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { - ( - OperationArg::Statement(Statement::ValueOf(ak0, v0)), - OperationArg::Statement(Statement::ValueOf(ak1, v1)), - OperationArg::Statement(Statement::ValueOf(ak2, v2)), - ) => { - let v0: i64 = v0.typed().try_into()?; - let v1: i64 = v1.typed().try_into()?; - let v2: i64 = v2.typed().try_into()?; - if v0 == std::cmp::max(v1, v2) { - vec![ - StatementArg::Key(ak0), - StatementArg::Key(ak1), - StatementArg::Key(ak2), - ] - } else { - return Err(Error::op_invalid_args("max-of".to_string())); - } + } + (MaxOf, &[a1, a2, a3]) => { + let (r1, v1) = a1 + .value_and_ref() + .ok_or_else(|| arg_error("max-from-entries"))?; + let (r2, v2) = a2 + .value_and_ref() + .ok_or_else(|| arg_error("max-from-entries"))?; + let (r3, v3) = a3 + .value_and_ref() + .ok_or_else(|| arg_error("max-from-entries"))?; + if middleware::Operation::check_int_fn(v1, v2, v3, max_op)? { + Statement::MaxOf(r1, r2, r3) + } else { + return Err(arg_error("max-from-entries")); } - _ => { - return Err(Error::op_invalid_args("max-of".to_string())); + } + (HashOf, &[a1, a2, a3]) => { + let (r1, v1) = a1 + .value_and_ref() + .ok_or_else(|| arg_error("hash-from-entries"))?; + let (r2, v2) = a2 + .value_and_ref() + .ok_or_else(|| arg_error("hash-from-entries"))?; + let (r3, v3) = a3 + .value_and_ref() + .ok_or_else(|| arg_error("hash-from-entries"))?; + if v1 == &hash_op(v2.clone(), v3.clone()) { + Statement::HashOf(r1, r2, r3) + } else { + return Err(arg_error("hash-from-entries")); } - }, - HashOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { - ( - OperationArg::Statement(Statement::ValueOf(ak0, v0)), - OperationArg::Statement(Statement::ValueOf(ak1, v1)), - OperationArg::Statement(Statement::ValueOf(ak2, v2)), - ) => { - if Hash::from(v0.raw()) == hash_values(&[v1, v2]) { - vec![ - StatementArg::Key(ak0), - StatementArg::Key(ak1), - StatementArg::Key(ak2), - ] - } else { - return Err(Error::op_invalid_args("hash-of".to_string())); - } + } + (ContainsFromEntries, &[a1, a2, a3]) => { + let (r1, _v1) = a1 + .value_and_ref() + .ok_or_else(|| arg_error("contains-from-entries"))?; + let (r2, _v2) = a2 + .value_and_ref() + .ok_or_else(|| arg_error("contains-from-entries"))?; + let (r3, _v3) = a3 + .value_and_ref() + .ok_or_else(|| arg_error("contains-from-entries"))?; + // TODO: validate proof + Statement::Contains(r1, r2, r3) + } + (NotContainsFromEntries, &[a1, a2]) => { + let (r1, _v1) = a1 + .value_and_ref() + .ok_or_else(|| arg_error("contains-from-entries"))?; + let (r2, _v2) = a2 + .value_and_ref() + .ok_or_else(|| arg_error("contains-from-entries"))?; + // TODO: validate proof + Statement::NotContains(r1, r2) + } + (t, _) => { + if t.is_syntactic_sugar() { + return Err(Error::custom(format!( + "Unexpected syntactic sugar: {:?}", + t + ))); + } else { + return Err(arg_error("malformed operation")); } - _ => { - return Err(Error::op_invalid_args("hash-of".to_string())); - } - }, - ContainsFromEntries => self.op_args_entries(public, args)?, - NotContainsFromEntries => self.op_args_entries(public, args)?, - _ => Err(Error::custom(format!( - "Unexpected syntactic sugar: {:?}", - op_type - )))?, + } }, OperationType::Custom(cpr) => { let pred = &cpr.batch.predicates()[cpr.index]; - if pred.statements.len() != args.len() { + if pred.statements.len() != op.1.len() { return Err(Error::custom(format!( "Custom predicate operation needs {} statements but has {}.", pred.statements.len(), - args.len() + op.1.len() ))); } // All args should be statements to be pattern matched against statement templates. - let args = args.iter().map( + let args = op.1.iter().map( |a| match a { OperationArg::Statement(s) => Ok(s.clone()), _ => Err(Error::custom(format!("Invalid argument {} to operation corresponding to custom predicate {:?}.", a, cpr))) @@ -488,14 +477,20 @@ impl MainPodBuilder { } } let v_default = WildcardValue::PodId(SELF); - wildcard_map + let st_args: Vec<_> = wildcard_map .into_iter() .take(pred.args_len) - .map(|v| StatementArg::WildcardLiteral(v.unwrap_or_else(|| v_default.clone()))) - .collect() + .map(|v| v.unwrap_or_else(|| v_default.clone())) + .collect(); + Statement::Custom(cpr, st_args) } }; - let st = Statement::from_args(pred, st_args).expect("valid arguments"); + Ok(st) + } + + fn op(&mut self, public: bool, op: Operation) -> Result { + let op = Self::fill_in_aux(Self::lower_op(op))?; + let st = self.op_statement(op.clone())?; self.insert(public, (st, op)); Ok(self.statements[self.statements.len() - 1].clone()) @@ -514,7 +509,7 @@ impl MainPodBuilder { fn literal(&mut self, public: bool, value: Value) -> Result { let public_value = (public, value); if let Some(key) = self.literals.get(&public_value) { - Ok(Statement::ValueOf( + Ok(Statement::equal( AnchoredKey::new(SELF, key.clone()), public_value.1, )) @@ -575,14 +570,11 @@ impl MainPodBuilder { let type_statement = pod .pub_statements() .into_iter() - .find_map(|s| match s { - Statement::ValueOf(AnchoredKey { pod_id: id, key }, value) - if id == pod_id && key.hash() == type_key_hash => + .find_map(|s| match s.as_entry() { + Some((AnchoredKey { pod_id: id, key }, _)) + if id == &pod_id && key.hash() == type_key_hash => { - Some(Statement::ValueOf( - AnchoredKey::from((pod_id, KEY_TYPE)), - value, - )) + Some(s) } _ => None, }) @@ -648,13 +640,13 @@ impl MainPod { self.pod.id() } - /// Returns the value of a ValueOf statement with self id that defines key if it exists. + /// Returns the value of a Equal statement with self id that defines key if it exists. pub fn get(&self, key: impl Into) -> Option { let key: Key = key.into(); self.public_statements .iter() .find_map(|st| match st { - Statement::ValueOf(ak, value) + Statement::Equal(ValueRef::Key(ak), ValueRef::Literal(value)) if ak.pod_id == self.id() && ak.key.hash() == key.hash() => { Some(value) @@ -701,11 +693,7 @@ impl MainPodCompiler { fn compile_op_arg(&self, op_arg: &OperationArg) -> Option { match op_arg { OperationArg::Statement(s) => Some(s.clone()), - OperationArg::Literal(_v) => { - // OperationArg::Literal is a syntax sugar for the frontend. This is translated to - // a new ValueOf statement and it's key used instead. - unreachable!() - } + OperationArg::Literal(_v) => Some(Statement::None), OperationArg::Entry(_k, _v) => { // OperationArg::Entry is only used in the frontend. The (key, value) will only // appear in the ValueOf statement in the backend. This is because a new ValueOf @@ -1108,7 +1096,7 @@ pub mod tests { #[should_panic] #[test] - fn test_incorrect_pod() { + fn test_reject_duplicate_new_entry() { // try to insert the same key multiple times // right now this is not caught when you build the pod, // but it is caught on verify @@ -1117,7 +1105,7 @@ pub mod tests { let params = Params::default(); let vd_set = &*DEFAULT_VD_SET; let mut builder = MainPodBuilder::new(¶ms, &vd_set); - let st = Statement::ValueOf(AnchoredKey::from((SELF, "a")), Value::from(3)); + let st = Statement::equal(AnchoredKey::from((SELF, "a")), Value::from(3)); let op_new_entry = Operation( OperationType::Native(NativeOperation::NewEntry), vec![], @@ -1125,24 +1113,35 @@ pub mod tests { ); builder.insert(false, (st, op_new_entry.clone())); - let st = Statement::ValueOf(AnchoredKey::from((SELF, "a")), Value::from(28)); + let st = Statement::equal(AnchoredKey::from((SELF, "a")), Value::from(28)); builder.insert(false, (st, op_new_entry.clone())); let mut prover = MockProver {}; let pod = builder.prove(&mut prover, ¶ms).unwrap(); pod.pod.verify().unwrap(); + } + #[should_panic] + #[test] + fn test_reject_unsound_statement() { // try to insert a statement that doesn't follow from the operation // right now the mock prover catches this when it calls compile() + let params = Params::default(); + let vd_set = &*DEFAULT_VD_SET; let mut builder = MainPodBuilder::new(¶ms, &vd_set); let self_a = AnchoredKey::from((SELF, "a")); let self_b = AnchoredKey::from((SELF, "b")); - let value_of_a = Statement::ValueOf(self_a.clone(), Value::from(3)); - let value_of_b = Statement::ValueOf(self_b.clone(), Value::from(27)); + let value_of_a = Statement::equal(self_a.clone(), Value::from(3)); + let value_of_b = Statement::equal(self_b.clone(), Value::from(27)); + let op_new_entry = Operation( + OperationType::Native(NativeOperation::NewEntry), + vec![], + OperationAux::None, + ); builder.insert(false, (value_of_a.clone(), op_new_entry.clone())); builder.insert(false, (value_of_b.clone(), op_new_entry)); - let st = Statement::Equal(self_a, self_b); + let st = Statement::equal(self_a, self_b); let op = Operation( OperationType::Native(NativeOperation::EqualFromEntries), vec![ diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index e558ca4..0c4aaf8 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -2,7 +2,10 @@ use std::fmt; use crate::{ frontend::{MainPod, SignedPod}, - middleware::{AnchoredKey, OperationAux, OperationType, Statement, Value}, + middleware::{ + AnchoredKey, CustomPredicateRef, NativeOperation, OperationAux, OperationType, Statement, + Value, ValueRef, + }, }; #[derive(Clone, Debug, PartialEq)] @@ -18,7 +21,15 @@ impl OperationArg { pub(crate) fn value(&self) -> Option<&Value> { match self { Self::Literal(v) => Some(v), - Self::Statement(Statement::ValueOf(_, v)) => Some(v), + Self::Statement(Statement::Equal(_, ValueRef::Literal(v))) => Some(v), + _ => None, + } + } + + pub(crate) fn value_and_ref(&self) -> Option<(ValueRef, &Value)> { + match self { + Self::Literal(v) => Some((ValueRef::Literal(v.clone()), v)), + Self::Statement(Statement::Equal(k, ValueRef::Literal(v))) => Some((k.clone(), v)), _ => None, } } @@ -72,9 +83,9 @@ impl From<(&SignedPod, &str)> for OperationArg { .get(&key.into()) .cloned() .unwrap_or_else(|| panic!("Key {} is not present in POD: {}", key, pod)); - Self::Statement(Statement::ValueOf( - AnchoredKey::from((pod.id(), key)), - value, + Self::Statement(Statement::Equal( + AnchoredKey::from((pod.id(), key)).into(), + value.into(), )) } } @@ -84,10 +95,7 @@ impl From<(&MainPod, &str)> for OperationArg { let value = pod .get(key) .unwrap_or_else(|| panic!("Key {} is not present in POD: {}", key, pod)); - Self::Statement(Statement::ValueOf( - AnchoredKey::from((pod.id(), key)), - value, - )) + Self::Statement(Statement::equal(AnchoredKey::from((pod.id(), key)), value)) } } @@ -97,6 +105,12 @@ impl From for OperationArg { } } +impl From<&Statement> for OperationArg { + fn from(value: &Statement) -> Self { + value.clone().into() + } +} + impl> From<(&str, V)> for OperationArg { fn from((key, value): (&str, V)) -> Self { Self::Entry(key.to_string(), value.into()) @@ -118,3 +132,78 @@ impl fmt::Display for Operation { Ok(()) } } + +macro_rules! op_impl_oa { + ($fn_name: ident, $op_name: ident, 2) => { + pub fn $fn_name(a1: impl Into, a2: impl Into) -> Self { + Self( + OperationType::Native(NativeOperation::$op_name), + vec![a1.into(), a2.into()], + OperationAux::None, + ) + } + }; + + ($fn_name: ident, $op_name: ident, 3) => { + pub fn $fn_name( + a1: impl Into, + a2: impl Into, + a3: impl Into, + ) -> Self { + Self( + OperationType::Native(NativeOperation::$op_name), + vec![a1.into(), a2.into(), a3.into()], + OperationAux::None, + ) + } + }; +} + +macro_rules! op_impl_st { + ($fn_name: ident, $op_name: ident, 1) => { + pub fn $fn_name(a1: &Statement) -> Self { + Self( + OperationType::Native(NativeOperation::$op_name), + vec![a1.into()], + OperationAux::None, + ) + } + }; + + ($fn_name: ident, $op_name: ident, 2) => { + pub fn $fn_name(a1: &Statement, a2: &Statement) -> Self { + Self( + OperationType::Native(NativeOperation::$op_name), + vec![a1.into(), a2.into()], + OperationAux::None, + ) + } + }; +} + +impl Operation { + pub fn new_entry(a1: impl Into, a2: impl Into) -> Self { + Self( + OperationType::Native(NativeOperation::NewEntry), + vec![a1.into(), a2.into().into()], + OperationAux::None, + ) + } + op_impl_oa!(eq, EqualFromEntries, 2); + op_impl_oa!(ne, NotEqualFromEntries, 2); + op_impl_oa!(gt, GtFromEntries, 2); + op_impl_oa!(lt, LtFromEntries, 2); + op_impl_st!(transitive_eq, TransitiveEqualFromStatements, 2); + op_impl_st!(gt_to_ne, GtToNotEqual, 1); + op_impl_oa!(sum_of, SumOf, 3); + op_impl_oa!(product_of, ProductOf, 3); + op_impl_oa!(max_of, MaxOf, 3); + pub fn custom(cpr: CustomPredicateRef, args: Vec) -> Self { + Self(OperationType::Custom(cpr), args, OperationAux::None) + } + op_impl_oa!(dict_contains, DictContainsFromEntries, 3); + op_impl_oa!(dict_not_contains, DictNotContainsFromEntries, 2); + op_impl_oa!(set_contains, SetContainsFromEntries, 3); + op_impl_oa!(set_not_contains, SetNotContainsFromEntries, 2); + op_impl_oa!(array_contains, ArrayContainsFromEntries, 3); +} diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 6022ba9..0bbefed 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -115,7 +115,7 @@ mod tests { fn test_e2e_simple_request() -> Result<(), LangError> { let input = r#" REQUEST( - ValueOf(?ConstPod["my_val"], 0x0000000000000000000000000000000000000000000000000000000000000001) + Equal(?ConstPod["my_val"], 0x0000000000000000000000000000000000000000000000000000000000000001) Lt(?GovPod["dob"], ?ConstPod["my_val"]) ) "#; @@ -133,7 +133,7 @@ mod tests { // Expected structure let expected_templates = vec![ StatementTmpl { - pred: Predicate::Native(NativePredicate::ValueOf), + pred: Predicate::Native(NativePredicate::Equal), args: vec![ sta_ak(("ConstPod", 0), k("my_val")), // ?ConstPod["my_val"] -> Wildcard(0), Key("my_val") sta_lit(SELF_ID_HASH), @@ -158,7 +158,7 @@ mod tests { let input = r#" uses_private(A, private: Temp) = AND( Equal(?A["input_key"], ?Temp["const_key"]) - ValueOf(?Temp["const_key"], "some_value") + Equal(?Temp["const_key"], "some_value") ) "#; @@ -182,7 +182,7 @@ mod tests { ], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::ValueOf), + pred: Predicate::Native(NativePredicate::Equal), args: vec![ sta_ak(("Temp", 1), k("const_key")), // ?Temp["const_key"] -> Wildcard(1), Key("const_key") sta_lit("some_value"), // Literal("some_value") @@ -390,8 +390,8 @@ mod tests { Lt(?gov["dateOfBirth"], ?SELF_HOLDER_18Y["const_18y"]) Equal(?pay["startDate"], ?SELF_HOLDER_1Y["const_1y"]) Equal(?gov["socialSecurityNumber"], ?pay["socialSecurityNumber"]) - ValueOf(?SELF_HOLDER_18Y["const_18y"], 1169909388) - ValueOf(?SELF_HOLDER_1Y["const_1y"], 1706367566) + Equal(?SELF_HOLDER_18Y["const_18y"], 1169909388) + Equal(?SELF_HOLDER_1Y["const_1y"], 1706367566) ) "#; @@ -463,9 +463,9 @@ mod tests { sta_ak((wc_pay.name.as_str(), wc_pay.index), ssn_key.clone()), ], }, - // 5. ValueOf(?SELF_HOLDER_18Y["const_18y"], 1169909388) + // 5. Equal(?SELF_HOLDER_18Y["const_18y"], 1169909388) StatementTmpl { - pred: Predicate::Native(NativePredicate::ValueOf), + pred: Predicate::Native(NativePredicate::Equal), args: vec![ sta_ak( (wc_self_18y.name.as_str(), wc_self_18y.index), @@ -474,9 +474,9 @@ mod tests { sta_lit(now_minus_18y_val.clone()), ], }, - // 6. ValueOf(?SELF_HOLDER_1Y["const_1y"], 1706367566) + // 6. Equal(?SELF_HOLDER_1Y["const_1y"], 1706367566) StatementTmpl { - pred: Predicate::Native(NativePredicate::ValueOf), + pred: Predicate::Native(NativePredicate::Equal), args: vec![ sta_ak( (wc_self_1y.name.as_str(), wc_self_1y.index), @@ -518,19 +518,19 @@ mod tests { let input = r#" eth_friend(src_key, dst_key, private: attestation_pod) = AND( - ValueOf(?attestation_pod["_type"], 1) + Equal(?attestation_pod["_type"], 1) Equal(?attestation_pod["_signer"], SELF[?src_key]) Equal(?attestation_pod["attestation"], SELF[?dst_key]) ) eth_dos_distance_base(src_key, dst_key, distance_key) = AND( Equal(SELF[?src_key], SELF[?dst_key]) - ValueOf(SELF[?distance_key], 0) + Equal(SELF[?distance_key], 0) ) eth_dos_distance_ind(src_key, dst_key, distance_key, private: one_key, shorter_distance_key, intermed_key) = AND( eth_dos_distance(?src_key, ?dst_key, ?distance_key) - ValueOf(SELF[?one_key], 1) + Equal(SELF[?one_key], 1) SumOf(SELF[?distance_key], SELF[?shorter_distance_key], SELF[?one_key]) eth_friend(?intermed_key, ?dst_key) ) @@ -558,7 +558,7 @@ mod tests { // eth_friend (Index 0) let expected_friend_stmts = vec![ StatementTmpl { - pred: Predicate::Native(NativePredicate::ValueOf), + pred: Predicate::Native(NativePredicate::Equal), args: vec![ sta_ak(("attestation_pod", 2), k("_type")), // Pub(0-1), Priv(2) sta_lit(PodType::MockSigned), @@ -598,7 +598,7 @@ mod tests { ], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::ValueOf), + pred: Predicate::Native(NativePredicate::Equal), args: vec![sta_ak_self(ko_wc("distance_key", 2)), sta_lit(0i64)], }, ]; @@ -625,7 +625,7 @@ mod tests { ], }, StatementTmpl { - pred: Predicate::Native(NativePredicate::ValueOf), + pred: Predicate::Native(NativePredicate::Equal), args: vec![sta_ak_self(ko_wc("one_key", 3)), sta_lit(1i64)], // private arg }, StatementTmpl { diff --git a/src/lang/processor.rs b/src/lang/processor.rs index 263c399..9c5c7c3 100644 --- a/src/lang/processor.rs +++ b/src/lang/processor.rs @@ -27,7 +27,8 @@ fn get_span(pair: &Pair) -> (usize, usize) { pub fn native_predicate_from_string(s: &str) -> Option { match s { - "ValueOf" => Some(NativePredicate::ValueOf), + // TODO: update any code that still uses ValueOf to use Equal instead + "ValueOf" => Some(NativePredicate::Equal), "Equal" => Some(NativePredicate::Equal), "NotEqual" => Some(NativePredicate::NotEqual), // Syntactic sugar for Gt/GtEq is handled at a later stage @@ -347,17 +348,16 @@ fn validate_and_build_statement_template( ) -> Result { match pred { Predicate::Native(native_pred) => { - let (expected_arity, mapped_pred_for_arity_check) = match native_pred { - NativePredicate::Gt => (2, NativePredicate::Lt), - NativePredicate::GtEq => (2, NativePredicate::LtEq), - NativePredicate::ValueOf + let expected_arity = match native_pred { + NativePredicate::Gt + | NativePredicate::GtEq | NativePredicate::Equal | NativePredicate::NotEqual | NativePredicate::Lt | NativePredicate::LtEq | NativePredicate::SetContains | NativePredicate::DictNotContains - | NativePredicate::SetNotContains => (2, *native_pred), + | NativePredicate::SetNotContains => 2, NativePredicate::NotContains | NativePredicate::Contains | NativePredicate::ArrayContains @@ -365,8 +365,8 @@ fn validate_and_build_statement_template( | NativePredicate::SumOf | NativePredicate::ProductOf | NativePredicate::MaxOf - | NativePredicate::HashOf => (3, *native_pred), - NativePredicate::None | NativePredicate::False => (0, *native_pred), + | NativePredicate::HashOf => 3, + NativePredicate::None | NativePredicate::False => 0, }; if args.len() != expected_arity { @@ -378,30 +378,9 @@ fn validate_and_build_statement_template( }); } - if mapped_pred_for_arity_check == NativePredicate::ValueOf { - if !matches!(args.get(0), Some(BuilderArg::Key(..))) { - return Err(ProcessorError::TypeError { - expected: "Anchored Key".to_string(), - found: args - .get(0) - .map_or("None".to_string(), |a| format!("{:?}", a)), - item: format!("argument 1 of native predicate '{}'", stmt_name_str), - span: Some(stmt_span), - }); - } - if !matches!(args.get(1), Some(BuilderArg::Literal(..))) { - return Err(ProcessorError::TypeError { - expected: "Literal".to_string(), - found: args - .get(1) - .map_or("None".to_string(), |a| format!("{:?}", a)), - item: format!("argument 2 of native predicate '{}'", stmt_name_str), - span: Some(stmt_span), - }); - } - } else if expected_arity > 0 { + if expected_arity > 0 { for (i, arg) in args.iter().enumerate() { - if !matches!(arg, BuilderArg::Key(..)) { + if !matches!(arg, BuilderArg::Key(..) | BuilderArg::Literal(..)) { return Err(ProcessorError::TypeError { expected: "Anchored Key".to_string(), found: format!("{:?}", arg), @@ -1056,7 +1035,7 @@ mod processor_tests { fn test_fp_multiple_predicates() -> Result<(), ProcessorError> { let input = r#" pred1(X) = AND( Equal(?X["k"],?X["k"]) ) - pred2(Y, Z) = OR( ValueOf(?Y["v"], 123) ) + pred2(Y, Z) = OR( Equal(?Y["v"], 123) ) "#; let pairs = get_document_content_pairs(input)?; let params = Params::default(); diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 3c188fa..c587bd1 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -527,7 +527,7 @@ mod tests { "_".into(), vec![ st( - P::Native(NP::ValueOf), + P::Native(NP::Equal), vec![ STA::AnchoredKey(sow_wc(4), kow_wc(5)), STA::Literal(2.into()), @@ -558,8 +558,8 @@ mod tests { let custom_deduction = Operation::Custom( CustomPredicateRef::new(cust_pred_batch, 0), vec![ - Statement::ValueOf(AnchoredKey::from((SELF, "Some constant")), 2.into()), - Statement::ProductOf( + Statement::equal(AnchoredKey::from((SELF, "Some constant")), 2), + Statement::product_of( AnchoredKey::from((SELF, "Some value")), AnchoredKey::from((SELF, "Some constant")), AnchoredKey::from((SELF, "Some other value")), @@ -585,7 +585,7 @@ mod tests { "eth_friend_cp".into(), vec![ st( - P::Native(NP::ValueOf), + P::Native(NP::Equal), vec![ STA::AnchoredKey(sow_wc(4), KeyOrWildcard::Key("type".into())), STA::Literal(PodType::Signed.into()), @@ -626,7 +626,7 @@ mod tests { ], ), st( - P::Native(NP::ValueOf), + P::Native(NP::Equal), vec![ STA::AnchoredKey(sow_wc(4), kow_wc(5)), STA::Literal(0.into()), @@ -654,7 +654,7 @@ mod tests { ], ), st( - P::Native(NP::ValueOf), + P::Native(NP::Equal), vec![ STA::AnchoredKey(sow_wc(6), kow_wc(7)), STA::Literal(1.into()), @@ -776,8 +776,8 @@ mod tests { WildcardValue::Key(Key::from("Six")), ], ), - Statement::ValueOf(AnchoredKey::from((SELF, "One")), 1.into()), - Statement::SumOf( + Statement::equal(AnchoredKey::from((SELF, "One")), 1), + Statement::sum_of( AnchoredKey::from((SELF, "Seven")), AnchoredKey::from((pod_id4, "Six")), AnchoredKey::from((SELF, "One")), diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index dc3247a..3c17581 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -671,8 +671,7 @@ impl Default for Params { max_signed_pod_values: 8, max_public_statements: 10, num_public_statements_id: 16, - // TODO: Reduce to 5 or less after https://github.com/0xPARC/pod2/issues/229 - max_statement_args: 6, + max_statement_args: 5, max_operation_args: 5, max_custom_predicate_batches: 2, max_custom_predicate_verifications: 5, @@ -793,7 +792,7 @@ pub trait Pod: fmt::Debug + DynClone + Any { self.pub_statements() .into_iter() .filter_map(|st| match st { - Statement::ValueOf(ak, v) => Some((ak, v)), + Statement::Equal(ValueRef::Key(ak), ValueRef::Literal(v)) => Some((ak, v)), _ => None, }) .collect() diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 3bdfa68..cde111e 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -7,9 +7,9 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::primitives::merkletree::MerkleProof, middleware::{ - custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, - NativePredicate, Params, Predicate, Result, SelfOrWildcard, Statement, StatementArg, - StatementTmplArg, ToFields, Wildcard, WildcardValue, F, SELF, + custom::KeyOrWildcard, hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, + Error, NativePredicate, Params, Predicate, Result, SelfOrWildcard, Statement, StatementArg, + StatementTmplArg, ToFields, Value, ValueRef, Wildcard, WildcardValue, F, SELF, }, }; @@ -86,6 +86,12 @@ pub enum NativeOperation { GtToNotEqual = 1008, } +impl NativeOperation { + pub fn is_syntactic_sugar(self) -> bool { + (self as usize) >= 1000 + } +} + impl ToFields for NativeOperation { fn to_fields(&self, _params: &Params) -> Vec { vec![F::from_canonical_u64(*self as u64)] @@ -100,7 +106,7 @@ impl OperationType { match self { OperationType::Native(native_op) => match native_op { NativeOperation::None => Some(Predicate::Native(NativePredicate::None)), - NativeOperation::NewEntry => Some(Predicate::Native(NativePredicate::ValueOf)), + NativeOperation::NewEntry => Some(Predicate::Native(NativePredicate::Equal)), NativeOperation::CopyStatement => None, NativeOperation::EqualFromEntries => { Some(Predicate::Native(NativePredicate::Equal)) @@ -161,6 +167,22 @@ pub enum Operation { Custom(CustomPredicateRef, Vec), } +pub(crate) fn sum_op(x: i64, y: i64) -> i64 { + x + y +} + +pub(crate) fn prod_op(x: i64, y: i64) -> i64 { + x * y +} + +pub(crate) fn max_op(x: i64, y: i64) -> i64 { + x.max(y) +} + +pub(crate) fn hash_op(x: Value, y: Value) -> Value { + Value::from(hash_values(&[x, y])) +} + impl Operation { pub fn op_type(&self) -> OperationType { type OT = OperationType; @@ -219,56 +241,53 @@ impl Operation { pub fn op(op_code: OperationType, args: &[Statement], aux: &OperationAux) -> Result { type OA = OperationAux; type NO = NativeOperation; - let arg_tup = ( - args.first().cloned(), - args.get(1).cloned(), - args.get(2).cloned(), - ); Ok(match op_code { - OperationType::Native(o) => match (o, arg_tup, aux.clone(), args.len()) { - (NO::None, (None, None, None), OA::None, 0) => Self::None, - (NO::NewEntry, (None, None, None), OA::None, 0) => Self::NewEntry, - (NO::CopyStatement, (Some(s), None, None), OA::None, 1) => Self::CopyStatement(s), - (NO::EqualFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => { - Self::EqualFromEntries(s1, s2) + OperationType::Native(o) => match (o, &args, aux.clone()) { + (NO::None, &[], OA::None) => Self::None, + (NO::NewEntry, &[], OA::None) => Self::NewEntry, + (NO::CopyStatement, &[s], OA::None) => Self::CopyStatement(s.clone()), + (NO::EqualFromEntries, &[s1, s2], OA::None) => { + Self::EqualFromEntries(s1.clone(), s2.clone()) } - (NO::NotEqualFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => { - Self::NotEqualFromEntries(s1, s2) + (NO::NotEqualFromEntries, &[s1, s2], OA::None) => { + Self::NotEqualFromEntries(s1.clone(), s2.clone()) } - (NO::LtEqFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => { - Self::LtEqFromEntries(s1, s2) + (NO::LtEqFromEntries, &[s1, s2], OA::None) => { + Self::LtEqFromEntries(s1.clone(), s2.clone()) } - (NO::LtFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => { - Self::LtFromEntries(s1, s2) + (NO::LtFromEntries, &[s1, s2], OA::None) => { + Self::LtFromEntries(s1.clone(), s2.clone()) } - ( - NO::ContainsFromEntries, - (Some(s1), Some(s2), Some(s3)), - OA::MerkleProof(pf), - 3, - ) => Self::ContainsFromEntries(s1, s2, s3, pf), - ( - NO::NotContainsFromEntries, - (Some(s1), Some(s2), None), - OA::MerkleProof(pf), - 2, - ) => Self::NotContainsFromEntries(s1, s2, pf), - (NO::SumOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => Self::SumOf(s1, s2, s3), - (NO::ProductOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => { - Self::ProductOf(s1, s2, s3) + (NO::ContainsFromEntries, &[s1, s2, s3], OA::MerkleProof(pf)) => { + Self::ContainsFromEntries(s1.clone(), s2.clone(), s3.clone(), pf) } - (NO::MaxOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => Self::MaxOf(s1, s2, s3), - (NO::HashOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => { - Self::HashOf(s1, s2, s3) + (NO::NotContainsFromEntries, &[s1, s2], OA::MerkleProof(pf)) => { + Self::NotContainsFromEntries(s1.clone(), s2.clone(), pf) + } + (NO::SumOf, &[s1, s2, s3], OA::None) => { + Self::SumOf(s1.clone(), s2.clone(), s3.clone()) + } + (NO::ProductOf, &[s1, s2, s3], OA::None) => { + Self::ProductOf(s1.clone(), s2.clone(), s3.clone()) + } + (NO::MaxOf, &[s1, s2, s3], OA::None) => { + Self::MaxOf(s1.clone(), s2.clone(), s3.clone()) + } + (NO::HashOf, &[s1, s2, s3], OA::None) => { + Self::HashOf(s1.clone(), s2.clone(), s3.clone()) } _ => Err(Error::custom(format!( - "Ill-formed operation {:?} with arguments {:?}.", - op_code, args + "Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.", + op_code, + args.len(), + args, + aux )))?, }, OperationType::Custom(cpr) => Self::Custom(cpr, args.to_vec()), }) } + /// Checks the given operation against a statement, and prints information if the check does not pass pub fn check_and_log(&self, params: &Params, output_statement: &Statement) -> Result { let valid: bool = self.check(params, output_statement)?; @@ -278,59 +297,69 @@ impl Operation { } Ok(valid) } + + pub(crate) fn check_int_fn( + v1: &Value, + v2: &Value, + v3: &Value, + f: impl FnOnce(i64, i64) -> i64, + ) -> Result { + let i1: i64 = v1.typed().try_into()?; + let i2: i64 = v2.typed().try_into()?; + let i3: i64 = v3.typed().try_into()?; + Ok(i1 == f(i2, i3)) + } + /// Checks the given operation against a statement. pub fn check(&self, params: &Params, output_statement: &Statement) -> Result { use Statement::*; - match (self, output_statement) { - (Self::None, None) => Ok(true), - (Self::NewEntry, ValueOf(AnchoredKey { pod_id, .. }, _)) => Ok(pod_id == &SELF), - (Self::CopyStatement(s1), s2) => Ok(s1 == s2), - (Self::EqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), Equal(ak3, ak4)) => { - Ok(v1 == v2 && ak3 == ak1 && ak4 == ak2) - } - (Self::NotEqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), NotEqual(ak3, ak4)) => { - Ok(v1 != v2 && ak3 == ak1 && ak4 == ak2) - } - (Self::LtEqFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), LtEq(ak3, ak4)) => { - Ok(v1 <= v2 && ak3 == ak1 && ak4 == ak2) - } - (Self::LtFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), Lt(ak3, ak4)) => { - Ok(v1 < v2 && ak3 == ak1 && ak4 == ak2) + let deduction_err = || Error::invalid_deduction(self.clone(), output_statement.clone()); + let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err); + let b = match (self, output_statement) { + (Self::None, None) => true, + (Self::NewEntry, Equal(ValueRef::Key(AnchoredKey { pod_id, .. }), _)) => { + pod_id == &SELF } + (Self::CopyStatement(s1), s2) => s1 == s2, + (Self::EqualFromEntries(s1, s2), Equal(v3, v4)) => val(v3, s1)? == val(v4, s2)?, + (Self::NotEqualFromEntries(s1, s2), NotEqual(v3, v4)) => val(v3, s1)? != val(v4, s2)?, + (Self::LtEqFromEntries(s1, s2), LtEq(v3, v4)) => val(v3, s1)? <= val(v4, s2)?, + (Self::LtFromEntries(s1, s2), Lt(v3, v4)) => val(v3, s1)? < val(v4, s2)?, (Self::ContainsFromEntries(_, _, _, _), Contains(_, _, _)) => /* TODO */ { - Ok(true) + true } (Self::NotContainsFromEntries(_, _, _), NotContains(_, _)) => /* TODO */ { - Ok(true) + true } ( Self::TransitiveEqualFromStatements(Equal(ak1, ak2), Equal(ak3, ak4)), Equal(ak5, ak6), - ) => Ok(ak2 == ak3 && ak5 == ak1 && ak6 == ak4), - (Self::LtToNotEqual(Lt(ak1, ak2)), NotEqual(ak3, ak4)) => Ok(ak1 == ak3 && ak2 == ak4), - ( - Self::SumOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)), - SumOf(ak4, ak5, ak6), - ) => { - let v1: i64 = v1.typed().try_into()?; - let v2: i64 = v2.typed().try_into()?; - let v3: i64 = v3.typed().try_into()?; - Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3) + ) => ak2 == ak3 && ak5 == ak1 && ak6 == ak4, + (Self::LtToNotEqual(Lt(ak1, ak2)), NotEqual(ak3, ak4)) => ak1 == ak3 && ak2 == ak4, + (Self::SumOf(s1, s2, s3), SumOf(v4, v5, v6)) => { + Self::check_int_fn(&val(v4, s1)?, &val(v5, s2)?, &val(v6, s3)?, sum_op)? + } + (Self::ProductOf(s1, s2, s3), ProductOf(v4, v5, v6)) => { + Self::check_int_fn(&val(v4, s1)?, &val(v5, s2)?, &val(v6, s3)?, prod_op)? + } + (Self::MaxOf(s1, s2, s3), ProductOf(v4, v5, v6)) => { + Self::check_int_fn(&val(v4, s1)?, &val(v5, s2)?, &val(v6, s3)?, max_op)? + } + (Self::HashOf(s1, s2, s3), ProductOf(v4, v5, v6)) => { + val(v4, s1)? == hash_op(val(v5, s2)?, val(v6, s3)?) } (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) if batch == &cpr.batch && index == &cpr.index => { - check_custom_pred(params, cpr, args, s_args) + check_custom_pred(params, cpr, args, s_args)? } - _ => Err(Error::invalid_deduction( - self.clone(), - output_statement.clone(), - )), - } + _ => return Err(deduction_err()), + }; + Ok(b) } } @@ -494,3 +523,15 @@ impl fmt::Display for Operation { Ok(()) } } + +/// Returns the value associated with `output_ref`. +/// If `output_ref` is a concrete value, returns that value. +/// Otherwise, `output_ref` was constructed using an `Equal` statement, and `input_st` +/// must be that statement. +pub(crate) fn value_from_op(input_st: &Statement, output_ref: &ValueRef) -> Option { + match (input_st, output_ref) { + (Statement::None, ValueRef::Literal(v)) => Some(v.clone()), + (Statement::Equal(r1, ValueRef::Literal(v)), r2) if r1 == r2 => Some(v.clone()), + _ => None, + } +} diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index b36c0a1..894b4f7 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -23,17 +23,16 @@ pub const OPERATION_AUX_F_LEN: usize = 2; pub enum NativePredicate { None = 0, // Always true False = 1, // Always false - ValueOf = 2, - Equal = 3, - NotEqual = 4, - LtEq = 5, - Lt = 6, - Contains = 7, - NotContains = 8, - SumOf = 9, - ProductOf = 10, - MaxOf = 11, - HashOf = 12, + Equal = 2, + NotEqual = 3, + LtEq = 4, + Lt = 5, + Contains = 6, + NotContains = 7, + SumOf = 8, + ProductOf = 9, + MaxOf = 10, + HashOf = 11, // Syntactic sugar predicates. These predicates are not supported by the backend. The // frontend compiler is responsible of translating these predicates into the predicates above. @@ -168,33 +167,58 @@ impl fmt::Display for Predicate { #[serde(tag = "predicate", content = "args")] pub enum Statement { None, - ValueOf(AnchoredKey, Value), - Equal(AnchoredKey, AnchoredKey), - NotEqual(AnchoredKey, AnchoredKey), - LtEq(AnchoredKey, AnchoredKey), - Lt(AnchoredKey, AnchoredKey), + Equal(ValueRef, ValueRef), + NotEqual(ValueRef, ValueRef), + LtEq(ValueRef, ValueRef), + Lt(ValueRef, ValueRef), Contains( - /* root */ AnchoredKey, - /* key */ AnchoredKey, - /* value */ AnchoredKey, + /* root */ ValueRef, + /* key */ ValueRef, + /* value */ ValueRef, ), - NotContains(/* root */ AnchoredKey, /* key */ AnchoredKey), - SumOf(AnchoredKey, AnchoredKey, AnchoredKey), - ProductOf(AnchoredKey, AnchoredKey, AnchoredKey), - MaxOf(AnchoredKey, AnchoredKey, AnchoredKey), - HashOf(AnchoredKey, AnchoredKey, AnchoredKey), + NotContains(/* root */ ValueRef, /* key */ ValueRef), + SumOf(ValueRef, ValueRef, ValueRef), + ProductOf(ValueRef, ValueRef, ValueRef), + MaxOf(ValueRef, ValueRef, ValueRef), + HashOf(ValueRef, ValueRef, ValueRef), Custom(CustomPredicateRef, Vec), } +macro_rules! statement_constructor { + ($var_name: ident, $cons_name: ident, 2) => { + pub fn $var_name(v1: impl Into, v2: impl Into) -> Self { + Self::$cons_name(v1.into(), v2.into()) + } + }; + ($var_name: ident, $cons_name: ident, 3) => { + pub fn $var_name( + v1: impl Into, + v2: impl Into, + v3: impl Into, + ) -> Self { + Self::$cons_name(v1.into(), v2.into(), v3.into()) + } + }; +} + impl Statement { pub fn is_none(&self) -> bool { self == &Self::None } + statement_constructor!(equal, Equal, 2); + statement_constructor!(not_equal, NotEqual, 2); + statement_constructor!(lt_eq, LtEq, 2); + statement_constructor!(lt, Lt, 2); + statement_constructor!(contains, Contains, 3); + statement_constructor!(not_contains, NotContains, 2); + statement_constructor!(sum_of, SumOf, 3); + statement_constructor!(product_of, ProductOf, 3); + statement_constructor!(max_of, MaxOf, 3); + statement_constructor!(hash_of, HashOf, 3); pub fn predicate(&self) -> Predicate { use Predicate::*; match self { Self::None => Native(NativePredicate::None), - Self::ValueOf(_, _) => Native(NativePredicate::ValueOf), Self::Equal(_, _) => Native(NativePredicate::Equal), Self::NotEqual(_, _) => Native(NativePredicate::NotEqual), Self::LtEq(_, _) => Native(NativePredicate::LtEq), @@ -212,117 +236,66 @@ impl Statement { use StatementArg::*; match self.clone() { Self::None => vec![], - Self::ValueOf(ak, v) => vec![Key(ak), Literal(v)], - Self::Equal(ak1, ak2) => vec![Key(ak1), Key(ak2)], - Self::NotEqual(ak1, ak2) => vec![Key(ak1), Key(ak2)], - Self::LtEq(ak1, ak2) => vec![Key(ak1), Key(ak2)], - Self::Lt(ak1, ak2) => vec![Key(ak1), Key(ak2)], - Self::Contains(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], - Self::NotContains(ak1, ak2) => vec![Key(ak1), Key(ak2)], - Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], - Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], - Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], - Self::HashOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], + Self::Equal(ak1, ak2) => vec![ak1.into(), ak2.into()], + Self::NotEqual(ak1, ak2) => vec![ak1.into(), ak2.into()], + Self::LtEq(ak1, ak2) => vec![ak1.into(), ak2.into()], + Self::Lt(ak1, ak2) => vec![ak1.into(), ak2.into()], + Self::Contains(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()], + Self::NotContains(ak1, ak2) => vec![ak1.into(), ak2.into()], + Self::SumOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()], + Self::ProductOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()], + Self::MaxOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()], + Self::HashOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()], Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(WildcardLiteral)), } } + + pub fn as_entry(&self) -> Option<(&AnchoredKey, &Value)> { + if let Self::Equal(ValueRef::Key(k), ValueRef::Literal(v)) = self { + Some((k, v)) + } else { + None + } + } + pub fn from_args(pred: Predicate, args: Vec) -> Result { use Predicate::*; - let st: Result = match pred { - Native(NativePredicate::None) => Ok(Self::None), - Native(NativePredicate::ValueOf) => { - if let (StatementArg::Key(a0), StatementArg::Literal(v1)) = - (args[0].clone(), args[1].clone()) - { - Ok(Self::ValueOf(a0, v1)) - } else { - Err(Error::incorrect_statements_args()) - } + let st = match (pred, &args.as_slice()) { + (Native(NativePredicate::None), &[]) => Self::None, + (Native(NativePredicate::Equal), &[a1, a2]) => { + Self::Equal(a1.try_into()?, a2.try_into()?) } - Native(NativePredicate::Equal) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1)) = - (args[0].clone(), args[1].clone()) - { - Ok(Self::Equal(a0, a1)) - } else { - Err(Error::incorrect_statements_args()) - } + (Native(NativePredicate::NotEqual), &[a1, a2]) => { + Self::NotEqual(a1.try_into()?, a2.try_into()?) } - Native(NativePredicate::NotEqual) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1)) = - (args[0].clone(), args[1].clone()) - { - Ok(Self::NotEqual(a0, a1)) - } else { - Err(Error::incorrect_statements_args()) - } + (Native(NativePredicate::LtEq), &[a1, a2]) => { + Self::LtEq(a1.try_into()?, a2.try_into()?) } - Native(NativePredicate::LtEq) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1)) = - (args[0].clone(), args[1].clone()) - { - Ok(Self::LtEq(a0, a1)) - } else { - Err(Error::incorrect_statements_args()) - } + (Native(NativePredicate::Lt), &[a1, a2]) => Self::Lt(a1.try_into()?, a2.try_into()?), + (Native(NativePredicate::Contains), &[a1, a2, a3]) => { + Self::Contains(a1.try_into()?, a2.try_into()?, a3.try_into()?) } - Native(NativePredicate::Lt) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1)) = - (args[0].clone(), args[1].clone()) - { - Ok(Self::Lt(a0, a1)) - } else { - Err(Error::incorrect_statements_args()) - } + (Native(NativePredicate::NotContains), &[a1, a2]) => { + Self::NotContains(a1.try_into()?, a2.try_into()?) } - Native(NativePredicate::Contains) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = - (args[0].clone(), args[1].clone(), args[2].clone()) - { - Ok(Self::Contains(a0, a1, a2)) - } else { - Err(Error::incorrect_statements_args()) - } + (Native(NativePredicate::SumOf), &[a1, a2, a3]) => { + Self::SumOf(a1.try_into()?, a2.try_into()?, a3.try_into()?) } - Native(NativePredicate::NotContains) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1)) = - (args[0].clone(), args[1].clone()) - { - Ok(Self::NotContains(a0, a1)) - } else { - Err(Error::incorrect_statements_args()) - } + (Native(NativePredicate::ProductOf), &[a1, a2, a3]) => { + Self::ProductOf(a1.try_into()?, a2.try_into()?, a3.try_into()?) } - Native(NativePredicate::SumOf) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = - (args[0].clone(), args[1].clone(), args[2].clone()) - { - Ok(Self::SumOf(a0, a1, a2)) - } else { - Err(Error::incorrect_statements_args()) - } + (Native(NativePredicate::MaxOf), &[a1, a2, a3]) => { + Self::MaxOf(a1.try_into()?, a2.try_into()?, a3.try_into()?) } - Native(NativePredicate::ProductOf) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = - (args[0].clone(), args[1].clone(), args[2].clone()) - { - Ok(Self::ProductOf(a0, a1, a2)) - } else { - Err(Error::incorrect_statements_args()) - } + (Native(NativePredicate::HashOf), &[a1, a2, a3]) => { + Self::HashOf(a1.try_into()?, a2.try_into()?, a3.try_into()?) } - Native(NativePredicate::MaxOf) => { - if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) = - (args[0].clone(), args[1].clone(), args[2].clone()) - { - Ok(Self::MaxOf(a0, a1, a2)) - } else { - Err(Error::incorrect_statements_args()) - } + + (Native(np), _) => { + return Err(Error::custom(format!("Predicate {:?} is syntax sugar", np))) } - Native(np) => Err(Error::custom(format!("Predicate {:?} is syntax sugar", np))), - BatchSelf(_) => unreachable!(), - Custom(cpr) => { + (BatchSelf(_), _) => unreachable!(), + (Custom(cpr), _) => { let v_args: Result> = args .iter() .map(|x| match x { @@ -330,10 +303,10 @@ impl Statement { _ => Err(Error::incorrect_statements_args()), }) .collect(); - Ok(Self::Custom(cpr, v_args?)) + Self::Custom(cpr, v_args?) } }; - st + Ok(st) } } @@ -437,6 +410,57 @@ impl ToFields for StatementArg { } } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub enum ValueRef { + Literal(Value), + Key(AnchoredKey), +} + +impl From for StatementArg { + fn from(value: ValueRef) -> Self { + match value { + ValueRef::Literal(v) => StatementArg::Literal(v), + ValueRef::Key(v) => StatementArg::Key(v), + } + } +} + +impl TryFrom for ValueRef { + type Error = crate::middleware::Error; + fn try_from(value: StatementArg) -> std::result::Result { + match value { + StatementArg::Literal(v) => Ok(Self::Literal(v)), + StatementArg::Key(k) => Ok(Self::Key(k)), + _ => Err(Self::Error::invalid_statement_arg( + value, + "literal or key".to_string(), + )), + } + } +} + +impl TryFrom<&StatementArg> for ValueRef { + type Error = crate::middleware::Error; + fn try_from(value: &StatementArg) -> std::result::Result { + value.clone().try_into() + } +} + +impl From for ValueRef { + fn from(value: AnchoredKey) -> Self { + Self::Key(value) + } +} + +impl From for ValueRef +where + T: Into, +{ + fn from(value: T) -> Self { + Self::Literal(value.into()) + } +} + #[cfg(test)] mod tests { use super::*;