Allow entries as args in custom statements (#498)

- Introduce a new operation ReplaceValueWithEntry that allows taking any statement and replacing literal arguments with entries given a matching Contains statement.
- Allow entries as args in custom statements
- Circuit optimization: For the public statements slots in the circuit we only support None and Copy which take at most 1 argument; but we were still doing max_statement_args random accesses per slot; so I reduced that to just 1 random access to a previous statement.
This commit is contained in:
Eduard S. 2026-04-01 23:49:29 +02:00 committed by GitHub
parent 22d25e5cb2
commit dbd958dcca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 515 additions and 190 deletions

View file

@ -55,7 +55,7 @@ use crate::{
middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix,
ToFields, Value, F, HASH_SIZE,
ToFields, Value, BASE_PARAMS, F, HASH_SIZE,
},
};
//
@ -69,30 +69,37 @@ pub const PI_OFFSET_VDSROOT: usize = 4;
pub const NUM_PUBLIC_INPUTS: usize = 8;
const MAX_VALUE_ARGS: usize = 4;
const MAX_VALUE_ARGS: usize = 5;
struct StatementArgCache {
rhs: ValueTarget,
lhs: StatementArgTarget,
valid: BoolTarget,
pred_is_none: BoolTarget,
is_reference: BoolTarget,
// if `is_reference` then this is the AnchoredKey found in the Contains statement
reference: StatementArgTarget,
// if `is_reference` then this is the value found in the Contains statement
value: ValueTarget,
}
struct StatementCache {
equations: [StatementArgCache; MAX_VALUE_ARGS],
first_n_equations_valid: [BoolTarget; MAX_VALUE_ARGS],
struct StatementCache<const MAX_EQS: usize> {
equations: [StatementArgCache; MAX_EQS],
first_n_equations_valid: [BoolTarget; MAX_EQS],
op_args: Vec<StatementTarget>,
}
impl StatementCache {
impl<const MAX_EQS: usize> StatementCache<MAX_EQS> {
fn new(
params: &Params,
max_operation_args: usize,
builder: &mut CircuitBuilder,
op: &OperationTarget,
st: &StatementTarget,
prev_statements: &[StatementTarget],
) -> Self {
let op_args = if prev_statements.is_empty() {
(0..params.max_operation_args)
(0..max_operation_args)
.map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[]))
.collect_vec()
} else {
@ -100,10 +107,10 @@ impl StatementCache {
// converting a length 1 array into a scalar.
op.args
.iter()
.take(max_operation_args)
.map(|i| builder.vec_ref(params, prev_statements, i))
.collect::<Vec<_>>()
};
assert!(params.max_operation_args >= MAX_VALUE_ARGS);
assert!(Params::max_statement_args() >= MAX_VALUE_ARGS);
let equations = array::from_fn(|i| {
let pred_is_none = op_args[i].has_native_type(builder, NativePredicate::None);
@ -117,9 +124,9 @@ impl StatementCache {
let is_reference = builder.and(pred_is_contains, ref_is_value);
let valid = builder.or(is_literal, is_reference);
let rhs_literal = st.args[i].as_value();
let rhs_reference = op_args[i].args[2].as_value();
let rhs = builder.select_value(pred_is_none, rhs_literal, rhs_reference);
let rhs_from_literal = st.args[i].as_value();
let rhs_from_reference = op_args[i].args[2].as_value();
let rhs = builder.select_value(pred_is_none, rhs_from_literal, rhs_from_reference);
let lhs_literal = &st.args[i];
let lhs_reference = StatementArgTarget::anchored_key(
builder,
@ -127,10 +134,22 @@ impl StatementCache {
&op_args[i].args[1].as_value(),
);
let lhs = builder.select_statement_arg(pred_is_none, lhs_literal, &lhs_reference);
StatementArgCache { rhs, lhs, valid }
StatementArgCache {
rhs,
lhs,
valid,
pred_is_none,
is_reference,
reference: lhs_reference,
value: rhs_from_reference,
}
});
let mut first_n_equations_valid = [equations[0].valid; MAX_VALUE_ARGS];
for i in 1..MAX_VALUE_ARGS {
let mut first_n_equations_valid = if MAX_EQS != 0 {
[equations[0].valid; MAX_EQS]
} else {
[builder._false(); MAX_EQS]
};
for i in 1..MAX_EQS {
first_n_equations_valid[i] =
builder.and(equations[i].valid, first_n_equations_valid[i - 1]);
}
@ -145,7 +164,7 @@ impl StatementCache {
///
/// If the operation argument is a statement of type `None`, then the value
/// should be the corresponding argument of the current statement.
/// If the operation argument is a statement of type `Equals`, then the value
/// If the operation argument is a statement of type `Contains`, then the value
/// should be the argument at index 1 of that statement.
/// If the function successfully interprets the arguments as values,
/// returns `True` along with those values. Otherwise, returns `False`
@ -158,6 +177,12 @@ impl StatementCache {
}
}
/// Statement cache for private statements
type StatementCachePriv = StatementCache<MAX_VALUE_ARGS>;
/// Statement cache for public statements. Since the operations can only be None or Copy, no
/// equation is needed because none of these operations dereference entries.
type StatementCachePub = StatementCache<0>;
/// Specialized implementation of `verify_operation_circuit` for operations that generate public
/// statement. This only allows operations to be None, NewEntry or Copy and accounts for the fact
/// that public statements in the current implementation are always generated by copying private
@ -169,13 +194,15 @@ fn verify_operation_public_statement_circuit(
op: &OperationTarget,
prev_statements: &[StatementTarget],
) -> Result<()> {
let measure = measure_gates_begin!(builder, "OpVerify");
let measure = measure_gates_begin!(builder, "OpVerifyPub");
// Verify that the operation `op` correctly generates the statement `st`. The operation
// can reference any of the `prev_statements`.
// TODO: Clean this up.
let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs");
let cache = StatementCache::new(params, builder, op, st, prev_statements);
// None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the
// StatementCache requires.
let cache = StatementCachePub::new(params, 1, builder, op, st, prev_statements);
measure_gates_end!(builder, measure_resolve_op_args);
let op_checks = vec![
@ -406,7 +433,7 @@ fn verify_operation_circuit(
prev_statements: &[StatementTarget],
aux_table: &MuxTableTarget,
) -> Result<()> {
let measure = measure_gates_begin!(builder, "OpVerify");
let measure = measure_gates_begin!(builder, "OpVerifyPriv");
let _true = builder._true();
let _false = builder._false();
@ -414,7 +441,14 @@ fn verify_operation_circuit(
// can reference any of the `prev_statements`.
// TODO: Clean this up.
let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs");
let cache = StatementCache::new(params, builder, op, st, prev_statements);
let cache = StatementCachePriv::new(
params,
BASE_PARAMS.max_operation_args,
builder,
op,
st,
prev_statements,
);
measure_gates_end!(builder, measure_resolve_op_args);
// Certain operations (e.g.: Contains/NotContains) will refer to one of the provided verified
@ -442,6 +476,7 @@ fn verify_operation_circuit(
verify_sum_of_circuit(params, builder, st, &op.op_type, &cache),
verify_product_of_circuit(params, builder, st, &op.op_type, &cache),
verify_max_of_circuit(params, builder, st, &op.op_type, &cache),
verify_replace_value_with_entry_circuit(params, builder, st, &op.op_type, &cache),
]);
}
// Skip these if there are no resolved aux entries
@ -542,7 +577,7 @@ fn verify_contains_from_entries_circuit(
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpContainsFromEntries");
let (aux_tag_ok, resolved_merkle_claim) =
@ -592,7 +627,7 @@ fn verify_not_contains_from_entries_circuit(
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries");
let (aux_tag_ok, resolved_merkle_claim) =
@ -639,7 +674,7 @@ fn verify_merkle_insert_circuit(
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "MerkleInsertOp");
let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) =
@ -714,7 +749,7 @@ fn verify_merkle_update_circuit(
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "MerkleUpdateOp");
let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) =
@ -789,7 +824,7 @@ fn verify_merkle_delete_circuit(
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "MerkleDeleteOp");
let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) =
@ -883,7 +918,7 @@ fn verify_eq_neq_from_entries_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries");
let eq_op_st_code_ok = {
@ -932,9 +967,9 @@ fn verify_lt_lteq_from_entries_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpLtLteqFromEntries");
let measure = measure_gates_begin!(builder, "OpLtEqFromEntries");
let zero = ValueTarget::zero(builder);
let one = ValueTarget::one(builder);
@ -1000,7 +1035,7 @@ fn verify_hash_of_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpHashOf");
let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf);
@ -1033,7 +1068,7 @@ fn verify_public_key_of_circuit(
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpPublicKeyOf");
let (aux_tag_ok, resolved_pk_sk) =
@ -1069,7 +1104,7 @@ fn verify_signed_by_circuit(
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpSignedBy");
let (aux_tag_ok, resolved_msg_pk) =
@ -1104,7 +1139,7 @@ fn verify_sum_of_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpSumOf");
let value_zero = ValueTarget::zero(builder);
@ -1142,7 +1177,7 @@ fn verify_product_of_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpProductOf");
let value_zero = ValueTarget::zero(builder);
@ -1180,7 +1215,7 @@ fn verify_max_of_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCache,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpMaxOf");
let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf);
@ -1220,6 +1255,47 @@ fn verify_max_of_circuit(
ok
}
fn verify_replace_value_with_entry_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpReplaceValueWithEntry");
let op_code_ok = op_type.has_native(builder, NativeOperation::ReplaceValueWithEntry);
let st_in = &cache.op_args[BASE_PARAMS.max_statement_args];
let mut args = Vec::new();
let mut args_ok = builder._true();
for (arg_in, entry_cache) in zip_eq(&st_in.args, &cache.equations) {
// if the op_arg is None, keep the original argument, if it's a Contains swap the value by
// the reference Entry while checking that the value in Contains matches the original
// argument.
let arg = builder.select_flattenable(
params,
entry_cache.pred_is_none,
arg_in,
&entry_cache.reference,
);
args.push(arg);
let arg_ref_ok = {
let arg_in_is_value = builder.statement_arg_is_value(arg_in);
let value_eq = builder.is_equal_flattenable(&arg_in.as_value(), &entry_cache.value);
builder.all([entry_cache.is_reference, arg_in_is_value, value_eq])
};
let arg_ok = builder.or(entry_cache.pred_is_none, arg_ref_ok);
args_ok = builder.and(args_ok, arg_ok);
}
let expected_statement = StatementTarget::new(*st_in.pred_hash(), args);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, args_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_transitive_eq_circuit(
params: &Params,
builder: &mut CircuitBuilder,
@ -1429,7 +1505,7 @@ fn make_custom_statement_circuit(
) -> Result<(StatementTarget, OperationTypeTarget)> {
let measure = measure_gates_begin!(builder, "CustomOpVerify");
// Some sanity checks
assert_eq!(params.max_operation_args, op_args.len());
assert_eq!(BASE_PARAMS.max_operation_args, op_args.len());
assert_eq!(params.max_custom_predicate_wildcards, args.len());
let (batch_id, index) = (custom_predicate.id, custom_predicate.index);
@ -1463,7 +1539,6 @@ fn make_custom_statement_circuit(
.collect();
// expected_sts.len() == params.max_custom_predicate_arity
// op_args.len() == params.max_operation_args;
assert!(Params::max_custom_predicate_arity() <= params.max_operation_args);
let sts_eq: Vec<_> = expected_sts
.iter()
@ -2076,7 +2151,8 @@ mod tests {
frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
middleware::{
hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard,
RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard, EMPTY_VALUE,
RawValue, StatementArg, StatementTmpl, StatementTmplArg, ValueRef, Wildcard,
BASE_PARAMS, EMPTY_VALUE,
},
};
@ -3068,6 +3144,33 @@ mod tests {
operation_verify(st, op, prev_statements, Aux::signed_by(signed_by))
}
#[test]
fn test_operation_replace_value_with_entry() -> Result<()> {
let d = dict!({"a" => 42, "b" => 33});
// 0: None
// 1: Lt(5, 42)
let st_in: mainpod::Statement = Statement::lt(5, 42).into();
// 2: Contains(d, "a", 42)
let st_entry: mainpod::Statement = Statement::contains(d.clone(), "a", 42).into();
let st_out: mainpod::Statement =
Statement::lt(5, ValueRef::Key(AnchoredKey::from((&d, "a")))).into();
let mut op_args: Vec<_> = iter::repeat(OperationArg::None)
.take(BASE_PARAMS.max_statement_args + 1)
.collect();
op_args[1] = OperationArg::Index(2);
op_args[BASE_PARAMS.max_statement_args] = OperationArg::Index(1);
let op = mainpod::Operation(
OperationType::Native(NativeOperation::ReplaceValueWithEntry),
op_args,
OperationAux::None,
);
let prev_statements = vec![Statement::None.into(), st_in, st_entry];
operation_verify(st_out, op, prev_statements, Aux::default())
}
fn helper_statement_arg_from_template(
params: &Params,
st_tmpl_arg: StatementTmplArg,
@ -3226,7 +3329,7 @@ mod tests {
expected_st: Option<Statement>,
) -> Result<()> {
// Pad
for _ in op_args.len()..params.max_operation_args {
for _ in op_args.len()..BASE_PARAMS.max_operation_args {
op_args.push(Statement::None);
}
for _ in args.len()..params.max_custom_predicate_wildcards {
@ -3275,6 +3378,10 @@ mod tests {
Ok(data.verify(proof.clone())?)
}
fn value_ref(v: impl Into<ValueRef>) -> ValueRef {
v.into()
}
// TODO: Add negative tests
#[test]
fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> {
@ -3309,7 +3416,7 @@ mod tests {
let args = vec![Value::from(dict), Value::from(1234)];
let expected_st = Statement::Custom(
custom_predicate.clone(),
vec![args[0].clone(), Value::from(0)],
vec![value_ref(args[0].clone()), value_ref(0)],
);
helper_custom_operation_verify_gadget(
@ -3330,7 +3437,7 @@ mod tests {
let args = vec![Value::from(dict), Value::from(0)];
let expected_st = Statement::Custom(
custom_predicate.clone(),
vec![args[0].clone(), Value::from(0)],
vec![value_ref(args[0].clone()), value_ref(0)],
);
helper_custom_operation_verify_gadget(
@ -3351,7 +3458,7 @@ mod tests {
let args = vec![Value::from(dict), Value::from(1234)];
let expected_st = Statement::Custom(
custom_predicate.clone(),
vec![args[0].clone(), Value::from(0)],
vec![value_ref(args[0].clone()), value_ref(0)],
);
helper_custom_operation_verify_gadget(
@ -3403,7 +3510,7 @@ mod tests {
let args = vec![Value::from(dict), Value::from(secret_dict)];
let expected_st = Statement::Custom(
custom_predicate.clone(),
vec![args[0].clone(), Value::from(0)],
vec![value_ref(args[0].clone()), value_ref(0)],
);
helper_custom_operation_verify_gadget(