Allow entries as args in custom statements (#498)

- Introduce a new operation ReplaceValueWithEntry that allows taking any statement and replacing literal arguments with entries given a matching Contains statement.
- Allow entries as args in custom statements
- Circuit optimization: For the public statements slots in the circuit we only support None and Copy which take at most 1 argument; but we were still doing max_statement_args random accesses per slot; so I reduced that to just 1 random access to a previous statement.
This commit is contained in:
Eduard S. 2026-04-01 23:49:29 +02:00 committed by GitHub
parent 22d25e5cb2
commit dbd958dcca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 515 additions and 190 deletions

View file

@ -37,8 +37,8 @@ use crate::{
hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate, hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate,
OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix, OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix,
PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg, PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg,
StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, STATEMENT_ARG_F_LEN, StatementTmplArgPrefix, ToFields, Value, BASE_PARAMS, EMPTY_VALUE, F, HASH_SIZE,
VALUE_SIZE, STATEMENT_ARG_F_LEN, VALUE_SIZE,
}, },
}; };
@ -103,6 +103,20 @@ pub struct StatementArgTarget {
pub elements: [Target; STATEMENT_ARG_F_LEN], pub elements: [Target; STATEMENT_ARG_F_LEN],
} }
impl Flattenable for StatementArgTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
Self {
elements: vs.try_into().expect("STATEMENT_ARG_F_LEN elements"),
}
}
fn size(_params: &Params) -> usize {
STATEMENT_ARG_F_LEN
}
}
impl StatementArgTarget { impl StatementArgTarget {
pub fn set_targets(&self, pw: &mut PartialWitness<F>, arg: &StatementArg) -> Result<()> { pub fn set_targets(&self, pw: &mut PartialWitness<F>, arg: &StatementArg) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?) Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?)
@ -318,7 +332,7 @@ impl OperationTarget {
.args() .args()
.iter() .iter()
.chain(iter::repeat(&OperationArg::None)) .chain(iter::repeat(&OperationArg::None))
.take(params.max_operation_args) .take(BASE_PARAMS.max_operation_args)
.enumerate() .enumerate()
{ {
self.args[i].set_targets(pw, arg.as_usize())?; self.args[i].set_targets(pw, arg.as_usize())?;
@ -328,7 +342,7 @@ impl OperationTarget {
fn size(params: &Params) -> usize { fn size(params: &Params) -> usize {
OperationTypeTarget::size(params) OperationTypeTarget::size(params)
+ params.max_operation_args * IndexTarget::size(params) + BASE_PARAMS.max_operation_args * IndexTarget::size(params)
+ IndexTarget::size(params) + IndexTarget::size(params)
} }
} }
@ -868,7 +882,7 @@ impl CustomPredicateVerifyEntryTarget {
args: (0..params.max_custom_predicate_wildcards) args: (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value()) .map(|_| builder.add_virtual_value())
.collect(), .collect(),
op_args: (0..params.max_operation_args) op_args: (0..BASE_PARAMS.max_operation_args)
.map(|_| builder.add_virtual_statement(false)) .map(|_| builder.add_virtual_statement(false))
.collect(), .collect(),
} }
@ -898,7 +912,7 @@ impl CustomPredicateVerifyEntryTarget {
cpv.op_args cpv.op_args
.iter() .iter()
.chain(iter::repeat(&pad_op_arg)) .chain(iter::repeat(&pad_op_arg))
.take(params.max_operation_args), .take(BASE_PARAMS.max_operation_args),
) { ) {
op_arg_target.set_targets(pw, op_arg)? op_arg_target.set_targets(pw, op_arg)?
} }
@ -941,7 +955,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget {
.expect("len = operation_type_size"), .expect("len = operation_type_size"),
}; };
let (pos, size) = (pos + size, StatementTarget::size(params)); let (pos, size) = (pos + size, StatementTarget::size(params));
let op_args = (0..params.max_operation_args) let op_args = (0..BASE_PARAMS.max_operation_args)
.map(|i| { .map(|i| {
StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size]) StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size])
}) })
@ -953,7 +967,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget {
} }
} }
fn size(params: &Params) -> usize { fn size(params: &Params) -> usize {
StatementTarget::size(params) * (1 + params.max_operation_args) StatementTarget::size(params) * (1 + BASE_PARAMS.max_operation_args)
+ OperationTarget::size(params) + OperationTarget::size(params)
} }
} }
@ -1425,7 +1439,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget {
OperationTarget { OperationTarget {
op_type: self.add_virtual_operation_type(), op_type: self.add_virtual_operation_type(),
args: (0..params.max_operation_args) args: (0..BASE_PARAMS.max_operation_args)
.map(|_| IndexTarget::new_virtual(params.statement_table_size(), self)) .map(|_| IndexTarget::new_virtual(params.statement_table_size(), self))
.collect(), .collect(),
aux_index: IndexTarget::new_virtual(OperationAux::table_size(params), self), aux_index: IndexTarget::new_virtual(OperationAux::table_size(params), self),
@ -1735,7 +1749,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
let num_chunks = array.len().div_ceil(CHUNK_LEN); let num_chunks = array.len().div_ceil(CHUNK_LEN);
for chunk in array.chunks(CHUNK_LEN) { for chunk in array.chunks(CHUNK_LEN) {
let mut index_chunk = i.low; let mut index_chunk = i.low;
// I we have several chunks and the last one is smaller (it's index needs less than 6 // If we have several chunks and the last one is smaller (it's index needs less than 6
// bits), make it zero except when it's used so that the range check over the index // bits), make it zero except when it's used so that the range check over the index
// passes. // passes.
if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 { if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 {

View file

@ -55,7 +55,7 @@ use crate::{
middleware::{ middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix, NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix,
ToFields, Value, F, HASH_SIZE, ToFields, Value, BASE_PARAMS, F, HASH_SIZE,
}, },
}; };
// //
@ -69,30 +69,37 @@ pub const PI_OFFSET_VDSROOT: usize = 4;
pub const NUM_PUBLIC_INPUTS: usize = 8; pub const NUM_PUBLIC_INPUTS: usize = 8;
const MAX_VALUE_ARGS: usize = 4; const MAX_VALUE_ARGS: usize = 5;
struct StatementArgCache { struct StatementArgCache {
rhs: ValueTarget, rhs: ValueTarget,
lhs: StatementArgTarget, lhs: StatementArgTarget,
valid: BoolTarget, valid: BoolTarget,
pred_is_none: BoolTarget,
is_reference: BoolTarget,
// if `is_reference` then this is the AnchoredKey found in the Contains statement
reference: StatementArgTarget,
// if `is_reference` then this is the value found in the Contains statement
value: ValueTarget,
} }
struct StatementCache { struct StatementCache<const MAX_EQS: usize> {
equations: [StatementArgCache; MAX_VALUE_ARGS], equations: [StatementArgCache; MAX_EQS],
first_n_equations_valid: [BoolTarget; MAX_VALUE_ARGS], first_n_equations_valid: [BoolTarget; MAX_EQS],
op_args: Vec<StatementTarget>, op_args: Vec<StatementTarget>,
} }
impl StatementCache { impl<const MAX_EQS: usize> StatementCache<MAX_EQS> {
fn new( fn new(
params: &Params, params: &Params,
max_operation_args: usize,
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
op: &OperationTarget, op: &OperationTarget,
st: &StatementTarget, st: &StatementTarget,
prev_statements: &[StatementTarget], prev_statements: &[StatementTarget],
) -> Self { ) -> Self {
let op_args = if prev_statements.is_empty() { let op_args = if prev_statements.is_empty() {
(0..params.max_operation_args) (0..max_operation_args)
.map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[])) .map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[]))
.collect_vec() .collect_vec()
} else { } else {
@ -100,10 +107,10 @@ impl StatementCache {
// converting a length 1 array into a scalar. // converting a length 1 array into a scalar.
op.args op.args
.iter() .iter()
.take(max_operation_args)
.map(|i| builder.vec_ref(params, prev_statements, i)) .map(|i| builder.vec_ref(params, prev_statements, i))
.collect::<Vec<_>>() .collect::<Vec<_>>()
}; };
assert!(params.max_operation_args >= MAX_VALUE_ARGS);
assert!(Params::max_statement_args() >= MAX_VALUE_ARGS); assert!(Params::max_statement_args() >= MAX_VALUE_ARGS);
let equations = array::from_fn(|i| { let equations = array::from_fn(|i| {
let pred_is_none = op_args[i].has_native_type(builder, NativePredicate::None); let pred_is_none = op_args[i].has_native_type(builder, NativePredicate::None);
@ -117,9 +124,9 @@ impl StatementCache {
let is_reference = builder.and(pred_is_contains, ref_is_value); let is_reference = builder.and(pred_is_contains, ref_is_value);
let valid = builder.or(is_literal, is_reference); let valid = builder.or(is_literal, is_reference);
let rhs_literal = st.args[i].as_value(); let rhs_from_literal = st.args[i].as_value();
let rhs_reference = op_args[i].args[2].as_value(); let rhs_from_reference = op_args[i].args[2].as_value();
let rhs = builder.select_value(pred_is_none, rhs_literal, rhs_reference); let rhs = builder.select_value(pred_is_none, rhs_from_literal, rhs_from_reference);
let lhs_literal = &st.args[i]; let lhs_literal = &st.args[i];
let lhs_reference = StatementArgTarget::anchored_key( let lhs_reference = StatementArgTarget::anchored_key(
builder, builder,
@ -127,10 +134,22 @@ impl StatementCache {
&op_args[i].args[1].as_value(), &op_args[i].args[1].as_value(),
); );
let lhs = builder.select_statement_arg(pred_is_none, lhs_literal, &lhs_reference); let lhs = builder.select_statement_arg(pred_is_none, lhs_literal, &lhs_reference);
StatementArgCache { rhs, lhs, valid } StatementArgCache {
rhs,
lhs,
valid,
pred_is_none,
is_reference,
reference: lhs_reference,
value: rhs_from_reference,
}
}); });
let mut first_n_equations_valid = [equations[0].valid; MAX_VALUE_ARGS]; let mut first_n_equations_valid = if MAX_EQS != 0 {
for i in 1..MAX_VALUE_ARGS { [equations[0].valid; MAX_EQS]
} else {
[builder._false(); MAX_EQS]
};
for i in 1..MAX_EQS {
first_n_equations_valid[i] = first_n_equations_valid[i] =
builder.and(equations[i].valid, first_n_equations_valid[i - 1]); builder.and(equations[i].valid, first_n_equations_valid[i - 1]);
} }
@ -145,7 +164,7 @@ impl StatementCache {
/// ///
/// If the operation argument is a statement of type `None`, then the value /// If the operation argument is a statement of type `None`, then the value
/// should be the corresponding argument of the current statement. /// should be the corresponding argument of the current statement.
/// If the operation argument is a statement of type `Equals`, then the value /// If the operation argument is a statement of type `Contains`, then the value
/// should be the argument at index 1 of that statement. /// should be the argument at index 1 of that statement.
/// If the function successfully interprets the arguments as values, /// If the function successfully interprets the arguments as values,
/// returns `True` along with those values. Otherwise, returns `False` /// returns `True` along with those values. Otherwise, returns `False`
@ -158,6 +177,12 @@ impl StatementCache {
} }
} }
/// Statement cache for private statements
type StatementCachePriv = StatementCache<MAX_VALUE_ARGS>;
/// Statement cache for public statements. Since the operations can only be None or Copy, no
/// equation is needed because none of these operations dereference entries.
type StatementCachePub = StatementCache<0>;
/// Specialized implementation of `verify_operation_circuit` for operations that generate public /// Specialized implementation of `verify_operation_circuit` for operations that generate public
/// statement. This only allows operations to be None, NewEntry or Copy and accounts for the fact /// statement. This only allows operations to be None, NewEntry or Copy and accounts for the fact
/// that public statements in the current implementation are always generated by copying private /// that public statements in the current implementation are always generated by copying private
@ -169,13 +194,15 @@ fn verify_operation_public_statement_circuit(
op: &OperationTarget, op: &OperationTarget,
prev_statements: &[StatementTarget], prev_statements: &[StatementTarget],
) -> Result<()> { ) -> Result<()> {
let measure = measure_gates_begin!(builder, "OpVerify"); let measure = measure_gates_begin!(builder, "OpVerifyPub");
// Verify that the operation `op` correctly generates the statement `st`. The operation // Verify that the operation `op` correctly generates the statement `st`. The operation
// can reference any of the `prev_statements`. // can reference any of the `prev_statements`.
// TODO: Clean this up. // TODO: Clean this up.
let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs");
let cache = StatementCache::new(params, builder, op, st, prev_statements); // None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the
// StatementCache requires.
let cache = StatementCachePub::new(params, 1, builder, op, st, prev_statements);
measure_gates_end!(builder, measure_resolve_op_args); measure_gates_end!(builder, measure_resolve_op_args);
let op_checks = vec![ let op_checks = vec![
@ -406,7 +433,7 @@ fn verify_operation_circuit(
prev_statements: &[StatementTarget], prev_statements: &[StatementTarget],
aux_table: &MuxTableTarget, aux_table: &MuxTableTarget,
) -> Result<()> { ) -> Result<()> {
let measure = measure_gates_begin!(builder, "OpVerify"); let measure = measure_gates_begin!(builder, "OpVerifyPriv");
let _true = builder._true(); let _true = builder._true();
let _false = builder._false(); let _false = builder._false();
@ -414,7 +441,14 @@ fn verify_operation_circuit(
// can reference any of the `prev_statements`. // can reference any of the `prev_statements`.
// TODO: Clean this up. // TODO: Clean this up.
let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs");
let cache = StatementCache::new(params, builder, op, st, prev_statements); let cache = StatementCachePriv::new(
params,
BASE_PARAMS.max_operation_args,
builder,
op,
st,
prev_statements,
);
measure_gates_end!(builder, measure_resolve_op_args); measure_gates_end!(builder, measure_resolve_op_args);
// Certain operations (e.g.: Contains/NotContains) will refer to one of the provided verified // Certain operations (e.g.: Contains/NotContains) will refer to one of the provided verified
@ -442,6 +476,7 @@ fn verify_operation_circuit(
verify_sum_of_circuit(params, builder, st, &op.op_type, &cache), verify_sum_of_circuit(params, builder, st, &op.op_type, &cache),
verify_product_of_circuit(params, builder, st, &op.op_type, &cache), verify_product_of_circuit(params, builder, st, &op.op_type, &cache),
verify_max_of_circuit(params, builder, st, &op.op_type, &cache), verify_max_of_circuit(params, builder, st, &op.op_type, &cache),
verify_replace_value_with_entry_circuit(params, builder, st, &op.op_type, &cache),
]); ]);
} }
// Skip these if there are no resolved aux entries // Skip these if there are no resolved aux entries
@ -542,7 +577,7 @@ fn verify_contains_from_entries_circuit(
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
aux: &TableEntryTarget, aux: &TableEntryTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpContainsFromEntries"); let measure = measure_gates_begin!(builder, "OpContainsFromEntries");
let (aux_tag_ok, resolved_merkle_claim) = let (aux_tag_ok, resolved_merkle_claim) =
@ -592,7 +627,7 @@ fn verify_not_contains_from_entries_circuit(
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
aux: &TableEntryTarget, aux: &TableEntryTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries"); let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries");
let (aux_tag_ok, resolved_merkle_claim) = let (aux_tag_ok, resolved_merkle_claim) =
@ -639,7 +674,7 @@ fn verify_merkle_insert_circuit(
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
aux: &TableEntryTarget, aux: &TableEntryTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "MerkleInsertOp"); let measure = measure_gates_begin!(builder, "MerkleInsertOp");
let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) =
@ -714,7 +749,7 @@ fn verify_merkle_update_circuit(
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
aux: &TableEntryTarget, aux: &TableEntryTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "MerkleUpdateOp"); let measure = measure_gates_begin!(builder, "MerkleUpdateOp");
let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) =
@ -789,7 +824,7 @@ fn verify_merkle_delete_circuit(
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
aux: &TableEntryTarget, aux: &TableEntryTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "MerkleDeleteOp"); let measure = measure_gates_begin!(builder, "MerkleDeleteOp");
let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) = let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) =
@ -883,7 +918,7 @@ fn verify_eq_neq_from_entries_circuit(
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries"); let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries");
let eq_op_st_code_ok = { let eq_op_st_code_ok = {
@ -932,9 +967,9 @@ fn verify_lt_lteq_from_entries_circuit(
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpLtLteqFromEntries"); let measure = measure_gates_begin!(builder, "OpLtEqFromEntries");
let zero = ValueTarget::zero(builder); let zero = ValueTarget::zero(builder);
let one = ValueTarget::one(builder); let one = ValueTarget::one(builder);
@ -1000,7 +1035,7 @@ fn verify_hash_of_circuit(
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpHashOf"); let measure = measure_gates_begin!(builder, "OpHashOf");
let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf); let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf);
@ -1033,7 +1068,7 @@ fn verify_public_key_of_circuit(
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
aux: &TableEntryTarget, aux: &TableEntryTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpPublicKeyOf"); let measure = measure_gates_begin!(builder, "OpPublicKeyOf");
let (aux_tag_ok, resolved_pk_sk) = let (aux_tag_ok, resolved_pk_sk) =
@ -1069,7 +1104,7 @@ fn verify_signed_by_circuit(
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
aux: &TableEntryTarget, aux: &TableEntryTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpSignedBy"); let measure = measure_gates_begin!(builder, "OpSignedBy");
let (aux_tag_ok, resolved_msg_pk) = let (aux_tag_ok, resolved_msg_pk) =
@ -1104,7 +1139,7 @@ fn verify_sum_of_circuit(
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpSumOf"); let measure = measure_gates_begin!(builder, "OpSumOf");
let value_zero = ValueTarget::zero(builder); let value_zero = ValueTarget::zero(builder);
@ -1142,7 +1177,7 @@ fn verify_product_of_circuit(
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpProductOf"); let measure = measure_gates_begin!(builder, "OpProductOf");
let value_zero = ValueTarget::zero(builder); let value_zero = ValueTarget::zero(builder);
@ -1180,7 +1215,7 @@ fn verify_max_of_circuit(
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
st: &StatementTarget, st: &StatementTarget,
op_type: &OperationTypeTarget, op_type: &OperationTypeTarget,
cache: &StatementCache, cache: &StatementCachePriv,
) -> BoolTarget { ) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpMaxOf"); let measure = measure_gates_begin!(builder, "OpMaxOf");
let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf); let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf);
@ -1220,6 +1255,47 @@ fn verify_max_of_circuit(
ok ok
} }
fn verify_replace_value_with_entry_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpReplaceValueWithEntry");
let op_code_ok = op_type.has_native(builder, NativeOperation::ReplaceValueWithEntry);
let st_in = &cache.op_args[BASE_PARAMS.max_statement_args];
let mut args = Vec::new();
let mut args_ok = builder._true();
for (arg_in, entry_cache) in zip_eq(&st_in.args, &cache.equations) {
// if the op_arg is None, keep the original argument, if it's a Contains swap the value by
// the reference Entry while checking that the value in Contains matches the original
// argument.
let arg = builder.select_flattenable(
params,
entry_cache.pred_is_none,
arg_in,
&entry_cache.reference,
);
args.push(arg);
let arg_ref_ok = {
let arg_in_is_value = builder.statement_arg_is_value(arg_in);
let value_eq = builder.is_equal_flattenable(&arg_in.as_value(), &entry_cache.value);
builder.all([entry_cache.is_reference, arg_in_is_value, value_eq])
};
let arg_ok = builder.or(entry_cache.pred_is_none, arg_ref_ok);
args_ok = builder.and(args_ok, arg_ok);
}
let expected_statement = StatementTarget::new(*st_in.pred_hash(), args);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, args_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_transitive_eq_circuit( fn verify_transitive_eq_circuit(
params: &Params, params: &Params,
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
@ -1429,7 +1505,7 @@ fn make_custom_statement_circuit(
) -> Result<(StatementTarget, OperationTypeTarget)> { ) -> Result<(StatementTarget, OperationTypeTarget)> {
let measure = measure_gates_begin!(builder, "CustomOpVerify"); let measure = measure_gates_begin!(builder, "CustomOpVerify");
// Some sanity checks // Some sanity checks
assert_eq!(params.max_operation_args, op_args.len()); assert_eq!(BASE_PARAMS.max_operation_args, op_args.len());
assert_eq!(params.max_custom_predicate_wildcards, args.len()); assert_eq!(params.max_custom_predicate_wildcards, args.len());
let (batch_id, index) = (custom_predicate.id, custom_predicate.index); let (batch_id, index) = (custom_predicate.id, custom_predicate.index);
@ -1463,7 +1539,6 @@ fn make_custom_statement_circuit(
.collect(); .collect();
// expected_sts.len() == params.max_custom_predicate_arity // expected_sts.len() == params.max_custom_predicate_arity
// op_args.len() == params.max_operation_args; // op_args.len() == params.max_operation_args;
assert!(Params::max_custom_predicate_arity() <= params.max_operation_args);
let sts_eq: Vec<_> = expected_sts let sts_eq: Vec<_> = expected_sts
.iter() .iter()
@ -2076,7 +2151,8 @@ mod tests {
frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
middleware::{ middleware::{
hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard, hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard,
RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard, EMPTY_VALUE, RawValue, StatementArg, StatementTmpl, StatementTmplArg, ValueRef, Wildcard,
BASE_PARAMS, EMPTY_VALUE,
}, },
}; };
@ -3068,6 +3144,33 @@ mod tests {
operation_verify(st, op, prev_statements, Aux::signed_by(signed_by)) operation_verify(st, op, prev_statements, Aux::signed_by(signed_by))
} }
#[test]
fn test_operation_replace_value_with_entry() -> Result<()> {
let d = dict!({"a" => 42, "b" => 33});
// 0: None
// 1: Lt(5, 42)
let st_in: mainpod::Statement = Statement::lt(5, 42).into();
// 2: Contains(d, "a", 42)
let st_entry: mainpod::Statement = Statement::contains(d.clone(), "a", 42).into();
let st_out: mainpod::Statement =
Statement::lt(5, ValueRef::Key(AnchoredKey::from((&d, "a")))).into();
let mut op_args: Vec<_> = iter::repeat(OperationArg::None)
.take(BASE_PARAMS.max_statement_args + 1)
.collect();
op_args[1] = OperationArg::Index(2);
op_args[BASE_PARAMS.max_statement_args] = OperationArg::Index(1);
let op = mainpod::Operation(
OperationType::Native(NativeOperation::ReplaceValueWithEntry),
op_args,
OperationAux::None,
);
let prev_statements = vec![Statement::None.into(), st_in, st_entry];
operation_verify(st_out, op, prev_statements, Aux::default())
}
fn helper_statement_arg_from_template( fn helper_statement_arg_from_template(
params: &Params, params: &Params,
st_tmpl_arg: StatementTmplArg, st_tmpl_arg: StatementTmplArg,
@ -3226,7 +3329,7 @@ mod tests {
expected_st: Option<Statement>, expected_st: Option<Statement>,
) -> Result<()> { ) -> Result<()> {
// Pad // Pad
for _ in op_args.len()..params.max_operation_args { for _ in op_args.len()..BASE_PARAMS.max_operation_args {
op_args.push(Statement::None); op_args.push(Statement::None);
} }
for _ in args.len()..params.max_custom_predicate_wildcards { for _ in args.len()..params.max_custom_predicate_wildcards {
@ -3275,6 +3378,10 @@ mod tests {
Ok(data.verify(proof.clone())?) Ok(data.verify(proof.clone())?)
} }
fn value_ref(v: impl Into<ValueRef>) -> ValueRef {
v.into()
}
// TODO: Add negative tests // TODO: Add negative tests
#[test] #[test]
fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> { fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> {
@ -3309,7 +3416,7 @@ mod tests {
let args = vec![Value::from(dict), Value::from(1234)]; let args = vec![Value::from(dict), Value::from(1234)];
let expected_st = Statement::Custom( let expected_st = Statement::Custom(
custom_predicate.clone(), custom_predicate.clone(),
vec![args[0].clone(), Value::from(0)], vec![value_ref(args[0].clone()), value_ref(0)],
); );
helper_custom_operation_verify_gadget( helper_custom_operation_verify_gadget(
@ -3330,7 +3437,7 @@ mod tests {
let args = vec![Value::from(dict), Value::from(0)]; let args = vec![Value::from(dict), Value::from(0)];
let expected_st = Statement::Custom( let expected_st = Statement::Custom(
custom_predicate.clone(), custom_predicate.clone(),
vec![args[0].clone(), Value::from(0)], vec![value_ref(args[0].clone()), value_ref(0)],
); );
helper_custom_operation_verify_gadget( helper_custom_operation_verify_gadget(
@ -3351,7 +3458,7 @@ mod tests {
let args = vec![Value::from(dict), Value::from(1234)]; let args = vec![Value::from(dict), Value::from(1234)];
let expected_st = Statement::Custom( let expected_st = Statement::Custom(
custom_predicate.clone(), custom_predicate.clone(),
vec![args[0].clone(), Value::from(0)], vec![value_ref(args[0].clone()), value_ref(0)],
); );
helper_custom_operation_verify_gadget( helper_custom_operation_verify_gadget(
@ -3403,7 +3510,7 @@ mod tests {
let args = vec![Value::from(dict), Value::from(secret_dict)]; let args = vec![Value::from(dict), Value::from(secret_dict)];
let expected_st = Statement::Custom( let expected_st = Statement::Custom(
custom_predicate.clone(), custom_predicate.clone(),
vec![args[0].clone(), Value::from(0)], vec![value_ref(args[0].clone()), value_ref(0)],
); );
helper_custom_operation_verify_gadget( helper_custom_operation_verify_gadget(

View file

@ -1,5 +1,5 @@
pub mod operation; pub mod operation;
use crate::middleware::{wildcard_values_from_op_st, PodType}; use crate::middleware::{wildcard_values_from_op_st, PodType, BASE_PARAMS};
pub mod statement; pub mod statement;
use std::iter; use std::iter;
@ -39,7 +39,7 @@ use crate::{
middleware::{ middleware::{
self, value_from_op, CustomPredicateRef, Error as MiddlewareError, Hash, MainPodInputs, self, value_from_op, CustomPredicateRef, Error as MiddlewareError, Hash, MainPodInputs,
MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg, MainPodProver, NativeOperation, OperationType, Params, Pod, RawValue, StatementArg,
ToFields, VDSet, Value, ToFields, VDSet, Value, ValueRef,
}, },
timed, timed,
}; };
@ -104,9 +104,20 @@ pub(crate) fn extract_custom_predicate_verifications(
if let middleware::Operation::Custom(cpr, sts) = op { if let middleware::Operation::Custom(cpr, sts) = op {
if let middleware::Statement::Custom(st_cpr, st_args) = st { if let middleware::Statement::Custom(st_cpr, st_args) = st {
assert_eq!(cpr, st_cpr); assert_eq!(cpr, st_cpr);
// The custom operation outputs statements with literal arguments. They can be
// replaced by references later with ReplaceValueWithEntry.
let st_args = st_args
.iter()
.map(|arg| match arg {
ValueRef::Literal(v) => Ok(v.clone()),
_ => Err(Error::custom(
"custom operation cannot output entries as arguments",
)),
})
.collect::<Result<Vec<_>>>()?;
let normalized_pred = cpr.normalized_predicate(); let normalized_pred = cpr.normalized_predicate();
let wildcard_values = let wildcard_values =
wildcard_values_from_op_st(params, &normalized_pred, sts, st_args) wildcard_values_from_op_st(params, &normalized_pred, sts, &st_args)
.expect("resolved wildcards"); .expect("resolved wildcards");
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
let custom_predicate_table_index = custom_predicates let custom_predicate_table_index = custom_predicates
@ -329,8 +340,8 @@ pub fn pad_statement(s: &mut Statement) {
fill_pad(&mut s.1, StatementArg::None, Params::max_statement_args()) fill_pad(&mut s.1, StatementArg::None, Params::max_statement_args())
} }
fn pad_operation_args(params: &Params, args: &mut Vec<OperationArg>) { fn pad_operation_args(args: &mut Vec<OperationArg>) {
fill_pad(args, OperationArg::None, params.max_operation_args) fill_pad(args, OperationArg::None, BASE_PARAMS.max_operation_args)
} }
/// Returns the statements from the given MainPodInputs, padding to the respective max lengths /// Returns the statements from the given MainPodInputs, padding to the respective max lengths
@ -428,7 +439,7 @@ pub(crate) fn process_private_statements_operations(
.map(|mid_arg| find_op_arg(statements, mid_arg)) .map(|mid_arg| find_op_arg(statements, mid_arg))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
pad_operation_args(params, &mut args); pad_operation_args(&mut args);
operations.push(Operation(op.op_type(), args, *aux)); operations.push(Operation(op.op_type(), args, *aux));
} }
Ok(operations) Ok(operations)
@ -459,7 +470,11 @@ pub(crate) fn process_public_statements_operations(
OperationAux::None, OperationAux::None,
) )
}; };
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args); fill_pad(
&mut op.1,
OperationArg::None,
BASE_PARAMS.max_operation_args,
);
operations.push(op); operations.push(op);
} }
Ok(operations) Ok(operations)
@ -469,6 +484,7 @@ pub struct Prover {}
impl MainPodProver for Prover { impl MainPodProver for Prover {
fn prove(&self, params: &Params, inputs: MainPodInputs) -> Result<Box<dyn Pod>> { fn prove(&self, params: &Params, inputs: MainPodInputs) -> Result<Box<dyn Pod>> {
assert_eq!(inputs.statements.len(), inputs.operations.len());
// Pad input recursive pods with empty pods if necessary // Pad input recursive pods with empty pods if necessary
let empty_pod = if inputs.pods.len() == params.max_input_pods { let empty_pod = if inputs.pods.len() == params.max_input_pods {
// We don't need padding so we skip creating an EmptyPod // We don't need padding so we skip creating an EmptyPod
@ -1005,7 +1021,6 @@ pub mod tests {
max_input_pods_public_statements: 2, max_input_pods_public_statements: 2,
max_statements: 5, max_statements: 5,
max_public_statements: 2, max_public_statements: 2,
max_operation_args: 5,
max_custom_predicates: 2, max_custom_predicates: 2,
max_custom_predicate_verifications: 2, max_custom_predicate_verifications: 2,
max_custom_predicate_wildcards: 3, max_custom_predicate_wildcards: 3,
@ -1070,7 +1085,6 @@ pub mod tests {
max_input_pods: 0, max_input_pods: 0,
max_statements: 9, max_statements: 9,
max_public_statements: 4, max_public_statements: 4,
max_operation_args: 5,
max_custom_predicate_wildcards: 4, max_custom_predicate_wildcards: 4,
max_custom_predicate_verifications: 2, max_custom_predicate_verifications: 2,
max_merkle_proofs_containers: 3, max_merkle_proofs_containers: 3,
@ -1140,7 +1154,6 @@ pub mod tests {
max_input_pods: 0, max_input_pods: 0,
max_statements: 6, max_statements: 6,
max_public_statements: 2, max_public_statements: 2,
max_operation_args: 5,
max_custom_predicate_wildcards: 4, max_custom_predicate_wildcards: 4,
max_custom_predicate_verifications: 2, max_custom_predicate_verifications: 2,
max_merkle_proofs_containers: 0, max_merkle_proofs_containers: 0,
@ -1251,11 +1264,108 @@ pub mod tests {
); );
let st = middleware::Statement::Custom( let st = middleware::Statement::Custom(
cpr, cpr,
[1, 1, 2].into_iter().map(middleware::Value::from).collect(), [1, 1, 2]
.into_iter()
.map(middleware::ValueRef::from)
.collect(),
); );
builder.insert((st.clone(), op)).unwrap(); builder.insert((st.clone(), op)).unwrap();
builder.reveal(&st).unwrap(); builder.reveal(&st).unwrap();
let prover = Prover {}; let prover = Prover {};
builder.prove(&prover).unwrap(); builder.prove(&prover).unwrap();
} }
#[test]
fn test_replace_value_with_entry() {
let params = middleware::Params::default();
let vd_set = &*DEFAULT_VD_SET;
let mut builder = MainPodBuilder::new(&params, vd_set);
let d = dict!({"a" => 42, "b" => 33});
builder
.priv_op(frontend::Operation::dict_contains(d.clone(), "a", 42))
.unwrap();
let st = builder.priv_op(frontend::Operation::lt(5, 42)).unwrap();
// Transform `Lt(5, 42)` into `Lt(5, d.a)` by using `DictContains(d, "a", 42)`
builder
.pub_op(frontend::Operation::replace_value_with_entry(
vec![None, Some((&d, "a"))],
st,
))
.unwrap();
// Mock
let prover = MockProver {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap();
assert_eq!(
middleware::Statement::Lt(
middleware::ValueRef::Literal(Value::from(5)),
middleware::ValueRef::Key(middleware::AnchoredKey {
root: d.commitment(),
key: middleware::Key::from("a")
})
),
pod.public_statements[0]
);
// Real
let prover = Prover {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap()
}
#[test]
fn test_entry_custom_statement_arg() {
let params = middleware::Params::default();
let vd_set = &*DEFAULT_VD_SET;
let input = r#"
PredA(x) = AND(
Lt(x, 100)
)
PredB(d) = AND(
PredA(d.x)
)
"#;
let module = load_module(input, "my_mod", &params, &[]).expect("lang parse");
let pred_a = module.batch.predicate_ref_by_name("PredA").unwrap();
let pred_b = module.batch.predicate_ref_by_name("PredB").unwrap();
let mut builder = MainPodBuilder::new(&params, vd_set);
let d = dict!({"x" => 42, "y" => 33});
let st_lt = builder.priv_op(frontend::Operation::lt(42, 100)).unwrap();
let st_a = builder
.priv_op(frontend::Operation::custom(pred_a, [st_lt]))
.unwrap();
builder
.priv_op(frontend::Operation::dict_contains(d.clone(), "x", 42))
.unwrap();
// Transform `PredA(42)` into `PredA(d.x)` by using `DictContains(d, "x", 42)`
let st_a1 = builder
.priv_op(frontend::Operation::replace_value_with_entry(
vec![Some((&d, "x"))],
st_a,
))
.unwrap();
builder
.pub_op(frontend::Operation::custom(pred_b.clone(), [st_a1]))
.unwrap();
// Mock
let prover = MockProver {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap();
let expected = middleware::Statement::Custom(
pred_b,
vec![middleware::ValueRef::Literal(Value::from(d))],
);
assert_eq!(expected, pod.public_statements[0]);
// Real
let prover = Prover {};
let pod = builder.prove(&prover).unwrap();
pod.pod.verify().unwrap()
}
} }

View file

@ -4,7 +4,9 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
backends::plonky2::error::{Error, Result}, backends::plonky2::error::{Error, Result},
middleware::{self, NativePredicate, Predicate, StatementArg, ToFields, Value, BASE_PARAMS}, middleware::{
self, NativePredicate, Predicate, StatementArg, ToFields, Value, ValueRef, BASE_PARAMS,
},
}; };
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@ -96,15 +98,15 @@ impl TryFrom<Statement> for middleware::Statement {
)))?, )))?,
}, },
Predicate::Custom(cpr) => { Predicate::Custom(cpr) => {
let vs: Vec<Value> = proper_args let args: Vec<ValueRef> = proper_args
.into_iter() .into_iter()
.filter_map(|arg| match arg { .filter_map(|arg| match arg {
SA::None => None, StatementArg::Literal(v) => Some(ValueRef::Literal(v)),
SA::Literal(v) => Some(v), StatementArg::Key(k) => Some(ValueRef::Key(k)),
_ => unreachable!(), StatementArg::None => None,
}) })
.collect(); .collect();
S::Custom(cpr, vs) S::Custom(cpr, args)
} }
Predicate::Intro(ir) => { Predicate::Intro(ir) => {
let vs: Vec<Value> = proper_args let vs: Vec<Value> = proper_args

View file

@ -380,7 +380,8 @@ pub mod tests {
great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request, great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request,
zu_kyc_sign_dict_builders, MOCK_VD_SET, zu_kyc_sign_dict_builders, MOCK_VD_SET,
}, },
frontend, middleware, frontend::{self},
middleware,
middleware::{Signer as _, Value}, middleware::{Signer as _, Value},
}; };

View file

@ -316,7 +316,9 @@ mod tests {
backends::plonky2::mock::mainpod::MockProver, backends::plonky2::mock::mainpod::MockProver,
examples::{custom::eth_dos_batch, MOCK_VD_SET}, examples::{custom::eth_dos_batch, MOCK_VD_SET},
frontend::{MainPodBuilder, Operation}, frontend::{MainPodBuilder, Operation},
middleware::{self, containers::Set, CustomPredicateRef, Params, PodType, DEFAULT_VD_SET}, middleware::{
self, containers::Set, CustomPredicateRef, Params, PodType, ValueRef, DEFAULT_VD_SET,
},
}; };
#[test] #[test]
@ -507,7 +509,7 @@ mod tests {
.find(|s| matches!(s, middleware::Statement::Custom(_, _))) .find(|s| matches!(s, middleware::Statement::Custom(_, _)))
.expect("should have a custom statement"); .expect("should have a custom statement");
if let middleware::Statement::Custom(_, args) = custom_st { if let middleware::Statement::Custom(_, args) = custom_st {
assert_eq!(args[0], pred_b_hash); assert_eq!(args[0], ValueRef::Literal(pred_b_hash));
} }
Ok(()) Ok(())

View file

@ -4,7 +4,7 @@
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
convert::From, convert::From,
fmt, fmt, iter,
}; };
use itertools::Itertools; use itertools::Itertools;
@ -15,9 +15,10 @@ pub use serialization::SerializedMainPod;
use crate::middleware::{ use crate::middleware::{
self, check_custom_pred, self, check_custom_pred,
containers::{Container, Dictionary}, containers::{Container, Dictionary},
fill_wildcard_values, hash_op, max_op, prod_op, sum_op, AnchoredKey, Hash, Key, MainPodInputs, fill_wildcard_values, hash_op, max_op, prod_op, root_key_to_ak, sum_op, AnchoredKey, Hash, Key,
MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey, RawValue, MainPodInputs, MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey,
Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, EMPTY_VALUE, RawValue, Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, BASE_PARAMS,
EMPTY_VALUE,
}; };
mod custom; mod custom;
@ -566,6 +567,37 @@ impl MainPodBuilder {
// TODO: validate proof // TODO: validate proof
Statement::ContainerDelete(r1, r2, r3) Statement::ContainerDelete(r1, r2, r3)
} }
(ReplaceValueWithEntry, &args, _) => {
let mut args = args.to_vec();
if args.len() != BASE_PARAMS.max_statement_args + 1 {
return Err(Error::custom(format!(
"ReplaceValueWithEntry requires exactly {} args but {} were found",
BASE_PARAMS.max_statement_args + 1,
args.len()
)));
}
let st = match args.pop().expect("valid vec len") {
OperationArg::Statement(st) => st,
_ => return Err(Error::custom("expected statement")),
};
let new_st_args = iter::zip(st.args().into_iter(), args)
.map(|(st_arg, arg)| match (st_arg, arg) {
(st_arg, OperationArg::Statement(Statement::None)) => Ok(st_arg),
(
StatementArg::Literal(st_arg_v),
OperationArg::Statement(Statement::Contains(
ValueRef::Literal(root),
ValueRef::Literal(key),
ValueRef::Literal(v),
)),
) if st_arg_v == v => root_key_to_ak(&root, &key)
.map(StatementArg::Key)
.ok_or_else(native_arg_error),
_ => Err(Error::custom("unexpected operation argument")),
})
.collect::<Result<Vec<_>>>()?;
Statement::from_args(st.predicate(), new_st_args)?
}
(t, _, _) => { (t, _, _) => {
if t.is_syntactic_sugar() { if t.is_syntactic_sugar() {
return Err(Error::custom(format!( return Err(Error::custom(format!(
@ -615,7 +647,7 @@ impl MainPodBuilder {
.map(|v| v.unwrap_or_else(|| v_default.clone())) .map(|v| v.unwrap_or_else(|| v_default.clone()))
.collect(); .collect();
check_custom_pred(&self.params, &cpr, &args, &st_args)?; check_custom_pred(&self.params, &cpr, &args, &st_args)?;
Statement::Custom(cpr, st_args) Statement::Custom(cpr, st_args.into_iter().map(ValueRef::Literal).collect())
} }
}; };
Ok(st) Ok(st)

View file

@ -111,7 +111,8 @@ impl StatementCost {
// Syntactic sugar variants (lowered before proving) // Syntactic sugar variants (lowered before proving)
| NativeOperation::GtEqFromEntries | NativeOperation::GtEqFromEntries
| NativeOperation::GtFromEntries | NativeOperation::GtFromEntries
| NativeOperation::GtToNotEqual => {} | NativeOperation::GtToNotEqual
| NativeOperation::ReplaceValueWithEntry => {}
} }
} }
OperationType::Custom(cpr) => { OperationType::Custom(cpr) => {

View file

@ -1,10 +1,10 @@
use std::fmt; use std::{fmt, iter};
use crate::{ use crate::{
frontend::SignedDict, frontend::SignedDict,
middleware::{ middleware::{
containers::Dictionary, root_key_to_ak, CustomPredicateRef, NativeOperation, OperationAux, containers::Dictionary, root_key_to_ak, CustomPredicateRef, NativeOperation, OperationAux,
OperationType, Signature, Statement, Value, ValueRef, OperationType, Signature, Statement, Value, ValueRef, BASE_PARAMS,
}, },
}; };
@ -219,6 +219,24 @@ impl Operation {
op_impl_oa!(set_insert, SetInsertFromEntries, 3); op_impl_oa!(set_insert, SetInsertFromEntries, 3);
op_impl_oa!(set_delete, SetDeleteFromEntries, 3); op_impl_oa!(set_delete, SetDeleteFromEntries, 3);
op_impl_oa!(array_update, ArrayUpdateFromEntries, 4); op_impl_oa!(array_update, ArrayUpdateFromEntries, 4);
pub fn replace_value_with_entry(args: Vec<Option<(&Dictionary, &str)>>, st: Statement) -> Self {
assert!(args.len() <= BASE_PARAMS.max_statement_args);
let args = args
.into_iter()
.chain(iter::repeat(None))
.take(BASE_PARAMS.max_statement_args)
.map(|a| match a {
None => OperationArg::Statement(Statement::None),
Some((dict, key)) => OperationArg::from((dict, key)),
})
.chain(iter::once(OperationArg::Statement(st)))
.collect();
Self(
OperationType::Native(NativeOperation::ReplaceValueWithEntry),
args,
OperationAux::None,
)
}
pub fn signed_by( pub fn signed_by(
msg: impl Into<OperationArg>, msg: impl Into<OperationArg>,
pk: impl Into<OperationArg>, pk: impl Into<OperationArg>,

View file

@ -174,18 +174,6 @@ fn render_validation_error(
"second REQUEST here", "second REQUEST here",
), ),
ValidationError::InvalidArgumentType { predicate, span } => {
let title = format!("invalid argument type for `{}`", predicate);
render_with_optional_span(
renderer,
source,
path,
&title,
span.as_ref(),
"anchored keys not allowed here",
)
}
ValidationError::DuplicateWildcard { name, span } => { ValidationError::DuplicateWildcard { name, span } => {
let title = format!("duplicate wildcard: {}", name); let title = format!("duplicate wildcard: {}", name);
render_with_optional_span( render_with_optional_span(

View file

@ -135,12 +135,6 @@ pub enum ValidationError {
span: Option<Span>, span: Option<Span>,
}, },
#[error("Invalid argument type for {predicate}: anchored keys not allowed")]
InvalidArgumentType {
predicate: String,
span: Option<Span>,
},
#[error("Duplicate wildcard in predicate arguments: {name}")] #[error("Duplicate wildcard in predicate arguments: {name}")]
DuplicateWildcard { name: String, span: Option<Span> }, DuplicateWildcard { name: String, span: Option<Span> },

View file

@ -522,7 +522,7 @@ impl Validator {
} }
// Validate arguments // Validate arguments
self.validate_statement_args(stmt, pred_info.as_ref(), wildcard_context)?; self.validate_statement_args(stmt, wildcard_context)?;
Ok(()) Ok(())
} }
@ -530,45 +530,8 @@ impl Validator {
fn validate_statement_args( fn validate_statement_args(
&self, &self,
stmt: &StatementTmpl, stmt: &StatementTmpl,
pred_info: Option<&PredicateInfo>,
wildcard_context: Option<(&str, &WildcardScope)>, wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> { ) -> Result<(), ValidationError> {
// For custom predicates, only wildcards and literals are allowed
if matches!(
pred_info.map(|i| &i.kind),
Some(PredicateKind::Custom { .. })
| Some(PredicateKind::BatchImported { .. })
| Some(PredicateKind::ModuleImported { .. })
) {
for arg in &stmt.args {
match arg {
StatementTmplArg::AnchoredKey(_) => {
return Err(ValidationError::InvalidArgumentType {
predicate: stmt.predicate.predicate_name().to_string(),
span: stmt.span,
});
}
StatementTmplArg::Wildcard(id) => {
if let Some((pred_name, scope)) = wildcard_context {
if !scope.wildcards.contains_key(&id.name) {
return Err(ValidationError::UndefinedWildcard {
name: id.name.clone(),
pred_name: pred_name.to_string(),
span: id.span,
});
}
}
}
StatementTmplArg::Literal(lit) => {
self.validate_literal_value(lit)?;
}
StatementTmplArg::SelfPredicateHash(id) => {
self.validate_self_predicate_hash(id, wildcard_context)?;
}
}
}
} else {
// Native predicates can have anchored keys
for arg in &stmt.args { for arg in &stmt.args {
match arg { match arg {
StatementTmplArg::Wildcard(id) => { StatementTmplArg::Wildcard(id) => {
@ -601,7 +564,6 @@ impl Validator {
} }
} }
} }
}
Ok(()) Ok(())
} }
@ -839,10 +801,7 @@ mod tests {
module_hash module_hash
); );
let result = parse_and_validate_request(&input, &available_modules); let result = parse_and_validate_request(&input, &available_modules);
assert!(matches!( assert!(result.is_ok());
result,
Err(ValidationError::InvalidArgumentType { .. })
));
} }
#[test] #[test]

View file

@ -578,7 +578,6 @@ mod tests {
max_input_pods: 3, max_input_pods: 3,
max_statements: 31, max_statements: 31,
max_public_statements: 10, max_public_statements: 10,
max_operation_args: 5,
max_custom_predicate_wildcards: 12, max_custom_predicate_wildcards: 12,
..Default::default() ..Default::default()
}; };

View file

@ -169,6 +169,12 @@ pub struct Hash(
pub [F; HASH_SIZE], pub [F; HASH_SIZE],
); );
impl Hash {
pub fn raw(self) -> RawValue {
RawValue::from(self)
}
}
impl From<Hash> for HashOut { impl From<Hash> for HashOut {
fn from(hash: Hash) -> HashOut { fn from(hash: Hash) -> HashOut {
HashOut { elements: hash.0 } HashOut { elements: hash.0 }

View file

@ -436,7 +436,7 @@ impl fmt::Display for CustomPredicate {
} }
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, JsonSchema)] #[derive(Clone, PartialEq, Eq, Serialize, JsonSchema)]
enum CustomPredicateBatchData { enum CustomPredicateBatchData {
Full { Full {
#[serde(skip)] #[serde(skip)]
@ -449,6 +449,20 @@ enum CustomPredicateBatchData {
}, },
} }
// Explicit implementation of Debug to skip the merkle tree
impl fmt::Debug for CustomPredicateBatchData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self {
Self::Full { mt, predicates } => f
.debug_struct("Full")
.field("id", &mt.root())
.field("predicates", &predicates)
.finish(),
Self::Opaque { id } => f.debug_struct("Opaque").field("id", &id).finish(),
}
}
}
// TODO: Rename Batch for Module everywhere in the code base // TODO: Rename Batch for Module everywhere in the code base
impl CustomPredicateBatchData { impl CustomPredicateBatchData {
fn new_full(predicates: Vec<CustomPredicate>) -> Self { fn new_full(predicates: Vec<CustomPredicate>) -> Self {
@ -630,7 +644,7 @@ mod tests {
middleware::{ middleware::{
AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key,
NativePredicate, Operation, Params, Predicate, Statement, StatementTmpl, NativePredicate, Operation, Params, Predicate, Statement, StatementTmpl,
StatementTmplArg, StatementTmplArg, ValueRef,
}, },
}; };
@ -653,6 +667,9 @@ mod tests {
fn names(names: &[&str]) -> Vec<String> { fn names(names: &[&str]) -> Vec<String> {
names.iter().map(|s| s.to_string()).collect() names.iter().map(|s| s.to_string()).collect()
} }
fn value_ref(v: impl Into<ValueRef>) -> ValueRef {
v.into()
}
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
type STA = StatementTmplArg; type STA = StatementTmplArg;
@ -701,7 +718,7 @@ mod tests {
}); });
let custom_statement = Statement::Custom( let custom_statement = Statement::Custom(
CustomPredicateRef::new(cust_pred_batch.clone(), 0), CustomPredicateRef::new(cust_pred_batch.clone(), 0),
vec![Value::from(d0.clone())], vec![value_ref(d0.clone())],
); );
let custom_deduction = Operation::Custom( let custom_deduction = Operation::Custom(
@ -833,7 +850,7 @@ mod tests {
// Example statement // Example statement
let ethdos_example = Statement::Custom( let ethdos_example = Statement::Custom(
CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2),
vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)], vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)],
); );
// Copies should work. // Copies should work.
@ -842,7 +859,7 @@ mod tests {
// This could arise as the inductive step. // This could arise as the inductive step.
let ethdos_ind_example = Statement::Custom( let ethdos_ind_example = Statement::Custom(
CustomPredicateRef::new(eth_dos_distance_batch.clone(), 1), CustomPredicateRef::new(eth_dos_distance_batch.clone(), 1),
vec![Value::from("Alice"), Value::from("Bob"), Value::from(7)], vec![value_ref("Alice"), value_ref("Bob"), value_ref(7)],
); );
assert!(Operation::Custom( assert!(Operation::Custom(
@ -857,12 +874,12 @@ mod tests {
let ethdos_facts = vec![ let ethdos_facts = vec![
Statement::Custom( Statement::Custom(
CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2), CustomPredicateRef::new(eth_dos_distance_batch.clone(), 2),
vec![Value::from("Alice"), Value::from("Charlie"), Value::from(6)], vec![value_ref("Alice"), value_ref("Charlie"), value_ref(6)],
), ),
Statement::sum_of(Value::from(7), Value::from(6), Value::from(1)), Statement::sum_of(Value::from(7), Value::from(6), Value::from(1)),
Statement::Custom( Statement::Custom(
CustomPredicateRef::new(eth_friend_batch.clone(), 0), CustomPredicateRef::new(eth_friend_batch.clone(), 0),
vec![Value::from("Charlie"), Value::from("Bob")], vec![value_ref("Charlie"), value_ref("Bob")],
), ),
]; ];
@ -959,7 +976,10 @@ mod tests {
let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())]; let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())];
// The output statement // The output statement
let output_st = Statement::Custom(pred_b_ref.clone(), vec![some_value.clone()]); let output_st = Statement::Custom(
pred_b_ref.clone(),
vec![ValueRef::Literal(some_value.clone())],
);
// This should pass // This should pass
assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(&params, &output_st)?); assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(&params, &output_st)?);
@ -1024,12 +1044,18 @@ mod tests {
// Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass // Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass
let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())]; let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())];
let st_a = Statement::Custom(pred_a_ref.clone(), vec![pred_b_hash.clone()]); let st_a = Statement::Custom(
pred_a_ref.clone(),
vec![ValueRef::Literal(pred_b_hash.clone())],
);
assert!(Operation::Custom(pred_a_ref, op_a).check(&params, &st_a)?); assert!(Operation::Custom(pred_a_ref, op_a).check(&params, &st_a)?);
// Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass // Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass
let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())]; let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())];
let st_b = Statement::Custom(pred_b_ref.clone(), vec![pred_a_hash.clone()]); let st_b = Statement::Custom(
pred_b_ref.clone(),
vec![ValueRef::Literal(pred_a_hash.clone())],
);
assert!(Operation::Custom(pred_b_ref, op_b).check(&params, &st_b)?); assert!(Operation::Custom(pred_b_ref, op_b).check(&params, &st_b)?);
Ok(()) Ok(())

View file

@ -768,6 +768,8 @@ pub struct BaseParams {
/// in a custom predicate /// in a custom predicate
pub max_custom_predicate_arity: usize, pub max_custom_predicate_arity: usize,
pub max_depth_custom_batch_mt: usize, pub max_depth_custom_batch_mt: usize,
// This value depends on `max_custom_predicate_arity`
pub max_operation_args: usize,
} }
pub const BASE_PARAMS: BaseParams = BaseParams { pub const BASE_PARAMS: BaseParams = BaseParams {
@ -775,6 +777,7 @@ pub const BASE_PARAMS: BaseParams = BaseParams {
max_statement_args: 5, max_statement_args: 5,
max_custom_predicate_arity: 5, max_custom_predicate_arity: 5,
max_depth_custom_batch_mt: 16, // up to 65k (2^16) custom predicates in a batch max_depth_custom_batch_mt: 16, // up to 65k (2^16) custom predicates in a batch
max_operation_args: 5 + 1,
}; };
/// Params: non dynamic parameters that define the circuit. /// Params: non dynamic parameters that define the circuit.
@ -785,7 +788,6 @@ pub struct Params {
pub max_input_pods_public_statements: usize, pub max_input_pods_public_statements: usize,
pub max_statements: usize, pub max_statements: usize,
pub max_public_statements: usize, pub max_public_statements: usize,
pub max_operation_args: usize,
// max number of different custom predicates that can be used in a MainPod // max number of different custom predicates that can be used in a MainPod
pub max_custom_predicates: usize, pub max_custom_predicates: usize,
// max number of operations using custom predicates that can be verified in the MainPod // max number of operations using custom predicates that can be verified in the MainPod
@ -815,7 +817,6 @@ impl Default for Params {
max_input_pods_public_statements: 8, max_input_pods_public_statements: 8,
max_statements: 48, max_statements: 48,
max_public_statements: 8, max_public_statements: 8,
max_operation_args: 5,
max_custom_predicates: 8, max_custom_predicates: 8,
max_custom_predicate_verifications: 8, max_custom_predicate_verifications: 8,
max_custom_predicate_wildcards: 8, max_custom_predicate_wildcards: 8,

View file

@ -14,7 +14,7 @@ use crate::{
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key, hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key,
MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result, MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result,
Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, Value, ValueRef, Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, Value, ValueRef,
Wildcard, F, Wildcard, BASE_PARAMS, F,
}, },
}; };
@ -89,6 +89,7 @@ pub enum NativeOperation {
ContainerInsertFromEntries = 16, ContainerInsertFromEntries = 16,
ContainerUpdateFromEntries = 17, ContainerUpdateFromEntries = 17,
ContainerDeleteFromEntries = 18, ContainerDeleteFromEntries = 18,
ReplaceValueWithEntry = 19,
// Syntactic sugar operations. These operations are not supported by the backend. The // Syntactic sugar operations. These operations are not supported by the backend. The
// frontend compiler is responsible of translating these operations into the operations above. // frontend compiler is responsible of translating these operations into the operations above.
@ -164,6 +165,7 @@ impl OperationType {
NativeOperation::ContainerDeleteFromEntries => { NativeOperation::ContainerDeleteFromEntries => {
Some(Predicate::Native(NativePredicate::ContainerDelete)) Some(Predicate::Native(NativePredicate::ContainerDelete))
} }
NativeOperation::ReplaceValueWithEntry => None,
no => unreachable!("Unexpected syntactic sugar op {:?}", no), no => unreachable!("Unexpected syntactic sugar op {:?}", no),
}, },
OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())), OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())),
@ -219,6 +221,10 @@ pub enum Operation {
/* key */ Statement, /* key */ Statement,
/* proof */ MerkleTreeStateTransitionProof, /* proof */ MerkleTreeStateTransitionProof,
), ),
ReplaceValueWithEntry(
/* Contains/None len=max_statement_args */ Vec<Statement>,
/* to copy */ Statement,
),
Custom(CustomPredicateRef, Vec<Statement>), Custom(CustomPredicateRef, Vec<Statement>),
} }
@ -270,6 +276,7 @@ impl Operation {
OT::Native(ContainerUpdateFromEntries) OT::Native(ContainerUpdateFromEntries)
} }
Self::ContainerDeleteFromEntries(_, _, _, _) => OT::Native(ContainerDeleteFromEntries), Self::ContainerDeleteFromEntries(_, _, _, _) => OT::Native(ContainerDeleteFromEntries),
Self::ReplaceValueWithEntry(_, _) => OT::Native(ReplaceValueWithEntry),
Self::Custom(cpr, _) => OT::Custom(cpr.clone()), Self::Custom(cpr, _) => OT::Custom(cpr.clone()),
} }
} }
@ -295,6 +302,11 @@ impl Operation {
Self::ContainerInsertFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4], Self::ContainerInsertFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4],
Self::ContainerUpdateFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4], Self::ContainerUpdateFromEntries(s1, s2, s3, s4, _pf) => vec![s1, s2, s3, s4],
Self::ContainerDeleteFromEntries(s1, s2, s3, _pf) => vec![s1, s2, s3], Self::ContainerDeleteFromEntries(s1, s2, s3, _pf) => vec![s1, s2, s3],
Self::ReplaceValueWithEntry(args, s) => {
let mut sts = args;
sts.push(s);
sts
}
Self::Custom(_, args) => args, Self::Custom(_, args) => args,
} }
} }
@ -377,6 +389,18 @@ impl Operation {
&[s1, s2, s3], &[s1, s2, s3],
OA::MerkleTreeStateTransitionProof(pf), OA::MerkleTreeStateTransitionProof(pf),
) => Self::ContainerDeleteFromEntries(s1.clone(), s2.clone(), s3.clone(), pf), ) => Self::ContainerDeleteFromEntries(s1.clone(), s2.clone(), s3.clone(), pf),
(NO::ReplaceValueWithEntry, args, OA::None) => {
let mut args = args.to_vec();
if args.len() != BASE_PARAMS.max_statement_args + 1 {
return Err(Error::custom(format!(
"ReplaceValueWithEntry requires exactly {} args but {} were found",
BASE_PARAMS.max_statement_args + 1,
args.len()
)));
}
let st = args.pop().expect("valid vec len");
Self::ReplaceValueWithEntry(args, st)
}
_ => Err(Error::custom(format!( _ => Err(Error::custom(format!(
"Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.", "Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.",
op_code, op_code,
@ -422,6 +446,38 @@ impl Operation {
Ok(sig.verify(pk, msg.raw())) Ok(sig.verify(pk, msg.raw()))
} }
fn check_replace_value_with_entry(
entries: &[Statement],
st_in: &Statement,
expected_st_out: &Statement,
) -> Result<bool> {
if entries.len() != BASE_PARAMS.max_statement_args {
return Ok(false);
}
let args = iter::zip(st_in.args(), entries)
.map(|(arg_in, entry)| match (arg_in, entry) {
(arg_in, Statement::None) => Ok(arg_in),
(
StatementArg::Literal(v_in),
Statement::Contains(
ValueRef::Literal(root),
ValueRef::Literal(key),
ValueRef::Literal(v),
),
) if v == &v_in => Ok(StatementArg::Key(AnchoredKey::new(
Hash::from(root.raw()),
Key::from(key.as_str().ok_or_else(|| Error::custom("not a string"))?),
))),
_ => Err(Error::custom(
"invalid statement argument in ReplaceValueWithEntry",
)),
})
.collect::<Result<Vec<_>>>()?;
let st_out = Statement::from_args(st_in.predicate(), args)?;
Ok(&st_out == expected_st_out)
}
/// Checks the given operation against a statement. /// Checks the given operation against a statement.
pub fn check(&self, params: &Params, output_statement: &Statement) -> Result<bool> { pub fn check(&self, params: &Params, output_statement: &Statement) -> Result<bool> {
use Statement::*; use Statement::*;
@ -541,7 +597,19 @@ impl Operation {
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
if batch == &cpr.batch && index == &cpr.index => if batch == &cpr.batch && index == &cpr.index =>
{ {
check_custom_pred(params, cpr, args, s_args).map(|_| true)? // The custom operation outputs statements with literal arguments. They can be
// replaced by references later with ReplaceValueWithEntry.
let s_args = s_args
.iter()
.map(|arg| match arg {
ValueRef::Literal(v) => Ok(v.clone()),
_ => Err(deduction_err()),
})
.collect::<Result<Vec<_>>>()?;
check_custom_pred(params, cpr, args, &s_args).map(|_| true)?
}
(Self::ReplaceValueWithEntry(entries, st_in), st_out) => {
Self::check_replace_value_with_entry(entries, st_in, st_out)?
} }
_ => return Err(deduction_err()), _ => return Err(deduction_err()),
}; };
@ -648,9 +716,9 @@ pub fn wildcard_values_from_op_st(
params: &Params, params: &Params,
pred: &CustomPredicate, pred: &CustomPredicate,
op_args: &[Statement], op_args: &[Statement],
st_args: &[Value], resolved_st_args: &[Value],
) -> Result<Vec<Value>> { ) -> Result<Vec<Value>> {
let mut wildcard_map = st_args let mut wildcard_map = resolved_st_args
.iter() .iter()
.map(|v| Some(v.clone())) .map(|v| Some(v.clone()))
.chain(core::iter::repeat(None)) .chain(core::iter::repeat(None))

View file

@ -311,7 +311,7 @@ pub enum Statement {
/* old_root */ ValueRef, /* old_root */ ValueRef,
/* key */ ValueRef, /* key */ ValueRef,
), ),
Custom(CustomPredicateRef, Vec<Value>), Custom(CustomPredicateRef, Vec<ValueRef>),
Intro(IntroPredicateRef, Vec<Value>), Intro(IntroPredicateRef, Vec<Value>),
} }
@ -407,7 +407,7 @@ impl Statement {
vec![ak1.into(), ak2.into(), ak3.into(), ak4.into()] vec![ak1.into(), ak2.into(), ak3.into(), ak4.into()]
} }
Self::ContainerDelete(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()], Self::ContainerDelete(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(Literal)), Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(StatementArg::from)),
Self::Intro(_, args) => Vec::from_iter(args.into_iter().map(Literal)), Self::Intro(_, args) => Vec::from_iter(args.into_iter().map(Literal)),
} }
} }
@ -478,14 +478,11 @@ impl Statement {
} }
(BatchSelf(_), _) => unreachable!(), (BatchSelf(_), _) => unreachable!(),
(Custom(cpr), _) => { (Custom(cpr), _) => {
let v_args: Result<Vec<Value>> = args let v_args = args
.iter() .iter()
.map(|x| match x { .map(|x| x.try_into())
StatementArg::Literal(v) => Ok(v.clone()), .collect::<Result<Vec<ValueRef>>>()?;
_ => Err(Error::incorrect_statements_args()), Self::Custom(cpr, v_args)
})
.collect();
Self::Custom(cpr, v_args?)
} }
(Intro(ir), _) => { (Intro(ir), _) => {
let v_args: Result<Vec<Value>> = args let v_args: Result<Vec<Value>> = args