Constraints for custom predicates (#227)
* add target types for custom predicates * simplify * fix clippy * fix typo * don't use ref for NativePredicate * fix wrong len * precalculate CustomPredicateBatch id * wip * wip * move code back * great progress * wip * code complete, hopefully; missing tests * fill aux for custom predicate op * fix clippy warnings * fix typos * fix test import * fix missing assignment in lt_mask, test custom_operation_verify_gadget * fix mistake * wip * fix * debug revert except for let entry = CustomPredicateVerifyEntryTarget * fix batch_id calculation by fixing padding * oops * remove completed TODOs
This commit is contained in:
parent
4fa9e20ecd
commit
024ed8bd04
12 changed files with 1597 additions and 291 deletions
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
use std::{array, iter};
|
||||
|
||||
use itertools::Itertools;
|
||||
use plonky2::{
|
||||
field::{
|
||||
extension::Extendable,
|
||||
|
|
@ -12,23 +13,28 @@ use plonky2::{
|
|||
poseidon::PoseidonHash,
|
||||
},
|
||||
iop::{
|
||||
generator::{GeneratedValues, SimpleGenerator},
|
||||
target::{BoolTarget, Target},
|
||||
witness::{PartialWitness, WitnessWrite},
|
||||
witness::{PartialWitness, PartitionWitness, Witness, WitnessWrite},
|
||||
},
|
||||
plonk::circuit_builder::CircuitBuilder,
|
||||
plonk::{circuit_builder::CircuitBuilder, circuit_data::CommonCircuitData},
|
||||
util::serialization::{Buffer, IoResult, Read, Write},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
backends::plonky2::{
|
||||
basetypes::D,
|
||||
circuits::mainpod::CustomPredicateVerification,
|
||||
error::Result,
|
||||
mainpod::{Operation, OperationArg, Statement},
|
||||
primitives::merkletree::MerkleClaimAndProofTarget,
|
||||
},
|
||||
middleware::{
|
||||
NativeOperation, NativePredicate, Params, Predicate, PredicatePrefix, RawValue,
|
||||
StatementArg, StatementTmplArgPrefix, ToFields, EMPTY_VALUE, F, HASH_SIZE,
|
||||
OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE,
|
||||
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
|
||||
NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg,
|
||||
StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, WildcardValue,
|
||||
EMPTY_VALUE, F, HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN,
|
||||
VALUE_SIZE,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -65,6 +71,10 @@ impl ValueTarget {
|
|||
elements: array::from_fn(|i| xs[i]),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_targets(&self, pw: &mut PartialWitness<F>, value: &Value) -> Result<()> {
|
||||
Ok(pw.set_target_arr(&self.elements, &value.raw().0)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
@ -82,7 +92,7 @@ impl StatementArgTarget {
|
|||
Ok(pw.set_target_arr(&self.elements, &arg.to_fields(params))?)
|
||||
}
|
||||
|
||||
fn new(first: ValueTarget, second: ValueTarget) -> Self {
|
||||
pub fn new(first: ValueTarget, second: ValueTarget) -> Self {
|
||||
let elements: Vec<_> = first.elements.into_iter().chain(second.elements).collect();
|
||||
StatementArgTarget {
|
||||
elements: elements.try_into().expect("size STATEMENT_ARG_F_LEN"),
|
||||
|
|
@ -107,6 +117,11 @@ impl StatementArgTarget {
|
|||
Self::new(*pod_id, *key)
|
||||
}
|
||||
|
||||
pub fn wildcard_literal(builder: &mut CircuitBuilder<F, D>, value: &ValueTarget) -> Self {
|
||||
let empty = builder.constant_value(EMPTY_VALUE);
|
||||
Self::new(*value, empty)
|
||||
}
|
||||
|
||||
/// StatementArgTarget to ValueTarget coercion. Make sure to check
|
||||
/// that the arg is a value using the `statement_arg_is_value` method
|
||||
/// first!
|
||||
|
|
@ -138,6 +153,7 @@ impl<T> Build<T> for T {
|
|||
}
|
||||
|
||||
impl StatementTarget {
|
||||
/// Build a new native StatementTarget
|
||||
pub fn new_native(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
params: &Params,
|
||||
|
|
@ -187,10 +203,60 @@ impl StatementTarget {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OperationTypeTarget {
|
||||
pub elements: [Target; Params::operation_type_size()],
|
||||
}
|
||||
|
||||
impl OperationTypeTarget {
|
||||
pub fn new_custom(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
batch_id: HashOutTarget,
|
||||
index: Target,
|
||||
) -> Self {
|
||||
// TODO: Use an enum for these prefixes
|
||||
let three = builder.constant(F::from_canonical_usize(3));
|
||||
let id = batch_id.elements;
|
||||
Self {
|
||||
elements: [three, id[0], id[1], id[2], id[3], index],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_custom(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
) -> (BoolTarget, HashOutTarget, Target) {
|
||||
// TODO: Use an enum for these prefixes
|
||||
let three = builder.constant(F::from_canonical_usize(3));
|
||||
let op_is_custom = builder.is_equal(self.elements[0], three);
|
||||
let batch_id = HashOutTarget::from_vec(self.elements[1..5].to_vec());
|
||||
let index = self.elements[5];
|
||||
(op_is_custom, batch_id, index)
|
||||
}
|
||||
|
||||
pub fn has_native(&self, builder: &mut CircuitBuilder<F, D>, t: NativeOperation) -> BoolTarget {
|
||||
// TODO: Use an enum for these prefixes
|
||||
let one = builder.one();
|
||||
let op_is_native = builder.is_equal(self.elements[0], one);
|
||||
let op_code = builder.constant(F::from_canonical_u64(t as u64));
|
||||
let op_code_matches = builder.is_equal(self.elements[1], op_code);
|
||||
builder.and(op_is_native, op_code_matches)
|
||||
}
|
||||
|
||||
pub fn set_targets(
|
||||
&self,
|
||||
pw: &mut PartialWitness<F>,
|
||||
params: &Params,
|
||||
op_type: &OperationType,
|
||||
) -> Result<()> {
|
||||
Ok(pw.set_target_arr(&self.elements, &op_type.to_fields(params))?)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Implement Operation::to_field to determine the size of each element
|
||||
#[derive(Clone)]
|
||||
pub struct OperationTarget {
|
||||
pub op_type: [Target; Params::operation_type_size()],
|
||||
pub op_type: OperationTypeTarget,
|
||||
pub args: Vec<[Target; OPERATION_ARG_F_LEN]>,
|
||||
pub aux: [Target; OPERATION_AUX_F_LEN],
|
||||
}
|
||||
|
|
@ -202,7 +268,7 @@ impl OperationTarget {
|
|||
params: &Params,
|
||||
op: &Operation,
|
||||
) -> Result<()> {
|
||||
pw.set_target_arr(&self.op_type, &op.op_type().to_fields(params))?;
|
||||
self.op_type.set_targets(pw, params, &op.op_type())?;
|
||||
for (i, arg) in op
|
||||
.args()
|
||||
.iter()
|
||||
|
|
@ -215,18 +281,6 @@ impl OperationTarget {
|
|||
pw.set_target_arr(&self.aux, &op.aux().to_fields(params))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn has_native_type(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
t: NativeOperation,
|
||||
) -> BoolTarget {
|
||||
let one = builder.one();
|
||||
let op_is_native = builder.is_equal(self.op_type[0], one);
|
||||
let op_code = builder.constant(F::from_canonical_u64(t as u64));
|
||||
let op_code_matches = builder.is_equal(self.op_type[1], op_code);
|
||||
builder.and(op_is_native, op_code_matches)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
@ -304,17 +358,37 @@ impl PredicateTarget {
|
|||
}
|
||||
}
|
||||
|
||||
/// Mirrors `middleware::KeyOrWildcard`
|
||||
#[derive(Clone)]
|
||||
pub struct KeyOrWildcardTarget {
|
||||
pub struct LiteralOrWildcardTarget {
|
||||
pub elements: [Target; VALUE_SIZE],
|
||||
}
|
||||
|
||||
impl KeyOrWildcardTarget {
|
||||
impl LiteralOrWildcardTarget {
|
||||
fn from_slice(v: &[Target]) -> Self {
|
||||
Self {
|
||||
elements: v.try_into().expect("len is VALUE_SIZE"),
|
||||
}
|
||||
}
|
||||
/// cases: ((is_key, key), (is_wildcard, wildcard_index))
|
||||
pub fn cases(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
) -> ((BoolTarget, ValueTarget), (BoolTarget, Target)) {
|
||||
let zero = builder.zero();
|
||||
let is_zero_tail: Vec<_> = (1..4)
|
||||
.map(|i| builder.is_equal(self.elements[i], zero))
|
||||
.collect();
|
||||
let is_wildcard = is_zero_tail
|
||||
.into_iter()
|
||||
.reduce(|acc, x| builder.and(acc, x))
|
||||
.expect("len > 1");
|
||||
let is_key = builder.not(is_wildcard);
|
||||
let key = ValueTarget::from_slice(&self.elements);
|
||||
let wildcard_index = self.elements[0];
|
||||
|
||||
((is_key, key), (is_wildcard, wildcard_index))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
@ -327,28 +401,40 @@ impl StatementTmplArgTarget {
|
|||
let prefix = builder.constant(F::from(StatementTmplArgPrefix::None));
|
||||
builder.is_equal(self.elements[0], prefix)
|
||||
}
|
||||
|
||||
pub fn as_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, ValueTarget) {
|
||||
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal));
|
||||
let case_ok = builder.is_equal(self.elements[0], prefix);
|
||||
let value = ValueTarget::from_slice(&self.elements[1..5]);
|
||||
(case_ok, value)
|
||||
}
|
||||
pub fn as_key(
|
||||
|
||||
pub fn as_anchored_key(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
) -> (BoolTarget, Target, KeyOrWildcardTarget) {
|
||||
let prefix = builder.constant(F::from(StatementTmplArgPrefix::Key));
|
||||
) -> (BoolTarget, Target, LiteralOrWildcardTarget) {
|
||||
let prefix = builder.constant(F::from(StatementTmplArgPrefix::AnchoredKey));
|
||||
let case_ok = builder.is_equal(self.elements[0], prefix);
|
||||
let id_wildcard_index = self.elements[1];
|
||||
let value_key_or_wildcard = KeyOrWildcardTarget::from_slice(&self.elements[5..9]);
|
||||
let value_key_or_wildcard = LiteralOrWildcardTarget::from_slice(&self.elements[5..9]);
|
||||
(case_ok, id_wildcard_index, value_key_or_wildcard)
|
||||
}
|
||||
|
||||
pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder<F, D>) -> (BoolTarget, Target) {
|
||||
let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral));
|
||||
let case_ok = builder.is_equal(self.elements[0], prefix);
|
||||
let wildcard_index = self.elements[1];
|
||||
(case_ok, wildcard_index)
|
||||
}
|
||||
|
||||
pub fn set_targets(
|
||||
&self,
|
||||
pw: &mut PartialWitness<F>,
|
||||
params: &Params,
|
||||
st_tmpl_arg: &StatementTmplArg,
|
||||
) -> Result<()> {
|
||||
Ok(pw.set_target_arr(&self.elements, &st_tmpl_arg.to_fields(params))?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
@ -357,6 +443,17 @@ pub struct StatementTmplTarget {
|
|||
pub args: Vec<StatementTmplArgTarget>,
|
||||
}
|
||||
|
||||
impl StatementTmplTarget {
|
||||
pub fn set_targets(
|
||||
&self,
|
||||
pw: &mut PartialWitness<F>,
|
||||
params: &Params,
|
||||
st_tmpl: &StatementTmpl,
|
||||
) -> Result<()> {
|
||||
Ok(pw.set_target_arr(&self.flatten(), &st_tmpl.to_fields(params))?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CustomPredicateTarget {
|
||||
pub conjunction: BoolTarget,
|
||||
|
|
@ -365,6 +462,17 @@ pub struct CustomPredicateTarget {
|
|||
pub args_len: Target,
|
||||
}
|
||||
|
||||
impl CustomPredicateTarget {
|
||||
pub fn set_targets(
|
||||
&self,
|
||||
pw: &mut PartialWitness<F>,
|
||||
params: &Params,
|
||||
custom_predicate: &CustomPredicate,
|
||||
) -> Result<()> {
|
||||
Ok(pw.set_target_arr(&self.flatten(), &custom_predicate.to_fields(params))?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CustomPredicateBatchTarget {
|
||||
pub predicates: Vec<CustomPredicateTarget>,
|
||||
|
|
@ -375,6 +483,161 @@ impl CustomPredicateBatchTarget {
|
|||
let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect();
|
||||
builder.hash_n_to_hash_no_pad::<PoseidonHash>(flattened)
|
||||
}
|
||||
|
||||
pub fn set_targets(
|
||||
&self,
|
||||
pw: &mut PartialWitness<F>,
|
||||
params: &Params,
|
||||
custom_predicate_batch: &CustomPredicateBatch,
|
||||
) -> Result<()> {
|
||||
let pad_predicate = CustomPredicate::empty();
|
||||
for (i, predicate) in custom_predicate_batch
|
||||
.predicates()
|
||||
.iter()
|
||||
.chain(iter::repeat(&pad_predicate))
|
||||
.take(params.max_custom_batch_size)
|
||||
.enumerate()
|
||||
{
|
||||
self.predicates[i].set_targets(pw, params, predicate)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom predicate table entry
|
||||
pub struct CustomPredicateEntryTarget {
|
||||
pub id: HashOutTarget,
|
||||
pub index: Target,
|
||||
pub predicate: CustomPredicateTarget,
|
||||
}
|
||||
|
||||
impl CustomPredicateEntryTarget {
|
||||
pub fn set_targets(
|
||||
&self,
|
||||
pw: &mut PartialWitness<F>,
|
||||
params: &Params,
|
||||
predicate: &CustomPredicateRef,
|
||||
) -> Result<()> {
|
||||
pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?;
|
||||
pw.set_target(self.index, F::from_canonical_usize(predicate.index))?;
|
||||
self.predicate
|
||||
.set_targets(pw, params, predicate.predicate())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Flattenable for CustomPredicateEntryTarget {
|
||||
fn flatten(&self) -> Vec<Target> {
|
||||
self.id
|
||||
.elements
|
||||
.iter()
|
||||
.chain(iter::once(&self.index))
|
||||
.chain(self.predicate.flatten().iter())
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
|
||||
Self {
|
||||
id: HashOutTarget::from_flattened(params, &vs[0..4]),
|
||||
index: vs[4],
|
||||
predicate: CustomPredicateTarget::from_flattened(params, &vs[5..]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomPredicateEntryTarget {
|
||||
pub fn hash(&self, builder: &mut CircuitBuilder<F, D>) -> HashOutTarget {
|
||||
builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.flatten())
|
||||
}
|
||||
}
|
||||
|
||||
// Custom predicate verification table entry
|
||||
pub struct CustomPredicateVerifyEntryTarget {
|
||||
pub custom_predicate_table_index: Target,
|
||||
pub custom_predicate: CustomPredicateEntryTarget,
|
||||
pub args: Vec<ValueTarget>,
|
||||
pub query: CustomPredicateVerifyQueryTarget,
|
||||
}
|
||||
|
||||
impl CustomPredicateVerifyEntryTarget {
|
||||
pub fn set_targets(
|
||||
&self,
|
||||
pw: &mut PartialWitness<F>,
|
||||
params: &Params,
|
||||
cpv: &CustomPredicateVerification,
|
||||
) -> Result<()> {
|
||||
pw.set_target(
|
||||
self.custom_predicate_table_index,
|
||||
F::from_canonical_usize(cpv.custom_predicate_table_index),
|
||||
)?;
|
||||
self.custom_predicate
|
||||
.set_targets(pw, params, &cpv.custom_predicate)?;
|
||||
let pad_arg = WildcardValue::None;
|
||||
for (arg_target, arg) in self.args.iter().zip_eq(
|
||||
cpv.args
|
||||
.iter()
|
||||
.chain(iter::repeat(&pad_arg))
|
||||
.take(params.max_custom_predicate_wildcards),
|
||||
) {
|
||||
arg_target.set_targets(pw, &Value::from(arg.raw()))?;
|
||||
}
|
||||
let pad_op_arg = Statement(Predicate::Native(NativePredicate::None), vec![]);
|
||||
for (op_arg_target, op_arg) in self.query.op_args.iter().zip_eq(
|
||||
cpv.op_args
|
||||
.iter()
|
||||
.chain(iter::repeat(&pad_op_arg))
|
||||
.take(params.max_operation_args),
|
||||
) {
|
||||
op_arg_target.set_targets(pw, params, op_arg)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Query for the custom predicate verification table
|
||||
pub struct CustomPredicateVerifyQueryTarget {
|
||||
pub statement: StatementTarget,
|
||||
pub op_type: OperationTypeTarget,
|
||||
pub op_args: Vec<StatementTarget>,
|
||||
}
|
||||
|
||||
impl CustomPredicateVerifyQueryTarget {
|
||||
pub fn hash(&self, builder: &mut CircuitBuilder<F, D>) -> HashOutTarget {
|
||||
builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.flatten())
|
||||
}
|
||||
}
|
||||
|
||||
impl Flattenable for CustomPredicateVerifyQueryTarget {
|
||||
fn flatten(&self) -> Vec<Target> {
|
||||
self.statement
|
||||
.flatten()
|
||||
.iter()
|
||||
.chain(self.op_type.elements.iter())
|
||||
.cloned()
|
||||
.chain(self.op_args.iter().flat_map(|op_arg| op_arg.flatten()))
|
||||
.collect()
|
||||
}
|
||||
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
|
||||
let (pos, size) = (0, params.statement_size());
|
||||
let statement = StatementTarget::from_flattened(params, &vs[pos..pos + size]);
|
||||
let (pos, size) = (pos + size, params.operation_size());
|
||||
let op_type = OperationTypeTarget {
|
||||
elements: vs[pos..pos + size]
|
||||
.try_into()
|
||||
.expect("len = operation_type_size"),
|
||||
};
|
||||
let (pos, size) = (pos + size, params.statement_size());
|
||||
let op_args = (0..params.max_operation_args)
|
||||
.map(|i| {
|
||||
StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size])
|
||||
})
|
||||
.collect();
|
||||
Self {
|
||||
statement,
|
||||
op_type,
|
||||
op_args,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for target structs that may be converted to and from vectors
|
||||
|
|
@ -408,6 +671,27 @@ impl From<MerkleClaimAndProofTarget> for MerkleClaimTarget {
|
|||
}
|
||||
}
|
||||
|
||||
impl Flattenable for HashOutTarget {
|
||||
fn flatten(&self) -> Vec<Target> {
|
||||
self.elements.to_vec()
|
||||
}
|
||||
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
|
||||
assert_eq!(vs.len(), HASH_SIZE);
|
||||
Self {
|
||||
elements: array::from_fn(|i| vs[i]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Flattenable for ValueTarget {
|
||||
fn flatten(&self) -> Vec<Target> {
|
||||
self.elements.to_vec()
|
||||
}
|
||||
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
|
||||
Self::from_slice(vs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Flattenable for MerkleClaimTarget {
|
||||
fn flatten(&self) -> Vec<Target> {
|
||||
[
|
||||
|
|
@ -543,8 +827,17 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
|||
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
|
||||
fn add_virtual_value(&mut self) -> ValueTarget;
|
||||
fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget;
|
||||
fn add_virtual_statement_arg(&mut self) -> StatementArgTarget;
|
||||
fn add_virtual_predicate(&mut self) -> PredicateTarget;
|
||||
fn add_virtual_operation_type(&mut self) -> OperationTypeTarget;
|
||||
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget;
|
||||
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget;
|
||||
fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget;
|
||||
fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget;
|
||||
fn add_virtual_custom_predicate_batch(&mut self, params: &Params)
|
||||
-> CustomPredicateBatchTarget;
|
||||
fn add_virtual_custom_predicate_entry(&mut self, params: &Params)
|
||||
-> CustomPredicateEntryTarget;
|
||||
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
||||
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget;
|
||||
fn constant_value(&mut self, v: RawValue) -> ValueTarget;
|
||||
|
|
@ -604,6 +897,9 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
|||
// Convenience methods for Boolean into-iters.
|
||||
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
|
||||
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
|
||||
|
||||
// Return a bit-mask of size `len` that selects all positions lower than `n`
|
||||
fn lt_mask(&mut self, len: usize, n: Target) -> Vec<BoolTarget>;
|
||||
}
|
||||
|
||||
impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
||||
|
|
@ -629,22 +925,32 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
|||
StatementTarget {
|
||||
predicate,
|
||||
args: (0..params.max_statement_args)
|
||||
.map(|_| StatementArgTarget {
|
||||
elements: self.add_virtual_target_arr(),
|
||||
})
|
||||
.map(|_| self.add_virtual_statement_arg())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_virtual_statement_arg(&mut self) -> StatementArgTarget {
|
||||
StatementArgTarget {
|
||||
elements: self.add_virtual_target_arr(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_virtual_predicate(&mut self) -> PredicateTarget {
|
||||
PredicateTarget {
|
||||
elements: self.add_virtual_target_arr(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_virtual_operation_type(&mut self) -> OperationTypeTarget {
|
||||
OperationTypeTarget {
|
||||
elements: self.add_virtual_target_arr(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget {
|
||||
OperationTarget {
|
||||
op_type: self.add_virtual_target_arr(),
|
||||
op_type: self.add_virtual_operation_type(),
|
||||
args: (0..params.max_operation_args)
|
||||
.map(|_| self.add_virtual_target_arr())
|
||||
.collect(),
|
||||
|
|
@ -652,6 +958,55 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
|||
}
|
||||
}
|
||||
|
||||
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget {
|
||||
StatementTmplArgTarget {
|
||||
elements: self.add_virtual_target_arr(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget {
|
||||
let args = (0..params.max_statement_args)
|
||||
.map(|_| self.add_virtual_statement_tmpl_arg())
|
||||
.collect();
|
||||
StatementTmplTarget {
|
||||
pred: self.add_virtual_predicate(),
|
||||
args,
|
||||
}
|
||||
}
|
||||
|
||||
fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget {
|
||||
let statements = (0..params.max_custom_predicate_arity)
|
||||
.map(|_| self.add_virtual_statement_tmpl(params))
|
||||
.collect();
|
||||
CustomPredicateTarget {
|
||||
conjunction: self.add_virtual_bool_target_safe(),
|
||||
statements,
|
||||
args_len: self.add_virtual_target(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_virtual_custom_predicate_batch(
|
||||
&mut self,
|
||||
params: &Params,
|
||||
) -> CustomPredicateBatchTarget {
|
||||
CustomPredicateBatchTarget {
|
||||
predicates: (0..params.max_custom_batch_size)
|
||||
.map(|_| self.add_virtual_custom_predicate(params))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_virtual_custom_predicate_entry(
|
||||
&mut self,
|
||||
params: &Params,
|
||||
) -> CustomPredicateEntryTarget {
|
||||
CustomPredicateEntryTarget {
|
||||
id: self.add_virtual_hash(),
|
||||
index: self.add_virtual_target(),
|
||||
predicate: self.add_virtual_custom_predicate(params),
|
||||
}
|
||||
}
|
||||
|
||||
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget {
|
||||
ValueTarget {
|
||||
elements: std::array::from_fn(|i| self.select(b, x.elements[i], y.elements[i])),
|
||||
|
|
@ -876,6 +1231,12 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
|||
)
|
||||
}
|
||||
|
||||
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
|
||||
// The idea would be the following: Take the array `ts` and hash each element. Then do the
|
||||
// random access on the hash result. Finally "unhash" to recover the resolved element.
|
||||
// We don't want to hash each element from the array each time, so we should cache the hashed
|
||||
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
|
||||
// then do `ts: &[HashCache<T>]`.
|
||||
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
|
||||
// TODO: Revisit this when we need more than 64 statements.
|
||||
let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| {
|
||||
|
|
@ -944,6 +1305,73 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
|||
.reduce(|a, b| self.or(a, b))
|
||||
.unwrap_or(self._false())
|
||||
}
|
||||
|
||||
fn lt_mask(&mut self, len: usize, n: Target) -> Vec<BoolTarget> {
|
||||
let zero = self.zero();
|
||||
let mask: Vec<_> = (0..len)
|
||||
.map(|_| self.add_virtual_bool_target_safe())
|
||||
.collect();
|
||||
self.add_simple_generator(LtMaskGenerator {
|
||||
n,
|
||||
mask: mask.iter().map(|bt| bt.target).collect(),
|
||||
});
|
||||
// We have `n` ones in the mask
|
||||
let mask_sum = mask
|
||||
.iter()
|
||||
.map(|b| b.target)
|
||||
.reduce(|acc, x| self.add(acc, x))
|
||||
.unwrap_or(zero);
|
||||
self.connect(n, mask_sum);
|
||||
|
||||
// The elements in the mask can only transition from 1 to 0 or 0 to 0.
|
||||
for i in 0..len - 1 {
|
||||
let diff = self.sub(mask[i].target, mask[i + 1].target);
|
||||
self.assert_bool(BoolTarget::new_unsafe(diff));
|
||||
}
|
||||
|
||||
mask
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct LtMaskGenerator {
|
||||
pub(crate) n: Target,
|
||||
pub(crate) mask: Vec<Target>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for LtMaskGenerator {
|
||||
fn id(&self) -> String {
|
||||
"LtMaskGenerator".to_string()
|
||||
}
|
||||
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
vec![self.n]
|
||||
}
|
||||
|
||||
fn run_once(
|
||||
&self,
|
||||
witness: &PartitionWitness<F>,
|
||||
out_buffer: &mut GeneratedValues<F>,
|
||||
) -> anyhow::Result<()> {
|
||||
let n = witness.get_target(self.n).to_canonical_u64();
|
||||
|
||||
for (i, mask_i) in self.mask.iter().enumerate() {
|
||||
let v = if (i as u64) < n { F::ONE } else { F::ZERO };
|
||||
out_buffer.set_target(*mask_i, v)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData<F, D>) -> IoResult<()> {
|
||||
dst.write_target(self.n)?;
|
||||
dst.write_target_vec(&self.mask)
|
||||
}
|
||||
|
||||
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
|
||||
let n = src.read_target()?;
|
||||
let mask = src.read_target_vec()?;
|
||||
Ok(Self { n, mask })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -1013,13 +1441,14 @@ pub(crate) mod tests {
|
|||
|
||||
let custom_predicate_batch = eth_friend_batch(¶ms)?;
|
||||
|
||||
for (i, cp) in custom_predicate_batch.predicates.iter().enumerate() {
|
||||
for (i, cp) in custom_predicate_batch.predicates().iter().enumerate() {
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config.clone());
|
||||
let flattened = cp.to_fields(¶ms);
|
||||
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
|
||||
let cp_target = CustomPredicateTarget::from_flattened(¶ms, &flatteend_target);
|
||||
// Round trip of from_flattened to flattened
|
||||
let flatteend_target_rt = cp_target.flatten();
|
||||
// TODO: Instead of connect, assign witness to result
|
||||
builder.connect_slice(&flatteend_target, &flatteend_target_rt);
|
||||
|
||||
let pw = PartialWitness::<F>::new();
|
||||
|
|
@ -1033,51 +1462,22 @@ pub(crate) mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn test_custom_predicate_batch_target_id(
|
||||
fn helper_custom_predicate_batch_target_id(
|
||||
params: &Params,
|
||||
custom_predicate_batch: &CustomPredicateBatch,
|
||||
) -> frontend::Result<()> {
|
||||
) -> Result<()> {
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let zero = builder.zero();
|
||||
let predicate_targets = custom_predicate_batch
|
||||
.predicates
|
||||
.iter()
|
||||
.map(|cp| {
|
||||
let flattened = cp.to_fields(params);
|
||||
let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec();
|
||||
CustomPredicateTarget::from_flattened(params, &flatteend_target)
|
||||
})
|
||||
.chain(iter::repeat({
|
||||
let empty_flatteend_target = iter::repeat(zero)
|
||||
.take(params.custom_predicate_size())
|
||||
.collect_vec();
|
||||
CustomPredicateTarget::from_flattened(params, &empty_flatteend_target)
|
||||
}))
|
||||
.take(params.max_custom_batch_size)
|
||||
.collect();
|
||||
|
||||
let custom_predicate_batch_target = CustomPredicateBatchTarget {
|
||||
predicates: predicate_targets,
|
||||
};
|
||||
let custom_predicate_batch_target = builder.add_virtual_custom_predicate_batch(params);
|
||||
|
||||
// Calculate the id in constraints and compare it against the id calculated natively
|
||||
let id_target = custom_predicate_batch_target.id(&mut builder);
|
||||
let id = custom_predicate_batch.id(params);
|
||||
|
||||
let id_expected_target = HashOutTarget {
|
||||
elements: id
|
||||
.to_fields(params)
|
||||
.iter()
|
||||
.map(|v| builder.constant(*v))
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
};
|
||||
builder.connect_array(id_target.elements, id_expected_target.elements);
|
||||
|
||||
let pw = PartialWitness::<F>::new();
|
||||
let mut pw = PartialWitness::<F>::new();
|
||||
custom_predicate_batch_target.set_targets(&mut pw, params, custom_predicate_batch)?;
|
||||
let id = custom_predicate_batch.id();
|
||||
pw.set_target_arr(&id_target.elements, &id.0)?;
|
||||
|
||||
// generate & verify proof
|
||||
let data = builder.build::<C>();
|
||||
|
|
@ -1088,7 +1488,7 @@ pub(crate) mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn custom_predicate_batch_target() -> frontend::Result<()> {
|
||||
fn test_custom_predicate_batch_target_id() -> frontend::Result<()> {
|
||||
let params = Params {
|
||||
max_statement_args: 6,
|
||||
max_custom_predicate_wildcards: 12,
|
||||
|
|
@ -1096,17 +1496,21 @@ pub(crate) mod tests {
|
|||
};
|
||||
|
||||
// Empty case
|
||||
let mut cpb_builder = CustomPredicateBatchBuilder::new("empty".into());
|
||||
_ = cpb_builder.predicate_and("empty", ¶ms, &[], &[], &[])?;
|
||||
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
|
||||
_ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
|
||||
let custom_predicate_batch = cpb_builder.finish();
|
||||
test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?;
|
||||
helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap();
|
||||
|
||||
// Some cases from the examples
|
||||
let custom_predicate_batch = eth_friend_batch(¶ms)?;
|
||||
test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?;
|
||||
helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap();
|
||||
|
||||
let custom_predicate_batch = eth_dos_batch(¶ms)?;
|
||||
test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?;
|
||||
helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap();
|
||||
|
||||
let custom_predicate_batch =
|
||||
CustomPredicateBatch::new(¶ms, "empty".to_string(), vec![CustomPredicate::empty()]);
|
||||
helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue