add target types for custom predicates (#223)

* add target types for custom predicates

* simplify

* fix clippy

* fix typo

* don't use ref for NativePredicate

* fix wrong len

* apply feedback from @ax0
This commit is contained in:
Eduard S. 2025-05-07 11:09:38 +02:00 committed by GitHub
parent bf394eada3
commit 726f95483d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 527 additions and 123 deletions

View file

@ -26,9 +26,9 @@ use crate::{
primitives::merkletree::MerkleClaimAndProofTarget, primitives::merkletree::MerkleClaimAndProofTarget,
}, },
middleware::{ middleware::{
NativeOperation, NativePredicate, Params, Predicate, RawValue, StatementArg, ToFields, NativeOperation, NativePredicate, Params, Predicate, PredicatePrefix, RawValue,
EMPTY_VALUE, F, HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, StatementArg, StatementTmplArgPrefix, ToFields, EMPTY_VALUE, F, HASH_SIZE,
VALUE_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE,
}, },
}; };
@ -117,20 +117,37 @@ impl StatementArgTarget {
#[derive(Clone)] #[derive(Clone)]
pub struct StatementTarget { pub struct StatementTarget {
pub predicate: [Target; Params::predicate_size()], pub predicate: PredicateTarget,
pub args: Vec<StatementArgTarget>, pub args: Vec<StatementArgTarget>,
} }
pub trait Build<T> {
fn build(self, builder: &mut CircuitBuilder<F, D>, params: &Params) -> T;
}
impl Build<NativePredicateTarget> for NativePredicate {
fn build(self, builder: &mut CircuitBuilder<F, D>, params: &Params) -> NativePredicateTarget {
NativePredicateTarget::constant(builder, params, self)
}
}
impl<T> Build<T> for T {
fn build(self, _builder: &mut CircuitBuilder<F, D>, _params: &Params) -> T {
self
}
}
impl StatementTarget { impl StatementTarget {
pub fn new_native( pub fn new_native(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
params: &Params, params: &Params,
predicate: NativePredicate, native_predicate: impl Build<NativePredicateTarget>,
args: &[StatementArgTarget], args: &[StatementArgTarget],
) -> Self { ) -> Self {
let predicate_vec = builder.constants(&Predicate::Native(predicate).to_fields(params)); // if native_predicate is const then NativePredicate -> NativePredicateTarget
// else just use as is
Self { Self {
predicate: array::from_fn(|i| predicate_vec[i]), predicate: PredicateTarget::new_native(builder, params, native_predicate),
args: args args: args
.iter() .iter()
.cloned() .cloned()
@ -146,7 +163,7 @@ impl StatementTarget {
params: &Params, params: &Params,
st: &Statement, st: &Statement,
) -> Result<()> { ) -> Result<()> {
pw.set_target_arr(&self.predicate, &st.predicate().to_fields(params))?; self.predicate.set_targets(pw, params, st.predicate())?;
for (i, arg) in st for (i, arg) in st
.args() .args()
.iter() .iter()
@ -165,8 +182,8 @@ impl StatementTarget {
params: &Params, params: &Params,
t: NativePredicate, t: NativePredicate,
) -> BoolTarget { ) -> BoolTarget {
let st_code = builder.constants(&Predicate::Native(t).to_fields(params)); let expected_predicate = PredicateTarget::new_native(builder, params, t);
builder.is_equal_slice(&self.predicate, &st_code) builder.is_equal_flattenable(&self.predicate, &expected_predicate)
} }
} }
@ -212,11 +229,159 @@ impl OperationTarget {
} }
} }
#[derive(Clone)]
pub struct NativePredicateTarget(Target);
impl NativePredicateTarget {
pub fn constant(
builder: &mut CircuitBuilder<F, D>,
params: &Params,
native_predicate: NativePredicate,
) -> Self {
let id = native_predicate.to_fields(params);
assert_eq!(1, id.len());
Self(builder.constant(id[0]))
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
native_predicate: NativePredicate,
) -> Result<()> {
let id = native_predicate.to_fields(params);
assert_eq!(1, id.len());
Ok(pw.set_target(self.0, id[0])?)
}
}
#[derive(Clone)]
pub struct PredicateTarget {
elements: [Target; Params::predicate_size()],
}
impl PredicateTarget {
pub fn new_native(
builder: &mut CircuitBuilder<F, D>,
params: &Params,
native_predicate: impl Build<NativePredicateTarget>,
) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::Native));
let id = native_predicate.build(builder, params).0;
let zero = builder.zero();
Self {
elements: [prefix, id, zero, zero, zero, zero],
}
}
pub fn new_batch_self(builder: &mut CircuitBuilder<F, D>, index: Target) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::BatchSelf));
let zero = builder.zero();
Self {
elements: [prefix, index, zero, zero, zero, zero],
}
}
pub fn new_custom(
builder: &mut CircuitBuilder<F, D>,
batch_id: HashOutTarget,
index: Target,
) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::Custom));
let id = batch_id.elements;
Self {
elements: [prefix, id[0], id[1], id[2], id[3], index],
}
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
predicate: Predicate,
) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &predicate.to_fields(params))?)
}
}
#[derive(Clone)]
pub struct KeyOrWildcardTarget {
pub elements: [Target; VALUE_SIZE],
}
impl KeyOrWildcardTarget {
fn from_slice(v: &[Target]) -> Self {
Self {
elements: v.try_into().expect("len is VALUE_SIZE"),
}
}
}
#[derive(Clone)]
pub struct StatementTmplArgTarget {
pub elements: [Target; Params::statement_tmpl_arg_size()],
}
impl StatementTmplArgTarget {
pub fn as_none(&self, builder: &mut CircuitBuilder<F, D>) -> BoolTarget {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::None));
builder.is_equal(self.elements[0], prefix)
}
pub fn as_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, ValueTarget) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal));
let case_ok = builder.is_equal(self.elements[0], prefix);
let value = ValueTarget::from_slice(&self.elements[1..5]);
(case_ok, value)
}
pub fn as_key(
&self,
builder: &mut CircuitBuilder<F, D>,
) -> (BoolTarget, Target, KeyOrWildcardTarget) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Key));
let case_ok = builder.is_equal(self.elements[0], prefix);
let id_wildcard_index = self.elements[1];
let value_key_or_wildcard = KeyOrWildcardTarget::from_slice(&self.elements[5..9]);
(case_ok, id_wildcard_index, value_key_or_wildcard)
}
pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, Target) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral));
let case_ok = builder.is_equal(self.elements[0], prefix);
let wildcard_index = self.elements[1];
(case_ok, wildcard_index)
}
}
#[derive(Clone)]
pub struct StatementTmplTarget {
pub pred: PredicateTarget,
pub args: Vec<StatementTmplArgTarget>,
}
#[derive(Clone)]
pub struct CustomPredicateTarget {
pub conjunction: BoolTarget,
// len = params.max_custom_predicate_arity
pub statements: Vec<StatementTmplTarget>,
pub args_len: Target,
}
#[derive(Clone)]
pub struct CustomPredicateBatchTarget {
pub predicates: Vec<CustomPredicateTarget>,
}
impl CustomPredicateBatchTarget {
pub fn id(&self, builder: &mut CircuitBuilder<F, D>) -> HashOutTarget {
let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect();
builder.hash_n_to_hash_no_pad::<PoseidonHash>(flattened)
}
}
/// Trait for target structs that may be converted to and from vectors /// Trait for target structs that may be converted to and from vectors
/// of targets. /// of targets.
pub trait Flattenable { pub trait Flattenable {
fn flatten(&self) -> Vec<Target>; fn flatten(&self) -> Vec<Target>;
fn from_flattened(vs: &[Target]) -> Self; fn from_flattened(params: &Params, vs: &[Target]) -> Self;
} }
/// For the purpose of op verification, we need only look up the /// For the purpose of op verification, we need only look up the
@ -255,7 +420,7 @@ impl Flattenable for MerkleClaimTarget {
.concat() .concat()
} }
fn from_flattened(vs: &[Target]) -> Self { fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
Self { Self {
enabled: BoolTarget::new_unsafe(vs[0]), enabled: BoolTarget::new_unsafe(vs[0]),
root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()), root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()),
@ -270,22 +435,34 @@ impl Flattenable for MerkleClaimTarget {
} }
} }
impl Flattenable for PredicateTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, v: &[Target]) -> Self {
Self {
elements: v.try_into().expect("len is predicate_size"),
}
}
}
impl Flattenable for StatementTarget { impl Flattenable for StatementTarget {
fn flatten(&self) -> Vec<Target> { fn flatten(&self) -> Vec<Target> {
self.predicate self.predicate
.iter() .flatten()
.chain(self.args.iter().flat_map(|a| &a.elements)) .into_iter()
.cloned() .chain(self.args.iter().flat_map(|a| &a.elements).cloned())
.collect() .collect()
} }
fn from_flattened(v: &[Target]) -> Self { fn from_flattened(params: &Params, v: &[Target]) -> Self {
let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN; let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN;
assert_eq!( assert_eq!(
v.len(), v.len(),
Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN
); );
let predicate: [Target; Params::predicate_size()] = array::from_fn(|i| v[i]); let predicate = PredicateTarget::from_flattened(params, &v[..Params::predicate_size()]);
let args = (0..num_args) let args = (0..num_args)
.map(|i| StatementArgTarget { .map(|i| StatementArgTarget {
elements: array::from_fn(|j| { elements: array::from_fn(|j| {
@ -298,11 +475,75 @@ impl Flattenable for StatementTarget {
} }
} }
impl Flattenable for CustomPredicateTarget {
fn flatten(&self) -> Vec<Target> {
iter::once(self.conjunction.target)
.chain(iter::once(self.args_len))
.chain(self.statements.iter().flat_map(|s| s.flatten()))
.collect()
}
fn from_flattened(params: &Params, v: &[Target]) -> Self {
// We assume that `from_flattened` is always called with the output of `flattened`, so
// this `BoolTarget` should actually safe.
let conjunction = BoolTarget::new_unsafe(v[0]);
let args_len = v[1];
let st_tmpl_size = params.statement_tmpl_size();
let statements = (0..params.max_custom_predicate_arity)
.map(|i| {
let st_v = &v[2 + st_tmpl_size * i..2 + st_tmpl_size * (i + 1)];
StatementTmplTarget::from_flattened(params, st_v)
})
.collect();
Self {
conjunction,
statements,
args_len,
}
}
}
impl Flattenable for StatementTmplTarget {
fn flatten(&self) -> Vec<Target> {
self.pred
.flatten()
.into_iter()
.chain(self.args.iter().flat_map(|sta| sta.flatten()))
.collect()
}
fn from_flattened(params: &Params, v: &[Target]) -> Self {
let pred_end = Params::predicate_size();
let pred = PredicateTarget::from_flattened(params, &v[..pred_end]);
let sta_size = Params::statement_tmpl_arg_size();
let args = (0..params.max_statement_args)
.map(|i| {
let sta_v = &v[pred_end + sta_size * i..pred_end + sta_size * (i + 1)];
StatementTmplArgTarget::from_flattened(params, sta_v)
})
.collect();
Self { pred, args }
}
}
impl Flattenable for StatementTmplArgTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, v: &[Target]) -> Self {
Self {
elements: v.try_into().expect("len is statement_tmpl_arg_size"),
}
}
}
pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> { pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn connect_values(&mut self, x: ValueTarget, y: ValueTarget); fn connect_values(&mut self, x: ValueTarget, y: ValueTarget);
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]); fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
fn add_virtual_value(&mut self) -> ValueTarget; fn add_virtual_value(&mut self) -> ValueTarget;
fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget; fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget;
fn add_virtual_predicate(&mut self) -> PredicateTarget;
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget; fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget;
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget; fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget;
@ -329,8 +570,14 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
// Convenience methods for accessing and connecting elements of // Convenience methods for accessing and connecting elements of
// (vectors of) flattenables. // (vectors of) flattenables.
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T; fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T; fn select_flattenable<T: Flattenable>(
&mut self,
params: &Params,
b: BoolTarget,
x: &T,
y: &T,
) -> T;
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T); fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T);
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget; fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget;
@ -358,8 +605,9 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
} }
fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget { fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget {
let predicate = self.add_virtual_predicate();
StatementTarget { StatementTarget {
predicate: self.add_virtual_target_arr(), predicate,
args: (0..params.max_statement_args) args: (0..params.max_statement_args)
.map(|_| StatementArgTarget { .map(|_| StatementArgTarget {
elements: self.add_virtual_target_arr(), elements: self.add_virtual_target_arr(),
@ -368,6 +616,12 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
} }
} }
fn add_virtual_predicate(&mut self) -> PredicateTarget {
PredicateTarget {
elements: self.add_virtual_target_arr(),
}
}
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget {
OperationTarget { OperationTarget {
op_type: self.add_virtual_target_arr(), op_type: self.add_virtual_target_arr(),
@ -470,7 +724,7 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
) )
} }
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T { fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
// TODO: Revisit this when we need more than 64 statements. // TODO: Revisit this when we need more than 64 statements.
let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| { let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| {
assert!(v.len() <= 64); assert!(v.len() <= 64);
@ -498,14 +752,21 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
}; };
let flattened_ts = ts.iter().map(|t| t.flatten()).collect::<Vec<_>>(); let flattened_ts = ts.iter().map(|t| t.flatten()).collect::<Vec<_>>();
T::from_flattened(&matrix_row_ref(self, &flattened_ts, i)) T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
} }
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T { fn select_flattenable<T: Flattenable>(
&mut self,
params: &Params,
b: BoolTarget,
x: &T,
y: &T,
) -> T {
let flattened_x = x.flatten(); let flattened_x = x.flatten();
let flattened_y = y.flatten(); let flattened_y = y.flatten();
T::from_flattened( T::from_flattened(
params,
&iter::zip(flattened_x, flattened_y) &iter::zip(flattened_x, flattened_y)
.map(|(x, y)| self.select(b, x, y)) .map(|(x, y)| self.select(b, x, y))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
@ -532,3 +793,123 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
.unwrap_or(self._false()) .unwrap_or(self._false())
} }
} }
#[cfg(test)]
mod tests {
use itertools::Itertools;
use plonky2::plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig};
use super::*;
use crate::{
backends::plonky2::basetypes::C,
examples::custom::{eth_dos_batch, eth_friend_batch},
frontend,
frontend::CustomPredicateBatchBuilder,
middleware::CustomPredicateBatch,
};
#[test]
fn custom_predicate_target() -> frontend::Result<()> {
let params = Params::default();
let config = CircuitConfig::standard_recursion_config();
let custom_predicate_batch = eth_friend_batch(&params)?;
for (i, cp) in custom_predicate_batch.predicates.iter().enumerate() {
let mut builder = CircuitBuilder::<F, D>::new(config.clone());
let flattened = cp.to_fields(&params);
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
let cp_target = CustomPredicateTarget::from_flattened(&params, &flatteend_target);
// Round trip of from_flattened to flattened
let flatteend_target_rt = cp_target.flatten();
builder.connect_slice(&flatteend_target, &flatteend_target_rt);
let pw = PartialWitness::<F>::new();
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).expect(&format!("predicate {}", i));
data.verify(proof.clone()).unwrap();
}
Ok(())
}
fn test_custom_predicate_batch_target_id(
params: &Params,
custom_predicate_batch: &CustomPredicateBatch,
) -> frontend::Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let zero = builder.zero();
let predicate_targets = custom_predicate_batch
.predicates
.iter()
.map(|cp| {
let flattened = cp.to_fields(params);
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
CustomPredicateTarget::from_flattened(params, &flatteend_target)
})
.chain(iter::repeat({
let empty_flatteend_target = iter::repeat(zero)
.take(params.custom_predicate_size())
.collect_vec();
CustomPredicateTarget::from_flattened(params, &empty_flatteend_target)
}))
.take(params.max_custom_batch_size)
.collect();
let custom_predicate_batch_target = CustomPredicateBatchTarget {
predicates: predicate_targets,
};
// Calculate the id in constraints and compare it against the id calculated natively
let id_target = custom_predicate_batch_target.id(&mut builder);
let id = custom_predicate_batch.id(params);
let id_expected_target = HashOutTarget {
elements: id
.to_fields(params)
.iter()
.map(|v| builder.constant(*v))
.collect_vec()
.try_into()
.unwrap(),
};
builder.connect_array(id_target.elements, id_expected_target.elements);
let pw = PartialWitness::<F>::new();
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();
data.verify(proof.clone()).unwrap();
Ok(())
}
#[test]
fn custom_predicate_batch_target() -> frontend::Result<()> {
let params = Params {
max_statement_args: 6,
max_custom_predicate_wildcards: 12,
..Default::default()
};
// Empty case
let mut cpb_builder = CustomPredicateBatchBuilder::new("empty".into());
_ = cpb_builder.predicate_and("empty", &params, &[], &[], &[])?;
let custom_predicate_batch = cpb_builder.finish();
test_custom_predicate_batch_target_id(&params, &custom_predicate_batch)?;
// Some cases from the examples
let custom_predicate_batch = eth_friend_batch(&params)?;
test_custom_predicate_batch_target_id(&params, &custom_predicate_batch)?;
let custom_predicate_batch = eth_dos_batch(&params)?;
test_custom_predicate_batch_target_id(&params, &custom_predicate_batch)?;
Ok(())
}
}

View file

@ -85,14 +85,14 @@ impl OperationVerifyGadget {
op.args op.args
.iter() .iter()
.flatten() .flatten()
.map(|&i| builder.vec_ref(prev_statements, i)) .map(|&i| builder.vec_ref(&self.params, prev_statements, i))
.collect::<Vec<_>>() .collect::<Vec<_>>()
}; };
// Certain operations (Contains/NotContains) will refer to one // Certain operations (Contains/NotContains) will refer to one
// of the provided Merkle proofs (if any). These proofs have already // of the provided Merkle proofs (if any). These proofs have already
// been verified, so we need only look up the claim. // been verified, so we need only look up the claim.
let resolved_merkle_claim = let resolved_merkle_claim = (!merkle_claims.is_empty())
(!merkle_claims.is_empty()).then(|| builder.vec_ref(merkle_claims, op.aux[0])); .then(|| builder.vec_ref(&self.params, merkle_claims, op.aux[0]));
// The verification may require aux data which needs to be stored in the // The verification may require aux data which needs to be stored in the
// `OperationVerifyTarget` so that we can set during witness generation. // `OperationVerifyTarget` so that we can set during witness generation.
@ -455,7 +455,7 @@ impl OperationVerifyGadget {
let individual_checks = prev_statements let individual_checks = prev_statements
.iter() .iter()
.map(|ps| { .map(|ps| {
let same_predicate = builder.is_equal_slice(&st.predicate, &ps.predicate); let same_predicate = builder.is_equal_flattenable(&st.predicate, &ps.predicate);
let same_anchored_key = let same_anchored_key =
builder.is_equal_slice(&st.args[0].elements, &ps.args[0].elements); builder.is_equal_slice(&st.args[0].elements, &ps.args[0].elements);
builder.and(same_predicate, same_anchored_key) builder.and(same_predicate, same_anchored_key)
@ -575,15 +575,7 @@ impl MainPodVerifyGadget {
.collect(); .collect();
// 2. Calculate the Pod Id from the public statements // 2. Calculate the Pod Id from the public statements
let pub_statements_flattened = pub_statements let pub_statements_flattened = pub_statements.iter().flat_map(|s| s.flatten()).collect();
.iter()
.flat_map(|s| {
s.predicate
.iter()
.chain(s.args.iter().flat_map(|a| &a.elements))
})
.cloned()
.collect();
let id = builder.hash_n_to_hash_no_pad::<PoseidonHash>(pub_statements_flattened); let id = builder.hash_n_to_hash_no_pad::<PoseidonHash>(pub_statements_flattened);
// 4. Verify type // 4. Verify type
@ -591,6 +583,7 @@ impl MainPodVerifyGadget {
// TODO: Store this hash in a global static with lazy init so that we don't have to // TODO: Store this hash in a global static with lazy init so that we don't have to
// compute it every time. // compute it every time.
let expected_type_statement = StatementTarget::from_flattened( let expected_type_statement = StatementTarget::from_flattened(
&self.params,
&builder.constants( &builder.constants(
&Statement::ValueOf( &Statement::ValueOf(
AnchoredKey::from((SELF, KEY_TYPE)), AnchoredKey::from((SELF, KEY_TYPE)),

View file

@ -3,17 +3,16 @@ use std::iter;
use itertools::Itertools; use itertools::Itertools;
use plonky2::{ use plonky2::{
hash::hash_types::{HashOut, HashOutTarget}, hash::hash_types::{HashOut, HashOutTarget},
iop::{ iop::witness::{PartialWitness, WitnessWrite},
target::Target,
witness::{PartialWitness, WitnessWrite},
},
plonk::circuit_builder::CircuitBuilder, plonk::circuit_builder::CircuitBuilder,
}; };
use crate::{ use crate::{
backends::plonky2::{ backends::plonky2::{
basetypes::D, basetypes::D,
circuits::common::{CircuitBuilderPod, StatementArgTarget, StatementTarget, ValueTarget}, circuits::common::{
CircuitBuilderPod, PredicateTarget, StatementArgTarget, StatementTarget, ValueTarget,
},
error::Result, error::Result,
primitives::{ primitives::{
merkletree::{ merkletree::{
@ -24,8 +23,8 @@ use crate::{
signedpod::SignedPod, signedpod::SignedPod,
}, },
middleware::{ middleware::{
hash_str, Key, NativePredicate, Params, PodType, Predicate, RawValue, ToFields, Value, F, hash_str, Key, NativePredicate, Params, PodType, RawValue, Value, F, KEY_SIGNER, KEY_TYPE,
KEY_SIGNER, KEY_TYPE, SELF, SELF,
}, },
}; };
@ -91,10 +90,8 @@ impl SignedPodVerifyTarget {
self_id: bool, self_id: bool,
) -> Vec<StatementTarget> { ) -> Vec<StatementTarget> {
let mut statements = Vec::new(); let mut statements = Vec::new();
let predicate: [Target; Params::predicate_size()] = builder let predicate =
.constants(&Predicate::Native(NativePredicate::ValueOf).to_fields(&self.params)) PredicateTarget::new_native(builder, &self.params, NativePredicate::ValueOf);
.try_into()
.expect("size predicate_size");
let pod_id = if self_id { let pod_id = if self_id {
builder.constant_value(SELF.0.into()) builder.constant_value(SELF.0.into())
} else { } else {
@ -111,7 +108,10 @@ impl SignedPodVerifyTarget {
.chain(iter::repeat_with(|| StatementArgTarget::none(builder))) .chain(iter::repeat_with(|| StatementArgTarget::none(builder)))
.take(self.params.max_statement_args) .take(self.params.max_statement_args)
.collect(); .collect();
let statement = StatementTarget { predicate, args }; let statement = StatementTarget {
predicate: predicate.clone(),
args,
};
statements.push(statement); statements.push(statement);
} }
statements statements

View file

@ -5,7 +5,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::middleware::{ use crate::middleware::{
hash_fields, Error, Hash, Key, NativePredicate, Params, Result, ToFields, Value, F, HASH_SIZE, hash_fields, Error, Hash, Key, Params, Predicate, Result, ToFields, Value, F, HASH_SIZE,
}; };
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@ -72,40 +72,54 @@ pub enum StatementTmplArg {
WildcardLiteral(Wildcard), WildcardLiteral(Wildcard),
} }
#[derive(Clone, Copy)]
pub enum StatementTmplArgPrefix {
None = 0,
Literal = 1,
Key = 2,
WildcardLiteral = 3,
}
impl From<StatementTmplArgPrefix> for F {
fn from(prefix: StatementTmplArgPrefix) -> Self {
Self::from_canonical_usize(prefix as usize)
}
}
impl ToFields for StatementTmplArg { impl ToFields for StatementTmplArg {
fn to_fields(&self, params: &Params) -> Vec<F> { fn to_fields(&self, params: &Params) -> Vec<F> {
// None => (0, ...) // None => (0, ...)
// Literal(value) => (1, [value], 0, 0, 0, 0) // Literal(value) => (1, [value], 0, 0, 0, 0)
// Key(wildcard1, key_or_wildcard2) // Key(wildcard1_index, key_or_wildcard2)
// => (2, [wildcard1], [key_or_wildcard2]) // => (2, [wildcard1_index], 0, 0, 0, [key_or_wildcard2])
// WildcardLiteral(wildcard) => (3, [wildcard], 0, 0, 0, 0) // WildcardLiteral(wildcard_index) => (3, [wildcard_index], 0, 0, 0, 0, 0, 0, 0)
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements // In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
let statement_tmpl_arg_size = 2 * HASH_SIZE + 1;
match self { match self {
StatementTmplArg::None => { StatementTmplArg::None => {
let fields: Vec<F> = iter::repeat_with(|| F::from_canonical_u64(0)) let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::None))
.take(statement_tmpl_arg_size) .chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect(); .collect();
fields fields
} }
StatementTmplArg::Literal(v) => { StatementTmplArg::Literal(v) => {
let fields: Vec<F> = iter::once(F::from_canonical_u64(1)) let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal))
.chain(v.raw().to_fields(params)) .chain(v.raw().to_fields(params))
.chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) .chain(iter::repeat(F::ZERO).take(HASH_SIZE))
.collect(); .collect();
fields fields
} }
StatementTmplArg::Key(wc1, kw2) => { StatementTmplArg::Key(wc1, kw2) => {
let fields: Vec<F> = iter::once(F::from_canonical_u64(2)) let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Key))
.chain(wc1.to_fields(params)) .chain(wc1.to_fields(params))
.chain(kw2.to_fields(params)) .chain(kw2.to_fields(params))
.collect(); .collect();
fields fields
} }
StatementTmplArg::WildcardLiteral(wc) => { StatementTmplArg::WildcardLiteral(wc) => {
let fields: Vec<F> = iter::once(F::from_canonical_u64(3)) let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
.chain(wc.to_fields(params)) .chain(wc.to_fields(params))
.chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) .chain(iter::repeat(F::ZERO).take(HASH_SIZE))
.collect(); .collect();
fields fields
} }
@ -312,7 +326,10 @@ impl ToFields for CustomPredicateBatch {
} }
impl CustomPredicateBatch { impl CustomPredicateBatch {
pub fn hash(&self, params: &Params) -> Hash { /// Cryptographic identifier for the batch.
pub fn id(&self, params: &Params) -> Hash {
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
// but ideally we want to use the root of a merkle tree built from the custom predicates.
let input = self.to_fields(params); let input = self.to_fields(params);
hash_fields(&input) hash_fields(&input)
@ -334,65 +351,6 @@ impl CustomPredicateRef {
} }
} }
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", content = "value")]
pub enum Predicate {
Native(NativePredicate),
BatchSelf(usize),
Custom(CustomPredicateRef),
}
impl From<NativePredicate> for Predicate {
fn from(v: NativePredicate) -> Self {
Self::Native(v)
}
}
impl ToFields for Predicate {
fn to_fields(&self, params: &Params) -> Vec<F> {
// serialize:
// NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize
// BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize
// CustomPredicateRef(pb, i) as
// (2, [hash of pb], i) -- pb hashes to 4 field elements
// -- i: usize
// in every case: pad to (hash_size + 2) field elements
let mut fields: Vec<F> = match self {
Self::Native(p) => iter::once(F::from_canonical_u64(1))
.chain(p.to_fields(params))
.collect(),
Self::BatchSelf(i) => iter::once(F::from_canonical_u64(2))
.chain(iter::once(F::from_canonical_usize(*i)))
.collect(),
Self::Custom(CustomPredicateRef { batch, index }) => {
iter::once(F::from_canonical_u64(3))
.chain(batch.hash(params).0)
.chain(iter::once(F::from_canonical_usize(*index)))
.collect()
}
};
fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0));
fields
}
}
impl fmt::Display for Predicate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Native(p) => write!(f, "{:?}", p),
Self::BatchSelf(i) => write!(f, "self.{}", i),
Self::Custom(CustomPredicateRef { batch, index }) => {
write!(
f,
"{}.{}[{}]",
batch.name, index, batch.predicates[*index].name
)
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{array, sync::Arc}; use std::{array, sync::Arc};

View file

@ -43,7 +43,7 @@ impl ToFields for OperationType {
.collect(), .collect(),
Self::Custom(CustomPredicateRef { batch, index }) => { Self::Custom(CustomPredicateRef { batch, index }) => {
iter::once(F::from_canonical_u64(3)) iter::once(F::from_canonical_u64(3))
.chain(batch.hash(params).0) .chain(batch.id(params).0)
.chain(iter::once(F::from_canonical_usize(*index))) .chain(iter::once(F::from_canonical_usize(*index)))
.collect() .collect()
} }

View file

@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize};
use strum_macros::FromRepr; use strum_macros::FromRepr;
use crate::middleware::{ use crate::middleware::{
AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, Predicate, RawValue, Result, AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, RawValue, Result, ToFields, Value,
ToFields, Value, F, VALUE_SIZE, F, VALUE_SIZE,
}; };
// TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static // TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static
@ -84,6 +84,78 @@ impl ToFields for WildcardValue {
} }
} }
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", content = "value")]
pub enum Predicate {
Native(NativePredicate),
BatchSelf(usize),
Custom(CustomPredicateRef),
}
impl From<NativePredicate> for Predicate {
fn from(v: NativePredicate) -> Self {
Self::Native(v)
}
}
#[derive(Clone, Copy)]
pub enum PredicatePrefix {
Native = 1,
BatchSelf = 2,
Custom = 3,
}
impl From<PredicatePrefix> for F {
fn from(prefix: PredicatePrefix) -> Self {
Self::from_canonical_usize(prefix as usize)
}
}
impl ToFields for Predicate {
fn to_fields(&self, params: &Params) -> Vec<F> {
// serialize:
// NativePredicate(id) as (1, id, 0, 0, 0, 0) -- id: usize
// BatchSelf(i) as (2, i, 0, 0, 0, 0) -- i: usize
// CustomPredicateRef(pb, i) as
// (3, [hash of pb], i) -- pb hashes to 4 field elements
// -- i: usize
// in every case: pad to (hash_size + 2) field elements
let mut fields: Vec<F> = match self {
Self::Native(p) => iter::once(F::from(PredicatePrefix::Native))
.chain(p.to_fields(params))
.collect(),
Self::BatchSelf(i) => iter::once(F::from(PredicatePrefix::BatchSelf))
.chain(iter::once(F::from_canonical_usize(*i)))
.collect(),
Self::Custom(CustomPredicateRef { batch, index }) => {
iter::once(F::from(PredicatePrefix::Custom))
.chain(batch.id(params).0)
.chain(iter::once(F::from_canonical_usize(*index)))
.collect()
}
};
fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0));
fields
}
}
impl fmt::Display for Predicate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Native(p) => write!(f, "{:?}", p),
Self::BatchSelf(i) => write!(f, "self.{}", i),
Self::Custom(CustomPredicateRef { batch, index }) => {
write!(
f,
"{}.{}[{}]",
batch.name, index, batch.predicates[*index].name
)
}
}
}
}
/// Type encapsulating statements with their associated arguments. /// Type encapsulating statements with their associated arguments.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "predicate", content = "args")] #[serde(tag = "predicate", content = "args")]