Constraints for custom predicates (#227)
* add target types for custom predicates * simplify * fix clippy * fix typo * don't use ref for NativePredicate * fix wrong len * precalculate CustomPredicateBatch id * wip * wip * move code back * great progress * wip * code complete, hopefully; missing tests * fill aux for custom predicate op * fix clippy warnings * fix typos * fix test import * fix missing assignment in lt_mask, test custom_operation_verify_gadget * fix mistake * wip * fix * debug revert except for let entry = CustomPredicateVerifyEntryTarget * fix batch_id calculation by fixing padding * oops * remove completed TODOs
This commit is contained in:
parent
4fa9e20ecd
commit
024ed8bd04
12 changed files with 1597 additions and 291 deletions
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
use std::{array, iter};
|
use std::{array, iter};
|
||||||
|
|
||||||
|
use itertools::Itertools;
|
||||||
use plonky2::{
|
use plonky2::{
|
||||||
field::{
|
field::{
|
||||||
extension::Extendable,
|
extension::Extendable,
|
||||||
|
|
@ -12,23 +13,28 @@ use plonky2::{
|
||||||
poseidon::PoseidonHash,
|
poseidon::PoseidonHash,
|
||||||
},
|
},
|
||||||
iop::{
|
iop::{
|
||||||
|
generator::{GeneratedValues, SimpleGenerator},
|
||||||
target::{BoolTarget, Target},
|
target::{BoolTarget, Target},
|
||||||
witness::{PartialWitness, WitnessWrite},
|
witness::{PartialWitness, PartitionWitness, Witness, WitnessWrite},
|
||||||
},
|
},
|
||||||
plonk::circuit_builder::CircuitBuilder,
|
plonk::{circuit_builder::CircuitBuilder, circuit_data::CommonCircuitData},
|
||||||
|
util::serialization::{Buffer, IoResult, Read, Write},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
backends::plonky2::{
|
backends::plonky2::{
|
||||||
basetypes::D,
|
basetypes::D,
|
||||||
|
circuits::mainpod::CustomPredicateVerification,
|
||||||
error::Result,
|
error::Result,
|
||||||
mainpod::{Operation, OperationArg, Statement},
|
mainpod::{Operation, OperationArg, Statement},
|
||||||
primitives::merkletree::MerkleClaimAndProofTarget,
|
primitives::merkletree::MerkleClaimAndProofTarget,
|
||||||
},
|
},
|
||||||
middleware::{
|
middleware::{
|
||||||
NativeOperation, NativePredicate, Params, Predicate, PredicatePrefix, RawValue,
|
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
|
||||||
StatementArg, StatementTmplArgPrefix, ToFields, EMPTY_VALUE, F, HASH_SIZE,
|
NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg,
|
||||||
OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE,
|
StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, WildcardValue,
|
||||||
|
EMPTY_VALUE, F, HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN,
|
||||||
|
VALUE_SIZE,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -65,6 +71,10 @@ impl ValueTarget {
|
||||||
elements: array::from_fn(|i| xs[i]),
|
elements: array::from_fn(|i| xs[i]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_targets(&self, pw: &mut PartialWitness<F>, value: &Value) -> Result<()> {
|
||||||
|
Ok(pw.set_target_arr(&self.elements, &value.raw().0)?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
|
@ -82,7 +92,7 @@ impl StatementArgTarget {
|
||||||
Ok(pw.set_target_arr(&self.elements, &arg.to_fields(params))?)
|
Ok(pw.set_target_arr(&self.elements, &arg.to_fields(params))?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new(first: ValueTarget, second: ValueTarget) -> Self {
|
pub fn new(first: ValueTarget, second: ValueTarget) -> Self {
|
||||||
let elements: Vec<_> = first.elements.into_iter().chain(second.elements).collect();
|
let elements: Vec<_> = first.elements.into_iter().chain(second.elements).collect();
|
||||||
StatementArgTarget {
|
StatementArgTarget {
|
||||||
elements: elements.try_into().expect("size STATEMENT_ARG_F_LEN"),
|
elements: elements.try_into().expect("size STATEMENT_ARG_F_LEN"),
|
||||||
|
|
@ -107,6 +117,11 @@ impl StatementArgTarget {
|
||||||
Self::new(*pod_id, *key)
|
Self::new(*pod_id, *key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn wildcard_literal(builder: &mut CircuitBuilder<F, D>, value: &ValueTarget) -> Self {
|
||||||
|
let empty = builder.constant_value(EMPTY_VALUE);
|
||||||
|
Self::new(*value, empty)
|
||||||
|
}
|
||||||
|
|
||||||
/// StatementArgTarget to ValueTarget coercion. Make sure to check
|
/// StatementArgTarget to ValueTarget coercion. Make sure to check
|
||||||
/// that the arg is a value using the `statement_arg_is_value` method
|
/// that the arg is a value using the `statement_arg_is_value` method
|
||||||
/// first!
|
/// first!
|
||||||
|
|
@ -138,6 +153,7 @@ impl<T> Build<T> for T {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StatementTarget {
|
impl StatementTarget {
|
||||||
|
/// Build a new native StatementTarget
|
||||||
pub fn new_native(
|
pub fn new_native(
|
||||||
builder: &mut CircuitBuilder<F, D>,
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
params: &Params,
|
params: &Params,
|
||||||
|
|
@ -187,10 +203,60 @@ impl StatementTarget {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OperationTypeTarget {
|
||||||
|
pub elements: [Target; Params::operation_type_size()],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OperationTypeTarget {
|
||||||
|
pub fn new_custom(
|
||||||
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
|
batch_id: HashOutTarget,
|
||||||
|
index: Target,
|
||||||
|
) -> Self {
|
||||||
|
// TODO: Use an enum for these prefixes
|
||||||
|
let three = builder.constant(F::from_canonical_usize(3));
|
||||||
|
let id = batch_id.elements;
|
||||||
|
Self {
|
||||||
|
elements: [three, id[0], id[1], id[2], id[3], index],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_custom(
|
||||||
|
&self,
|
||||||
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
|
) -> (BoolTarget, HashOutTarget, Target) {
|
||||||
|
// TODO: Use an enum for these prefixes
|
||||||
|
let three = builder.constant(F::from_canonical_usize(3));
|
||||||
|
let op_is_custom = builder.is_equal(self.elements[0], three);
|
||||||
|
let batch_id = HashOutTarget::from_vec(self.elements[1..5].to_vec());
|
||||||
|
let index = self.elements[5];
|
||||||
|
(op_is_custom, batch_id, index)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_native(&self, builder: &mut CircuitBuilder<F, D>, t: NativeOperation) -> BoolTarget {
|
||||||
|
// TODO: Use an enum for these prefixes
|
||||||
|
let one = builder.one();
|
||||||
|
let op_is_native = builder.is_equal(self.elements[0], one);
|
||||||
|
let op_code = builder.constant(F::from_canonical_u64(t as u64));
|
||||||
|
let op_code_matches = builder.is_equal(self.elements[1], op_code);
|
||||||
|
builder.and(op_is_native, op_code_matches)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_targets(
|
||||||
|
&self,
|
||||||
|
pw: &mut PartialWitness<F>,
|
||||||
|
params: &Params,
|
||||||
|
op_type: &OperationType,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(pw.set_target_arr(&self.elements, &op_type.to_fields(params))?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Implement Operation::to_field to determine the size of each element
|
// TODO: Implement Operation::to_field to determine the size of each element
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OperationTarget {
|
pub struct OperationTarget {
|
||||||
pub op_type: [Target; Params::operation_type_size()],
|
pub op_type: OperationTypeTarget,
|
||||||
pub args: Vec<[Target; OPERATION_ARG_F_LEN]>,
|
pub args: Vec<[Target; OPERATION_ARG_F_LEN]>,
|
||||||
pub aux: [Target; OPERATION_AUX_F_LEN],
|
pub aux: [Target; OPERATION_AUX_F_LEN],
|
||||||
}
|
}
|
||||||
|
|
@ -202,7 +268,7 @@ impl OperationTarget {
|
||||||
params: &Params,
|
params: &Params,
|
||||||
op: &Operation,
|
op: &Operation,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
pw.set_target_arr(&self.op_type, &op.op_type().to_fields(params))?;
|
self.op_type.set_targets(pw, params, &op.op_type())?;
|
||||||
for (i, arg) in op
|
for (i, arg) in op
|
||||||
.args()
|
.args()
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -215,18 +281,6 @@ impl OperationTarget {
|
||||||
pw.set_target_arr(&self.aux, &op.aux().to_fields(params))?;
|
pw.set_target_arr(&self.aux, &op.aux().to_fields(params))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn has_native_type(
|
|
||||||
&self,
|
|
||||||
builder: &mut CircuitBuilder<F, D>,
|
|
||||||
t: NativeOperation,
|
|
||||||
) -> BoolTarget {
|
|
||||||
let one = builder.one();
|
|
||||||
let op_is_native = builder.is_equal(self.op_type[0], one);
|
|
||||||
let op_code = builder.constant(F::from_canonical_u64(t as u64));
|
|
||||||
let op_code_matches = builder.is_equal(self.op_type[1], op_code);
|
|
||||||
builder.and(op_is_native, op_code_matches)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
|
@ -304,17 +358,37 @@ impl PredicateTarget {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Mirrors `middleware::KeyOrWildcard`
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct KeyOrWildcardTarget {
|
pub struct LiteralOrWildcardTarget {
|
||||||
pub elements: [Target; VALUE_SIZE],
|
pub elements: [Target; VALUE_SIZE],
|
||||||
}
|
}
|
||||||
|
|
||||||
impl KeyOrWildcardTarget {
|
impl LiteralOrWildcardTarget {
|
||||||
fn from_slice(v: &[Target]) -> Self {
|
fn from_slice(v: &[Target]) -> Self {
|
||||||
Self {
|
Self {
|
||||||
elements: v.try_into().expect("len is VALUE_SIZE"),
|
elements: v.try_into().expect("len is VALUE_SIZE"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/// cases: ((is_key, key), (is_wildcard, wildcard_index))
|
||||||
|
pub fn cases(
|
||||||
|
&self,
|
||||||
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
|
) -> ((BoolTarget, ValueTarget), (BoolTarget, Target)) {
|
||||||
|
let zero = builder.zero();
|
||||||
|
let is_zero_tail: Vec<_> = (1..4)
|
||||||
|
.map(|i| builder.is_equal(self.elements[i], zero))
|
||||||
|
.collect();
|
||||||
|
let is_wildcard = is_zero_tail
|
||||||
|
.into_iter()
|
||||||
|
.reduce(|acc, x| builder.and(acc, x))
|
||||||
|
.expect("len > 1");
|
||||||
|
let is_key = builder.not(is_wildcard);
|
||||||
|
let key = ValueTarget::from_slice(&self.elements);
|
||||||
|
let wildcard_index = self.elements[0];
|
||||||
|
|
||||||
|
((is_key, key), (is_wildcard, wildcard_index))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
|
@ -327,28 +401,40 @@ impl StatementTmplArgTarget {
|
||||||
let prefix = builder.constant(F::from(StatementTmplArgPrefix::None));
|
let prefix = builder.constant(F::from(StatementTmplArgPrefix::None));
|
||||||
builder.is_equal(self.elements[0], prefix)
|
builder.is_equal(self.elements[0], prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn as_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, ValueTarget) {
|
pub fn as_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, ValueTarget) {
|
||||||
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal));
|
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal));
|
||||||
let case_ok = builder.is_equal(self.elements[0], prefix);
|
let case_ok = builder.is_equal(self.elements[0], prefix);
|
||||||
let value = ValueTarget::from_slice(&self.elements[1..5]);
|
let value = ValueTarget::from_slice(&self.elements[1..5]);
|
||||||
(case_ok, value)
|
(case_ok, value)
|
||||||
}
|
}
|
||||||
pub fn as_key(
|
|
||||||
|
pub fn as_anchored_key(
|
||||||
&self,
|
&self,
|
||||||
builder: &mut CircuitBuilder<F, D>,
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
) -> (BoolTarget, Target, KeyOrWildcardTarget) {
|
) -> (BoolTarget, Target, LiteralOrWildcardTarget) {
|
||||||
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Key));
|
let prefix = builder.constant(F::from(StatementTmplArgPrefix::AnchoredKey));
|
||||||
let case_ok = builder.is_equal(self.elements[0], prefix);
|
let case_ok = builder.is_equal(self.elements[0], prefix);
|
||||||
let id_wildcard_index = self.elements[1];
|
let id_wildcard_index = self.elements[1];
|
||||||
let value_key_or_wildcard = KeyOrWildcardTarget::from_slice(&self.elements[5..9]);
|
let value_key_or_wildcard = LiteralOrWildcardTarget::from_slice(&self.elements[5..9]);
|
||||||
(case_ok, id_wildcard_index, value_key_or_wildcard)
|
(case_ok, id_wildcard_index, value_key_or_wildcard)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, Target) {
|
pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, Target) {
|
||||||
let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral));
|
let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral));
|
||||||
let case_ok = builder.is_equal(self.elements[0], prefix);
|
let case_ok = builder.is_equal(self.elements[0], prefix);
|
||||||
let wildcard_index = self.elements[1];
|
let wildcard_index = self.elements[1];
|
||||||
(case_ok, wildcard_index)
|
(case_ok, wildcard_index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_targets(
|
||||||
|
&self,
|
||||||
|
pw: &mut PartialWitness<F>,
|
||||||
|
params: &Params,
|
||||||
|
st_tmpl_arg: &StatementTmplArg,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(pw.set_target_arr(&self.elements, &st_tmpl_arg.to_fields(params))?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
|
@ -357,6 +443,17 @@ pub struct StatementTmplTarget {
|
||||||
pub args: Vec<StatementTmplArgTarget>,
|
pub args: Vec<StatementTmplArgTarget>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl StatementTmplTarget {
|
||||||
|
pub fn set_targets(
|
||||||
|
&self,
|
||||||
|
pw: &mut PartialWitness<F>,
|
||||||
|
params: &Params,
|
||||||
|
st_tmpl: &StatementTmpl,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(pw.set_target_arr(&self.flatten(), &st_tmpl.to_fields(params))?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct CustomPredicateTarget {
|
pub struct CustomPredicateTarget {
|
||||||
pub conjunction: BoolTarget,
|
pub conjunction: BoolTarget,
|
||||||
|
|
@ -365,6 +462,17 @@ pub struct CustomPredicateTarget {
|
||||||
pub args_len: Target,
|
pub args_len: Target,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CustomPredicateTarget {
|
||||||
|
pub fn set_targets(
|
||||||
|
&self,
|
||||||
|
pw: &mut PartialWitness<F>,
|
||||||
|
params: &Params,
|
||||||
|
custom_predicate: &CustomPredicate,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(pw.set_target_arr(&self.flatten(), &custom_predicate.to_fields(params))?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct CustomPredicateBatchTarget {
|
pub struct CustomPredicateBatchTarget {
|
||||||
pub predicates: Vec<CustomPredicateTarget>,
|
pub predicates: Vec<CustomPredicateTarget>,
|
||||||
|
|
@ -375,6 +483,161 @@ impl CustomPredicateBatchTarget {
|
||||||
let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect();
|
let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect();
|
||||||
builder.hash_n_to_hash_no_pad::<PoseidonHash>(flattened)
|
builder.hash_n_to_hash_no_pad::<PoseidonHash>(flattened)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_targets(
|
||||||
|
&self,
|
||||||
|
pw: &mut PartialWitness<F>,
|
||||||
|
params: &Params,
|
||||||
|
custom_predicate_batch: &CustomPredicateBatch,
|
||||||
|
) -> Result<()> {
|
||||||
|
let pad_predicate = CustomPredicate::empty();
|
||||||
|
for (i, predicate) in custom_predicate_batch
|
||||||
|
.predicates()
|
||||||
|
.iter()
|
||||||
|
.chain(iter::repeat(&pad_predicate))
|
||||||
|
.take(params.max_custom_batch_size)
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
self.predicates[i].set_targets(pw, params, predicate)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Custom predicate table entry
|
||||||
|
pub struct CustomPredicateEntryTarget {
|
||||||
|
pub id: HashOutTarget,
|
||||||
|
pub index: Target,
|
||||||
|
pub predicate: CustomPredicateTarget,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomPredicateEntryTarget {
|
||||||
|
pub fn set_targets(
|
||||||
|
&self,
|
||||||
|
pw: &mut PartialWitness<F>,
|
||||||
|
params: &Params,
|
||||||
|
predicate: &CustomPredicateRef,
|
||||||
|
) -> Result<()> {
|
||||||
|
pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?;
|
||||||
|
pw.set_target(self.index, F::from_canonical_usize(predicate.index))?;
|
||||||
|
self.predicate
|
||||||
|
.set_targets(pw, params, predicate.predicate())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Flattenable for CustomPredicateEntryTarget {
|
||||||
|
fn flatten(&self) -> Vec<Target> {
|
||||||
|
self.id
|
||||||
|
.elements
|
||||||
|
.iter()
|
||||||
|
.chain(iter::once(&self.index))
|
||||||
|
.chain(self.predicate.flatten().iter())
|
||||||
|
.cloned()
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
|
||||||
|
Self {
|
||||||
|
id: HashOutTarget::from_flattened(params, &vs[0..4]),
|
||||||
|
index: vs[4],
|
||||||
|
predicate: CustomPredicateTarget::from_flattened(params, &vs[5..]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomPredicateEntryTarget {
|
||||||
|
pub fn hash(&self, builder: &mut CircuitBuilder<F, D>) -> HashOutTarget {
|
||||||
|
builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.flatten())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom predicate verification table entry
|
||||||
|
pub struct CustomPredicateVerifyEntryTarget {
|
||||||
|
pub custom_predicate_table_index: Target,
|
||||||
|
pub custom_predicate: CustomPredicateEntryTarget,
|
||||||
|
pub args: Vec<ValueTarget>,
|
||||||
|
pub query: CustomPredicateVerifyQueryTarget,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomPredicateVerifyEntryTarget {
|
||||||
|
pub fn set_targets(
|
||||||
|
&self,
|
||||||
|
pw: &mut PartialWitness<F>,
|
||||||
|
params: &Params,
|
||||||
|
cpv: &CustomPredicateVerification,
|
||||||
|
) -> Result<()> {
|
||||||
|
pw.set_target(
|
||||||
|
self.custom_predicate_table_index,
|
||||||
|
F::from_canonical_usize(cpv.custom_predicate_table_index),
|
||||||
|
)?;
|
||||||
|
self.custom_predicate
|
||||||
|
.set_targets(pw, params, &cpv.custom_predicate)?;
|
||||||
|
let pad_arg = WildcardValue::None;
|
||||||
|
for (arg_target, arg) in self.args.iter().zip_eq(
|
||||||
|
cpv.args
|
||||||
|
.iter()
|
||||||
|
.chain(iter::repeat(&pad_arg))
|
||||||
|
.take(params.max_custom_predicate_wildcards),
|
||||||
|
) {
|
||||||
|
arg_target.set_targets(pw, &Value::from(arg.raw()))?;
|
||||||
|
}
|
||||||
|
let pad_op_arg = Statement(Predicate::Native(NativePredicate::None), vec![]);
|
||||||
|
for (op_arg_target, op_arg) in self.query.op_args.iter().zip_eq(
|
||||||
|
cpv.op_args
|
||||||
|
.iter()
|
||||||
|
.chain(iter::repeat(&pad_op_arg))
|
||||||
|
.take(params.max_operation_args),
|
||||||
|
) {
|
||||||
|
op_arg_target.set_targets(pw, params, op_arg)?
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Query for the custom predicate verification table
|
||||||
|
pub struct CustomPredicateVerifyQueryTarget {
|
||||||
|
pub statement: StatementTarget,
|
||||||
|
pub op_type: OperationTypeTarget,
|
||||||
|
pub op_args: Vec<StatementTarget>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomPredicateVerifyQueryTarget {
|
||||||
|
pub fn hash(&self, builder: &mut CircuitBuilder<F, D>) -> HashOutTarget {
|
||||||
|
builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.flatten())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Flattenable for CustomPredicateVerifyQueryTarget {
|
||||||
|
fn flatten(&self) -> Vec<Target> {
|
||||||
|
self.statement
|
||||||
|
.flatten()
|
||||||
|
.iter()
|
||||||
|
.chain(self.op_type.elements.iter())
|
||||||
|
.cloned()
|
||||||
|
.chain(self.op_args.iter().flat_map(|op_arg| op_arg.flatten()))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
|
||||||
|
let (pos, size) = (0, params.statement_size());
|
||||||
|
let statement = StatementTarget::from_flattened(params, &vs[pos..pos + size]);
|
||||||
|
let (pos, size) = (pos + size, params.operation_size());
|
||||||
|
let op_type = OperationTypeTarget {
|
||||||
|
elements: vs[pos..pos + size]
|
||||||
|
.try_into()
|
||||||
|
.expect("len = operation_type_size"),
|
||||||
|
};
|
||||||
|
let (pos, size) = (pos + size, params.statement_size());
|
||||||
|
let op_args = (0..params.max_operation_args)
|
||||||
|
.map(|i| {
|
||||||
|
StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size])
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Self {
|
||||||
|
statement,
|
||||||
|
op_type,
|
||||||
|
op_args,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for target structs that may be converted to and from vectors
|
/// Trait for target structs that may be converted to and from vectors
|
||||||
|
|
@ -408,6 +671,27 @@ impl From<MerkleClaimAndProofTarget> for MerkleClaimTarget {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Flattenable for HashOutTarget {
|
||||||
|
fn flatten(&self) -> Vec<Target> {
|
||||||
|
self.elements.to_vec()
|
||||||
|
}
|
||||||
|
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
|
||||||
|
assert_eq!(vs.len(), HASH_SIZE);
|
||||||
|
Self {
|
||||||
|
elements: array::from_fn(|i| vs[i]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Flattenable for ValueTarget {
|
||||||
|
fn flatten(&self) -> Vec<Target> {
|
||||||
|
self.elements.to_vec()
|
||||||
|
}
|
||||||
|
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
|
||||||
|
Self::from_slice(vs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Flattenable for MerkleClaimTarget {
|
impl Flattenable for MerkleClaimTarget {
|
||||||
fn flatten(&self) -> Vec<Target> {
|
fn flatten(&self) -> Vec<Target> {
|
||||||
[
|
[
|
||||||
|
|
@ -543,8 +827,17 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
||||||
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
|
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
|
||||||
fn add_virtual_value(&mut self) -> ValueTarget;
|
fn add_virtual_value(&mut self) -> ValueTarget;
|
||||||
fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget;
|
fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget;
|
||||||
|
fn add_virtual_statement_arg(&mut self) -> StatementArgTarget;
|
||||||
fn add_virtual_predicate(&mut self) -> PredicateTarget;
|
fn add_virtual_predicate(&mut self) -> PredicateTarget;
|
||||||
|
fn add_virtual_operation_type(&mut self) -> OperationTypeTarget;
|
||||||
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget;
|
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget;
|
||||||
|
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget;
|
||||||
|
fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget;
|
||||||
|
fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget;
|
||||||
|
fn add_virtual_custom_predicate_batch(&mut self, params: &Params)
|
||||||
|
-> CustomPredicateBatchTarget;
|
||||||
|
fn add_virtual_custom_predicate_entry(&mut self, params: &Params)
|
||||||
|
-> CustomPredicateEntryTarget;
|
||||||
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
||||||
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget;
|
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget;
|
||||||
fn constant_value(&mut self, v: RawValue) -> ValueTarget;
|
fn constant_value(&mut self, v: RawValue) -> ValueTarget;
|
||||||
|
|
@ -604,6 +897,9 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
||||||
// Convenience methods for Boolean into-iters.
|
// Convenience methods for Boolean into-iters.
|
||||||
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
|
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
|
||||||
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
|
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
|
||||||
|
|
||||||
|
// Return a bit-mask of size `len` that selects all positions lower than `n`
|
||||||
|
fn lt_mask(&mut self, len: usize, n: Target) -> Vec<BoolTarget>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
||||||
|
|
@ -629,22 +925,32 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
||||||
StatementTarget {
|
StatementTarget {
|
||||||
predicate,
|
predicate,
|
||||||
args: (0..params.max_statement_args)
|
args: (0..params.max_statement_args)
|
||||||
.map(|_| StatementArgTarget {
|
.map(|_| self.add_virtual_statement_arg())
|
||||||
elements: self.add_virtual_target_arr(),
|
|
||||||
})
|
|
||||||
.collect(),
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_virtual_statement_arg(&mut self) -> StatementArgTarget {
|
||||||
|
StatementArgTarget {
|
||||||
|
elements: self.add_virtual_target_arr(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn add_virtual_predicate(&mut self) -> PredicateTarget {
|
fn add_virtual_predicate(&mut self) -> PredicateTarget {
|
||||||
PredicateTarget {
|
PredicateTarget {
|
||||||
elements: self.add_virtual_target_arr(),
|
elements: self.add_virtual_target_arr(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_virtual_operation_type(&mut self) -> OperationTypeTarget {
|
||||||
|
OperationTypeTarget {
|
||||||
|
elements: self.add_virtual_target_arr(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget {
|
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget {
|
||||||
OperationTarget {
|
OperationTarget {
|
||||||
op_type: self.add_virtual_target_arr(),
|
op_type: self.add_virtual_operation_type(),
|
||||||
args: (0..params.max_operation_args)
|
args: (0..params.max_operation_args)
|
||||||
.map(|_| self.add_virtual_target_arr())
|
.map(|_| self.add_virtual_target_arr())
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
@ -652,6 +958,55 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget {
|
||||||
|
StatementTmplArgTarget {
|
||||||
|
elements: self.add_virtual_target_arr(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget {
|
||||||
|
let args = (0..params.max_statement_args)
|
||||||
|
.map(|_| self.add_virtual_statement_tmpl_arg())
|
||||||
|
.collect();
|
||||||
|
StatementTmplTarget {
|
||||||
|
pred: self.add_virtual_predicate(),
|
||||||
|
args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget {
|
||||||
|
let statements = (0..params.max_custom_predicate_arity)
|
||||||
|
.map(|_| self.add_virtual_statement_tmpl(params))
|
||||||
|
.collect();
|
||||||
|
CustomPredicateTarget {
|
||||||
|
conjunction: self.add_virtual_bool_target_safe(),
|
||||||
|
statements,
|
||||||
|
args_len: self.add_virtual_target(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_virtual_custom_predicate_batch(
|
||||||
|
&mut self,
|
||||||
|
params: &Params,
|
||||||
|
) -> CustomPredicateBatchTarget {
|
||||||
|
CustomPredicateBatchTarget {
|
||||||
|
predicates: (0..params.max_custom_batch_size)
|
||||||
|
.map(|_| self.add_virtual_custom_predicate(params))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_virtual_custom_predicate_entry(
|
||||||
|
&mut self,
|
||||||
|
params: &Params,
|
||||||
|
) -> CustomPredicateEntryTarget {
|
||||||
|
CustomPredicateEntryTarget {
|
||||||
|
id: self.add_virtual_hash(),
|
||||||
|
index: self.add_virtual_target(),
|
||||||
|
predicate: self.add_virtual_custom_predicate(params),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget {
|
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget {
|
||||||
ValueTarget {
|
ValueTarget {
|
||||||
elements: std::array::from_fn(|i| self.select(b, x.elements[i], y.elements[i])),
|
elements: std::array::from_fn(|i| self.select(b, x.elements[i], y.elements[i])),
|
||||||
|
|
@ -876,6 +1231,12 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
|
||||||
|
// The idea would be the following: Take the array `ts` and hash each element. Then do the
|
||||||
|
// random access on the hash result. Finally "unhash" to recover the resolved element.
|
||||||
|
// We don't want to hash each element from the array each time, so we should cache the hashed
|
||||||
|
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
|
||||||
|
// then do `ts: &[HashCache<T>]`.
|
||||||
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
|
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
|
||||||
// TODO: Revisit this when we need more than 64 statements.
|
// TODO: Revisit this when we need more than 64 statements.
|
||||||
let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| {
|
let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| {
|
||||||
|
|
@ -944,6 +1305,73 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
||||||
.reduce(|a, b| self.or(a, b))
|
.reduce(|a, b| self.or(a, b))
|
||||||
.unwrap_or(self._false())
|
.unwrap_or(self._false())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn lt_mask(&mut self, len: usize, n: Target) -> Vec<BoolTarget> {
|
||||||
|
let zero = self.zero();
|
||||||
|
let mask: Vec<_> = (0..len)
|
||||||
|
.map(|_| self.add_virtual_bool_target_safe())
|
||||||
|
.collect();
|
||||||
|
self.add_simple_generator(LtMaskGenerator {
|
||||||
|
n,
|
||||||
|
mask: mask.iter().map(|bt| bt.target).collect(),
|
||||||
|
});
|
||||||
|
// We have `n` ones in the mask
|
||||||
|
let mask_sum = mask
|
||||||
|
.iter()
|
||||||
|
.map(|b| b.target)
|
||||||
|
.reduce(|acc, x| self.add(acc, x))
|
||||||
|
.unwrap_or(zero);
|
||||||
|
self.connect(n, mask_sum);
|
||||||
|
|
||||||
|
// The elements in the mask can only transition from 1 to 0 or 0 to 0.
|
||||||
|
for i in 0..len - 1 {
|
||||||
|
let diff = self.sub(mask[i].target, mask[i + 1].target);
|
||||||
|
self.assert_bool(BoolTarget::new_unsafe(diff));
|
||||||
|
}
|
||||||
|
|
||||||
|
mask
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct LtMaskGenerator {
|
||||||
|
pub(crate) n: Target,
|
||||||
|
pub(crate) mask: Vec<Target>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for LtMaskGenerator {
|
||||||
|
fn id(&self) -> String {
|
||||||
|
"LtMaskGenerator".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dependencies(&self) -> Vec<Target> {
|
||||||
|
vec![self.n]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_once(
|
||||||
|
&self,
|
||||||
|
witness: &PartitionWitness<F>,
|
||||||
|
out_buffer: &mut GeneratedValues<F>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let n = witness.get_target(self.n).to_canonical_u64();
|
||||||
|
|
||||||
|
for (i, mask_i) in self.mask.iter().enumerate() {
|
||||||
|
let v = if (i as u64) < n { F::ONE } else { F::ZERO };
|
||||||
|
out_buffer.set_target(*mask_i, v)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData<F, D>) -> IoResult<()> {
|
||||||
|
dst.write_target(self.n)?;
|
||||||
|
dst.write_target_vec(&self.mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
|
||||||
|
let n = src.read_target()?;
|
||||||
|
let mask = src.read_target_vec()?;
|
||||||
|
Ok(Self { n, mask })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -1013,13 +1441,14 @@ pub(crate) mod tests {
|
||||||
|
|
||||||
let custom_predicate_batch = eth_friend_batch(¶ms)?;
|
let custom_predicate_batch = eth_friend_batch(¶ms)?;
|
||||||
|
|
||||||
for (i, cp) in custom_predicate_batch.predicates.iter().enumerate() {
|
for (i, cp) in custom_predicate_batch.predicates().iter().enumerate() {
|
||||||
let mut builder = CircuitBuilder::<F, D>::new(config.clone());
|
let mut builder = CircuitBuilder::<F, D>::new(config.clone());
|
||||||
let flattened = cp.to_fields(¶ms);
|
let flattened = cp.to_fields(¶ms);
|
||||||
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
|
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
|
||||||
let cp_target = CustomPredicateTarget::from_flattened(¶ms, &flatteend_target);
|
let cp_target = CustomPredicateTarget::from_flattened(¶ms, &flatteend_target);
|
||||||
// Round trip of from_flattened to flattened
|
// Round trip of from_flattened to flattened
|
||||||
let flatteend_target_rt = cp_target.flatten();
|
let flatteend_target_rt = cp_target.flatten();
|
||||||
|
// TODO: Instead of connect, assign witness to result
|
||||||
builder.connect_slice(&flatteend_target, &flatteend_target_rt);
|
builder.connect_slice(&flatteend_target, &flatteend_target_rt);
|
||||||
|
|
||||||
let pw = PartialWitness::<F>::new();
|
let pw = PartialWitness::<F>::new();
|
||||||
|
|
@ -1033,51 +1462,22 @@ pub(crate) mod tests {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_custom_predicate_batch_target_id(
|
fn helper_custom_predicate_batch_target_id(
|
||||||
params: &Params,
|
params: &Params,
|
||||||
custom_predicate_batch: &CustomPredicateBatch,
|
custom_predicate_batch: &CustomPredicateBatch,
|
||||||
) -> frontend::Result<()> {
|
) -> Result<()> {
|
||||||
let config = CircuitConfig::standard_recursion_config();
|
let config = CircuitConfig::standard_recursion_config();
|
||||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||||
|
|
||||||
let zero = builder.zero();
|
let custom_predicate_batch_target = builder.add_virtual_custom_predicate_batch(params);
|
||||||
let predicate_targets = custom_predicate_batch
|
|
||||||
.predicates
|
|
||||||
.iter()
|
|
||||||
.map(|cp| {
|
|
||||||
let flattened = cp.to_fields(params);
|
|
||||||
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
|
|
||||||
CustomPredicateTarget::from_flattened(params, &flatteend_target)
|
|
||||||
})
|
|
||||||
.chain(iter::repeat({
|
|
||||||
let empty_flatteend_target = iter::repeat(zero)
|
|
||||||
.take(params.custom_predicate_size())
|
|
||||||
.collect_vec();
|
|
||||||
CustomPredicateTarget::from_flattened(params, &empty_flatteend_target)
|
|
||||||
}))
|
|
||||||
.take(params.max_custom_batch_size)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let custom_predicate_batch_target = CustomPredicateBatchTarget {
|
|
||||||
predicates: predicate_targets,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Calculate the id in constraints and compare it against the id calculated natively
|
// Calculate the id in constraints and compare it against the id calculated natively
|
||||||
let id_target = custom_predicate_batch_target.id(&mut builder);
|
let id_target = custom_predicate_batch_target.id(&mut builder);
|
||||||
let id = custom_predicate_batch.id(params);
|
|
||||||
|
|
||||||
let id_expected_target = HashOutTarget {
|
let mut pw = PartialWitness::<F>::new();
|
||||||
elements: id
|
custom_predicate_batch_target.set_targets(&mut pw, params, custom_predicate_batch)?;
|
||||||
.to_fields(params)
|
let id = custom_predicate_batch.id();
|
||||||
.iter()
|
pw.set_target_arr(&id_target.elements, &id.0)?;
|
||||||
.map(|v| builder.constant(*v))
|
|
||||||
.collect_vec()
|
|
||||||
.try_into()
|
|
||||||
.unwrap(),
|
|
||||||
};
|
|
||||||
builder.connect_array(id_target.elements, id_expected_target.elements);
|
|
||||||
|
|
||||||
let pw = PartialWitness::<F>::new();
|
|
||||||
|
|
||||||
// generate & verify proof
|
// generate & verify proof
|
||||||
let data = builder.build::<C>();
|
let data = builder.build::<C>();
|
||||||
|
|
@ -1088,7 +1488,7 @@ pub(crate) mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn custom_predicate_batch_target() -> frontend::Result<()> {
|
fn test_custom_predicate_batch_target_id() -> frontend::Result<()> {
|
||||||
let params = Params {
|
let params = Params {
|
||||||
max_statement_args: 6,
|
max_statement_args: 6,
|
||||||
max_custom_predicate_wildcards: 12,
|
max_custom_predicate_wildcards: 12,
|
||||||
|
|
@ -1096,17 +1496,21 @@ pub(crate) mod tests {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Empty case
|
// Empty case
|
||||||
let mut cpb_builder = CustomPredicateBatchBuilder::new("empty".into());
|
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
|
||||||
_ = cpb_builder.predicate_and("empty", ¶ms, &[], &[], &[])?;
|
_ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
|
||||||
let custom_predicate_batch = cpb_builder.finish();
|
let custom_predicate_batch = cpb_builder.finish();
|
||||||
test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?;
|
helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap();
|
||||||
|
|
||||||
// Some cases from the examples
|
// Some cases from the examples
|
||||||
let custom_predicate_batch = eth_friend_batch(¶ms)?;
|
let custom_predicate_batch = eth_friend_batch(¶ms)?;
|
||||||
test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?;
|
helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap();
|
||||||
|
|
||||||
let custom_predicate_batch = eth_dos_batch(¶ms)?;
|
let custom_predicate_batch = eth_dos_batch(¶ms)?;
|
||||||
test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?;
|
helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap();
|
||||||
|
|
||||||
|
let custom_predicate_batch =
|
||||||
|
CustomPredicateBatch::new(¶ms, "empty".to_string(), vec![CustomPredicate::empty()]);
|
||||||
|
helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
||||||
pub mod operation;
|
pub mod operation;
|
||||||
pub mod statement;
|
pub mod statement;
|
||||||
use std::any::Any;
|
use std::{any::Any, sync::Arc};
|
||||||
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
pub use operation::*;
|
pub use operation::*;
|
||||||
|
|
@ -17,14 +17,17 @@ pub use statement::*;
|
||||||
use crate::{
|
use crate::{
|
||||||
backends::plonky2::{
|
backends::plonky2::{
|
||||||
basetypes::{C, D},
|
basetypes::{C, D},
|
||||||
circuits::mainpod::{MainPodVerifyCircuit, MainPodVerifyInput},
|
circuits::mainpod::{
|
||||||
|
CustomPredicateVerification, MainPodVerifyCircuit, MainPodVerifyInput,
|
||||||
|
},
|
||||||
error::{Error, Result},
|
error::{Error, Result},
|
||||||
primitives::merkletree::MerkleClaimAndProof,
|
primitives::merkletree::MerkleClaimAndProof,
|
||||||
signedpod::SignedPod,
|
signedpod::SignedPod,
|
||||||
},
|
},
|
||||||
middleware::{
|
middleware::{
|
||||||
self, AnchoredKey, DynError, Hash, MainPodInputs, NativeOperation, NonePod, OperationType,
|
self, resolve_wildcard_values, AnchoredKey, CustomPredicateBatch, DynError, Hash,
|
||||||
Params, Pod, PodId, PodProver, PodType, StatementArg, ToFields, F, KEY_TYPE, SELF,
|
MainPodInputs, NativeOperation, NonePod, OperationType, Params, Pod, PodId, PodProver,
|
||||||
|
PodType, StatementArg, ToFields, F, KEY_TYPE, SELF,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -37,7 +40,71 @@ pub(crate) fn hash_statements(statements: &[Statement], _params: &Params) -> mid
|
||||||
Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
|
Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extracts and pads Merkle proofs from Contains/NotContains ops.
|
/// Extracts unique `CustomPredicateBatch`es from Custom ops.
|
||||||
|
pub(crate) fn extract_custom_predicate_batches(
|
||||||
|
params: &Params,
|
||||||
|
operations: &[middleware::Operation],
|
||||||
|
) -> Result<Vec<Arc<CustomPredicateBatch>>> {
|
||||||
|
let custom_predicate_batches: Vec<_> = operations
|
||||||
|
.iter()
|
||||||
|
.flat_map(|op| match op {
|
||||||
|
middleware::Operation::Custom(cpr, _) => Some(cpr.batch.clone()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.unique_by(|cpr| cpr.id())
|
||||||
|
.collect();
|
||||||
|
if custom_predicate_batches.len() > params.max_custom_predicate_batches {
|
||||||
|
return Err(Error::custom(format!(
|
||||||
|
"The number of required `CustomPredicateBatch`es ({}) exceeds the maximum number ({}).",
|
||||||
|
custom_predicate_batches.len(),
|
||||||
|
params.max_custom_predicate_batches
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(custom_predicate_batches)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts all custom predicate operations with all the data required to verify them.
|
||||||
|
pub(crate) fn extract_custom_predicate_verifications(
|
||||||
|
params: &Params,
|
||||||
|
operations: &[middleware::Operation],
|
||||||
|
custom_predicate_batches: &[Arc<CustomPredicateBatch>],
|
||||||
|
) -> Result<Vec<CustomPredicateVerification>> {
|
||||||
|
let custom_predicate_data: Vec<_> = operations
|
||||||
|
.iter()
|
||||||
|
.flat_map(|op| match op {
|
||||||
|
middleware::Operation::Custom(cpr, sts) => Some((cpr, sts)),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.map(|(cpr, sts)| {
|
||||||
|
let wildcard_values =
|
||||||
|
resolve_wildcard_values(params, cpr.predicate(), sts).expect("resolved wildcards");
|
||||||
|
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
|
||||||
|
let batch_index = custom_predicate_batches
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i))
|
||||||
|
.expect("find the custom predicate from the extracted unique list");
|
||||||
|
let custom_predicate_table_index =
|
||||||
|
batch_index * params.max_custom_predicate_batches + cpr.index;
|
||||||
|
CustomPredicateVerification {
|
||||||
|
custom_predicate_table_index,
|
||||||
|
custom_predicate: cpr.clone(),
|
||||||
|
args: wildcard_values,
|
||||||
|
op_args: sts,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
if custom_predicate_data.len() > params.max_custom_predicate_verifications {
|
||||||
|
return Err(Error::custom(format!(
|
||||||
|
"The number of required custom predicate verifications ({}) exceeds the maximum number ({}).",
|
||||||
|
custom_predicate_data.len(),
|
||||||
|
params.max_custom_predicate_verifications
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
Ok(custom_predicate_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts Merkle proofs from Contains/NotContains ops.
|
||||||
pub(crate) fn extract_merkle_proofs(
|
pub(crate) fn extract_merkle_proofs(
|
||||||
params: &Params,
|
params: &Params,
|
||||||
operations: &[middleware::Operation],
|
operations: &[middleware::Operation],
|
||||||
|
|
@ -98,11 +165,32 @@ fn find_op_arg(statements: &[Statement], op_arg: &middleware::Statement) -> Resu
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find the operation auxiliary data in the list of auxiliary data and return the index.
|
/// Find the operation auxiliary data in the list of auxiliary data and return the index.
|
||||||
|
// NOTE: The `custom_predicate_verifications` is optional because in the MainPod we want to store
|
||||||
|
// the index of a custom predicate verification in the aux data, but in the MockMainPod we don't
|
||||||
|
// need that because we keep a reference to the custom predicate in the operation type, which
|
||||||
|
// removes the need for indexing. We could change the OperationType and Predicate for the backend
|
||||||
|
// to not keep a reference to the custom predicate and instead just keep the id and index and then
|
||||||
|
// do the same double indexing that the MainPod does to verify custom predicates.
|
||||||
fn find_op_aux(
|
fn find_op_aux(
|
||||||
merkle_proofs: &[MerkleClaimAndProof],
|
merkle_proofs: &[MerkleClaimAndProof],
|
||||||
op_aux: &middleware::OperationAux,
|
custom_predicate_verifications: Option<&[CustomPredicateVerification]>,
|
||||||
|
op: &middleware::Operation,
|
||||||
) -> Result<OperationAux> {
|
) -> Result<OperationAux> {
|
||||||
match op_aux {
|
let op_aux = op.aux();
|
||||||
|
let op_type = op.op_type();
|
||||||
|
if let (OperationType::Custom(cpr), Some(cpvs)) = (op_type, custom_predicate_verifications) {
|
||||||
|
return Ok(cpvs
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find_map(|(i, cpv)| {
|
||||||
|
(cpv.custom_predicate.batch.id() == cpr.batch.id()
|
||||||
|
&& cpv.custom_predicate.index == cpr.index)
|
||||||
|
.then_some(i)
|
||||||
|
})
|
||||||
|
.map(OperationAux::CustomPredVerifyIndex)
|
||||||
|
.expect("custom predicate verification in the list"));
|
||||||
|
}
|
||||||
|
match &op_aux {
|
||||||
middleware::OperationAux::None => Ok(OperationAux::None),
|
middleware::OperationAux::None => Ok(OperationAux::None),
|
||||||
middleware::OperationAux::MerkleProof(pf_arg) => merkle_proofs
|
middleware::OperationAux::MerkleProof(pf_arg) => merkle_proofs
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -217,6 +305,7 @@ pub(crate) fn process_private_statements_operations(
|
||||||
params: &Params,
|
params: &Params,
|
||||||
statements: &[Statement],
|
statements: &[Statement],
|
||||||
merkle_proofs: &[MerkleClaimAndProof],
|
merkle_proofs: &[MerkleClaimAndProof],
|
||||||
|
custom_predicate_verifications: Option<&[CustomPredicateVerification]>,
|
||||||
input_operations: &[middleware::Operation],
|
input_operations: &[middleware::Operation],
|
||||||
) -> Result<Vec<Operation>> {
|
) -> Result<Vec<Operation>> {
|
||||||
let mut operations = Vec::new();
|
let mut operations = Vec::new();
|
||||||
|
|
@ -231,8 +320,7 @@ pub(crate) fn process_private_statements_operations(
|
||||||
.map(|mid_arg| find_op_arg(statements, mid_arg))
|
.map(|mid_arg| find_op_arg(statements, mid_arg))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let mid_aux = op.aux();
|
let aux = find_op_aux(merkle_proofs, custom_predicate_verifications, &op)?;
|
||||||
let aux = find_op_aux(merkle_proofs, &mid_aux)?;
|
|
||||||
|
|
||||||
pad_operation_args(params, &mut args);
|
pad_operation_args(params, &mut args);
|
||||||
operations.push(Operation(op.op_type(), args, aux));
|
operations.push(Operation(op.op_type(), args, aux));
|
||||||
|
|
@ -301,12 +389,19 @@ impl Prover {
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
|
|
||||||
let merkle_proofs = extract_merkle_proofs(params, inputs.operations)?;
|
let merkle_proofs = extract_merkle_proofs(params, inputs.operations)?;
|
||||||
|
let custom_predicate_batches = extract_custom_predicate_batches(params, inputs.operations)?;
|
||||||
|
let custom_predicate_verifications = extract_custom_predicate_verifications(
|
||||||
|
params,
|
||||||
|
inputs.operations,
|
||||||
|
&custom_predicate_batches,
|
||||||
|
)?;
|
||||||
|
|
||||||
let statements = layout_statements(params, &inputs);
|
let statements = layout_statements(params, &inputs);
|
||||||
let operations = process_private_statements_operations(
|
let operations = process_private_statements_operations(
|
||||||
params,
|
params,
|
||||||
&statements,
|
&statements,
|
||||||
&merkle_proofs,
|
&merkle_proofs,
|
||||||
|
Some(&custom_predicate_verifications),
|
||||||
inputs.operations,
|
inputs.operations,
|
||||||
)?;
|
)?;
|
||||||
let operations = process_public_statements_operations(params, &statements, operations)?;
|
let operations = process_public_statements_operations(params, &statements, operations)?;
|
||||||
|
|
@ -321,6 +416,8 @@ impl Prover {
|
||||||
statements: statements[statements.len() - params.max_statements..].to_vec(),
|
statements: statements[statements.len() - params.max_statements..].to_vec(),
|
||||||
operations,
|
operations,
|
||||||
merkle_proofs,
|
merkle_proofs,
|
||||||
|
custom_predicate_batches,
|
||||||
|
custom_predicate_verifications,
|
||||||
};
|
};
|
||||||
main_pod.set_targets(&mut pw, &input)?;
|
main_pod.set_targets(&mut pw, &input)?;
|
||||||
|
|
||||||
|
|
@ -505,4 +602,41 @@ pub mod tests {
|
||||||
let pod = (kyc_pod.pod as Box<dyn Any>).downcast::<MainPod>().unwrap();
|
let pod = (kyc_pod.pod as Box<dyn Any>).downcast::<MainPod>().unwrap();
|
||||||
pod.verify().unwrap()
|
pod.verify().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mainpod_small_empty() {
|
||||||
|
let params = middleware::Params {
|
||||||
|
max_input_signed_pods: 0,
|
||||||
|
max_input_main_pods: 0,
|
||||||
|
max_statements: 5,
|
||||||
|
max_signed_pod_values: 2,
|
||||||
|
max_public_statements: 2,
|
||||||
|
max_statement_args: 2,
|
||||||
|
max_operation_args: 3,
|
||||||
|
max_custom_predicate_batches: 2,
|
||||||
|
max_custom_predicate_verifications: 2,
|
||||||
|
max_custom_predicate_arity: 2,
|
||||||
|
max_custom_predicate_wildcards: 2,
|
||||||
|
max_custom_batch_size: 2,
|
||||||
|
max_merkle_proofs: 2,
|
||||||
|
max_depth_mt_gadget: 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
let pod_builder = frontend::MainPodBuilder::new(¶ms);
|
||||||
|
|
||||||
|
// Mock
|
||||||
|
let mut prover = MockProver {};
|
||||||
|
let kyc_pod = pod_builder.prove(&mut prover, ¶ms).unwrap();
|
||||||
|
let pod = (kyc_pod.pod as Box<dyn Any>)
|
||||||
|
.downcast::<MockMainPod>()
|
||||||
|
.unwrap();
|
||||||
|
pod.verify().unwrap();
|
||||||
|
println!("{:#}", pod);
|
||||||
|
|
||||||
|
// Real
|
||||||
|
let mut prover = Prover {};
|
||||||
|
let kyc_pod = pod_builder.prove(&mut prover, ¶ms).unwrap();
|
||||||
|
let pod = (kyc_pod.pod as Box<dyn Any>).downcast::<MainPod>().unwrap();
|
||||||
|
pod.verify().unwrap()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -38,15 +38,17 @@ impl OperationArg {
|
||||||
pub enum OperationAux {
|
pub enum OperationAux {
|
||||||
None,
|
None,
|
||||||
MerkleProofIndex(usize),
|
MerkleProofIndex(usize),
|
||||||
|
CustomPredVerifyIndex(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToFields for OperationAux {
|
impl ToFields for OperationAux {
|
||||||
fn to_fields(&self, _params: &Params) -> Vec<F> {
|
fn to_fields(&self, _params: &Params) -> Vec<F> {
|
||||||
let f = match self {
|
let fs = match self {
|
||||||
Self::None => F::ZERO,
|
Self::None => [F::ZERO, F::ZERO],
|
||||||
Self::MerkleProofIndex(i) => F::from_canonical_usize(*i),
|
Self::MerkleProofIndex(i) => [F::from_canonical_usize(*i), F::ZERO],
|
||||||
|
Self::CustomPredVerifyIndex(i) => [F::ZERO, F::from_canonical_usize(*i)],
|
||||||
};
|
};
|
||||||
vec![f]
|
vec![fs[0], fs[1]]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -78,6 +80,7 @@ impl Operation {
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let deref_aux = match self.2 {
|
let deref_aux = match self.2 {
|
||||||
OperationAux::None => crate::middleware::OperationAux::None,
|
OperationAux::None => crate::middleware::OperationAux::None,
|
||||||
|
OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None,
|
||||||
OperationAux::MerkleProofIndex(i) => crate::middleware::OperationAux::MerkleProof(
|
OperationAux::MerkleProofIndex(i) => crate::middleware::OperationAux::MerkleProof(
|
||||||
merkle_proofs
|
merkle_proofs
|
||||||
.get(i)
|
.get(i)
|
||||||
|
|
@ -111,6 +114,7 @@ impl fmt::Display for Operation {
|
||||||
match self.2 {
|
match self.2 {
|
||||||
OperationAux::None => (),
|
OperationAux::None => (),
|
||||||
OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?,
|
OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?,
|
||||||
|
OperationAux::CustomPredVerifyIndex(i) => write!(f, " custom_pred_verify_{:02}", i)?,
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,7 @@ impl MockMainPod {
|
||||||
params,
|
params,
|
||||||
&statements,
|
&statements,
|
||||||
&merkle_proofs,
|
&merkle_proofs,
|
||||||
|
None,
|
||||||
inputs.operations,
|
inputs.operations,
|
||||||
)?;
|
)?;
|
||||||
let operations = process_public_statements_operations(params, &statements, operations)?;
|
let operations = process_public_statements_operations(params, &statements, operations)?;
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,9 @@ use crate::{
|
||||||
|
|
||||||
/// Instantiates an ETH friend batch
|
/// Instantiates an ETH friend batch
|
||||||
pub fn eth_friend_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
|
pub fn eth_friend_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
|
||||||
let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into());
|
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "eth_friend".into());
|
||||||
let _eth_friend = builder.predicate_and(
|
let _eth_friend = builder.predicate_and(
|
||||||
"eth_friend",
|
"eth_friend",
|
||||||
params,
|
|
||||||
// arguments:
|
// arguments:
|
||||||
&["src_ori", "src_key", "dst_ori", "dst_key"],
|
&["src_ori", "src_key", "dst_ori", "dst_key"],
|
||||||
// private arguments:
|
// private arguments:
|
||||||
|
|
@ -44,7 +43,8 @@ pub fn eth_friend_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
|
||||||
/// Instantiates an ETHDoS batch
|
/// Instantiates an ETHDoS batch
|
||||||
pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
|
pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
|
||||||
let eth_friend = Predicate::Custom(CustomPredicateRef::new(eth_friend_batch(params)?, 0));
|
let eth_friend = Predicate::Custom(CustomPredicateRef::new(eth_friend_batch(params)?, 0));
|
||||||
let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into());
|
let mut builder =
|
||||||
|
CustomPredicateBatchBuilder::new(params.clone(), "eth_dos_distance_base".into());
|
||||||
|
|
||||||
// eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and<
|
// eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and<
|
||||||
// eq(src_or, src_key, dst_or, dst_key),
|
// eq(src_or, src_key, dst_or, dst_key),
|
||||||
|
|
@ -52,7 +52,6 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
|
||||||
// >
|
// >
|
||||||
let eth_dos_distance_base = builder.predicate_and(
|
let eth_dos_distance_base = builder.predicate_and(
|
||||||
"eth_dos_distance_base",
|
"eth_dos_distance_base",
|
||||||
params,
|
|
||||||
&[
|
&[
|
||||||
// arguments:
|
// arguments:
|
||||||
"src_ori",
|
"src_ori",
|
||||||
|
|
@ -83,7 +82,6 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
|
||||||
|
|
||||||
let eth_dos_distance_ind = builder.predicate_and(
|
let eth_dos_distance_ind = builder.predicate_and(
|
||||||
"eth_dos_distance_ind",
|
"eth_dos_distance_ind",
|
||||||
params,
|
|
||||||
&[
|
&[
|
||||||
// arguments:
|
// arguments:
|
||||||
"src_ori",
|
"src_ori",
|
||||||
|
|
@ -135,7 +133,6 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
|
||||||
|
|
||||||
let _eth_dos_distance = builder.predicate_or(
|
let _eth_dos_distance = builder.predicate_or(
|
||||||
"eth_dos_distance",
|
"eth_dos_distance",
|
||||||
params,
|
|
||||||
&[
|
&[
|
||||||
"src_ori",
|
"src_ori",
|
||||||
"src_key",
|
"src_key",
|
||||||
|
|
|
||||||
|
|
@ -128,13 +128,15 @@ impl StatementTmplBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CustomPredicateBatchBuilder {
|
pub struct CustomPredicateBatchBuilder {
|
||||||
|
params: Params,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub predicates: Vec<CustomPredicate>,
|
pub predicates: Vec<CustomPredicate>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CustomPredicateBatchBuilder {
|
impl CustomPredicateBatchBuilder {
|
||||||
pub fn new(name: String) -> Self {
|
pub fn new(params: Params, name: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
params,
|
||||||
name,
|
name,
|
||||||
predicates: Vec::new(),
|
predicates: Vec::new(),
|
||||||
}
|
}
|
||||||
|
|
@ -143,23 +145,21 @@ impl CustomPredicateBatchBuilder {
|
||||||
pub fn predicate_and(
|
pub fn predicate_and(
|
||||||
&mut self,
|
&mut self,
|
||||||
name: &str,
|
name: &str,
|
||||||
params: &Params,
|
|
||||||
args: &[&str],
|
args: &[&str],
|
||||||
priv_args: &[&str],
|
priv_args: &[&str],
|
||||||
sts: &[StatementTmplBuilder],
|
sts: &[StatementTmplBuilder],
|
||||||
) -> Result<Predicate> {
|
) -> Result<Predicate> {
|
||||||
self.predicate(name, params, true, args, priv_args, sts)
|
self.predicate(name, true, args, priv_args, sts)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn predicate_or(
|
pub fn predicate_or(
|
||||||
&mut self,
|
&mut self,
|
||||||
name: &str,
|
name: &str,
|
||||||
params: &Params,
|
|
||||||
args: &[&str],
|
args: &[&str],
|
||||||
priv_args: &[&str],
|
priv_args: &[&str],
|
||||||
sts: &[StatementTmplBuilder],
|
sts: &[StatementTmplBuilder],
|
||||||
) -> Result<Predicate> {
|
) -> Result<Predicate> {
|
||||||
self.predicate(name, params, false, args, priv_args, sts)
|
self.predicate(name, false, args, priv_args, sts)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// creates the custom predicate from the given input, adds it to the
|
/// creates the custom predicate from the given input, adds it to the
|
||||||
|
|
@ -167,24 +167,23 @@ impl CustomPredicateBatchBuilder {
|
||||||
fn predicate(
|
fn predicate(
|
||||||
&mut self,
|
&mut self,
|
||||||
name: &str,
|
name: &str,
|
||||||
params: &Params,
|
|
||||||
conjunction: bool,
|
conjunction: bool,
|
||||||
args: &[&str],
|
args: &[&str],
|
||||||
priv_args: &[&str],
|
priv_args: &[&str],
|
||||||
sts: &[StatementTmplBuilder],
|
sts: &[StatementTmplBuilder],
|
||||||
) -> Result<Predicate> {
|
) -> Result<Predicate> {
|
||||||
if args.len() > params.max_statement_args {
|
if args.len() > self.params.max_statement_args {
|
||||||
return Err(Error::max_length(
|
return Err(Error::max_length(
|
||||||
"args.len".to_string(),
|
"args.len".to_string(),
|
||||||
args.len(),
|
args.len(),
|
||||||
params.max_statement_args,
|
self.params.max_statement_args,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if (args.len() + priv_args.len()) > params.max_custom_predicate_wildcards {
|
if (args.len() + priv_args.len()) > self.params.max_custom_predicate_wildcards {
|
||||||
return Err(Error::max_length(
|
return Err(Error::max_length(
|
||||||
"wildcards.len".to_string(),
|
"wildcards.len".to_string(),
|
||||||
args.len() + priv_args.len(),
|
args.len() + priv_args.len(),
|
||||||
params.max_custom_predicate_wildcards,
|
self.params.max_custom_predicate_wildcards,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -197,7 +196,7 @@ impl CustomPredicateBatchBuilder {
|
||||||
.iter()
|
.iter()
|
||||||
.map(|a| match a {
|
.map(|a| match a {
|
||||||
BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()),
|
BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()),
|
||||||
BuilderArg::Key(pod_id, key) => StatementTmplArg::Key(
|
BuilderArg::Key(pod_id, key) => StatementTmplArg::AnchoredKey(
|
||||||
resolve_wildcard(args, priv_args, pod_id),
|
resolve_wildcard(args, priv_args, pod_id),
|
||||||
resolve_key_or_wildcard(args, priv_args, key),
|
resolve_key_or_wildcard(args, priv_args, key),
|
||||||
),
|
),
|
||||||
|
|
@ -212,17 +211,19 @@ impl CustomPredicateBatchBuilder {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let custom_predicate =
|
let custom_predicate = CustomPredicate::new(
|
||||||
CustomPredicate::new(name.into(), params, conjunction, statements, args.len())?;
|
&self.params,
|
||||||
|
name.into(),
|
||||||
|
conjunction,
|
||||||
|
statements,
|
||||||
|
args.len(),
|
||||||
|
)?;
|
||||||
self.predicates.push(custom_predicate);
|
self.predicates.push(custom_predicate);
|
||||||
Ok(Predicate::BatchSelf(self.predicates.len() - 1))
|
Ok(Predicate::BatchSelf(self.predicates.len() - 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn finish(self) -> Arc<CustomPredicateBatch> {
|
pub fn finish(self) -> Arc<CustomPredicateBatch> {
|
||||||
Arc::new(CustomPredicateBatch {
|
CustomPredicateBatch::new(&self.params, self.name, self.predicates)
|
||||||
name: self.name,
|
|
||||||
predicates: self.predicates,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -290,7 +291,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_desugared_gt_custom_pred() -> Result<()> {
|
fn test_desugared_gt_custom_pred() -> Result<()> {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let mut builder = CustomPredicateBatchBuilder::new("gt_custom_pred".into());
|
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into());
|
||||||
|
|
||||||
let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt)
|
let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt)
|
||||||
.arg(("s1_origin", "s1_key"))
|
.arg(("s1_origin", "s1_key"))
|
||||||
|
|
@ -298,7 +299,6 @@ mod tests {
|
||||||
|
|
||||||
builder.predicate_and(
|
builder.predicate_and(
|
||||||
"gt_custom_pred",
|
"gt_custom_pred",
|
||||||
¶ms,
|
|
||||||
&["s1_origin", "s1_key", "s2_origin", "s2_key"],
|
&["s1_origin", "s1_key", "s2_origin", "s2_key"],
|
||||||
&[],
|
&[],
|
||||||
&[gt_stb],
|
&[gt_stb],
|
||||||
|
|
@ -322,7 +322,7 @@ mod tests {
|
||||||
// Check that the desugared predicate is the same as the one in the statement template
|
// Check that the desugared predicate is the same as the one in the statement template
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
desugared_gt.predicate(),
|
desugared_gt.predicate(),
|
||||||
*batch_clone.predicates[0].statements[0].pred()
|
*batch_clone.predicates()[0].statements[0].pred()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check that our custom predicate matches the statement template
|
// Check that our custom predicate matches the statement template
|
||||||
|
|
@ -339,7 +339,8 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_desugared_set_contains_custom_pred() -> Result<()> {
|
fn test_desugared_set_contains_custom_pred() -> Result<()> {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let mut builder = CustomPredicateBatchBuilder::new("set_contains_custom_pred".into());
|
let mut builder =
|
||||||
|
CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into());
|
||||||
|
|
||||||
let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains)
|
let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains)
|
||||||
.arg(("s1_origin", "s1_key"))
|
.arg(("s1_origin", "s1_key"))
|
||||||
|
|
@ -347,7 +348,6 @@ mod tests {
|
||||||
|
|
||||||
builder.predicate_and(
|
builder.predicate_and(
|
||||||
"set_contains_custom_pred",
|
"set_contains_custom_pred",
|
||||||
¶ms,
|
|
||||||
&["s1_origin", "s1_key", "s2_origin", "s2_key"],
|
&["s1_origin", "s1_key", "s2_origin", "s2_key"],
|
||||||
&[],
|
&[],
|
||||||
&[set_contains_stb],
|
&[set_contains_stb],
|
||||||
|
|
@ -368,7 +368,7 @@ mod tests {
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
set_contains.predicate(),
|
set_contains.predicate(),
|
||||||
*batch_clone.predicates[0].statements[0].pred()
|
*batch_clone.predicates()[0].statements[0].pred()
|
||||||
);
|
);
|
||||||
|
|
||||||
let set_contains_custom_pred = CustomPredicateRef::new(batch, 0);
|
let set_contains_custom_pred = CustomPredicateRef::new(batch, 0);
|
||||||
|
|
|
||||||
|
|
@ -466,7 +466,7 @@ impl MainPodBuilder {
|
||||||
)))?,
|
)))?,
|
||||||
},
|
},
|
||||||
OperationType::Custom(cpr) => {
|
OperationType::Custom(cpr) => {
|
||||||
let pred = &cpr.batch.predicates[cpr.index];
|
let pred = &cpr.batch.predicates()[cpr.index];
|
||||||
if pred.statements.len() != args.len() {
|
if pred.statements.len() != args.len() {
|
||||||
return Err(Error::custom(format!(
|
return Err(Error::custom(format!(
|
||||||
"Custom predicate operation needs {} statements but has {}.",
|
"Custom predicate operation needs {} statements but has {}.",
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,8 @@ use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::middleware::{
|
use crate::middleware::{
|
||||||
hash_fields, Error, Hash, Key, Params, Predicate, Result, ToFields, Value, F, HASH_SIZE,
|
hash_fields, Error, Hash, Key, NativePredicate, Params, Predicate, Result, ToFields, Value,
|
||||||
|
EMPTY_HASH, F, HASH_SIZE, VALUE_SIZE,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||||
|
|
@ -49,12 +50,15 @@ impl fmt::Display for KeyOrWildcard {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToFields for KeyOrWildcard {
|
impl ToFields for KeyOrWildcard {
|
||||||
|
// Encoding:
|
||||||
|
// - Key(k) => [[k]]
|
||||||
|
// - Wildcard(index) => [[index], 0, 0, 0]
|
||||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||||
match self {
|
match self {
|
||||||
KeyOrWildcard::Key(k) => k.hash().to_fields(params),
|
KeyOrWildcard::Key(k) => k.hash().to_fields(params),
|
||||||
KeyOrWildcard::Wildcard(wc) => iter::once(F::ZERO)
|
KeyOrWildcard::Wildcard(wc) => iter::once(F::from_canonical_u64(wc.index as u64))
|
||||||
.take(HASH_SIZE - 1)
|
.chain(iter::repeat(F::ZERO))
|
||||||
.chain(iter::once(F::from_canonical_u64(wc.index as u64)))
|
.take(HASH_SIZE)
|
||||||
.collect(),
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -66,7 +70,7 @@ pub enum StatementTmplArg {
|
||||||
None,
|
None,
|
||||||
Literal(Value),
|
Literal(Value),
|
||||||
// AnchoredKey
|
// AnchoredKey
|
||||||
Key(Wildcard, KeyOrWildcard),
|
AnchoredKey(Wildcard, KeyOrWildcard),
|
||||||
// TODO: This naming is a bit confusing: a WildcardLiteral that contains a Wildcard...
|
// TODO: This naming is a bit confusing: a WildcardLiteral that contains a Wildcard...
|
||||||
// Could we merge WildcardValue and Value and allow wildcard value apart from pod_id and key?
|
// Could we merge WildcardValue and Value and allow wildcard value apart from pod_id and key?
|
||||||
WildcardLiteral(Wildcard),
|
WildcardLiteral(Wildcard),
|
||||||
|
|
@ -76,7 +80,7 @@ pub enum StatementTmplArg {
|
||||||
pub enum StatementTmplArgPrefix {
|
pub enum StatementTmplArgPrefix {
|
||||||
None = 0,
|
None = 0,
|
||||||
Literal = 1,
|
Literal = 1,
|
||||||
Key = 2,
|
AnchoredKey = 2,
|
||||||
WildcardLiteral = 3,
|
WildcardLiteral = 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -88,11 +92,11 @@ impl From<StatementTmplArgPrefix> for F {
|
||||||
|
|
||||||
impl ToFields for StatementTmplArg {
|
impl ToFields for StatementTmplArg {
|
||||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||||
// None => (0, ...)
|
// Encoding:
|
||||||
// Literal(value) => (1, [value], 0, 0, 0, 0)
|
// None => (0, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||||
// Key(wildcard1_index, key_or_wildcard2)
|
// Literal(v) => (1, [v ], 0, 0, 0, 0)
|
||||||
// => (2, [wildcard1_index], 0, 0, 0, [key_or_wildcard2])
|
// Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
|
||||||
// WildcardLiteral(wildcard_index) => (3, [wildcard_index], 0, 0, 0, 0, 0, 0, 0)
|
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
|
||||||
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
|
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
|
||||||
match self {
|
match self {
|
||||||
StatementTmplArg::None => {
|
StatementTmplArg::None => {
|
||||||
|
|
@ -105,13 +109,15 @@ impl ToFields for StatementTmplArg {
|
||||||
StatementTmplArg::Literal(v) => {
|
StatementTmplArg::Literal(v) => {
|
||||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal))
|
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal))
|
||||||
.chain(v.raw().to_fields(params))
|
.chain(v.raw().to_fields(params))
|
||||||
.chain(iter::repeat(F::ZERO).take(HASH_SIZE))
|
.chain(iter::repeat(F::ZERO))
|
||||||
|
.take(Params::statement_tmpl_arg_size())
|
||||||
.collect();
|
.collect();
|
||||||
fields
|
fields
|
||||||
}
|
}
|
||||||
StatementTmplArg::Key(wc1, kw2) => {
|
StatementTmplArg::AnchoredKey(wc1, kw2) => {
|
||||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Key))
|
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::AnchoredKey))
|
||||||
.chain(wc1.to_fields(params))
|
.chain(wc1.to_fields(params))
|
||||||
|
.chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1))
|
||||||
.chain(kw2.to_fields(params))
|
.chain(kw2.to_fields(params))
|
||||||
.collect();
|
.collect();
|
||||||
fields
|
fields
|
||||||
|
|
@ -119,7 +125,8 @@ impl ToFields for StatementTmplArg {
|
||||||
StatementTmplArg::WildcardLiteral(wc) => {
|
StatementTmplArg::WildcardLiteral(wc) => {
|
||||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
|
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
|
||||||
.chain(wc.to_fields(params))
|
.chain(wc.to_fields(params))
|
||||||
.chain(iter::repeat(F::ZERO).take(HASH_SIZE))
|
.chain(iter::repeat(F::ZERO))
|
||||||
|
.take(Params::statement_tmpl_arg_size())
|
||||||
.collect();
|
.collect();
|
||||||
fields
|
fields
|
||||||
}
|
}
|
||||||
|
|
@ -132,7 +139,7 @@ impl fmt::Display for StatementTmplArg {
|
||||||
match self {
|
match self {
|
||||||
Self::None => write!(f, "none"),
|
Self::None => write!(f, "none"),
|
||||||
Self::Literal(v) => write!(f, "{}", v),
|
Self::Literal(v) => write!(f, "{}", v),
|
||||||
Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key),
|
Self::AnchoredKey(pod_id, key) => write!(f, "({}, {})", pod_id, key),
|
||||||
Self::WildcardLiteral(v) => write!(f, "{}", v),
|
Self::WildcardLiteral(v) => write!(f, "{}", v),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -177,7 +184,11 @@ impl ToFields for StatementTmpl {
|
||||||
// instead of at the `to_fields` method, where we should assume that the
|
// instead of at the `to_fields` method, where we should assume that the
|
||||||
// values are already valid
|
// values are already valid
|
||||||
if self.args.len() > params.max_statement_args {
|
if self.args.len() > params.max_statement_args {
|
||||||
panic!("Statement template has too many arguments");
|
panic!(
|
||||||
|
"Statement template has too many arguments {} > {}",
|
||||||
|
self.args.len(),
|
||||||
|
params.max_statement_args
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut fields: Vec<F> = self
|
let mut fields: Vec<F> = self
|
||||||
|
|
@ -206,25 +217,36 @@ pub struct CustomPredicate {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CustomPredicate {
|
impl CustomPredicate {
|
||||||
|
pub fn empty() -> Self {
|
||||||
|
Self {
|
||||||
|
name: "empty".to_string(),
|
||||||
|
conjunction: false,
|
||||||
|
statements: vec![StatementTmpl {
|
||||||
|
pred: Predicate::Native(NativePredicate::None),
|
||||||
|
args: vec![],
|
||||||
|
}],
|
||||||
|
args_len: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
pub fn and(
|
pub fn and(
|
||||||
name: String,
|
|
||||||
params: &Params,
|
params: &Params,
|
||||||
|
name: String,
|
||||||
statements: Vec<StatementTmpl>,
|
statements: Vec<StatementTmpl>,
|
||||||
args_len: usize,
|
args_len: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Self::new(name, params, true, statements, args_len)
|
Self::new(params, name, true, statements, args_len)
|
||||||
}
|
}
|
||||||
pub fn or(
|
pub fn or(
|
||||||
name: String,
|
|
||||||
params: &Params,
|
params: &Params,
|
||||||
|
name: String,
|
||||||
statements: Vec<StatementTmpl>,
|
statements: Vec<StatementTmpl>,
|
||||||
args_len: usize,
|
args_len: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Self::new(name, params, false, statements, args_len)
|
Self::new(params, name, false, statements, args_len)
|
||||||
}
|
}
|
||||||
pub fn new(
|
pub fn new(
|
||||||
name: String,
|
|
||||||
params: &Params,
|
params: &Params,
|
||||||
|
name: String,
|
||||||
conjunction: bool,
|
conjunction: bool,
|
||||||
statements: Vec<StatementTmpl>,
|
statements: Vec<StatementTmpl>,
|
||||||
args_len: usize,
|
args_len: usize,
|
||||||
|
|
@ -236,6 +258,13 @@ impl CustomPredicate {
|
||||||
params.max_custom_predicate_arity,
|
params.max_custom_predicate_arity,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
if args_len > params.max_statement_args {
|
||||||
|
return Err(Error::max_length(
|
||||||
|
"statement_args.len".to_string(),
|
||||||
|
args_len,
|
||||||
|
params.max_statement_args,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
name,
|
name,
|
||||||
|
|
@ -244,6 +273,16 @@ impl CustomPredicate {
|
||||||
args_len,
|
args_len,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
pub fn pad_statement_tmpl(&self) -> StatementTmpl {
|
||||||
|
StatementTmpl {
|
||||||
|
pred: Predicate::Native(if self.conjunction {
|
||||||
|
NativePredicate::False
|
||||||
|
} else {
|
||||||
|
NativePredicate::None
|
||||||
|
}),
|
||||||
|
args: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToFields for CustomPredicate {
|
impl ToFields for CustomPredicate {
|
||||||
|
|
@ -262,11 +301,17 @@ impl ToFields for CustomPredicate {
|
||||||
panic!("Custom predicate depends on too many statements");
|
panic!("Custom predicate depends on too many statements");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
|
let pad_st = self.pad_statement_tmpl();
|
||||||
|
let fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
|
||||||
.chain(iter::once(F::from_canonical_usize(self.args_len)))
|
.chain(iter::once(F::from_canonical_usize(self.args_len)))
|
||||||
.chain(self.statements.iter().flat_map(|st| st.to_fields(params)))
|
.chain(
|
||||||
|
self.statements
|
||||||
|
.iter()
|
||||||
|
.chain(iter::repeat(&pad_st))
|
||||||
|
.take(params.max_custom_predicate_arity)
|
||||||
|
.flat_map(|st| st.to_fields(params)),
|
||||||
|
)
|
||||||
.collect();
|
.collect();
|
||||||
fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0));
|
|
||||||
fields
|
fields
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -298,8 +343,9 @@ impl fmt::Display for CustomPredicate {
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||||
pub struct CustomPredicateBatch {
|
pub struct CustomPredicateBatch {
|
||||||
|
id: Hash,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub predicates: Vec<CustomPredicate>,
|
predicates: Vec<CustomPredicate>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToFields for CustomPredicateBatch {
|
impl ToFields for CustomPredicateBatch {
|
||||||
|
|
@ -313,27 +359,45 @@ impl ToFields for CustomPredicateBatch {
|
||||||
panic!("Predicate batch exceeds maximum size");
|
panic!("Predicate batch exceeds maximum size");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut fields: Vec<F> = self
|
let pad_pred = CustomPredicate::empty();
|
||||||
|
let fields: Vec<F> = self
|
||||||
.predicates
|
.predicates
|
||||||
.iter()
|
.iter()
|
||||||
|
.chain(iter::repeat(&pad_pred))
|
||||||
|
.take(params.max_custom_batch_size)
|
||||||
.flat_map(|p| p.to_fields(params))
|
.flat_map(|p| p.to_fields(params))
|
||||||
.collect();
|
.collect();
|
||||||
fields.resize_with(params.custom_predicate_batch_size_field_elts(), || {
|
|
||||||
F::from_canonical_u64(0)
|
|
||||||
});
|
|
||||||
fields
|
fields
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CustomPredicateBatch {
|
impl CustomPredicateBatch {
|
||||||
|
pub fn new(params: &Params, name: String, predicates: Vec<CustomPredicate>) -> Arc<Self> {
|
||||||
|
let mut cpb = Self {
|
||||||
|
id: EMPTY_HASH,
|
||||||
|
name,
|
||||||
|
predicates,
|
||||||
|
};
|
||||||
|
let id = cpb.calculate_id(params);
|
||||||
|
cpb.id = id;
|
||||||
|
Arc::new(cpb)
|
||||||
|
}
|
||||||
|
|
||||||
/// Cryptographic identifier for the batch.
|
/// Cryptographic identifier for the batch.
|
||||||
pub fn id(&self, params: &Params) -> Hash {
|
fn calculate_id(&self, params: &Params) -> Hash {
|
||||||
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
|
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
|
||||||
// but ideally we want to use the root of a merkle tree built from the custom predicates.
|
// but ideally we want to use the root of a merkle tree built from the custom predicates.
|
||||||
let input = self.to_fields(params);
|
let input = self.to_fields(params);
|
||||||
|
|
||||||
hash_fields(&input)
|
hash_fields(&input)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> Hash {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
pub fn predicates(&self) -> &[CustomPredicate] {
|
||||||
|
&self.predicates
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||||
|
|
@ -347,13 +411,16 @@ impl CustomPredicateRef {
|
||||||
Self { batch, index }
|
Self { batch, index }
|
||||||
}
|
}
|
||||||
pub fn arg_len(&self) -> usize {
|
pub fn arg_len(&self) -> usize {
|
||||||
self.batch.predicates[self.index].args_len
|
self.predicate().args_len
|
||||||
|
}
|
||||||
|
pub fn predicate(&self) -> &CustomPredicate {
|
||||||
|
&self.batch.predicates[self.index]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::{array, sync::Arc};
|
use std::array;
|
||||||
|
|
||||||
use plonky2::field::goldilocks_field::GoldilocksField;
|
use plonky2::field::goldilocks_field::GoldilocksField;
|
||||||
|
|
||||||
|
|
@ -392,28 +459,29 @@ mod tests {
|
||||||
p:value_of(Constant, 2),
|
p:value_of(Constant, 2),
|
||||||
p:product_of(S1, Constant, S2)
|
p:product_of(S1, Constant, S2)
|
||||||
*/
|
*/
|
||||||
let cust_pred_batch = Arc::new(CustomPredicateBatch {
|
let cust_pred_batch = CustomPredicateBatch::new(
|
||||||
name: "is_double".to_string(),
|
¶ms,
|
||||||
predicates: vec![CustomPredicate::and(
|
"is_double".to_string(),
|
||||||
"_".into(),
|
vec![CustomPredicate::and(
|
||||||
¶ms,
|
¶ms,
|
||||||
|
"_".into(),
|
||||||
vec![
|
vec![
|
||||||
st(
|
st(
|
||||||
P::Native(NP::ValueOf),
|
P::Native(NP::ValueOf),
|
||||||
vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(2.into())],
|
vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(2.into())],
|
||||||
),
|
),
|
||||||
st(
|
st(
|
||||||
P::Native(NP::ProductOf),
|
P::Native(NP::ProductOf),
|
||||||
vec![
|
vec![
|
||||||
STA::Key(wc(0), kow_wc(1)),
|
STA::AnchoredKey(wc(0), kow_wc(1)),
|
||||||
STA::Key(wc(4), kow_wc(5)),
|
STA::AnchoredKey(wc(4), kow_wc(5)),
|
||||||
STA::Key(wc(2), kow_wc(3)),
|
STA::AnchoredKey(wc(2), kow_wc(3)),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
2,
|
2,
|
||||||
)?],
|
)?],
|
||||||
});
|
);
|
||||||
|
|
||||||
let custom_statement = Statement::Custom(
|
let custom_statement = Statement::Custom(
|
||||||
CustomPredicateRef::new(cust_pred_batch.clone(), 0),
|
CustomPredicateRef::new(cust_pred_batch.clone(), 0),
|
||||||
|
|
@ -444,55 +512,57 @@ mod tests {
|
||||||
fn ethdos_test() -> Result<()> {
|
fn ethdos_test() -> Result<()> {
|
||||||
let params = Params {
|
let params = Params {
|
||||||
max_custom_predicate_wildcards: 12,
|
max_custom_predicate_wildcards: 12,
|
||||||
|
max_statement_args: 6,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let eth_friend_cp = CustomPredicate::and(
|
let eth_friend_cp = CustomPredicate::and(
|
||||||
"eth_friend_cp".into(),
|
|
||||||
¶ms,
|
¶ms,
|
||||||
|
"eth_friend_cp".into(),
|
||||||
vec![
|
vec![
|
||||||
st(
|
st(
|
||||||
P::Native(NP::ValueOf),
|
P::Native(NP::ValueOf),
|
||||||
vec![
|
vec![
|
||||||
STA::Key(wc(4), KeyOrWildcard::Key("type".into())),
|
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("type".into())),
|
||||||
STA::Literal(PodType::Signed.into()),
|
STA::Literal(PodType::Signed.into()),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
st(
|
st(
|
||||||
P::Native(NP::Equal),
|
P::Native(NP::Equal),
|
||||||
vec![
|
vec![
|
||||||
STA::Key(wc(4), KeyOrWildcard::Key("signer".into())),
|
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("signer".into())),
|
||||||
STA::Key(wc(0), kow_wc(1)),
|
STA::AnchoredKey(wc(0), kow_wc(1)),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
st(
|
st(
|
||||||
P::Native(NP::Equal),
|
P::Native(NP::Equal),
|
||||||
vec![
|
vec![
|
||||||
STA::Key(wc(4), KeyOrWildcard::Key("attestation".into())),
|
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("attestation".into())),
|
||||||
STA::Key(wc(2), kow_wc(3)),
|
STA::AnchoredKey(wc(2), kow_wc(3)),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
4,
|
4,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let eth_friend_batch = Arc::new(CustomPredicateBatch {
|
let eth_friend_batch =
|
||||||
name: "eth_friend".to_string(),
|
CustomPredicateBatch::new(¶ms, "eth_friend".to_string(), vec![eth_friend_cp]);
|
||||||
predicates: vec![eth_friend_cp],
|
|
||||||
});
|
|
||||||
|
|
||||||
// 0
|
// 0
|
||||||
let eth_dos_base = CustomPredicate::and(
|
let eth_dos_base = CustomPredicate::and(
|
||||||
"eth_dos_base".into(),
|
|
||||||
¶ms,
|
¶ms,
|
||||||
|
"eth_dos_base".into(),
|
||||||
vec![
|
vec![
|
||||||
st(
|
st(
|
||||||
P::Native(NP::Equal),
|
P::Native(NP::Equal),
|
||||||
vec![STA::Key(wc(0), kow_wc(1)), STA::Key(wc(2), kow_wc(3))],
|
vec![
|
||||||
|
STA::AnchoredKey(wc(0), kow_wc(1)),
|
||||||
|
STA::AnchoredKey(wc(2), kow_wc(3)),
|
||||||
|
],
|
||||||
),
|
),
|
||||||
st(
|
st(
|
||||||
P::Native(NP::ValueOf),
|
P::Native(NP::ValueOf),
|
||||||
vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(0.into())],
|
vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(0.into())],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
6,
|
6,
|
||||||
|
|
@ -500,8 +570,8 @@ mod tests {
|
||||||
|
|
||||||
// 1
|
// 1
|
||||||
let eth_dos_ind = CustomPredicate::and(
|
let eth_dos_ind = CustomPredicate::and(
|
||||||
"eth_dos_ind".into(),
|
|
||||||
¶ms,
|
¶ms,
|
||||||
|
"eth_dos_ind".into(),
|
||||||
vec![
|
vec![
|
||||||
st(
|
st(
|
||||||
P::BatchSelf(2),
|
P::BatchSelf(2),
|
||||||
|
|
@ -516,14 +586,14 @@ mod tests {
|
||||||
),
|
),
|
||||||
st(
|
st(
|
||||||
P::Native(NP::ValueOf),
|
P::Native(NP::ValueOf),
|
||||||
vec![STA::Key(wc(6), kow_wc(7)), STA::Literal(1.into())],
|
vec![STA::AnchoredKey(wc(6), kow_wc(7)), STA::Literal(1.into())],
|
||||||
),
|
),
|
||||||
st(
|
st(
|
||||||
P::Native(NP::SumOf),
|
P::Native(NP::SumOf),
|
||||||
vec![
|
vec![
|
||||||
STA::Key(wc(4), kow_wc(5)),
|
STA::AnchoredKey(wc(4), kow_wc(5)),
|
||||||
STA::Key(wc(8), kow_wc(9)),
|
STA::AnchoredKey(wc(8), kow_wc(9)),
|
||||||
STA::Key(wc(6), kow_wc(7)),
|
STA::AnchoredKey(wc(6), kow_wc(7)),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
st(
|
st(
|
||||||
|
|
@ -541,8 +611,8 @@ mod tests {
|
||||||
|
|
||||||
// 2
|
// 2
|
||||||
let eth_dos_distance_either = CustomPredicate::or(
|
let eth_dos_distance_either = CustomPredicate::or(
|
||||||
"eth_dos_distance_either".into(),
|
|
||||||
¶ms,
|
¶ms,
|
||||||
|
"eth_dos_distance_either".into(),
|
||||||
vec![
|
vec![
|
||||||
st(
|
st(
|
||||||
P::BatchSelf(0),
|
P::BatchSelf(0),
|
||||||
|
|
@ -570,10 +640,11 @@ mod tests {
|
||||||
6,
|
6,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let eth_dos_distance_batch = Arc::new(CustomPredicateBatch {
|
let eth_dos_distance_batch = CustomPredicateBatch::new(
|
||||||
name: "ETHDoS_distance".to_string(),
|
¶ms,
|
||||||
predicates: vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either],
|
"ETHDoS_distance".to_string(),
|
||||||
});
|
vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either],
|
||||||
|
);
|
||||||
|
|
||||||
// Some POD IDs
|
// Some POD IDs
|
||||||
let pod_id1 = PodId(Hash(array::from_fn(|i| GoldilocksField(i as u64))));
|
let pod_id1 = PodId(Hash(array::from_fn(|i| GoldilocksField(i as u64))));
|
||||||
|
|
|
||||||
|
|
@ -584,6 +584,10 @@ pub struct Params {
|
||||||
pub max_public_statements: usize,
|
pub max_public_statements: usize,
|
||||||
pub max_statement_args: usize,
|
pub max_statement_args: usize,
|
||||||
pub max_operation_args: usize,
|
pub max_operation_args: usize,
|
||||||
|
// max number of custom predicates batches that a MainPod can use
|
||||||
|
pub max_custom_predicate_batches: usize,
|
||||||
|
// max number of operations using custom predicates that can be verified in the MainPod
|
||||||
|
pub max_custom_predicate_verifications: usize,
|
||||||
// max number of statements that can be ANDed or ORed together
|
// max number of statements that can be ANDed or ORed together
|
||||||
// in a custom predicate
|
// in a custom predicate
|
||||||
pub max_custom_predicate_arity: usize,
|
pub max_custom_predicate_arity: usize,
|
||||||
|
|
@ -605,6 +609,8 @@ impl Default for Params {
|
||||||
max_public_statements: 10,
|
max_public_statements: 10,
|
||||||
max_statement_args: 5,
|
max_statement_args: 5,
|
||||||
max_operation_args: 5,
|
max_operation_args: 5,
|
||||||
|
max_custom_predicate_batches: 2,
|
||||||
|
max_custom_predicate_verifications: 5,
|
||||||
max_custom_predicate_arity: 5,
|
max_custom_predicate_arity: 5,
|
||||||
max_custom_predicate_wildcards: 10,
|
max_custom_predicate_wildcards: 10,
|
||||||
max_custom_batch_size: 5,
|
max_custom_batch_size: 5,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use std::{fmt, iter, sync::Arc};
|
use std::{fmt, iter};
|
||||||
|
|
||||||
use log::error;
|
use log::error;
|
||||||
use plonky2::field::types::Field;
|
use plonky2::field::types::Field;
|
||||||
|
|
@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
||||||
use crate::{
|
use crate::{
|
||||||
backends::plonky2::primitives::merkletree::MerkleProof,
|
backends::plonky2::primitives::merkletree::MerkleProof,
|
||||||
middleware::{
|
middleware::{
|
||||||
custom::KeyOrWildcard, AnchoredKey, CustomPredicateBatch, CustomPredicateRef, Error,
|
custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error,
|
||||||
NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmplArg,
|
NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmplArg,
|
||||||
ToFields, Wildcard, WildcardValue, F, SELF,
|
ToFields, Wildcard, WildcardValue, F, SELF,
|
||||||
},
|
},
|
||||||
|
|
@ -36,6 +36,9 @@ impl fmt::Display for OperationAux {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToFields for OperationType {
|
impl ToFields for OperationType {
|
||||||
|
/// Encoding:
|
||||||
|
/// - Native(native_op) => [1, [native_op], 0, 0, 0, 0]
|
||||||
|
/// - Custom(batch, index) => [3, [batch.id], index]
|
||||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||||
let mut fields: Vec<F> = match self {
|
let mut fields: Vec<F> = match self {
|
||||||
Self::Native(p) => iter::once(F::from_canonical_u64(1))
|
Self::Native(p) => iter::once(F::from_canonical_u64(1))
|
||||||
|
|
@ -43,7 +46,7 @@ impl ToFields for OperationType {
|
||||||
.collect(),
|
.collect(),
|
||||||
Self::Custom(CustomPredicateRef { batch, index }) => {
|
Self::Custom(CustomPredicateRef { batch, index }) => {
|
||||||
iter::once(F::from_canonical_u64(3))
|
iter::once(F::from_canonical_u64(3))
|
||||||
.chain(batch.id(params).0)
|
.chain(batch.id().0)
|
||||||
.chain(iter::once(F::from_canonical_usize(*index)))
|
.chain(iter::once(F::from_canonical_usize(*index)))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
@ -321,7 +324,7 @@ impl Operation {
|
||||||
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
|
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
|
||||||
if batch == &cpr.batch && index == &cpr.index =>
|
if batch == &cpr.batch && index == &cpr.index =>
|
||||||
{
|
{
|
||||||
check_custom_pred(params, batch, *index, args, s_args)
|
check_custom_pred(params, cpr, args, s_args)
|
||||||
}
|
}
|
||||||
_ => Err(Error::invalid_deduction(
|
_ => Err(Error::invalid_deduction(
|
||||||
self.clone(),
|
self.clone(),
|
||||||
|
|
@ -360,7 +363,7 @@ pub fn check_st_tmpl(
|
||||||
(StatementTmplArg::None, StatementArg::None) => true,
|
(StatementTmplArg::None, StatementArg::None) => true,
|
||||||
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true,
|
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true,
|
||||||
(
|
(
|
||||||
StatementTmplArg::Key(pod_id_wc, key_or_wc),
|
StatementTmplArg::AnchoredKey(pod_id_wc, key_or_wc),
|
||||||
StatementArg::Key(AnchoredKey { pod_id, key }),
|
StatementArg::Key(AnchoredKey { pod_id, key }),
|
||||||
) => {
|
) => {
|
||||||
let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map);
|
let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map);
|
||||||
|
|
@ -379,14 +382,46 @@ pub fn check_st_tmpl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn resolve_wildcard_values(
|
||||||
|
params: &Params,
|
||||||
|
pred: &CustomPredicate,
|
||||||
|
args: &[Statement],
|
||||||
|
) -> Option<Vec<WildcardValue>> {
|
||||||
|
// Check that all wildcard have consistent values as assigned in the statements while storing a
|
||||||
|
// map of their values.
|
||||||
|
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
|
||||||
|
// disjunctions we expect Statement::None for the unused statements.
|
||||||
|
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
|
||||||
|
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||||
|
let st_args = st.args();
|
||||||
|
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
|
||||||
|
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
|
||||||
|
// TODO: Better errors. Example:
|
||||||
|
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
|
||||||
|
// println!("{} doesn't match {}", st, st_tmpl);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
|
||||||
|
// they are beyond the number of used wildcards in this custom predicate, or they could be
|
||||||
|
// private arguments that are unused in a particular disjunction.
|
||||||
|
Some(
|
||||||
|
wildcard_map
|
||||||
|
.into_iter()
|
||||||
|
.map(|opt| opt.unwrap_or(WildcardValue::None))
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn check_custom_pred(
|
fn check_custom_pred(
|
||||||
params: &Params,
|
params: &Params,
|
||||||
batch: &Arc<CustomPredicateBatch>,
|
custom_pred_ref: &CustomPredicateRef,
|
||||||
index: usize,
|
|
||||||
args: &[Statement],
|
args: &[Statement],
|
||||||
s_args: &[WildcardValue],
|
s_args: &[WildcardValue],
|
||||||
) -> Result<bool> {
|
) -> Result<bool> {
|
||||||
let pred = &batch.predicates[index];
|
let pred = custom_pred_ref.predicate();
|
||||||
if pred.statements.len() != args.len() {
|
if pred.statements.len() != args.len() {
|
||||||
return Err(Error::diff_amount(
|
return Err(Error::diff_amount(
|
||||||
"custom predicate operation".to_string(),
|
"custom predicate operation".to_string(),
|
||||||
|
|
@ -404,26 +439,12 @@ fn check_custom_pred(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that all wildcard have consistent values as assigned in the statements while storing a
|
// Count the number of statements that match the templates by predicate.
|
||||||
// map of their values. Count the number of statements that match the templates by predicate.
|
|
||||||
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
|
|
||||||
// disjunctions we expect Statement::None for the unused statements.
|
|
||||||
let mut num_matches = 0;
|
let mut num_matches = 0;
|
||||||
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
|
|
||||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||||
let st_args = st.args();
|
|
||||||
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
|
|
||||||
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
|
|
||||||
// TODO: Better errors. Example:
|
|
||||||
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
|
|
||||||
// println!("{} doesn't match {}", st, st_tmpl);
|
|
||||||
return Ok(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let st_tmpl_pred = match &st_tmpl.pred {
|
let st_tmpl_pred = match &st_tmpl.pred {
|
||||||
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
|
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
|
||||||
batch: batch.clone(),
|
batch: custom_pred_ref.batch.clone(),
|
||||||
index: *i,
|
index: *i,
|
||||||
}),
|
}),
|
||||||
p => p.clone(),
|
p => p.clone(),
|
||||||
|
|
@ -433,9 +454,14 @@ fn check_custom_pred(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let wildcard_map = match resolve_wildcard_values(params, pred, args) {
|
||||||
|
Some(wc_map) => wc_map,
|
||||||
|
None => return Ok(false),
|
||||||
|
};
|
||||||
|
|
||||||
// Check that the resolved wildcard match the statement arguments.
|
// Check that the resolved wildcard match the statement arguments.
|
||||||
for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) {
|
for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) {
|
||||||
if !wc_value.as_ref().is_none_or(|wc_value| *wc_value == *s_arg) {
|
if *wc_value != *s_arg {
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ use strum_macros::FromRepr;
|
||||||
|
|
||||||
use crate::middleware::{
|
use crate::middleware::{
|
||||||
AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, RawValue, Result, ToFields, Value,
|
AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, RawValue, Result, ToFields, Value,
|
||||||
F, VALUE_SIZE,
|
EMPTY_VALUE, F, VALUE_SIZE,
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static
|
// TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static
|
||||||
|
|
@ -17,22 +17,23 @@ pub const KEY_SIGNER: &str = "_signer";
|
||||||
pub const KEY_TYPE: &str = "_type";
|
pub const KEY_TYPE: &str = "_type";
|
||||||
pub const STATEMENT_ARG_F_LEN: usize = 8;
|
pub const STATEMENT_ARG_F_LEN: usize = 8;
|
||||||
pub const OPERATION_ARG_F_LEN: usize = 1;
|
pub const OPERATION_ARG_F_LEN: usize = 1;
|
||||||
pub const OPERATION_AUX_F_LEN: usize = 1;
|
pub const OPERATION_AUX_F_LEN: usize = 2;
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||||
pub enum NativePredicate {
|
pub enum NativePredicate {
|
||||||
None = 0,
|
None = 0, // Always true
|
||||||
ValueOf = 1,
|
False = 1, // Always false
|
||||||
Equal = 2,
|
ValueOf = 2,
|
||||||
NotEqual = 3,
|
Equal = 3,
|
||||||
LtEq = 4,
|
NotEqual = 4,
|
||||||
Lt = 5,
|
LtEq = 5,
|
||||||
Contains = 6,
|
Lt = 6,
|
||||||
NotContains = 7,
|
Contains = 7,
|
||||||
SumOf = 8,
|
NotContains = 8,
|
||||||
ProductOf = 9,
|
SumOf = 9,
|
||||||
MaxOf = 10,
|
ProductOf = 10,
|
||||||
HashOf = 11,
|
MaxOf = 11,
|
||||||
|
HashOf = 12,
|
||||||
|
|
||||||
// Syntactic sugar predicates. These predicates are not supported by the backend. The
|
// Syntactic sugar predicates. These predicates are not supported by the backend. The
|
||||||
// frontend compiler is responsible of translating these predicates into the predicates above.
|
// frontend compiler is responsible of translating these predicates into the predicates above.
|
||||||
|
|
@ -53,6 +54,7 @@ impl ToFields for NativePredicate {
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||||
pub enum WildcardValue {
|
pub enum WildcardValue {
|
||||||
|
None,
|
||||||
PodId(PodId),
|
PodId(PodId),
|
||||||
Key(Key),
|
Key(Key),
|
||||||
}
|
}
|
||||||
|
|
@ -60,6 +62,7 @@ pub enum WildcardValue {
|
||||||
impl WildcardValue {
|
impl WildcardValue {
|
||||||
pub fn raw(&self) -> RawValue {
|
pub fn raw(&self) -> RawValue {
|
||||||
match self {
|
match self {
|
||||||
|
WildcardValue::None => EMPTY_VALUE,
|
||||||
WildcardValue::PodId(pod_id) => RawValue::from(pod_id.0),
|
WildcardValue::PodId(pod_id) => RawValue::from(pod_id.0),
|
||||||
WildcardValue::Key(key) => key.raw(),
|
WildcardValue::Key(key) => key.raw(),
|
||||||
}
|
}
|
||||||
|
|
@ -69,6 +72,7 @@ impl WildcardValue {
|
||||||
impl fmt::Display for WildcardValue {
|
impl fmt::Display for WildcardValue {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
|
WildcardValue::None => write!(f, "none"),
|
||||||
WildcardValue::PodId(pod_id) => write!(f, "{}", pod_id),
|
WildcardValue::PodId(pod_id) => write!(f, "{}", pod_id),
|
||||||
WildcardValue::Key(key) => write!(f, "{}", key),
|
WildcardValue::Key(key) => write!(f, "{}", key),
|
||||||
}
|
}
|
||||||
|
|
@ -77,10 +81,7 @@ impl fmt::Display for WildcardValue {
|
||||||
|
|
||||||
impl ToFields for WildcardValue {
|
impl ToFields for WildcardValue {
|
||||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||||
match self {
|
self.raw().to_fields(params)
|
||||||
WildcardValue::PodId(pod_id) => pod_id.to_fields(params),
|
|
||||||
WildcardValue::Key(key) => key.to_fields(params),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -130,7 +131,7 @@ impl ToFields for Predicate {
|
||||||
.collect(),
|
.collect(),
|
||||||
Self::Custom(CustomPredicateRef { batch, index }) => {
|
Self::Custom(CustomPredicateRef { batch, index }) => {
|
||||||
iter::once(F::from(PredicatePrefix::Custom))
|
iter::once(F::from(PredicatePrefix::Custom))
|
||||||
.chain(batch.id(params).0)
|
.chain(batch.id().0)
|
||||||
.chain(iter::once(F::from_canonical_usize(*index)))
|
.chain(iter::once(F::from_canonical_usize(*index)))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
@ -149,7 +150,9 @@ impl fmt::Display for Predicate {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"{}.{}[{}]",
|
"{}.{}[{}]",
|
||||||
batch.name, index, batch.predicates[*index].name
|
batch.name,
|
||||||
|
index,
|
||||||
|
batch.predicates()[*index].name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -397,14 +400,14 @@ impl StatementArg {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToFields for StatementArg {
|
impl ToFields for StatementArg {
|
||||||
fn to_fields(&self, _params: &Params) -> Vec<F> {
|
/// Encoding:
|
||||||
// NOTE: current version returns always the same amount of field elements in the returned
|
/// - None => [0, 0, 0, 0, 0, 0, 0, 0]
|
||||||
// vector, which means that the `None` case is padded with 8 zeroes, and the `Literal` case
|
/// - Literal(v) => [[v], 0, 0, 0, 0]
|
||||||
// is padded with 4 zeroes. Since the returned vector will mostly be hashed (and reproduced
|
/// - Key(pod_id, key) => [[pod_id], [key]]
|
||||||
// in-circuit), we might be interested into reducing the length of it. If that's the case,
|
/// - WildcardLiteral(v) => [[v], 0, 0, 0, 0]
|
||||||
// we can check if it makes sense to make it dependant on the concrete StatementArg; that
|
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||||
// is, when dealing with a `None` it would be a single field element (zero value), and when
|
// NOTE for @ax0: I removed the old comment because may `to_fields` implementations do
|
||||||
// dealing with `Literal` it would be of length 4.
|
// padding and we need fixed output length for the circuits.
|
||||||
let f = match self {
|
let f = match self {
|
||||||
StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN],
|
StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN],
|
||||||
StatementArg::Literal(v) => v
|
StatementArg::Literal(v) => v
|
||||||
|
|
@ -414,8 +417,8 @@ impl ToFields for StatementArg {
|
||||||
.chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE))
|
.chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE))
|
||||||
.collect(),
|
.collect(),
|
||||||
StatementArg::Key(ak) => {
|
StatementArg::Key(ak) => {
|
||||||
let mut fields = ak.pod_id.to_fields(_params);
|
let mut fields = ak.pod_id.to_fields(params);
|
||||||
fields.extend(ak.key.to_fields(_params));
|
fields.extend(ak.key.to_fields(params));
|
||||||
fields
|
fields
|
||||||
}
|
}
|
||||||
StatementArg::WildcardLiteral(v) => v
|
StatementArg::WildcardLiteral(v) => v
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue