pod2/src/backends/plonky2/circuits/mainpod.rs
Eduard S. dbd958dcca
Allow entries as args in custom statements (#498)
- Introduce a new operation ReplaceValueWithEntry that allows taking any statement and replacing literal arguments with entries given a matching Contains statement.
- Allow entries as args in custom statements
- Circuit optimization: For the public statements slots in the circuit we only support None and Copy which take at most 1 argument; but we were still doing max_statement_args random accesses per slot; so I reduced that to just 1 random access to a previous statement.
2026-04-01 23:49:29 +02:00

3810 lines
143 KiB
Rust

use std::{array, iter};
use itertools::{izip, zip_eq, Itertools};
use num::{BigUint, One};
use plonky2::{
field::types::Field,
hash::{
hash_types::HashOutTarget,
poseidon::{PoseidonHash, PoseidonPermutation},
},
iop::{
target::{BoolTarget, Target},
witness::{PartialWitness, WitnessWrite},
},
};
use plonky2_u32::gadgets::multiple_comparison::list_le_circuit;
use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::{
basetypes::{CircuitBuilder, VDSet},
circuits::{
common::{
CircuitBuilderPod, CustomPredicateEntryTarget, CustomPredicateInBatchTarget,
CustomPredicateTarget, CustomPredicateVerifyEntryTarget,
CustomPredicateVerifyQueryTarget, Flattenable, MerkleClaimTarget,
MerkleTreeStateTransitionClaimTarget, OperationTarget, OperationTypeTarget,
PredicateHashOrWildcardTarget, PredicateTarget, StatementArgTarget,
StatementTarget, StatementTmplArgTarget, StatementTmplTarget, ValueTarget,
},
hash::{hash_from_state_circuit, precompute_hash_state},
mux_table::{MuxTableTarget, TableEntryTarget},
},
emptypod::EmptyPod,
error::Result,
mainpod::{self, pad_statement, SignedBy},
primitives::{
ec::{
bits::{BigUInt320Target, CircuitBuilderBits},
curve::{
CircuitBuilderElliptic, Point, PointTarget, WitnessWriteCurve, GROUP_ORDER,
},
schnorr::{CircuitBuilderSchnorr, SecretKey, SignatureTarget, WitnessWriteSchnorr},
},
merkletree::{
verify_merkle_proof_circuit, verify_merkle_state_transition_circuit,
MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProof, MerkleTreeOp,
MerkleTreeStateTransitionProof, MerkleTreeStateTransitionProofTarget,
},
signature::{verify_signature_circuit, SignatureVerifyTarget},
},
recursion::{InnerCircuit, VerifiedProofTarget},
},
measure_gates_begin, measure_gates_end,
middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix,
ToFields, Value, BASE_PARAMS, F, HASH_SIZE,
},
};
//
// MainPod verification
//
/// Offset in public inputs where we store the statements hash
pub const PI_OFFSET_STATEMENTS_HASH: usize = 0;
/// Offset in public inputs where we store the verified data array root
pub const PI_OFFSET_VDSROOT: usize = 4;
pub const NUM_PUBLIC_INPUTS: usize = 8;
const MAX_VALUE_ARGS: usize = 5;
struct StatementArgCache {
rhs: ValueTarget,
lhs: StatementArgTarget,
valid: BoolTarget,
pred_is_none: BoolTarget,
is_reference: BoolTarget,
// if `is_reference` then this is the AnchoredKey found in the Contains statement
reference: StatementArgTarget,
// if `is_reference` then this is the value found in the Contains statement
value: ValueTarget,
}
struct StatementCache<const MAX_EQS: usize> {
equations: [StatementArgCache; MAX_EQS],
first_n_equations_valid: [BoolTarget; MAX_EQS],
op_args: Vec<StatementTarget>,
}
impl<const MAX_EQS: usize> StatementCache<MAX_EQS> {
fn new(
params: &Params,
max_operation_args: usize,
builder: &mut CircuitBuilder,
op: &OperationTarget,
st: &StatementTarget,
prev_statements: &[StatementTarget],
) -> Self {
let op_args = if prev_statements.is_empty() {
(0..max_operation_args)
.map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[]))
.collect_vec()
} else {
// `op.args` is a vector of arrays of length 1, so `.flatten()` is just
// converting a length 1 array into a scalar.
op.args
.iter()
.take(max_operation_args)
.map(|i| builder.vec_ref(params, prev_statements, i))
.collect::<Vec<_>>()
};
assert!(Params::max_statement_args() >= MAX_VALUE_ARGS);
let equations = array::from_fn(|i| {
let pred_is_none = op_args[i].has_native_type(builder, NativePredicate::None);
let arg_is_value = builder.statement_arg_is_value(&st.args[i]);
let is_literal = builder.and(pred_is_none, arg_is_value);
let pred_is_contains = op_args[i].has_native_type(builder, NativePredicate::Contains);
let ref_is_value_arg: [_; 3] =
array::from_fn(|j| builder.statement_arg_is_value(&op_args[i].args[j]));
let ref_is_value = builder.and(ref_is_value_arg[0], ref_is_value_arg[1]);
let ref_is_value = builder.and(ref_is_value, ref_is_value_arg[2]);
let is_reference = builder.and(pred_is_contains, ref_is_value);
let valid = builder.or(is_literal, is_reference);
let rhs_from_literal = st.args[i].as_value();
let rhs_from_reference = op_args[i].args[2].as_value();
let rhs = builder.select_value(pred_is_none, rhs_from_literal, rhs_from_reference);
let lhs_literal = &st.args[i];
let lhs_reference = StatementArgTarget::anchored_key(
builder,
&op_args[i].args[0].as_value(),
&op_args[i].args[1].as_value(),
);
let lhs = builder.select_statement_arg(pred_is_none, lhs_literal, &lhs_reference);
StatementArgCache {
rhs,
lhs,
valid,
pred_is_none,
is_reference,
reference: lhs_reference,
value: rhs_from_reference,
}
});
let mut first_n_equations_valid = if MAX_EQS != 0 {
[equations[0].valid; MAX_EQS]
} else {
[builder._false(); MAX_EQS]
};
for i in 1..MAX_EQS {
first_n_equations_valid[i] =
builder.and(equations[i].valid, first_n_equations_valid[i - 1]);
}
StatementCache {
equations,
first_n_equations_valid,
op_args,
}
}
/// Attempts to interpret the first `N` arguments as values.
///
/// If the operation argument is a statement of type `None`, then the value
/// should be the corresponding argument of the current statement.
/// If the operation argument is a statement of type `Contains`, then the value
/// should be the argument at index 1 of that statement.
/// If the function successfully interprets the arguments as values,
/// returns `True` along with those values. Otherwise, returns `False`
/// along with some arbitrary values.
fn first_n_args_as_values<const N: usize>(&self) -> (BoolTarget, [ValueTarget; N]) {
(
self.first_n_equations_valid[N - 1],
array::from_fn(|i| self.equations[i].rhs),
)
}
}
/// Statement cache for private statements
type StatementCachePriv = StatementCache<MAX_VALUE_ARGS>;
/// Statement cache for public statements. Since the operations can only be None or Copy, no
/// equation is needed because none of these operations dereference entries.
type StatementCachePub = StatementCache<0>;
/// Specialized implementation of `verify_operation_circuit` for operations that generate public
/// statement. This only allows operations to be None, NewEntry or Copy and accounts for the fact
/// that public statements in the current implementation are always generated by copying private
/// statements (or NewEntry for the `KEY_TYPE` public entry).
fn verify_operation_public_statement_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op: &OperationTarget,
prev_statements: &[StatementTarget],
) -> Result<()> {
let measure = measure_gates_begin!(builder, "OpVerifyPub");
// Verify that the operation `op` correctly generates the statement `st`. The operation
// can reference any of the `prev_statements`.
// TODO: Clean this up.
let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs");
// None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the
// StatementCache requires.
let cache = StatementCachePub::new(params, 1, builder, op, st, prev_statements);
measure_gates_end!(builder, measure_resolve_op_args);
let op_checks = vec![
verify_none_circuit(params, builder, st, &op.op_type),
verify_copy_circuit(builder, st, &op.op_type, &cache.op_args),
];
let ok = builder.any(op_checks);
builder.assert_one(ok.target);
measure_gates_end!(builder, measure);
Ok(())
}
enum OperationAuxTableTag {
None = 0,
MerkleProof = 1,
PublicKeyOf = 2,
SignedBy = 3,
MerkleTreeStateTransitionProof = 4,
CustomPredVerify = 5,
}
fn max_operation_aux_entry_len(params: &Params) -> usize {
[
(params.max_merkle_proofs_containers > 0).then(|| MerkleClaimTarget::size(params)),
(params.max_public_key_of > 0).then(|| PubKeySecKeyTarget::size(params)),
(params.max_signed_by > 0).then(|| MsgPubKeyTarget::size(params)),
(params.max_merkle_tree_state_transition_proofs_containers > 0)
.then(|| MerkleTreeStateTransitionClaimTarget::size(params)),
(params.max_custom_predicate_verifications > 0)
.then(|| CustomPredicateVerifyQueryTarget::size(params)),
]
.into_iter()
.flatten()
.max()
.unwrap_or(0)
}
#[derive(Copy, Clone)]
struct HashPairTarget(HashOutTarget, HashOutTarget);
impl Flattenable for HashPairTarget {
fn flatten(&self) -> Vec<Target> {
self.0.elements.into_iter().chain(self.1.elements).collect()
}
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params));
Self(
HashOutTarget::try_from(&vs[..4]).expect("len = 4"),
HashOutTarget::try_from(&vs[4..]).expect("len = 4"),
)
}
fn size(_params: &Params) -> usize {
8
}
}
type PubKeySecKeyTarget = HashPairTarget; // (public_key, secret_key)
type MsgPubKeyTarget = HashPairTarget; // (message, public_key)
#[derive(Clone, Serialize, Deserialize)]
struct SignedByTarget {
msg: ValueTarget,
pk: PointTarget,
sig: SignatureTarget,
}
impl SignedByTarget {
pub fn set_targets(&self, pw: &mut PartialWitness<F>, signed_by: &SignedBy) -> Result<()> {
self.msg.set_targets(pw, &Value::from(signed_by.msg))?;
pw.set_point_target(&self.pk, &signed_by.pk)?;
pw.set_signature_target(&self.sig, &signed_by.sig)?;
Ok(())
}
pub fn new_virtual(builder: &mut CircuitBuilder) -> Self {
Self {
msg: builder.add_virtual_value(),
pk: builder.add_virtual_point_target(),
sig: builder.add_virtual_schnorr_signature_target(),
}
}
}
#[allow(clippy::too_many_arguments)]
fn build_operation_aux_table_circuit(
params: &Params,
builder: &mut CircuitBuilder,
merkle_proofs: &[MerkleClaimAndProofTarget],
public_key_of_sks: &[BigUInt320Target],
signed_bys: &[SignedByTarget],
merkle_tree_state_transition_proofs: &[MerkleTreeStateTransitionProofTarget],
custom_predicate_verifications: &[CustomPredicateVerifyEntryTarget],
custom_predicate_table: &[HashOutTarget],
) -> Result<MuxTableTarget> {
let measure = measure_gates_begin!(builder, "BuildOpAuxTbl");
assert_eq!(
params.max_custom_predicate_verifications,
custom_predicate_verifications.len()
);
assert_eq!(params.max_merkle_proofs_containers, merkle_proofs.len());
let max_entry_len = max_operation_aux_entry_len(params);
let mut table = MuxTableTarget::new(params, max_entry_len);
// None
table.push_flattened(builder, OperationAuxTableTag::None as u32, &[]);
// MerkleProofs: verify container merkle proofs (inclusion/non-inclusion)
for merkle_proof in merkle_proofs {
verify_merkle_proof_circuit(builder, merkle_proof);
let entry = MerkleClaimTarget::from(merkle_proof.clone());
table.push(builder, OperationAuxTableTag::MerkleProof as u32, &entry);
}
// PublicKeyOf: verify the derivation from a Schnorr secret key to public key
for sk in public_key_of_sks {
let measure = measure_gates_begin!(builder, "PublicKeyOf");
let invgenerator = builder.constant_point(Point::generator().inverse());
let group_orderm1 = &*GROUP_ORDER - BigUint::one();
let group_orderm1target = builder.constant_biguint320(&group_orderm1);
let compare_ok = list_le_circuit(
builder,
sk.limbs.to_vec(),
group_orderm1target.limbs.to_vec(),
32,
);
builder.assert_one(compare_ok.target);
// public_key = g^-secret key
let pk = builder.multiply_point(&sk.bits, &invgenerator);
let sk_hash = builder.hash_n_to_hash_no_pad::<PoseidonHash>(sk.limbs.to_vec());
let pk_hash = builder.hash_n_to_hash_no_pad::<PoseidonHash>(
pk.x.components.into_iter().chain(pk.u.components).collect(),
);
let entry: PubKeySecKeyTarget = HashPairTarget(pk_hash, sk_hash);
table.push(builder, OperationAuxTableTag::PublicKeyOf as u32, &entry);
measure_gates_end!(builder, measure);
}
// SignedBy: verify the Schnorr signature of a message with a public key
for signed_by in signed_bys {
let measure = measure_gates_begin!(builder, "SignedBy");
let signature_verify = SignatureVerifyTarget {
enabled: builder._true(),
pk: signed_by.pk.clone(),
msg: signed_by.msg,
sig: signed_by.sig.clone(),
};
verify_signature_circuit(builder, &signature_verify);
// TODO: Add a function to hash the public key
let pk_hash = builder.hash_n_to_hash_no_pad::<PoseidonHash>(
signed_by
.pk
.x
.components
.into_iter()
.chain(signed_by.pk.u.components)
.collect(),
);
let entry: MsgPubKeyTarget = HashPairTarget(HashOutTarget::from(signed_by.msg), pk_hash);
table.push(builder, OperationAuxTableTag::SignedBy as u32, &entry);
measure_gates_end!(builder, measure);
}
// Merkle state transition proofs: verify op proof (insert/update/delete)
for merkle_tree_state_transition_proof in merkle_tree_state_transition_proofs {
verify_merkle_state_transition_circuit(builder, merkle_tree_state_transition_proof);
let entry =
MerkleTreeStateTransitionClaimTarget::from(merkle_tree_state_transition_proof.clone());
table.push(
builder,
OperationAuxTableTag::MerkleTreeStateTransitionProof as u32,
&entry,
);
}
// CustomPredVerify: verify custom predicate statements verification against operations
for entry in custom_predicate_verifications {
let measure = measure_gates_begin!(builder, "CustomPredVerify");
// Verify the custom predicate operation
let (statement, op_type) = make_custom_statement_circuit(
params,
builder,
&entry.custom_predicate,
&entry.op_args,
&entry.args,
)?;
// Check that the batch id is correct by querying the custom predicate batches table
let table_query_hash = builder.vec_ref(
params,
custom_predicate_table,
&entry.custom_predicate_table_index,
);
let out_query_hash = entry.custom_predicate.hash(builder);
builder.connect_array(table_query_hash.elements, out_query_hash.elements);
let query = CustomPredicateVerifyQueryTarget {
statement, // output
op_type, // output
op_args: entry.op_args.clone(), // input
};
table.push(
builder,
OperationAuxTableTag::CustomPredVerify as u32,
&query,
);
measure_gates_end!(builder, measure);
}
measure_gates_end!(builder, measure);
Ok(table)
}
#[allow(clippy::too_many_arguments)]
fn verify_operation_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op: &OperationTarget,
prev_statements: &[StatementTarget],
aux_table: &MuxTableTarget,
) -> Result<()> {
let measure = measure_gates_begin!(builder, "OpVerifyPriv");
let _true = builder._true();
let _false = builder._false();
// Verify that the operation `op` correctly generates the statement `st`. The operation
// can reference any of the `prev_statements`.
// TODO: Clean this up.
let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs");
let cache = StatementCachePriv::new(
params,
BASE_PARAMS.max_operation_args,
builder,
op,
st,
prev_statements,
);
measure_gates_end!(builder, measure_resolve_op_args);
// Certain operations (e.g.: Contains/NotContains) will refer to one of the provided verified
// entries in a table (e.g.: Merkle proofs ). These entries have already been verified, so we
// need only look up the claim.
// The aux table always has a fixed zero entry, so we check if there are more than 1 entries to
// trigger the unhashing.
let resolved_aux = (aux_table.len() > 1).then(|| aux_table.get(builder, &op.aux_index));
// Op checks to carry out. Each 'verify_X_circuit' should be thought of as operation check
// restricted to the op of type X, where the returned target is `false` if the input targets
// lie outside of the domain.
let mut op_checks = Vec::new();
op_checks.extend_from_slice(&[verify_none_circuit(params, builder, st, &op.op_type)]);
// Skip these if there are no resolved op args
if !cache.op_args.is_empty() {
op_checks.extend_from_slice(&[
verify_copy_circuit(builder, st, &op.op_type, &cache.op_args),
verify_eq_neq_from_entries_circuit(builder, st, &op.op_type, &cache),
verify_lt_lteq_from_entries_circuit(builder, st, &op.op_type, &cache),
verify_transitive_eq_circuit(params, builder, st, &op.op_type, &cache.op_args),
verify_lt_to_neq_circuit(params, builder, st, &op.op_type, &cache.op_args),
verify_hash_of_circuit(params, builder, st, &op.op_type, &cache),
verify_sum_of_circuit(params, builder, st, &op.op_type, &cache),
verify_product_of_circuit(params, builder, st, &op.op_type, &cache),
verify_max_of_circuit(params, builder, st, &op.op_type, &cache),
verify_replace_value_with_entry_circuit(params, builder, st, &op.op_type, &cache),
]);
}
// Skip these if there are no resolved aux entries
if let Some(resolved_aux) = resolved_aux {
if params.max_merkle_proofs_containers > 0 {
op_checks.extend_from_slice(&[
verify_contains_from_entries_circuit(
params,
builder,
st,
&op.op_type,
&resolved_aux,
&cache,
),
verify_not_contains_from_entries_circuit(
params,
builder,
st,
&op.op_type,
&resolved_aux,
&cache,
),
]);
}
if params.max_public_key_of > 0 {
op_checks.push(verify_public_key_of_circuit(
params,
builder,
st,
&op.op_type,
&resolved_aux,
&cache,
));
}
if params.max_signed_by > 0 {
op_checks.push(verify_signed_by_circuit(
params,
builder,
st,
&op.op_type,
&resolved_aux,
&cache,
));
}
if params.max_merkle_tree_state_transition_proofs_containers > 0 {
op_checks.extend_from_slice(&[
verify_merkle_insert_circuit(
params,
builder,
st,
&op.op_type,
&resolved_aux,
&cache,
),
verify_merkle_update_circuit(
params,
builder,
st,
&op.op_type,
&resolved_aux,
&cache,
),
verify_merkle_delete_circuit(
params,
builder,
st,
&op.op_type,
&resolved_aux,
&cache,
),
]);
}
if params.max_custom_predicate_verifications > 0 {
op_checks.push(verify_custom_circuit(
builder,
st,
&op.op_type,
&resolved_aux,
&cache.op_args,
));
}
}
let ok = builder.any(op_checks);
builder.assert_one(ok.target);
measure_gates_end!(builder, measure);
Ok(())
}
//
// Native operation constraints
//
fn verify_contains_from_entries_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpContainsFromEntries");
let (aux_tag_ok, resolved_merkle_claim) =
aux.as_type::<MerkleClaimTarget>(builder, OperationAuxTableTag::MerkleProof as u32);
let op_code_ok = op_type.has_native(builder, NativeOperation::ContainsFromEntries);
let (arg_types_ok, [merkle_root_value, key_value, value_value]) =
cache.first_n_args_as_values();
// Check Merkle proof (verified elsewhere) against op args.
let merkle_proof_checks = [
/* The supplied Merkle proof must be enabled. */
resolved_merkle_claim.enabled,
/* ...and it must be an existence proof. */
resolved_merkle_claim.existence,
/* ...for the root-key-value triple in the resolved op args. */
builder.is_equal_slice(
&merkle_root_value.elements,
&resolved_merkle_claim.root.elements,
),
builder.is_equal_slice(&key_value.elements, &resolved_merkle_claim.key.elements),
builder.is_equal_slice(&value_value.elements, &resolved_merkle_claim.value.elements),
];
let merkle_proof_ok = builder.all(merkle_proof_checks);
// Check output statement
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let arg3_expected = cache.equations[2].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::Contains,
&[arg1_expected, arg2_expected, arg3_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, aux_tag_ok, arg_types_ok, merkle_proof_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_not_contains_from_entries_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpNotContainsFromEntries");
let (aux_tag_ok, resolved_merkle_claim) =
aux.as_type::<MerkleClaimTarget>(builder, OperationAuxTableTag::MerkleProof as u32);
let op_code_ok = op_type.has_native(builder, NativeOperation::NotContainsFromEntries);
let (arg_types_ok, [merkle_root_value, key_value]) = cache.first_n_args_as_values();
// Check Merkle proof (verified elsewhere) against op args.
let merkle_proof_checks = [
/* The supplied Merkle proof must be enabled. */
resolved_merkle_claim.enabled,
/* ...and it must be a nonexistence proof. */
builder.not(resolved_merkle_claim.existence),
/* ...for the root-key pair in the resolved op args. */
builder.is_equal_slice(
&merkle_root_value.elements,
&resolved_merkle_claim.root.elements,
),
builder.is_equal_slice(&key_value.elements, &resolved_merkle_claim.key.elements),
];
let merkle_proof_ok = builder.all(merkle_proof_checks);
// Check output statement
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::NotContains,
&[arg1_expected, arg2_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, aux_tag_ok, arg_types_ok, merkle_proof_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_merkle_insert_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "MerkleInsertOp");
let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) =
aux.as_type::<MerkleTreeStateTransitionClaimTarget>(
builder,
OperationAuxTableTag::MerkleTreeStateTransitionProof as u32,
);
let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerInsertFromEntries);
let (arg_types_ok, [new_root_value, old_root_value, op_key_value, op_value_value]) =
cache.first_n_args_as_values();
let expected_merkle_op = builder.constant(F::from_canonical_u8(MerkleTreeOp::Insert as u8));
// Check Merkle proof (verified elsewhere) against op args.
let merkle_proof_checks = [
/* The supplied Merkle transition proof must be enabled. */
resolved_merkle_tree_state_transition_claim.enabled,
/* ...and it must be an insertion proof. */
builder.is_equal(
resolved_merkle_tree_state_transition_claim.op,
expected_merkle_op,
),
/* ...for the root-key-value combination in the resolved op args. */
builder.is_equal_slice(
&old_root_value.elements,
&resolved_merkle_tree_state_transition_claim
.old_root
.elements,
),
builder.is_equal_slice(
&new_root_value.elements,
&resolved_merkle_tree_state_transition_claim
.new_root
.elements,
),
builder.is_equal_slice(
&op_key_value.elements,
&resolved_merkle_tree_state_transition_claim.op_key.elements,
),
builder.is_equal_slice(
&op_value_value.elements,
&resolved_merkle_tree_state_transition_claim
.op_value
.elements,
),
];
let merkle_proof_ok = builder.all(merkle_proof_checks);
// Check output statement
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let arg3_expected = cache.equations[2].lhs.clone();
let arg4_expected = cache.equations[3].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::ContainerInsert,
&[arg1_expected, arg2_expected, arg3_expected, arg4_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, aux_tag_ok, arg_types_ok, merkle_proof_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_merkle_update_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "MerkleUpdateOp");
let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) =
aux.as_type::<MerkleTreeStateTransitionClaimTarget>(
builder,
OperationAuxTableTag::MerkleTreeStateTransitionProof as u32,
);
let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerUpdateFromEntries);
let (arg_types_ok, [new_root_value, old_root_value, op_key_value, op_value_value]) =
cache.first_n_args_as_values();
let expected_merkle_op = builder.constant(F::from_canonical_u8(MerkleTreeOp::Update as u8));
// Check Merkle proof (verified elsewhere) against op args.
let merkle_proof_checks = [
/* The supplied Merkle transition proof must be enabled. */
resolved_merkle_tree_state_transition_claim.enabled,
/* ...and it must be an update proof. */
builder.is_equal(
resolved_merkle_tree_state_transition_claim.op,
expected_merkle_op,
),
/* ...for the root-key-value combination in the resolved op args. */
builder.is_equal_slice(
&old_root_value.elements,
&resolved_merkle_tree_state_transition_claim
.old_root
.elements,
),
builder.is_equal_slice(
&new_root_value.elements,
&resolved_merkle_tree_state_transition_claim
.new_root
.elements,
),
builder.is_equal_slice(
&op_key_value.elements,
&resolved_merkle_tree_state_transition_claim.op_key.elements,
),
builder.is_equal_slice(
&op_value_value.elements,
&resolved_merkle_tree_state_transition_claim
.op_value
.elements,
),
];
let merkle_proof_ok = builder.all(merkle_proof_checks);
// Check output statement
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let arg3_expected = cache.equations[2].lhs.clone();
let arg4_expected = cache.equations[3].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::ContainerUpdate,
&[arg1_expected, arg2_expected, arg3_expected, arg4_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, aux_tag_ok, arg_types_ok, merkle_proof_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_merkle_delete_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "MerkleDeleteOp");
let (aux_tag_ok, resolved_merkle_tree_state_transition_claim) =
aux.as_type::<MerkleTreeStateTransitionClaimTarget>(
builder,
OperationAuxTableTag::MerkleTreeStateTransitionProof as u32,
);
let op_code_ok = op_type.has_native(builder, NativeOperation::ContainerDeleteFromEntries);
let (arg_types_ok, [new_root_value, old_root_value, op_key_value]) =
cache.first_n_args_as_values();
let expected_merkle_op = builder.constant(F::from_canonical_u8(MerkleTreeOp::Delete as u8));
// Check Merkle proof (verified elsewhere) against op args.
let merkle_proof_checks = [
/* The supplied Merkle transition proof must be enabled. */
resolved_merkle_tree_state_transition_claim.enabled,
/* ...and it must be a deletion proof. */
builder.is_equal(
resolved_merkle_tree_state_transition_claim.op,
expected_merkle_op,
),
/* ...for the root-key combination in the resolved op args. */
builder.is_equal_slice(
&old_root_value.elements,
&resolved_merkle_tree_state_transition_claim
.old_root
.elements,
),
builder.is_equal_slice(
&new_root_value.elements,
&resolved_merkle_tree_state_transition_claim
.new_root
.elements,
),
builder.is_equal_slice(
&op_key_value.elements,
&resolved_merkle_tree_state_transition_claim.op_key.elements,
),
];
let merkle_proof_ok = builder.all(merkle_proof_checks);
// Check output statement
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let arg3_expected = cache.equations[2].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::ContainerDelete,
&[arg1_expected, arg2_expected, arg3_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, aux_tag_ok, arg_types_ok, merkle_proof_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_custom_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
resolved_op_args: &[StatementTarget],
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpCustom");
let (aux_tag_ok, resolved_query) = aux.as_type::<CustomPredicateVerifyQueryTarget>(
builder,
OperationAuxTableTag::CustomPredVerify as u32,
);
let query_ok = builder.is_equal_flattenable(
&resolved_query,
&CustomPredicateVerifyQueryTarget {
statement: st.clone(),
op_type: op_type.clone(),
op_args: resolved_op_args.to_vec(),
},
);
let ok = builder.all([aux_tag_ok, query_ok]);
measure_gates_end!(builder, measure);
ok
}
/// Carries out the checks necessary for EqualFromEntries and
/// NotEqualFromEntries.
fn verify_eq_neq_from_entries_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpEqNeqFromEntries");
let eq_op_st_code_ok = {
let op_code_ok = op_type.has_native(builder, NativeOperation::EqualFromEntries);
let st_code_ok = st.has_native_type(builder, NativePredicate::Equal);
builder.and(op_code_ok, st_code_ok)
};
let neq_op_st_code_ok = {
let op_code_ok = op_type.has_native(builder, NativeOperation::NotEqualFromEntries);
let st_code_ok = st.has_native_type(builder, NativePredicate::NotEqual);
builder.and(op_code_ok, st_code_ok)
};
let op_st_code_ok = builder.or(eq_op_st_code_ok, neq_op_st_code_ok);
let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values();
let op_args_eq = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements);
let op_args_ok = builder.is_equal(op_args_eq.target, eq_op_st_code_ok.target);
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let expected_st_args: Vec<_> = [arg1_expected, arg2_expected]
.into_iter()
.chain(std::iter::repeat_with(|| StatementArgTarget::none(builder)))
.take(Params::max_statement_args())
.flat_map(|arg| arg.elements)
.collect();
let st_args_ok = builder.is_equal_slice(
&expected_st_args,
&st.args
.iter()
.flat_map(|arg| arg.elements)
.collect::<Vec<_>>(),
);
let ok = builder.all([op_st_code_ok, arg_types_ok, op_args_ok, st_args_ok]);
measure_gates_end!(builder, measure);
ok
}
/// Carries out the checks necessary for LtFromEntries and
/// LtEqFromEntries.
fn verify_lt_lteq_from_entries_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpLtEqFromEntries");
let zero = ValueTarget::zero(builder);
let one = ValueTarget::one(builder);
let lt_op_st_code_ok = {
let op_code_ok = op_type.has_native(builder, NativeOperation::LtFromEntries);
let st_code_ok = st.has_native_type(builder, NativePredicate::Lt);
builder.and(op_code_ok, st_code_ok)
};
let lteq_op_st_code_ok = {
let op_code_ok = op_type.has_native(builder, NativeOperation::LtEqFromEntries);
let st_code_ok = st.has_native_type(builder, NativePredicate::LtEq);
builder.and(op_code_ok, st_code_ok)
};
let op_st_code_ok = builder.or(lt_op_st_code_ok, lteq_op_st_code_ok);
let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values();
// If we are not dealing with the right op & statement types,
// replace args with dummy values in the following checks.
let value1 = builder.select_value(op_st_code_ok, arg1_value, zero);
let value2 = builder.select_value(op_st_code_ok, arg2_value, one);
// Range check
builder.assert_i64(value1);
builder.assert_i64(value2);
// Check for equality.
let args_equal = builder.is_equal_slice(&value1.elements, &value2.elements);
// Check < if applicable.
let lt_check_flag = {
let not_args_equal = builder.not(args_equal);
let lteq_eq_case = builder.and(lteq_op_st_code_ok, not_args_equal);
builder.or(lt_op_st_code_ok, lteq_eq_case)
};
builder.assert_i64_less_if(lt_check_flag, value1, value2);
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let expected_st_args: Vec<_> = [arg1_expected, arg2_expected]
.into_iter()
.chain(std::iter::repeat_with(|| StatementArgTarget::none(builder)))
.take(Params::max_statement_args())
.flat_map(|arg| arg.elements)
.collect();
let st_args_ok = builder.is_equal_slice(
&expected_st_args,
&st.args
.iter()
.flat_map(|arg| arg.elements)
.collect::<Vec<_>>(),
);
let ok = builder.all([op_st_code_ok, arg_types_ok, st_args_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_hash_of_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpHashOf");
let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf);
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values();
let expected_hash_value = builder.hash_values(arg2_value, arg3_value);
let hash_value_ok = builder.is_equal_slice(&arg1_value.elements, &expected_hash_value.elements);
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let arg3_expected = cache.equations[2].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::HashOf,
&[arg1_expected, arg2_expected, arg3_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, arg_types_ok, hash_value_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_public_key_of_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpPublicKeyOf");
let (aux_tag_ok, resolved_pk_sk) =
aux.as_type::<PubKeySecKeyTarget>(builder, OperationAuxTableTag::PublicKeyOf as u32);
let op_code_ok = op_type.has_native(builder, NativeOperation::PublicKeyOf);
let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values();
// inputting public_key, secret_key
let pk_hash = arg1_value;
let sk_hash = arg2_value;
let pk_ok = builder.is_equal_slice(&pk_hash.elements, &resolved_pk_sk.0.elements);
let sk_ok = builder.is_equal_slice(&sk_hash.elements, &resolved_pk_sk.1.elements);
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::PublicKeyOf,
&[arg1_expected, arg2_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, aux_tag_ok, arg_types_ok, pk_ok, sk_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_signed_by_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
aux: &TableEntryTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpSignedBy");
let (aux_tag_ok, resolved_msg_pk) =
aux.as_type::<MsgPubKeyTarget>(builder, OperationAuxTableTag::SignedBy as u32);
let op_code_ok = op_type.has_native(builder, NativeOperation::SignedBy);
let (arg_types_ok, [arg1_value, arg2_value]) = cache.first_n_args_as_values();
// inputting msg, pub_key
let msg = arg1_value;
let pk_hash = arg2_value;
let msg_ok = builder.is_equal_slice(&msg.elements, &resolved_msg_pk.0.elements);
let pk_ok = builder.is_equal_slice(&pk_hash.elements, &resolved_msg_pk.1.elements);
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::SignedBy,
&[arg1_expected, arg2_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, aux_tag_ok, arg_types_ok, msg_ok, pk_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_sum_of_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpSumOf");
let value_zero = ValueTarget::zero(builder);
let op_code_ok = op_type.has_native(builder, NativeOperation::SumOf);
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values();
// Select to avoid overflow.
let summand1 = builder.select_value(op_code_ok, arg2_value, value_zero);
let summand2 = builder.select_value(op_code_ok, arg3_value, value_zero);
let expected_sum = builder.i64_add(summand1, summand2);
let sum_ok = builder.is_equal_slice(&arg1_value.elements, &expected_sum.elements);
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let arg3_expected = cache.equations[2].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::SumOf,
&[arg1_expected, arg2_expected, arg3_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, arg_types_ok, sum_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_product_of_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpProductOf");
let value_zero = ValueTarget::zero(builder);
let op_code_ok = op_type.has_native(builder, NativeOperation::ProductOf);
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values();
// Select to avoid overflow.
let factor1 = builder.select_value(op_code_ok, arg2_value, value_zero);
let factor2 = builder.select_value(op_code_ok, arg3_value, value_zero);
let expected_product = builder.i64_mul(factor1, factor2);
let product_ok = builder.is_equal_slice(&arg1_value.elements, &expected_product.elements);
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let arg3_expected = cache.equations[2].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::ProductOf,
&[arg1_expected, arg2_expected, arg3_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, arg_types_ok, product_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_max_of_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpMaxOf");
let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf);
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = cache.first_n_args_as_values();
// Check that arg1_value is equal to one of the other two
// values.
let arg1_eq_arg2 = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements);
let arg1_eq_arg3 = builder.is_equal_slice(&arg1_value.elements, &arg3_value.elements);
let all_eq = builder.and(arg1_eq_arg2, arg1_eq_arg3);
let not_all_eq = builder.not(all_eq);
let arg1_check = builder.or(arg1_eq_arg2, arg1_eq_arg3);
// If it is not equal to any of the other two values, it must be greater than it.
let lower_bound = builder.select_value(arg1_eq_arg2, arg3_value, arg2_value);
// Only check lower bound if not all args are equal.
let lt_check_enabled = builder.and(not_all_eq, op_code_ok);
builder.assert_i64_less_if(lt_check_enabled, lower_bound, arg1_value);
let arg1_expected = cache.equations[0].lhs.clone();
let arg2_expected = cache.equations[1].lhs.clone();
let arg3_expected = cache.equations[2].lhs.clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::MaxOf,
&[arg1_expected, arg2_expected, arg3_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, arg_types_ok, arg1_check, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_replace_value_with_entry_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
cache: &StatementCachePriv,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpReplaceValueWithEntry");
let op_code_ok = op_type.has_native(builder, NativeOperation::ReplaceValueWithEntry);
let st_in = &cache.op_args[BASE_PARAMS.max_statement_args];
let mut args = Vec::new();
let mut args_ok = builder._true();
for (arg_in, entry_cache) in zip_eq(&st_in.args, &cache.equations) {
// if the op_arg is None, keep the original argument, if it's a Contains swap the value by
// the reference Entry while checking that the value in Contains matches the original
// argument.
let arg = builder.select_flattenable(
params,
entry_cache.pred_is_none,
arg_in,
&entry_cache.reference,
);
args.push(arg);
let arg_ref_ok = {
let arg_in_is_value = builder.statement_arg_is_value(arg_in);
let value_eq = builder.is_equal_flattenable(&arg_in.as_value(), &entry_cache.value);
builder.all([entry_cache.is_reference, arg_in_is_value, value_eq])
};
let arg_ok = builder.or(entry_cache.pred_is_none, arg_ref_ok);
args_ok = builder.and(args_ok, arg_ok);
}
let expected_statement = StatementTarget::new(*st_in.pred_hash(), args);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, args_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_transitive_eq_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpTransitiveEq");
let op_code_ok = op_type.has_native(builder, NativeOperation::TransitiveEqualFromStatements);
let arg1_type_ok = resolved_op_args[0].has_native_type(builder, NativePredicate::Equal);
let arg2_type_ok = resolved_op_args[1].has_native_type(builder, NativePredicate::Equal);
let arg_types_ok = builder.all([arg1_type_ok, arg2_type_ok]);
let arg1_lhs = &resolved_op_args[0].args[0];
let arg1_rhs = &resolved_op_args[0].args[1];
let arg2_lhs = &resolved_op_args[1].args[0];
let arg2_rhs = &resolved_op_args[1].args[1];
let inner_args_match = builder.is_equal_slice(&arg1_rhs.elements, &arg2_lhs.elements);
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::Equal,
&[arg1_lhs.clone(), arg2_rhs.clone()],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, arg_types_ok, inner_args_match, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_none_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpNone");
let op_code_ok = op_type.has_native(builder, NativeOperation::None);
let expected_statement =
StatementTarget::new_native(builder, params, NativePredicate::None, &[]);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
fn verify_lt_to_neq_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpLtToNeq");
let op_code_ok = op_type.has_native(builder, NativeOperation::LtToNotEqual);
let arg_type_ok = resolved_op_args[0].has_native_type(builder, NativePredicate::Lt);
let arg1_expected = resolved_op_args[0].args[0].clone();
let arg2_expected = resolved_op_args[0].args[1].clone();
let expected_statement = StatementTarget::new_native(
builder,
params,
NativePredicate::NotEqual,
&[arg1_expected, arg2_expected],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
let ok = builder.all([op_code_ok, arg_type_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
//
// Custom Predicate constraints
//
fn verify_copy_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
) -> BoolTarget {
let measure = measure_gates_begin!(builder, "OpCopy");
let op_code_ok = op_type.has_native(builder, NativeOperation::CopyStatement);
let expected_statement = &resolved_op_args[0];
let st_ok = builder.is_equal_flattenable(st, expected_statement);
let ok = builder.all([op_code_ok, st_ok]);
measure_gates_end!(builder, measure);
ok
}
// NOTE: This is a bit messy. The target types are defined in `common.rs` because they are used in
// `add_virtual_foo` methods in the trait for the `CircuitBuilder`. But the constraint logic is
// here. Maybe we want to move everything related to custom predicates to its own module, but then
// should we add a new trait for the `add_virtual_foo` methods so that everything is contained in a
// module?
fn make_statement_arg_from_template_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st_tmpl_arg: &StatementTmplArgTarget,
args: &[ValueTarget],
) -> StatementArgTarget {
let zero = builder.zero();
let (is_literal, value_literal) = st_tmpl_arg.as_literal(builder);
let (is_ak, ak_id_wc_index, ak_key_lit_or_wc) = st_tmpl_arg.as_anchored_key(builder);
let (is_wc_literal, wc_index) = st_tmpl_arg.as_wildcard_literal(builder);
let ((_is_ak_key_lit, ak_key_lit), (is_ak_key_wc, ak_key_wc_index)) =
ak_key_lit_or_wc.cases(builder);
// optimization: ak_id_wc_index and wc_index use the same signals, so we only need to do one
// random access to resolve both of them
assert_eq!(ak_id_wc_index, wc_index);
// If the index is not used, use a 0 instead to still pass the range constraints from
// vec_ref
let first_index = ak_id_wc_index;
let is_first_index_valid = builder.or(is_ak, is_wc_literal);
let first_index = builder.select(is_first_index_valid, first_index, zero);
let resolved_ak_id = builder.vec_ref_small(params, args, first_index);
let resolved_wc = resolved_ak_id;
// If the index is not used, use a 0 instead to still pass the range constraints from
// vec_ref
let second_index = ak_key_wc_index;
let is_second_index_valid = builder.and(is_ak, is_ak_key_wc);
let second_index = builder.select(is_second_index_valid, second_index, zero);
let resolved_ak_key = builder.vec_ref_small(params, args, second_index);
let ak_key = ak_key_lit; // is_ak_key_lit
let ak_key = builder.select_flattenable(params, is_ak_key_wc, &resolved_ak_key, &ak_key);
let first = ValueTarget::zero(builder); // is_none
let first = builder.select_flattenable(params, is_literal, &value_literal, &first);
let first = builder.select_flattenable(params, is_ak, &resolved_ak_id, &first);
let first = builder.select_flattenable(params, is_wc_literal, &resolved_wc, &first);
let second = ValueTarget::zero(builder); // is_none or is_literal or is_wc_literal
let second = builder.select_flattenable(params, is_ak, &ak_key, &second);
StatementArgTarget::new(first, second)
}
fn make_predicate_from_template_circuit(
params: &Params,
builder: &mut CircuitBuilder,
pred_hash_or_wc: &PredicateHashOrWildcardTarget,
args: &[ValueTarget],
) -> HashOutTarget {
let zero = builder.zero();
let is_pred = pred_hash_or_wc.is_pred(builder);
// If the index is not used, use a 0 instead to still pass the range constraints from
// vec_ref
let index = builder.select(is_pred, zero, pred_hash_or_wc.wc_index());
let resolved_pred_hash = HashOutTarget::from(builder.vec_ref_small(params, args, index));
builder.select_flattenable(
params,
is_pred,
&pred_hash_or_wc.pred_hash(),
&resolved_pred_hash,
)
}
fn make_statement_from_template_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st_tmpl: &StatementTmplTarget,
args: &[ValueTarget],
) -> StatementTarget {
let measure = measure_gates_begin!(builder, "StArgFromTmpl");
let st_args = st_tmpl
.args
.iter()
.map(|st_tmpl_arg| {
make_statement_arg_from_template_circuit(params, builder, st_tmpl_arg, args)
})
.collect();
measure_gates_end!(builder, measure);
let measure = measure_gates_begin!(builder, "PredFromTmpl");
let pred_hash =
make_predicate_from_template_circuit(params, builder, st_tmpl.pred_hash_or_wc(), args);
measure_gates_end!(builder, measure);
StatementTarget::new(pred_hash, st_args)
}
/// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard
/// values (args):
/// - Verify that the custom predicate is satisfied with the given statements
/// - Build the output statement
/// - Build the expected operation type
fn make_custom_statement_circuit(
params: &Params,
builder: &mut CircuitBuilder,
custom_predicate: &CustomPredicateEntryTarget,
op_args: &[StatementTarget],
args: &[ValueTarget], // arguments to the custom predicate, public and private
) -> Result<(StatementTarget, OperationTypeTarget)> {
let measure = measure_gates_begin!(builder, "CustomOpVerify");
// Some sanity checks
assert_eq!(BASE_PARAMS.max_operation_args, op_args.len());
assert_eq!(params.max_custom_predicate_wildcards, args.len());
let (batch_id, index) = (custom_predicate.id, custom_predicate.index);
let op_type = OperationTypeTarget::new_custom(builder, batch_id, index);
// Build the statement
let st_predicate = PredicateTarget::new_custom(builder, batch_id, index);
let arg_none = ValueTarget::zero(builder);
let lt_mask = builder.lt_mask(
Params::max_statement_args(),
custom_predicate.predicate.args_len,
);
let st_args = std::iter::zip(lt_mask, args)
.map(|(mask, arg)| {
let v = builder.select_flattenable(params, mask, arg, &arg_none);
StatementArgTarget::wildcard_literal(builder, &v)
})
.collect_vec();
let statement_with_pred =
StatementTarget::new_with_pred(builder, params, st_predicate, &st_args);
// Check the operation arguments
// From each statement template we generate an expected statement using replacing the
// wildcards by the arguments. Then we compare the expected statement with the operation
// argument.
let expected_sts: Vec<_> = custom_predicate
.predicate
.statements
.iter()
.map(|st_tmpl| make_statement_from_template_circuit(params, builder, st_tmpl, args))
.collect();
// expected_sts.len() == params.max_custom_predicate_arity
// op_args.len() == params.max_operation_args;
let sts_eq: Vec<_> = expected_sts
.iter()
.zip(op_args.iter())
.map(|(expected_st, st)| builder.is_equal_flattenable(expected_st, st))
.collect();
let all_st_eq = builder.all(sts_eq.clone());
let some_st_eq = builder.any(sts_eq);
// NOTE: This BoolTarget is safe because both inputs to the select are safe
let is_op_args_ok = BoolTarget::new_unsafe(builder.select(
custom_predicate.predicate.conjunction,
all_st_eq.target,
some_st_eq.target,
));
builder.assert_one(is_op_args_ok.target);
measure_gates_end!(builder, measure);
Ok((statement_with_pred, op_type))
}
/// Replace the blank verifier_data_hash slots in intro predicates by `vd_hash`
fn normalize_statement_circuit(
params: &Params,
builder: &mut CircuitBuilder,
statement: &StatementTarget,
vd_hash: &HashOutTarget,
) -> StatementTarget {
let is_blank_intro = statement.pred_is_blank_intro(builder);
let old_pred_hash = statement.pred_hash();
let intro_pred_hash = PredicateTarget::new_intro(builder, *vd_hash).hash(builder);
let new_pred_hash =
builder.select_flattenable(params, is_blank_intro, &intro_pred_hash, old_pred_hash);
StatementTarget::new(new_pred_hash, statement.args.clone())
}
/// `params.num_public_statements_hash` is the total number of statements that will be hashed.
/// The statements hash is calculated with front-padded none-statements and then the input
/// statements reversed. The part of the hash from the front-padded none-statements is
/// precomputed.
pub fn calculate_statements_hash_circuit(
builder: &mut CircuitBuilder,
// These statements will be padded to reach `num_statements`
statements: &[StatementTarget],
) -> HashOutTarget {
assert!(statements.len() <= Params::num_public_statements_hash());
let measure = measure_gates_begin!(builder, "CalculateStsHash");
let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten());
let mut none_st = mainpod::Statement::from(Statement::None);
pad_statement(&mut none_st);
let front_pad_elts = iter::repeat(&none_st)
.take(Params::num_public_statements_hash() - statements.len())
.flat_map(|s| s.to_fields())
.collect_vec();
let (perm, front_pad_elts_rem) =
precompute_hash_state::<F, PoseidonPermutation<F>>(&front_pad_elts);
// Precompute the Poseidon state for the initial padding chunks
let inputs = front_pad_elts_rem
.iter()
.map(|v| builder.constant(*v))
.chain(statements_rev_flattened)
.collect_vec();
let sts_hash =
hash_from_state_circuit::<PoseidonHash, PoseidonPermutation<F>>(builder, perm, &inputs);
measure_gates_end!(builder, measure);
sts_hash
}
// Replace BatchSelf predicates with the corresponding Custom(batch_id, index), and
// SelfPredicateHash args with Literal(hash(Custom(batch_id, index))).
fn normalize_st_tmpl_circuit(
params: &Params,
builder: &mut CircuitBuilder,
st_tmpl: &StatementTmplTarget,
id: HashOutTarget,
) -> StatementTmplTarget {
// If the custom predicate is self, we normalize it and then hash it.
let old_pred = st_tmpl.pred().expect("StatementTmpl contains predicate");
let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf));
let is_batch_self = builder.is_equal(old_pred.elements[0], prefix_batch_self);
let pred_index = old_pred.elements[1];
let normalized_custom_pred = PredicateTarget::new_custom(builder, id, pred_index);
let normalized_custom_pred_hash = normalized_custom_pred.hash(builder);
// If the template is using a predicate and it is batch self we use the freshly computed
// normalized predicate hash, otherwise we keep the original data.
let old_data = st_tmpl.pred_hash_or_wc().data();
let is_pred = st_tmpl.pred_hash_or_wc().is_pred(builder);
let is_pred_batch_self = builder.and(is_pred, is_batch_self);
let data = builder.select_flattenable(
params,
is_pred_batch_self,
&ValueTarget::from(normalized_custom_pred_hash),
&old_data,
);
let pred_hash_or_wc =
PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data);
// Normalize SelfPredicateHash args: replace prefix 4 with Literal containing the resolved
// predicate hash. Same pattern as the predicate normalization above.
let prefix_sph = builder.constant(F::from(StatementTmplArgPrefix::SelfPredicateHash));
let prefix_literal = builder.constant(F::from(StatementTmplArgPrefix::Literal));
let zero = builder.zero();
let normalized_args = st_tmpl
.args
.iter()
.map(|arg| {
let is_sph = builder.is_equal(arg.elements[0], prefix_sph);
// The predicate index is in elements[1] (same slot as WildcardLiteral).
let pred_index = arg.elements[1];
// Compute hash(Custom(batch_id, pred_index))
let pred_target = PredicateTarget::new_custom(builder, id, pred_index);
let pred_hash = pred_target.hash(builder);
// Build a Literal-encoded arg: [1, hash[0..4], 0, 0, 0, 0]
let mut literal_elements = [zero; Params::statement_tmpl_arg_size()];
literal_elements[0] = prefix_literal;
literal_elements[1] = pred_hash.elements[0];
literal_elements[2] = pred_hash.elements[1];
literal_elements[3] = pred_hash.elements[2];
literal_elements[4] = pred_hash.elements[3];
let normalized = StatementTmplArgTarget {
elements: literal_elements,
};
builder.select_flattenable(params, is_sph, &normalized, arg)
})
.collect();
StatementTmplTarget::new(pred_hash_or_wc, normalized_args)
}
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
/// hash([batch_id, custom_predicate_index, custom_predicate]). While building the table we
/// calculate the id of each batch. Return the hash of each table entry.
fn build_custom_predicate_table_circuit(
params: &Params,
builder: &mut CircuitBuilder,
custom_predicates: &[CustomPredicateInBatchTarget],
) -> Result<Vec<HashOutTarget>> {
let measure = measure_gates_begin!(builder, "BuildCustomPredTbl");
let mut custom_predicate_table = Vec::with_capacity(params.max_custom_predicates);
for cp in custom_predicates {
let measure_cp = measure_gates_begin!(builder, "CustomPred");
cp.verify_circuit(builder);
let statements = cp
.self_predicate
.statements
.iter()
.map(|st_with_pred_tmpl| {
normalize_st_tmpl_circuit(params, builder, st_with_pred_tmpl, cp.id)
})
.collect_vec();
let entry = CustomPredicateEntryTarget {
id: cp.id, // output
index: cp.index, // input
predicate: CustomPredicateTarget {
conjunction: cp.self_predicate.conjunction,
statements,
args_len: cp.self_predicate.args_len,
}, // input
};
let in_query_hash = entry.hash(builder);
custom_predicate_table.push(in_query_hash);
measure_gates_end!(builder, measure_cp);
}
measure_gates_end!(builder, measure);
Ok(custom_predicate_table)
}
fn verify_main_pod_circuit(
builder: &mut CircuitBuilder,
main_pod: &MainPodVerifyTarget,
verified_proofs: &[VerifiedProofTarget],
) -> Result<HashOutTarget> {
let params = &main_pod.params;
assert_eq!(params.max_input_pods, verified_proofs.len());
let measure = measure_gates_begin!(builder, "MainPodVerify");
// Build the statement array
let mut statements = Vec::new();
// Statement at index 0 is always None to be used for padding operation arguments in custom
// predicate statements
let st_none = StatementTarget::new_native(builder, params, NativePredicate::None, &[]);
statements.push(st_none);
// 1a. Verify all input recursive pods
for (verified_proof, vd_mt_proof, input_pod_self_statements) in izip!(
verified_proofs,
&main_pod.vd_mt_proofs,
&main_pod.input_pods_self_statements
) {
let measure_in_pod = measure_gates_begin!(builder, "VerifyInPod");
//
// Verify sts_hash from the statements
//
let expected_sts_hash = HashOutTarget::try_from(
&verified_proof.public_inputs
[PI_OFFSET_STATEMENTS_HASH..PI_OFFSET_STATEMENTS_HASH + HASH_SIZE],
)
.expect("4 elements");
// NOTE: We use an EmptyPod for padding input pod slots. The EmptyPod is an introduction
// pod that declares a statement with no arguments.
let is_blank_intro = input_pod_self_statements[0].pred_is_blank_intro(builder);
// Introduction pods can only have Introduction or None statements
let mut intro_ok = is_blank_intro;
for self_st in &input_pod_self_statements[1..] {
let st_is_intro = self_st.pred_is_blank_intro(builder);
let st_is_none = self_st.has_native_type(builder, NativePredicate::None);
let st_is_intro_or_none = builder.or(st_is_intro, st_is_none);
intro_ok = builder.and(intro_ok, st_is_intro_or_none);
}
builder.connect(is_blank_intro.target, intro_ok.target);
let is_main = builder.not(is_blank_intro);
for self_st in input_pod_self_statements {
let normalized_st = normalize_statement_circuit(
params,
builder,
self_st,
&verified_proof.verifier_data_hash,
);
statements.push(normalized_st);
}
let sts_hash = calculate_statements_hash_circuit(builder, input_pod_self_statements);
builder.connect_hashes(expected_sts_hash, sts_hash);
//
// Verify that all main input pod proofs use verifier data from the public input VD
// array. This requires merkle proofs. introduction pods are not checked here because
// their verifier_data_hash appears in their introduction statement.
//
verify_merkle_proof_circuit(builder, vd_mt_proof);
// ensure that mt_proof is enabled if it's a main pod
builder.connect(vd_mt_proof.enabled.target, is_main.target);
// connect the vd_mt_proof's root to the actual vds_root, to ensure that the mt proof
// verifies against the vds_root
builder.connect_hashes(main_pod.vds_root, vd_mt_proof.root);
// connect vd_mt_proof's value with the verified_proof.verifier_data_hash
builder.connect_hashes(
verified_proof.verifier_data_hash,
HashOutTarget::from_vec(vd_mt_proof.value.elements.to_vec()),
);
//
// Verify that VD array that input pod uses is the same we use now.
//
let verified_proof_vds_root = HashOutTarget::try_from(
&verified_proof.public_inputs[PI_OFFSET_VDSROOT..PI_OFFSET_VDSROOT + HASH_SIZE],
)
.expect("4 elements");
builder.connect_hashes(main_pod.vds_root, verified_proof_vds_root);
measure_gates_end!(builder, measure_in_pod);
}
let input_statements_offset = statements.len();
// Add the input (private and public) statements
for statement in &main_pod.input_statements {
statements.push(statement.clone());
}
let public_statements_offset = main_pod.input_statements.len() - params.max_public_statements;
let pub_statements = &main_pod.input_statements[public_statements_offset..];
// Table of custom predicate batches with batch_id calculation
let custom_predicate_table =
build_custom_predicate_table_circuit(params, builder, &main_pod.custom_predicates)?;
let aux_table = build_operation_aux_table_circuit(
params,
builder,
&main_pod.merkle_proofs,
&main_pod.public_key_of_sks,
&main_pod.signed_bys,
&main_pod.merkle_tree_state_transition_proofs,
&main_pod.custom_predicate_verifications,
&custom_predicate_table,
)?;
// 2. Calculate the Pod Id from the public statements
let sts_hash = calculate_statements_hash_circuit(builder, pub_statements);
// 5. Verify input statements
for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() {
let prev_statements = &statements[..input_statements_offset + i];
if i < public_statements_offset {
verify_operation_circuit(params, builder, st, op, prev_statements, &aux_table)?;
} else {
verify_operation_public_statement_circuit(params, builder, st, op, prev_statements)?;
}
}
measure_gates_end!(builder, measure);
Ok(sts_hash)
}
#[derive(Clone, Serialize, Deserialize)]
pub struct MainPodVerifyTarget {
params: Params,
vds_root: HashOutTarget,
vd_mt_proofs: Vec<MerkleClaimAndProofTarget>,
input_pods_self_statements: Vec<Vec<StatementTarget>>,
// The KEY_TYPE statement must be the first public one
input_statements: Vec<StatementTarget>,
operations: Vec<OperationTarget>,
merkle_proofs: Vec<MerkleClaimAndProofTarget>,
public_key_of_sks: Vec<BigUInt320Target>,
signed_bys: Vec<SignedByTarget>,
merkle_tree_state_transition_proofs: Vec<MerkleTreeStateTransitionProofTarget>,
custom_predicates: Vec<CustomPredicateInBatchTarget>,
custom_predicate_verifications: Vec<CustomPredicateVerifyEntryTarget>,
}
impl MainPodVerifyTarget {
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
MainPodVerifyTarget {
params: params.clone(),
vds_root: builder.add_virtual_hash(),
vd_mt_proofs: (0..params.max_input_pods)
.map(|_| MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_vds, builder))
.collect(),
input_pods_self_statements: (0..params.max_input_pods)
.map(|_| {
(0..params.max_input_pods_public_statements)
.map(|_| builder.add_virtual_statement(false))
.collect_vec()
})
.collect(),
input_statements: (0..params.max_statements)
.map(|_| builder.add_virtual_statement(false))
.collect(),
operations: (0..params.max_statements)
.map(|_| builder.add_virtual_operation(params))
.collect(),
merkle_proofs: (0..params.max_merkle_proofs_containers)
.map(|_| {
MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, builder)
})
.collect(),
public_key_of_sks: (0..params.max_public_key_of)
.map(|_| builder.add_virtual_biguint320_target())
.collect(),
signed_bys: (0..params.max_signed_by)
.map(|_| SignedByTarget::new_virtual(builder))
.collect(),
merkle_tree_state_transition_proofs: (0..params
.max_merkle_tree_state_transition_proofs_containers)
.map(|_| {
MerkleTreeStateTransitionProofTarget::new_virtual(
params.max_depth_mt_containers,
builder,
)
})
.collect(),
custom_predicates: (0..params.max_custom_predicates)
.map(|_| CustomPredicateInBatchTarget::new_virtual(builder))
.collect(),
custom_predicate_verifications: (0..params.max_custom_predicate_verifications)
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder))
.collect(),
}
}
}
pub struct CustomPredicateVerification {
pub custom_predicate_table_index: usize,
pub custom_predicate: CustomPredicateRef,
pub args: Vec<Value>,
pub op_args: Vec<mainpod::Statement>,
}
pub struct MainPodVerifyInput {
pub vds_set: VDSet,
/// field containing the `vd_mt_proofs` aside from the `vds_set`, because
/// inside the MainPodVerifyTarget circuit, since it is the InnerCircuit for
/// the RecursiveCircuit, we don't have access to the used verifier_datas.
/// The bool is used as `enabled` and will be false for intro pods.
pub vd_mt_proofs: Vec<(bool, MerkleClaimAndProof)>,
pub input_pods_pub_self_statements: Vec<Vec<Statement>>,
pub statements: Vec<mainpod::Statement>,
pub operations: Vec<mainpod::Operation>,
pub merkle_proofs: Vec<MerkleClaimAndProof>,
pub public_key_of_sks: Vec<SecretKey>,
pub signed_bys: Vec<SignedBy>,
pub merkle_tree_state_transition_proofs: Vec<MerkleTreeStateTransitionProof>,
pub custom_predicates_with_mpt_proofs: Vec<(CustomPredicateRef, MerkleProof)>,
pub custom_predicate_verifications: Vec<CustomPredicateVerification>,
}
fn set_targets_input_pods_self_statements(
pw: &mut PartialWitness<F>,
params: &Params,
statements_target: &[StatementTarget],
statements: &[Statement],
) -> Result<()> {
assert_eq!(
statements_target.len(),
params.max_input_pods_public_statements
);
assert!(statements.len() <= Params::num_public_statements_hash());
for (i, statement) in statements.iter().enumerate() {
statements_target[i].set_targets(pw, &statement.clone().into())?;
}
// Padding
let mut none_st = mainpod::Statement::from(Statement::None);
pad_statement(&mut none_st);
for statement_target in statements_target.iter().skip(statements.len()) {
statement_target.set_targets(pw, &none_st)?;
}
Ok(())
}
impl InnerCircuit for MainPodVerifyTarget {
type Input = MainPodVerifyInput;
type Params = Params;
fn build(
builder: &mut CircuitBuilder,
params: &Self::Params,
verified_proofs: &[VerifiedProofTarget],
) -> Result<Self> {
let main_pod = MainPodVerifyTarget::new_virtual(params, builder);
let sts_hash = verify_main_pod_circuit(builder, &main_pod, verified_proofs)?;
builder.register_public_inputs(&sts_hash.elements);
builder.register_public_inputs(&main_pod.vds_root.elements);
Ok(main_pod)
}
/// assigns the values to the targets
fn set_targets(&self, pw: &mut PartialWitness<F>, input: &Self::Input) -> Result<()> {
let vds_root = input.vds_set.root();
pw.set_target_arr(&self.vds_root.elements, &vds_root.0)?;
assert_eq!(
input.vd_mt_proofs.len(),
input.input_pods_pub_self_statements.len()
);
let input_pods_len = input.vd_mt_proofs.len();
assert!(input_pods_len <= self.params.max_input_pods);
for (i, (enable, vd_mt_proof)) in input.vd_mt_proofs.iter().enumerate() {
self.vd_mt_proofs[i].set_targets(pw, *enable, vd_mt_proof)?;
}
for (i, pod_pub_statements) in input.input_pods_pub_self_statements.iter().enumerate() {
set_targets_input_pods_self_statements(
pw,
&self.params,
&self.input_pods_self_statements[i],
pod_pub_statements,
)?;
}
// Padding
if input_pods_len != self.params.max_input_pods {
let empty_pod = EmptyPod::new_boxed(input.vds_set.clone());
let empty_pod_statements = empty_pod.pub_statements();
let empty_mt_proof = MerkleClaimAndProof {
root: input.vds_set.root(),
value: RawValue::from(empty_pod.verifier_data_hash()),
..MerkleClaimAndProof::empty()
};
for i in input_pods_len..self.params.max_input_pods {
self.vd_mt_proofs[i].set_targets(pw, false, &empty_mt_proof)?;
set_targets_input_pods_self_statements(
pw,
&self.params,
&self.input_pods_self_statements[i],
&empty_pod_statements,
)?;
}
}
assert_eq!(input.statements.len(), self.params.max_statements);
for (i, (st, op)) in zip_eq(&input.statements, &input.operations).enumerate() {
self.input_statements[i].set_targets(pw, st)?;
self.operations[i].set_targets(pw, &self.params, op)?;
}
assert!(input.merkle_proofs.len() <= self.params.max_merkle_proofs_containers);
for (i, mp) in input.merkle_proofs.iter().enumerate() {
self.merkle_proofs[i].set_targets(pw, true, mp)?;
}
// Padding
let pad_mp = MerkleClaimAndProof::empty();
for i in input.merkle_proofs.len()..self.params.max_merkle_proofs_containers {
self.merkle_proofs[i].set_targets(pw, false, &pad_mp)?;
}
assert!(input.public_key_of_sks.len() <= self.params.max_public_key_of);
for (i, sk) in input.public_key_of_sks.iter().enumerate() {
pw.set_biguint320_target(&self.public_key_of_sks[i], &sk.0)?;
}
// Padding
let pad_sk = BigUint::ZERO;
for i in input.public_key_of_sks.len()..self.params.max_public_key_of {
pw.set_biguint320_target(&self.public_key_of_sks[i], &pad_sk)?;
}
assert!(input.signed_bys.len() <= self.params.max_signed_by);
for (i, signed_by) in input.signed_bys.iter().enumerate() {
self.signed_bys[i].set_targets(pw, signed_by)?;
}
// Padding
let pad_signed_by = SignedBy::dummy();
for i in input.signed_bys.len()..self.params.max_signed_by {
self.signed_bys[i].set_targets(pw, &pad_signed_by)?;
}
assert!(
input.merkle_tree_state_transition_proofs.len()
<= self
.params
.max_merkle_tree_state_transition_proofs_containers
);
for (i, mtp) in input.merkle_tree_state_transition_proofs.iter().enumerate() {
self.merkle_tree_state_transition_proofs[i].set_targets(pw, true, mtp)?;
}
// Padding
let pad_mtp = MerkleTreeStateTransitionProof::empty();
for i in input.merkle_tree_state_transition_proofs.len()
..self
.params
.max_merkle_tree_state_transition_proofs_containers
{
self.merkle_tree_state_transition_proofs[i].set_targets(pw, false, &pad_mtp)?;
}
assert!(input.custom_predicates_with_mpt_proofs.len() <= self.params.max_custom_predicates);
for (i, (cp, mtp)) in input.custom_predicates_with_mpt_proofs.iter().enumerate() {
self.custom_predicates[i].set_targets(pw, cp, mtp)?;
}
// Padding
let pad_cpb =
CustomPredicateBatch::new("empty".to_string(), vec![CustomPredicate::empty()]);
let pad_cp = pad_cpb.predicate_ref_by_index(0).expect("index 0 exists");
let (_, pad_mtp) = pad_cpb
.mt()
.prove(&Value::from(0i64).raw())
.expect("exists");
for i in input.custom_predicates_with_mpt_proofs.len()..self.params.max_custom_predicates {
self.custom_predicates[i].set_targets(pw, &pad_cp, &pad_mtp)?;
}
assert!(
input.custom_predicate_verifications.len()
<= self.params.max_custom_predicate_verifications
);
for (i, cpv) in input.custom_predicate_verifications.iter().enumerate() {
self.custom_predicate_verifications[i].set_targets(pw, &self.params, cpv)?;
}
// Padding. Use the first input if it exists. If it doesnt, all batches in this MainPod
// are padding so refer to the first padding entry.
let empty_cpv = CustomPredicateVerification {
custom_predicate_table_index: 0,
custom_predicate: CustomPredicateRef::new(pad_cpb, 0),
args: vec![],
op_args: vec![],
};
let pad_cpv = input
.custom_predicate_verifications
.first()
.unwrap_or(&empty_cpv);
for i in input.custom_predicate_verifications.len()
..self.params.max_custom_predicate_verifications
{
self.custom_predicate_verifications[i].set_targets(pw, &self.params, pad_cpv)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::{iter, ops::Not};
use num::FromPrimitive;
use plonky2::{
field::{goldilocks_field::GoldilocksField, types::Field},
hash::hash_types::HashOut,
iop::witness::WitnessWrite,
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig},
};
use super::*;
use crate::{
backends::plonky2::{
basetypes::C,
circuits::common::tests::I64_TEST_PAIRS,
mainpod::{calculate_statements_hash, OperationArg, OperationAux},
primitives::{
ec::schnorr::SecretKey,
merkletree::{MerkleClaimAndProof, MerkleTree, MerkleTreeStateTransitionProof},
},
signer,
},
dict,
frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
middleware::{
hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard,
RawValue, StatementArg, StatementTmpl, StatementTmplArg, ValueRef, Wildcard,
BASE_PARAMS, EMPTY_VALUE,
},
};
#[derive(Default)]
struct Aux {
merkle_proofs: Vec<MerkleClaimAndProof>,
secret_keys: Vec<SecretKey>,
signed_bys: Vec<SignedBy>,
merkle_tree_state_transition_proofs: Vec<MerkleTreeStateTransitionProof>,
}
impl Aux {
fn merkle_proof(v: MerkleClaimAndProof) -> Self {
Self {
merkle_proofs: vec![v],
..Default::default()
}
}
fn secret_key(v: SecretKey) -> Self {
Self {
secret_keys: vec![v],
..Default::default()
}
}
fn signed_by(v: SignedBy) -> Self {
Self {
signed_bys: vec![v],
..Default::default()
}
}
fn merkle_tree_state_transition_proof(v: MerkleTreeStateTransitionProof) -> Self {
Self {
merkle_tree_state_transition_proofs: vec![v],
..Default::default()
}
}
}
fn operation_verify(
st: mainpod::Statement,
op: mainpod::Operation,
prev_statements: Vec<mainpod::Statement>,
aux: Aux,
) -> Result<()> {
let params = Params {
max_merkle_proofs_containers: aux.merkle_proofs.len(),
max_public_key_of: aux.secret_keys.len(),
max_signed_by: aux.signed_bys.len(),
max_merkle_tree_state_transition_proofs_containers: aux
.merkle_tree_state_transition_proofs
.len(),
max_custom_predicate_verifications: 0,
max_custom_predicates: 0,
..Default::default()
};
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let st_target = builder.add_virtual_statement(false);
let op_target = builder.add_virtual_operation(&params);
let prev_statements_target: Vec<_> = (0..prev_statements.len())
.map(|_| builder.add_virtual_statement(false))
.collect();
let merkle_proofs_target: Vec<_> = aux
.merkle_proofs
.iter()
.map(|_| {
MerkleClaimAndProofTarget::new_virtual(params.max_depth_mt_containers, &mut builder)
})
.collect();
let secret_keys_target: Vec<_> = aux
.secret_keys
.iter()
.map(|sk| builder.constant_biguint320(&sk.0))
.collect();
let signed_by_targets: Vec<_> = aux
.signed_bys
.iter()
.map(|_| SignedByTarget::new_virtual(&mut builder))
.collect();
let merkle_tree_state_transition_proofs_target: Vec<_> = aux
.merkle_tree_state_transition_proofs
.iter()
.map(|_| {
MerkleTreeStateTransitionProofTarget::new_virtual(
params.max_depth_mt_containers,
&mut builder,
)
})
.collect();
let aux_table = build_operation_aux_table_circuit(
&params,
&mut builder,
&merkle_proofs_target,
&secret_keys_target,
&signed_by_targets,
&merkle_tree_state_transition_proofs_target,
&[],
&[],
)?;
verify_operation_circuit(
&params,
&mut builder,
&st_target,
&op_target,
&prev_statements_target,
&aux_table,
)?;
let mut pw = PartialWitness::<F>::new();
st_target.set_targets(&mut pw, &st)?;
op_target.set_targets(&mut pw, &params, &op)?;
for (prev_st_target, prev_st) in prev_statements_target.iter().zip(prev_statements.iter()) {
prev_st_target.set_targets(&mut pw, prev_st)?;
}
for (signed_by_target, signed_by) in signed_by_targets.iter().zip(aux.signed_bys.iter()) {
signed_by_target.set_targets(&mut pw, signed_by)?
}
for (merkle_proof_target, merkle_proof) in
merkle_proofs_target.iter().zip(aux.merkle_proofs.iter())
{
merkle_proof_target.set_targets(&mut pw, true, merkle_proof)?
}
for (merkle_tree_state_transition_proof_target, merkle_tree_state_transition_proof) in
merkle_tree_state_transition_proofs_target
.iter()
.zip(aux.merkle_tree_state_transition_proofs.iter())
{
merkle_tree_state_transition_proof_target.set_targets(
&mut pw,
true,
merkle_tree_state_transition_proof,
)?
}
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
#[test]
fn test_lt_lteq_verify_failures() {
let invalid_int = RawValue([
GoldilocksField::NEG_ONE,
GoldilocksField::ZERO,
GoldilocksField::ZERO,
GoldilocksField::ZERO,
]);
let prev_statements = [Statement::None.into()];
[
// 56 < 55, 55 < 55, 56 <= 55, -55 < -55, -55 < -56, -55 <= -56 should fail to verify
(
mainpod::Operation(
OperationType::Native(NativeOperation::LtFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
),
Statement::lt(56, 55).into(),
),
(
mainpod::Operation(
OperationType::Native(NativeOperation::LtFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
),
Statement::lt(55, 55).into(),
),
(
mainpod::Operation(
OperationType::Native(NativeOperation::LtEqFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
),
Statement::lt_eq(56, 55).into(),
),
(
mainpod::Operation(
OperationType::Native(NativeOperation::LtFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
),
Statement::lt(-55, -55).into(),
),
(
mainpod::Operation(
OperationType::Native(NativeOperation::LtFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
),
Statement::lt(-55, -56).into(),
),
(
mainpod::Operation(
OperationType::Native(NativeOperation::LtEqFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
),
Statement::lt_eq(-55, -56).into(),
),
// 56 < p-1 and p-1 <= p-1 should fail to verify, where p
// is the Goldilocks prime and 'p-1' occupies a single
// limb.
(
mainpod::Operation(
OperationType::Native(NativeOperation::LtFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
),
Statement::lt(56, invalid_int).into(),
),
(
mainpod::Operation(
OperationType::Native(NativeOperation::LtEqFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
),
Statement::lt_eq(invalid_int, invalid_int).into(),
),
]
.into_iter()
.for_each(|(op, st)| {
let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
operation_verify(st, op, prev_statements.to_vec(), Aux::default())
}));
match check {
Err(e) => {
let err_string = e.downcast_ref::<String>().unwrap();
if !err_string.contains("Integer too large to fit") {
panic!("Test failed with an unexpected error: {}", err_string);
}
}
Ok(Err(_)) => {}
_ => panic!("Test passed, yet it should have failed!"),
}
});
}
#[test]
fn test_eq_neq_verify_failures() {
let prev_statements = [Statement::None.into()];
[
// 56 == 55, 55 != 55 should fail to verify
(
mainpod::Operation(
OperationType::Native(NativeOperation::EqualFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
),
Statement::equal(56, 55).into(),
),
(
mainpod::Operation(
OperationType::Native(NativeOperation::NotEqualFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
),
Statement::not_equal(55, 55).into(),
),
]
.into_iter()
.for_each(|(op, st)| {
assert!(operation_verify(st, op, prev_statements.to_vec(), Aux::default()).is_err())
});
}
#[test]
fn test_operation_verify_none() -> Result<()> {
let st: mainpod::Statement = Statement::None.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::None),
vec![],
OperationAux::None,
);
let prev_statements = vec![Statement::None.into()];
operation_verify(st, op, prev_statements, Aux::default())
}
#[test]
fn test_operation_verify_copy() -> Result<()> {
let st: mainpod::Statement = Statement::None.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::CopyStatement),
vec![OperationArg::Index(0)],
OperationAux::None,
);
let prev_statements = vec![Statement::None.into()];
operation_verify(st, op, prev_statements, Aux::default())
}
#[test]
fn test_operation_verify_eq() -> Result<()> {
let dict1 = dict!({"hello" => 55});
let dict2 = dict!({"world" => 55});
let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into();
let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 55).into();
let st: mainpod::Statement = Statement::equal(
AnchoredKey::from((&dict1, "hello")),
AnchoredKey::from((&dict2, "world")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::EqualFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(1)],
OperationAux::None,
);
let prev_statements = vec![st1, st2];
operation_verify(st, op, prev_statements, Aux::default())
}
#[test]
fn test_operation_verify_neq() -> Result<()> {
let dict1 = dict!({"hello" => 55});
let dict2 = dict!({"world" => 75});
let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into();
let st2: mainpod::Statement = Statement::contains(dict2.clone(), "world", 75).into();
let st: mainpod::Statement = Statement::not_equal(
AnchoredKey::from((&dict1, "hello")),
AnchoredKey::from((&dict2, "world")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::NotEqualFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(1)],
OperationAux::None,
);
let prev_statements = vec![st1, st2];
operation_verify(st, op, prev_statements, Aux::default())
}
#[test]
fn test_operation_verify_lt() -> Result<()> {
let dict1 = dict!({"hello" => 55});
let dict2 = dict!({"hello" => 56});
let st1: mainpod::Statement = Statement::contains(dict1.clone(), "hello", 55).into();
let st2: mainpod::Statement = Statement::contains(dict2.clone(), "hello", 56).into();
let st: mainpod::Statement = Statement::lt(
AnchoredKey::from((&dict1, "hello")),
AnchoredKey::from((&dict2, "hello")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::LtFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(1)],
OperationAux::None,
);
let prev_statements = vec![st1, st2.clone()];
operation_verify(st, op, prev_statements, Aux::default())?;
// Also check negative < negative
let dict3 = dict!({"hola" => -56});
let dict4 = dict!({"mundo" => -55});
let st3: mainpod::Statement = Statement::contains(dict3.clone(), "hola", -56).into();
let st4: mainpod::Statement = Statement::contains(dict4.clone(), "mundo", -55).into();
let st: mainpod::Statement = Statement::lt(
AnchoredKey::from((&dict3, "hola")),
AnchoredKey::from((&dict4, "mundo")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::LtFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(1)],
OperationAux::None,
);
let prev_statements = vec![st3.clone(), st4];
operation_verify(st, op, prev_statements, Aux::default())?;
// Also check negative < positive
let st: mainpod::Statement = Statement::lt(
AnchoredKey::from((&dict3, "hola")),
AnchoredKey::from((&dict2, "hello")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::LtFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(1)],
OperationAux::None,
);
let prev_statements = vec![st3, st2];
operation_verify(st, op, prev_statements, Aux::default())
}
#[test]
fn test_operation_verify_lteq() -> Result<()> {
let local = dict!({
"n55" => 55,
"n56" => 56,
"n_56" => -56,
"n_55" => -55,
});
let st1: mainpod::Statement = Statement::contains(local.clone(), "n55", 55).into();
let st2: mainpod::Statement = Statement::contains(local.clone(), "n56", 56).into();
let st: mainpod::Statement = Statement::lt_eq(
AnchoredKey::from((&local, "n55")),
AnchoredKey::from((&local, "n56")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::LtEqFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(1)],
OperationAux::None,
);
let prev_statements = vec![st1, st2.clone()];
operation_verify(st, op, prev_statements, Aux::default())?;
// Also check negative <= negative
let st3: mainpod::Statement = Statement::contains(local.clone(), "n_56", -56).into();
let st4: mainpod::Statement = Statement::contains(local.clone(), "n_55", -55).into();
let st: mainpod::Statement = Statement::lt_eq(
AnchoredKey::from((&local, "n_56")),
AnchoredKey::from((&local, "n_55")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::LtEqFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(1)],
OperationAux::None,
);
let prev_statements = vec![st3.clone(), st4];
operation_verify(st, op, prev_statements, Aux::default())?;
// Also check negative <= positive
let st: mainpod::Statement = Statement::lt_eq(
AnchoredKey::from((&local, "n_56")),
AnchoredKey::from((&local, "n56")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::LtEqFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(1)],
OperationAux::None,
);
let prev_statements = vec![st3, st2];
operation_verify(st, op, prev_statements.clone(), Aux::default())?;
// Also check equality, both positive and negative.
let st: mainpod::Statement = Statement::lt_eq(
AnchoredKey::from((&local, "n_56")),
AnchoredKey::from((&local, "n_56")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::LtEqFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
);
operation_verify(st, op, prev_statements.clone(), Aux::default())?;
let st: mainpod::Statement = Statement::lt_eq(
AnchoredKey::from((&local, "n56")),
AnchoredKey::from((&local, "n56")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::LtEqFromEntries),
vec![OperationArg::Index(1), OperationArg::Index(1)],
OperationAux::None,
);
operation_verify(st, op, prev_statements, Aux::default())
}
#[test]
fn test_operation_verify_hashof() -> Result<()> {
let input_values = [
Value::from(RawValue([
GoldilocksField(1),
GoldilocksField(2),
GoldilocksField(3),
GoldilocksField(4),
])),
Value::from(512),
];
let v1 = hash_values(&input_values);
let [v2, v3] = input_values;
let local = dict!({
"hola" => v1,
"mundo" => v2.clone(),
"!" => v3.clone(),
});
let st1: mainpod::Statement = Statement::contains(local.clone(), "hola", v1).into();
let st2: mainpod::Statement = Statement::contains(local.clone(), "mundo", v2).into();
let st3: mainpod::Statement = Statement::contains(local.clone(), "!", v3).into();
let st: mainpod::Statement = Statement::hash_of(
AnchoredKey::from((&local, "hola")),
AnchoredKey::from((&local, "mundo")),
AnchoredKey::from((&local, "!")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::HashOf),
vec![
OperationArg::Index(0),
OperationArg::Index(1),
OperationArg::Index(2),
],
OperationAux::None,
);
let prev_statements = vec![st1, st2, st3];
operation_verify(st, op, prev_statements, Aux::default())
}
#[test]
fn test_operation_verify_sumof() -> Result<()> {
I64_TEST_PAIRS
.into_iter()
.flat_map(|(a, b)| {
let (sum, overflow) = a.overflowing_add(b);
overflow.not().then_some((a, b, sum))
})
.try_for_each(|(a, b, sum)| {
let local = dict!({
"sum" => sum,
"a" => a,
"b" => b,
});
let st1: mainpod::Statement = Statement::contains(local.clone(), "sum", sum).into();
let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into();
let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into();
let st: mainpod::Statement = Statement::sum_of(
AnchoredKey::from((&local, "sum")),
AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "b")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::SumOf),
vec![
OperationArg::Index(0),
OperationArg::Index(1),
OperationArg::Index(2),
],
OperationAux::None,
);
let prev_statements = vec![st1, st2, st3];
operation_verify(st, op, prev_statements, Aux::default())
})
}
#[test]
fn test_operation_verify_productof() -> Result<()> {
I64_TEST_PAIRS
.into_iter()
.flat_map(|(a, b)| {
let (prod, overflow) = a.overflowing_mul(b);
overflow.not().then_some((a, b, prod))
})
.try_for_each(|(a, b, prod)| {
let local = dict!({
"prod" => prod,
"a" => a,
"b" => b,
});
let st1: mainpod::Statement =
Statement::contains(local.clone(), "prod", prod).into();
let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into();
let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into();
let st: mainpod::Statement = Statement::product_of(
AnchoredKey::from((&local, "prod")),
AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "b")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::ProductOf),
vec![
OperationArg::Index(0),
OperationArg::Index(1),
OperationArg::Index(2),
],
OperationAux::None,
);
let prev_statements = vec![st1, st2, st3];
operation_verify(st, op, prev_statements, Aux::default())
})
}
#[test]
fn test_operation_verify_maxof() -> Result<()> {
I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| {
let max = i64::max(a, b);
let local = dict!({
"max" => max,
"a" => a,
"b" => b,
});
let st1: mainpod::Statement = Statement::contains(local.clone(), "max", max).into();
let st2: mainpod::Statement = Statement::contains(local.clone(), "a", a).into();
let st3: mainpod::Statement = Statement::contains(local.clone(), "b", b).into();
let st: mainpod::Statement = Statement::max_of(
AnchoredKey::from((&local, "max")),
AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "b")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::MaxOf),
vec![
OperationArg::Index(0),
OperationArg::Index(1),
OperationArg::Index(2),
],
OperationAux::None,
);
let prev_statements = vec![st1, st2, st3];
operation_verify(st, op, prev_statements, Aux::default())
})
}
#[test]
fn test_operation_verify_maxof_failures() {
[(5, 3, 4), (5, 5, 8), (3, 4, 5)]
.into_iter()
.for_each(|(max, a, b)| {
let st: mainpod::Statement = Statement::max_of(max, a, b).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::MaxOf),
vec![
OperationArg::Index(0),
OperationArg::Index(0),
OperationArg::Index(0),
],
OperationAux::None,
);
let prev_statements = [Statement::None.into()];
let check = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
operation_verify(st, op, prev_statements.to_vec(), Aux::default())
}));
match check {
Err(e) => {
let err_string = e.downcast_ref::<String>().unwrap();
if !err_string.contains("Integer too large to fit") {
panic!("Test failed with an unexpected error: {}", err_string);
}
}
Ok(Err(_)) => {}
_ => panic!("Test passed, yet it should have failed!"),
}
})
}
#[test]
fn test_operation_verify_lt_to_neq() -> Result<()> {
let local = dict!({
"a" => 10,
"b" => 20,
});
let st: mainpod::Statement = Statement::not_equal(
AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "b")),
)
.into();
let st1: mainpod::Statement = Statement::lt(
AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "b")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::LtToNotEqual),
vec![OperationArg::Index(0)],
OperationAux::None,
);
let prev_statements = vec![st1];
operation_verify(st, op, prev_statements, Aux::default())
}
#[test]
fn test_operation_verify_transitive_eq() -> Result<()> {
let local = dict!({
"a" => 10,
"b" => 10,
"c" => 10,
});
let st: mainpod::Statement = Statement::equal(
AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "c")),
)
.into();
let st1: mainpod::Statement = Statement::equal(
AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "b")),
)
.into();
let st2: mainpod::Statement = Statement::equal(
AnchoredKey::from((&local, "b")),
AnchoredKey::from((&local, "c")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::TransitiveEqualFromStatements),
vec![OperationArg::Index(0), OperationArg::Index(1)],
OperationAux::None,
);
let prev_statements = vec![st1, st2];
operation_verify(st, op, prev_statements, Aux::default())
}
#[test]
fn test_operation_verify_sintains() -> Result<()> {
let kvs = [
(1.into(), 55.into()),
(2.into(), 88.into()),
(175.into(), 0.into()),
]
.into_iter()
.collect();
let mt = MerkleTree::new(&kvs);
let root = mt.root();
let key = Value::from(5);
let local = dict!({
"merkle_root" => root,
"key" => key.clone(),
});
let root_ak = AnchoredKey::from((&local, "merkle_root"));
let key_ak = AnchoredKey::from((&local, "key"));
let no_key_pf = mt.prove_nonexistence(&key.raw())?;
let root_st: mainpod::Statement =
Statement::contains(local.clone(), "merkle_root", root).into();
let key_st: mainpod::Statement =
Statement::contains(local.clone(), "key", key.clone()).into();
let st: mainpod::Statement = Statement::not_contains(root_ak, key_ak).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::NotContainsFromEntries),
vec![OperationArg::Index(0), OperationArg::Index(1)],
OperationAux::MerkleProofIndex(0),
);
let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), None, no_key_pf);
let prev_statements = vec![root_st, key_st];
operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof))
}
#[test]
fn test_operation_verify_contains() -> Result<()> {
let kvs = [
(1.into(), 55.into()),
(2.into(), 88.into()),
(175.into(), 0.into()),
]
.into_iter()
.collect();
let mt = MerkleTree::new(&kvs);
let root = mt.root();
let key = Value::from(175);
let (value, key_pf) = mt.prove(&key.raw())?;
let local = dict!({
"merkle_root" => root,
"key" => key.clone(),
"value" => value,
});
let root_ak = AnchoredKey::from((&local, "merkle_root"));
let key_ak = AnchoredKey::from((&local, "key"));
let value_ak = AnchoredKey::from((&local, "value"));
let root_st: mainpod::Statement =
Statement::contains(local.clone(), "merkle_root", root).into();
let key_st: mainpod::Statement =
Statement::contains(local.clone(), "key", key.clone()).into();
let value_st: mainpod::Statement =
Statement::contains(local.clone(), "value", value).into();
let st: mainpod::Statement = Statement::contains(root_ak, key_ak, value_ak).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::ContainsFromEntries),
vec![
OperationArg::Index(0),
OperationArg::Index(1),
OperationArg::Index(2),
],
OperationAux::MerkleProofIndex(0),
);
let merkle_proof = MerkleClaimAndProof::new(root, key.raw(), Some(value), key_pf);
let prev_statements = vec![root_st, key_st, value_st];
operation_verify(st, op, prev_statements, Aux::merkle_proof(merkle_proof))
}
#[test]
fn test_operation_verify_merkle_insert() -> Result<()> {
let mut tree = MerkleTree::new(&[].into());
let key = Value::from(175);
let value = Value::from(0);
let state_transition_proof = tree.insert(&key.raw(), &value.raw())?;
let old_root = state_transition_proof.old_root;
let new_root = state_transition_proof.new_root;
let st: mainpod::Statement = Statement::insert(new_root, old_root, key, value).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::ContainerInsertFromEntries),
vec![
OperationArg::Index(0),
OperationArg::Index(0),
OperationArg::Index(0),
OperationArg::Index(0),
],
OperationAux::MerkleTreeStateTransitionProofIndex(0),
);
let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof);
let prev_statements = vec![Statement::None.into()];
operation_verify(st, op, prev_statements, aux)
}
#[test]
fn test_operation_verify_merkle_update() -> Result<()> {
let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into());
let key = Value::from(175);
let value = Value::from(0);
let state_transition_proof = tree.update(&key.raw(), &value.raw())?;
let old_root = state_transition_proof.old_root;
let new_root = state_transition_proof.new_root;
let st: mainpod::Statement = Statement::update(new_root, old_root, key, value).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::ContainerUpdateFromEntries),
vec![
OperationArg::Index(0),
OperationArg::Index(0),
OperationArg::Index(0),
OperationArg::Index(0),
],
OperationAux::MerkleTreeStateTransitionProofIndex(0),
);
let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof);
let prev_statements = vec![Statement::None.into()];
operation_verify(st, op, prev_statements, aux)
}
#[test]
fn test_operation_verify_merkle_delete() -> Result<()> {
let mut tree = MerkleTree::new(&[(175.into(), 55.into())].into());
let key = Value::from(175);
let state_transition_proof = tree.delete(&key.raw())?;
let old_root = state_transition_proof.old_root;
let new_root = state_transition_proof.new_root;
let st: mainpod::Statement = Statement::delete(new_root, old_root, key).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::ContainerDeleteFromEntries),
vec![
OperationArg::Index(0),
OperationArg::Index(0),
OperationArg::Index(0),
],
OperationAux::MerkleTreeStateTransitionProofIndex(0),
);
let aux = Aux::merkle_tree_state_transition_proof(state_transition_proof);
let prev_statements = vec![Statement::None.into()];
operation_verify(st, op, prev_statements, aux)
}
#[test]
fn test_operation_verify_publickeyof_ok() -> Result<()> {
[
SecretKey(BigUint::one()),
SecretKey::new_rand(),
SecretKey(&*GROUP_ORDER - BigUint::one()),
]
.into_iter()
.try_for_each(|secret_key| {
let public_key = secret_key.public_key();
let st: mainpod::Statement =
Statement::public_key_of(public_key, secret_key.clone()).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::PublicKeyOf),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::PublicKeyOfIndex(0),
);
let prev_statements = vec![Statement::None.into()];
operation_verify(st, op, prev_statements, Aux::secret_key(secret_key))
})
}
#[test]
fn test_operation_verify_publickeyof_failure_wrong_key() {
let secret_key = SecretKey(BigUint::one());
let public_key = SecretKey(BigUint::ZERO).public_key();
let st: mainpod::Statement =
Statement::public_key_of(public_key, secret_key.clone()).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::PublicKeyOf),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::PublicKeyOfIndex(0),
);
let prev_statements = vec![Statement::None.into()];
assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err())
}
#[test]
fn test_operation_verify_publickeyof_failure_pk_type() {
let secret_key = SecretKey(BigUint::one());
let public_key = 123i64;
let st: mainpod::Statement =
Statement::public_key_of(public_key, secret_key.clone()).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::PublicKeyOf),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::None,
);
let prev_statements = vec![Statement::None.into()];
assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err())
}
#[test]
fn test_operation_verify_publickeyof_failure_sk_type() {
let secret_key = 123i64;
let public_key = SecretKey(BigUint::from(123u32)).public_key();
let st: mainpod::Statement = Statement::public_key_of(public_key, secret_key).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::PublicKeyOf),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::PublicKeyOfIndex(0),
);
let prev_statements = vec![Statement::None.into()];
let aux = Aux::secret_key(SecretKey(BigUint::from(123u32)));
assert!(operation_verify(st, op, prev_statements, aux,).is_err())
}
#[test]
fn test_operation_verify_publickeyof_failure_sk_size() {
let secret_key = SecretKey(&*GROUP_ORDER - BigUint::ZERO);
let public_key = secret_key.public_key();
let st: mainpod::Statement =
Statement::public_key_of(public_key, secret_key.clone()).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::PublicKeyOf),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::PublicKeyOfIndex(0),
);
let prev_statements = vec![Statement::None.into()];
assert!(operation_verify(st, op, prev_statements, Aux::secret_key(secret_key)).is_err())
}
#[test]
fn test_operation_verify_signedby_ok() -> Result<()> {
let sk = SecretKey(BigUint::from_u32(0xbadcafe).unwrap());
let pk = sk.public_key();
let msg = RawValue([F(1), F(2), F(3), F(4)]);
let nonce = BigUint::from_u32(123).unwrap();
let sig = signer::Signer(sk).sign_with_nonce(nonce, msg);
let signed_by = SignedBy {
msg,
pk,
sig: sig.clone(),
};
let st: mainpod::Statement = Statement::signed_by(msg, pk).into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::SignedBy),
vec![OperationArg::Index(0), OperationArg::Index(0)],
OperationAux::SignedByIndex(0),
);
let prev_statements = vec![Statement::None.into()];
operation_verify(st, op, prev_statements, Aux::signed_by(signed_by))
}
#[test]
fn test_operation_replace_value_with_entry() -> Result<()> {
let d = dict!({"a" => 42, "b" => 33});
// 0: None
// 1: Lt(5, 42)
let st_in: mainpod::Statement = Statement::lt(5, 42).into();
// 2: Contains(d, "a", 42)
let st_entry: mainpod::Statement = Statement::contains(d.clone(), "a", 42).into();
let st_out: mainpod::Statement =
Statement::lt(5, ValueRef::Key(AnchoredKey::from((&d, "a")))).into();
let mut op_args: Vec<_> = iter::repeat(OperationArg::None)
.take(BASE_PARAMS.max_statement_args + 1)
.collect();
op_args[1] = OperationArg::Index(2);
op_args[BASE_PARAMS.max_statement_args] = OperationArg::Index(1);
let op = mainpod::Operation(
OperationType::Native(NativeOperation::ReplaceValueWithEntry),
op_args,
OperationAux::None,
);
let prev_statements = vec![Statement::None.into(), st_in, st_entry];
operation_verify(st_out, op, prev_statements, Aux::default())
}
fn helper_statement_arg_from_template(
params: &Params,
st_tmpl_arg: StatementTmplArg,
args: Vec<Value>,
expected_st_arg: StatementArg,
) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let st_tmpl_arg_target = builder.add_virtual_statement_tmpl_arg();
let args_target: Vec<_> = (0..args.len())
.map(|_| builder.add_virtual_value())
.collect();
let st_arg_target = make_statement_arg_from_template_circuit(
params,
&mut builder,
&st_tmpl_arg_target,
&args_target,
);
// TODO: Instead of connect, assign witness to result
let expected_st_arg_target = builder.add_virtual_statement_arg();
builder.connect_array(expected_st_arg_target.elements, st_arg_target.elements);
let mut pw = PartialWitness::<F>::new();
st_tmpl_arg_target.set_targets(&mut pw, &st_tmpl_arg)?;
for (arg_target, arg) in args_target.iter().zip(args.iter()) {
arg_target.set_targets(&mut pw, arg)?;
}
expected_st_arg_target.set_targets(&mut pw, &expected_st_arg)?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof.clone()).unwrap();
Ok(())
}
#[test]
fn test_statement_arg_from_template() -> Result<()> {
let params = Params::default();
let dict = Hash([F(6), F(7), F(8), F(9)]);
// case: None
let st_tmpl_arg = StatementTmplArg::None;
let args = vec![Value::from(1), Value::from(2), Value::from(3)];
let expected_st_arg = StatementArg::None;
helper_statement_arg_from_template(&params, st_tmpl_arg, args, expected_st_arg)?;
// case: Literal
let st_tmpl_arg = StatementTmplArg::Literal(Value::from("foo"));
let args = vec![Value::from(1), Value::from(2), Value::from(3)];
let expected_st_arg = StatementArg::Literal(Value::from("foo"));
helper_statement_arg_from_template(&params, st_tmpl_arg, args, expected_st_arg)?;
// case: AnchoredKey(id_wildcard, key_literal)
let st_tmpl_arg =
StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("foo"));
let args = vec![Value::from(1), Value::from(dict), Value::from(3)];
let expected_st_arg = StatementArg::Key(AnchoredKey::new(dict, Key::from("foo")));
helper_statement_arg_from_template(&params, st_tmpl_arg, args, expected_st_arg)?;
// case: WildcardLiteral(wildcard)
let st_tmpl_arg = StatementTmplArg::Wildcard(Wildcard::new("a".to_string(), 1));
let args = vec![Value::from(1), Value::from("key"), Value::from(3)];
let expected_st_arg = StatementArg::Literal(Value::from("key"));
helper_statement_arg_from_template(&params, st_tmpl_arg, args, expected_st_arg)?;
Ok(())
}
fn helper_statement_from_template(
params: &Params,
st_tmpl: StatementTmpl,
args: Vec<Value>,
expected_st: Statement,
) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let st_tmpl_target = builder.add_virtual_statement_tmpl(false);
let args_target: Vec<_> = (0..args.len())
.map(|_| builder.add_virtual_value())
.collect();
let st_target = make_statement_from_template_circuit(
params,
&mut builder,
&st_tmpl_target,
&args_target,
);
// TODO: Instead of connect, assign witness to result
let expected_st_target = builder.add_virtual_statement(false);
builder.connect_flattenable(&expected_st_target, &st_target);
let mut pw = PartialWitness::<F>::new();
st_tmpl_target.set_targets(&mut pw, &st_tmpl)?;
for (arg_target, arg) in args_target.iter().zip(args.iter()) {
arg_target.set_targets(&mut pw, arg)?;
}
expected_st_target.set_targets(&mut pw, &expected_st.into())?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof.clone()).unwrap();
Ok(())
}
#[test]
fn test_statement_from_template() -> Result<()> {
let params = Params::default();
let dict = Hash([F(6), F(7), F(8), F(9)]);
let st_tmpl = StatementTmpl {
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")),
StatementTmplArg::Literal(Value::from("value")),
],
};
let args = vec![Value::from(1), Value::from(dict), Value::from(3)];
let expected_st = Statement::equal(
AnchoredKey::new(dict, Key::from("key")),
Value::from("value"),
);
helper_statement_from_template(&params, st_tmpl, args, expected_st)?;
let st_tmpl = StatementTmpl {
pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)),
args: vec![
StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")),
StatementTmplArg::Literal(Value::from("value")),
],
};
let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash();
let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)];
let expected_st = Statement::not_equal(
AnchoredKey::new(dict, Key::from("key")),
Value::from("value"),
);
helper_statement_from_template(&params, st_tmpl, args, expected_st)?;
Ok(())
}
fn helper_custom_operation_verify_gadget(
params: &Params,
custom_predicate: CustomPredicateRef,
mut op_args: Vec<Statement>,
mut args: Vec<Value>,
expected_st: Option<Statement>,
) -> Result<()> {
// Pad
for _ in op_args.len()..BASE_PARAMS.max_operation_args {
op_args.push(Statement::None);
}
for _ in args.len()..params.max_custom_predicate_wildcards {
args.push(Value::from(EMPTY_VALUE));
}
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let custom_predicate_target = builder.add_virtual_custom_predicate_entry();
let op_args_target: Vec<_> = (0..op_args.len())
.map(|_| builder.add_virtual_statement(false))
.collect();
let args_target: Vec<_> = (0..args.len())
.map(|_| builder.add_virtual_value())
.collect();
let (st_target, op_type_target) = make_custom_statement_circuit(
params,
&mut builder,
&custom_predicate_target,
&op_args_target,
&args_target,
)?;
let mut pw = PartialWitness::<F>::new();
// Input
custom_predicate_target.set_targets(&mut pw, &custom_predicate)?;
for (op_arg_target, op_arg) in op_args_target.iter().zip(op_args.into_iter()) {
op_arg_target.set_targets(&mut pw, &op_arg.into())?;
}
for (arg_target, arg) in args_target.iter().zip(args.iter()) {
arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?;
}
// Expected Output
if let Some(expected_st) = expected_st {
st_target.set_targets(&mut pw, &expected_st.into())?;
}
let expected_op_type = OperationType::Custom(custom_predicate);
op_type_target.set_targets(&mut pw, &expected_op_type)?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw)?;
Ok(data.verify(proof.clone())?)
}
fn value_ref(v: impl Into<ValueRef>) -> ValueRef {
v.into()
}
// TODO: Add negative tests
#[test]
fn test_custom_operation_verify_gadget_positive() -> frontend::Result<()> {
let params = Params::default();
use NativePredicate as NP;
use StatementTmplBuilder as STB;
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb0 = STB::new_from_pred(NP::Equal)
.arg(("id", "score"))
.arg(literal(42));
let stb1 = STB::new_from_pred(NP::Equal)
.arg(("id", "key"))
.arg("secret");
let _ = builder.predicate_and(
"pred_and",
&["id"],
&["secret"],
&[stb0.clone(), stb1.clone()],
)?;
let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?;
let batch = builder.finish()?;
let dict = Hash([F(6), F(7), F(8), F(9)]);
// AND
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
let op_args = vec![
Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)),
Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)),
];
let args = vec![Value::from(dict), Value::from(1234)];
let expected_st = Statement::Custom(
custom_predicate.clone(),
vec![value_ref(args[0].clone()), value_ref(0)],
);
helper_custom_operation_verify_gadget(
&params,
custom_predicate,
op_args,
args,
Some(expected_st),
)
.unwrap();
// OR (1)
let custom_predicate = CustomPredicateRef::new(batch.clone(), 1);
let op_args = vec![
Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)),
Statement::None,
];
let args = vec![Value::from(dict), Value::from(0)];
let expected_st = Statement::Custom(
custom_predicate.clone(),
vec![value_ref(args[0].clone()), value_ref(0)],
);
helper_custom_operation_verify_gadget(
&params,
custom_predicate,
op_args,
args,
Some(expected_st),
)
.unwrap();
// OR (2)
let custom_predicate = CustomPredicateRef::new(batch.clone(), 1);
let op_args = vec![
Statement::None,
Statement::equal(AnchoredKey::new(dict, Key::from("key")), Value::from(1234)),
];
let args = vec![Value::from(dict), Value::from(1234)];
let expected_st = Statement::Custom(
custom_predicate.clone(),
vec![value_ref(args[0].clone()), value_ref(0)],
);
helper_custom_operation_verify_gadget(
&params,
custom_predicate,
op_args,
args,
Some(expected_st),
)
.unwrap();
Ok(())
}
#[test]
fn test_custom_operation_verify_gadget_negative() -> frontend::Result<()> {
let params = Params::default();
use NativePredicate as NP;
use StatementTmplBuilder as STB;
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb0 = STB::new_from_pred(NP::Equal)
.arg(("id", "score"))
.arg(literal(42));
let stb1 = STB::new_from_pred(NP::Equal)
.arg(("secret_id", "key"))
.arg(("id", "score"));
let _ = builder.predicate_and(
"pred_and",
&["id"],
&["secret_id"],
&[stb0.clone(), stb1.clone()],
)?;
let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?;
let batch = builder.finish()?;
let dict = Hash([F(1), F(2), F(3), F(4)]);
let secret_dict = Hash([F(6), F(7), F(8), F(9)]);
// AND (0) Sanity check with correct values
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
let op_args = vec![
Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)),
Statement::equal(
AnchoredKey::new(secret_dict, Key::from("key")),
AnchoredKey::new(dict, Key::from("score")),
),
];
let args = vec![Value::from(dict), Value::from(secret_dict)];
let expected_st = Statement::Custom(
custom_predicate.clone(),
vec![value_ref(args[0].clone()), value_ref(0)],
);
helper_custom_operation_verify_gadget(
&params,
custom_predicate,
op_args,
args,
Some(expected_st),
)
.unwrap();
// AND (1) Different dict for same wildcard
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
let op_args = vec![
Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)),
Statement::equal(
AnchoredKey::new(secret_dict, Key::from("key")),
AnchoredKey::new(Hash([F(0), F(5), F(1), F(6)]), Key::from("score")),
),
];
let args = vec![Value::from(dict), Value::from(secret_dict)];
assert!(helper_custom_operation_verify_gadget(
&params,
custom_predicate,
op_args,
args,
None,
)
.is_err());
// AND (2) key doesn't match template
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
let op_args = vec![
Statement::equal(AnchoredKey::new(dict, Key::from("BAD")), Value::from(42)),
Statement::equal(
AnchoredKey::new(secret_dict, Key::from("key")),
AnchoredKey::new(dict, Key::from("score")),
),
];
let args = vec![Value::from(dict), Value::from(secret_dict)];
assert!(helper_custom_operation_verify_gadget(
&params,
custom_predicate,
op_args,
args,
None,
)
.is_err());
// AND (3) literal doesn't match template
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
let op_args = vec![
Statement::equal(
AnchoredKey::new(dict, Key::from("score")),
Value::from(0xbad),
),
Statement::equal(
AnchoredKey::new(secret_dict, Key::from("key")),
AnchoredKey::new(dict, Key::from("score")),
),
];
let args = vec![Value::from(dict), Value::from(secret_dict)];
assert!(helper_custom_operation_verify_gadget(
&params,
custom_predicate,
op_args,
args,
None,
)
.is_err());
// AND (4) predicate doesn't match template
let custom_predicate = CustomPredicateRef::new(batch.clone(), 0);
let op_args = vec![
Statement::equal(AnchoredKey::new(dict, Key::from("score")), Value::from(42)),
Statement::not_equal(
AnchoredKey::new(secret_dict, Key::from("key")),
AnchoredKey::new(dict, Key::from("score")),
),
];
let args = vec![Value::from(dict), Value::from(secret_dict)];
assert!(helper_custom_operation_verify_gadget(
&params,
custom_predicate,
op_args,
args,
None,
)
.is_err());
// OR (1) Two Nones
let custom_predicate = CustomPredicateRef::new(batch.clone(), 1);
let op_args = vec![Statement::None, Statement::None];
let args = vec![Value::from(dict), Value::from(0)];
assert!(helper_custom_operation_verify_gadget(
&params,
custom_predicate,
op_args,
args,
None
)
.is_err());
Ok(())
}
fn helper_calculate_statements_hash(params: &Params, statements: &[Statement]) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let statements_target = (0..params.max_public_statements)
.map(|_| builder.add_virtual_statement(false))
.collect_vec();
let sts_hash_target = calculate_statements_hash_circuit(&mut builder, &statements_target);
let mut pw = PartialWitness::<F>::new();
// Input
let statements = statements
.iter()
.map(|st| {
let mut st = mainpod::Statement::from(st.clone());
pad_statement(&mut st);
st
})
.collect_vec();
for (st_target, st) in statements_target.iter().zip(statements.iter()) {
st_target.set_targets(&mut pw, st)?;
}
// Expected Output
let expected_sts_hash = calculate_statements_hash(&statements);
pw.set_hash_target(
sts_hash_target,
HashOut {
elements: expected_sts_hash.0,
},
)?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw)?;
Ok(data.verify(proof.clone())?)
}
#[test]
fn test_calculate_sts_hash() -> frontend::Result<()> {
assert_eq!(Params::num_public_statements_hash(), 16);
// Case with no public public statements
let params = Params {
max_public_statements: 0,
..Default::default()
};
helper_calculate_statements_hash(&params, &[]).unwrap();
// Case with number of statements for the sts_hash equal to number of public statements
let params = Params {
max_public_statements: Params::num_public_statements_hash(),
..Default::default()
};
let dict = Hash([F(1), F(2), F(3), F(4)]);
let statements = (0..Params::num_public_statements_hash())
.map(|i| Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(i as i64)))
.collect_vec();
helper_calculate_statements_hash(&params, &statements).unwrap();
// Case with more statements for the sts_hash than the number of public statements
let params = Params {
max_public_statements: 4,
..Default::default()
};
let dict2 = Hash([F(5), F(6), F(7), F(8)]);
let statements = [
Statement::equal(AnchoredKey::from((dict, "foo")), Value::from(42)),
Statement::equal(
AnchoredKey::from((dict, "bar")),
AnchoredKey::from((dict, "baz")),
),
Statement::lt(
AnchoredKey::from((dict2, "one")),
AnchoredKey::from((dict2, "two")),
),
]
.into_iter()
.chain(iter::repeat(Statement::None))
.take(params.max_public_statements)
.collect_vec();
helper_calculate_statements_hash(&params, &statements).unwrap();
Ok(())
}
#[test]
fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> {
let params = Params::default();
// Build a batch with two predicates:
// pred_A: Equal(x, y)
// pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash
use NativePredicate as NP;
let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal)
.arg("x")
.arg("y");
cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a])
.unwrap();
// Build pred_B's template manually with SelfPredicateHash(0)
let stb_b_tmpl = StatementTmpl {
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)),
args: vec![
StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)),
StatementTmplArg::SelfPredicateHash(0),
],
};
let pred_b = CustomPredicate::new(
&params,
"pred_B".into(),
true,
vec![stb_b_tmpl],
1,
vec!["x".to_string()],
)
.unwrap();
cpb.predicates.push(pred_b);
let batch = cpb.finish().unwrap();
// Compute the expected resolved hash of pred_A
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_a_hash = Predicate::Custom(pred_a_ref).hash();
let expected_pred_a_value = Value::from(pred_a_hash);
// Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to
// Literal(pred_a_hash). Then make_statement_from_template_circuit should produce
// a statement with that literal value.
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_b_tmpl = &pred_b_ref.predicate().statements[0];
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
// Create the template target and batch id target
let st_tmpl_target = builder.add_virtual_statement_tmpl(true);
let batch_id = builder.add_virtual_hash();
// Normalize the template (this is what we're testing)
let normalized =
normalize_st_tmpl_circuit(&params, &mut builder, &st_tmpl_target, batch_id);
// Feed normalized template into statement generation
let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value())
.collect();
let st_target =
make_statement_from_template_circuit(&params, &mut builder, &normalized, &args_target);
// Connect to expected output
let expected_st_target = builder.add_virtual_statement(false);
builder.connect_flattenable(&expected_st_target, &st_target);
// Set witness
let mut pw = PartialWitness::<F>::new();
st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?;
pw.set_target_arr(&batch_id.elements, &batch.id().0)?;
let some_value = Value::from(42);
// args: first wildcard is "x" = some_value, rest are padding
let mut args_values = vec![some_value.clone()];
for _ in 1..params.max_custom_predicate_wildcards {
args_values.push(Value::from(EMPTY_VALUE));
}
for (target, value) in args_target.iter().zip(args_values.iter()) {
target.set_targets(&mut pw, value)?;
}
// Expected statement: Equal(Literal(some_value), Literal(pred_a_hash))
let expected_st: crate::backends::plonky2::mainpod::Statement =
Statement::equal(some_value, expected_pred_a_value).into();
expected_st_target.set_targets(&mut pw, &expected_st)?;
// Build and verify
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
}