Constraints for custom predicates (#227)

* add target types for custom predicates

* simplify

* fix clippy

* fix typo

* don't use ref for NativePredicate

* fix wrong len

* precalculate CustomPredicateBatch id

* wip

* wip

* move code back

* great progress

* wip

* code complete, hopefully; missing tests

* fill aux for custom predicate op

* fix clippy warnings

* fix typos

* fix test import

* fix missing assignment in lt_mask, test custom_operation_verify_gadget

* fix mistake

* wip

* fix

* debug revert except for let entry = CustomPredicateVerifyEntryTarget

* fix batch_id calculation by fixing padding

* oops

* remove completed TODOs
This commit is contained in:
Eduard S. 2025-05-13 11:00:45 +02:00 committed by GitHub
parent 4fa9e20ecd
commit 024ed8bd04
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1597 additions and 291 deletions

View file

@ -2,6 +2,7 @@
use std::{array, iter}; use std::{array, iter};
use itertools::Itertools;
use plonky2::{ use plonky2::{
field::{ field::{
extension::Extendable, extension::Extendable,
@ -12,23 +13,28 @@ use plonky2::{
poseidon::PoseidonHash, poseidon::PoseidonHash,
}, },
iop::{ iop::{
generator::{GeneratedValues, SimpleGenerator},
target::{BoolTarget, Target}, target::{BoolTarget, Target},
witness::{PartialWitness, WitnessWrite}, witness::{PartialWitness, PartitionWitness, Witness, WitnessWrite},
}, },
plonk::circuit_builder::CircuitBuilder, plonk::{circuit_builder::CircuitBuilder, circuit_data::CommonCircuitData},
util::serialization::{Buffer, IoResult, Read, Write},
}; };
use crate::{ use crate::{
backends::plonky2::{ backends::plonky2::{
basetypes::D, basetypes::D,
circuits::mainpod::CustomPredicateVerification,
error::Result, error::Result,
mainpod::{Operation, OperationArg, Statement}, mainpod::{Operation, OperationArg, Statement},
primitives::merkletree::MerkleClaimAndProofTarget, primitives::merkletree::MerkleClaimAndProofTarget,
}, },
middleware::{ middleware::{
NativeOperation, NativePredicate, Params, Predicate, PredicatePrefix, RawValue, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
StatementArg, StatementTmplArgPrefix, ToFields, EMPTY_VALUE, F, HASH_SIZE, NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg,
OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE, StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, WildcardValue,
EMPTY_VALUE, F, HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN,
VALUE_SIZE,
}, },
}; };
@ -65,6 +71,10 @@ impl ValueTarget {
elements: array::from_fn(|i| xs[i]), elements: array::from_fn(|i| xs[i]),
} }
} }
pub fn set_targets(&self, pw: &mut PartialWitness<F>, value: &Value) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &value.raw().0)?)
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -82,7 +92,7 @@ impl StatementArgTarget {
Ok(pw.set_target_arr(&self.elements, &arg.to_fields(params))?) Ok(pw.set_target_arr(&self.elements, &arg.to_fields(params))?)
} }
fn new(first: ValueTarget, second: ValueTarget) -> Self { pub fn new(first: ValueTarget, second: ValueTarget) -> Self {
let elements: Vec<_> = first.elements.into_iter().chain(second.elements).collect(); let elements: Vec<_> = first.elements.into_iter().chain(second.elements).collect();
StatementArgTarget { StatementArgTarget {
elements: elements.try_into().expect("size STATEMENT_ARG_F_LEN"), elements: elements.try_into().expect("size STATEMENT_ARG_F_LEN"),
@ -107,6 +117,11 @@ impl StatementArgTarget {
Self::new(*pod_id, *key) Self::new(*pod_id, *key)
} }
pub fn wildcard_literal(builder: &mut CircuitBuilder<F, D>, value: &ValueTarget) -> Self {
let empty = builder.constant_value(EMPTY_VALUE);
Self::new(*value, empty)
}
/// StatementArgTarget to ValueTarget coercion. Make sure to check /// StatementArgTarget to ValueTarget coercion. Make sure to check
/// that the arg is a value using the `statement_arg_is_value` method /// that the arg is a value using the `statement_arg_is_value` method
/// first! /// first!
@ -138,6 +153,7 @@ impl<T> Build<T> for T {
} }
impl StatementTarget { impl StatementTarget {
/// Build a new native StatementTarget
pub fn new_native( pub fn new_native(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
params: &Params, params: &Params,
@ -187,10 +203,60 @@ impl StatementTarget {
} }
} }
#[derive(Clone)]
pub struct OperationTypeTarget {
pub elements: [Target; Params::operation_type_size()],
}
impl OperationTypeTarget {
pub fn new_custom(
builder: &mut CircuitBuilder<F, D>,
batch_id: HashOutTarget,
index: Target,
) -> Self {
// TODO: Use an enum for these prefixes
let three = builder.constant(F::from_canonical_usize(3));
let id = batch_id.elements;
Self {
elements: [three, id[0], id[1], id[2], id[3], index],
}
}
pub fn as_custom(
&self,
builder: &mut CircuitBuilder<F, D>,
) -> (BoolTarget, HashOutTarget, Target) {
// TODO: Use an enum for these prefixes
let three = builder.constant(F::from_canonical_usize(3));
let op_is_custom = builder.is_equal(self.elements[0], three);
let batch_id = HashOutTarget::from_vec(self.elements[1..5].to_vec());
let index = self.elements[5];
(op_is_custom, batch_id, index)
}
pub fn has_native(&self, builder: &mut CircuitBuilder<F, D>, t: NativeOperation) -> BoolTarget {
// TODO: Use an enum for these prefixes
let one = builder.one();
let op_is_native = builder.is_equal(self.elements[0], one);
let op_code = builder.constant(F::from_canonical_u64(t as u64));
let op_code_matches = builder.is_equal(self.elements[1], op_code);
builder.and(op_is_native, op_code_matches)
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
op_type: &OperationType,
) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &op_type.to_fields(params))?)
}
}
// TODO: Implement Operation::to_field to determine the size of each element // TODO: Implement Operation::to_field to determine the size of each element
#[derive(Clone)] #[derive(Clone)]
pub struct OperationTarget { pub struct OperationTarget {
pub op_type: [Target; Params::operation_type_size()], pub op_type: OperationTypeTarget,
pub args: Vec<[Target; OPERATION_ARG_F_LEN]>, pub args: Vec<[Target; OPERATION_ARG_F_LEN]>,
pub aux: [Target; OPERATION_AUX_F_LEN], pub aux: [Target; OPERATION_AUX_F_LEN],
} }
@ -202,7 +268,7 @@ impl OperationTarget {
params: &Params, params: &Params,
op: &Operation, op: &Operation,
) -> Result<()> { ) -> Result<()> {
pw.set_target_arr(&self.op_type, &op.op_type().to_fields(params))?; self.op_type.set_targets(pw, params, &op.op_type())?;
for (i, arg) in op for (i, arg) in op
.args() .args()
.iter() .iter()
@ -215,18 +281,6 @@ impl OperationTarget {
pw.set_target_arr(&self.aux, &op.aux().to_fields(params))?; pw.set_target_arr(&self.aux, &op.aux().to_fields(params))?;
Ok(()) Ok(())
} }
pub fn has_native_type(
&self,
builder: &mut CircuitBuilder<F, D>,
t: NativeOperation,
) -> BoolTarget {
let one = builder.one();
let op_is_native = builder.is_equal(self.op_type[0], one);
let op_code = builder.constant(F::from_canonical_u64(t as u64));
let op_code_matches = builder.is_equal(self.op_type[1], op_code);
builder.and(op_is_native, op_code_matches)
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -304,17 +358,37 @@ impl PredicateTarget {
} }
} }
/// Mirrors `middleware::KeyOrWildcard`
#[derive(Clone)] #[derive(Clone)]
pub struct KeyOrWildcardTarget { pub struct LiteralOrWildcardTarget {
pub elements: [Target; VALUE_SIZE], pub elements: [Target; VALUE_SIZE],
} }
impl KeyOrWildcardTarget { impl LiteralOrWildcardTarget {
fn from_slice(v: &[Target]) -> Self { fn from_slice(v: &[Target]) -> Self {
Self { Self {
elements: v.try_into().expect("len is VALUE_SIZE"), elements: v.try_into().expect("len is VALUE_SIZE"),
} }
} }
/// cases: ((is_key, key), (is_wildcard, wildcard_index))
pub fn cases(
&self,
builder: &mut CircuitBuilder<F, D>,
) -> ((BoolTarget, ValueTarget), (BoolTarget, Target)) {
let zero = builder.zero();
let is_zero_tail: Vec<_> = (1..4)
.map(|i| builder.is_equal(self.elements[i], zero))
.collect();
let is_wildcard = is_zero_tail
.into_iter()
.reduce(|acc, x| builder.and(acc, x))
.expect("len > 1");
let is_key = builder.not(is_wildcard);
let key = ValueTarget::from_slice(&self.elements);
let wildcard_index = self.elements[0];
((is_key, key), (is_wildcard, wildcard_index))
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -327,28 +401,40 @@ impl StatementTmplArgTarget {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::None)); let prefix = builder.constant(F::from(StatementTmplArgPrefix::None));
builder.is_equal(self.elements[0], prefix) builder.is_equal(self.elements[0], prefix)
} }
pub fn as_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, ValueTarget) { pub fn as_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, ValueTarget) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal)); let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal));
let case_ok = builder.is_equal(self.elements[0], prefix); let case_ok = builder.is_equal(self.elements[0], prefix);
let value = ValueTarget::from_slice(&self.elements[1..5]); let value = ValueTarget::from_slice(&self.elements[1..5]);
(case_ok, value) (case_ok, value)
} }
pub fn as_key(
pub fn as_anchored_key(
&self, &self,
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
) -> (BoolTarget, Target, KeyOrWildcardTarget) { ) -> (BoolTarget, Target, LiteralOrWildcardTarget) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Key)); let prefix = builder.constant(F::from(StatementTmplArgPrefix::AnchoredKey));
let case_ok = builder.is_equal(self.elements[0], prefix); let case_ok = builder.is_equal(self.elements[0], prefix);
let id_wildcard_index = self.elements[1]; let id_wildcard_index = self.elements[1];
let value_key_or_wildcard = KeyOrWildcardTarget::from_slice(&self.elements[5..9]); let value_key_or_wildcard = LiteralOrWildcardTarget::from_slice(&self.elements[5..9]);
(case_ok, id_wildcard_index, value_key_or_wildcard) (case_ok, id_wildcard_index, value_key_or_wildcard)
} }
pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, Target) { pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, Target) {
let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral)); let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral));
let case_ok = builder.is_equal(self.elements[0], prefix); let case_ok = builder.is_equal(self.elements[0], prefix);
let wildcard_index = self.elements[1]; let wildcard_index = self.elements[1];
(case_ok, wildcard_index) (case_ok, wildcard_index)
} }
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
st_tmpl_arg: &StatementTmplArg,
) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &st_tmpl_arg.to_fields(params))?)
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -357,6 +443,17 @@ pub struct StatementTmplTarget {
pub args: Vec<StatementTmplArgTarget>, pub args: Vec<StatementTmplArgTarget>,
} }
impl StatementTmplTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
st_tmpl: &StatementTmpl,
) -> Result<()> {
Ok(pw.set_target_arr(&self.flatten(), &st_tmpl.to_fields(params))?)
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct CustomPredicateTarget { pub struct CustomPredicateTarget {
pub conjunction: BoolTarget, pub conjunction: BoolTarget,
@ -365,6 +462,17 @@ pub struct CustomPredicateTarget {
pub args_len: Target, pub args_len: Target,
} }
impl CustomPredicateTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
custom_predicate: &CustomPredicate,
) -> Result<()> {
Ok(pw.set_target_arr(&self.flatten(), &custom_predicate.to_fields(params))?)
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct CustomPredicateBatchTarget { pub struct CustomPredicateBatchTarget {
pub predicates: Vec<CustomPredicateTarget>, pub predicates: Vec<CustomPredicateTarget>,
@ -375,6 +483,161 @@ impl CustomPredicateBatchTarget {
let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect(); let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect();
builder.hash_n_to_hash_no_pad::<PoseidonHash>(flattened) builder.hash_n_to_hash_no_pad::<PoseidonHash>(flattened)
} }
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
custom_predicate_batch: &CustomPredicateBatch,
) -> Result<()> {
let pad_predicate = CustomPredicate::empty();
for (i, predicate) in custom_predicate_batch
.predicates()
.iter()
.chain(iter::repeat(&pad_predicate))
.take(params.max_custom_batch_size)
.enumerate()
{
self.predicates[i].set_targets(pw, params, predicate)?;
}
Ok(())
}
}
/// Custom predicate table entry
pub struct CustomPredicateEntryTarget {
pub id: HashOutTarget,
pub index: Target,
pub predicate: CustomPredicateTarget,
}
impl CustomPredicateEntryTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
predicate: &CustomPredicateRef,
) -> Result<()> {
pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?;
pw.set_target(self.index, F::from_canonical_usize(predicate.index))?;
self.predicate
.set_targets(pw, params, predicate.predicate())?;
Ok(())
}
}
impl Flattenable for CustomPredicateEntryTarget {
fn flatten(&self) -> Vec<Target> {
self.id
.elements
.iter()
.chain(iter::once(&self.index))
.chain(self.predicate.flatten().iter())
.cloned()
.collect()
}
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
Self {
id: HashOutTarget::from_flattened(params, &vs[0..4]),
index: vs[4],
predicate: CustomPredicateTarget::from_flattened(params, &vs[5..]),
}
}
}
impl CustomPredicateEntryTarget {
pub fn hash(&self, builder: &mut CircuitBuilder<F, D>) -> HashOutTarget {
builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.flatten())
}
}
// Custom predicate verification table entry
pub struct CustomPredicateVerifyEntryTarget {
pub custom_predicate_table_index: Target,
pub custom_predicate: CustomPredicateEntryTarget,
pub args: Vec<ValueTarget>,
pub query: CustomPredicateVerifyQueryTarget,
}
impl CustomPredicateVerifyEntryTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
cpv: &CustomPredicateVerification,
) -> Result<()> {
pw.set_target(
self.custom_predicate_table_index,
F::from_canonical_usize(cpv.custom_predicate_table_index),
)?;
self.custom_predicate
.set_targets(pw, params, &cpv.custom_predicate)?;
let pad_arg = WildcardValue::None;
for (arg_target, arg) in self.args.iter().zip_eq(
cpv.args
.iter()
.chain(iter::repeat(&pad_arg))
.take(params.max_custom_predicate_wildcards),
) {
arg_target.set_targets(pw, &Value::from(arg.raw()))?;
}
let pad_op_arg = Statement(Predicate::Native(NativePredicate::None), vec![]);
for (op_arg_target, op_arg) in self.query.op_args.iter().zip_eq(
cpv.op_args
.iter()
.chain(iter::repeat(&pad_op_arg))
.take(params.max_operation_args),
) {
op_arg_target.set_targets(pw, params, op_arg)?
}
Ok(())
}
}
/// Query for the custom predicate verification table
pub struct CustomPredicateVerifyQueryTarget {
pub statement: StatementTarget,
pub op_type: OperationTypeTarget,
pub op_args: Vec<StatementTarget>,
}
impl CustomPredicateVerifyQueryTarget {
pub fn hash(&self, builder: &mut CircuitBuilder<F, D>) -> HashOutTarget {
builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.flatten())
}
}
impl Flattenable for CustomPredicateVerifyQueryTarget {
fn flatten(&self) -> Vec<Target> {
self.statement
.flatten()
.iter()
.chain(self.op_type.elements.iter())
.cloned()
.chain(self.op_args.iter().flat_map(|op_arg| op_arg.flatten()))
.collect()
}
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
let (pos, size) = (0, params.statement_size());
let statement = StatementTarget::from_flattened(params, &vs[pos..pos + size]);
let (pos, size) = (pos + size, params.operation_size());
let op_type = OperationTypeTarget {
elements: vs[pos..pos + size]
.try_into()
.expect("len = operation_type_size"),
};
let (pos, size) = (pos + size, params.statement_size());
let op_args = (0..params.max_operation_args)
.map(|i| {
StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size])
})
.collect();
Self {
statement,
op_type,
op_args,
}
}
} }
/// Trait for target structs that may be converted to and from vectors /// Trait for target structs that may be converted to and from vectors
@ -408,6 +671,27 @@ impl From<MerkleClaimAndProofTarget> for MerkleClaimTarget {
} }
} }
impl Flattenable for HashOutTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
assert_eq!(vs.len(), HASH_SIZE);
Self {
elements: array::from_fn(|i| vs[i]),
}
}
}
impl Flattenable for ValueTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
Self::from_slice(vs)
}
}
impl Flattenable for MerkleClaimTarget { impl Flattenable for MerkleClaimTarget {
fn flatten(&self) -> Vec<Target> { fn flatten(&self) -> Vec<Target> {
[ [
@ -543,8 +827,17 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]); fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
fn add_virtual_value(&mut self) -> ValueTarget; fn add_virtual_value(&mut self) -> ValueTarget;
fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget; fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget;
fn add_virtual_statement_arg(&mut self) -> StatementArgTarget;
fn add_virtual_predicate(&mut self) -> PredicateTarget; fn add_virtual_predicate(&mut self) -> PredicateTarget;
fn add_virtual_operation_type(&mut self) -> OperationTypeTarget;
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget; fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget;
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget;
fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget;
fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget;
fn add_virtual_custom_predicate_batch(&mut self, params: &Params)
-> CustomPredicateBatchTarget;
fn add_virtual_custom_predicate_entry(&mut self, params: &Params)
-> CustomPredicateEntryTarget;
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget; fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget;
fn constant_value(&mut self, v: RawValue) -> ValueTarget; fn constant_value(&mut self, v: RawValue) -> ValueTarget;
@ -604,6 +897,9 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
// Convenience methods for Boolean into-iters. // Convenience methods for Boolean into-iters.
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget; fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget; fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
// Return a bit-mask of size `len` that selects all positions lower than `n`
fn lt_mask(&mut self, len: usize, n: Target) -> Vec<BoolTarget>;
} }
impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> { impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
@ -629,22 +925,32 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
StatementTarget { StatementTarget {
predicate, predicate,
args: (0..params.max_statement_args) args: (0..params.max_statement_args)
.map(|_| StatementArgTarget { .map(|_| self.add_virtual_statement_arg())
elements: self.add_virtual_target_arr(),
})
.collect(), .collect(),
} }
} }
fn add_virtual_statement_arg(&mut self) -> StatementArgTarget {
StatementArgTarget {
elements: self.add_virtual_target_arr(),
}
}
fn add_virtual_predicate(&mut self) -> PredicateTarget { fn add_virtual_predicate(&mut self) -> PredicateTarget {
PredicateTarget { PredicateTarget {
elements: self.add_virtual_target_arr(), elements: self.add_virtual_target_arr(),
} }
} }
fn add_virtual_operation_type(&mut self) -> OperationTypeTarget {
OperationTypeTarget {
elements: self.add_virtual_target_arr(),
}
}
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget {
OperationTarget { OperationTarget {
op_type: self.add_virtual_target_arr(), op_type: self.add_virtual_operation_type(),
args: (0..params.max_operation_args) args: (0..params.max_operation_args)
.map(|_| self.add_virtual_target_arr()) .map(|_| self.add_virtual_target_arr())
.collect(), .collect(),
@ -652,6 +958,55 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
} }
} }
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget {
StatementTmplArgTarget {
elements: self.add_virtual_target_arr(),
}
}
fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget {
let args = (0..params.max_statement_args)
.map(|_| self.add_virtual_statement_tmpl_arg())
.collect();
StatementTmplTarget {
pred: self.add_virtual_predicate(),
args,
}
}
fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget {
let statements = (0..params.max_custom_predicate_arity)
.map(|_| self.add_virtual_statement_tmpl(params))
.collect();
CustomPredicateTarget {
conjunction: self.add_virtual_bool_target_safe(),
statements,
args_len: self.add_virtual_target(),
}
}
fn add_virtual_custom_predicate_batch(
&mut self,
params: &Params,
) -> CustomPredicateBatchTarget {
CustomPredicateBatchTarget {
predicates: (0..params.max_custom_batch_size)
.map(|_| self.add_virtual_custom_predicate(params))
.collect(),
}
}
fn add_virtual_custom_predicate_entry(
&mut self,
params: &Params,
) -> CustomPredicateEntryTarget {
CustomPredicateEntryTarget {
id: self.add_virtual_hash(),
index: self.add_virtual_target(),
predicate: self.add_virtual_custom_predicate(params),
}
}
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget { fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget {
ValueTarget { ValueTarget {
elements: std::array::from_fn(|i| self.select(b, x.elements[i], y.elements[i])), elements: std::array::from_fn(|i| self.select(b, x.elements[i], y.elements[i])),
@ -876,6 +1231,12 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
) )
} }
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
// The idea would be the following: Take the array `ts` and hash each element. Then do the
// random access on the hash result. Finally "unhash" to recover the resolved element.
// We don't want to hash each element from the array each time, so we should cache the hashed
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
// then do `ts: &[HashCache<T>]`.
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T { fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
// TODO: Revisit this when we need more than 64 statements. // TODO: Revisit this when we need more than 64 statements.
let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| { let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| {
@ -944,6 +1305,73 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
.reduce(|a, b| self.or(a, b)) .reduce(|a, b| self.or(a, b))
.unwrap_or(self._false()) .unwrap_or(self._false())
} }
fn lt_mask(&mut self, len: usize, n: Target) -> Vec<BoolTarget> {
let zero = self.zero();
let mask: Vec<_> = (0..len)
.map(|_| self.add_virtual_bool_target_safe())
.collect();
self.add_simple_generator(LtMaskGenerator {
n,
mask: mask.iter().map(|bt| bt.target).collect(),
});
// We have `n` ones in the mask
let mask_sum = mask
.iter()
.map(|b| b.target)
.reduce(|acc, x| self.add(acc, x))
.unwrap_or(zero);
self.connect(n, mask_sum);
// The elements in the mask can only transition from 1 to 0 or 0 to 0.
for i in 0..len - 1 {
let diff = self.sub(mask[i].target, mask[i + 1].target);
self.assert_bool(BoolTarget::new_unsafe(diff));
}
mask
}
}
#[derive(Debug, Default)]
pub struct LtMaskGenerator {
pub(crate) n: Target,
pub(crate) mask: Vec<Target>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for LtMaskGenerator {
fn id(&self) -> String {
"LtMaskGenerator".to_string()
}
fn dependencies(&self) -> Vec<Target> {
vec![self.n]
}
fn run_once(
&self,
witness: &PartitionWitness<F>,
out_buffer: &mut GeneratedValues<F>,
) -> anyhow::Result<()> {
let n = witness.get_target(self.n).to_canonical_u64();
for (i, mask_i) in self.mask.iter().enumerate() {
let v = if (i as u64) < n { F::ONE } else { F::ZERO };
out_buffer.set_target(*mask_i, v)?;
}
Ok(())
}
fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData<F, D>) -> IoResult<()> {
dst.write_target(self.n)?;
dst.write_target_vec(&self.mask)
}
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
let n = src.read_target()?;
let mask = src.read_target_vec()?;
Ok(Self { n, mask })
}
} }
#[cfg(test)] #[cfg(test)]
@ -1013,13 +1441,14 @@ pub(crate) mod tests {
let custom_predicate_batch = eth_friend_batch(&params)?; let custom_predicate_batch = eth_friend_batch(&params)?;
for (i, cp) in custom_predicate_batch.predicates.iter().enumerate() { for (i, cp) in custom_predicate_batch.predicates().iter().enumerate() {
let mut builder = CircuitBuilder::<F, D>::new(config.clone()); let mut builder = CircuitBuilder::<F, D>::new(config.clone());
let flattened = cp.to_fields(&params); let flattened = cp.to_fields(&params);
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec(); let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
let cp_target = CustomPredicateTarget::from_flattened(&params, &flatteend_target); let cp_target = CustomPredicateTarget::from_flattened(&params, &flatteend_target);
// Round trip of from_flattened to flattened // Round trip of from_flattened to flattened
let flatteend_target_rt = cp_target.flatten(); let flatteend_target_rt = cp_target.flatten();
// TODO: Instead of connect, assign witness to result
builder.connect_slice(&flatteend_target, &flatteend_target_rt); builder.connect_slice(&flatteend_target, &flatteend_target_rt);
let pw = PartialWitness::<F>::new(); let pw = PartialWitness::<F>::new();
@ -1033,51 +1462,22 @@ pub(crate) mod tests {
Ok(()) Ok(())
} }
fn test_custom_predicate_batch_target_id( fn helper_custom_predicate_batch_target_id(
params: &Params, params: &Params,
custom_predicate_batch: &CustomPredicateBatch, custom_predicate_batch: &CustomPredicateBatch,
) -> frontend::Result<()> { ) -> Result<()> {
let config = CircuitConfig::standard_recursion_config(); let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config); let mut builder = CircuitBuilder::<F, D>::new(config);
let zero = builder.zero(); let custom_predicate_batch_target = builder.add_virtual_custom_predicate_batch(params);
let predicate_targets = custom_predicate_batch
.predicates
.iter()
.map(|cp| {
let flattened = cp.to_fields(params);
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
CustomPredicateTarget::from_flattened(params, &flatteend_target)
})
.chain(iter::repeat({
let empty_flatteend_target = iter::repeat(zero)
.take(params.custom_predicate_size())
.collect_vec();
CustomPredicateTarget::from_flattened(params, &empty_flatteend_target)
}))
.take(params.max_custom_batch_size)
.collect();
let custom_predicate_batch_target = CustomPredicateBatchTarget {
predicates: predicate_targets,
};
// Calculate the id in constraints and compare it against the id calculated natively // Calculate the id in constraints and compare it against the id calculated natively
let id_target = custom_predicate_batch_target.id(&mut builder); let id_target = custom_predicate_batch_target.id(&mut builder);
let id = custom_predicate_batch.id(params);
let id_expected_target = HashOutTarget { let mut pw = PartialWitness::<F>::new();
elements: id custom_predicate_batch_target.set_targets(&mut pw, params, custom_predicate_batch)?;
.to_fields(params) let id = custom_predicate_batch.id();
.iter() pw.set_target_arr(&id_target.elements, &id.0)?;
.map(|v| builder.constant(*v))
.collect_vec()
.try_into()
.unwrap(),
};
builder.connect_array(id_target.elements, id_expected_target.elements);
let pw = PartialWitness::<F>::new();
// generate & verify proof // generate & verify proof
let data = builder.build::<C>(); let data = builder.build::<C>();
@ -1088,7 +1488,7 @@ pub(crate) mod tests {
} }
#[test] #[test]
fn custom_predicate_batch_target() -> frontend::Result<()> { fn test_custom_predicate_batch_target_id() -> frontend::Result<()> {
let params = Params { let params = Params {
max_statement_args: 6, max_statement_args: 6,
max_custom_predicate_wildcards: 12, max_custom_predicate_wildcards: 12,
@ -1096,17 +1496,21 @@ pub(crate) mod tests {
}; };
// Empty case // Empty case
let mut cpb_builder = CustomPredicateBatchBuilder::new("empty".into()); let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
_ = cpb_builder.predicate_and("empty", &params, &[], &[], &[])?; _ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
let custom_predicate_batch = cpb_builder.finish(); let custom_predicate_batch = cpb_builder.finish();
test_custom_predicate_batch_target_id(&params, &custom_predicate_batch)?; helper_custom_predicate_batch_target_id(&params, &custom_predicate_batch).unwrap();
// Some cases from the examples // Some cases from the examples
let custom_predicate_batch = eth_friend_batch(&params)?; let custom_predicate_batch = eth_friend_batch(&params)?;
test_custom_predicate_batch_target_id(&params, &custom_predicate_batch)?; helper_custom_predicate_batch_target_id(&params, &custom_predicate_batch).unwrap();
let custom_predicate_batch = eth_dos_batch(&params)?; let custom_predicate_batch = eth_dos_batch(&params)?;
test_custom_predicate_batch_target_id(&params, &custom_predicate_batch)?; helper_custom_predicate_batch_target_id(&params, &custom_predicate_batch).unwrap();
let custom_predicate_batch =
CustomPredicateBatch::new(&params, "empty".to_string(), vec![CustomPredicate::empty()]);
helper_custom_predicate_batch_target_id(&params, &custom_predicate_batch).unwrap();
Ok(()) Ok(())
} }

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
pub mod operation; pub mod operation;
pub mod statement; pub mod statement;
use std::any::Any; use std::{any::Any, sync::Arc};
use itertools::Itertools; use itertools::Itertools;
pub use operation::*; pub use operation::*;
@ -17,14 +17,17 @@ pub use statement::*;
use crate::{ use crate::{
backends::plonky2::{ backends::plonky2::{
basetypes::{C, D}, basetypes::{C, D},
circuits::mainpod::{MainPodVerifyCircuit, MainPodVerifyInput}, circuits::mainpod::{
CustomPredicateVerification, MainPodVerifyCircuit, MainPodVerifyInput,
},
error::{Error, Result}, error::{Error, Result},
primitives::merkletree::MerkleClaimAndProof, primitives::merkletree::MerkleClaimAndProof,
signedpod::SignedPod, signedpod::SignedPod,
}, },
middleware::{ middleware::{
self, AnchoredKey, DynError, Hash, MainPodInputs, NativeOperation, NonePod, OperationType, self, resolve_wildcard_values, AnchoredKey, CustomPredicateBatch, DynError, Hash,
Params, Pod, PodId, PodProver, PodType, StatementArg, ToFields, F, KEY_TYPE, SELF, MainPodInputs, NativeOperation, NonePod, OperationType, Params, Pod, PodId, PodProver,
PodType, StatementArg, ToFields, F, KEY_TYPE, SELF,
}, },
}; };
@ -37,7 +40,71 @@ pub(crate) fn hash_statements(statements: &[Statement], _params: &Params) -> mid
Hash(PoseidonHash::hash_no_pad(&field_elems).elements) Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
} }
/// Extracts and pads Merkle proofs from Contains/NotContains ops. /// Extracts unique `CustomPredicateBatch`es from Custom ops.
pub(crate) fn extract_custom_predicate_batches(
params: &Params,
operations: &[middleware::Operation],
) -> Result<Vec<Arc<CustomPredicateBatch>>> {
let custom_predicate_batches: Vec<_> = operations
.iter()
.flat_map(|op| match op {
middleware::Operation::Custom(cpr, _) => Some(cpr.batch.clone()),
_ => None,
})
.unique_by(|cpr| cpr.id())
.collect();
if custom_predicate_batches.len() > params.max_custom_predicate_batches {
return Err(Error::custom(format!(
"The number of required `CustomPredicateBatch`es ({}) exceeds the maximum number ({}).",
custom_predicate_batches.len(),
params.max_custom_predicate_batches
)));
}
Ok(custom_predicate_batches)
}
/// Extracts all custom predicate operations with all the data required to verify them.
pub(crate) fn extract_custom_predicate_verifications(
params: &Params,
operations: &[middleware::Operation],
custom_predicate_batches: &[Arc<CustomPredicateBatch>],
) -> Result<Vec<CustomPredicateVerification>> {
let custom_predicate_data: Vec<_> = operations
.iter()
.flat_map(|op| match op {
middleware::Operation::Custom(cpr, sts) => Some((cpr, sts)),
_ => None,
})
.map(|(cpr, sts)| {
let wildcard_values =
resolve_wildcard_values(params, cpr.predicate(), sts).expect("resolved wildcards");
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
let batch_index = custom_predicate_batches
.iter()
.enumerate()
.find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i))
.expect("find the custom predicate from the extracted unique list");
let custom_predicate_table_index =
batch_index * params.max_custom_predicate_batches + cpr.index;
CustomPredicateVerification {
custom_predicate_table_index,
custom_predicate: cpr.clone(),
args: wildcard_values,
op_args: sts,
}
})
.collect();
if custom_predicate_data.len() > params.max_custom_predicate_verifications {
return Err(Error::custom(format!(
"The number of required custom predicate verifications ({}) exceeds the maximum number ({}).",
custom_predicate_data.len(),
params.max_custom_predicate_verifications
)));
}
Ok(custom_predicate_data)
}
/// Extracts Merkle proofs from Contains/NotContains ops.
pub(crate) fn extract_merkle_proofs( pub(crate) fn extract_merkle_proofs(
params: &Params, params: &Params,
operations: &[middleware::Operation], operations: &[middleware::Operation],
@ -98,11 +165,32 @@ fn find_op_arg(statements: &[Statement], op_arg: &middleware::Statement) -> Resu
} }
/// Find the operation auxiliary data in the list of auxiliary data and return the index. /// Find the operation auxiliary data in the list of auxiliary data and return the index.
// NOTE: The `custom_predicate_verifications` is optional because in the MainPod we want to store
// the index of a custom predicate verification in the aux data, but in the MockMainPod we don't
// need that because we keep a reference to the custom predicate in the operation type, which
// removes the need for indexing. We could change the OperationType and Predicate for the backend
// to not keep a reference to the custom predicate and instead just keep the id and index and then
// do the same double indexing that the MainPod does to verify custom predicates.
fn find_op_aux( fn find_op_aux(
merkle_proofs: &[MerkleClaimAndProof], merkle_proofs: &[MerkleClaimAndProof],
op_aux: &middleware::OperationAux, custom_predicate_verifications: Option<&[CustomPredicateVerification]>,
op: &middleware::Operation,
) -> Result<OperationAux> { ) -> Result<OperationAux> {
match op_aux { let op_aux = op.aux();
let op_type = op.op_type();
if let (OperationType::Custom(cpr), Some(cpvs)) = (op_type, custom_predicate_verifications) {
return Ok(cpvs
.iter()
.enumerate()
.find_map(|(i, cpv)| {
(cpv.custom_predicate.batch.id() == cpr.batch.id()
&& cpv.custom_predicate.index == cpr.index)
.then_some(i)
})
.map(OperationAux::CustomPredVerifyIndex)
.expect("custom predicate verification in the list"));
}
match &op_aux {
middleware::OperationAux::None => Ok(OperationAux::None), middleware::OperationAux::None => Ok(OperationAux::None),
middleware::OperationAux::MerkleProof(pf_arg) => merkle_proofs middleware::OperationAux::MerkleProof(pf_arg) => merkle_proofs
.iter() .iter()
@ -217,6 +305,7 @@ pub(crate) fn process_private_statements_operations(
params: &Params, params: &Params,
statements: &[Statement], statements: &[Statement],
merkle_proofs: &[MerkleClaimAndProof], merkle_proofs: &[MerkleClaimAndProof],
custom_predicate_verifications: Option<&[CustomPredicateVerification]>,
input_operations: &[middleware::Operation], input_operations: &[middleware::Operation],
) -> Result<Vec<Operation>> { ) -> Result<Vec<Operation>> {
let mut operations = Vec::new(); let mut operations = Vec::new();
@ -231,8 +320,7 @@ pub(crate) fn process_private_statements_operations(
.map(|mid_arg| find_op_arg(statements, mid_arg)) .map(|mid_arg| find_op_arg(statements, mid_arg))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let mid_aux = op.aux(); let aux = find_op_aux(merkle_proofs, custom_predicate_verifications, &op)?;
let aux = find_op_aux(merkle_proofs, &mid_aux)?;
pad_operation_args(params, &mut args); pad_operation_args(params, &mut args);
operations.push(Operation(op.op_type(), args, aux)); operations.push(Operation(op.op_type(), args, aux));
@ -301,12 +389,19 @@ impl Prover {
.collect_vec(); .collect_vec();
let merkle_proofs = extract_merkle_proofs(params, inputs.operations)?; let merkle_proofs = extract_merkle_proofs(params, inputs.operations)?;
let custom_predicate_batches = extract_custom_predicate_batches(params, inputs.operations)?;
let custom_predicate_verifications = extract_custom_predicate_verifications(
params,
inputs.operations,
&custom_predicate_batches,
)?;
let statements = layout_statements(params, &inputs); let statements = layout_statements(params, &inputs);
let operations = process_private_statements_operations( let operations = process_private_statements_operations(
params, params,
&statements, &statements,
&merkle_proofs, &merkle_proofs,
Some(&custom_predicate_verifications),
inputs.operations, inputs.operations,
)?; )?;
let operations = process_public_statements_operations(params, &statements, operations)?; let operations = process_public_statements_operations(params, &statements, operations)?;
@ -321,6 +416,8 @@ impl Prover {
statements: statements[statements.len() - params.max_statements..].to_vec(), statements: statements[statements.len() - params.max_statements..].to_vec(),
operations, operations,
merkle_proofs, merkle_proofs,
custom_predicate_batches,
custom_predicate_verifications,
}; };
main_pod.set_targets(&mut pw, &input)?; main_pod.set_targets(&mut pw, &input)?;
@ -505,4 +602,41 @@ pub mod tests {
let pod = (kyc_pod.pod as Box<dyn Any>).downcast::<MainPod>().unwrap(); let pod = (kyc_pod.pod as Box<dyn Any>).downcast::<MainPod>().unwrap();
pod.verify().unwrap() pod.verify().unwrap()
} }
#[test]
fn test_mainpod_small_empty() {
let params = middleware::Params {
max_input_signed_pods: 0,
max_input_main_pods: 0,
max_statements: 5,
max_signed_pod_values: 2,
max_public_statements: 2,
max_statement_args: 2,
max_operation_args: 3,
max_custom_predicate_batches: 2,
max_custom_predicate_verifications: 2,
max_custom_predicate_arity: 2,
max_custom_predicate_wildcards: 2,
max_custom_batch_size: 2,
max_merkle_proofs: 2,
max_depth_mt_gadget: 4,
};
let pod_builder = frontend::MainPodBuilder::new(&params);
// Mock
let mut prover = MockProver {};
let kyc_pod = pod_builder.prove(&mut prover, &params).unwrap();
let pod = (kyc_pod.pod as Box<dyn Any>)
.downcast::<MockMainPod>()
.unwrap();
pod.verify().unwrap();
println!("{:#}", pod);
// Real
let mut prover = Prover {};
let kyc_pod = pod_builder.prove(&mut prover, &params).unwrap();
let pod = (kyc_pod.pod as Box<dyn Any>).downcast::<MainPod>().unwrap();
pod.verify().unwrap()
}
} }

View file

@ -38,15 +38,17 @@ impl OperationArg {
pub enum OperationAux { pub enum OperationAux {
None, None,
MerkleProofIndex(usize), MerkleProofIndex(usize),
CustomPredVerifyIndex(usize),
} }
impl ToFields for OperationAux { impl ToFields for OperationAux {
fn to_fields(&self, _params: &Params) -> Vec<F> { fn to_fields(&self, _params: &Params) -> Vec<F> {
let f = match self { let fs = match self {
Self::None => F::ZERO, Self::None => [F::ZERO, F::ZERO],
Self::MerkleProofIndex(i) => F::from_canonical_usize(*i), Self::MerkleProofIndex(i) => [F::from_canonical_usize(*i), F::ZERO],
Self::CustomPredVerifyIndex(i) => [F::ZERO, F::from_canonical_usize(*i)],
}; };
vec![f] vec![fs[0], fs[1]]
} }
} }
@ -78,6 +80,7 @@ impl Operation {
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let deref_aux = match self.2 { let deref_aux = match self.2 {
OperationAux::None => crate::middleware::OperationAux::None, OperationAux::None => crate::middleware::OperationAux::None,
OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None,
OperationAux::MerkleProofIndex(i) => crate::middleware::OperationAux::MerkleProof( OperationAux::MerkleProofIndex(i) => crate::middleware::OperationAux::MerkleProof(
merkle_proofs merkle_proofs
.get(i) .get(i)
@ -111,6 +114,7 @@ impl fmt::Display for Operation {
match self.2 { match self.2 {
OperationAux::None => (), OperationAux::None => (),
OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?, OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?,
OperationAux::CustomPredVerifyIndex(i) => write!(f, " custom_pred_verify_{:02}", i)?,
} }
Ok(()) Ok(())
} }

View file

@ -147,6 +147,7 @@ impl MockMainPod {
params, params,
&statements, &statements,
&merkle_proofs, &merkle_proofs,
None,
inputs.operations, inputs.operations,
)?; )?;
let operations = process_public_statements_operations(params, &statements, operations)?; let operations = process_public_statements_operations(params, &statements, operations)?;

View file

@ -12,10 +12,9 @@ use crate::{
/// Instantiates an ETH friend batch /// Instantiates an ETH friend batch
pub fn eth_friend_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> { pub fn eth_friend_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into()); let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "eth_friend".into());
let _eth_friend = builder.predicate_and( let _eth_friend = builder.predicate_and(
"eth_friend", "eth_friend",
params,
// arguments: // arguments:
&["src_ori", "src_key", "dst_ori", "dst_key"], &["src_ori", "src_key", "dst_ori", "dst_key"],
// private arguments: // private arguments:
@ -44,7 +43,8 @@ pub fn eth_friend_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
/// Instantiates an ETHDoS batch /// Instantiates an ETHDoS batch
pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> { pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
let eth_friend = Predicate::Custom(CustomPredicateRef::new(eth_friend_batch(params)?, 0)); let eth_friend = Predicate::Custom(CustomPredicateRef::new(eth_friend_batch(params)?, 0));
let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into()); let mut builder =
CustomPredicateBatchBuilder::new(params.clone(), "eth_dos_distance_base".into());
// eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and< // eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and<
// eq(src_or, src_key, dst_or, dst_key), // eq(src_or, src_key, dst_or, dst_key),
@ -52,7 +52,6 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
// > // >
let eth_dos_distance_base = builder.predicate_and( let eth_dos_distance_base = builder.predicate_and(
"eth_dos_distance_base", "eth_dos_distance_base",
params,
&[ &[
// arguments: // arguments:
"src_ori", "src_ori",
@ -83,7 +82,6 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
let eth_dos_distance_ind = builder.predicate_and( let eth_dos_distance_ind = builder.predicate_and(
"eth_dos_distance_ind", "eth_dos_distance_ind",
params,
&[ &[
// arguments: // arguments:
"src_ori", "src_ori",
@ -135,7 +133,6 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
let _eth_dos_distance = builder.predicate_or( let _eth_dos_distance = builder.predicate_or(
"eth_dos_distance", "eth_dos_distance",
params,
&[ &[
"src_ori", "src_ori",
"src_key", "src_key",

View file

@ -128,13 +128,15 @@ impl StatementTmplBuilder {
} }
pub struct CustomPredicateBatchBuilder { pub struct CustomPredicateBatchBuilder {
params: Params,
pub name: String, pub name: String,
pub predicates: Vec<CustomPredicate>, pub predicates: Vec<CustomPredicate>,
} }
impl CustomPredicateBatchBuilder { impl CustomPredicateBatchBuilder {
pub fn new(name: String) -> Self { pub fn new(params: Params, name: String) -> Self {
Self { Self {
params,
name, name,
predicates: Vec::new(), predicates: Vec::new(),
} }
@ -143,23 +145,21 @@ impl CustomPredicateBatchBuilder {
pub fn predicate_and( pub fn predicate_and(
&mut self, &mut self,
name: &str, name: &str,
params: &Params,
args: &[&str], args: &[&str],
priv_args: &[&str], priv_args: &[&str],
sts: &[StatementTmplBuilder], sts: &[StatementTmplBuilder],
) -> Result<Predicate> { ) -> Result<Predicate> {
self.predicate(name, params, true, args, priv_args, sts) self.predicate(name, true, args, priv_args, sts)
} }
pub fn predicate_or( pub fn predicate_or(
&mut self, &mut self,
name: &str, name: &str,
params: &Params,
args: &[&str], args: &[&str],
priv_args: &[&str], priv_args: &[&str],
sts: &[StatementTmplBuilder], sts: &[StatementTmplBuilder],
) -> Result<Predicate> { ) -> Result<Predicate> {
self.predicate(name, params, false, args, priv_args, sts) self.predicate(name, false, args, priv_args, sts)
} }
/// creates the custom predicate from the given input, adds it to the /// creates the custom predicate from the given input, adds it to the
@ -167,24 +167,23 @@ impl CustomPredicateBatchBuilder {
fn predicate( fn predicate(
&mut self, &mut self,
name: &str, name: &str,
params: &Params,
conjunction: bool, conjunction: bool,
args: &[&str], args: &[&str],
priv_args: &[&str], priv_args: &[&str],
sts: &[StatementTmplBuilder], sts: &[StatementTmplBuilder],
) -> Result<Predicate> { ) -> Result<Predicate> {
if args.len() > params.max_statement_args { if args.len() > self.params.max_statement_args {
return Err(Error::max_length( return Err(Error::max_length(
"args.len".to_string(), "args.len".to_string(),
args.len(), args.len(),
params.max_statement_args, self.params.max_statement_args,
)); ));
} }
if (args.len() + priv_args.len()) > params.max_custom_predicate_wildcards { if (args.len() + priv_args.len()) > self.params.max_custom_predicate_wildcards {
return Err(Error::max_length( return Err(Error::max_length(
"wildcards.len".to_string(), "wildcards.len".to_string(),
args.len() + priv_args.len(), args.len() + priv_args.len(),
params.max_custom_predicate_wildcards, self.params.max_custom_predicate_wildcards,
)); ));
} }
@ -197,7 +196,7 @@ impl CustomPredicateBatchBuilder {
.iter() .iter()
.map(|a| match a { .map(|a| match a {
BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()),
BuilderArg::Key(pod_id, key) => StatementTmplArg::Key( BuilderArg::Key(pod_id, key) => StatementTmplArg::AnchoredKey(
resolve_wildcard(args, priv_args, pod_id), resolve_wildcard(args, priv_args, pod_id),
resolve_key_or_wildcard(args, priv_args, key), resolve_key_or_wildcard(args, priv_args, key),
), ),
@ -212,17 +211,19 @@ impl CustomPredicateBatchBuilder {
} }
}) })
.collect(); .collect();
let custom_predicate = let custom_predicate = CustomPredicate::new(
CustomPredicate::new(name.into(), params, conjunction, statements, args.len())?; &self.params,
name.into(),
conjunction,
statements,
args.len(),
)?;
self.predicates.push(custom_predicate); self.predicates.push(custom_predicate);
Ok(Predicate::BatchSelf(self.predicates.len() - 1)) Ok(Predicate::BatchSelf(self.predicates.len() - 1))
} }
pub fn finish(self) -> Arc<CustomPredicateBatch> { pub fn finish(self) -> Arc<CustomPredicateBatch> {
Arc::new(CustomPredicateBatch { CustomPredicateBatch::new(&self.params, self.name, self.predicates)
name: self.name,
predicates: self.predicates,
})
} }
} }
@ -290,7 +291,7 @@ mod tests {
#[test] #[test]
fn test_desugared_gt_custom_pred() -> Result<()> { fn test_desugared_gt_custom_pred() -> Result<()> {
let params = Params::default(); let params = Params::default();
let mut builder = CustomPredicateBatchBuilder::new("gt_custom_pred".into()); let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into());
let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt) let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt)
.arg(("s1_origin", "s1_key")) .arg(("s1_origin", "s1_key"))
@ -298,7 +299,6 @@ mod tests {
builder.predicate_and( builder.predicate_and(
"gt_custom_pred", "gt_custom_pred",
&params,
&["s1_origin", "s1_key", "s2_origin", "s2_key"], &["s1_origin", "s1_key", "s2_origin", "s2_key"],
&[], &[],
&[gt_stb], &[gt_stb],
@ -322,7 +322,7 @@ mod tests {
// Check that the desugared predicate is the same as the one in the statement template // Check that the desugared predicate is the same as the one in the statement template
assert_eq!( assert_eq!(
desugared_gt.predicate(), desugared_gt.predicate(),
*batch_clone.predicates[0].statements[0].pred() *batch_clone.predicates()[0].statements[0].pred()
); );
// Check that our custom predicate matches the statement template // Check that our custom predicate matches the statement template
@ -339,7 +339,8 @@ mod tests {
#[test] #[test]
fn test_desugared_set_contains_custom_pred() -> Result<()> { fn test_desugared_set_contains_custom_pred() -> Result<()> {
let params = Params::default(); let params = Params::default();
let mut builder = CustomPredicateBatchBuilder::new("set_contains_custom_pred".into()); let mut builder =
CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into());
let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains) let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains)
.arg(("s1_origin", "s1_key")) .arg(("s1_origin", "s1_key"))
@ -347,7 +348,6 @@ mod tests {
builder.predicate_and( builder.predicate_and(
"set_contains_custom_pred", "set_contains_custom_pred",
&params,
&["s1_origin", "s1_key", "s2_origin", "s2_key"], &["s1_origin", "s1_key", "s2_origin", "s2_key"],
&[], &[],
&[set_contains_stb], &[set_contains_stb],
@ -368,7 +368,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
set_contains.predicate(), set_contains.predicate(),
*batch_clone.predicates[0].statements[0].pred() *batch_clone.predicates()[0].statements[0].pred()
); );
let set_contains_custom_pred = CustomPredicateRef::new(batch, 0); let set_contains_custom_pred = CustomPredicateRef::new(batch, 0);

View file

@ -466,7 +466,7 @@ impl MainPodBuilder {
)))?, )))?,
}, },
OperationType::Custom(cpr) => { OperationType::Custom(cpr) => {
let pred = &cpr.batch.predicates[cpr.index]; let pred = &cpr.batch.predicates()[cpr.index];
if pred.statements.len() != args.len() { if pred.statements.len() != args.len() {
return Err(Error::custom(format!( return Err(Error::custom(format!(
"Custom predicate operation needs {} statements but has {}.", "Custom predicate operation needs {} statements but has {}.",

View file

@ -5,7 +5,8 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::middleware::{ use crate::middleware::{
hash_fields, Error, Hash, Key, Params, Predicate, Result, ToFields, Value, F, HASH_SIZE, hash_fields, Error, Hash, Key, NativePredicate, Params, Predicate, Result, ToFields, Value,
EMPTY_HASH, F, HASH_SIZE, VALUE_SIZE,
}; };
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@ -49,12 +50,15 @@ impl fmt::Display for KeyOrWildcard {
} }
impl ToFields for KeyOrWildcard { impl ToFields for KeyOrWildcard {
// Encoding:
// - Key(k) => [[k]]
// - Wildcard(index) => [[index], 0, 0, 0]
fn to_fields(&self, params: &Params) -> Vec<F> { fn to_fields(&self, params: &Params) -> Vec<F> {
match self { match self {
KeyOrWildcard::Key(k) => k.hash().to_fields(params), KeyOrWildcard::Key(k) => k.hash().to_fields(params),
KeyOrWildcard::Wildcard(wc) => iter::once(F::ZERO) KeyOrWildcard::Wildcard(wc) => iter::once(F::from_canonical_u64(wc.index as u64))
.take(HASH_SIZE - 1) .chain(iter::repeat(F::ZERO))
.chain(iter::once(F::from_canonical_u64(wc.index as u64))) .take(HASH_SIZE)
.collect(), .collect(),
} }
} }
@ -66,7 +70,7 @@ pub enum StatementTmplArg {
None, None,
Literal(Value), Literal(Value),
// AnchoredKey // AnchoredKey
Key(Wildcard, KeyOrWildcard), AnchoredKey(Wildcard, KeyOrWildcard),
// TODO: This naming is a bit confusing: a WildcardLiteral that contains a Wildcard... // TODO: This naming is a bit confusing: a WildcardLiteral that contains a Wildcard...
// Could we merge WildcardValue and Value and allow wildcard value apart from pod_id and key? // Could we merge WildcardValue and Value and allow wildcard value apart from pod_id and key?
WildcardLiteral(Wildcard), WildcardLiteral(Wildcard),
@ -76,7 +80,7 @@ pub enum StatementTmplArg {
pub enum StatementTmplArgPrefix { pub enum StatementTmplArgPrefix {
None = 0, None = 0,
Literal = 1, Literal = 1,
Key = 2, AnchoredKey = 2,
WildcardLiteral = 3, WildcardLiteral = 3,
} }
@ -88,11 +92,11 @@ impl From<StatementTmplArgPrefix> for F {
impl ToFields for StatementTmplArg { impl ToFields for StatementTmplArg {
fn to_fields(&self, params: &Params) -> Vec<F> { fn to_fields(&self, params: &Params) -> Vec<F> {
// None => (0, ...) // Encoding:
// Literal(value) => (1, [value], 0, 0, 0, 0) // None => (0, 0, 0, 0, 0, 0, 0, 0, 0)
// Key(wildcard1_index, key_or_wildcard2) // Literal(v) => (1, [v ], 0, 0, 0, 0)
// => (2, [wildcard1_index], 0, 0, 0, [key_or_wildcard2]) // Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
// WildcardLiteral(wildcard_index) => (3, [wildcard_index], 0, 0, 0, 0, 0, 0, 0) // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements // In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
match self { match self {
StatementTmplArg::None => { StatementTmplArg::None => {
@ -105,13 +109,15 @@ impl ToFields for StatementTmplArg {
StatementTmplArg::Literal(v) => { StatementTmplArg::Literal(v) => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal)) let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal))
.chain(v.raw().to_fields(params)) .chain(v.raw().to_fields(params))
.chain(iter::repeat(F::ZERO).take(HASH_SIZE)) .chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect(); .collect();
fields fields
} }
StatementTmplArg::Key(wc1, kw2) => { StatementTmplArg::AnchoredKey(wc1, kw2) => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Key)) let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::AnchoredKey))
.chain(wc1.to_fields(params)) .chain(wc1.to_fields(params))
.chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1))
.chain(kw2.to_fields(params)) .chain(kw2.to_fields(params))
.collect(); .collect();
fields fields
@ -119,7 +125,8 @@ impl ToFields for StatementTmplArg {
StatementTmplArg::WildcardLiteral(wc) => { StatementTmplArg::WildcardLiteral(wc) => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral)) let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
.chain(wc.to_fields(params)) .chain(wc.to_fields(params))
.chain(iter::repeat(F::ZERO).take(HASH_SIZE)) .chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect(); .collect();
fields fields
} }
@ -132,7 +139,7 @@ impl fmt::Display for StatementTmplArg {
match self { match self {
Self::None => write!(f, "none"), Self::None => write!(f, "none"),
Self::Literal(v) => write!(f, "{}", v), Self::Literal(v) => write!(f, "{}", v),
Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key), Self::AnchoredKey(pod_id, key) => write!(f, "({}, {})", pod_id, key),
Self::WildcardLiteral(v) => write!(f, "{}", v), Self::WildcardLiteral(v) => write!(f, "{}", v),
} }
} }
@ -177,7 +184,11 @@ impl ToFields for StatementTmpl {
// instead of at the `to_fields` method, where we should assume that the // instead of at the `to_fields` method, where we should assume that the
// values are already valid // values are already valid
if self.args.len() > params.max_statement_args { if self.args.len() > params.max_statement_args {
panic!("Statement template has too many arguments"); panic!(
"Statement template has too many arguments {} > {}",
self.args.len(),
params.max_statement_args
);
} }
let mut fields: Vec<F> = self let mut fields: Vec<F> = self
@ -206,25 +217,36 @@ pub struct CustomPredicate {
} }
impl CustomPredicate { impl CustomPredicate {
pub fn empty() -> Self {
Self {
name: "empty".to_string(),
conjunction: false,
statements: vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::None),
args: vec![],
}],
args_len: 0,
}
}
pub fn and( pub fn and(
name: String,
params: &Params, params: &Params,
name: String,
statements: Vec<StatementTmpl>, statements: Vec<StatementTmpl>,
args_len: usize, args_len: usize,
) -> Result<Self> { ) -> Result<Self> {
Self::new(name, params, true, statements, args_len) Self::new(params, name, true, statements, args_len)
} }
pub fn or( pub fn or(
name: String,
params: &Params, params: &Params,
name: String,
statements: Vec<StatementTmpl>, statements: Vec<StatementTmpl>,
args_len: usize, args_len: usize,
) -> Result<Self> { ) -> Result<Self> {
Self::new(name, params, false, statements, args_len) Self::new(params, name, false, statements, args_len)
} }
pub fn new( pub fn new(
name: String,
params: &Params, params: &Params,
name: String,
conjunction: bool, conjunction: bool,
statements: Vec<StatementTmpl>, statements: Vec<StatementTmpl>,
args_len: usize, args_len: usize,
@ -236,6 +258,13 @@ impl CustomPredicate {
params.max_custom_predicate_arity, params.max_custom_predicate_arity,
)); ));
} }
if args_len > params.max_statement_args {
return Err(Error::max_length(
"statement_args.len".to_string(),
args_len,
params.max_statement_args,
));
}
Ok(Self { Ok(Self {
name, name,
@ -244,6 +273,16 @@ impl CustomPredicate {
args_len, args_len,
}) })
} }
pub fn pad_statement_tmpl(&self) -> StatementTmpl {
StatementTmpl {
pred: Predicate::Native(if self.conjunction {
NativePredicate::False
} else {
NativePredicate::None
}),
args: vec![],
}
}
} }
impl ToFields for CustomPredicate { impl ToFields for CustomPredicate {
@ -262,11 +301,17 @@ impl ToFields for CustomPredicate {
panic!("Custom predicate depends on too many statements"); panic!("Custom predicate depends on too many statements");
} }
let mut fields: Vec<F> = iter::once(F::from_bool(self.conjunction)) let pad_st = self.pad_statement_tmpl();
let fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
.chain(iter::once(F::from_canonical_usize(self.args_len))) .chain(iter::once(F::from_canonical_usize(self.args_len)))
.chain(self.statements.iter().flat_map(|st| st.to_fields(params))) .chain(
self.statements
.iter()
.chain(iter::repeat(&pad_st))
.take(params.max_custom_predicate_arity)
.flat_map(|st| st.to_fields(params)),
)
.collect(); .collect();
fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0));
fields fields
} }
} }
@ -298,8 +343,9 @@ impl fmt::Display for CustomPredicate {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct CustomPredicateBatch { pub struct CustomPredicateBatch {
id: Hash,
pub name: String, pub name: String,
pub predicates: Vec<CustomPredicate>, predicates: Vec<CustomPredicate>,
} }
impl ToFields for CustomPredicateBatch { impl ToFields for CustomPredicateBatch {
@ -313,27 +359,45 @@ impl ToFields for CustomPredicateBatch {
panic!("Predicate batch exceeds maximum size"); panic!("Predicate batch exceeds maximum size");
} }
let mut fields: Vec<F> = self let pad_pred = CustomPredicate::empty();
let fields: Vec<F> = self
.predicates .predicates
.iter() .iter()
.chain(iter::repeat(&pad_pred))
.take(params.max_custom_batch_size)
.flat_map(|p| p.to_fields(params)) .flat_map(|p| p.to_fields(params))
.collect(); .collect();
fields.resize_with(params.custom_predicate_batch_size_field_elts(), || {
F::from_canonical_u64(0)
});
fields fields
} }
} }
impl CustomPredicateBatch { impl CustomPredicateBatch {
pub fn new(params: &Params, name: String, predicates: Vec<CustomPredicate>) -> Arc<Self> {
let mut cpb = Self {
id: EMPTY_HASH,
name,
predicates,
};
let id = cpb.calculate_id(params);
cpb.id = id;
Arc::new(cpb)
}
/// Cryptographic identifier for the batch. /// Cryptographic identifier for the batch.
pub fn id(&self, params: &Params) -> Hash { fn calculate_id(&self, params: &Params) -> Hash {
// NOTE: This implementation just hashes the concatenation of all the custom predicates, // NOTE: This implementation just hashes the concatenation of all the custom predicates,
// but ideally we want to use the root of a merkle tree built from the custom predicates. // but ideally we want to use the root of a merkle tree built from the custom predicates.
let input = self.to_fields(params); let input = self.to_fields(params);
hash_fields(&input) hash_fields(&input)
} }
pub fn id(&self) -> Hash {
self.id
}
pub fn predicates(&self) -> &[CustomPredicate] {
&self.predicates
}
} }
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@ -347,13 +411,16 @@ impl CustomPredicateRef {
Self { batch, index } Self { batch, index }
} }
pub fn arg_len(&self) -> usize { pub fn arg_len(&self) -> usize {
self.batch.predicates[self.index].args_len self.predicate().args_len
}
pub fn predicate(&self) -> &CustomPredicate {
&self.batch.predicates[self.index]
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{array, sync::Arc}; use std::array;
use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::goldilocks_field::GoldilocksField;
@ -392,28 +459,29 @@ mod tests {
p:value_of(Constant, 2), p:value_of(Constant, 2),
p:product_of(S1, Constant, S2) p:product_of(S1, Constant, S2)
*/ */
let cust_pred_batch = Arc::new(CustomPredicateBatch { let cust_pred_batch = CustomPredicateBatch::new(
name: "is_double".to_string(), &params,
predicates: vec![CustomPredicate::and( "is_double".to_string(),
"_".into(), vec![CustomPredicate::and(
&params, &params,
"_".into(),
vec![ vec![
st( st(
P::Native(NP::ValueOf), P::Native(NP::ValueOf),
vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(2.into())], vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(2.into())],
), ),
st( st(
P::Native(NP::ProductOf), P::Native(NP::ProductOf),
vec![ vec![
STA::Key(wc(0), kow_wc(1)), STA::AnchoredKey(wc(0), kow_wc(1)),
STA::Key(wc(4), kow_wc(5)), STA::AnchoredKey(wc(4), kow_wc(5)),
STA::Key(wc(2), kow_wc(3)), STA::AnchoredKey(wc(2), kow_wc(3)),
], ],
), ),
], ],
2, 2,
)?], )?],
}); );
let custom_statement = Statement::Custom( let custom_statement = Statement::Custom(
CustomPredicateRef::new(cust_pred_batch.clone(), 0), CustomPredicateRef::new(cust_pred_batch.clone(), 0),
@ -444,55 +512,57 @@ mod tests {
fn ethdos_test() -> Result<()> { fn ethdos_test() -> Result<()> {
let params = Params { let params = Params {
max_custom_predicate_wildcards: 12, max_custom_predicate_wildcards: 12,
max_statement_args: 6,
..Default::default() ..Default::default()
}; };
let eth_friend_cp = CustomPredicate::and( let eth_friend_cp = CustomPredicate::and(
"eth_friend_cp".into(),
&params, &params,
"eth_friend_cp".into(),
vec![ vec![
st( st(
P::Native(NP::ValueOf), P::Native(NP::ValueOf),
vec![ vec![
STA::Key(wc(4), KeyOrWildcard::Key("type".into())), STA::AnchoredKey(wc(4), KeyOrWildcard::Key("type".into())),
STA::Literal(PodType::Signed.into()), STA::Literal(PodType::Signed.into()),
], ],
), ),
st( st(
P::Native(NP::Equal), P::Native(NP::Equal),
vec![ vec![
STA::Key(wc(4), KeyOrWildcard::Key("signer".into())), STA::AnchoredKey(wc(4), KeyOrWildcard::Key("signer".into())),
STA::Key(wc(0), kow_wc(1)), STA::AnchoredKey(wc(0), kow_wc(1)),
], ],
), ),
st( st(
P::Native(NP::Equal), P::Native(NP::Equal),
vec![ vec![
STA::Key(wc(4), KeyOrWildcard::Key("attestation".into())), STA::AnchoredKey(wc(4), KeyOrWildcard::Key("attestation".into())),
STA::Key(wc(2), kow_wc(3)), STA::AnchoredKey(wc(2), kow_wc(3)),
], ],
), ),
], ],
4, 4,
)?; )?;
let eth_friend_batch = Arc::new(CustomPredicateBatch { let eth_friend_batch =
name: "eth_friend".to_string(), CustomPredicateBatch::new(&params, "eth_friend".to_string(), vec![eth_friend_cp]);
predicates: vec![eth_friend_cp],
});
// 0 // 0
let eth_dos_base = CustomPredicate::and( let eth_dos_base = CustomPredicate::and(
"eth_dos_base".into(),
&params, &params,
"eth_dos_base".into(),
vec![ vec![
st( st(
P::Native(NP::Equal), P::Native(NP::Equal),
vec![STA::Key(wc(0), kow_wc(1)), STA::Key(wc(2), kow_wc(3))], vec![
STA::AnchoredKey(wc(0), kow_wc(1)),
STA::AnchoredKey(wc(2), kow_wc(3)),
],
), ),
st( st(
P::Native(NP::ValueOf), P::Native(NP::ValueOf),
vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(0.into())], vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(0.into())],
), ),
], ],
6, 6,
@ -500,8 +570,8 @@ mod tests {
// 1 // 1
let eth_dos_ind = CustomPredicate::and( let eth_dos_ind = CustomPredicate::and(
"eth_dos_ind".into(),
&params, &params,
"eth_dos_ind".into(),
vec![ vec![
st( st(
P::BatchSelf(2), P::BatchSelf(2),
@ -516,14 +586,14 @@ mod tests {
), ),
st( st(
P::Native(NP::ValueOf), P::Native(NP::ValueOf),
vec![STA::Key(wc(6), kow_wc(7)), STA::Literal(1.into())], vec![STA::AnchoredKey(wc(6), kow_wc(7)), STA::Literal(1.into())],
), ),
st( st(
P::Native(NP::SumOf), P::Native(NP::SumOf),
vec![ vec![
STA::Key(wc(4), kow_wc(5)), STA::AnchoredKey(wc(4), kow_wc(5)),
STA::Key(wc(8), kow_wc(9)), STA::AnchoredKey(wc(8), kow_wc(9)),
STA::Key(wc(6), kow_wc(7)), STA::AnchoredKey(wc(6), kow_wc(7)),
], ],
), ),
st( st(
@ -541,8 +611,8 @@ mod tests {
// 2 // 2
let eth_dos_distance_either = CustomPredicate::or( let eth_dos_distance_either = CustomPredicate::or(
"eth_dos_distance_either".into(),
&params, &params,
"eth_dos_distance_either".into(),
vec![ vec![
st( st(
P::BatchSelf(0), P::BatchSelf(0),
@ -570,10 +640,11 @@ mod tests {
6, 6,
)?; )?;
let eth_dos_distance_batch = Arc::new(CustomPredicateBatch { let eth_dos_distance_batch = CustomPredicateBatch::new(
name: "ETHDoS_distance".to_string(), &params,
predicates: vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either], "ETHDoS_distance".to_string(),
}); vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either],
);
// Some POD IDs // Some POD IDs
let pod_id1 = PodId(Hash(array::from_fn(|i| GoldilocksField(i as u64)))); let pod_id1 = PodId(Hash(array::from_fn(|i| GoldilocksField(i as u64))));

View file

@ -584,6 +584,10 @@ pub struct Params {
pub max_public_statements: usize, pub max_public_statements: usize,
pub max_statement_args: usize, pub max_statement_args: usize,
pub max_operation_args: usize, pub max_operation_args: usize,
// max number of custom predicates batches that a MainPod can use
pub max_custom_predicate_batches: usize,
// max number of operations using custom predicates that can be verified in the MainPod
pub max_custom_predicate_verifications: usize,
// max number of statements that can be ANDed or ORed together // max number of statements that can be ANDed or ORed together
// in a custom predicate // in a custom predicate
pub max_custom_predicate_arity: usize, pub max_custom_predicate_arity: usize,
@ -605,6 +609,8 @@ impl Default for Params {
max_public_statements: 10, max_public_statements: 10,
max_statement_args: 5, max_statement_args: 5,
max_operation_args: 5, max_operation_args: 5,
max_custom_predicate_batches: 2,
max_custom_predicate_verifications: 5,
max_custom_predicate_arity: 5, max_custom_predicate_arity: 5,
max_custom_predicate_wildcards: 10, max_custom_predicate_wildcards: 10,
max_custom_batch_size: 5, max_custom_batch_size: 5,

View file

@ -1,4 +1,4 @@
use std::{fmt, iter, sync::Arc}; use std::{fmt, iter};
use log::error; use log::error;
use plonky2::field::types::Field; use plonky2::field::types::Field;
@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
backends::plonky2::primitives::merkletree::MerkleProof, backends::plonky2::primitives::merkletree::MerkleProof,
middleware::{ middleware::{
custom::KeyOrWildcard, AnchoredKey, CustomPredicateBatch, CustomPredicateRef, Error, custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error,
NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmplArg, NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmplArg,
ToFields, Wildcard, WildcardValue, F, SELF, ToFields, Wildcard, WildcardValue, F, SELF,
}, },
@ -36,6 +36,9 @@ impl fmt::Display for OperationAux {
} }
impl ToFields for OperationType { impl ToFields for OperationType {
/// Encoding:
/// - Native(native_op) => [1, [native_op], 0, 0, 0, 0]
/// - Custom(batch, index) => [3, [batch.id], index]
fn to_fields(&self, params: &Params) -> Vec<F> { fn to_fields(&self, params: &Params) -> Vec<F> {
let mut fields: Vec<F> = match self { let mut fields: Vec<F> = match self {
Self::Native(p) => iter::once(F::from_canonical_u64(1)) Self::Native(p) => iter::once(F::from_canonical_u64(1))
@ -43,7 +46,7 @@ impl ToFields for OperationType {
.collect(), .collect(),
Self::Custom(CustomPredicateRef { batch, index }) => { Self::Custom(CustomPredicateRef { batch, index }) => {
iter::once(F::from_canonical_u64(3)) iter::once(F::from_canonical_u64(3))
.chain(batch.id(params).0) .chain(batch.id().0)
.chain(iter::once(F::from_canonical_usize(*index))) .chain(iter::once(F::from_canonical_usize(*index)))
.collect() .collect()
} }
@ -321,7 +324,7 @@ impl Operation {
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
if batch == &cpr.batch && index == &cpr.index => if batch == &cpr.batch && index == &cpr.index =>
{ {
check_custom_pred(params, batch, *index, args, s_args) check_custom_pred(params, cpr, args, s_args)
} }
_ => Err(Error::invalid_deduction( _ => Err(Error::invalid_deduction(
self.clone(), self.clone(),
@ -360,7 +363,7 @@ pub fn check_st_tmpl(
(StatementTmplArg::None, StatementArg::None) => true, (StatementTmplArg::None, StatementArg::None) => true,
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true, (StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true,
( (
StatementTmplArg::Key(pod_id_wc, key_or_wc), StatementTmplArg::AnchoredKey(pod_id_wc, key_or_wc),
StatementArg::Key(AnchoredKey { pod_id, key }), StatementArg::Key(AnchoredKey { pod_id, key }),
) => { ) => {
let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map); let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map);
@ -379,14 +382,46 @@ pub fn check_st_tmpl(
} }
} }
pub fn resolve_wildcard_values(
params: &Params,
pred: &CustomPredicate,
args: &[Statement],
) -> Option<Vec<WildcardValue>> {
// Check that all wildcard have consistent values as assigned in the statements while storing a
// map of their values.
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
// disjunctions we expect Statement::None for the unused statements.
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
for (st_tmpl, st) in pred.statements.iter().zip(args) {
let st_args = st.args();
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
// TODO: Better errors. Example:
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
// println!("{} doesn't match {}", st, st_tmpl);
return None;
}
}
}
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
// they are beyond the number of used wildcards in this custom predicate, or they could be
// private arguments that are unused in a particular disjunction.
Some(
wildcard_map
.into_iter()
.map(|opt| opt.unwrap_or(WildcardValue::None))
.collect(),
)
}
fn check_custom_pred( fn check_custom_pred(
params: &Params, params: &Params,
batch: &Arc<CustomPredicateBatch>, custom_pred_ref: &CustomPredicateRef,
index: usize,
args: &[Statement], args: &[Statement],
s_args: &[WildcardValue], s_args: &[WildcardValue],
) -> Result<bool> { ) -> Result<bool> {
let pred = &batch.predicates[index]; let pred = custom_pred_ref.predicate();
if pred.statements.len() != args.len() { if pred.statements.len() != args.len() {
return Err(Error::diff_amount( return Err(Error::diff_amount(
"custom predicate operation".to_string(), "custom predicate operation".to_string(),
@ -404,26 +439,12 @@ fn check_custom_pred(
)); ));
} }
// Check that all wildcard have consistent values as assigned in the statements while storing a // Count the number of statements that match the templates by predicate.
// map of their values. Count the number of statements that match the templates by predicate.
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
// disjunctions we expect Statement::None for the unused statements.
let mut num_matches = 0; let mut num_matches = 0;
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
for (st_tmpl, st) in pred.statements.iter().zip(args) { for (st_tmpl, st) in pred.statements.iter().zip(args) {
let st_args = st.args();
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
// TODO: Better errors. Example:
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
// println!("{} doesn't match {}", st, st_tmpl);
return Ok(false);
}
}
let st_tmpl_pred = match &st_tmpl.pred { let st_tmpl_pred = match &st_tmpl.pred {
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef { Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
batch: batch.clone(), batch: custom_pred_ref.batch.clone(),
index: *i, index: *i,
}), }),
p => p.clone(), p => p.clone(),
@ -433,9 +454,14 @@ fn check_custom_pred(
} }
} }
let wildcard_map = match resolve_wildcard_values(params, pred, args) {
Some(wc_map) => wc_map,
None => return Ok(false),
};
// Check that the resolved wildcard match the statement arguments. // Check that the resolved wildcard match the statement arguments.
for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) { for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) {
if !wc_value.as_ref().is_none_or(|wc_value| *wc_value == *s_arg) { if *wc_value != *s_arg {
return Ok(false); return Ok(false);
} }
} }

View file

@ -7,7 +7,7 @@ use strum_macros::FromRepr;
use crate::middleware::{ use crate::middleware::{
AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, RawValue, Result, ToFields, Value, AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, RawValue, Result, ToFields, Value,
F, VALUE_SIZE, EMPTY_VALUE, F, VALUE_SIZE,
}; };
// TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static // TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static
@ -17,22 +17,23 @@ pub const KEY_SIGNER: &str = "_signer";
pub const KEY_TYPE: &str = "_type"; pub const KEY_TYPE: &str = "_type";
pub const STATEMENT_ARG_F_LEN: usize = 8; pub const STATEMENT_ARG_F_LEN: usize = 8;
pub const OPERATION_ARG_F_LEN: usize = 1; pub const OPERATION_ARG_F_LEN: usize = 1;
pub const OPERATION_AUX_F_LEN: usize = 1; pub const OPERATION_AUX_F_LEN: usize = 2;
#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
pub enum NativePredicate { pub enum NativePredicate {
None = 0, None = 0, // Always true
ValueOf = 1, False = 1, // Always false
Equal = 2, ValueOf = 2,
NotEqual = 3, Equal = 3,
LtEq = 4, NotEqual = 4,
Lt = 5, LtEq = 5,
Contains = 6, Lt = 6,
NotContains = 7, Contains = 7,
SumOf = 8, NotContains = 8,
ProductOf = 9, SumOf = 9,
MaxOf = 10, ProductOf = 10,
HashOf = 11, MaxOf = 11,
HashOf = 12,
// Syntactic sugar predicates. These predicates are not supported by the backend. The // Syntactic sugar predicates. These predicates are not supported by the backend. The
// frontend compiler is responsible of translating these predicates into the predicates above. // frontend compiler is responsible of translating these predicates into the predicates above.
@ -53,6 +54,7 @@ impl ToFields for NativePredicate {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub enum WildcardValue { pub enum WildcardValue {
None,
PodId(PodId), PodId(PodId),
Key(Key), Key(Key),
} }
@ -60,6 +62,7 @@ pub enum WildcardValue {
impl WildcardValue { impl WildcardValue {
pub fn raw(&self) -> RawValue { pub fn raw(&self) -> RawValue {
match self { match self {
WildcardValue::None => EMPTY_VALUE,
WildcardValue::PodId(pod_id) => RawValue::from(pod_id.0), WildcardValue::PodId(pod_id) => RawValue::from(pod_id.0),
WildcardValue::Key(key) => key.raw(), WildcardValue::Key(key) => key.raw(),
} }
@ -69,6 +72,7 @@ impl WildcardValue {
impl fmt::Display for WildcardValue { impl fmt::Display for WildcardValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
WildcardValue::None => write!(f, "none"),
WildcardValue::PodId(pod_id) => write!(f, "{}", pod_id), WildcardValue::PodId(pod_id) => write!(f, "{}", pod_id),
WildcardValue::Key(key) => write!(f, "{}", key), WildcardValue::Key(key) => write!(f, "{}", key),
} }
@ -77,10 +81,7 @@ impl fmt::Display for WildcardValue {
impl ToFields for WildcardValue { impl ToFields for WildcardValue {
fn to_fields(&self, params: &Params) -> Vec<F> { fn to_fields(&self, params: &Params) -> Vec<F> {
match self { self.raw().to_fields(params)
WildcardValue::PodId(pod_id) => pod_id.to_fields(params),
WildcardValue::Key(key) => key.to_fields(params),
}
} }
} }
@ -130,7 +131,7 @@ impl ToFields for Predicate {
.collect(), .collect(),
Self::Custom(CustomPredicateRef { batch, index }) => { Self::Custom(CustomPredicateRef { batch, index }) => {
iter::once(F::from(PredicatePrefix::Custom)) iter::once(F::from(PredicatePrefix::Custom))
.chain(batch.id(params).0) .chain(batch.id().0)
.chain(iter::once(F::from_canonical_usize(*index))) .chain(iter::once(F::from_canonical_usize(*index)))
.collect() .collect()
} }
@ -149,7 +150,9 @@ impl fmt::Display for Predicate {
write!( write!(
f, f,
"{}.{}[{}]", "{}.{}[{}]",
batch.name, index, batch.predicates[*index].name batch.name,
index,
batch.predicates()[*index].name
) )
} }
} }
@ -397,14 +400,14 @@ impl StatementArg {
} }
impl ToFields for StatementArg { impl ToFields for StatementArg {
fn to_fields(&self, _params: &Params) -> Vec<F> { /// Encoding:
// NOTE: current version returns always the same amount of field elements in the returned /// - None => [0, 0, 0, 0, 0, 0, 0, 0]
// vector, which means that the `None` case is padded with 8 zeroes, and the `Literal` case /// - Literal(v) => [[v], 0, 0, 0, 0]
// is padded with 4 zeroes. Since the returned vector will mostly be hashed (and reproduced /// - Key(pod_id, key) => [[pod_id], [key]]
// in-circuit), we might be interested into reducing the length of it. If that's the case, /// - WildcardLiteral(v) => [[v], 0, 0, 0, 0]
// we can check if it makes sense to make it dependant on the concrete StatementArg; that fn to_fields(&self, params: &Params) -> Vec<F> {
// is, when dealing with a `None` it would be a single field element (zero value), and when // NOTE for @ax0: I removed the old comment because may `to_fields` implementations do
// dealing with `Literal` it would be of length 4. // padding and we need fixed output length for the circuits.
let f = match self { let f = match self {
StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN], StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN],
StatementArg::Literal(v) => v StatementArg::Literal(v) => v
@ -414,8 +417,8 @@ impl ToFields for StatementArg {
.chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE)) .chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE))
.collect(), .collect(),
StatementArg::Key(ak) => { StatementArg::Key(ak) => {
let mut fields = ak.pod_id.to_fields(_params); let mut fields = ak.pod_id.to_fields(params);
fields.extend(ak.key.to_fields(_params)); fields.extend(ak.key.to_fields(params));
fields fields
} }
StatementArg::WildcardLiteral(v) => v StatementArg::WildcardLiteral(v) => v