From dbd958dcca41d74eaf85499ad0def8246c15b795 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Wed, 1 Apr 2026 23:49:29 +0200 Subject: [PATCH] Allow entries as args in custom statements (#498) - Introduce a new operation ReplaceValueWithEntry that allows taking any statement and replacing literal arguments with entries given a matching Contains statement. - Allow entries as args in custom statements - Circuit optimization: For the public statements slots in the circuit we only support None and Copy which take at most 1 argument; but we were still doing max_statement_args random accesses per slot; so I reduced that to just 1 random access to a previous statement. --- src/backends/plonky2/circuits/common.rs | 34 ++-- src/backends/plonky2/circuits/mainpod.rs | 189 +++++++++++++++++----- src/backends/plonky2/mainpod/mod.rs | 132 +++++++++++++-- src/backends/plonky2/mainpod/statement.rs | 14 +- src/backends/plonky2/mock/mainpod.rs | 3 +- src/frontend/custom.rs | 6 +- src/frontend/mod.rs | 42 ++++- src/frontend/multi_pod/cost.rs | 3 +- src/frontend/operation.rs | 22 ++- src/lang/diagnostics.rs | 12 -- src/lang/error.rs | 6 - src/lang/frontend_ast_validate.rs | 93 +++-------- src/lang/mod.rs | 1 - src/middleware/basetypes.rs | 6 + src/middleware/custom.rs | 46 ++++-- src/middleware/mod.rs | 5 +- src/middleware/operation.rs | 76 ++++++++- src/middleware/statement.rs | 15 +- 18 files changed, 515 insertions(+), 190 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 7d25786..de53ee5 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -37,8 +37,8 @@ use crate::{ hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate, OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix, PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg, - StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, STATEMENT_ARG_F_LEN, - VALUE_SIZE, + StatementTmplArgPrefix, ToFields, Value, BASE_PARAMS, EMPTY_VALUE, F, HASH_SIZE, + STATEMENT_ARG_F_LEN, VALUE_SIZE, }, }; @@ -103,6 +103,20 @@ pub struct StatementArgTarget { pub elements: [Target; STATEMENT_ARG_F_LEN], } +impl Flattenable for StatementArgTarget { + fn flatten(&self) -> Vec { + self.elements.to_vec() + } + fn from_flattened(_params: &Params, vs: &[Target]) -> Self { + Self { + elements: vs.try_into().expect("STATEMENT_ARG_F_LEN elements"), + } + } + fn size(_params: &Params) -> usize { + STATEMENT_ARG_F_LEN + } +} + impl StatementArgTarget { pub fn set_targets(&self, pw: &mut PartialWitness, arg: &StatementArg) -> Result<()> { Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?) @@ -318,7 +332,7 @@ impl OperationTarget { .args() .iter() .chain(iter::repeat(&OperationArg::None)) - .take(params.max_operation_args) + .take(BASE_PARAMS.max_operation_args) .enumerate() { self.args[i].set_targets(pw, arg.as_usize())?; @@ -328,7 +342,7 @@ impl OperationTarget { fn size(params: &Params) -> usize { OperationTypeTarget::size(params) - + params.max_operation_args * IndexTarget::size(params) + + BASE_PARAMS.max_operation_args * IndexTarget::size(params) + IndexTarget::size(params) } } @@ -868,7 +882,7 @@ impl CustomPredicateVerifyEntryTarget { args: (0..params.max_custom_predicate_wildcards) .map(|_| builder.add_virtual_value()) .collect(), - op_args: (0..params.max_operation_args) + op_args: (0..BASE_PARAMS.max_operation_args) .map(|_| builder.add_virtual_statement(false)) .collect(), } @@ -898,7 +912,7 @@ impl CustomPredicateVerifyEntryTarget { cpv.op_args .iter() .chain(iter::repeat(&pad_op_arg)) - .take(params.max_operation_args), + .take(BASE_PARAMS.max_operation_args), ) { op_arg_target.set_targets(pw, op_arg)? } @@ -941,7 +955,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget { .expect("len = operation_type_size"), }; let (pos, size) = (pos + size, StatementTarget::size(params)); - let op_args = (0..params.max_operation_args) + let op_args = (0..BASE_PARAMS.max_operation_args) .map(|i| { StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size]) }) @@ -953,7 +967,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget { } } fn size(params: &Params) -> usize { - StatementTarget::size(params) * (1 + params.max_operation_args) + StatementTarget::size(params) * (1 + BASE_PARAMS.max_operation_args) + OperationTarget::size(params) } } @@ -1425,7 +1439,7 @@ impl CircuitBuilderPod for CircuitBuilder { fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { OperationTarget { op_type: self.add_virtual_operation_type(), - args: (0..params.max_operation_args) + args: (0..BASE_PARAMS.max_operation_args) .map(|_| IndexTarget::new_virtual(params.statement_table_size(), self)) .collect(), aux_index: IndexTarget::new_virtual(OperationAux::table_size(params), self), @@ -1735,7 +1749,7 @@ impl CircuitBuilderPod for CircuitBuilder { let num_chunks = array.len().div_ceil(CHUNK_LEN); for chunk in array.chunks(CHUNK_LEN) { let mut index_chunk = i.low; - // I we have several chunks and the last one is smaller (it's index needs less than 6 + // If we have several chunks and the last one is smaller (it's index needs less than 6 // bits), make it zero except when it's used so that the range check over the index // passes. if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 { diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 68114d2..0ac3bec 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -55,7 +55,7 @@ use crate::{ middleware::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix, - ToFields, Value, F, HASH_SIZE, + ToFields, Value, BASE_PARAMS, F, HASH_SIZE, }, }; // @@ -69,30 +69,37 @@ pub const PI_OFFSET_VDSROOT: usize = 4; pub const NUM_PUBLIC_INPUTS: usize = 8; -const MAX_VALUE_ARGS: usize = 4; +const MAX_VALUE_ARGS: usize = 5; struct StatementArgCache { rhs: ValueTarget, lhs: StatementArgTarget, valid: BoolTarget, + pred_is_none: BoolTarget, + is_reference: BoolTarget, + // if `is_reference` then this is the AnchoredKey found in the Contains statement + reference: StatementArgTarget, + // if `is_reference` then this is the value found in the Contains statement + value: ValueTarget, } -struct StatementCache { - equations: [StatementArgCache; MAX_VALUE_ARGS], - first_n_equations_valid: [BoolTarget; MAX_VALUE_ARGS], +struct StatementCache { + equations: [StatementArgCache; MAX_EQS], + first_n_equations_valid: [BoolTarget; MAX_EQS], op_args: Vec, } -impl StatementCache { +impl StatementCache { fn new( params: &Params, + max_operation_args: usize, builder: &mut CircuitBuilder, op: &OperationTarget, st: &StatementTarget, prev_statements: &[StatementTarget], ) -> Self { let op_args = if prev_statements.is_empty() { - (0..params.max_operation_args) + (0..max_operation_args) .map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[])) .collect_vec() } else { @@ -100,10 +107,10 @@ impl StatementCache { // converting a length 1 array into a scalar. op.args .iter() + .take(max_operation_args) .map(|i| builder.vec_ref(params, prev_statements, i)) .collect::>() }; - assert!(params.max_operation_args >= MAX_VALUE_ARGS); assert!(Params::max_statement_args() >= MAX_VALUE_ARGS); let equations = array::from_fn(|i| { let pred_is_none = op_args[i].has_native_type(builder, NativePredicate::None); @@ -117,9 +124,9 @@ impl StatementCache { let is_reference = builder.and(pred_is_contains, ref_is_value); let valid = builder.or(is_literal, is_reference); - let rhs_literal = st.args[i].as_value(); - let rhs_reference = op_args[i].args[2].as_value(); - let rhs = builder.select_value(pred_is_none, rhs_literal, rhs_reference); + let rhs_from_literal = st.args[i].as_value(); + let rhs_from_reference = op_args[i].args[2].as_value(); + let rhs = builder.select_value(pred_is_none, rhs_from_literal, rhs_from_reference); let lhs_literal = &st.args[i]; let lhs_reference = StatementArgTarget::anchored_key( builder, @@ -127,10 +134,22 @@ impl StatementCache { &op_args[i].args[1].as_value(), ); let lhs = builder.select_statement_arg(pred_is_none, lhs_literal, &lhs_reference); - StatementArgCache { rhs, lhs, valid } + StatementArgCache { + rhs, + lhs, + valid, + pred_is_none, + is_reference, + reference: lhs_reference, + value: rhs_from_reference, + } }); - let mut first_n_equations_valid = [equations[0].valid; MAX_VALUE_ARGS]; - for i in 1..MAX_VALUE_ARGS { + let mut first_n_equations_valid = if MAX_EQS != 0 { + [equations[0].valid; MAX_EQS] + } else { + [builder._false(); MAX_EQS] + }; + for i in 1..MAX_EQS { first_n_equations_valid[i] = builder.and(equations[i].valid, first_n_equations_valid[i - 1]); } @@ -145,7 +164,7 @@ impl StatementCache { /// /// If the operation argument is a statement of type `None`, then the value /// should be the corresponding argument of the current statement. - /// If the operation argument is a statement of type `Equals`, then the value + /// If the operation argument is a statement of type `Contains`, then the value /// should be the argument at index 1 of that statement. /// If the function successfully interprets the arguments as values, /// returns `True` along with those values. Otherwise, returns `False` @@ -158,6 +177,12 @@ impl StatementCache { } } +/// Statement cache for private statements +type StatementCachePriv = StatementCache; +/// Statement cache for public statements. Since the operations can only be None or Copy, no +/// equation is needed because none of these operations dereference entries. +type StatementCachePub = StatementCache<0>; + /// Specialized implementation of `verify_operation_circuit` for operations that generate public /// statement. This only allows operations to be None, NewEntry or Copy and accounts for the fact /// that public statements in the current implementation are always generated by copying private @@ -169,13 +194,15 @@ fn verify_operation_public_statement_circuit( op: &OperationTarget, prev_statements: &[StatementTarget], ) -> Result<()> { - let measure = measure_gates_begin!(builder, "OpVerify"); + let measure = measure_gates_begin!(builder, "OpVerifyPub"); // Verify that the operation `op` correctly generates the statement `st`. The operation // can reference any of the `prev_statements`. // TODO: Clean this up. let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); - let cache = StatementCache::new(params, builder, op, st, prev_statements); + // None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the + // StatementCache requires. + let cache = StatementCachePub::new(params, 1, builder, op, st, prev_statements); measure_gates_end!(builder, measure_resolve_op_args); let op_checks = vec![ @@ -406,7 +433,7 @@ fn verify_operation_circuit( prev_statements: &[StatementTarget], aux_table: &MuxTableTarget, ) -> Result<()> { - let measure = measure_gates_begin!(builder, "OpVerify"); + let measure = measure_gates_begin!(builder, "OpVerifyPriv"); let _true = builder._true(); let _false = builder._false(); @@ -414,7 +441,14 @@ fn verify_operation_circuit( // can reference any of the `prev_statements`. // TODO: Clean this up. let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); - let cache = StatementCache::new(params, builder, op, st, prev_statements); + let cache = StatementCachePriv::new( + params, + BASE_PARAMS.max_operation_args, + builder, + op, + st, + prev_statements, + ); measure_gates_end!(builder, measure_resolve_op_args); // Certain operations (e.g.: Contains/NotContains) will refer to one of the provided verified @@ -442,6 +476,7 @@ fn verify_operation_circuit( verify_sum_of_circuit(params, builder, st, &op.op_type, &cache), verify_product_of_circuit(params, builder, st, &op.op_type, &cache), verify_max_of_circuit(params, builder, st, &op.op_type, &cache), + verify_replace_value_with_entry_circuit(params, builder, st, &op.op_type, &cache), ]); } // Skip these if there are no resolved aux entries @@ -542,7 +577,7 @@ fn verify_contains_from_entries_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpContainsFromEntries"); let (aux_tag_ok, resolved_merkle_claim) = @@ -592,7 +627,7 @@ fn verify_not_contains_from_entries_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries"); let (aux_tag_ok, resolved_merkle_claim) = @@ -639,7 +674,7 @@ fn verify_merkle_insert_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "MerkleInsertOp"); let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = @@ -714,7 +749,7 @@ fn verify_merkle_update_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "MerkleUpdateOp"); let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = @@ -789,7 +824,7 @@ fn verify_merkle_delete_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "MerkleDeleteOp"); let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = @@ -883,7 +918,7 @@ fn verify_eq_neq_from_entries_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries"); let eq_op_st_code_ok = { @@ -932,9 +967,9 @@ fn verify_lt_lteq_from_entries_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { - let measure = measure_gates_begin!(builder, "OpLtLteqFromEntries"); + let measure = measure_gates_begin!(builder, "OpLtEqFromEntries"); let zero = ValueTarget::zero(builder); let one = ValueTarget::one(builder); @@ -1000,7 +1035,7 @@ fn verify_hash_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpHashOf"); let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf); @@ -1033,7 +1068,7 @@ fn verify_public_key_of_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpPublicKeyOf"); let (aux_tag_ok, resolved_pk_sk) = @@ -1069,7 +1104,7 @@ fn verify_signed_by_circuit( st: &StatementTarget, op_type: &OperationTypeTarget, aux: &TableEntryTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpSignedBy"); let (aux_tag_ok, resolved_msg_pk) = @@ -1104,7 +1139,7 @@ fn verify_sum_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpSumOf"); let value_zero = ValueTarget::zero(builder); @@ -1142,7 +1177,7 @@ fn verify_product_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpProductOf"); let value_zero = ValueTarget::zero(builder); @@ -1180,7 +1215,7 @@ fn verify_max_of_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op_type: &OperationTypeTarget, - cache: &StatementCache, + cache: &StatementCachePriv, ) -> BoolTarget { let measure = measure_gates_begin!(builder, "OpMaxOf"); let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf); @@ -1220,6 +1255,47 @@ fn verify_max_of_circuit( ok } +fn verify_replace_value_with_entry_circuit( + params: &Params, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + cache: &StatementCachePriv, +) -> BoolTarget { + let measure = measure_gates_begin!(builder, "OpReplaceValueWithEntry"); + let op_code_ok = op_type.has_native(builder, NativeOperation::ReplaceValueWithEntry); + + let st_in = &cache.op_args[BASE_PARAMS.max_statement_args]; + + let mut args = Vec::new(); + let mut args_ok = builder._true(); + for (arg_in, entry_cache) in zip_eq(&st_in.args, &cache.equations) { + // if the op_arg is None, keep the original argument, if it's a Contains swap the value by + // the reference Entry while checking that the value in Contains matches the original + // argument. + let arg = builder.select_flattenable( + params, + entry_cache.pred_is_none, + arg_in, + &entry_cache.reference, + ); + args.push(arg); + let arg_ref_ok = { + let arg_in_is_value = builder.statement_arg_is_value(arg_in); + let value_eq = builder.is_equal_flattenable(&arg_in.as_value(), &entry_cache.value); + builder.all([entry_cache.is_reference, arg_in_is_value, value_eq]) + }; + let arg_ok = builder.or(entry_cache.pred_is_none, arg_ref_ok); + args_ok = builder.and(args_ok, arg_ok); + } + let expected_statement = StatementTarget::new(*st_in.pred_hash(), args); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); + + let ok = builder.all([op_code_ok, args_ok, st_ok]); + measure_gates_end!(builder, measure); + ok +} + fn verify_transitive_eq_circuit( params: &Params, builder: &mut CircuitBuilder, @@ -1429,7 +1505,7 @@ fn make_custom_statement_circuit( ) -> Result<(StatementTarget, OperationTypeTarget)> { let measure = measure_gates_begin!(builder, "CustomOpVerify"); // Some sanity checks - assert_eq!(params.max_operation_args, op_args.len()); + assert_eq!(BASE_PARAMS.max_operation_args, op_args.len()); assert_eq!(params.max_custom_predicate_wildcards, args.len()); let (batch_id, index) = (custom_predicate.id, custom_predicate.index); @@ -1463,7 +1539,6 @@ fn make_custom_statement_circuit( .collect(); // expected_sts.len() == params.max_custom_predicate_arity // op_args.len() == params.max_operation_args; - assert!(Params::max_custom_predicate_arity() <= params.max_operation_args); let sts_eq: Vec<_> = expected_sts .iter() @@ -2076,7 +2151,8 @@ mod tests { frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, middleware::{ hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, - RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard, EMPTY_VALUE, + RawValue, StatementArg, StatementTmpl, StatementTmplArg, ValueRef, Wildcard, + BASE_PARAMS, EMPTY_VALUE, }, }; @@ -3068,6 +3144,33 @@ mod tests { operation_verify(st, op, prev_statements, Aux::signed_by(signed_by)) } + #[test] + fn test_operation_replace_value_with_entry() -> Result<()> { + let d = dict!({"a" => 42, "b" => 33}); + + // 0: None + // 1: Lt(5, 42) + let st_in: mainpod::Statement = Statement::lt(5, 42).into(); + // 2: Contains(d, "a", 42) + let st_entry: mainpod::Statement = Statement::contains(d.clone(), "a", 42).into(); + + let st_out: mainpod::Statement = + Statement::lt(5, ValueRef::Key(AnchoredKey::from((&d, "a")))).into(); + let mut op_args: Vec<_> = iter::repeat(OperationArg::None) + .take(BASE_PARAMS.max_statement_args + 1) + .collect(); + op_args[1] = OperationArg::Index(2); + op_args[BASE_PARAMS.max_statement_args] = OperationArg::Index(1); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ReplaceValueWithEntry), + op_args, + OperationAux::None, + ); + + let prev_statements = vec![Statement::None.into(), st_in, st_entry]; + operation_verify(st_out, op, prev_statements, Aux::default()) + } + fn helper_statement_arg_from_template( params: &Params, st_tmpl_arg: StatementTmplArg, @@ -3226,7 +3329,7 @@ mod tests { expected_st: Option, ) -> Result<()> { // Pad - for _ in op_args.len()..params.max_operation_args { + for _ in op_args.len()..BASE_PARAMS.max_operation_args { op_args.push(Statement::None); } for _ in args.len()..params.max_custom_predicate_wildcards { @@ -3275,6 +3378,10 @@ mod tests { Ok(data.verify(proof.clone())?) } + fn value_ref(v: impl Into) -> ValueRef { + v.into() + } + // TODO: Add negative tests #[test] fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { @@ -3309,7 +3416,7 @@ mod tests { let args = vec![Value::from(dict), Value::from(1234)]; let expected_st = Statement::Custom( custom_predicate.clone(), - vec![args[0].clone(), Value::from(0)], + vec![value_ref(args[0].clone()), value_ref(0)], ); helper_custom_operation_verify_gadget( @@ -3330,7 +3437,7 @@ mod tests { let args = vec![Value::from(dict), Value::from(0)]; let expected_st = Statement::Custom( custom_predicate.clone(), - vec![args[0].clone(), Value::from(0)], + vec![value_ref(args[0].clone()), value_ref(0)], ); helper_custom_operation_verify_gadget( @@ -3351,7 +3458,7 @@ mod tests { let args = vec![Value::from(dict), Value::from(1234)]; let expected_st = Statement::Custom( custom_predicate.clone(), - vec![args[0].clone(), Value::from(0)], + vec![value_ref(args[0].clone()), value_ref(0)], ); helper_custom_operation_verify_gadget( @@ -3403,7 +3510,7 @@ mod tests { let args = vec![Value::from(dict), Value::from(secret_dict)]; let expected_st = Statement::Custom( custom_predicate.clone(), - vec![args[0].clone(), Value::from(0)], + vec![value_ref(args[0].clone()), value_ref(0)], ); helper_custom_operation_verify_gadget( diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index ae1ade3..5e9df2e 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1,5 +1,5 @@ pub mod operation; -use crate::middleware::{wildcard_values_from_op_st, PodType}; +use crate::middleware::{wildcard_values_from_op_st, PodType, BASE_PARAMS}; pub mod statement; use std::iter; @@ -39,7 +39,7 @@ use crate::{ middleware::{ self, value_from_op, CustomPredicateRef, Error as MiddlewareError, Hash, MainPodInputs, MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg, - ToFields, VDSet, Value, + ToFields, VDSet, Value, ValueRef, }, timed, }; @@ -104,9 +104,20 @@ pub(crate) fn extract_custom_predicate_verifications( if let middleware::Operation::Custom(cpr, sts) = op { if let middleware::Statement::Custom(st_cpr, st_args) = st { assert_eq!(cpr, st_cpr); + // The custom operation outputs statements with literal arguments. They can be + // replaced by references later with ReplaceValueWithEntry. + let st_args = st_args + .iter() + .map(|arg| match arg { + ValueRef::Literal(v) => Ok(v.clone()), + _ => Err(Error::custom( + "custom operation cannot output entries as arguments", + )), + }) + .collect::>>()?; let normalized_pred = cpr.normalized_predicate(); let wildcard_values = - wildcard_values_from_op_st(params, &normalized_pred, sts, st_args) + wildcard_values_from_op_st(params, &normalized_pred, sts, &st_args) .expect("resolved wildcards"); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); let custom_predicate_table_index = custom_predicates @@ -329,8 +340,8 @@ pub fn pad_statement(s: &mut Statement) { fill_pad(&mut s.1, StatementArg::None, Params::max_statement_args()) } -fn pad_operation_args(params: &Params, args: &mut Vec) { - fill_pad(args, OperationArg::None, params.max_operation_args) +fn pad_operation_args(args: &mut Vec) { + fill_pad(args, OperationArg::None, BASE_PARAMS.max_operation_args) } /// Returns the statements from the given MainPodInputs, padding to the respective max lengths @@ -428,7 +439,7 @@ pub(crate) fn process_private_statements_operations( .map(|mid_arg| find_op_arg(statements, mid_arg)) .collect::>>()?; - pad_operation_args(params, &mut args); + pad_operation_args(&mut args); operations.push(Operation(op.op_type(), args, *aux)); } Ok(operations) @@ -459,7 +470,11 @@ pub(crate) fn process_public_statements_operations( OperationAux::None, ) }; - fill_pad(&mut op.1, OperationArg::None, params.max_operation_args); + fill_pad( + &mut op.1, + OperationArg::None, + BASE_PARAMS.max_operation_args, + ); operations.push(op); } Ok(operations) @@ -469,6 +484,7 @@ pub struct Prover {} impl MainPodProver for Prover { fn prove(&self, params: &Params, inputs: MainPodInputs) -> Result> { + assert_eq!(inputs.statements.len(), inputs.operations.len()); // Pad input recursive pods with empty pods if necessary let empty_pod = if inputs.pods.len() == params.max_input_pods { // We don't need padding so we skip creating an EmptyPod @@ -1005,7 +1021,6 @@ pub mod tests { max_input_pods_public_statements: 2, max_statements: 5, max_public_statements: 2, - max_operation_args: 5, max_custom_predicates: 2, max_custom_predicate_verifications: 2, max_custom_predicate_wildcards: 3, @@ -1070,7 +1085,6 @@ pub mod tests { max_input_pods: 0, max_statements: 9, max_public_statements: 4, - max_operation_args: 5, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, max_merkle_proofs_containers: 3, @@ -1140,7 +1154,6 @@ pub mod tests { max_input_pods: 0, max_statements: 6, max_public_statements: 2, - max_operation_args: 5, max_custom_predicate_wildcards: 4, max_custom_predicate_verifications: 2, max_merkle_proofs_containers: 0, @@ -1251,11 +1264,108 @@ pub mod tests { ); let st = middleware::Statement::Custom( cpr, - [1, 1, 2].into_iter().map(middleware::Value::from).collect(), + [1, 1, 2] + .into_iter() + .map(middleware::ValueRef::from) + .collect(), ); builder.insert((st.clone(), op)).unwrap(); builder.reveal(&st).unwrap(); let prover = Prover {}; builder.prove(&prover).unwrap(); } + + #[test] + fn test_replace_value_with_entry() { + let params = middleware::Params::default(); + let vd_set = &*DEFAULT_VD_SET; + let mut builder = MainPodBuilder::new(¶ms, vd_set); + let d = dict!({"a" => 42, "b" => 33}); + builder + .priv_op(frontend::Operation::dict_contains(d.clone(), "a", 42)) + .unwrap(); + let st = builder.priv_op(frontend::Operation::lt(5, 42)).unwrap(); + // Transform `Lt(5, 42)` into `Lt(5, d.a)` by using `DictContains(d, "a", 42)` + builder + .pub_op(frontend::Operation::replace_value_with_entry( + vec![None, Some((&d, "a"))], + st, + )) + .unwrap(); + + // Mock + let prover = MockProver {}; + let pod = builder.prove(&prover).unwrap(); + pod.pod.verify().unwrap(); + assert_eq!( + middleware::Statement::Lt( + middleware::ValueRef::Literal(Value::from(5)), + middleware::ValueRef::Key(middleware::AnchoredKey { + root: d.commitment(), + key: middleware::Key::from("a") + }) + ), + pod.public_statements[0] + ); + + // Real + let prover = Prover {}; + let pod = builder.prove(&prover).unwrap(); + pod.pod.verify().unwrap() + } + + #[test] + fn test_entry_custom_statement_arg() { + let params = middleware::Params::default(); + let vd_set = &*DEFAULT_VD_SET; + let input = r#" + PredA(x) = AND( + Lt(x, 100) + ) + + PredB(d) = AND( + PredA(d.x) + ) + "#; + let module = load_module(input, "my_mod", ¶ms, &[]).expect("lang parse"); + let pred_a = module.batch.predicate_ref_by_name("PredA").unwrap(); + let pred_b = module.batch.predicate_ref_by_name("PredB").unwrap(); + + let mut builder = MainPodBuilder::new(¶ms, vd_set); + let d = dict!({"x" => 42, "y" => 33}); + + let st_lt = builder.priv_op(frontend::Operation::lt(42, 100)).unwrap(); + let st_a = builder + .priv_op(frontend::Operation::custom(pred_a, [st_lt])) + .unwrap(); + builder + .priv_op(frontend::Operation::dict_contains(d.clone(), "x", 42)) + .unwrap(); + // Transform `PredA(42)` into `PredA(d.x)` by using `DictContains(d, "x", 42)` + let st_a1 = builder + .priv_op(frontend::Operation::replace_value_with_entry( + vec![Some((&d, "x"))], + st_a, + )) + .unwrap(); + + builder + .pub_op(frontend::Operation::custom(pred_b.clone(), [st_a1])) + .unwrap(); + + // Mock + let prover = MockProver {}; + let pod = builder.prove(&prover).unwrap(); + pod.pod.verify().unwrap(); + let expected = middleware::Statement::Custom( + pred_b, + vec![middleware::ValueRef::Literal(Value::from(d))], + ); + assert_eq!(expected, pod.public_statements[0]); + + // Real + let prover = Prover {}; + let pod = builder.prove(&prover).unwrap(); + pod.pod.verify().unwrap() + } } diff --git a/src/backends/plonky2/mainpod/statement.rs b/src/backends/plonky2/mainpod/statement.rs index 27776a6..64fe675 100644 --- a/src/backends/plonky2/mainpod/statement.rs +++ b/src/backends/plonky2/mainpod/statement.rs @@ -4,7 +4,9 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::error::{Error, Result}, - middleware::{self, NativePredicate, Predicate, StatementArg, ToFields, Value, BASE_PARAMS}, + middleware::{ + self, NativePredicate, Predicate, StatementArg, ToFields, Value, ValueRef, BASE_PARAMS, + }, }; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -96,15 +98,15 @@ impl TryFrom for middleware::Statement { )))?, }, Predicate::Custom(cpr) => { - let vs: Vec = proper_args + let args: Vec = proper_args .into_iter() .filter_map(|arg| match arg { - SA::None => None, - SA::Literal(v) => Some(v), - _ => unreachable!(), + StatementArg::Literal(v) => Some(ValueRef::Literal(v)), + StatementArg::Key(k) => Some(ValueRef::Key(k)), + StatementArg::None => None, }) .collect(); - S::Custom(cpr, vs) + S::Custom(cpr, args) } Predicate::Intro(ir) => { let vs: Vec = proper_args diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index dcb1355..b8c6a03 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -380,7 +380,8 @@ pub mod tests { great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request, zu_kyc_sign_dict_builders, MOCK_VD_SET, }, - frontend, middleware, + frontend::{self}, + middleware, middleware::{Signer as _, Value}, }; diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index a2614a0..8de6871 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -316,7 +316,9 @@ mod tests { backends::plonky2::mock::mainpod::MockProver, examples::{custom::eth_dos_batch, MOCK_VD_SET}, frontend::{MainPodBuilder, Operation}, - middleware::{self, containers::Set, CustomPredicateRef, Params, PodType, DEFAULT_VD_SET}, + middleware::{ + self, containers::Set, CustomPredicateRef, Params, PodType, ValueRef, DEFAULT_VD_SET, + }, }; #[test] @@ -507,7 +509,7 @@ mod tests { .find(|s| matches!(s, middleware::Statement::Custom(_, _))) .expect("should have a custom statement"); if let middleware::Statement::Custom(_, args) = custom_st { - assert_eq!(args[0], pred_b_hash); + assert_eq!(args[0], ValueRef::Literal(pred_b_hash)); } Ok(()) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index f23e374..b6e8691 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -4,7 +4,7 @@ use std::{ collections::{HashMap, HashSet}, convert::From, - fmt, + fmt, iter, }; use itertools::Itertools; @@ -15,9 +15,10 @@ pub use serialization::SerializedMainPod; use crate::middleware::{ self, check_custom_pred, containers::{Container, Dictionary}, - fill_wildcard_values, hash_op, max_op, prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, - MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey, RawValue, - Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, EMPTY_VALUE, + fill_wildcard_values, hash_op, max_op, prod_op, root_key_to_ak, sum_op, AnchoredKey, Hash, Key, + MainPodInputs, MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey, + RawValue, Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, BASE_PARAMS, + EMPTY_VALUE, }; mod custom; @@ -566,6 +567,37 @@ impl MainPodBuilder { // TODO: validate proof Statement::ContainerDelete(r1, r2, r3) } + (ReplaceValueWithEntry, &args, _) => { + let mut args = args.to_vec(); + if args.len() != BASE_PARAMS.max_statement_args + 1 { + return Err(Error::custom(format!( + "ReplaceValueWithEntry requires exactly {} args but {} were found", + BASE_PARAMS.max_statement_args + 1, + args.len() + ))); + } + let st = match args.pop().expect("valid vec len") { + OperationArg::Statement(st) => st, + _ => return Err(Error::custom("expected statement")), + }; + let new_st_args = iter::zip(st.args().into_iter(), args) + .map(|(st_arg, arg)| match (st_arg, arg) { + (st_arg, OperationArg::Statement(Statement::None)) => Ok(st_arg), + ( + StatementArg::Literal(st_arg_v), + OperationArg::Statement(Statement::Contains( + ValueRef::Literal(root), + ValueRef::Literal(key), + ValueRef::Literal(v), + )), + ) if st_arg_v == v => root_key_to_ak(&root, &key) + .map(StatementArg::Key) + .ok_or_else(native_arg_error), + _ => Err(Error::custom("unexpected operation argument")), + }) + .collect::>>()?; + Statement::from_args(st.predicate(), new_st_args)? + } (t, _, _) => { if t.is_syntactic_sugar() { return Err(Error::custom(format!( @@ -615,7 +647,7 @@ impl MainPodBuilder { .map(|v| v.unwrap_or_else(|| v_default.clone())) .collect(); check_custom_pred(&self.params, &cpr, &args, &st_args)?; - Statement::Custom(cpr, st_args) + Statement::Custom(cpr, st_args.into_iter().map(ValueRef::Literal).collect()) } }; Ok(st) diff --git a/src/frontend/multi_pod/cost.rs b/src/frontend/multi_pod/cost.rs index 0c0c2ef..2839ea8 100644 --- a/src/frontend/multi_pod/cost.rs +++ b/src/frontend/multi_pod/cost.rs @@ -111,7 +111,8 @@ impl StatementCost { // Syntactic sugar variants (lowered before proving) | NativeOperation::GtEqFromEntries | NativeOperation::GtFromEntries - | NativeOperation::GtToNotEqual => {} + | NativeOperation::GtToNotEqual + | NativeOperation::ReplaceValueWithEntry => {} } } OperationType::Custom(cpr) => { diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index a1045a5..9794e60 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -1,10 +1,10 @@ -use std::fmt; +use std::{fmt, iter}; use crate::{ frontend::SignedDict, middleware::{ containers::Dictionary, root_key_to_ak, CustomPredicateRef, NativeOperation, OperationAux, - OperationType, Signature, Statement, Value, ValueRef, + OperationType, Signature, Statement, Value, ValueRef, BASE_PARAMS, }, }; @@ -219,6 +219,24 @@ impl Operation { op_impl_oa!(set_insert, SetInsertFromEntries, 3); op_impl_oa!(set_delete, SetDeleteFromEntries, 3); op_impl_oa!(array_update, ArrayUpdateFromEntries, 4); + pub fn replace_value_with_entry(args: Vec>, st: Statement) -> Self { + assert!(args.len() <= BASE_PARAMS.max_statement_args); + let args = args + .into_iter() + .chain(iter::repeat(None)) + .take(BASE_PARAMS.max_statement_args) + .map(|a| match a { + None => OperationArg::Statement(Statement::None), + Some((dict, key)) => OperationArg::from((dict, key)), + }) + .chain(iter::once(OperationArg::Statement(st))) + .collect(); + Self( + OperationType::Native(NativeOperation::ReplaceValueWithEntry), + args, + OperationAux::None, + ) + } pub fn signed_by( msg: impl Into, pk: impl Into, diff --git a/src/lang/diagnostics.rs b/src/lang/diagnostics.rs index 0a1d770..7807318 100644 --- a/src/lang/diagnostics.rs +++ b/src/lang/diagnostics.rs @@ -174,18 +174,6 @@ fn render_validation_error( "second REQUEST here", ), - ValidationError::InvalidArgumentType { predicate, span } => { - let title = format!("invalid argument type for `{}`", predicate); - render_with_optional_span( - renderer, - source, - path, - &title, - span.as_ref(), - "anchored keys not allowed here", - ) - } - ValidationError::DuplicateWildcard { name, span } => { let title = format!("duplicate wildcard: {}", name); render_with_optional_span( diff --git a/src/lang/error.rs b/src/lang/error.rs index 769faf6..792d4d8 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -135,12 +135,6 @@ pub enum ValidationError { span: Option, }, - #[error("Invalid argument type for {predicate}: anchored keys not allowed")] - InvalidArgumentType { - predicate: String, - span: Option, - }, - #[error("Duplicate wildcard in predicate arguments: {name}")] DuplicateWildcard { name: String, span: Option }, diff --git a/src/lang/frontend_ast_validate.rs b/src/lang/frontend_ast_validate.rs index 0b7737d..ef3d395 100644 --- a/src/lang/frontend_ast_validate.rs +++ b/src/lang/frontend_ast_validate.rs @@ -522,7 +522,7 @@ impl Validator { } // Validate arguments - self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?; + self.validate_statement_args(stmt, wildcard_context)?; Ok(()) } @@ -530,75 +530,37 @@ impl Validator { fn validate_statement_args( &self, stmt: &StatementTmpl, - pred_info: Option<&PredicateInfo>, wildcard_context: Option<(&str, &WildcardScope)>, ) -> Result<(), ValidationError> { - // For custom predicates, only wildcards and literals are allowed - if matches!( - pred_info.map(|i| &i.kind), - Some(PredicateKind::Custom { .. }) - | Some(PredicateKind::BatchImported { .. }) - | Some(PredicateKind::ModuleImported { .. }) - ) { - for arg in &stmt.args { - match arg { - StatementTmplArg::AnchoredKey(_) => { - return Err(ValidationError::InvalidArgumentType { - predicate: stmt.predicate.predicate_name().to_string(), - span: stmt.span, - }); - } - StatementTmplArg::Wildcard(id) => { - if let Some((pred_name, scope)) = wildcard_context { - if !scope.wildcards.contains_key(&id.name) { - return Err(ValidationError::UndefinedWildcard { - name: id.name.clone(), - pred_name: pred_name.to_string(), - span: id.span, - }); - } + for arg in &stmt.args { + match arg { + StatementTmplArg::Wildcard(id) => { + if let Some((pred_name, scope)) = wildcard_context { + if !scope.wildcards.contains_key(&id.name) { + return Err(ValidationError::UndefinedWildcard { + name: id.name.clone(), + pred_name: pred_name.to_string(), + span: id.span, + }); } } - StatementTmplArg::Literal(lit) => { - self.validate_literal_value(lit)?; - } - StatementTmplArg::SelfPredicateHash(id) => { - self.validate_self_predicate_hash(id, wildcard_context)?; - } } - } - } else { - // Native predicates can have anchored keys - for arg in &stmt.args { - match arg { - StatementTmplArg::Wildcard(id) => { - if let Some((pred_name, scope)) = wildcard_context { - if !scope.wildcards.contains_key(&id.name) { - return Err(ValidationError::UndefinedWildcard { - name: id.name.clone(), - pred_name: pred_name.to_string(), - span: id.span, - }); - } + StatementTmplArg::AnchoredKey(ak) => { + if let Some((pred_name, scope)) = wildcard_context { + if !scope.wildcards.contains_key(&ak.root.name) { + return Err(ValidationError::UndefinedWildcard { + name: ak.root.name.clone(), + pred_name: pred_name.to_string(), + span: ak.root.span, + }); } } - StatementTmplArg::AnchoredKey(ak) => { - if let Some((pred_name, scope)) = wildcard_context { - if !scope.wildcards.contains_key(&ak.root.name) { - return Err(ValidationError::UndefinedWildcard { - name: ak.root.name.clone(), - pred_name: pred_name.to_string(), - span: ak.root.span, - }); - } - } - } - StatementTmplArg::Literal(lit) => { - self.validate_literal_value(lit)?; - } - StatementTmplArg::SelfPredicateHash(id) => { - self.validate_self_predicate_hash(id, wildcard_context)?; - } + } + StatementTmplArg::Literal(lit) => { + self.validate_literal_value(lit)?; + } + StatementTmplArg::SelfPredicateHash(id) => { + self.validate_self_predicate_hash(id, wildcard_context)?; } } } @@ -839,10 +801,7 @@ mod tests { module_hash ); let result = parse_and_validate_request(&input, &available_modules); - assert!(matches!( - result, - Err(ValidationError::InvalidArgumentType { .. }) - )); + assert!(result.is_ok()); } #[test] diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 5674f53..291f7a6 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -578,7 +578,6 @@ mod tests { max_input_pods: 3, max_statements: 31, max_public_statements: 10, - max_operation_args: 5, max_custom_predicate_wildcards: 12, ..Default::default() }; diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index e6af211..0012251 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -169,6 +169,12 @@ pub struct Hash( pub [F; HASH_SIZE], ); +impl Hash { + pub fn raw(self) -> RawValue { + RawValue::from(self) + } +} + impl From for HashOut { fn from(hash: Hash) -> HashOut { HashOut { elements: hash.0 } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index cf6d9be..e5c7285 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -436,7 +436,7 @@ impl fmt::Display for CustomPredicate { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)] +#[derive(Clone, PartialEq, Eq, Serialize, JsonSchema)] enum CustomPredicateBatchData { Full { #[serde(skip)] @@ -449,6 +449,20 @@ enum CustomPredicateBatchData { }, } +// Explicit implementation of Debug to skip the merkle tree +impl fmt::Debug for CustomPredicateBatchData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self { + Self::Full { mt, predicates } => f + .debug_struct("Full") + .field("id", &mt.root()) + .field("predicates", &predicates) + .finish(), + Self::Opaque { id } => f.debug_struct("Opaque").field("id", &id).finish(), + } + } +} + // TODO: Rename Batch for Module everywhere in the code base impl CustomPredicateBatchData { fn new_full(predicates: Vec) -> Self { @@ -630,7 +644,7 @@ mod tests { middleware::{ AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate, Operation, Params, Predicate, Statement, StatementTmpl, - StatementTmplArg, + StatementTmplArg, ValueRef, }, }; @@ -653,6 +667,9 @@ mod tests { fn names(names: &[&str]) -> Vec { names.iter().map(|s| s.to_string()).collect() } + fn value_ref(v: impl Into) -> ValueRef { + v.into() + } #[allow(clippy::upper_case_acronyms)] type STA = StatementTmplArg; @@ -701,7 +718,7 @@ mod tests { }); let custom_statement = Statement::Custom( CustomPredicateRef::new(cust_pred_batch.clone(), 0), - vec![Value::from(d0.clone())], + vec![value_ref(d0.clone())], ); let custom_deduction = Operation::Custom( @@ -833,7 +850,7 @@ mod tests { // Example statement let ethdos_example = Statement::Custom( CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), - vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)], + vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)], ); // Copies should work. @@ -842,7 +859,7 @@ mod tests { // This could arise as the inductive step. let ethdos_ind_example = Statement::Custom( CustomPredicateRef::new(eth_dos_distance_batch.clone(), 1), - vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)], + vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)], ); assert!(Operation::Custom( @@ -857,12 +874,12 @@ mod tests { let ethdos_facts = vec![ Statement::Custom( CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), - vec![Value::from("Alice"), Value::from("Charlie"), Value::from(6)], + vec![value_ref("Alice"), value_ref("Charlie"), value_ref(6)], ), Statement::sum_of(Value::from(7), Value::from(6), Value::from(1)), Statement::Custom( CustomPredicateRef::new(eth_friend_batch.clone(), 0), - vec![Value::from("Charlie"), Value::from("Bob")], + vec![value_ref("Charlie"), value_ref("Bob")], ), ]; @@ -959,7 +976,10 @@ mod tests { let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())]; // The output statement - let output_st = Statement::Custom(pred_b_ref.clone(), vec![some_value.clone()]); + let output_st = Statement::Custom( + pred_b_ref.clone(), + vec![ValueRef::Literal(some_value.clone())], + ); // This should pass assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(¶ms, &output_st)?); @@ -1024,12 +1044,18 @@ mod tests { // Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())]; - let st_a = Statement::Custom(pred_a_ref.clone(), vec![pred_b_hash.clone()]); + let st_a = Statement::Custom( + pred_a_ref.clone(), + vec![ValueRef::Literal(pred_b_hash.clone())], + ); assert!(Operation::Custom(pred_a_ref, op_a).check(¶ms, &st_a)?); // Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())]; - let st_b = Statement::Custom(pred_b_ref.clone(), vec![pred_a_hash.clone()]); + let st_b = Statement::Custom( + pred_b_ref.clone(), + vec![ValueRef::Literal(pred_a_hash.clone())], + ); assert!(Operation::Custom(pred_b_ref, op_b).check(¶ms, &st_b)?); Ok(()) diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 19ca2c2..82675d7 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -768,6 +768,8 @@ pub struct BaseParams { /// in a custom predicate pub max_custom_predicate_arity: usize, pub max_depth_custom_batch_mt: usize, + // This value depends on `max_custom_predicate_arity` + pub max_operation_args: usize, } pub const BASE_PARAMS: BaseParams = BaseParams { @@ -775,6 +777,7 @@ pub const BASE_PARAMS: BaseParams = BaseParams { max_statement_args: 5, max_custom_predicate_arity: 5, max_depth_custom_batch_mt: 16, // up to 65k (2^16) custom predicates in a batch + max_operation_args: 5 + 1, }; /// Params: non dynamic parameters that define the circuit. @@ -785,7 +788,6 @@ pub struct Params { pub max_input_pods_public_statements: usize, pub max_statements: usize, pub max_public_statements: usize, - pub max_operation_args: usize, // max number of different custom predicates that can be used in a MainPod pub max_custom_predicates: usize, // max number of operations using custom predicates that can be verified in the MainPod @@ -815,7 +817,6 @@ impl Default for Params { max_input_pods_public_statements: 8, max_statements: 48, max_public_statements: 8, - max_operation_args: 5, max_custom_predicates: 8, max_custom_predicate_verifications: 8, max_custom_predicate_wildcards: 8, diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 1793e4d..8d3316c 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -14,7 +14,7 @@ use crate::{ hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key, MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result, Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, Value, ValueRef, - Wildcard, F, + Wildcard, BASE_PARAMS, F, }, }; @@ -89,6 +89,7 @@ pub enum NativeOperation { ContainerInsertFromEntries = 16, ContainerUpdateFromEntries = 17, ContainerDeleteFromEntries = 18, + ReplaceValueWithEntry = 19, // Syntactic sugar operations. These operations are not supported by the backend. The // frontend compiler is responsible of translating these operations into the operations above. @@ -164,6 +165,7 @@ impl OperationType { NativeOperation::ContainerDeleteFromEntries => { Some(Predicate::Native(NativePredicate::ContainerDelete)) } + NativeOperation::ReplaceValueWithEntry => None, no => unreachable!("Unexpected syntactic sugar op {:?}", no), }, OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())), @@ -219,6 +221,10 @@ pub enum Operation { /* key */ Statement, /* proof */ MerkleTreeStateTransitionProof, ), + ReplaceValueWithEntry( + /* Contains/None len=max_statement_args */ Vec, + /* to copy */ Statement, + ), Custom(CustomPredicateRef, Vec), } @@ -270,6 +276,7 @@ impl Operation { OT::Native(ContainerUpdateFromEntries) } Self::ContainerDeleteFromEntries(_, _, _, _) => OT::Native(ContainerDeleteFromEntries), + Self::ReplaceValueWithEntry(_, _) => OT::Native(ReplaceValueWithEntry), Self::Custom(cpr, _) => OT::Custom(cpr.clone()), } } @@ -295,6 +302,11 @@ impl Operation { Self::ContainerInsertFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4], Self::ContainerUpdateFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4], Self::ContainerDeleteFromEntries(s1, s2, s3, _pf) => vec![s1, s2, s3], + Self::ReplaceValueWithEntry(args, s) => { + let mut sts = args; + sts.push(s); + sts + } Self::Custom(_, args) => args, } } @@ -377,6 +389,18 @@ impl Operation { &[s1, s2, s3], OA::MerkleTreeStateTransitionProof(pf), ) => Self::ContainerDeleteFromEntries(s1.clone(), s2.clone(), s3.clone(), pf), + (NO::ReplaceValueWithEntry, args, OA::None) => { + let mut args = args.to_vec(); + if args.len() != BASE_PARAMS.max_statement_args + 1 { + return Err(Error::custom(format!( + "ReplaceValueWithEntry requires exactly {} args but {} were found", + BASE_PARAMS.max_statement_args + 1, + args.len() + ))); + } + let st = args.pop().expect("valid vec len"); + Self::ReplaceValueWithEntry(args, st) + } _ => Err(Error::custom(format!( "Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.", op_code, @@ -422,6 +446,38 @@ impl Operation { Ok(sig.verify(pk, msg.raw())) } + fn check_replace_value_with_entry( + entries: &[Statement], + st_in: &Statement, + expected_st_out: &Statement, + ) -> Result { + if entries.len() != BASE_PARAMS.max_statement_args { + return Ok(false); + } + let args = iter::zip(st_in.args(), entries) + .map(|(arg_in, entry)| match (arg_in, entry) { + (arg_in, Statement::None) => Ok(arg_in), + ( + StatementArg::Literal(v_in), + Statement::Contains( + ValueRef::Literal(root), + ValueRef::Literal(key), + ValueRef::Literal(v), + ), + ) if v == &v_in => Ok(StatementArg::Key(AnchoredKey::new( + Hash::from(root.raw()), + Key::from(key.as_str().ok_or_else(|| Error::custom("not a string"))?), + ))), + _ => Err(Error::custom( + "invalid statement argument in ReplaceValueWithEntry", + )), + }) + .collect::>>()?; + + let st_out = Statement::from_args(st_in.predicate(), args)?; + Ok(&st_out == expected_st_out) + } + /// Checks the given operation against a statement. pub fn check(&self, params: &Params, output_statement: &Statement) -> Result { use Statement::*; @@ -541,7 +597,19 @@ impl Operation { (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) if batch == &cpr.batch && index == &cpr.index => { - check_custom_pred(params, cpr, args, s_args).map(|_| true)? + // The custom operation outputs statements with literal arguments. They can be + // replaced by references later with ReplaceValueWithEntry. + let s_args = s_args + .iter() + .map(|arg| match arg { + ValueRef::Literal(v) => Ok(v.clone()), + _ => Err(deduction_err()), + }) + .collect::>>()?; + check_custom_pred(params, cpr, args, &s_args).map(|_| true)? + } + (Self::ReplaceValueWithEntry(entries, st_in), st_out) => { + Self::check_replace_value_with_entry(entries, st_in, st_out)? } _ => return Err(deduction_err()), }; @@ -648,9 +716,9 @@ pub fn wildcard_values_from_op_st( params: &Params, pred: &CustomPredicate, op_args: &[Statement], - st_args: &[Value], + resolved_st_args: &[Value], ) -> Result> { - let mut wildcard_map = st_args + let mut wildcard_map = resolved_st_args .iter() .map(|v| Some(v.clone())) .chain(core::iter::repeat(None)) diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index d3e0534..b5c1f60 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -311,7 +311,7 @@ pub enum Statement { /* old_root */ ValueRef, /* key */ ValueRef, ), - Custom(CustomPredicateRef, Vec), + Custom(CustomPredicateRef, Vec), Intro(IntroPredicateRef, Vec), } @@ -407,7 +407,7 @@ impl Statement { vec![ak1.into(), ak2.into(), ak3.into(), ak4.into()] } Self::ContainerDelete(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()], - Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(Literal)), + Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(StatementArg::from)), Self::Intro(_, args) => Vec::from_iter(args.into_iter().map(Literal)), } } @@ -478,14 +478,11 @@ impl Statement { } (BatchSelf(_), _) => unreachable!(), (Custom(cpr), _) => { - let v_args: Result> = args + let v_args = args .iter() - .map(|x| match x { - StatementArg::Literal(v) => Ok(v.clone()), - _ => Err(Error::incorrect_statements_args()), - }) - .collect(); - Self::Custom(cpr, v_args?) + .map(|x| x.try_into()) + .collect::>>()?; + Self::Custom(cpr, v_args) } (Intro(ir), _) => { let v_args: Result> = args