add target types for custom predicates (#223)

* add target types for custom predicates

* simplify

* fix clippy

* fix typo

* don't use ref for NativePredicate

* fix wrong len

* apply feedback from @ax0
This commit is contained in:
Eduard S. 2025-05-07 11:09:38 +02:00 committed by GitHub
parent bf394eada3
commit 726f95483d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 527 additions and 123 deletions

View file

@ -26,9 +26,9 @@ use crate::{
primitives::merkletree::MerkleClaimAndProofTarget,
},
middleware::{
NativeOperation, NativePredicate, Params, Predicate, RawValue, StatementArg, ToFields,
EMPTY_VALUE, F, HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN,
VALUE_SIZE,
NativeOperation, NativePredicate, Params, Predicate, PredicatePrefix, RawValue,
StatementArg, StatementTmplArgPrefix, ToFields, EMPTY_VALUE, F, HASH_SIZE,
OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE,
},
};
@ -117,20 +117,37 @@ impl StatementArgTarget {
#[derive(Clone)]
pub struct StatementTarget {
pub predicate: [Target; Params::predicate_size()],
pub predicate: PredicateTarget,
pub args: Vec<StatementArgTarget>,
}
pub trait Build<T> {
fn build(self, builder: &mut CircuitBuilder<F, D>, params: &Params) -> T;
}
impl Build<NativePredicateTarget> for NativePredicate {
fn build(self, builder: &mut CircuitBuilder<F, D>, params: &Params) -> NativePredicateTarget {
NativePredicateTarget::constant(builder, params, self)
}
}
impl<T> Build<T> for T {
fn build(self, _builder: &mut CircuitBuilder<F, D>, _params: &Params) -> T {
self
}
}
impl StatementTarget {
pub fn new_native(
builder: &mut CircuitBuilder<F, D>,
params: &Params,
predicate: NativePredicate,
native_predicate: impl Build<NativePredicateTarget>,
args: &[StatementArgTarget],
) -> Self {
let predicate_vec = builder.constants(&Predicate::Native(predicate).to_fields(params));
// if native_predicate is const then NativePredicate -> NativePredicateTarget
// else just use as is
Self {
predicate: array::from_fn(|i| predicate_vec[i]),
predicate: PredicateTarget::new_native(builder, params, native_predicate),
args: args
.iter()
.cloned()
@ -146,7 +163,7 @@ impl StatementTarget {
params: &Params,
st: &Statement,
) -> Result<()> {
pw.set_target_arr(&self.predicate, &st.predicate().to_fields(params))?;
self.predicate.set_targets(pw, params, st.predicate())?;
for (i, arg) in st
.args()
.iter()
@ -165,8 +182,8 @@ impl StatementTarget {
params: &Params,
t: NativePredicate,
) -> BoolTarget {
let st_code = builder.constants(&Predicate::Native(t).to_fields(params));
builder.is_equal_slice(&self.predicate, &st_code)
let expected_predicate = PredicateTarget::new_native(builder, params, t);
builder.is_equal_flattenable(&self.predicate, &expected_predicate)
}
}
@ -212,11 +229,159 @@ impl OperationTarget {
}
}
#[derive(Clone)]
pub struct NativePredicateTarget(Target);
impl NativePredicateTarget {
pub fn constant(
builder: &mut CircuitBuilder<F, D>,
params: &Params,
native_predicate: NativePredicate,
) -> Self {
let id = native_predicate.to_fields(params);
assert_eq!(1, id.len());
Self(builder.constant(id[0]))
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
native_predicate: NativePredicate,
) -> Result<()> {
let id = native_predicate.to_fields(params);
assert_eq!(1, id.len());
Ok(pw.set_target(self.0, id[0])?)
}
}
#[derive(Clone)]
pub struct PredicateTarget {
elements: [Target; Params::predicate_size()],
}
impl PredicateTarget {
pub fn new_native(
builder: &mut CircuitBuilder<F, D>,
params: &Params,
native_predicate: impl Build<NativePredicateTarget>,
) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::Native));
let id = native_predicate.build(builder, params).0;
let zero = builder.zero();
Self {
elements: [prefix, id, zero, zero, zero, zero],
}
}
pub fn new_batch_self(builder: &mut CircuitBuilder<F, D>, index: Target) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::BatchSelf));
let zero = builder.zero();
Self {
elements: [prefix, index, zero, zero, zero, zero],
}
}
pub fn new_custom(
builder: &mut CircuitBuilder<F, D>,
batch_id: HashOutTarget,
index: Target,
) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::Custom));
let id = batch_id.elements;
Self {
elements: [prefix, id[0], id[1], id[2], id[3], index],
}
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
predicate: Predicate,
) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &predicate.to_fields(params))?)
}
}
#[derive(Clone)]
pub struct KeyOrWildcardTarget {
pub elements: [Target; VALUE_SIZE],
}
impl KeyOrWildcardTarget {
fn from_slice(v: &[Target]) -> Self {
Self {
elements: v.try_into().expect("len is VALUE_SIZE"),
}
}
}
#[derive(Clone)]
pub struct StatementTmplArgTarget {
pub elements: [Target; Params::statement_tmpl_arg_size()],
}
impl StatementTmplArgTarget {
pub fn as_none(&self, builder: &mut CircuitBuilder<F, D>) -> BoolTarget {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::None));
builder.is_equal(self.elements[0], prefix)
}
pub fn as_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, ValueTarget) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal));
let case_ok = builder.is_equal(self.elements[0], prefix);
let value = ValueTarget::from_slice(&self.elements[1..5]);
(case_ok, value)
}
pub fn as_key(
&self,
builder: &mut CircuitBuilder<F, D>,
) -> (BoolTarget, Target, KeyOrWildcardTarget) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Key));
let case_ok = builder.is_equal(self.elements[0], prefix);
let id_wildcard_index = self.elements[1];
let value_key_or_wildcard = KeyOrWildcardTarget::from_slice(&self.elements[5..9]);
(case_ok, id_wildcard_index, value_key_or_wildcard)
}
pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, Target) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral));
let case_ok = builder.is_equal(self.elements[0], prefix);
let wildcard_index = self.elements[1];
(case_ok, wildcard_index)
}
}
#[derive(Clone)]
pub struct StatementTmplTarget {
pub pred: PredicateTarget,
pub args: Vec<StatementTmplArgTarget>,
}
#[derive(Clone)]
pub struct CustomPredicateTarget {
pub conjunction: BoolTarget,
// len = params.max_custom_predicate_arity
pub statements: Vec<StatementTmplTarget>,
pub args_len: Target,
}
#[derive(Clone)]
pub struct CustomPredicateBatchTarget {
pub predicates: Vec<CustomPredicateTarget>,
}
impl CustomPredicateBatchTarget {
pub fn id(&self, builder: &mut CircuitBuilder<F, D>) -> HashOutTarget {
let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect();
builder.hash_n_to_hash_no_pad::<PoseidonHash>(flattened)
}
}
/// Trait for target structs that may be converted to and from vectors
/// of targets.
pub trait Flattenable {
fn flatten(&self) -> Vec<Target>;
fn from_flattened(vs: &[Target]) -> Self;
fn from_flattened(params: &Params, vs: &[Target]) -> Self;
}
/// For the purpose of op verification, we need only look up the
@ -255,7 +420,7 @@ impl Flattenable for MerkleClaimTarget {
.concat()
}
fn from_flattened(vs: &[Target]) -> Self {
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
Self {
enabled: BoolTarget::new_unsafe(vs[0]),
root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()),
@ -270,22 +435,34 @@ impl Flattenable for MerkleClaimTarget {
}
}
impl Flattenable for PredicateTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, v: &[Target]) -> Self {
Self {
elements: v.try_into().expect("len is predicate_size"),
}
}
}
impl Flattenable for StatementTarget {
fn flatten(&self) -> Vec<Target> {
self.predicate
.iter()
.chain(self.args.iter().flat_map(|a| &a.elements))
.cloned()
.flatten()
.into_iter()
.chain(self.args.iter().flat_map(|a| &a.elements).cloned())
.collect()
}
fn from_flattened(v: &[Target]) -> Self {
fn from_flattened(params: &Params, v: &[Target]) -> Self {
let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN;
assert_eq!(
v.len(),
Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN
);
let predicate: [Target; Params::predicate_size()] = array::from_fn(|i| v[i]);
let predicate = PredicateTarget::from_flattened(params, &v[..Params::predicate_size()]);
let args = (0..num_args)
.map(|i| StatementArgTarget {
elements: array::from_fn(|j| {
@ -298,11 +475,75 @@ impl Flattenable for StatementTarget {
}
}
impl Flattenable for CustomPredicateTarget {
fn flatten(&self) -> Vec<Target> {
iter::once(self.conjunction.target)
.chain(iter::once(self.args_len))
.chain(self.statements.iter().flat_map(|s| s.flatten()))
.collect()
}
fn from_flattened(params: &Params, v: &[Target]) -> Self {
// We assume that `from_flattened` is always called with the output of `flattened`, so
// this `BoolTarget` should actually safe.
let conjunction = BoolTarget::new_unsafe(v[0]);
let args_len = v[1];
let st_tmpl_size = params.statement_tmpl_size();
let statements = (0..params.max_custom_predicate_arity)
.map(|i| {
let st_v = &v[2 + st_tmpl_size * i..2 + st_tmpl_size * (i + 1)];
StatementTmplTarget::from_flattened(params, st_v)
})
.collect();
Self {
conjunction,
statements,
args_len,
}
}
}
impl Flattenable for StatementTmplTarget {
fn flatten(&self) -> Vec<Target> {
self.pred
.flatten()
.into_iter()
.chain(self.args.iter().flat_map(|sta| sta.flatten()))
.collect()
}
fn from_flattened(params: &Params, v: &[Target]) -> Self {
let pred_end = Params::predicate_size();
let pred = PredicateTarget::from_flattened(params, &v[..pred_end]);
let sta_size = Params::statement_tmpl_arg_size();
let args = (0..params.max_statement_args)
.map(|i| {
let sta_v = &v[pred_end + sta_size * i..pred_end + sta_size * (i + 1)];
StatementTmplArgTarget::from_flattened(params, sta_v)
})
.collect();
Self { pred, args }
}
}
impl Flattenable for StatementTmplArgTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, v: &[Target]) -> Self {
Self {
elements: v.try_into().expect("len is statement_tmpl_arg_size"),
}
}
}
pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn connect_values(&mut self, x: ValueTarget, y: ValueTarget);
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
fn add_virtual_value(&mut self) -> ValueTarget;
fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget;
fn add_virtual_predicate(&mut self) -> PredicateTarget;
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget;
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget;
@ -329,8 +570,14 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
// Convenience methods for accessing and connecting elements of
// (vectors of) flattenables.
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T;
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T;
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
fn select_flattenable<T: Flattenable>(
&mut self,
params: &Params,
b: BoolTarget,
x: &T,
y: &T,
) -> T;
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T);
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget;
@ -358,8 +605,9 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
}
fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget {
let predicate = self.add_virtual_predicate();
StatementTarget {
predicate: self.add_virtual_target_arr(),
predicate,
args: (0..params.max_statement_args)
.map(|_| StatementArgTarget {
elements: self.add_virtual_target_arr(),
@ -368,6 +616,12 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
}
}
fn add_virtual_predicate(&mut self) -> PredicateTarget {
PredicateTarget {
elements: self.add_virtual_target_arr(),
}
}
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget {
OperationTarget {
op_type: self.add_virtual_target_arr(),
@ -470,7 +724,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
)
}
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T {
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
// TODO: Revisit this when we need more than 64 statements.
let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| {
assert!(v.len() <= 64);
@ -498,14 +752,21 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
};
let flattened_ts = ts.iter().map(|t| t.flatten()).collect::<Vec<_>>();
T::from_flattened(&matrix_row_ref(self, &flattened_ts, i))
T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
}
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T {
fn select_flattenable<T: Flattenable>(
&mut self,
params: &Params,
b: BoolTarget,
x: &T,
y: &T,
) -> T {
let flattened_x = x.flatten();
let flattened_y = y.flatten();
T::from_flattened(
params,
&iter::zip(flattened_x, flattened_y)
.map(|(x, y)| self.select(b, x, y))
.collect::<Vec<_>>(),
@ -532,3 +793,123 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
.unwrap_or(self._false())
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use plonky2::plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig};
use super::*;
use crate::{
backends::plonky2::basetypes::C,
examples::custom::{eth_dos_batch, eth_friend_batch},
frontend,
frontend::CustomPredicateBatchBuilder,
middleware::CustomPredicateBatch,
};
#[test]
fn custom_predicate_target() -> frontend::Result<()> {
let params = Params::default();
let config = CircuitConfig::standard_recursion_config();
let custom_predicate_batch = eth_friend_batch(&params)?;
for (i, cp) in custom_predicate_batch.predicates.iter().enumerate() {
let mut builder = CircuitBuilder::<F, D>::new(config.clone());
let flattened = cp.to_fields(&params);
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
let cp_target = CustomPredicateTarget::from_flattened(&params, &flatteend_target);
// Round trip of from_flattened to flattened
let flatteend_target_rt = cp_target.flatten();
builder.connect_slice(&flatteend_target, &flatteend_target_rt);
let pw = PartialWitness::<F>::new();
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).expect(&format!("predicate {}", i));
data.verify(proof.clone()).unwrap();
}
Ok(())
}
fn test_custom_predicate_batch_target_id(
params: &Params,
custom_predicate_batch: &CustomPredicateBatch,
) -> frontend::Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let zero = builder.zero();
let predicate_targets = custom_predicate_batch
.predicates
.iter()
.map(|cp| {
let flattened = cp.to_fields(params);
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
CustomPredicateTarget::from_flattened(params, &flatteend_target)
})
.chain(iter::repeat({
let empty_flatteend_target = iter::repeat(zero)
.take(params.custom_predicate_size())
.collect_vec();
CustomPredicateTarget::from_flattened(params, &empty_flatteend_target)
}))
.take(params.max_custom_batch_size)
.collect();
let custom_predicate_batch_target = CustomPredicateBatchTarget {
predicates: predicate_targets,
};
// Calculate the id in constraints and compare it against the id calculated natively
let id_target = custom_predicate_batch_target.id(&mut builder);
let id = custom_predicate_batch.id(params);
let id_expected_target = HashOutTarget {
elements: id
.to_fields(params)
.iter()
.map(|v| builder.constant(*v))
.collect_vec()
.try_into()
.unwrap(),
};
builder.connect_array(id_target.elements, id_expected_target.elements);
let pw = PartialWitness::<F>::new();
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof.clone()).unwrap();
Ok(())
}
#[test]
fn custom_predicate_batch_target() -> frontend::Result<()> {
let params = Params {
max_statement_args: 6,
max_custom_predicate_wildcards: 12,
..Default::default()
};
// Empty case
let mut cpb_builder = CustomPredicateBatchBuilder::new("empty".into());
_ = cpb_builder.predicate_and("empty", &params, &[], &[], &[])?;
let custom_predicate_batch = cpb_builder.finish();
test_custom_predicate_batch_target_id(&params, &custom_predicate_batch)?;
// Some cases from the examples
let custom_predicate_batch = eth_friend_batch(&params)?;
test_custom_predicate_batch_target_id(&params, &custom_predicate_batch)?;
let custom_predicate_batch = eth_dos_batch(&params)?;
test_custom_predicate_batch_target_id(&params, &custom_predicate_batch)?;
Ok(())
}
}

View file

@ -85,14 +85,14 @@ impl OperationVerifyGadget {
op.args
.iter()
.flatten()
.map(|&i| builder.vec_ref(prev_statements, i))
.map(|&i| builder.vec_ref(&self.params, prev_statements, i))
.collect::<Vec<_>>()
};
// Certain operations (Contains/NotContains) will refer to one
// of the provided Merkle proofs (if any). These proofs have already
// been verified, so we need only look up the claim.
let resolved_merkle_claim =
(!merkle_claims.is_empty()).then(|| builder.vec_ref(merkle_claims, op.aux[0]));
let resolved_merkle_claim = (!merkle_claims.is_empty())
.then(|| builder.vec_ref(&self.params, merkle_claims, op.aux[0]));
// The verification may require aux data which needs to be stored in the
// `OperationVerifyTarget` so that we can set during witness generation.
@ -455,7 +455,7 @@ impl OperationVerifyGadget {
let individual_checks = prev_statements
.iter()
.map(|ps| {
let same_predicate = builder.is_equal_slice(&st.predicate, &ps.predicate);
let same_predicate = builder.is_equal_flattenable(&st.predicate, &ps.predicate);
let same_anchored_key =
builder.is_equal_slice(&st.args[0].elements, &ps.args[0].elements);
builder.and(same_predicate, same_anchored_key)
@ -575,15 +575,7 @@ impl MainPodVerifyGadget {
.collect();
// 2. Calculate the Pod Id from the public statements
let pub_statements_flattened = pub_statements
.iter()
.flat_map(|s| {
s.predicate
.iter()
.chain(s.args.iter().flat_map(|a| &a.elements))
})
.cloned()
.collect();
let pub_statements_flattened = pub_statements.iter().flat_map(|s| s.flatten()).collect();
let id = builder.hash_n_to_hash_no_pad::<PoseidonHash>(pub_statements_flattened);
// 4. Verify type
@ -591,6 +583,7 @@ impl MainPodVerifyGadget {
// TODO: Store this hash in a global static with lazy init so that we don't have to
// compute it every time.
let expected_type_statement = StatementTarget::from_flattened(
&self.params,
&builder.constants(
&Statement::ValueOf(
AnchoredKey::from((SELF, KEY_TYPE)),

View file

@ -3,17 +3,16 @@ use std::iter;
use itertools::Itertools;
use plonky2::{
hash::hash_types::{HashOut, HashOutTarget},
iop::{
target::Target,
witness::{PartialWitness, WitnessWrite},
},
iop::witness::{PartialWitness, WitnessWrite},
plonk::circuit_builder::CircuitBuilder,
};
use crate::{
backends::plonky2::{
basetypes::D,
circuits::common::{CircuitBuilderPod, StatementArgTarget, StatementTarget, ValueTarget},
circuits::common::{
CircuitBuilderPod, PredicateTarget, StatementArgTarget, StatementTarget, ValueTarget,
},
error::Result,
primitives::{
merkletree::{
@ -24,8 +23,8 @@ use crate::{
signedpod::SignedPod,
},
middleware::{
hash_str, Key, NativePredicate, Params, PodType, Predicate, RawValue, ToFields, Value, F,
KEY_SIGNER, KEY_TYPE, SELF,
hash_str, Key, NativePredicate, Params, PodType, RawValue, Value, F, KEY_SIGNER, KEY_TYPE,
SELF,
},
};
@ -91,10 +90,8 @@ impl SignedPodVerifyTarget {
self_id: bool,
) -> Vec<StatementTarget> {
let mut statements = Vec::new();
let predicate: [Target; Params::predicate_size()] = builder
.constants(&Predicate::Native(NativePredicate::ValueOf).to_fields(&self.params))
.try_into()
.expect("size predicate_size");
let predicate =
PredicateTarget::new_native(builder, &self.params, NativePredicate::ValueOf);
let pod_id = if self_id {
builder.constant_value(SELF.0.into())
} else {
@ -111,7 +108,10 @@ impl SignedPodVerifyTarget {
.chain(iter::repeat_with(|| StatementArgTarget::none(builder)))
.take(self.params.max_statement_args)
.collect();
let statement = StatementTarget { predicate, args };
let statement = StatementTarget {
predicate: predicate.clone(),
args,
};
statements.push(statement);
}
statements