pod2/src/backends/plonky2/circuits/common.rs
Rob Knight 1e592e11cf
Self-referential predicate hashes as statement template args (#494)
* Support quoted predicate hashes, including self-referential predicates

* Clippy

* Review feedback
2026-03-24 07:25:11 -07:00

2134 lines
74 KiB
Rust

//! Common functionality to build Pod circuits with plonky2
use std::{array, iter};
use itertools::Itertools;
use plonky2::{
field::{
extension::Extendable,
types::{Field, PrimeField64},
},
hash::{
hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS},
poseidon::PoseidonHash,
},
iop::{
generator::{GeneratedValues, SimpleGenerator},
target::{BoolTarget, Target},
witness::{PartialWitness, PartitionWitness, Witness, WitnessWrite},
},
plonk::config::Hasher,
util::serialization::{Buffer, IoResult, Read, Write},
};
use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::{
basetypes::{CircuitBuilder, CommonCircuitData, D},
circuits::mainpod::CustomPredicateVerification,
error::Result,
mainpod::{Operation, OperationArg, OperationAux, Statement},
primitives::merkletree::{
verify_merkle_proof_circuit, MerkleClaimAndProof, MerkleClaimAndProofTarget,
MerkleProof, MerkleTreeStateTransitionProofTarget,
},
},
middleware::{
hash_fields, CustomPredicate, CustomPredicateRef, NativeOperation, NativePredicate,
OperationType, Params, Predicate, PredicateOrWildcard, PredicateOrWildcardPrefix,
PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg,
StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE, STATEMENT_ARG_F_LEN,
VALUE_SIZE,
},
};
pub const CODE_SIZE: usize = HASH_SIZE + 2;
const NUM_BITS: usize = 32;
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct ValueTarget {
pub elements: [Target; VALUE_SIZE],
}
impl From<ValueTarget> for HashOutTarget {
fn from(v: ValueTarget) -> HashOutTarget {
HashOutTarget {
elements: v.elements,
}
}
}
impl From<HashOutTarget> for ValueTarget {
fn from(h: HashOutTarget) -> ValueTarget {
ValueTarget {
elements: h.elements,
}
}
}
impl ValueTarget {
pub fn zero(builder: &mut CircuitBuilder) -> Self {
Self {
elements: [builder.zero(); VALUE_SIZE],
}
}
pub fn one(builder: &mut CircuitBuilder) -> Self {
Self {
elements: array::from_fn(|i| {
if i == 0 {
builder.one()
} else {
builder.zero()
}
}),
}
}
pub fn from_slice(xs: &[Target]) -> Self {
assert_eq!(xs.len(), VALUE_SIZE);
Self {
elements: array::from_fn(|i| xs[i]),
}
}
pub fn set_targets(&self, pw: &mut PartialWitness<F>, value: &Value) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &value.raw().0)?)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct StatementArgTarget {
#[serde(with = "serde_arrays")]
pub elements: [Target; STATEMENT_ARG_F_LEN],
}
impl StatementArgTarget {
pub fn set_targets(&self, pw: &mut PartialWitness<F>, arg: &StatementArg) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &arg.to_fields())?)
}
pub fn new(first: ValueTarget, second: ValueTarget) -> Self {
let elements: Vec<_> = first.elements.into_iter().chain(second.elements).collect();
StatementArgTarget {
elements: elements.try_into().expect("size STATEMENT_ARG_F_LEN"),
}
}
pub fn none(builder: &mut CircuitBuilder) -> Self {
let empty = builder.constant_value(EMPTY_VALUE);
Self::new(empty, empty)
}
pub fn literal(builder: &mut CircuitBuilder, value: &ValueTarget) -> Self {
let empty = builder.constant_value(EMPTY_VALUE);
Self::new(*value, empty)
}
pub fn anchored_key(
_builder: &mut CircuitBuilder,
dict: &ValueTarget,
key: &ValueTarget,
) -> Self {
Self::new(*dict, *key)
}
pub fn wildcard_literal(builder: &mut CircuitBuilder, value: &ValueTarget) -> Self {
let empty = builder.constant_value(EMPTY_VALUE);
Self::new(*value, empty)
}
/// StatementArgTarget to ValueTarget coercion. Make sure to check
/// that the arg is a value using the `statement_arg_is_value` method
/// first!
pub fn as_value(&self) -> ValueTarget {
ValueTarget::from_slice(&self.elements[..VALUE_SIZE])
}
fn size(_params: &Params) -> usize {
STATEMENT_ARG_F_LEN
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct StatementTarget {
// If the pred is Some, then the `pred_hash` is constrained to be the `hash(pred)`.
pred: Option<PredicateTarget>,
pred_hash: HashOutTarget,
pub args: Vec<StatementArgTarget>,
}
impl StatementTarget {
pub fn pred(&self) -> Option<&PredicateTarget> {
self.pred.as_ref()
}
pub fn pred_hash(&self) -> &HashOutTarget {
&self.pred_hash
}
pub fn new(pred_hash: HashOutTarget, args: Vec<StatementArgTarget>) -> Self {
Self {
pred: None,
pred_hash,
args,
}
}
pub fn new_with_pred(
builder: &mut CircuitBuilder,
params: &Params,
predicate: impl Build<PredicateTarget>,
args: &[StatementArgTarget],
) -> Self {
let pred = predicate.build(builder, params);
let pred_hash = pred.hash(builder);
Self {
pred: Some(pred),
pred_hash,
args: args
.iter()
.cloned()
.chain(iter::repeat_with(|| StatementArgTarget::none(builder)))
.take(Params::max_statement_args())
.collect(),
}
}
pub fn new_native(
builder: &mut CircuitBuilder,
params: &Params,
native_predicate: impl Build<NativePredicateTarget>,
args: &[StatementArgTarget],
) -> Self {
let pred = PredicateTarget::new_native(builder, params, native_predicate);
Self::new_with_pred(builder, params, pred, args)
}
pub fn set_targets(&self, pw: &mut PartialWitness<F>, st: &Statement) -> Result<()> {
if let Some(pred) = &self.pred {
pred.set_targets(pw, &st.predicate())?;
}
pw.set_hash_target(self.pred_hash, HashOut::from(st.predicate().hash()))?;
for (i, arg) in st
.args()
.iter()
.chain(iter::repeat(&StatementArg::None))
.take(Params::max_statement_args())
.enumerate()
{
self.args[i].set_targets(pw, arg)?;
}
Ok(())
}
pub fn pred_is_blank_intro(&self, builder: &mut CircuitBuilder) -> BoolTarget {
let zero_hash = builder.constant_hash(HashOut {
elements: [F::ZERO, F::ZERO, F::ZERO, F::ZERO],
});
let blank_intro = PredicateTarget::new_intro(builder, zero_hash).hash(builder);
builder.is_equal_flattenable(&self.pred_hash, &blank_intro)
}
pub fn has_native_type(&self, builder: &mut CircuitBuilder, t: NativePredicate) -> BoolTarget {
let expected_predicate_hash =
builder.constant_hash(HashOut::from(Predicate::Native(t).hash()));
builder.is_equal_flattenable(&self.pred_hash, &expected_predicate_hash)
}
}
pub trait Build<T> {
fn build(self, builder: &mut CircuitBuilder, params: &Params) -> T;
}
impl Build<NativePredicateTarget> for NativePredicate {
fn build(self, builder: &mut CircuitBuilder, _params: &Params) -> NativePredicateTarget {
NativePredicateTarget::constant(builder, self)
}
}
impl<T> Build<T> for T {
fn build(self, _builder: &mut CircuitBuilder, _params: &Params) -> T {
self
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct OperationTypeTarget {
#[serde(with = "serde_arrays")]
pub elements: [Target; Params::operation_type_size()],
}
impl OperationTypeTarget {
pub fn new_custom(
builder: &mut CircuitBuilder,
batch_id: HashOutTarget,
index: Target,
) -> Self {
// TODO: Use an enum for these prefixes
let three = builder.constant(F::from_canonical_usize(3));
let id = batch_id.elements;
Self {
elements: [three, id[0], id[1], id[2], id[3], index],
}
}
pub fn as_custom(&self, builder: &mut CircuitBuilder) -> (BoolTarget, HashOutTarget, Target) {
// TODO: Use an enum for these prefixes
let three = builder.constant(F::from_canonical_usize(3));
let op_is_custom = builder.is_equal(self.elements[0], three);
let batch_id = HashOutTarget::from_vec(self.elements[1..5].to_vec());
let index = self.elements[5];
(op_is_custom, batch_id, index)
}
pub fn has_native(&self, builder: &mut CircuitBuilder, t: NativeOperation) -> BoolTarget {
// TODO: Use an enum for these prefixes
let one = builder.one();
let op_is_native = builder.is_equal(self.elements[0], one);
let op_code = builder.constant(F::from_canonical_u64(t as u64));
let op_code_matches = builder.is_equal(self.elements[1], op_code);
builder.and(op_is_native, op_code_matches)
}
pub fn set_targets(&self, pw: &mut PartialWitness<F>, op_type: &OperationType) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &op_type.to_fields())?)
}
fn size(_params: &Params) -> usize {
Params::operation_type_size()
}
}
// TODO: Implement Operation::to_field to determine the size of each element
#[derive(Clone, Serialize, Deserialize)]
pub struct OperationTarget {
pub op_type: OperationTypeTarget,
pub args: Vec<IndexTarget>,
pub aux_index: IndexTarget,
}
impl OperationTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
op: &Operation,
) -> Result<()> {
self.op_type.set_targets(pw, &op.op_type())?;
for (i, arg) in op
.args()
.iter()
.chain(iter::repeat(&OperationArg::None))
.take(params.max_operation_args)
.enumerate()
{
self.args[i].set_targets(pw, arg.as_usize())?;
}
self.aux_index.set_targets(pw, op.aux().table_index(params))
}
fn size(params: &Params) -> usize {
OperationTypeTarget::size(params)
+ params.max_operation_args * IndexTarget::size(params)
+ IndexTarget::size(params)
}
}
#[derive(Clone)]
pub struct NativePredicateTarget(Target);
impl NativePredicateTarget {
pub fn constant(builder: &mut CircuitBuilder, native_predicate: NativePredicate) -> Self {
let id = native_predicate.to_fields();
assert_eq!(1, id.len());
Self(builder.constant(id[0]))
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
native_predicate: NativePredicate,
) -> Result<()> {
let id = native_predicate.to_fields();
assert_eq!(1, id.len());
Ok(pw.set_target(self.0, id[0])?)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct PredicateTarget {
#[serde(with = "serde_arrays")]
pub(crate) elements: [Target; Params::predicate_size()],
}
impl PredicateTarget {
pub fn new_native(
builder: &mut CircuitBuilder,
params: &Params,
native_predicate: impl Build<NativePredicateTarget>,
) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::Native));
let id = native_predicate.build(builder, params).0;
let zero = builder.zero();
Self {
elements: [prefix, id, zero, zero, zero, zero, zero, zero],
}
}
pub fn new_batch_self(builder: &mut CircuitBuilder, index: Target) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::BatchSelf));
let zero = builder.zero();
Self {
elements: [prefix, index, zero, zero, zero, zero, zero, zero],
}
}
pub fn new_custom(
builder: &mut CircuitBuilder,
batch_id: HashOutTarget,
index: Target,
) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::Custom));
let id = batch_id.elements;
let zero = builder.zero();
Self {
elements: [prefix, id[0], id[1], id[2], id[3], index, zero, zero],
}
}
pub fn new_intro(builder: &mut CircuitBuilder, vd_hash: HashOutTarget) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::Intro));
let vh = vd_hash.elements;
let zero = builder.zero();
Self {
elements: [prefix, vh[0], vh[1], vh[2], vh[3], zero, zero, zero],
}
}
pub fn is_intro(&self, builder: &mut CircuitBuilder) -> BoolTarget {
let prefix = builder.constant(F::from(PredicatePrefix::Intro));
builder.is_equal(prefix, self.elements[0])
}
pub fn set_targets(&self, pw: &mut PartialWitness<F>, predicate: &Predicate) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &predicate.to_fields())?)
}
pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget {
// Optimization: if all the predicate values are constants we skip the hash circuit and
// return a hash constant
let mut predicate_values = [F::ZERO; Params::predicate_size()];
let mut predicate_constant = true;
for (i, target) in self.elements.iter().enumerate() {
if let Some(v) = builder.target_as_constant(*target) {
predicate_values[i] = v;
} else {
predicate_constant = false;
break;
}
}
if predicate_constant {
builder.constant_hash(PoseidonHash::hash_no_pad(&predicate_values))
} else {
builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.elements.to_vec())
}
}
}
/// Mirrors `middleware::KeyOrWildcard`
#[derive(Clone)]
pub struct LiteralOrWildcardTarget {
pub elements: [Target; VALUE_SIZE],
}
impl LiteralOrWildcardTarget {
fn from_slice(v: &[Target]) -> Self {
Self {
elements: v.try_into().expect("len is VALUE_SIZE"),
}
}
/// cases: ((is_key, key), (is_wildcard, wildcard_index))
pub fn cases(
&self,
builder: &mut CircuitBuilder,
) -> ((BoolTarget, ValueTarget), (BoolTarget, Target)) {
let zero = builder.zero();
let is_zero_tail: Vec<_> = (1..4)
.map(|i| builder.is_equal(self.elements[i], zero))
.collect();
let is_wildcard = is_zero_tail
.into_iter()
.reduce(|acc, x| builder.and(acc, x))
.expect("len > 1");
let is_key = builder.not(is_wildcard);
let key = ValueTarget::from_slice(&self.elements);
let wildcard_index = self.elements[0];
((is_key, key), (is_wildcard, wildcard_index))
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct StatementTmplArgTarget {
#[serde(with = "serde_arrays")]
pub elements: [Target; Params::statement_tmpl_arg_size()],
}
impl StatementTmplArgTarget {
pub fn as_none(&self, builder: &mut CircuitBuilder) -> BoolTarget {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::None));
builder.is_equal(self.elements[0], prefix)
}
pub fn as_literal(&self, builder: &mut CircuitBuilder) -> (BoolTarget, ValueTarget) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal));
let case_ok = builder.is_equal(self.elements[0], prefix);
let value = ValueTarget::from_slice(&self.elements[1..5]);
(case_ok, value)
}
pub fn as_anchored_key(
&self,
builder: &mut CircuitBuilder,
) -> (BoolTarget, Target, LiteralOrWildcardTarget) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::AnchoredKey));
let case_ok = builder.is_equal(self.elements[0], prefix);
let id_wildcard_index = self.elements[1];
let value_key_or_wildcard = LiteralOrWildcardTarget::from_slice(&self.elements[5..9]);
(case_ok, id_wildcard_index, value_key_or_wildcard)
}
pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder) -> (BoolTarget, Target) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral));
let case_ok = builder.is_equal(self.elements[0], prefix);
let wildcard_index = self.elements[1];
(case_ok, wildcard_index)
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
st_tmpl_arg: &StatementTmplArg,
) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &st_tmpl_arg.to_fields())?)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct PredicateHashOrWildcardTarget {
/// layout: `prefix | [data]`, where data is predicate_hash or wildcard_index
pub elements: [Target; Params::pred_hash_or_wc_size()],
}
impl PredicateHashOrWildcardTarget {
pub fn new(prefix: Target, data: ValueTarget) -> Self {
let v = data.elements;
Self {
elements: [prefix, v[0], v[1], v[2], v[3]],
}
}
pub fn new_pred_hash(builder: &mut CircuitBuilder, pred_hash: HashOutTarget) -> Self {
Self::new(
builder.constant(F::from(PredicateOrWildcardPrefix::Predicate)),
ValueTarget::from(pred_hash),
)
}
pub fn is_pred(&self, builder: &mut CircuitBuilder) -> BoolTarget {
let prefix_pred = builder.constant(F::from(PredicateOrWildcardPrefix::Predicate));
builder.is_equal(self.elements[0], prefix_pred)
}
pub fn data(&self) -> ValueTarget {
ValueTarget {
elements: self.elements[1..].try_into().expect("4 elements"),
}
}
pub fn pred_hash(&self) -> HashOutTarget {
HashOutTarget::from(self.data())
}
pub fn wc_index(&self) -> Target {
self.elements[1]
}
pub fn set_targets_raw(
&self,
pw: &mut PartialWitness<F>,
prefix: PredicateOrWildcardPrefix,
data: RawValue,
) -> Result<()> {
pw.set_target(self.elements[0], F::from(prefix))?;
pw.set_target_arr(&self.elements[1..], &data.0)?;
Ok(())
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
pred: &PredicateOrWildcard,
) -> Result<()> {
match pred {
PredicateOrWildcard::Predicate(pred) => {
self.set_targets_raw(
pw,
PredicateOrWildcardPrefix::Predicate,
RawValue::from(pred.hash()),
)?;
}
PredicateOrWildcard::Wildcard(wc) => {
self.set_targets_raw(
pw,
PredicateOrWildcardPrefix::Wildcard,
RawValue([F::from_canonical_usize(wc.index), F::ZERO, F::ZERO, F::ZERO]),
)?;
}
}
Ok(())
}
}
impl Flattenable for PredicateHashOrWildcardTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
Self {
elements: vs.try_into().expect("5 elements"),
}
}
fn size(_params: &Params) -> usize {
Params::pred_hash_or_wc_size()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct StatementTmplTarget {
/// The preimage of the predicate_hash. This predicate is needed only to build the custom
/// predicate table because it needs to normalize statement templates with predicates that
/// refer to self into content-addressed predicates (using the batch id and index). The
/// predicate type is inspected to do this normalization. After the table is built we only use
/// the predicate hash for equality checks.
pred: Option<PredicateTarget>,
/// This is constrained to be `hash(pred)` through the type constructor when we have `pred`
/// and the template uses a predicate and not a wildcard.
pred_hash_or_wc: PredicateHashOrWildcardTarget,
pub args: Vec<StatementTmplArgTarget>,
}
impl StatementTmplTarget {
pub fn new(
pred_hash_or_wc: PredicateHashOrWildcardTarget,
args: Vec<StatementTmplArgTarget>,
) -> Self {
Self {
pred: None,
pred_hash_or_wc,
args,
}
}
pub fn set_targets(&self, pw: &mut PartialWitness<F>, st_tmpl: &StatementTmpl) -> Result<()> {
if let Some(pred) = &self.pred {
match &st_tmpl.pred_or_wc {
PredicateOrWildcard::Predicate(p) => {
// We store a predicate (not a wildcard) and we have it available. In this
// case the hash will be calculated by constraints later on and we should not
// rely on the original data.
pred.set_targets(pw, p)?
}
PredicateOrWildcard::Wildcard(_wc) => {
// Fill in with a recognizable constant for better debugging; this value is
// not supposed to be used.
pw.set_target_arr(&pred.elements, &[F(0xdead); Params::predicate_size()])?
}
}
}
self.pred_hash_or_wc.set_targets(pw, &st_tmpl.pred_or_wc)?;
let arg_pad = StatementTmplArg::None;
for (i, arg) in st_tmpl
.args
.iter()
.chain(iter::repeat(&arg_pad))
.take(Params::max_statement_args())
.enumerate()
{
self.args[i].set_targets(pw, arg)?;
}
Ok(())
}
pub fn pred(&self) -> Option<&PredicateTarget> {
self.pred.as_ref()
}
pub fn pred_hash_or_wc(&self) -> &PredicateHashOrWildcardTarget {
&self.pred_hash_or_wc
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct CustomPredicateTarget {
pub conjunction: BoolTarget,
// len = params.max_custom_predicate_arity
pub statements: Vec<StatementTmplTarget>,
pub args_len: Target,
}
impl CustomPredicateTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
custom_pred: &CustomPredicate,
) -> Result<()> {
pw.set_target(
self.conjunction.target,
F::from_bool(custom_pred.conjunction),
)?;
let st_tmpl_pad = custom_pred.pad_statement_tmpl();
for (i, st_tmpl) in custom_pred
.statements
.iter()
.chain(iter::repeat(&st_tmpl_pad))
.take(Params::max_custom_predicate_arity())
.enumerate()
{
self.statements[i].set_targets(pw, st_tmpl)?;
}
pw.set_target(self.args_len, F::from_canonical_usize(custom_pred.args_len))?;
Ok(())
}
}
/// Custom predicate structure that can be verified to belong to a batch id at a particular index
#[derive(Clone, Serialize, Deserialize)]
pub struct CustomPredicateInBatchTarget {
pub id: HashOutTarget,
pub index: Target,
/// Predicate that may use references to another predicate of the batch with BatchSelf
pub self_predicate: CustomPredicateTarget,
pub mtp: MerkleClaimAndProofTarget,
}
impl CustomPredicateInBatchTarget {
/// This constructor connects the merkle proof and claim targets with with the (index,
/// self_predicate) and id.
pub fn new_virtual(builder: &mut CircuitBuilder) -> CustomPredicateInBatchTarget {
let index = builder.add_virtual_target();
let self_predicate = builder.add_virtual_custom_predicate(true);
// Existence Merkle Tree proof of (index, hash(self_predicate)) -> id
let mtp =
MerkleClaimAndProofTarget::new_virtual(Params::max_depth_custom_batch_mt(), builder);
let _true = builder._true();
builder.connect(_true.target, mtp.enabled.target);
builder.connect(_true.target, mtp.existence.target);
let zero = builder.constant(F(0));
let key = ValueTarget {
elements: [index, zero, zero, zero],
};
builder.connect_values(key, mtp.key);
let id = mtp.root;
Self {
id,
index,
mtp,
self_predicate,
}
}
/// Hash the predicate, connect it to the merkle proof claim value and verify the merkle proof.
pub fn verify_circuit(&self, builder: &mut CircuitBuilder) {
let value = builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.self_predicate.flatten());
builder.connect_array(value.elements, self.mtp.value.elements);
verify_merkle_proof_circuit(builder, &self.mtp);
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
predicate_ref: &CustomPredicateRef,
mtp: &MerkleProof,
) -> Result<()> {
pw.set_target_arr(&self.id.elements, &predicate_ref.batch.id().0)?;
pw.set_target(self.index, F::from_canonical_usize(predicate_ref.index))?;
let predicate = predicate_ref.predicate();
self.self_predicate.set_targets(pw, predicate)?;
let mtp_claim = MerkleClaimAndProof {
root: predicate_ref.batch.id(),
key: Value::from(predicate_ref.index as i64).raw(),
value: RawValue::from(hash_fields(&predicate.to_fields())),
proof: mtp.clone(),
};
self.mtp.set_targets(pw, true, &mtp_claim)?;
Ok(())
}
}
/// Custom predicate table entry
#[derive(Clone, Serialize, Deserialize)]
pub struct CustomPredicateEntryTarget {
pub id: HashOutTarget,
pub index: Target,
pub predicate: CustomPredicateTarget,
}
impl CustomPredicateEntryTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
predicate: &CustomPredicateRef,
) -> Result<()> {
pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?;
pw.set_target(self.index, F::from_canonical_usize(predicate.index))?;
// Replace BatchSelf predicates with Custom(batch, i), and
// SelfPredicateHash args with Literal(hash(Custom(batch, i)))
let batch = &predicate.batch;
let predicate = predicate.predicate();
let statements = predicate
.statements
.clone()
.into_iter()
.map(|st_tmpl| {
let pred_or_wc = match st_tmpl.pred_or_wc {
PredicateOrWildcard::Predicate(Predicate::BatchSelf(i)) => {
PredicateOrWildcard::Predicate(Predicate::Custom(CustomPredicateRef {
batch: batch.clone(),
index: i,
}))
}
x => x.clone(),
};
let args = st_tmpl
.args
.into_iter()
.map(|arg| match arg {
StatementTmplArg::SelfPredicateHash(i) => {
let pred_hash = Predicate::Custom(CustomPredicateRef {
batch: batch.clone(),
index: i,
})
.hash();
StatementTmplArg::Literal(Value::from(pred_hash))
}
other => other,
})
.collect();
StatementTmpl { pred_or_wc, args }
})
.collect_vec();
let predicate = CustomPredicate {
name: predicate.name.clone(),
conjunction: predicate.conjunction,
statements,
args_len: predicate.args_len,
wildcard_names: predicate.wildcard_names.clone(),
};
self.predicate.set_targets(pw, &predicate)?;
Ok(())
}
}
impl Flattenable for CustomPredicateEntryTarget {
fn flatten(&self) -> Vec<Target> {
self.id
.elements
.iter()
.chain(iter::once(&self.index))
.chain(self.predicate.flatten().iter())
.cloned()
.collect()
}
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params));
Self {
id: HashOutTarget::from_flattened(params, &vs[0..4]),
index: vs[4],
predicate: CustomPredicateTarget::from_flattened(params, &vs[5..]),
}
}
fn size(params: &Params) -> usize {
HashOutTarget::size(params) + 1 + CustomPredicateTarget::size(params)
}
}
impl CustomPredicateEntryTarget {
pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget {
builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.flatten())
}
}
// Custom predicate verification table entry
#[derive(Clone, Serialize, Deserialize)]
pub struct CustomPredicateVerifyEntryTarget {
pub custom_predicate_table_index: IndexTarget,
pub custom_predicate: CustomPredicateEntryTarget,
pub args: Vec<ValueTarget>,
pub op_args: Vec<StatementTarget>,
}
impl CustomPredicateVerifyEntryTarget {
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
CustomPredicateVerifyEntryTarget {
custom_predicate_table_index: IndexTarget::new_virtual(
params.max_custom_predicates,
builder,
),
custom_predicate: builder.add_virtual_custom_predicate_entry(),
args: (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value())
.collect(),
op_args: (0..params.max_operation_args)
.map(|_| builder.add_virtual_statement(false))
.collect(),
}
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
cpv: &CustomPredicateVerification,
) -> Result<()> {
self.custom_predicate_table_index
.set_targets(pw, cpv.custom_predicate_table_index)?;
// Replace statement templates of batch-self with (id,index)
self.custom_predicate
.set_targets(pw, &cpv.custom_predicate)?;
let pad_arg = Value::from(0);
for (arg_target, arg) in self.args.iter().zip_eq(
cpv.args
.iter()
.chain(iter::repeat(&pad_arg))
.take(params.max_custom_predicate_wildcards),
) {
arg_target.set_targets(pw, &Value::from(arg.raw()))?;
}
let pad_op_arg = Statement(Predicate::Native(NativePredicate::None), vec![]);
for (op_arg_target, op_arg) in self.op_args.iter().zip_eq(
cpv.op_args
.iter()
.chain(iter::repeat(&pad_op_arg))
.take(params.max_operation_args),
) {
op_arg_target.set_targets(pw, op_arg)?
}
Ok(())
}
}
/// Query for the custom predicate verification table
#[derive(Clone, Serialize, Deserialize)]
pub struct CustomPredicateVerifyQueryTarget {
pub statement: StatementTarget,
pub op_type: OperationTypeTarget,
pub op_args: Vec<StatementTarget>,
}
impl CustomPredicateVerifyQueryTarget {
pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget {
builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.flatten())
}
}
impl Flattenable for CustomPredicateVerifyQueryTarget {
fn flatten(&self) -> Vec<Target> {
self.statement
.flatten()
.iter()
.chain(self.op_type.elements.iter())
.cloned()
.chain(self.op_args.iter().flat_map(|op_arg| op_arg.flatten()))
.collect()
}
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params));
let (pos, size) = (0, StatementTarget::size(params));
let statement = StatementTarget::from_flattened(params, &vs[pos..pos + size]);
let (pos, size) = (pos + size, OperationTypeTarget::size(params));
let op_type = OperationTypeTarget {
elements: vs[pos..pos + size]
.try_into()
.expect("len = operation_type_size"),
};
let (pos, size) = (pos + size, StatementTarget::size(params));
let op_args = (0..params.max_operation_args)
.map(|i| {
StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size])
})
.collect();
Self {
statement,
op_type,
op_args,
}
}
fn size(params: &Params) -> usize {
StatementTarget::size(params) * (1 + params.max_operation_args)
+ OperationTarget::size(params)
}
}
/// Trait for target structs that may be converted to and from vectors
/// of targets.
pub trait Flattenable {
fn flatten(&self) -> Vec<Target>;
fn from_flattened(params: &Params, vs: &[Target]) -> Self;
/// Size in number of `Target`s
fn size(params: &Params) -> usize;
}
// TODO: Figure out why this is defined in common and not in the merkletree directory
/// For the purpose of op verification, we need only look up the
/// Merkle claim rather than the Merkle proof since it is verified
/// elsewhere.
#[derive(Copy, Clone)]
pub struct MerkleClaimTarget {
pub(crate) enabled: BoolTarget,
pub(crate) root: HashOutTarget,
pub(crate) key: ValueTarget,
pub(crate) value: ValueTarget,
pub(crate) existence: BoolTarget,
}
impl From<MerkleClaimAndProofTarget> for MerkleClaimTarget {
fn from(pf: MerkleClaimAndProofTarget) -> Self {
Self {
enabled: pf.enabled,
root: pf.root,
key: pf.key,
value: pf.value,
existence: pf.existence,
}
}
}
/// For the purpose of op verification, we need only look up the
/// Merkle state transition claim rather than the Merkle state
/// transition proof since it is verified elsewhere.
#[derive(Copy, Clone)]
pub struct MerkleTreeStateTransitionClaimTarget {
pub(crate) enabled: BoolTarget,
pub(crate) op: Target,
pub(crate) old_root: HashOutTarget,
pub(crate) new_root: HashOutTarget,
pub(crate) op_key: ValueTarget,
pub(crate) op_value: ValueTarget,
}
impl From<MerkleTreeStateTransitionProofTarget> for MerkleTreeStateTransitionClaimTarget {
fn from(pf: MerkleTreeStateTransitionProofTarget) -> Self {
Self {
enabled: pf.enabled,
op: pf.op,
old_root: pf.old_root,
new_root: pf.new_root,
op_key: pf.op_key,
op_value: pf.op_value,
}
}
}
impl Flattenable for HashOutTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params));
Self {
elements: array::from_fn(|i| vs[i]),
}
}
fn size(_params: &Params) -> usize {
4
}
}
impl Flattenable for ValueTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params));
Self::from_slice(vs)
}
fn size(_params: &Params) -> usize {
4
}
}
impl Flattenable for MerkleClaimTarget {
fn flatten(&self) -> Vec<Target> {
[
vec![self.enabled.target],
self.root.elements.to_vec(),
self.key.elements.to_vec(),
self.value.elements.to_vec(),
vec![self.existence.target],
]
.concat()
}
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params));
Self {
enabled: BoolTarget::new_unsafe(vs[0]),
root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()),
key: ValueTarget::from_slice(
&vs[1 + NUM_HASH_OUT_ELTS..1 + NUM_HASH_OUT_ELTS + VALUE_SIZE],
),
value: ValueTarget::from_slice(
&vs[1 + NUM_HASH_OUT_ELTS + VALUE_SIZE..1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE],
),
existence: BoolTarget::new_unsafe(vs[1 + NUM_HASH_OUT_ELTS + 2 * VALUE_SIZE]),
}
}
fn size(params: &Params) -> usize {
2 + HashOutTarget::size(params) + 2 * ValueTarget::size(params)
}
}
impl Flattenable for MerkleTreeStateTransitionClaimTarget {
fn flatten(&self) -> Vec<Target> {
[
vec![self.enabled.target, self.op],
self.old_root.elements.to_vec(),
self.new_root.elements.to_vec(),
self.op_key.elements.to_vec(),
self.op_value.elements.to_vec(),
]
.concat()
}
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), Self::size(params));
Self {
enabled: BoolTarget::new_unsafe(vs[0]),
op: vs[1],
old_root: HashOutTarget::from_vec(vs[2..2 + NUM_HASH_OUT_ELTS].to_vec()),
new_root: HashOutTarget::from_vec(
vs[2 + NUM_HASH_OUT_ELTS..2 * (1 + NUM_HASH_OUT_ELTS)].to_vec(),
),
op_key: ValueTarget::from_slice(
&vs[2 * (1 + NUM_HASH_OUT_ELTS)..2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE],
),
op_value: ValueTarget::from_slice(
&vs[2 * (1 + NUM_HASH_OUT_ELTS) + VALUE_SIZE
..2 * (1 + NUM_HASH_OUT_ELTS) + 2 * VALUE_SIZE],
),
}
}
fn size(params: &Params) -> usize {
2 * (1 + HashOutTarget::size(params)) + 2 * ValueTarget::size(params)
}
}
impl Flattenable for PredicateTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(params: &Params, v: &[Target]) -> Self {
assert_eq!(v.len(), Self::size(params));
Self {
elements: v.try_into().expect("len is predicate_size"),
}
}
fn size(_params: &Params) -> usize {
Params::predicate_size()
}
}
impl Flattenable for StatementTarget {
fn flatten(&self) -> Vec<Target> {
self.pred_hash
.flatten()
.into_iter()
.chain(self.args.iter().flat_map(|a| &a.elements).cloned())
.collect()
}
fn from_flattened(params: &Params, v: &[Target]) -> Self {
assert_eq!(v.len(), Self::size(params));
let predicate_hash = HashOutTarget::from_flattened(params, &v[..HASH_SIZE]);
let args = (0..Params::max_statement_args())
.map(|i| StatementArgTarget {
elements: array::from_fn(|j| v[HASH_SIZE + i * STATEMENT_ARG_F_LEN + j]),
})
.collect();
Self {
pred: None,
pred_hash: predicate_hash,
args,
}
}
fn size(params: &Params) -> usize {
HASH_SIZE + Params::max_statement_args() * StatementArgTarget::size(params)
}
}
impl Flattenable for CustomPredicateTarget {
fn flatten(&self) -> Vec<Target> {
iter::once(self.conjunction.target)
.chain(iter::once(self.args_len))
.chain(self.statements.iter().flat_map(|s| s.flatten()))
.collect()
}
fn from_flattened(params: &Params, v: &[Target]) -> Self {
assert_eq!(v.len(), Self::size(params));
// We assume that `from_flattened` is always called with the output of `flattened`, so
// this `BoolTarget` should actually safe.
let conjunction = BoolTarget::new_unsafe(v[0]);
let args_len = v[1];
let st_tmpl_size = Params::statement_tmpl_size();
let statements = (0..Params::max_custom_predicate_arity())
.map(|i| {
let st_v = &v[2 + st_tmpl_size * i..2 + st_tmpl_size * (i + 1)];
StatementTmplTarget::from_flattened(params, st_v)
})
.collect();
Self {
conjunction,
statements,
args_len,
}
}
fn size(params: &Params) -> usize {
2 + Params::max_custom_predicate_arity() * StatementTmplTarget::size(params)
}
}
impl Flattenable for StatementTmplTarget {
fn flatten(&self) -> Vec<Target> {
self.pred_hash_or_wc
.flatten()
.into_iter()
.chain(self.args.iter().flat_map(|sta| sta.flatten()))
.collect()
}
fn from_flattened(params: &Params, v: &[Target]) -> Self {
assert_eq!(v.len(), Self::size(params));
let pred_hash_or_wc_end = Params::pred_hash_or_wc_size();
let pred_hash_or_wc =
PredicateHashOrWildcardTarget::from_flattened(params, &v[..pred_hash_or_wc_end]);
let sta_size = Params::statement_tmpl_arg_size();
let args = (0..Params::max_statement_args())
.map(|i| {
let sta_v = &v
[pred_hash_or_wc_end + sta_size * i..pred_hash_or_wc_end + sta_size * (i + 1)];
StatementTmplArgTarget::from_flattened(params, sta_v)
})
.collect();
Self {
pred: None,
pred_hash_or_wc,
args,
}
}
fn size(params: &Params) -> usize {
Params::pred_hash_or_wc_size()
+ Params::max_statement_args() * StatementTmplArgTarget::size(params)
}
}
impl Flattenable for StatementTmplArgTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(params: &Params, v: &[Target]) -> Self {
assert_eq!(v.len(), Self::size(params));
Self {
elements: v.try_into().expect("len is statement_tmpl_arg_size"),
}
}
fn size(_params: &Params) -> usize {
Params::statement_tmpl_arg_size()
}
}
/// Index to an array for random access
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct IndexTarget {
pub max_array_len: usize,
pub low: Target,
pub high: Target,
}
impl IndexTarget {
// Length in field elements
pub fn size(_params: &Params) -> usize {
2
}
pub fn new_virtual(max_array_len: usize, builder: &mut CircuitBuilder) -> Self {
// Limit the maximum array length to avoid abusing `vec_ref`
assert!(max_array_len <= 256);
Self {
max_array_len,
low: builder.add_virtual_target(),
high: if max_array_len > 64 {
builder.add_virtual_target()
} else {
builder.zero()
},
}
}
pub fn set_targets(&self, pw: &mut PartialWitness<F>, index: usize) -> Result<()> {
assert!(index == 0 || index < self.max_array_len);
pw.set_target(self.low, F::from_canonical_usize(index & ((1 << 6) - 1)))?;
pw.set_target(self.high, F::from_canonical_usize(index >> 6))?;
Ok(())
}
}
pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn connect_values(&mut self, x: ValueTarget, y: ValueTarget);
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
fn add_virtual_value(&mut self) -> ValueTarget;
fn add_virtual_statement(&mut self, with_pred: bool) -> StatementTarget;
fn add_virtual_statement_arg(&mut self) -> StatementArgTarget;
fn add_virtual_predicate(&mut self) -> PredicateTarget;
fn add_virtual_operation_type(&mut self) -> OperationTypeTarget;
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget;
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget;
fn add_virtual_statement_tmpl(&mut self, with_pred: bool) -> StatementTmplTarget;
fn add_virtual_custom_predicate(&mut self, with_pred: bool) -> CustomPredicateTarget;
fn add_virtual_custom_predicate_entry(&mut self) -> CustomPredicateEntryTarget;
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
fn select_statement_arg(
&mut self,
b: BoolTarget,
x: &StatementArgTarget,
y: &StatementArgTarget,
) -> StatementArgTarget;
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget;
fn constant_value(&mut self, v: RawValue) -> ValueTarget;
fn is_equal_slice(&mut self, xs: &[Target], ys: &[Target]) -> BoolTarget;
// Convenience methods for checking values.
/// Checks whether `xs` is right-padded with 0s so as to represent a `Value`.
fn statement_arg_is_value(&mut self, arg: &StatementArgTarget) -> BoolTarget;
/// Checks whether `x` is an i64, which involves checking that it
/// consists of two `u32` limbs.
fn assert_i64(&mut self, x: ValueTarget);
/// Checks whether an i64 is negative.
fn i64_is_negative(&mut self, x: ValueTarget) -> BoolTarget;
/// Checks whether `x < y` if `b` is true. This assumes that `x`
/// and `y` each consist of two `u32` limbs.
fn assert_i64_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget);
/// Computes `x + y` assuming `x` and `y` are assigned `i64`
/// values.
fn i64_wrapping_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
/// Computes `x + y` assuming `x` and `y` are assigned `i64`
/// values. Enforces no overflow.
fn i64_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
/// Computes `x * y` assuming `x` and `y` are assigned `i64`
/// values. Enforces no overflow.
fn i64_mul(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
/// Computes the canonical involution of `x` in `i64`, i.e. the
/// negation of `x` as an `i64`.
fn i64_inv(&mut self, x: ValueTarget) -> ValueTarget;
/// Computes the absolute value of `x` *as an element of
/// `i64`*. Includes sign indicator (true if negative).
fn i64_abs(&mut self, x: ValueTarget) -> (ValueTarget, BoolTarget);
/// Creates value target that is a hash of two given values.
fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
/// Like `random_access` but allows using longer arrays.
fn random_access_long(&mut self, i: &IndexTarget, array: &[Target]) -> Target;
/// Convenience methods for accessing and connecting elements of
/// (vectors of) flattenables.
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T;
/// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target`
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
fn select_flattenable<T: Flattenable>(
&mut self,
params: &Params,
b: BoolTarget,
x: &T,
y: &T,
) -> T;
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T);
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget;
/// Convenience methods for Boolean into-iters.
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
/// Return a bit-mask of size `len` that selects all positions lower than `n`
fn lt_mask(&mut self, len: usize, n: Target) -> Vec<BoolTarget>;
}
impl CircuitBuilderPod<F, D> for CircuitBuilder {
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]) {
assert_eq!(xs.len(), ys.len());
for (x, y) in xs.iter().zip(ys.iter()) {
self.connect(*x, *y);
}
}
fn connect_values(&mut self, x: ValueTarget, y: ValueTarget) {
self.connect_slice(&x.elements, &y.elements);
}
fn add_virtual_value(&mut self) -> ValueTarget {
ValueTarget {
elements: self.add_virtual_target_arr(),
}
}
/// If `with_pred = true` a predicate is included and its hash constrained.
/// If `with_pred = false` only the predicate hash is included.
fn add_virtual_statement(&mut self, with_pred: bool) -> StatementTarget {
let (pred, pred_hash) = if with_pred {
let pred = self.add_virtual_predicate();
let pred_hash = pred.hash(self);
(Some(pred), pred_hash)
} else {
let pred_hash = self.add_virtual_hash();
(None, pred_hash)
};
StatementTarget {
pred,
pred_hash,
args: (0..Params::max_statement_args())
.map(|_| self.add_virtual_statement_arg())
.collect(),
}
}
fn add_virtual_statement_arg(&mut self) -> StatementArgTarget {
StatementArgTarget {
elements: self.add_virtual_target_arr(),
}
}
fn add_virtual_predicate(&mut self) -> PredicateTarget {
PredicateTarget {
elements: self.add_virtual_target_arr(),
}
}
fn add_virtual_operation_type(&mut self) -> OperationTypeTarget {
OperationTypeTarget {
elements: self.add_virtual_target_arr(),
}
}
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget {
OperationTarget {
op_type: self.add_virtual_operation_type(),
args: (0..params.max_operation_args)
.map(|_| IndexTarget::new_virtual(params.statement_table_size(), self))
.collect(),
aux_index: IndexTarget::new_virtual(OperationAux::table_size(params), self),
}
}
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget {
StatementTmplArgTarget {
elements: self.add_virtual_target_arr(),
}
}
/// If `with_pred = true` a predicate is included.
/// If `with_pred = false` only the predicate hash is included.
/// The pred_hash is constrained to be hash(pred) conditionally on the template using a
/// predicate and not a wildcard.
fn add_virtual_statement_tmpl(&mut self, with_pred: bool) -> StatementTmplTarget {
let pred_hash_or_wc =
PredicateHashOrWildcardTarget::new(self.add_virtual_target(), self.add_virtual_value());
let pred = if with_pred {
let pred = self.add_virtual_predicate();
let pred_hash = pred.hash(self);
let is_pred = pred_hash_or_wc.is_pred(self);
let data = pred_hash_or_wc.data();
for i in 0..VALUE_SIZE {
self.conditional_assert_eq(is_pred.target, data.elements[i], pred_hash.elements[i]);
}
Some(pred)
} else {
None
};
StatementTmplTarget {
pred,
pred_hash_or_wc,
args: (0..Params::max_statement_args())
.map(|_| self.add_virtual_statement_tmpl_arg())
.collect(),
}
}
/// See `add_virtual_statement_tmpl` for the meaning of `with_pred`.
fn add_virtual_custom_predicate(&mut self, with_pred: bool) -> CustomPredicateTarget {
let statements = (0..Params::max_custom_predicate_arity())
.map(|_| self.add_virtual_statement_tmpl(with_pred))
.collect();
CustomPredicateTarget {
conjunction: self.add_virtual_bool_target_safe(),
statements,
args_len: self.add_virtual_target(),
}
}
/// See `add_virtual_statement_tmpl` for the meaning of `with_pred`.
fn add_virtual_custom_predicate_entry(&mut self) -> CustomPredicateEntryTarget {
CustomPredicateEntryTarget {
id: self.add_virtual_hash(),
index: self.add_virtual_target(),
predicate: self.add_virtual_custom_predicate(false),
}
}
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget {
ValueTarget {
elements: std::array::from_fn(|i| self.select(b, x.elements[i], y.elements[i])),
}
}
fn select_statement_arg(
&mut self,
b: BoolTarget,
x: &StatementArgTarget,
y: &StatementArgTarget,
) -> StatementArgTarget {
StatementArgTarget {
elements: std::array::from_fn(|i| self.select(b, x.elements[i], y.elements[i])),
}
}
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget {
BoolTarget::new_unsafe(self.select(b, x.target, y.target))
}
fn constant_value(&mut self, v: RawValue) -> ValueTarget {
ValueTarget {
elements: std::array::from_fn(|i| {
self.constant(F::from_noncanonical_u64(v.0[i].to_noncanonical_u64()))
}),
}
}
fn is_equal_slice(&mut self, xs: &[Target], ys: &[Target]) -> BoolTarget {
assert_eq!(xs.len(), ys.len());
let init = self._true();
xs.iter().zip(ys.iter()).fold(init, |ok, (x, y)| {
let is_eq = self.is_equal(*x, *y);
self.and(ok, is_eq)
})
}
fn statement_arg_is_value(&mut self, arg: &StatementArgTarget) -> BoolTarget {
let zeros = iter::repeat(self.zero())
.take(STATEMENT_ARG_F_LEN - VALUE_SIZE)
.collect::<Vec<_>>();
self.is_equal_slice(&arg.elements[VALUE_SIZE..], &zeros)
}
fn assert_i64(&mut self, x: ValueTarget) {
// `x` should only have two limbs.
x.elements
.into_iter()
.skip(2)
.for_each(|l| self.assert_zero(l));
// 32-bit range check.
self.range_check(x.elements[0], NUM_BITS);
self.range_check(x.elements[1], NUM_BITS);
}
fn i64_is_negative(&mut self, x: ValueTarget) -> BoolTarget {
// x is negative if the most significant bit of its most
// significant limb is 1.
let high_bits = self.split_le(x.elements[1], NUM_BITS);
high_bits[31]
}
fn assert_i64_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) {
// If b is false, replace `x` and `y` with dummy values.
let zero = ValueTarget::zero(self);
let one = ValueTarget::one(self);
let x = self.select_value(b, x, zero);
let y = self.select_value(b, y, one);
// Lt assertion.
let assert_limb_lt = |builder: &mut Self, x, y| {
// Check that `y-1-x` fits within `NUM_BITS` bits.
let one = builder.one();
let y_minus_one = builder.sub(y, one);
let expr = builder.sub(y_minus_one, x);
builder.range_check(expr, NUM_BITS);
};
// Check if `x` and `y` have the same sign. If not, swap.
let x_is_negative = self.i64_is_negative(x);
let y_is_negative = self.i64_is_negative(y);
let same_sign_ind = self.is_equal(x_is_negative.target, y_is_negative.target);
let (x, y) = (
self.select_value(same_sign_ind, x, y),
self.select_value(same_sign_ind, y, x),
);
let big_limbs_eq = self.is_equal(x.elements[1], y.elements[1]);
let lhs = self.select(big_limbs_eq, x.elements[0], x.elements[1]);
let rhs = self.select(big_limbs_eq, y.elements[0], y.elements[1]);
assert_limb_lt(self, lhs, rhs);
}
fn i64_wrapping_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
let zero = self.zero();
// Add components and carry where appropriate.
let (_, sum) = std::iter::zip(&x.elements[..2], &y.elements[..2]).fold(
(zero, vec![]),
|(carry, out), (&a, &b)| {
let sum = [a, b, carry]
.into_iter()
.reduce(|alpha, beta| self.add(alpha, beta))
.expect("Iterator should be nonempty.");
let (sum_residue, sum_quotient) = self.split_low_high(sum, NUM_BITS, F::BITS);
(sum_quotient, [out, vec![sum_residue]].concat())
},
);
ValueTarget::from_slice(&[sum[0], sum[1], zero, zero])
}
fn i64_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
let zero = self.zero();
let sum = self.i64_wrapping_add(x, y);
// Overflow check.
let x_is_negative = self.i64_is_negative(x);
let x_is_nonnegative = self.not(x_is_negative);
let y_is_negative = self.i64_is_negative(y);
let y_is_nonnegative = self.not(y_is_negative);
let sum_is_negative = self.i64_is_negative(sum);
let sum_is_nonnegative = self.not(sum_is_negative);
let overflow_conditions = [
self.all([x_is_negative, y_is_negative, sum_is_nonnegative]),
self.all([x_is_nonnegative, y_is_nonnegative, sum_is_negative]),
];
let overflow = self.any(overflow_conditions);
self.connect(overflow.target, zero);
sum
}
fn i64_mul(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
let zero = self.zero();
let i64_min = ValueTarget::from_slice(&self.constants(&RawValue::from(i64::MIN).0));
let (abs_x, x_is_negative) = self.i64_abs(x);
let (abs_y, y_is_negative) = self.i64_abs(y);
// Sign indicators.
let same_sign_ind = self.is_equal(x_is_negative.target, y_is_negative.target);
let prod_sign = self.not(same_sign_ind);
// Determine product of absolute values.
let x = abs_x.elements[..2].to_vec();
let y = abs_y.elements[..2].to_vec();
let prods = [
self.mul(x[0], y[0]),
self.mul(x[0], y[1]),
self.mul(x[1], y[0]),
]
.into_iter()
.map(|p| self.split_low_high(p, NUM_BITS, F::BITS))
.collect::<Vec<_>>();
let prod_lower = prods[0].0;
let (prod_upper, _) = {
let sum1 = self.add(prods[1].0, prods[2].0);
let sum2 = self.add(sum1, prods[0].1);
self.split_low_high(sum2, NUM_BITS, F::BITS)
};
let abs_prod = ValueTarget::from_slice(&[prod_lower, prod_upper, zero, zero]);
// Overflow check: The latter two products in `prods` should
// have zero higher-order coefficients.
let no_spillovers = [
self.is_equal(prods[1].1, zero),
self.is_equal(prods[2].1, zero),
]
.into_iter()
.reduce(|a, b| self.and(a, b))
.expect("Iterator should be nonempty.");
// Overflow check: The product of the higher-order
// coefficients should be zero.
let higher_prod = self.mul(x[1], y[1]);
let higher_prod_is_zero = self.is_equal(higher_prod, zero);
// Overflow check: The product of the absolute values is
// either nonnegative or negative and equal to `i64::MIN`.
let abs_prod_is_negative = self.i64_is_negative(abs_prod);
let abs_prod_is_nonnegative = self.not(abs_prod_is_negative);
let abs_prod_is_min = self.is_equal_slice(&abs_prod.elements, &i64_min.elements);
let abs_prod_sign_ok = self.and(abs_prod_is_min, prod_sign);
let abs_prod_sign_ok = self.or(abs_prod_sign_ok, abs_prod_is_nonnegative);
// Combine the above conditions.
let no_overflow = self.and(abs_prod_sign_ok, higher_prod_is_zero);
let no_overflow = self.and(no_overflow, no_spillovers);
self.assert_one(no_overflow.target);
// Take sign into account.
let minus_abs_prod = self.i64_inv(abs_prod);
self.select_value(prod_sign, minus_abs_prod, abs_prod)
}
fn i64_inv(&mut self, x: ValueTarget) -> ValueTarget {
let zero = self.zero();
let one = ValueTarget::one(self);
let u32_max = self.constant(F::from_canonical_u32(u32::MAX));
let flipped_x = ValueTarget::from_slice(&[
self.sub(u32_max, x.elements[0]),
self.sub(u32_max, x.elements[1]),
zero,
zero,
]);
self.i64_wrapping_add(one, flipped_x)
}
fn i64_abs(&mut self, x: ValueTarget) -> (ValueTarget, BoolTarget) {
let x_is_negative = self.i64_is_negative(x);
let minus_x = self.i64_inv(x);
(self.select_value(x_is_negative, minus_x, x), x_is_negative)
}
fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
ValueTarget::from_slice(
&self
.hash_n_to_hash_no_pad::<PoseidonHash>([x.elements, y.elements].concat())
.elements,
)
}
fn random_access_long(&mut self, i: &IndexTarget, array: &[Target]) -> Target {
const CHUNK_LEN: usize = 64; // Max size of a single gate native random access
assert!(array.len() <= i.max_array_len);
// Limit to 4 chunks (combination of 4 random_access of CHUNK_LEN elements) to avoid
// abusing this method.
assert!(array.len() <= 4 * CHUNK_LEN);
// We do several random accesses over chunks of CHUNK_LEN using the lowest bits of the
// index. Then we combine them using the highest bits of the index.
let mut chunk_res = Vec::new();
let num_chunks = array.len().div_ceil(CHUNK_LEN);
for chunk in array.chunks(CHUNK_LEN) {
let mut index_chunk = i.low;
// I we have several chunks and the last one is smaller (it's index needs less than 6
// bits), make it zero except when it's used so that the range check over the index
// passes.
if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 {
let last_chunk_index_high = self.constant(F::from_canonical_usize(num_chunks - 1));
let selector = self.is_equal(i.high, last_chunk_index_high);
index_chunk = self.mul(index_chunk, selector.target);
}
let res = self.random_access(index_chunk, chunk.to_vec());
chunk_res.push(res);
}
self.random_access(i.high, chunk_res)
}
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
// The idea would be the following: Take the array `ts` and hash each element. Then do the
// random access on the hash result. Finally "unhash" to recover the resolved element.
// We don't want to hash each element from the array each time, so we should cache the hashed
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
// then do `ts: &[HashCache<T>]`.
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T {
let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| {
let num_rows = m.len();
let num_columns = m
.first()
.map(|row| {
let row_len = row.len();
assert!(m.iter().all(|row| row.len() == row_len));
row_len
})
.unwrap_or(0);
(0..num_columns)
.map(|j| {
builder
.random_access_long(i, &(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>())
})
.collect::<Vec<_>>()
};
let flattened_ts = ts.iter().map(|t| t.flatten()).collect::<Vec<_>>();
T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
}
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
let zero = self.zero();
self.vec_ref(
params,
ts,
&IndexTarget {
max_array_len: 64,
low: i,
high: zero,
},
)
}
fn select_flattenable<T: Flattenable>(
&mut self,
params: &Params,
b: BoolTarget,
x: &T,
y: &T,
) -> T {
let flattened_x = x.flatten();
let flattened_y = y.flatten();
T::from_flattened(
params,
&iter::zip(flattened_x, flattened_y)
.map(|(x, y)| self.select(b, x, y))
.collect::<Vec<_>>(),
)
}
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) {
self.connect_slice(&xs.flatten(), &ys.flatten())
}
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget {
self.is_equal_slice(&xs.flatten(), &ys.flatten())
}
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget {
xs.into_iter()
.reduce(|a, b| self.and(a, b))
.unwrap_or(self._true())
}
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget {
xs.into_iter()
.reduce(|a, b| self.or(a, b))
.unwrap_or(self._false())
}
fn lt_mask(&mut self, len: usize, n: Target) -> Vec<BoolTarget> {
let zero = self.zero();
let mask: Vec<_> = (0..len)
.map(|_| self.add_virtual_bool_target_safe())
.collect();
self.add_simple_generator(LtMaskGenerator {
n,
mask: mask.iter().map(|bt| bt.target).collect(),
});
// We have `n` ones in the mask
let mask_sum = mask
.iter()
.map(|b| b.target)
.reduce(|acc, x| self.add(acc, x))
.unwrap_or(zero);
self.connect(n, mask_sum);
// The elements in the mask can only transition from 1 to 0 or 0 to 0.
for i in 0..len - 1 {
let diff = self.sub(mask[i].target, mask[i + 1].target);
self.assert_bool(BoolTarget::new_unsafe(diff));
}
mask
}
}
#[derive(Debug, Default, Clone)]
pub struct LtMaskGenerator {
pub(crate) n: Target,
pub(crate) mask: Vec<Target>,
}
impl SimpleGenerator<F, D> for LtMaskGenerator {
fn id(&self) -> String {
"LtMaskGenerator".to_string()
}
fn dependencies(&self) -> Vec<Target> {
vec![self.n]
}
fn run_once(
&self,
witness: &PartitionWitness<F>,
out_buffer: &mut GeneratedValues<F>,
) -> anyhow::Result<()> {
let n = witness.get_target(self.n).to_canonical_u64();
for (i, mask_i) in self.mask.iter().enumerate() {
let v = if (i as u64) < n { F::ONE } else { F::ZERO };
out_buffer.set_target(*mask_i, v)?;
}
Ok(())
}
fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData) -> IoResult<()> {
dst.write_target(self.n)?;
dst.write_target_vec(&self.mask)
}
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult<Self> {
let n = src.read_target()?;
let mask = src.read_target_vec()?;
Ok(Self { n, mask })
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::sync::Arc;
use anyhow::anyhow;
use itertools::Itertools;
use plonky2::plonk::{
circuit_builder::CircuitBuilder, circuit_data::CircuitConfig,
config::PoseidonGoldilocksConfig,
};
use super::*;
use crate::{
backends::plonky2::basetypes::C,
examples::custom::eth_dos_batch,
frontend::{self, CustomPredicateBatchBuilder},
middleware::CustomPredicateBatch,
};
pub(crate) const I64_TEST_PAIRS: [(i64, i64); 36] = [
// Nonnegative numbers
(0, 0),
(0, 50),
(35, 50),
(483748374, 221672),
(2, 1 << 31),
(2, 1 << 62),
(0, 1 << 62),
(1 << 31, 1 << 62),
(1 << 32, 1 << 32),
(1 << 62, 1 << 62),
(0, i64::MAX),
(i64::MAX, 1 << 62),
(i64::MAX, i64::MAX),
// Negative numbers
(-35, -50),
(-483748374, -221672),
(-(1 << 33), -1),
(-(1 << 32), -(1 << 32)),
(-(1 << 33), -(1 << 29)),
(-(1 << 33), -(1 << 30)),
(-(1 << 33), -(1 << 62)),
(-(1 << 62), -(1 << 62)),
(i64::MIN, -1),
(i64::MIN, -(1 << 31)),
(i64::MIN, -(1 << 62)),
(i64::MIN, i64::MIN),
// Mix of numbers
(-35, 50),
(-483748374, 221672),
(-(1 << 32), (1 << 32)),
(-(1 << 33), (1 << 30) - 1),
(-(1 << 33), (1 << 30)),
(-(1 << 62), (1 << 62)),
(i64::MIN, 0),
(i64::MIN, 1),
(i64::MIN, 1 << 31),
(i64::MIN, 1 << 62),
(i64::MIN, i64::MAX),
];
#[test]
fn custom_predicate_target() -> frontend::Result<()> {
let params = Params::default();
let config = CircuitConfig::standard_recursion_config();
let custom_predicate_batch = eth_dos_batch(&params)?;
for (i, cp) in custom_predicate_batch.predicates().iter().enumerate() {
let mut builder = CircuitBuilder::<F, D>::new(config.clone());
let flattened = cp.to_fields();
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
let cp_target = CustomPredicateTarget::from_flattened(&params, &flatteend_target);
// Round trip of from_flattened to flattened
let flatteend_target_rt = cp_target.flatten();
// TODO: Instead of connect, assign witness to result
builder.connect_slice(&flatteend_target, &flatteend_target_rt);
let pw = PartialWitness::<F>::new();
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap_or_else(|_| panic!("predicate {}", i));
data.verify(proof.clone()).unwrap();
}
Ok(())
}
fn helper_custom_predicate_in_batch_target(
custom_predicate_batch: &Arc<CustomPredicateBatch>,
) -> Result<()> {
for index in 0..custom_predicate_batch.predicates().len() {
let cpr = custom_predicate_batch
.predicate_ref_by_index(index)
.unwrap();
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let custom_pred_in_batch_target =
CustomPredicateInBatchTarget::new_virtual(&mut builder);
custom_pred_in_batch_target.verify_circuit(&mut builder);
let mut pw = PartialWitness::<F>::new();
let (_, mtp) = custom_predicate_batch
.mt()
.prove(&Value::from(index as i64).raw())
.unwrap();
custom_pred_in_batch_target.set_targets(&mut pw, &cpr, &mtp)?;
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof.clone()).unwrap();
}
Ok(())
}
#[test]
fn test_custom_predicate_in_batch_target() -> frontend::Result<()> {
let params = Params::default();
// Empty case
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
_ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
let custom_predicate_batch = cpb_builder.finish()?;
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
// Some cases from the examples
let custom_predicate_batch = eth_dos_batch(&params)?;
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
let custom_predicate_batch =
CustomPredicateBatch::new("empty".to_string(), vec![CustomPredicate::empty()]);
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
Ok(())
}
#[test]
fn test_i64_addition() -> Result<(), anyhow::Error> {
// Circuit declaration
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
let y_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
let sum_target = builder.i64_add(x_target, y_target);
let data = builder.build::<PoseidonGoldilocksConfig>();
I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| {
let mut pw = PartialWitness::<F>::new();
let (sum, overflow) = x.overflowing_add(y);
pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields())?;
pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields())?;
pw.set_target_arr(&sum_target.elements, &RawValue::from(sum).to_fields())?;
let proof = data.prove(pw);
match (overflow, proof) {
(false, Ok(pf)) => data.verify(pf),
(false, Err(e)) => Err(anyhow!("Proof failure despite no overflow: {}", e)),
(true, Ok(_)) => Err(anyhow!("Proof success despite overflow.")),
(true, Err(_)) => Ok(()),
}
})
}
#[test]
fn test_i64_multiplication() -> Result<(), anyhow::Error> {
// Circuit declaration
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
let y_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
let prod_target = builder.i64_mul(x_target, y_target);
let data = builder.build::<PoseidonGoldilocksConfig>();
I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| {
println!("{}, {}", x, y);
let mut pw = PartialWitness::<F>::new();
let (prod, overflow) = x.overflowing_mul(y);
pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields())?;
pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields())?;
pw.set_target_arr(&prod_target.elements, &RawValue::from(prod).to_fields())?;
let proof = data.prove(pw);
match (overflow, proof) {
(false, Ok(pf)) => data.verify(pf),
(false, Err(e)) => Err(anyhow!("Proof failure despite no overflow: {}", e)),
(true, Ok(_)) => Err(anyhow!("Proof success despite overflow.")),
(true, Err(_)) => Ok(()),
}
})
}
#[test]
fn test_random_access_long() -> Result<(), anyhow::Error> {
let lens: [usize; _] = [10, 60, 64, 96, 126, 159, 190, 256];
for len in &lens {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let array = builder.add_virtual_targets(*len);
let index_target = IndexTarget::new_virtual(*len, &mut builder);
let res = builder.random_access_long(&index_target, &array);
let data = builder.build::<PoseidonGoldilocksConfig>();
for i in 0..3 {
let index = (len - 1) * i / 2;
println!("len={}, index={}", len, index);
let mut pw = PartialWitness::<F>::new();
for (j, elem) in array.iter().enumerate() {
pw.set_target(*elem, F::from_canonical_usize(j * 11))?;
}
index_target.set_targets(&mut pw, index)?;
pw.set_target(res, F::from_canonical_usize(index * 11))?; // Expected
let proof = data.prove(pw)?;
data.verify(proof)?;
}
}
Ok(())
}
}