From 024ed8bd04298a226ed4b7ef87ca984c0d24687c Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Tue, 13 May 2025 11:00:45 +0200 Subject: [PATCH] Constraints for custom predicates (#227) * add target types for custom predicates * simplify * fix clippy * fix typo * don't use ref for NativePredicate * fix wrong len * precalculate CustomPredicateBatch id * wip * wip * move code back * great progress * wip * code complete, hopefully; missing tests * fill aux for custom predicate op * fix clippy warnings * fix typos * fix test import * fix missing assignment in lt_mask, test custom_operation_verify_gadget * fix mistake * wip * fix * debug revert except for let entry = CustomPredicateVerifyEntryTarget * fix batch_id calculation by fixing padding * oops * remove completed TODOs --- src/backends/plonky2/circuits/common.rs | 550 ++++++++++++++-- src/backends/plonky2/circuits/mainpod.rs | 770 ++++++++++++++++++++-- src/backends/plonky2/mainpod/mod.rs | 152 ++++- src/backends/plonky2/mainpod/operation.rs | 12 +- src/backends/plonky2/mock/mainpod.rs | 1 + src/examples/custom.rs | 9 +- src/frontend/custom.rs | 46 +- src/frontend/mod.rs | 2 +- src/middleware/custom.rs | 201 ++++-- src/middleware/mod.rs | 6 + src/middleware/operation.rs | 76 ++- src/middleware/statement.rs | 63 +- 12 files changed, 1597 insertions(+), 291 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index ef16bd7..440803a 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -2,6 +2,7 @@ use std::{array, iter}; +use itertools::Itertools; use plonky2::{ field::{ extension::Extendable, @@ -12,23 +13,28 @@ use plonky2::{ poseidon::PoseidonHash, }, iop::{ + generator::{GeneratedValues, SimpleGenerator}, target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, + witness::{PartialWitness, PartitionWitness, Witness, WitnessWrite}, }, - plonk::circuit_builder::CircuitBuilder, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CommonCircuitData}, + util::serialization::{Buffer, IoResult, Read, Write}, }; use crate::{ backends::plonky2::{ basetypes::D, + circuits::mainpod::CustomPredicateVerification, error::Result, mainpod::{Operation, OperationArg, Statement}, primitives::merkletree::MerkleClaimAndProofTarget, }, middleware::{ - NativeOperation, NativePredicate, Params, Predicate, PredicatePrefix, RawValue, - StatementArg, StatementTmplArgPrefix, ToFields, EMPTY_VALUE, F, HASH_SIZE, - OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE, + CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, + NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg, + StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, WildcardValue, + EMPTY_VALUE, F, HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, + VALUE_SIZE, }, }; @@ -65,6 +71,10 @@ impl ValueTarget { elements: array::from_fn(|i| xs[i]), } } + + pub fn set_targets(&self, pw: &mut PartialWitness, value: &Value) -> Result<()> { + Ok(pw.set_target_arr(&self.elements, &value.raw().0)?) + } } #[derive(Clone)] @@ -82,7 +92,7 @@ impl StatementArgTarget { Ok(pw.set_target_arr(&self.elements, &arg.to_fields(params))?) } - fn new(first: ValueTarget, second: ValueTarget) -> Self { + pub fn new(first: ValueTarget, second: ValueTarget) -> Self { let elements: Vec<_> = first.elements.into_iter().chain(second.elements).collect(); StatementArgTarget { elements: elements.try_into().expect("size STATEMENT_ARG_F_LEN"), @@ -107,6 +117,11 @@ impl StatementArgTarget { Self::new(*pod_id, *key) } + pub fn wildcard_literal(builder: &mut CircuitBuilder, value: &ValueTarget) -> Self { + let empty = builder.constant_value(EMPTY_VALUE); + Self::new(*value, empty) + } + /// StatementArgTarget to ValueTarget coercion. Make sure to check /// that the arg is a value using the `statement_arg_is_value` method /// first! @@ -138,6 +153,7 @@ impl Build for T { } impl StatementTarget { + /// Build a new native StatementTarget pub fn new_native( builder: &mut CircuitBuilder, params: &Params, @@ -187,10 +203,60 @@ impl StatementTarget { } } +#[derive(Clone)] +pub struct OperationTypeTarget { + pub elements: [Target; Params::operation_type_size()], +} + +impl OperationTypeTarget { + pub fn new_custom( + builder: &mut CircuitBuilder, + batch_id: HashOutTarget, + index: Target, + ) -> Self { + // TODO: Use an enum for these prefixes + let three = builder.constant(F::from_canonical_usize(3)); + let id = batch_id.elements; + Self { + elements: [three, id[0], id[1], id[2], id[3], index], + } + } + + pub fn as_custom( + &self, + builder: &mut CircuitBuilder, + ) -> (BoolTarget, HashOutTarget, Target) { + // TODO: Use an enum for these prefixes + let three = builder.constant(F::from_canonical_usize(3)); + let op_is_custom = builder.is_equal(self.elements[0], three); + let batch_id = HashOutTarget::from_vec(self.elements[1..5].to_vec()); + let index = self.elements[5]; + (op_is_custom, batch_id, index) + } + + pub fn has_native(&self, builder: &mut CircuitBuilder, t: NativeOperation) -> BoolTarget { + // TODO: Use an enum for these prefixes + let one = builder.one(); + let op_is_native = builder.is_equal(self.elements[0], one); + let op_code = builder.constant(F::from_canonical_u64(t as u64)); + let op_code_matches = builder.is_equal(self.elements[1], op_code); + builder.and(op_is_native, op_code_matches) + } + + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + op_type: &OperationType, + ) -> Result<()> { + Ok(pw.set_target_arr(&self.elements, &op_type.to_fields(params))?) + } +} + // TODO: Implement Operation::to_field to determine the size of each element #[derive(Clone)] pub struct OperationTarget { - pub op_type: [Target; Params::operation_type_size()], + pub op_type: OperationTypeTarget, pub args: Vec<[Target; OPERATION_ARG_F_LEN]>, pub aux: [Target; OPERATION_AUX_F_LEN], } @@ -202,7 +268,7 @@ impl OperationTarget { params: &Params, op: &Operation, ) -> Result<()> { - pw.set_target_arr(&self.op_type, &op.op_type().to_fields(params))?; + self.op_type.set_targets(pw, params, &op.op_type())?; for (i, arg) in op .args() .iter() @@ -215,18 +281,6 @@ impl OperationTarget { pw.set_target_arr(&self.aux, &op.aux().to_fields(params))?; Ok(()) } - - pub fn has_native_type( - &self, - builder: &mut CircuitBuilder, - t: NativeOperation, - ) -> BoolTarget { - let one = builder.one(); - let op_is_native = builder.is_equal(self.op_type[0], one); - let op_code = builder.constant(F::from_canonical_u64(t as u64)); - let op_code_matches = builder.is_equal(self.op_type[1], op_code); - builder.and(op_is_native, op_code_matches) - } } #[derive(Clone)] @@ -304,17 +358,37 @@ impl PredicateTarget { } } +/// Mirrors `middleware::KeyOrWildcard` #[derive(Clone)] -pub struct KeyOrWildcardTarget { +pub struct LiteralOrWildcardTarget { pub elements: [Target; VALUE_SIZE], } -impl KeyOrWildcardTarget { +impl LiteralOrWildcardTarget { fn from_slice(v: &[Target]) -> Self { Self { elements: v.try_into().expect("len is VALUE_SIZE"), } } + /// cases: ((is_key, key), (is_wildcard, wildcard_index)) + pub fn cases( + &self, + builder: &mut CircuitBuilder, + ) -> ((BoolTarget, ValueTarget), (BoolTarget, Target)) { + let zero = builder.zero(); + let is_zero_tail: Vec<_> = (1..4) + .map(|i| builder.is_equal(self.elements[i], zero)) + .collect(); + let is_wildcard = is_zero_tail + .into_iter() + .reduce(|acc, x| builder.and(acc, x)) + .expect("len > 1"); + let is_key = builder.not(is_wildcard); + let key = ValueTarget::from_slice(&self.elements); + let wildcard_index = self.elements[0]; + + ((is_key, key), (is_wildcard, wildcard_index)) + } } #[derive(Clone)] @@ -327,28 +401,40 @@ impl StatementTmplArgTarget { let prefix = builder.constant(F::from(StatementTmplArgPrefix::None)); builder.is_equal(self.elements[0], prefix) } + pub fn as_literal(&self, builder: &mut CircuitBuilder) -> (BoolTarget, ValueTarget) { let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal)); let case_ok = builder.is_equal(self.elements[0], prefix); let value = ValueTarget::from_slice(&self.elements[1..5]); (case_ok, value) } - pub fn as_key( + + pub fn as_anchored_key( &self, builder: &mut CircuitBuilder, - ) -> (BoolTarget, Target, KeyOrWildcardTarget) { - let prefix = builder.constant(F::from(StatementTmplArgPrefix::Key)); + ) -> (BoolTarget, Target, LiteralOrWildcardTarget) { + let prefix = builder.constant(F::from(StatementTmplArgPrefix::AnchoredKey)); let case_ok = builder.is_equal(self.elements[0], prefix); let id_wildcard_index = self.elements[1]; - let value_key_or_wildcard = KeyOrWildcardTarget::from_slice(&self.elements[5..9]); + let value_key_or_wildcard = LiteralOrWildcardTarget::from_slice(&self.elements[5..9]); (case_ok, id_wildcard_index, value_key_or_wildcard) } + pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder) -> (BoolTarget, Target) { let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral)); let case_ok = builder.is_equal(self.elements[0], prefix); let wildcard_index = self.elements[1]; (case_ok, wildcard_index) } + + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + st_tmpl_arg: &StatementTmplArg, + ) -> Result<()> { + Ok(pw.set_target_arr(&self.elements, &st_tmpl_arg.to_fields(params))?) + } } #[derive(Clone)] @@ -357,6 +443,17 @@ pub struct StatementTmplTarget { pub args: Vec, } +impl StatementTmplTarget { + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + st_tmpl: &StatementTmpl, + ) -> Result<()> { + Ok(pw.set_target_arr(&self.flatten(), &st_tmpl.to_fields(params))?) + } +} + #[derive(Clone)] pub struct CustomPredicateTarget { pub conjunction: BoolTarget, @@ -365,6 +462,17 @@ pub struct CustomPredicateTarget { pub args_len: Target, } +impl CustomPredicateTarget { + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + custom_predicate: &CustomPredicate, + ) -> Result<()> { + Ok(pw.set_target_arr(&self.flatten(), &custom_predicate.to_fields(params))?) + } +} + #[derive(Clone)] pub struct CustomPredicateBatchTarget { pub predicates: Vec, @@ -375,6 +483,161 @@ impl CustomPredicateBatchTarget { let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect(); builder.hash_n_to_hash_no_pad::(flattened) } + + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + custom_predicate_batch: &CustomPredicateBatch, + ) -> Result<()> { + let pad_predicate = CustomPredicate::empty(); + for (i, predicate) in custom_predicate_batch + .predicates() + .iter() + .chain(iter::repeat(&pad_predicate)) + .take(params.max_custom_batch_size) + .enumerate() + { + self.predicates[i].set_targets(pw, params, predicate)?; + } + Ok(()) + } +} + +/// Custom predicate table entry +pub struct CustomPredicateEntryTarget { + pub id: HashOutTarget, + pub index: Target, + pub predicate: CustomPredicateTarget, +} + +impl CustomPredicateEntryTarget { + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + predicate: &CustomPredicateRef, + ) -> Result<()> { + pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?; + pw.set_target(self.index, F::from_canonical_usize(predicate.index))?; + self.predicate + .set_targets(pw, params, predicate.predicate())?; + Ok(()) + } +} + +impl Flattenable for CustomPredicateEntryTarget { + fn flatten(&self) -> Vec { + self.id + .elements + .iter() + .chain(iter::once(&self.index)) + .chain(self.predicate.flatten().iter()) + .cloned() + .collect() + } + fn from_flattened(params: &Params, vs: &[Target]) -> Self { + Self { + id: HashOutTarget::from_flattened(params, &vs[0..4]), + index: vs[4], + predicate: CustomPredicateTarget::from_flattened(params, &vs[5..]), + } + } +} + +impl CustomPredicateEntryTarget { + pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget { + builder.hash_n_to_hash_no_pad::(self.flatten()) + } +} + +// Custom predicate verification table entry +pub struct CustomPredicateVerifyEntryTarget { + pub custom_predicate_table_index: Target, + pub custom_predicate: CustomPredicateEntryTarget, + pub args: Vec, + pub query: CustomPredicateVerifyQueryTarget, +} + +impl CustomPredicateVerifyEntryTarget { + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + cpv: &CustomPredicateVerification, + ) -> Result<()> { + pw.set_target( + self.custom_predicate_table_index, + F::from_canonical_usize(cpv.custom_predicate_table_index), + )?; + self.custom_predicate + .set_targets(pw, params, &cpv.custom_predicate)?; + let pad_arg = WildcardValue::None; + for (arg_target, arg) in self.args.iter().zip_eq( + cpv.args + .iter() + .chain(iter::repeat(&pad_arg)) + .take(params.max_custom_predicate_wildcards), + ) { + arg_target.set_targets(pw, &Value::from(arg.raw()))?; + } + let pad_op_arg = Statement(Predicate::Native(NativePredicate::None), vec![]); + for (op_arg_target, op_arg) in self.query.op_args.iter().zip_eq( + cpv.op_args + .iter() + .chain(iter::repeat(&pad_op_arg)) + .take(params.max_operation_args), + ) { + op_arg_target.set_targets(pw, params, op_arg)? + } + Ok(()) + } +} + +/// Query for the custom predicate verification table +pub struct CustomPredicateVerifyQueryTarget { + pub statement: StatementTarget, + pub op_type: OperationTypeTarget, + pub op_args: Vec, +} + +impl CustomPredicateVerifyQueryTarget { + pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget { + builder.hash_n_to_hash_no_pad::(self.flatten()) + } +} + +impl Flattenable for CustomPredicateVerifyQueryTarget { + fn flatten(&self) -> Vec { + self.statement + .flatten() + .iter() + .chain(self.op_type.elements.iter()) + .cloned() + .chain(self.op_args.iter().flat_map(|op_arg| op_arg.flatten())) + .collect() + } + fn from_flattened(params: &Params, vs: &[Target]) -> Self { + let (pos, size) = (0, params.statement_size()); + let statement = StatementTarget::from_flattened(params, &vs[pos..pos + size]); + let (pos, size) = (pos + size, params.operation_size()); + let op_type = OperationTypeTarget { + elements: vs[pos..pos + size] + .try_into() + .expect("len = operation_type_size"), + }; + let (pos, size) = (pos + size, params.statement_size()); + let op_args = (0..params.max_operation_args) + .map(|i| { + StatementTarget::from_flattened(params, &vs[pos + i * size..pos + (1 + i) * size]) + }) + .collect(); + Self { + statement, + op_type, + op_args, + } + } } /// Trait for target structs that may be converted to and from vectors @@ -408,6 +671,27 @@ impl From for MerkleClaimTarget { } } +impl Flattenable for HashOutTarget { + fn flatten(&self) -> Vec { + self.elements.to_vec() + } + fn from_flattened(_params: &Params, vs: &[Target]) -> Self { + assert_eq!(vs.len(), HASH_SIZE); + Self { + elements: array::from_fn(|i| vs[i]), + } + } +} + +impl Flattenable for ValueTarget { + fn flatten(&self) -> Vec { + self.elements.to_vec() + } + fn from_flattened(_params: &Params, vs: &[Target]) -> Self { + Self::from_slice(vs) + } +} + impl Flattenable for MerkleClaimTarget { fn flatten(&self) -> Vec { [ @@ -543,8 +827,17 @@ pub trait CircuitBuilderPod, const D: usize> { fn connect_slice(&mut self, xs: &[Target], ys: &[Target]); fn add_virtual_value(&mut self) -> ValueTarget; fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget; + fn add_virtual_statement_arg(&mut self) -> StatementArgTarget; fn add_virtual_predicate(&mut self) -> PredicateTarget; + fn add_virtual_operation_type(&mut self) -> OperationTypeTarget; fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget; + fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget; + fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget; + fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget; + fn add_virtual_custom_predicate_batch(&mut self, params: &Params) + -> CustomPredicateBatchTarget; + fn add_virtual_custom_predicate_entry(&mut self, params: &Params) + -> CustomPredicateEntryTarget; fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget; fn constant_value(&mut self, v: RawValue) -> ValueTarget; @@ -604,6 +897,9 @@ pub trait CircuitBuilderPod, const D: usize> { // Convenience methods for Boolean into-iters. fn all(&mut self, xs: impl IntoIterator) -> BoolTarget; fn any(&mut self, xs: impl IntoIterator) -> BoolTarget; + + // Return a bit-mask of size `len` that selects all positions lower than `n` + fn lt_mask(&mut self, len: usize, n: Target) -> Vec; } impl CircuitBuilderPod for CircuitBuilder { @@ -629,22 +925,32 @@ impl CircuitBuilderPod for CircuitBuilder { StatementTarget { predicate, args: (0..params.max_statement_args) - .map(|_| StatementArgTarget { - elements: self.add_virtual_target_arr(), - }) + .map(|_| self.add_virtual_statement_arg()) .collect(), } } + fn add_virtual_statement_arg(&mut self) -> StatementArgTarget { + StatementArgTarget { + elements: self.add_virtual_target_arr(), + } + } + fn add_virtual_predicate(&mut self) -> PredicateTarget { PredicateTarget { elements: self.add_virtual_target_arr(), } } + fn add_virtual_operation_type(&mut self) -> OperationTypeTarget { + OperationTypeTarget { + elements: self.add_virtual_target_arr(), + } + } + fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { OperationTarget { - op_type: self.add_virtual_target_arr(), + op_type: self.add_virtual_operation_type(), args: (0..params.max_operation_args) .map(|_| self.add_virtual_target_arr()) .collect(), @@ -652,6 +958,55 @@ impl CircuitBuilderPod for CircuitBuilder { } } + fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget { + StatementTmplArgTarget { + elements: self.add_virtual_target_arr(), + } + } + + fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget { + let args = (0..params.max_statement_args) + .map(|_| self.add_virtual_statement_tmpl_arg()) + .collect(); + StatementTmplTarget { + pred: self.add_virtual_predicate(), + args, + } + } + + fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget { + let statements = (0..params.max_custom_predicate_arity) + .map(|_| self.add_virtual_statement_tmpl(params)) + .collect(); + CustomPredicateTarget { + conjunction: self.add_virtual_bool_target_safe(), + statements, + args_len: self.add_virtual_target(), + } + } + + fn add_virtual_custom_predicate_batch( + &mut self, + params: &Params, + ) -> CustomPredicateBatchTarget { + CustomPredicateBatchTarget { + predicates: (0..params.max_custom_batch_size) + .map(|_| self.add_virtual_custom_predicate(params)) + .collect(), + } + } + + fn add_virtual_custom_predicate_entry( + &mut self, + params: &Params, + ) -> CustomPredicateEntryTarget { + CustomPredicateEntryTarget { + id: self.add_virtual_hash(), + index: self.add_virtual_target(), + predicate: self.add_virtual_custom_predicate(params), + } + } + fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget { ValueTarget { elements: std::array::from_fn(|i| self.select(b, x.elements[i], y.elements[i])), @@ -876,6 +1231,12 @@ impl CircuitBuilderPod for CircuitBuilder { ) } + // TODO: Implement a version of vec_ref for types `T` which are big and support hashing. + // The idea would be the following: Take the array `ts` and hash each element. Then do the + // random access on the hash result. Finally "unhash" to recover the resolved element. + // We don't want to hash each element from the array each time, so we should cache the hashed + // result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and + // then do `ts: &[HashCache]`. fn vec_ref(&mut self, params: &Params, ts: &[T], i: Target) -> T { // TODO: Revisit this when we need more than 64 statements. let vector_ref = |builder: &mut CircuitBuilder, v: &[Target], i| { @@ -944,6 +1305,73 @@ impl CircuitBuilderPod for CircuitBuilder { .reduce(|a, b| self.or(a, b)) .unwrap_or(self._false()) } + + fn lt_mask(&mut self, len: usize, n: Target) -> Vec { + let zero = self.zero(); + let mask: Vec<_> = (0..len) + .map(|_| self.add_virtual_bool_target_safe()) + .collect(); + self.add_simple_generator(LtMaskGenerator { + n, + mask: mask.iter().map(|bt| bt.target).collect(), + }); + // We have `n` ones in the mask + let mask_sum = mask + .iter() + .map(|b| b.target) + .reduce(|acc, x| self.add(acc, x)) + .unwrap_or(zero); + self.connect(n, mask_sum); + + // The elements in the mask can only transition from 1 to 0 or 0 to 0. + for i in 0..len - 1 { + let diff = self.sub(mask[i].target, mask[i + 1].target); + self.assert_bool(BoolTarget::new_unsafe(diff)); + } + + mask + } +} + +#[derive(Debug, Default)] +pub struct LtMaskGenerator { + pub(crate) n: Target, + pub(crate) mask: Vec, +} + +impl, const D: usize> SimpleGenerator for LtMaskGenerator { + fn id(&self) -> String { + "LtMaskGenerator".to_string() + } + + fn dependencies(&self) -> Vec { + vec![self.n] + } + + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> anyhow::Result<()> { + let n = witness.get_target(self.n).to_canonical_u64(); + + for (i, mask_i) in self.mask.iter().enumerate() { + let v = if (i as u64) < n { F::ONE } else { F::ZERO }; + out_buffer.set_target(*mask_i, v)?; + } + Ok(()) + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_target(self.n)?; + dst.write_target_vec(&self.mask) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + let n = src.read_target()?; + let mask = src.read_target_vec()?; + Ok(Self { n, mask }) + } } #[cfg(test)] @@ -1013,13 +1441,14 @@ pub(crate) mod tests { let custom_predicate_batch = eth_friend_batch(¶ms)?; - for (i, cp) in custom_predicate_batch.predicates.iter().enumerate() { + for (i, cp) in custom_predicate_batch.predicates().iter().enumerate() { let mut builder = CircuitBuilder::::new(config.clone()); let flattened = cp.to_fields(¶ms); let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec(); let cp_target = CustomPredicateTarget::from_flattened(¶ms, &flatteend_target); // Round trip of from_flattened to flattened let flatteend_target_rt = cp_target.flatten(); + // TODO: Instead of connect, assign witness to result builder.connect_slice(&flatteend_target, &flatteend_target_rt); let pw = PartialWitness::::new(); @@ -1033,51 +1462,22 @@ pub(crate) mod tests { Ok(()) } - fn test_custom_predicate_batch_target_id( + fn helper_custom_predicate_batch_target_id( params: &Params, custom_predicate_batch: &CustomPredicateBatch, - ) -> frontend::Result<()> { + ) -> Result<()> { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - let zero = builder.zero(); - let predicate_targets = custom_predicate_batch - .predicates - .iter() - .map(|cp| { - let flattened = cp.to_fields(params); - let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec(); - CustomPredicateTarget::from_flattened(params, &flatteend_target) - }) - .chain(iter::repeat({ - let empty_flatteend_target = iter::repeat(zero) - .take(params.custom_predicate_size()) - .collect_vec(); - CustomPredicateTarget::from_flattened(params, &empty_flatteend_target) - })) - .take(params.max_custom_batch_size) - .collect(); - - let custom_predicate_batch_target = CustomPredicateBatchTarget { - predicates: predicate_targets, - }; + let custom_predicate_batch_target = builder.add_virtual_custom_predicate_batch(params); // Calculate the id in constraints and compare it against the id calculated natively let id_target = custom_predicate_batch_target.id(&mut builder); - let id = custom_predicate_batch.id(params); - let id_expected_target = HashOutTarget { - elements: id - .to_fields(params) - .iter() - .map(|v| builder.constant(*v)) - .collect_vec() - .try_into() - .unwrap(), - }; - builder.connect_array(id_target.elements, id_expected_target.elements); - - let pw = PartialWitness::::new(); + let mut pw = PartialWitness::::new(); + custom_predicate_batch_target.set_targets(&mut pw, params, custom_predicate_batch)?; + let id = custom_predicate_batch.id(); + pw.set_target_arr(&id_target.elements, &id.0)?; // generate & verify proof let data = builder.build::(); @@ -1088,7 +1488,7 @@ pub(crate) mod tests { } #[test] - fn custom_predicate_batch_target() -> frontend::Result<()> { + fn test_custom_predicate_batch_target_id() -> frontend::Result<()> { let params = Params { max_statement_args: 6, max_custom_predicate_wildcards: 12, @@ -1096,17 +1496,21 @@ pub(crate) mod tests { }; // Empty case - let mut cpb_builder = CustomPredicateBatchBuilder::new("empty".into()); - _ = cpb_builder.predicate_and("empty", ¶ms, &[], &[], &[])?; + let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into()); + _ = cpb_builder.predicate_and("empty", &[], &[], &[])?; let custom_predicate_batch = cpb_builder.finish(); - test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?; + helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap(); // Some cases from the examples let custom_predicate_batch = eth_friend_batch(¶ms)?; - test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?; + helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap(); let custom_predicate_batch = eth_dos_batch(¶ms)?; - test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?; + helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap(); + + let custom_predicate_batch = + CustomPredicateBatch::new(¶ms, "empty".to_string(), vec![CustomPredicate::empty()]); + helper_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch).unwrap(); Ok(()) } diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 95d34be..a12e2cc 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -1,7 +1,8 @@ -use std::array; +use std::{array, sync::Arc}; -use itertools::zip_eq; +use itertools::{zip_eq, Itertools}; use plonky2::{ + field::types::Field, hash::{hash_types::HashOutTarget, poseidon::PoseidonHash}, iop::{target::BoolTarget, witness::PartialWitness}, plonk::circuit_builder::CircuitBuilder, @@ -12,8 +13,11 @@ use crate::{ basetypes::D, circuits::{ common::{ - CircuitBuilderPod, Flattenable, MerkleClaimTarget, OperationTarget, - StatementArgTarget, StatementTarget, ValueTarget, + CircuitBuilderPod, CustomPredicateBatchTarget, CustomPredicateEntryTarget, + CustomPredicateVerifyEntryTarget, CustomPredicateVerifyQueryTarget, Flattenable, + MerkleClaimTarget, OperationTarget, OperationTypeTarget, PredicateTarget, + StatementArgTarget, StatementTarget, StatementTmplArgTarget, StatementTmplTarget, + ValueTarget, }, signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget}, }, @@ -25,8 +29,9 @@ use crate::{ signedpod::SignedPod, }, middleware::{ - AnchoredKey, NativeOperation, NativePredicate, Params, PodType, Statement, StatementArg, - ToFields, Value, F, KEY_TYPE, SELF, VALUE_SIZE, + AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, + NativePredicate, Params, PodType, Statement, StatementArg, ToFields, Value, WildcardValue, + F, KEY_TYPE, SELF, VALUE_SIZE, }, }; @@ -72,6 +77,7 @@ impl OperationVerifyGadget { op: &OperationTarget, prev_statements: &[StatementTarget], merkle_claims: &[MerkleClaimTarget], + custom_predicate_verification_table: &[HashOutTarget], ) -> Result<()> { let _true = builder._true(); let _false = builder._false(); @@ -80,7 +86,11 @@ impl OperationVerifyGadget { // can reference any of the `prev_statements`. // TODO: Clean this up. let resolved_op_args = if prev_statements.is_empty() { - vec![] + (0..self.params.max_operation_args) + .map(|_| { + StatementTarget::new_native(builder, &self.params, NativePredicate::None, &[]) + }) + .collect_vec() } else { op.args .iter() @@ -88,12 +98,30 @@ impl OperationVerifyGadget { .map(|&i| builder.vec_ref(&self.params, prev_statements, i)) .collect::>() }; + // TODO: Can we have a single table with merkel claims and verified custom predicates + // together (with an identifying prefix) and then we only need one random access instead of + // two? + // Currently we use one slot of aux for the index to merkle claim and another slot of aux + // for the index to the verified custom predicate. We can't use the same slot because then + // if one table is different size the random access to the smaller one may use an index + // that is too big and not pass the constraints. Possible solutions to use a single slot + // are: + // - a. Use a single table (mux both tables) + // - b. select the index or 0 by checking the operation type here; but that breaks the + // current abstraction a little bit. + // Certain operations (Contains/NotContains) will refer to one // of the provided Merkle proofs (if any). These proofs have already // been verified, so we need only look up the claim. let resolved_merkle_claim = (!merkle_claims.is_empty()) .then(|| builder.vec_ref(&self.params, merkle_claims, op.aux[0])); + // Operations from custom statements will refer to one + // of the provided custom predicates verifications (if any). These operations have already + // been verified, so we need only look up the entry. + let resolved_custom_pred_verification = (!custom_predicate_verification_table.is_empty()) + .then(|| builder.vec_ref(&self.params, custom_predicate_verification_table, op.aux[1])); + // The verification may require aux data which needs to be stored in the // `OperationVerifyTarget` so that we can set during witness generation. @@ -104,23 +132,23 @@ impl OperationVerifyGadget { // lie outside of the domain. let op_checks = [ vec![ - self.eval_none(builder, st, op), - self.eval_new_entry(builder, st, op, prev_statements), + self.eval_none(builder, st, &op.op_type), + self.eval_new_entry(builder, st, &op.op_type, prev_statements), ], // Skip these if there are no resolved op args if resolved_op_args.is_empty() { vec![] } else { vec![ - self.eval_copy(builder, st, op, &resolved_op_args)?, - self.eval_eq_neq_from_entries(builder, st, op, &resolved_op_args), - self.eval_lt_lteq_from_entries(builder, st, op, &resolved_op_args), - self.eval_transitive_eq(builder, st, op, &resolved_op_args), - self.eval_lt_to_neq(builder, st, op, &resolved_op_args), - self.eval_hash_of(builder, st, op, &resolved_op_args), - self.eval_sum_of(builder, st, op, &resolved_op_args), - self.eval_product_of(builder, st, op, &resolved_op_args), - self.eval_max_of(builder, st, op, &resolved_op_args), + self.eval_copy(builder, st, &op.op_type, &resolved_op_args)?, + self.eval_eq_neq_from_entries(builder, st, &op.op_type, &resolved_op_args), + self.eval_lt_lteq_from_entries(builder, st, &op.op_type, &resolved_op_args), + self.eval_transitive_eq(builder, st, &op.op_type, &resolved_op_args), + self.eval_lt_to_neq(builder, st, &op.op_type, &resolved_op_args), + self.eval_hash_of(builder, st, &op.op_type, &resolved_op_args), + self.eval_sum_of(builder, st, &op.op_type, &resolved_op_args), + self.eval_product_of(builder, st, &op.op_type, &resolved_op_args), + self.eval_max_of(builder, st, &op.op_type, &resolved_op_args), ] }, // Skip these if there are no resolved Merkle claims @@ -129,14 +157,14 @@ impl OperationVerifyGadget { self.eval_contains_from_entries( builder, st, - op, + &op.op_type, resolved_merkle_claim, &resolved_op_args, ), self.eval_not_contains_from_entries( builder, st, - op, + &op.op_type, resolved_merkle_claim, &resolved_op_args, ), @@ -144,6 +172,18 @@ impl OperationVerifyGadget { } else { vec![] }, + // Skip these if there are no resolved custom predicate verifications + if let Some(resolved_custom_pred_verification) = resolved_custom_pred_verification { + vec![self.eval_custom( + builder, + st, + &op.op_type, + resolved_custom_pred_verification, + &resolved_op_args, + )] + } else { + vec![] + }, ] .concat(); @@ -158,11 +198,11 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_merkle_claim: MerkleClaimTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { - let op_code_ok = op.has_native_type(builder, NativeOperation::ContainsFromEntries); + let op_code_ok = op_type.has_native(builder, NativeOperation::ContainsFromEntries); let (arg_types_ok, [merkle_root_value, key_value, value_value]) = self.first_n_args_as_values(builder, resolved_op_args); @@ -203,11 +243,11 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_merkle_claim: MerkleClaimTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { - let op_code_ok = op.has_native_type(builder, NativeOperation::NotContainsFromEntries); + let op_code_ok = op_type.has_native(builder, NativeOperation::NotContainsFromEntries); let (arg_types_ok, [merkle_root_value, key_value]) = self.first_n_args_as_values(builder, resolved_op_args); @@ -242,22 +282,42 @@ impl OperationVerifyGadget { builder.all([op_code_ok, arg_types_ok, merkle_proof_ok, st_ok]) } + fn eval_custom( + &self, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op_type: &OperationTypeTarget, + resolved_custom_pred_verification: HashOutTarget, + resolved_op_args: &[StatementTarget], + ) -> BoolTarget { + let query = CustomPredicateVerifyQueryTarget { + statement: st.clone(), + op_type: op_type.clone(), + op_args: resolved_op_args.to_vec(), + }; + let out_query_hash = query.hash(builder); + builder.is_equal_slice( + &resolved_custom_pred_verification.elements, + &out_query_hash.elements, + ) + } + /// Carries out the checks necessary for EqualFromEntries and /// NotEqualFromEntries. fn eval_eq_neq_from_entries( &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { let eq_op_st_code_ok = { - let op_code_ok = op.has_native_type(builder, NativeOperation::EqualFromEntries); + let op_code_ok = op_type.has_native(builder, NativeOperation::EqualFromEntries); let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::Equal); builder.and(op_code_ok, st_code_ok) }; let neq_op_st_code_ok = { - let op_code_ok = op.has_native_type(builder, NativeOperation::NotEqualFromEntries); + let op_code_ok = op_type.has_native(builder, NativeOperation::NotEqualFromEntries); let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::NotEqual); builder.and(op_code_ok, st_code_ok) }; @@ -296,19 +356,19 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { let zero = ValueTarget::zero(builder); let one = ValueTarget::one(builder); let lt_op_st_code_ok = { - let op_code_ok = op.has_native_type(builder, NativeOperation::LtFromEntries); + let op_code_ok = op_type.has_native(builder, NativeOperation::LtFromEntries); let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::Lt); builder.and(op_code_ok, st_code_ok) }; let lteq_op_st_code_ok = { - let op_code_ok = op.has_native_type(builder, NativeOperation::LtEqFromEntries); + let op_code_ok = op_type.has_native(builder, NativeOperation::LtEqFromEntries); let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::LtEq); builder.and(op_code_ok, st_code_ok) }; @@ -362,10 +422,10 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { - let op_code_ok = op.has_native_type(builder, NativeOperation::HashOf); + let op_code_ok = op_type.has_native(builder, NativeOperation::HashOf); let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = self.first_n_args_as_values(builder, resolved_op_args); @@ -393,12 +453,12 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { let value_zero = ValueTarget::zero(builder); - let op_code_ok = op.has_native_type(builder, NativeOperation::SumOf); + let op_code_ok = op_type.has_native(builder, NativeOperation::SumOf); let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = self.first_n_args_as_values(builder, resolved_op_args); @@ -429,12 +489,12 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { let value_zero = ValueTarget::zero(builder); - let op_code_ok = op.has_native_type(builder, NativeOperation::ProductOf); + let op_code_ok = op_type.has_native(builder, NativeOperation::ProductOf); let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = self.first_n_args_as_values(builder, resolved_op_args); @@ -465,10 +525,10 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { - let op_code_ok = op.has_native_type(builder, NativeOperation::MaxOf); + let op_code_ok = op_type.has_native(builder, NativeOperation::MaxOf); let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = self.first_n_args_as_values(builder, resolved_op_args); @@ -508,11 +568,11 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { let op_code_ok = - op.has_native_type(builder, NativeOperation::TransitiveEqualFromStatements); + op_type.has_native(builder, NativeOperation::TransitiveEqualFromStatements); let arg1_type_ok = resolved_op_args[0].has_native_type(builder, &self.params, NativePredicate::Equal); @@ -541,9 +601,9 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, ) -> BoolTarget { - let op_code_ok = op.has_native_type(builder, NativeOperation::None); + let op_code_ok = op_type.has_native(builder, NativeOperation::None); let expected_statement = StatementTarget::new_native(builder, &self.params, NativePredicate::None, &[]); @@ -556,10 +616,10 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, prev_statements: &[StatementTarget], ) -> BoolTarget { - let op_code_ok = op.has_native_type(builder, NativeOperation::NewEntry); + let op_code_ok = op_type.has_native(builder, NativeOperation::NewEntry); let st_code_ok = st.has_native_type(builder, &self.params, NativePredicate::ValueOf); @@ -591,10 +651,10 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], ) -> BoolTarget { - let op_code_ok = op.has_native_type(builder, NativeOperation::LtToNotEqual); + let op_code_ok = op_type.has_native(builder, NativeOperation::LtToNotEqual); let arg_type_ok = resolved_op_args[0].has_native_type(builder, &self.params, NativePredicate::Lt); @@ -617,10 +677,10 @@ impl OperationVerifyGadget { &self, builder: &mut CircuitBuilder, st: &StatementTarget, - op: &OperationTarget, + op_type: &OperationTypeTarget, resolved_op_args: &[StatementTarget], ) -> Result { - let op_code_ok = op.has_native_type(builder, NativeOperation::CopyStatement); + let op_code_ok = op_type.has_native(builder, NativeOperation::CopyStatement); let expected_statement = &resolved_op_args[0]; let st_ok = builder.is_equal_flattenable(st, expected_statement); @@ -629,6 +689,149 @@ impl OperationVerifyGadget { } } +struct CustomOperationVerifyGadget { + params: Params, +} + +// NOTE: This is a bit messy. The target types are defined in `common.rs` because they are used in +// `add_virtual_foo` methods in the trait for the `CircuitBuilder`. But the constraint logic is +// here. Maybe we want to move everything related to custom predicates to its own module, but then +// should we add a new trait for the `add_virtual_foo` methods so that everything is contained in a +// module? +impl CustomOperationVerifyGadget { + fn statement_arg_from_template( + &self, + builder: &mut CircuitBuilder, + st_tmpl_arg: &StatementTmplArgTarget, + args: &[ValueTarget], + ) -> StatementArgTarget { + let zero = builder.zero(); + let (is_literal, value_literal) = st_tmpl_arg.as_literal(builder); + let (is_ak, ak_id_wc_index, ak_key_lit_or_wc) = st_tmpl_arg.as_anchored_key(builder); + let (is_wc_literal, wc_index) = st_tmpl_arg.as_wildcard_literal(builder); + + let ((_is_ak_key_lit, ak_key_lit), (is_ak_key_wc, ak_key_wc_index)) = + ak_key_lit_or_wc.cases(builder); + + // optimization: ak_id_wc_index and wc_index use the same signals, so we only need to do one + // random access to resolve both of them + assert_eq!(ak_id_wc_index, wc_index); + // If the index is not used, use a 0 instead to still pass the range constraints from + // vec_ref + let first_index = ak_id_wc_index; + let is_first_index_valid = builder.or(is_ak, is_wc_literal); + let first_index = builder.select(is_first_index_valid, first_index, zero); + let resolved_ak_id = builder.vec_ref(&self.params, args, first_index); + let resolved_wc = resolved_ak_id; + + // If the index is not used, use a 0 instead to still pass the range constraints from + // vec_ref + let second_index = ak_key_wc_index; + let is_second_index_valid = builder.and(is_ak, is_ak_key_wc); + let second_index = builder.select(is_second_index_valid, second_index, zero); + let resolved_ak_key = builder.vec_ref(&self.params, args, second_index); + + let ak_key = ak_key_lit; // is_ak_key_lit + let ak_key = + builder.select_flattenable(&self.params, is_ak_key_wc, &resolved_ak_key, &ak_key); + + let first = ValueTarget::zero(builder); // is_none + let first = builder.select_flattenable(&self.params, is_literal, &value_literal, &first); + let first = builder.select_flattenable(&self.params, is_ak, &resolved_ak_id, &first); + let first = builder.select_flattenable(&self.params, is_wc_literal, &resolved_wc, &first); + + let second = ValueTarget::zero(builder); // is_none or is_literal or is_wc_literal + let second = builder.select_flattenable(&self.params, is_ak, &ak_key, &second); + + StatementArgTarget::new(first, second) + } + + fn statement_from_template( + &self, + builder: &mut CircuitBuilder, + st_tmpl: &StatementTmplTarget, + args: &[ValueTarget], + ) -> StatementTarget { + let args = st_tmpl + .args + .iter() + .map(|st_tmpl_arg| self.statement_arg_from_template(builder, st_tmpl_arg, args)) + .collect(); + StatementTarget { + predicate: st_tmpl.pred.clone(), + args, + } + } + + /// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard + /// values (args): + /// - Verify that the custom predicate is satisfied with the given statements + /// - Build the output statement + /// - Build the expected operation type + fn eval( + &self, + builder: &mut CircuitBuilder, + custom_predicate: &CustomPredicateEntryTarget, + op_args: &[StatementTarget], + args: &[ValueTarget], // arguments to the custom predicate, public and private + ) -> Result<(StatementTarget, OperationTypeTarget)> { + // Some sanity checks + assert_eq!(self.params.max_operation_args, op_args.len()); + assert_eq!(self.params.max_custom_predicate_wildcards, args.len()); + + let (batch_id, index) = (custom_predicate.id, custom_predicate.index); + let op_type = OperationTypeTarget::new_custom(builder, batch_id, index); + + // Build the statement + let st_predicate = PredicateTarget::new_custom(builder, batch_id, index); + let arg_none = ValueTarget::zero(builder); + let lt_mask = builder.lt_mask( + self.params.max_statement_args, + custom_predicate.predicate.args_len, + ); + let st_args = (0..self.params.max_statement_args) + .map(|i| { + let v = builder.select_flattenable(&self.params, lt_mask[i], &args[i], &arg_none); + StatementArgTarget::wildcard_literal(builder, &v) + }) + .collect(); + let statement = StatementTarget { + predicate: st_predicate, + args: st_args, + }; + + // Check the operation arguments + // From each statement template we generate an expected statement using replacing the + // wildcards by the arguments. Then we compare the expected statement with the operation + // argument. + let expected_sts: Vec<_> = custom_predicate + .predicate + .statements + .iter() + .map(|st_tmpl| self.statement_from_template(builder, st_tmpl, args)) + .collect(); + // expected_sts.len() == self.params.max_custom_predicate_arity + // op_args.len() == self.params.max_operation_args; + assert!(self.params.max_custom_predicate_arity <= self.params.max_operation_args); + let sts_eq: Vec<_> = expected_sts + .iter() + .zip(op_args.iter()) + .map(|(expected_st, st)| builder.is_equal_flattenable(expected_st, st)) + .collect(); + let all_st_eq = builder.all(sts_eq.clone()); + let some_st_eq = builder.any(sts_eq); + // NOTE: This BoolTarget is safe because both inputs to the select are safe + let is_op_args_ok = BoolTarget::new_unsafe(builder.select( + custom_predicate.predicate.conjunction, + all_st_eq.target, + some_st_eq.target, + )); + + builder.assert_one(is_op_args_ok.target); + Ok((statement, op_type)) + } +} + struct MainPodVerifyGadget { params: Params, } @@ -692,6 +895,74 @@ impl MainPodVerifyGadget { .map(|pf| pf.into()) .collect(); + // Table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as + // hash([batch_id, custom_predicate_index, custom_predicate]). While building the table we + // calculate the id of each batch. + let mut custom_predicate_table = + Vec::with_capacity(params.max_custom_predicate_batches * params.max_custom_batch_size); + let mut custom_predicate_batches = Vec::with_capacity(params.max_custom_predicate_batches); + for _ in 0..params.max_custom_predicate_batches { + let cpb = builder.add_virtual_custom_predicate_batch(&self.params); + let id = cpb.id(builder); // constrain the id + for (index, cp) in cpb.predicates.iter().enumerate() { + let entry = CustomPredicateEntryTarget { + id, // output + index: builder.constant(F::from_canonical_usize(index)), // constant + predicate: cp.clone(), // input + }; + let in_query_hash = entry.hash(builder); + custom_predicate_table.push(in_query_hash); + } + custom_predicate_batches.push(cpb); // We keep this for witness assignment + } + + // Table of [batch_id, custom_predicate_index, custom_predicate, args, st, op, op_args] + // with queryable part as hash([st, op, op_args]). While building the table we verify each + // custom predicate against the operation and statement. + let mut custom_predicate_verifications = + Vec::with_capacity(params.max_custom_predicate_verifications); + let mut custom_predicate_verification_table = + Vec::with_capacity(params.max_custom_predicate_verifications); + for _ in 0..params.max_custom_predicate_verifications { + let custom_predicate_table_index = builder.add_virtual_target(); + let custom_predicate = builder.add_virtual_custom_predicate_entry(&self.params); + let args = (0..params.max_custom_predicate_wildcards) + .map(|_| builder.add_virtual_value()) + .collect_vec(); + let op_args = (0..params.max_operation_args) + .map(|_| builder.add_virtual_statement(&self.params)) + .collect_vec(); + + // Verify the custom predicate operation + let (statement, op_type) = CustomOperationVerifyGadget { + params: params.clone(), + } + .eval(builder, &custom_predicate, &op_args, &args)?; + + // Check that the batch id is correct by querying the custom predicate batches table + let table_query_hash = builder.vec_ref( + &self.params, + &custom_predicate_table, + custom_predicate_table_index, + ); + let out_query_hash = custom_predicate.hash(builder); + builder.connect_array(table_query_hash.elements, out_query_hash.elements); + + let entry = CustomPredicateVerifyEntryTarget { + custom_predicate_table_index, // input + custom_predicate, // input + args, // input + query: CustomPredicateVerifyQueryTarget { + statement, // output + op_type, // output + op_args, // input + }, + }; + let in_query_hash = entry.query.hash(builder); + custom_predicate_verification_table.push(in_query_hash); + custom_predicate_verifications.push(entry); // We keep this for witness assignment + } + // 2. Calculate the Pod Id from the public statements let pub_statements_flattened = pub_statements.iter().flat_map(|s| s.flatten()).collect(); let id = builder.hash_n_to_hash_no_pad::(pub_statements_flattened); @@ -720,7 +991,14 @@ impl MainPodVerifyGadget { OperationVerifyGadget { params: params.clone(), } - .eval(builder, st, op, prev_statements, &merkle_claims)?; + .eval( + builder, + st, + op, + prev_statements, + &merkle_claims, + &custom_predicate_verification_table, + )?; } Ok(MainPodVerifyTarget { @@ -730,6 +1008,8 @@ impl MainPodVerifyGadget { statements: input_statements.to_vec(), operations, merkle_proofs, + custom_predicate_batches, + custom_predicate_verifications, }) } } @@ -742,6 +1022,15 @@ pub struct MainPodVerifyTarget { statements: Vec, operations: Vec, merkle_proofs: Vec, + custom_predicate_batches: Vec, + custom_predicate_verifications: Vec, +} + +pub struct CustomPredicateVerification { + pub custom_predicate_table_index: usize, + pub custom_predicate: CustomPredicateRef, + pub args: Vec, + pub op_args: Vec, } pub struct MainPodVerifyInput { @@ -749,6 +1038,8 @@ pub struct MainPodVerifyInput { pub statements: Vec, pub operations: Vec, pub merkle_proofs: Vec, + pub custom_predicate_batches: Vec>, + pub custom_predicate_verifications: Vec, } impl MainPodVerifyTarget { @@ -762,16 +1053,21 @@ impl MainPodVerifyTarget { self.signed_pods[i].set_targets(pw, signed_pod)?; } // Padding - // TODO: Instead of using an input for padding, use a canonical minimal SignedPod - let pad_pod = &input.signed_pods[0]; - for i in input.signed_pods.len()..self.params.max_input_signed_pods { - self.signed_pods[i].set_targets(pw, pad_pod)?; + if self.params.max_input_signed_pods > 0 { + // TODO: Instead of using an input for padding, use a canonical minimal SignedPod, + // without it a MainPod configured to support input signed pods must have at least one + // input signed pod :( + let pad_pod = &input.signed_pods[0]; + for i in input.signed_pods.len()..self.params.max_input_signed_pods { + self.signed_pods[i].set_targets(pw, pad_pod)?; + } } assert_eq!(input.statements.len(), self.params.max_statements); for (i, (st, op)) in zip_eq(&input.statements, &input.operations).enumerate() { self.statements[i].set_targets(pw, &self.params, st)?; self.operations[i].set_targets(pw, &self.params, op)?; } + assert!(input.merkle_proofs.len() <= self.params.max_merkle_proofs); for (i, mp) in input.merkle_proofs.iter().enumerate() { self.merkle_proofs[i].set_targets(pw, true, mp)?; @@ -781,6 +1077,46 @@ impl MainPodVerifyTarget { for i in input.merkle_proofs.len()..self.params.max_merkle_proofs { self.merkle_proofs[i].set_targets(pw, false, &pad_mp)?; } + + assert!(input.custom_predicate_batches.len() <= self.params.max_custom_predicate_batches); + for (i, cpb) in input.custom_predicate_batches.iter().enumerate() { + self.custom_predicate_batches[i].set_targets(pw, &self.params, cpb)?; + } + // Padding + let pad_cpb = CustomPredicateBatch::new( + &self.params, + "empty".to_string(), + vec![CustomPredicate::empty()], + ); + for i in input.custom_predicate_batches.len()..self.params.max_custom_predicate_batches { + self.custom_predicate_batches[i].set_targets(pw, &self.params, &pad_cpb)?; + } + + assert!( + input.custom_predicate_verifications.len() + <= self.params.max_custom_predicate_verifications + ); + for (i, cpv) in input.custom_predicate_verifications.iter().enumerate() { + self.custom_predicate_verifications[i].set_targets(pw, &self.params, cpv)?; + } + // Padding. Use the first input if it exists. If it doesnt, all batches in this MainPod + // are padding so refer to the first padding entry. + let empty_cpv = CustomPredicateVerification { + custom_predicate_table_index: 0, + custom_predicate: CustomPredicateRef::new(pad_cpb, 0), + args: vec![], + op_args: vec![], + }; + let pad_cpv = input + .custom_predicate_verifications + .first() + .unwrap_or(&empty_cpv); + for i in input.custom_predicate_verifications.len() + ..self.params.max_custom_predicate_verifications + { + self.custom_predicate_verifications[i].set_targets(pw, &self.params, pad_cpv)?; + } + Ok(()) } } @@ -817,7 +1153,11 @@ mod tests { mainpod::{OperationArg, OperationAux}, primitives::merkletree::{MerkleClaimAndProof, MerkleTree}, }, - middleware::{hash_values, Hash, OperationType, PodId, RawValue}, + frontend::{self, key, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, + middleware::{ + hash_str, hash_values, Hash, Key, KeyOrWildcard, OperationType, PodId, Predicate, + RawValue, StatementTmpl, StatementTmplArg, Wildcard, WildcardValue, + }, }; fn operation_verify( @@ -826,7 +1166,11 @@ mod tests { prev_statements: Vec, merkle_proofs: Vec, ) -> Result<()> { - let params = Params::default(); + let params = Params { + max_custom_predicate_batches: 0, + max_custom_predicate_verifications: 0, + ..Default::default() + }; let mp_gadget = MerkleProofGadget { max_depth: params.max_depth_mt_gadget, }; @@ -848,6 +1192,7 @@ mod tests { .into_iter() .map(|pf| pf.into()) .collect(); + let custom_predicate_verification_table = vec![]; OperationVerifyGadget { params: params.clone(), @@ -858,6 +1203,7 @@ mod tests { &op_target, &prev_statements_target, &merkle_claims_target, + &custom_predicate_verification_table, )?; let mut pw = PartialWitness::::new(); @@ -1670,4 +2016,318 @@ mod tests { let prev_statements = vec![root_st, key_st, value_st]; operation_verify(st, op, prev_statements, merkle_proofs) } + + fn helper_statement_arg_from_template( + params: &Params, + st_tmpl_arg: StatementTmplArg, + args: Vec, + expected_st_arg: StatementArg, + ) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let gadget = CustomOperationVerifyGadget { + params: params.clone(), + }; + + let st_tmpl_arg_target = builder.add_virtual_statement_tmpl_arg(); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_arg_target = + gadget.statement_arg_from_template(&mut builder, &st_tmpl_arg_target, &args_target); + // TODO: Instead of connect, assign witness to result + let expected_st_arg_target = builder.add_virtual_statement_arg(); + builder.connect_array(expected_st_arg_target.elements, st_arg_target.elements); + + let mut pw = PartialWitness::::new(); + + st_tmpl_arg_target.set_targets(&mut pw, params, &st_tmpl_arg)?; + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, arg)?; + } + expected_st_arg_target.set_targets(&mut pw, params, &expected_st_arg)?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + + Ok(()) + } + + #[test] + fn test_statement_arg_from_template() -> Result<()> { + let params = Params::default(); + + let pod_id = PodId(hash_str("pod_id")); + + // case: None + let st_tmpl_arg = StatementTmplArg::None; + let args = vec![Value::from(1), Value::from(2), Value::from(3)]; + let expected_st_arg = StatementArg::None; + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: Literal + let st_tmpl_arg = StatementTmplArg::Literal(Value::from("foo")); + let args = vec![Value::from(1), Value::from(2), Value::from(3)]; + let expected_st_arg = StatementArg::Literal(Value::from("foo")); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: AnchoredKey(id_wildcard, key_literal) + let st_tmpl_arg = StatementTmplArg::AnchoredKey( + Wildcard::new("a".to_string(), 1), + KeyOrWildcard::Key(Key::from("foo")), + ); + let args = vec![Value::from(1), Value::from(pod_id.0), Value::from(3)]; + let expected_st_arg = StatementArg::Key(AnchoredKey::new(pod_id, Key::from("foo"))); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: AnchoredKey(id_wildcard, key_wildcard) + let st_tmpl_arg = StatementTmplArg::AnchoredKey( + Wildcard::new("a".to_string(), 1), + KeyOrWildcard::Wildcard(Wildcard::new("b".to_string(), 2)), + ); + let args = vec![Value::from(1), Value::from(pod_id.0), Value::from("key")]; + let expected_st_arg = StatementArg::Key(AnchoredKey::new(pod_id, Key::from("key"))); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: WildcardLiteral(wildcard) + let st_tmpl_arg = StatementTmplArg::WildcardLiteral(Wildcard::new("a".to_string(), 1)); + let args = vec![Value::from(1), Value::from("key"), Value::from(3)]; + let expected_st_arg = StatementArg::WildcardLiteral(WildcardValue::Key(Key::from("key"))); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + Ok(()) + } + + fn helper_statement_from_template( + params: &Params, + st_tmpl: StatementTmpl, + args: Vec, + expected_st: Statement, + ) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let gadget = CustomOperationVerifyGadget { + params: params.clone(), + }; + + let st_tmpl_target = builder.add_virtual_statement_tmpl(params); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_target = gadget.statement_from_template(&mut builder, &st_tmpl_target, &args_target); + // TODO: Instead of connect, assign witness to result + let expected_st_target = builder.add_virtual_statement(params); + builder.connect_flattenable(&expected_st_target, &st_target); + + let mut pw = PartialWitness::::new(); + + st_tmpl_target.set_targets(&mut pw, params, &st_tmpl)?; + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, arg)?; + } + expected_st_target.set_targets(&mut pw, params, &expected_st.into())?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + + Ok(()) + } + + #[test] + fn test_statement_from_template() -> Result<()> { + let params = Params::default(); + + let pod_id = PodId(hash_str("pod_id")); + + let st_tmpl = StatementTmpl { + pred: Predicate::Native(NativePredicate::ValueOf), + args: vec![ + StatementTmplArg::AnchoredKey( + Wildcard::new("a".to_string(), 1), + KeyOrWildcard::Key(Key::from("key")), + ), + StatementTmplArg::Literal(Value::from("value")), + ], + }; + let args = vec![Value::from(1), Value::from(pod_id.0), Value::from(3)]; + let expected_st = Statement::ValueOf( + AnchoredKey::new(pod_id, Key::from("key")), + Value::from("value"), + ); + helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?; + + Ok(()) + } + + fn helper_custom_operation_verify_gadget( + params: &Params, + custom_predicate: CustomPredicateRef, + op_args: Vec, + args: Vec, + expected_st: Statement, + ) -> Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let gadget = CustomOperationVerifyGadget { + params: params.clone(), + }; + + let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params); + let op_args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_statement(params)) + .collect(); + let args_target: Vec<_> = (0..args.len()) + .map(|_| builder.add_virtual_value()) + .collect(); + let (st_target, op_type_target) = gadget.eval( + &mut builder, + &custom_predicate_target, + &op_args_target, + &args_target, + )?; + + let mut pw = PartialWitness::::new(); + + // Input + custom_predicate_target.set_targets(&mut pw, params, &custom_predicate)?; + for (op_arg_target, op_arg) in op_args_target.iter().zip(op_args.into_iter()) { + op_arg_target.set_targets(&mut pw, params, &op_arg.into())?; + } + for (arg_target, arg) in args_target.iter().zip(args.iter()) { + arg_target.set_targets(&mut pw, &Value::from(arg.raw()))?; + } + // Expected Output + st_target.set_targets(&mut pw, params, &expected_st.into())?; + + let expected_op_type = OperationType::Custom(custom_predicate); + op_type_target.set_targets(&mut pw, params, &expected_op_type)?; + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + + Ok(()) + } + + // TODO: Add negative tests + #[test] + fn test_custom_operation_verify_gadget() -> frontend::Result<()> { + // We set the parameters to the exact sizes we have in the test so that we don't have to + // pad. + let params = Params { + max_custom_predicate_arity: 2, + max_custom_predicate_wildcards: 2, + max_operation_args: 2, + max_statement_args: 2, + ..Default::default() + }; + + use NativePredicate as NP; + use StatementTmplBuilder as STB; + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb0 = STB::new(NP::ValueOf) + .arg(("id", key("score"))) + .arg(literal(42)); + let stb1 = STB::new(NP::ValueOf) + .arg(("id", "secret_key")) + .arg(literal(1234)); + let _ = builder.predicate_and( + "pred_and", + &["id"], + &["secret_key"], + &[stb0.clone(), stb1.clone()], + )?; + let _ = builder.predicate_or("pred_or", &["id"], &["secret_key"], &[stb0, stb1])?; + let batch = builder.finish(); + + let pod_id = PodId(hash_str("pod_id")); + + // AND + let custom_predicate = CustomPredicateRef::new(batch.clone(), 0); + let op_args = vec![ + Statement::ValueOf( + AnchoredKey::new(pod_id, Key::from("score")), + Value::from(42), + ), + Statement::ValueOf( + AnchoredKey::new(pod_id, Key::from("foo")), + Value::from(1234), + ), + ]; + let args = vec![ + WildcardValue::PodId(pod_id), + WildcardValue::Key(Key::from("foo")), + ]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![args[0].clone(), WildcardValue::None], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + expected_st, + ) + .unwrap(); + + // OR (1) + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![ + Statement::ValueOf( + AnchoredKey::new(pod_id, Key::from("score")), + Value::from(42), + ), + Statement::None, + ]; + let args = vec![WildcardValue::PodId(pod_id), WildcardValue::None]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![args[0].clone(), WildcardValue::None], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + expected_st, + ) + .unwrap(); + + // OR (2) + let custom_predicate = CustomPredicateRef::new(batch.clone(), 1); + let op_args = vec![ + Statement::None, + Statement::ValueOf( + AnchoredKey::new(pod_id, Key::from("foo")), + Value::from(1234), + ), + ]; + let args = vec![ + WildcardValue::PodId(pod_id), + WildcardValue::Key(Key::from("foo")), + ]; + let expected_st = Statement::Custom( + custom_predicate.clone(), + vec![args[0].clone(), WildcardValue::None], + ); + + helper_custom_operation_verify_gadget( + ¶ms, + custom_predicate, + op_args, + args, + expected_st, + ) + .unwrap(); + + Ok(()) + } } diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 83df6ae..397dad2 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1,6 +1,6 @@ pub mod operation; pub mod statement; -use std::any::Any; +use std::{any::Any, sync::Arc}; use itertools::Itertools; pub use operation::*; @@ -17,14 +17,17 @@ pub use statement::*; use crate::{ backends::plonky2::{ basetypes::{C, D}, - circuits::mainpod::{MainPodVerifyCircuit, MainPodVerifyInput}, + circuits::mainpod::{ + CustomPredicateVerification, MainPodVerifyCircuit, MainPodVerifyInput, + }, error::{Error, Result}, primitives::merkletree::MerkleClaimAndProof, signedpod::SignedPod, }, middleware::{ - self, AnchoredKey, DynError, Hash, MainPodInputs, NativeOperation, NonePod, OperationType, - Params, Pod, PodId, PodProver, PodType, StatementArg, ToFields, F, KEY_TYPE, SELF, + self, resolve_wildcard_values, AnchoredKey, CustomPredicateBatch, DynError, Hash, + MainPodInputs, NativeOperation, NonePod, OperationType, Params, Pod, PodId, PodProver, + PodType, StatementArg, ToFields, F, KEY_TYPE, SELF, }, }; @@ -37,7 +40,71 @@ pub(crate) fn hash_statements(statements: &[Statement], _params: &Params) -> mid Hash(PoseidonHash::hash_no_pad(&field_elems).elements) } -/// Extracts and pads Merkle proofs from Contains/NotContains ops. +/// Extracts unique `CustomPredicateBatch`es from Custom ops. +pub(crate) fn extract_custom_predicate_batches( + params: &Params, + operations: &[middleware::Operation], +) -> Result>> { + let custom_predicate_batches: Vec<_> = operations + .iter() + .flat_map(|op| match op { + middleware::Operation::Custom(cpr, _) => Some(cpr.batch.clone()), + _ => None, + }) + .unique_by(|cpr| cpr.id()) + .collect(); + if custom_predicate_batches.len() > params.max_custom_predicate_batches { + return Err(Error::custom(format!( + "The number of required `CustomPredicateBatch`es ({}) exceeds the maximum number ({}).", + custom_predicate_batches.len(), + params.max_custom_predicate_batches + ))); + } + Ok(custom_predicate_batches) +} + +/// Extracts all custom predicate operations with all the data required to verify them. +pub(crate) fn extract_custom_predicate_verifications( + params: &Params, + operations: &[middleware::Operation], + custom_predicate_batches: &[Arc], +) -> Result> { + let custom_predicate_data: Vec<_> = operations + .iter() + .flat_map(|op| match op { + middleware::Operation::Custom(cpr, sts) => Some((cpr, sts)), + _ => None, + }) + .map(|(cpr, sts)| { + let wildcard_values = + resolve_wildcard_values(params, cpr.predicate(), sts).expect("resolved wildcards"); + let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); + let batch_index = custom_predicate_batches + .iter() + .enumerate() + .find_map(|(i, cpb)| (cpb.id() == cpr.batch.id()).then_some(i)) + .expect("find the custom predicate from the extracted unique list"); + let custom_predicate_table_index = + batch_index * params.max_custom_predicate_batches + cpr.index; + CustomPredicateVerification { + custom_predicate_table_index, + custom_predicate: cpr.clone(), + args: wildcard_values, + op_args: sts, + } + }) + .collect(); + if custom_predicate_data.len() > params.max_custom_predicate_verifications { + return Err(Error::custom(format!( + "The number of required custom predicate verifications ({}) exceeds the maximum number ({}).", + custom_predicate_data.len(), + params.max_custom_predicate_verifications + ))); + } + Ok(custom_predicate_data) +} + +/// Extracts Merkle proofs from Contains/NotContains ops. pub(crate) fn extract_merkle_proofs( params: &Params, operations: &[middleware::Operation], @@ -98,11 +165,32 @@ fn find_op_arg(statements: &[Statement], op_arg: &middleware::Statement) -> Resu } /// Find the operation auxiliary data in the list of auxiliary data and return the index. +// NOTE: The `custom_predicate_verifications` is optional because in the MainPod we want to store +// the index of a custom predicate verification in the aux data, but in the MockMainPod we don't +// need that because we keep a reference to the custom predicate in the operation type, which +// removes the need for indexing. We could change the OperationType and Predicate for the backend +// to not keep a reference to the custom predicate and instead just keep the id and index and then +// do the same double indexing that the MainPod does to verify custom predicates. fn find_op_aux( merkle_proofs: &[MerkleClaimAndProof], - op_aux: &middleware::OperationAux, + custom_predicate_verifications: Option<&[CustomPredicateVerification]>, + op: &middleware::Operation, ) -> Result { - match op_aux { + let op_aux = op.aux(); + let op_type = op.op_type(); + if let (OperationType::Custom(cpr), Some(cpvs)) = (op_type, custom_predicate_verifications) { + return Ok(cpvs + .iter() + .enumerate() + .find_map(|(i, cpv)| { + (cpv.custom_predicate.batch.id() == cpr.batch.id() + && cpv.custom_predicate.index == cpr.index) + .then_some(i) + }) + .map(OperationAux::CustomPredVerifyIndex) + .expect("custom predicate verification in the list")); + } + match &op_aux { middleware::OperationAux::None => Ok(OperationAux::None), middleware::OperationAux::MerkleProof(pf_arg) => merkle_proofs .iter() @@ -217,6 +305,7 @@ pub(crate) fn process_private_statements_operations( params: &Params, statements: &[Statement], merkle_proofs: &[MerkleClaimAndProof], + custom_predicate_verifications: Option<&[CustomPredicateVerification]>, input_operations: &[middleware::Operation], ) -> Result> { let mut operations = Vec::new(); @@ -231,8 +320,7 @@ pub(crate) fn process_private_statements_operations( .map(|mid_arg| find_op_arg(statements, mid_arg)) .collect::>>()?; - let mid_aux = op.aux(); - let aux = find_op_aux(merkle_proofs, &mid_aux)?; + let aux = find_op_aux(merkle_proofs, custom_predicate_verifications, &op)?; pad_operation_args(params, &mut args); operations.push(Operation(op.op_type(), args, aux)); @@ -301,12 +389,19 @@ impl Prover { .collect_vec(); let merkle_proofs = extract_merkle_proofs(params, inputs.operations)?; + let custom_predicate_batches = extract_custom_predicate_batches(params, inputs.operations)?; + let custom_predicate_verifications = extract_custom_predicate_verifications( + params, + inputs.operations, + &custom_predicate_batches, + )?; let statements = layout_statements(params, &inputs); let operations = process_private_statements_operations( params, &statements, &merkle_proofs, + Some(&custom_predicate_verifications), inputs.operations, )?; let operations = process_public_statements_operations(params, &statements, operations)?; @@ -321,6 +416,8 @@ impl Prover { statements: statements[statements.len() - params.max_statements..].to_vec(), operations, merkle_proofs, + custom_predicate_batches, + custom_predicate_verifications, }; main_pod.set_targets(&mut pw, &input)?; @@ -505,4 +602,41 @@ pub mod tests { let pod = (kyc_pod.pod as Box).downcast::().unwrap(); pod.verify().unwrap() } + + #[test] + fn test_mainpod_small_empty() { + let params = middleware::Params { + max_input_signed_pods: 0, + max_input_main_pods: 0, + max_statements: 5, + max_signed_pod_values: 2, + max_public_statements: 2, + max_statement_args: 2, + max_operation_args: 3, + max_custom_predicate_batches: 2, + max_custom_predicate_verifications: 2, + max_custom_predicate_arity: 2, + max_custom_predicate_wildcards: 2, + max_custom_batch_size: 2, + max_merkle_proofs: 2, + max_depth_mt_gadget: 4, + }; + + let pod_builder = frontend::MainPodBuilder::new(¶ms); + + // Mock + let mut prover = MockProver {}; + let kyc_pod = pod_builder.prove(&mut prover, ¶ms).unwrap(); + let pod = (kyc_pod.pod as Box) + .downcast::() + .unwrap(); + pod.verify().unwrap(); + println!("{:#}", pod); + + // Real + let mut prover = Prover {}; + let kyc_pod = pod_builder.prove(&mut prover, ¶ms).unwrap(); + let pod = (kyc_pod.pod as Box).downcast::().unwrap(); + pod.verify().unwrap() + } } diff --git a/src/backends/plonky2/mainpod/operation.rs b/src/backends/plonky2/mainpod/operation.rs index 1ba2069..561ed82 100644 --- a/src/backends/plonky2/mainpod/operation.rs +++ b/src/backends/plonky2/mainpod/operation.rs @@ -38,15 +38,17 @@ impl OperationArg { pub enum OperationAux { None, MerkleProofIndex(usize), + CustomPredVerifyIndex(usize), } impl ToFields for OperationAux { fn to_fields(&self, _params: &Params) -> Vec { - let f = match self { - Self::None => F::ZERO, - Self::MerkleProofIndex(i) => F::from_canonical_usize(*i), + let fs = match self { + Self::None => [F::ZERO, F::ZERO], + Self::MerkleProofIndex(i) => [F::from_canonical_usize(*i), F::ZERO], + Self::CustomPredVerifyIndex(i) => [F::ZERO, F::from_canonical_usize(*i)], }; - vec![f] + vec![fs[0], fs[1]] } } @@ -78,6 +80,7 @@ impl Operation { .collect::>>()?; let deref_aux = match self.2 { OperationAux::None => crate::middleware::OperationAux::None, + OperationAux::CustomPredVerifyIndex(_) => crate::middleware::OperationAux::None, OperationAux::MerkleProofIndex(i) => crate::middleware::OperationAux::MerkleProof( merkle_proofs .get(i) @@ -111,6 +114,7 @@ impl fmt::Display for Operation { match self.2 { OperationAux::None => (), OperationAux::MerkleProofIndex(i) => write!(f, " merkle_proof_{:02}", i)?, + OperationAux::CustomPredVerifyIndex(i) => write!(f, " custom_pred_verify_{:02}", i)?, } Ok(()) } diff --git a/src/backends/plonky2/mock/mainpod.rs b/src/backends/plonky2/mock/mainpod.rs index 08960e7..c838c8c 100644 --- a/src/backends/plonky2/mock/mainpod.rs +++ b/src/backends/plonky2/mock/mainpod.rs @@ -147,6 +147,7 @@ impl MockMainPod { params, &statements, &merkle_proofs, + None, inputs.operations, )?; let operations = process_public_statements_operations(params, &statements, operations)?; diff --git a/src/examples/custom.rs b/src/examples/custom.rs index 3131555..65c4c57 100644 --- a/src/examples/custom.rs +++ b/src/examples/custom.rs @@ -12,10 +12,9 @@ use crate::{ /// Instantiates an ETH friend batch pub fn eth_friend_batch(params: &Params) -> Result> { - let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into()); + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "eth_friend".into()); let _eth_friend = builder.predicate_and( "eth_friend", - params, // arguments: &["src_ori", "src_key", "dst_ori", "dst_key"], // private arguments: @@ -44,7 +43,8 @@ pub fn eth_friend_batch(params: &Params) -> Result> { /// Instantiates an ETHDoS batch pub fn eth_dos_batch(params: &Params) -> Result> { let eth_friend = Predicate::Custom(CustomPredicateRef::new(eth_friend_batch(params)?, 0)); - let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into()); + let mut builder = + CustomPredicateBatchBuilder::new(params.clone(), "eth_dos_distance_base".into()); // eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and< // eq(src_or, src_key, dst_or, dst_key), @@ -52,7 +52,6 @@ pub fn eth_dos_batch(params: &Params) -> Result> { // > let eth_dos_distance_base = builder.predicate_and( "eth_dos_distance_base", - params, &[ // arguments: "src_ori", @@ -83,7 +82,6 @@ pub fn eth_dos_batch(params: &Params) -> Result> { let eth_dos_distance_ind = builder.predicate_and( "eth_dos_distance_ind", - params, &[ // arguments: "src_ori", @@ -135,7 +133,6 @@ pub fn eth_dos_batch(params: &Params) -> Result> { let _eth_dos_distance = builder.predicate_or( "eth_dos_distance", - params, &[ "src_ori", "src_key", diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 81ee86e..ace5720 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -128,13 +128,15 @@ impl StatementTmplBuilder { } pub struct CustomPredicateBatchBuilder { + params: Params, pub name: String, pub predicates: Vec, } impl CustomPredicateBatchBuilder { - pub fn new(name: String) -> Self { + pub fn new(params: Params, name: String) -> Self { Self { + params, name, predicates: Vec::new(), } @@ -143,23 +145,21 @@ impl CustomPredicateBatchBuilder { pub fn predicate_and( &mut self, name: &str, - params: &Params, args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], ) -> Result { - self.predicate(name, params, true, args, priv_args, sts) + self.predicate(name, true, args, priv_args, sts) } pub fn predicate_or( &mut self, name: &str, - params: &Params, args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], ) -> Result { - self.predicate(name, params, false, args, priv_args, sts) + self.predicate(name, false, args, priv_args, sts) } /// creates the custom predicate from the given input, adds it to the @@ -167,24 +167,23 @@ impl CustomPredicateBatchBuilder { fn predicate( &mut self, name: &str, - params: &Params, conjunction: bool, args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder], ) -> Result { - if args.len() > params.max_statement_args { + if args.len() > self.params.max_statement_args { return Err(Error::max_length( "args.len".to_string(), args.len(), - params.max_statement_args, + self.params.max_statement_args, )); } - if (args.len() + priv_args.len()) > params.max_custom_predicate_wildcards { + if (args.len() + priv_args.len()) > self.params.max_custom_predicate_wildcards { return Err(Error::max_length( "wildcards.len".to_string(), args.len() + priv_args.len(), - params.max_custom_predicate_wildcards, + self.params.max_custom_predicate_wildcards, )); } @@ -197,7 +196,7 @@ impl CustomPredicateBatchBuilder { .iter() .map(|a| match a { BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), - BuilderArg::Key(pod_id, key) => StatementTmplArg::Key( + BuilderArg::Key(pod_id, key) => StatementTmplArg::AnchoredKey( resolve_wildcard(args, priv_args, pod_id), resolve_key_or_wildcard(args, priv_args, key), ), @@ -212,17 +211,19 @@ impl CustomPredicateBatchBuilder { } }) .collect(); - let custom_predicate = - CustomPredicate::new(name.into(), params, conjunction, statements, args.len())?; + let custom_predicate = CustomPredicate::new( + &self.params, + name.into(), + conjunction, + statements, + args.len(), + )?; self.predicates.push(custom_predicate); Ok(Predicate::BatchSelf(self.predicates.len() - 1)) } pub fn finish(self) -> Arc { - Arc::new(CustomPredicateBatch { - name: self.name, - predicates: self.predicates, - }) + CustomPredicateBatch::new(&self.params, self.name, self.predicates) } } @@ -290,7 +291,7 @@ mod tests { #[test] fn test_desugared_gt_custom_pred() -> Result<()> { let params = Params::default(); - let mut builder = CustomPredicateBatchBuilder::new("gt_custom_pred".into()); + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "gt_custom_pred".into()); let gt_stb = StatementTmplBuilder::new(NativePredicate::Gt) .arg(("s1_origin", "s1_key")) @@ -298,7 +299,6 @@ mod tests { builder.predicate_and( "gt_custom_pred", - ¶ms, &["s1_origin", "s1_key", "s2_origin", "s2_key"], &[], &[gt_stb], @@ -322,7 +322,7 @@ mod tests { // Check that the desugared predicate is the same as the one in the statement template assert_eq!( desugared_gt.predicate(), - *batch_clone.predicates[0].statements[0].pred() + *batch_clone.predicates()[0].statements[0].pred() ); // Check that our custom predicate matches the statement template @@ -339,7 +339,8 @@ mod tests { #[test] fn test_desugared_set_contains_custom_pred() -> Result<()> { let params = Params::default(); - let mut builder = CustomPredicateBatchBuilder::new("set_contains_custom_pred".into()); + let mut builder = + CustomPredicateBatchBuilder::new(params.clone(), "set_contains_custom_pred".into()); let set_contains_stb = StatementTmplBuilder::new(NativePredicate::SetContains) .arg(("s1_origin", "s1_key")) @@ -347,7 +348,6 @@ mod tests { builder.predicate_and( "set_contains_custom_pred", - ¶ms, &["s1_origin", "s1_key", "s2_origin", "s2_key"], &[], &[set_contains_stb], @@ -368,7 +368,7 @@ mod tests { ); assert_eq!( set_contains.predicate(), - *batch_clone.predicates[0].statements[0].pred() + *batch_clone.predicates()[0].statements[0].pred() ); let set_contains_custom_pred = CustomPredicateRef::new(batch, 0); diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index f906887..41efddc 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -466,7 +466,7 @@ impl MainPodBuilder { )))?, }, OperationType::Custom(cpr) => { - let pred = &cpr.batch.predicates[cpr.index]; + let pred = &cpr.batch.predicates()[cpr.index]; if pred.statements.len() != args.len() { return Err(Error::custom(format!( "Custom predicate operation needs {} statements but has {}.", diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index e08859f..3f2fe2e 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -5,7 +5,8 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use crate::middleware::{ - hash_fields, Error, Hash, Key, Params, Predicate, Result, ToFields, Value, F, HASH_SIZE, + hash_fields, Error, Hash, Key, NativePredicate, Params, Predicate, Result, ToFields, Value, + EMPTY_HASH, F, HASH_SIZE, VALUE_SIZE, }; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] @@ -49,12 +50,15 @@ impl fmt::Display for KeyOrWildcard { } impl ToFields for KeyOrWildcard { + // Encoding: + // - Key(k) => [[k]] + // - Wildcard(index) => [[index], 0, 0, 0] fn to_fields(&self, params: &Params) -> Vec { match self { KeyOrWildcard::Key(k) => k.hash().to_fields(params), - KeyOrWildcard::Wildcard(wc) => iter::once(F::ZERO) - .take(HASH_SIZE - 1) - .chain(iter::once(F::from_canonical_u64(wc.index as u64))) + KeyOrWildcard::Wildcard(wc) => iter::once(F::from_canonical_u64(wc.index as u64)) + .chain(iter::repeat(F::ZERO)) + .take(HASH_SIZE) .collect(), } } @@ -66,7 +70,7 @@ pub enum StatementTmplArg { None, Literal(Value), // AnchoredKey - Key(Wildcard, KeyOrWildcard), + AnchoredKey(Wildcard, KeyOrWildcard), // TODO: This naming is a bit confusing: a WildcardLiteral that contains a Wildcard... // Could we merge WildcardValue and Value and allow wildcard value apart from pod_id and key? WildcardLiteral(Wildcard), @@ -76,7 +80,7 @@ pub enum StatementTmplArg { pub enum StatementTmplArgPrefix { None = 0, Literal = 1, - Key = 2, + AnchoredKey = 2, WildcardLiteral = 3, } @@ -88,11 +92,11 @@ impl From for F { impl ToFields for StatementTmplArg { fn to_fields(&self, params: &Params) -> Vec { - // None => (0, ...) - // Literal(value) => (1, [value], 0, 0, 0, 0) - // Key(wildcard1_index, key_or_wildcard2) - // => (2, [wildcard1_index], 0, 0, 0, [key_or_wildcard2]) - // WildcardLiteral(wildcard_index) => (3, [wildcard_index], 0, 0, 0, 0, 0, 0, 0) + // Encoding: + // None => (0, 0, 0, 0, 0, 0, 0, 0, 0) + // Literal(v) => (1, [v ], 0, 0, 0, 0) + // Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc]) + // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0) // In all three cases, we pad to 2 * hash_size + 1 = 9 field elements match self { StatementTmplArg::None => { @@ -105,13 +109,15 @@ impl ToFields for StatementTmplArg { StatementTmplArg::Literal(v) => { let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::Literal)) .chain(v.raw().to_fields(params)) - .chain(iter::repeat(F::ZERO).take(HASH_SIZE)) + .chain(iter::repeat(F::ZERO)) + .take(Params::statement_tmpl_arg_size()) .collect(); fields } - StatementTmplArg::Key(wc1, kw2) => { - let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::Key)) + StatementTmplArg::AnchoredKey(wc1, kw2) => { + let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::AnchoredKey)) .chain(wc1.to_fields(params)) + .chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1)) .chain(kw2.to_fields(params)) .collect(); fields @@ -119,7 +125,8 @@ impl ToFields for StatementTmplArg { StatementTmplArg::WildcardLiteral(wc) => { let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral)) .chain(wc.to_fields(params)) - .chain(iter::repeat(F::ZERO).take(HASH_SIZE)) + .chain(iter::repeat(F::ZERO)) + .take(Params::statement_tmpl_arg_size()) .collect(); fields } @@ -132,7 +139,7 @@ impl fmt::Display for StatementTmplArg { match self { Self::None => write!(f, "none"), Self::Literal(v) => write!(f, "{}", v), - Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key), + Self::AnchoredKey(pod_id, key) => write!(f, "({}, {})", pod_id, key), Self::WildcardLiteral(v) => write!(f, "{}", v), } } @@ -177,7 +184,11 @@ impl ToFields for StatementTmpl { // instead of at the `to_fields` method, where we should assume that the // values are already valid if self.args.len() > params.max_statement_args { - panic!("Statement template has too many arguments"); + panic!( + "Statement template has too many arguments {} > {}", + self.args.len(), + params.max_statement_args + ); } let mut fields: Vec = self @@ -206,25 +217,36 @@ pub struct CustomPredicate { } impl CustomPredicate { + pub fn empty() -> Self { + Self { + name: "empty".to_string(), + conjunction: false, + statements: vec![StatementTmpl { + pred: Predicate::Native(NativePredicate::None), + args: vec![], + }], + args_len: 0, + } + } pub fn and( - name: String, params: &Params, + name: String, statements: Vec, args_len: usize, ) -> Result { - Self::new(name, params, true, statements, args_len) + Self::new(params, name, true, statements, args_len) } pub fn or( - name: String, params: &Params, + name: String, statements: Vec, args_len: usize, ) -> Result { - Self::new(name, params, false, statements, args_len) + Self::new(params, name, false, statements, args_len) } pub fn new( - name: String, params: &Params, + name: String, conjunction: bool, statements: Vec, args_len: usize, @@ -236,6 +258,13 @@ impl CustomPredicate { params.max_custom_predicate_arity, )); } + if args_len > params.max_statement_args { + return Err(Error::max_length( + "statement_args.len".to_string(), + args_len, + params.max_statement_args, + )); + } Ok(Self { name, @@ -244,6 +273,16 @@ impl CustomPredicate { args_len, }) } + pub fn pad_statement_tmpl(&self) -> StatementTmpl { + StatementTmpl { + pred: Predicate::Native(if self.conjunction { + NativePredicate::False + } else { + NativePredicate::None + }), + args: vec![], + } + } } impl ToFields for CustomPredicate { @@ -262,11 +301,17 @@ impl ToFields for CustomPredicate { panic!("Custom predicate depends on too many statements"); } - let mut fields: Vec = iter::once(F::from_bool(self.conjunction)) + let pad_st = self.pad_statement_tmpl(); + let fields: Vec = iter::once(F::from_bool(self.conjunction)) .chain(iter::once(F::from_canonical_usize(self.args_len))) - .chain(self.statements.iter().flat_map(|st| st.to_fields(params))) + .chain( + self.statements + .iter() + .chain(iter::repeat(&pad_st)) + .take(params.max_custom_predicate_arity) + .flat_map(|st| st.to_fields(params)), + ) .collect(); - fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0)); fields } } @@ -298,8 +343,9 @@ impl fmt::Display for CustomPredicate { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] pub struct CustomPredicateBatch { + id: Hash, pub name: String, - pub predicates: Vec, + predicates: Vec, } impl ToFields for CustomPredicateBatch { @@ -313,27 +359,45 @@ impl ToFields for CustomPredicateBatch { panic!("Predicate batch exceeds maximum size"); } - let mut fields: Vec = self + let pad_pred = CustomPredicate::empty(); + let fields: Vec = self .predicates .iter() + .chain(iter::repeat(&pad_pred)) + .take(params.max_custom_batch_size) .flat_map(|p| p.to_fields(params)) .collect(); - fields.resize_with(params.custom_predicate_batch_size_field_elts(), || { - F::from_canonical_u64(0) - }); fields } } impl CustomPredicateBatch { + pub fn new(params: &Params, name: String, predicates: Vec) -> Arc { + let mut cpb = Self { + id: EMPTY_HASH, + name, + predicates, + }; + let id = cpb.calculate_id(params); + cpb.id = id; + Arc::new(cpb) + } + /// Cryptographic identifier for the batch. - pub fn id(&self, params: &Params) -> Hash { + fn calculate_id(&self, params: &Params) -> Hash { // NOTE: This implementation just hashes the concatenation of all the custom predicates, // but ideally we want to use the root of a merkle tree built from the custom predicates. let input = self.to_fields(params); hash_fields(&input) } + + pub fn id(&self) -> Hash { + self.id + } + pub fn predicates(&self) -> &[CustomPredicate] { + &self.predicates + } } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] @@ -347,13 +411,16 @@ impl CustomPredicateRef { Self { batch, index } } pub fn arg_len(&self) -> usize { - self.batch.predicates[self.index].args_len + self.predicate().args_len + } + pub fn predicate(&self) -> &CustomPredicate { + &self.batch.predicates[self.index] } } #[cfg(test)] mod tests { - use std::{array, sync::Arc}; + use std::array; use plonky2::field::goldilocks_field::GoldilocksField; @@ -392,28 +459,29 @@ mod tests { p:value_of(Constant, 2), p:product_of(S1, Constant, S2) */ - let cust_pred_batch = Arc::new(CustomPredicateBatch { - name: "is_double".to_string(), - predicates: vec![CustomPredicate::and( - "_".into(), + let cust_pred_batch = CustomPredicateBatch::new( + ¶ms, + "is_double".to_string(), + vec![CustomPredicate::and( ¶ms, + "_".into(), vec![ st( P::Native(NP::ValueOf), - vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(2.into())], + vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(2.into())], ), st( P::Native(NP::ProductOf), vec![ - STA::Key(wc(0), kow_wc(1)), - STA::Key(wc(4), kow_wc(5)), - STA::Key(wc(2), kow_wc(3)), + STA::AnchoredKey(wc(0), kow_wc(1)), + STA::AnchoredKey(wc(4), kow_wc(5)), + STA::AnchoredKey(wc(2), kow_wc(3)), ], ), ], 2, )?], - }); + ); let custom_statement = Statement::Custom( CustomPredicateRef::new(cust_pred_batch.clone(), 0), @@ -444,55 +512,57 @@ mod tests { fn ethdos_test() -> Result<()> { let params = Params { max_custom_predicate_wildcards: 12, + max_statement_args: 6, ..Default::default() }; let eth_friend_cp = CustomPredicate::and( - "eth_friend_cp".into(), ¶ms, + "eth_friend_cp".into(), vec![ st( P::Native(NP::ValueOf), vec![ - STA::Key(wc(4), KeyOrWildcard::Key("type".into())), + STA::AnchoredKey(wc(4), KeyOrWildcard::Key("type".into())), STA::Literal(PodType::Signed.into()), ], ), st( P::Native(NP::Equal), vec![ - STA::Key(wc(4), KeyOrWildcard::Key("signer".into())), - STA::Key(wc(0), kow_wc(1)), + STA::AnchoredKey(wc(4), KeyOrWildcard::Key("signer".into())), + STA::AnchoredKey(wc(0), kow_wc(1)), ], ), st( P::Native(NP::Equal), vec![ - STA::Key(wc(4), KeyOrWildcard::Key("attestation".into())), - STA::Key(wc(2), kow_wc(3)), + STA::AnchoredKey(wc(4), KeyOrWildcard::Key("attestation".into())), + STA::AnchoredKey(wc(2), kow_wc(3)), ], ), ], 4, )?; - let eth_friend_batch = Arc::new(CustomPredicateBatch { - name: "eth_friend".to_string(), - predicates: vec![eth_friend_cp], - }); + let eth_friend_batch = + CustomPredicateBatch::new(¶ms, "eth_friend".to_string(), vec![eth_friend_cp]); // 0 let eth_dos_base = CustomPredicate::and( - "eth_dos_base".into(), ¶ms, + "eth_dos_base".into(), vec![ st( P::Native(NP::Equal), - vec![STA::Key(wc(0), kow_wc(1)), STA::Key(wc(2), kow_wc(3))], + vec![ + STA::AnchoredKey(wc(0), kow_wc(1)), + STA::AnchoredKey(wc(2), kow_wc(3)), + ], ), st( P::Native(NP::ValueOf), - vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(0.into())], + vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(0.into())], ), ], 6, @@ -500,8 +570,8 @@ mod tests { // 1 let eth_dos_ind = CustomPredicate::and( - "eth_dos_ind".into(), ¶ms, + "eth_dos_ind".into(), vec![ st( P::BatchSelf(2), @@ -516,14 +586,14 @@ mod tests { ), st( P::Native(NP::ValueOf), - vec![STA::Key(wc(6), kow_wc(7)), STA::Literal(1.into())], + vec![STA::AnchoredKey(wc(6), kow_wc(7)), STA::Literal(1.into())], ), st( P::Native(NP::SumOf), vec![ - STA::Key(wc(4), kow_wc(5)), - STA::Key(wc(8), kow_wc(9)), - STA::Key(wc(6), kow_wc(7)), + STA::AnchoredKey(wc(4), kow_wc(5)), + STA::AnchoredKey(wc(8), kow_wc(9)), + STA::AnchoredKey(wc(6), kow_wc(7)), ], ), st( @@ -541,8 +611,8 @@ mod tests { // 2 let eth_dos_distance_either = CustomPredicate::or( - "eth_dos_distance_either".into(), ¶ms, + "eth_dos_distance_either".into(), vec![ st( P::BatchSelf(0), @@ -570,10 +640,11 @@ mod tests { 6, )?; - let eth_dos_distance_batch = Arc::new(CustomPredicateBatch { - name: "ETHDoS_distance".to_string(), - predicates: vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either], - }); + let eth_dos_distance_batch = CustomPredicateBatch::new( + ¶ms, + "ETHDoS_distance".to_string(), + vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either], + ); // Some POD IDs let pod_id1 = PodId(Hash(array::from_fn(|i| GoldilocksField(i as u64)))); diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 99f5caf..df8dc20 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -584,6 +584,10 @@ pub struct Params { pub max_public_statements: usize, pub max_statement_args: usize, pub max_operation_args: usize, + // max number of custom predicates batches that a MainPod can use + pub max_custom_predicate_batches: usize, + // max number of operations using custom predicates that can be verified in the MainPod + pub max_custom_predicate_verifications: usize, // max number of statements that can be ANDed or ORed together // in a custom predicate pub max_custom_predicate_arity: usize, @@ -605,6 +609,8 @@ impl Default for Params { max_public_statements: 10, max_statement_args: 5, max_operation_args: 5, + max_custom_predicate_batches: 2, + max_custom_predicate_verifications: 5, max_custom_predicate_arity: 5, max_custom_predicate_wildcards: 10, max_custom_batch_size: 5, diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 6cde0a7..08efd93 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -1,4 +1,4 @@ -use std::{fmt, iter, sync::Arc}; +use std::{fmt, iter}; use log::error; use plonky2::field::types::Field; @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::primitives::merkletree::MerkleProof, middleware::{ - custom::KeyOrWildcard, AnchoredKey, CustomPredicateBatch, CustomPredicateRef, Error, + custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmplArg, ToFields, Wildcard, WildcardValue, F, SELF, }, @@ -36,6 +36,9 @@ impl fmt::Display for OperationAux { } impl ToFields for OperationType { + /// Encoding: + /// - Native(native_op) => [1, [native_op], 0, 0, 0, 0] + /// - Custom(batch, index) => [3, [batch.id], index] fn to_fields(&self, params: &Params) -> Vec { let mut fields: Vec = match self { Self::Native(p) => iter::once(F::from_canonical_u64(1)) @@ -43,7 +46,7 @@ impl ToFields for OperationType { .collect(), Self::Custom(CustomPredicateRef { batch, index }) => { iter::once(F::from_canonical_u64(3)) - .chain(batch.id(params).0) + .chain(batch.id().0) .chain(iter::once(F::from_canonical_usize(*index))) .collect() } @@ -321,7 +324,7 @@ impl Operation { (Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args)) if batch == &cpr.batch && index == &cpr.index => { - check_custom_pred(params, batch, *index, args, s_args) + check_custom_pred(params, cpr, args, s_args) } _ => Err(Error::invalid_deduction( self.clone(), @@ -360,7 +363,7 @@ pub fn check_st_tmpl( (StatementTmplArg::None, StatementArg::None) => true, (StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true, ( - StatementTmplArg::Key(pod_id_wc, key_or_wc), + StatementTmplArg::AnchoredKey(pod_id_wc, key_or_wc), StatementArg::Key(AnchoredKey { pod_id, key }), ) => { let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map); @@ -379,14 +382,46 @@ pub fn check_st_tmpl( } } +pub fn resolve_wildcard_values( + params: &Params, + pred: &CustomPredicate, + args: &[Statement], +) -> Option> { + // Check that all wildcard have consistent values as assigned in the statements while storing a + // map of their values. + // NOTE: We assume the statements have the same order as defined in the custom predicate. For + // disjunctions we expect Statement::None for the unused statements. + let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards]; + for (st_tmpl, st) in pred.statements.iter().zip(args) { + let st_args = st.args(); + for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { + if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) { + // TODO: Better errors. Example: + // println!("{} doesn't match {}", st_arg, st_tmpl_arg); + // println!("{} doesn't match {}", st, st_tmpl); + return None; + } + } + } + + // NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because + // they are beyond the number of used wildcards in this custom predicate, or they could be + // private arguments that are unused in a particular disjunction. + Some( + wildcard_map + .into_iter() + .map(|opt| opt.unwrap_or(WildcardValue::None)) + .collect(), + ) +} + fn check_custom_pred( params: &Params, - batch: &Arc, - index: usize, + custom_pred_ref: &CustomPredicateRef, args: &[Statement], s_args: &[WildcardValue], ) -> Result { - let pred = &batch.predicates[index]; + let pred = custom_pred_ref.predicate(); if pred.statements.len() != args.len() { return Err(Error::diff_amount( "custom predicate operation".to_string(), @@ -404,26 +439,12 @@ fn check_custom_pred( )); } - // Check that all wildcard have consistent values as assigned in the statements while storing a - // map of their values. Count the number of statements that match the templates by predicate. - // NOTE: We assume the statements have the same order as defined in the custom predicate. For - // disjunctions we expect Statement::None for the unused statements. + // Count the number of statements that match the templates by predicate. let mut num_matches = 0; - let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards]; for (st_tmpl, st) in pred.statements.iter().zip(args) { - let st_args = st.args(); - for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { - if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) { - // TODO: Better errors. Example: - // println!("{} doesn't match {}", st_arg, st_tmpl_arg); - // println!("{} doesn't match {}", st, st_tmpl); - return Ok(false); - } - } - let st_tmpl_pred = match &st_tmpl.pred { Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef { - batch: batch.clone(), + batch: custom_pred_ref.batch.clone(), index: *i, }), p => p.clone(), @@ -433,9 +454,14 @@ fn check_custom_pred( } } + let wildcard_map = match resolve_wildcard_values(params, pred, args) { + Some(wc_map) => wc_map, + None => return Ok(false), + }; + // Check that the resolved wildcard match the statement arguments. for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) { - if !wc_value.as_ref().is_none_or(|wc_value| *wc_value == *s_arg) { + if *wc_value != *s_arg { return Ok(false); } } diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 6246e94..aee4407 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -7,7 +7,7 @@ use strum_macros::FromRepr; use crate::middleware::{ AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, RawValue, Result, ToFields, Value, - F, VALUE_SIZE, + EMPTY_VALUE, F, VALUE_SIZE, }; // TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static @@ -17,22 +17,23 @@ pub const KEY_SIGNER: &str = "_signer"; pub const KEY_TYPE: &str = "_type"; pub const STATEMENT_ARG_F_LEN: usize = 8; pub const OPERATION_ARG_F_LEN: usize = 1; -pub const OPERATION_AUX_F_LEN: usize = 1; +pub const OPERATION_AUX_F_LEN: usize = 2; #[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub enum NativePredicate { - None = 0, - ValueOf = 1, - Equal = 2, - NotEqual = 3, - LtEq = 4, - Lt = 5, - Contains = 6, - NotContains = 7, - SumOf = 8, - ProductOf = 9, - MaxOf = 10, - HashOf = 11, + None = 0, // Always true + False = 1, // Always false + ValueOf = 2, + Equal = 3, + NotEqual = 4, + LtEq = 5, + Lt = 6, + Contains = 7, + NotContains = 8, + SumOf = 9, + ProductOf = 10, + MaxOf = 11, + HashOf = 12, // Syntactic sugar predicates. These predicates are not supported by the backend. The // frontend compiler is responsible of translating these predicates into the predicates above. @@ -53,6 +54,7 @@ impl ToFields for NativePredicate { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] pub enum WildcardValue { + None, PodId(PodId), Key(Key), } @@ -60,6 +62,7 @@ pub enum WildcardValue { impl WildcardValue { pub fn raw(&self) -> RawValue { match self { + WildcardValue::None => EMPTY_VALUE, WildcardValue::PodId(pod_id) => RawValue::from(pod_id.0), WildcardValue::Key(key) => key.raw(), } @@ -69,6 +72,7 @@ impl WildcardValue { impl fmt::Display for WildcardValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + WildcardValue::None => write!(f, "none"), WildcardValue::PodId(pod_id) => write!(f, "{}", pod_id), WildcardValue::Key(key) => write!(f, "{}", key), } @@ -77,10 +81,7 @@ impl fmt::Display for WildcardValue { impl ToFields for WildcardValue { fn to_fields(&self, params: &Params) -> Vec { - match self { - WildcardValue::PodId(pod_id) => pod_id.to_fields(params), - WildcardValue::Key(key) => key.to_fields(params), - } + self.raw().to_fields(params) } } @@ -130,7 +131,7 @@ impl ToFields for Predicate { .collect(), Self::Custom(CustomPredicateRef { batch, index }) => { iter::once(F::from(PredicatePrefix::Custom)) - .chain(batch.id(params).0) + .chain(batch.id().0) .chain(iter::once(F::from_canonical_usize(*index))) .collect() } @@ -149,7 +150,9 @@ impl fmt::Display for Predicate { write!( f, "{}.{}[{}]", - batch.name, index, batch.predicates[*index].name + batch.name, + index, + batch.predicates()[*index].name ) } } @@ -397,14 +400,14 @@ impl StatementArg { } impl ToFields for StatementArg { - fn to_fields(&self, _params: &Params) -> Vec { - // NOTE: current version returns always the same amount of field elements in the returned - // vector, which means that the `None` case is padded with 8 zeroes, and the `Literal` case - // is padded with 4 zeroes. Since the returned vector will mostly be hashed (and reproduced - // in-circuit), we might be interested into reducing the length of it. If that's the case, - // we can check if it makes sense to make it dependant on the concrete StatementArg; that - // is, when dealing with a `None` it would be a single field element (zero value), and when - // dealing with `Literal` it would be of length 4. + /// Encoding: + /// - None => [0, 0, 0, 0, 0, 0, 0, 0] + /// - Literal(v) => [[v], 0, 0, 0, 0] + /// - Key(pod_id, key) => [[pod_id], [key]] + /// - WildcardLiteral(v) => [[v], 0, 0, 0, 0] + fn to_fields(&self, params: &Params) -> Vec { + // NOTE for @ax0: I removed the old comment because may `to_fields` implementations do + // padding and we need fixed output length for the circuits. let f = match self { StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN], StatementArg::Literal(v) => v @@ -414,8 +417,8 @@ impl ToFields for StatementArg { .chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE)) .collect(), StatementArg::Key(ak) => { - let mut fields = ak.pod_id.to_fields(_params); - fields.extend(ak.key.to_fields(_params)); + let mut fields = ak.pod_id.to_fields(params); + fields.extend(ak.key.to_fields(params)); fields } StatementArg::WildcardLiteral(v) => v