From 726f95483de7e70c22f0446d655b6f89713bd0e1 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Wed, 7 May 2025 11:09:38 +0200 Subject: [PATCH] add target types for custom predicates (#223) * add target types for custom predicates * simplify * fix clippy * fix typo * don't use ref for NativePredicate * fix wrong len * apply feedback from @ax0 --- src/backends/plonky2/circuits/common.rs | 427 +++++++++++++++++++-- src/backends/plonky2/circuits/mainpod.rs | 19 +- src/backends/plonky2/circuits/signedpod.rs | 24 +- src/middleware/custom.rs | 102 ++--- src/middleware/operation.rs | 2 +- src/middleware/statement.rs | 76 +++- 6 files changed, 527 insertions(+), 123 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 72f129d..87101f6 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -26,9 +26,9 @@ use crate::{ primitives::merkletree::MerkleClaimAndProofTarget, }, middleware::{ - NativeOperation, NativePredicate, Params, Predicate, RawValue, StatementArg, ToFields, - EMPTY_VALUE, F, HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, - VALUE_SIZE, + NativeOperation, NativePredicate, Params, Predicate, PredicatePrefix, RawValue, + StatementArg, StatementTmplArgPrefix, ToFields, EMPTY_VALUE, F, HASH_SIZE, + OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE, }, }; @@ -117,20 +117,37 @@ impl StatementArgTarget { #[derive(Clone)] pub struct StatementTarget { - pub predicate: [Target; Params::predicate_size()], + pub predicate: PredicateTarget, pub args: Vec, } +pub trait Build { + fn build(self, builder: &mut CircuitBuilder, params: &Params) -> T; +} + +impl Build for NativePredicate { + fn build(self, builder: &mut CircuitBuilder, params: &Params) -> NativePredicateTarget { + NativePredicateTarget::constant(builder, params, self) + } +} + +impl Build for T { + fn build(self, _builder: &mut CircuitBuilder, _params: &Params) -> T { + self + } +} + impl StatementTarget { pub fn new_native( builder: &mut CircuitBuilder, params: &Params, - predicate: NativePredicate, + native_predicate: impl Build, args: &[StatementArgTarget], ) -> Self { - let predicate_vec = builder.constants(&Predicate::Native(predicate).to_fields(params)); + // if native_predicate is const then NativePredicate -> NativePredicateTarget + // else just use as is Self { - predicate: array::from_fn(|i| predicate_vec[i]), + predicate: PredicateTarget::new_native(builder, params, native_predicate), args: args .iter() .cloned() @@ -146,7 +163,7 @@ impl StatementTarget { params: &Params, st: &Statement, ) -> Result<()> { - pw.set_target_arr(&self.predicate, &st.predicate().to_fields(params))?; + self.predicate.set_targets(pw, params, st.predicate())?; for (i, arg) in st .args() .iter() @@ -165,8 +182,8 @@ impl StatementTarget { params: &Params, t: NativePredicate, ) -> BoolTarget { - let st_code = builder.constants(&Predicate::Native(t).to_fields(params)); - builder.is_equal_slice(&self.predicate, &st_code) + let expected_predicate = PredicateTarget::new_native(builder, params, t); + builder.is_equal_flattenable(&self.predicate, &expected_predicate) } } @@ -212,11 +229,159 @@ impl OperationTarget { } } +#[derive(Clone)] +pub struct NativePredicateTarget(Target); + +impl NativePredicateTarget { + pub fn constant( + builder: &mut CircuitBuilder, + params: &Params, + native_predicate: NativePredicate, + ) -> Self { + let id = native_predicate.to_fields(params); + assert_eq!(1, id.len()); + Self(builder.constant(id[0])) + } + + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + native_predicate: NativePredicate, + ) -> Result<()> { + let id = native_predicate.to_fields(params); + assert_eq!(1, id.len()); + Ok(pw.set_target(self.0, id[0])?) + } +} + +#[derive(Clone)] +pub struct PredicateTarget { + elements: [Target; Params::predicate_size()], +} + +impl PredicateTarget { + pub fn new_native( + builder: &mut CircuitBuilder, + params: &Params, + native_predicate: impl Build, + ) -> Self { + let prefix = builder.constant(F::from(PredicatePrefix::Native)); + let id = native_predicate.build(builder, params).0; + let zero = builder.zero(); + Self { + elements: [prefix, id, zero, zero, zero, zero], + } + } + + pub fn new_batch_self(builder: &mut CircuitBuilder, index: Target) -> Self { + let prefix = builder.constant(F::from(PredicatePrefix::BatchSelf)); + let zero = builder.zero(); + Self { + elements: [prefix, index, zero, zero, zero, zero], + } + } + + pub fn new_custom( + builder: &mut CircuitBuilder, + batch_id: HashOutTarget, + index: Target, + ) -> Self { + let prefix = builder.constant(F::from(PredicatePrefix::Custom)); + let id = batch_id.elements; + Self { + elements: [prefix, id[0], id[1], id[2], id[3], index], + } + } + + pub fn set_targets( + &self, + pw: &mut PartialWitness, + params: &Params, + predicate: Predicate, + ) -> Result<()> { + Ok(pw.set_target_arr(&self.elements, &predicate.to_fields(params))?) + } +} + +#[derive(Clone)] +pub struct KeyOrWildcardTarget { + pub elements: [Target; VALUE_SIZE], +} + +impl KeyOrWildcardTarget { + fn from_slice(v: &[Target]) -> Self { + Self { + elements: v.try_into().expect("len is VALUE_SIZE"), + } + } +} + +#[derive(Clone)] +pub struct StatementTmplArgTarget { + pub elements: [Target; Params::statement_tmpl_arg_size()], +} + +impl StatementTmplArgTarget { + pub fn as_none(&self, builder: &mut CircuitBuilder) -> BoolTarget { + let prefix = builder.constant(F::from(StatementTmplArgPrefix::None)); + builder.is_equal(self.elements[0], prefix) + } + pub fn as_literal(&self, builder: &mut CircuitBuilder) -> (BoolTarget, ValueTarget) { + let prefix = builder.constant(F::from(StatementTmplArgPrefix::Literal)); + let case_ok = builder.is_equal(self.elements[0], prefix); + let value = ValueTarget::from_slice(&self.elements[1..5]); + (case_ok, value) + } + pub fn as_key( + &self, + builder: &mut CircuitBuilder, + ) -> (BoolTarget, Target, KeyOrWildcardTarget) { + let prefix = builder.constant(F::from(StatementTmplArgPrefix::Key)); + let case_ok = builder.is_equal(self.elements[0], prefix); + let id_wildcard_index = self.elements[1]; + let value_key_or_wildcard = KeyOrWildcardTarget::from_slice(&self.elements[5..9]); + (case_ok, id_wildcard_index, value_key_or_wildcard) + } + pub fn as_wildcard_literal(&self, builder: &mut CircuitBuilder) -> (BoolTarget, Target) { + let prefix = builder.constant(F::from(StatementTmplArgPrefix::WildcardLiteral)); + let case_ok = builder.is_equal(self.elements[0], prefix); + let wildcard_index = self.elements[1]; + (case_ok, wildcard_index) + } +} + +#[derive(Clone)] +pub struct StatementTmplTarget { + pub pred: PredicateTarget, + pub args: Vec, +} + +#[derive(Clone)] +pub struct CustomPredicateTarget { + pub conjunction: BoolTarget, + // len = params.max_custom_predicate_arity + pub statements: Vec, + pub args_len: Target, +} + +#[derive(Clone)] +pub struct CustomPredicateBatchTarget { + pub predicates: Vec, +} + +impl CustomPredicateBatchTarget { + pub fn id(&self, builder: &mut CircuitBuilder) -> HashOutTarget { + let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect(); + builder.hash_n_to_hash_no_pad::(flattened) + } +} + /// Trait for target structs that may be converted to and from vectors /// of targets. pub trait Flattenable { fn flatten(&self) -> Vec; - fn from_flattened(vs: &[Target]) -> Self; + fn from_flattened(params: &Params, vs: &[Target]) -> Self; } /// For the purpose of op verification, we need only look up the @@ -255,7 +420,7 @@ impl Flattenable for MerkleClaimTarget { .concat() } - fn from_flattened(vs: &[Target]) -> Self { + fn from_flattened(_params: &Params, vs: &[Target]) -> Self { Self { enabled: BoolTarget::new_unsafe(vs[0]), root: HashOutTarget::from_vec(vs[1..1 + NUM_HASH_OUT_ELTS].to_vec()), @@ -270,22 +435,34 @@ impl Flattenable for MerkleClaimTarget { } } +impl Flattenable for PredicateTarget { + fn flatten(&self) -> Vec { + self.elements.to_vec() + } + + fn from_flattened(_params: &Params, v: &[Target]) -> Self { + Self { + elements: v.try_into().expect("len is predicate_size"), + } + } +} + impl Flattenable for StatementTarget { fn flatten(&self) -> Vec { self.predicate - .iter() - .chain(self.args.iter().flat_map(|a| &a.elements)) - .cloned() + .flatten() + .into_iter() + .chain(self.args.iter().flat_map(|a| &a.elements).cloned()) .collect() } - fn from_flattened(v: &[Target]) -> Self { + fn from_flattened(params: &Params, v: &[Target]) -> Self { let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN; assert_eq!( v.len(), Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN ); - let predicate: [Target; Params::predicate_size()] = array::from_fn(|i| v[i]); + let predicate = PredicateTarget::from_flattened(params, &v[..Params::predicate_size()]); let args = (0..num_args) .map(|i| StatementArgTarget { elements: array::from_fn(|j| { @@ -298,11 +475,75 @@ impl Flattenable for StatementTarget { } } +impl Flattenable for CustomPredicateTarget { + fn flatten(&self) -> Vec { + iter::once(self.conjunction.target) + .chain(iter::once(self.args_len)) + .chain(self.statements.iter().flat_map(|s| s.flatten())) + .collect() + } + + fn from_flattened(params: &Params, v: &[Target]) -> Self { + // We assume that `from_flattened` is always called with the output of `flattened`, so + // this `BoolTarget` should actually safe. + let conjunction = BoolTarget::new_unsafe(v[0]); + let args_len = v[1]; + let st_tmpl_size = params.statement_tmpl_size(); + let statements = (0..params.max_custom_predicate_arity) + .map(|i| { + let st_v = &v[2 + st_tmpl_size * i..2 + st_tmpl_size * (i + 1)]; + StatementTmplTarget::from_flattened(params, st_v) + }) + .collect(); + Self { + conjunction, + statements, + args_len, + } + } +} + +impl Flattenable for StatementTmplTarget { + fn flatten(&self) -> Vec { + self.pred + .flatten() + .into_iter() + .chain(self.args.iter().flat_map(|sta| sta.flatten())) + .collect() + } + + fn from_flattened(params: &Params, v: &[Target]) -> Self { + let pred_end = Params::predicate_size(); + let pred = PredicateTarget::from_flattened(params, &v[..pred_end]); + let sta_size = Params::statement_tmpl_arg_size(); + let args = (0..params.max_statement_args) + .map(|i| { + let sta_v = &v[pred_end + sta_size * i..pred_end + sta_size * (i + 1)]; + StatementTmplArgTarget::from_flattened(params, sta_v) + }) + .collect(); + Self { pred, args } + } +} + +impl Flattenable for StatementTmplArgTarget { + fn flatten(&self) -> Vec { + self.elements.to_vec() + } + + fn from_flattened(_params: &Params, v: &[Target]) -> Self { + Self { + elements: v.try_into().expect("len is statement_tmpl_arg_size"), + } + } +} + pub trait CircuitBuilderPod, const D: usize> { fn connect_values(&mut self, x: ValueTarget, y: ValueTarget); fn connect_slice(&mut self, xs: &[Target], ys: &[Target]); fn add_virtual_value(&mut self) -> ValueTarget; fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget; + fn add_virtual_predicate(&mut self) -> PredicateTarget; fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget; fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget; @@ -329,8 +570,14 @@ pub trait CircuitBuilderPod, const D: usize> { // Convenience methods for accessing and connecting elements of // (vectors of) flattenables. - fn vec_ref(&mut self, ts: &[T], i: Target) -> T; - fn select_flattenable(&mut self, b: BoolTarget, x: &T, y: &T) -> T; + fn vec_ref(&mut self, params: &Params, ts: &[T], i: Target) -> T; + fn select_flattenable( + &mut self, + params: &Params, + b: BoolTarget, + x: &T, + y: &T, + ) -> T; fn connect_flattenable(&mut self, xs: &T, ys: &T); fn is_equal_flattenable(&mut self, xs: &T, ys: &T) -> BoolTarget; @@ -358,8 +605,9 @@ impl CircuitBuilderPod for CircuitBuilder { } fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget { + let predicate = self.add_virtual_predicate(); StatementTarget { - predicate: self.add_virtual_target_arr(), + predicate, args: (0..params.max_statement_args) .map(|_| StatementArgTarget { elements: self.add_virtual_target_arr(), @@ -368,6 +616,12 @@ impl CircuitBuilderPod for CircuitBuilder { } } + fn add_virtual_predicate(&mut self) -> PredicateTarget { + PredicateTarget { + elements: self.add_virtual_target_arr(), + } + } + fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { OperationTarget { op_type: self.add_virtual_target_arr(), @@ -470,7 +724,7 @@ impl CircuitBuilderPod for CircuitBuilder { ) } - fn vec_ref(&mut self, ts: &[T], i: Target) -> T { + fn vec_ref(&mut self, params: &Params, ts: &[T], i: Target) -> T { // TODO: Revisit this when we need more than 64 statements. let vector_ref = |builder: &mut CircuitBuilder, v: &[Target], i| { assert!(v.len() <= 64); @@ -498,14 +752,21 @@ impl CircuitBuilderPod for CircuitBuilder { }; let flattened_ts = ts.iter().map(|t| t.flatten()).collect::>(); - T::from_flattened(&matrix_row_ref(self, &flattened_ts, i)) + T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i)) } - fn select_flattenable(&mut self, b: BoolTarget, x: &T, y: &T) -> T { + fn select_flattenable( + &mut self, + params: &Params, + b: BoolTarget, + x: &T, + y: &T, + ) -> T { let flattened_x = x.flatten(); let flattened_y = y.flatten(); T::from_flattened( + params, &iter::zip(flattened_x, flattened_y) .map(|(x, y)| self.select(b, x, y)) .collect::>(), @@ -532,3 +793,123 @@ impl CircuitBuilderPod for CircuitBuilder { .unwrap_or(self._false()) } } + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use plonky2::plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}; + + use super::*; + use crate::{ + backends::plonky2::basetypes::C, + examples::custom::{eth_dos_batch, eth_friend_batch}, + frontend, + frontend::CustomPredicateBatchBuilder, + middleware::CustomPredicateBatch, + }; + + #[test] + fn custom_predicate_target() -> frontend::Result<()> { + let params = Params::default(); + let config = CircuitConfig::standard_recursion_config(); + + let custom_predicate_batch = eth_friend_batch(¶ms)?; + + for (i, cp) in custom_predicate_batch.predicates.iter().enumerate() { + let mut builder = CircuitBuilder::::new(config.clone()); + let flattened = cp.to_fields(¶ms); + let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec(); + let cp_target = CustomPredicateTarget::from_flattened(¶ms, &flatteend_target); + // Round trip of from_flattened to flattened + let flatteend_target_rt = cp_target.flatten(); + builder.connect_slice(&flatteend_target, &flatteend_target_rt); + + let pw = PartialWitness::::new(); + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).expect(&format!("predicate {}", i)); + data.verify(proof.clone()).unwrap(); + } + + Ok(()) + } + + fn test_custom_predicate_batch_target_id( + params: &Params, + custom_predicate_batch: &CustomPredicateBatch, + ) -> frontend::Result<()> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let zero = builder.zero(); + let predicate_targets = custom_predicate_batch + .predicates + .iter() + .map(|cp| { + let flattened = cp.to_fields(params); + let flatteend_target = flattened.iter().map(|v| builder.constant(*v)).collect_vec(); + CustomPredicateTarget::from_flattened(params, &flatteend_target) + }) + .chain(iter::repeat({ + let empty_flatteend_target = iter::repeat(zero) + .take(params.custom_predicate_size()) + .collect_vec(); + CustomPredicateTarget::from_flattened(params, &empty_flatteend_target) + })) + .take(params.max_custom_batch_size) + .collect(); + + let custom_predicate_batch_target = CustomPredicateBatchTarget { + predicates: predicate_targets, + }; + + // Calculate the id in constraints and compare it against the id calculated natively + let id_target = custom_predicate_batch_target.id(&mut builder); + let id = custom_predicate_batch.id(params); + + let id_expected_target = HashOutTarget { + elements: id + .to_fields(params) + .iter() + .map(|v| builder.constant(*v)) + .collect_vec() + .try_into() + .unwrap(), + }; + builder.connect_array(id_target.elements, id_expected_target.elements); + + let pw = PartialWitness::::new(); + + // generate & verify proof + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof.clone()).unwrap(); + + Ok(()) + } + + #[test] + fn custom_predicate_batch_target() -> frontend::Result<()> { + let params = Params { + max_statement_args: 6, + max_custom_predicate_wildcards: 12, + ..Default::default() + }; + + // Empty case + let mut cpb_builder = CustomPredicateBatchBuilder::new("empty".into()); + _ = cpb_builder.predicate_and("empty", ¶ms, &[], &[], &[])?; + let custom_predicate_batch = cpb_builder.finish(); + test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?; + + // Some cases from the examples + let custom_predicate_batch = eth_friend_batch(¶ms)?; + test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?; + + let custom_predicate_batch = eth_dos_batch(¶ms)?; + test_custom_predicate_batch_target_id(¶ms, &custom_predicate_batch)?; + + Ok(()) + } +} diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 398c805..12fe60d 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -85,14 +85,14 @@ impl OperationVerifyGadget { op.args .iter() .flatten() - .map(|&i| builder.vec_ref(prev_statements, i)) + .map(|&i| builder.vec_ref(&self.params, prev_statements, i)) .collect::>() }; // Certain operations (Contains/NotContains) will refer to one // of the provided Merkle proofs (if any). These proofs have already // been verified, so we need only look up the claim. - let resolved_merkle_claim = - (!merkle_claims.is_empty()).then(|| builder.vec_ref(merkle_claims, op.aux[0])); + let resolved_merkle_claim = (!merkle_claims.is_empty()) + .then(|| builder.vec_ref(&self.params, merkle_claims, op.aux[0])); // The verification may require aux data which needs to be stored in the // `OperationVerifyTarget` so that we can set during witness generation. @@ -455,7 +455,7 @@ impl OperationVerifyGadget { let individual_checks = prev_statements .iter() .map(|ps| { - let same_predicate = builder.is_equal_slice(&st.predicate, &ps.predicate); + let same_predicate = builder.is_equal_flattenable(&st.predicate, &ps.predicate); let same_anchored_key = builder.is_equal_slice(&st.args[0].elements, &ps.args[0].elements); builder.and(same_predicate, same_anchored_key) @@ -575,15 +575,7 @@ impl MainPodVerifyGadget { .collect(); // 2. Calculate the Pod Id from the public statements - let pub_statements_flattened = pub_statements - .iter() - .flat_map(|s| { - s.predicate - .iter() - .chain(s.args.iter().flat_map(|a| &a.elements)) - }) - .cloned() - .collect(); + let pub_statements_flattened = pub_statements.iter().flat_map(|s| s.flatten()).collect(); let id = builder.hash_n_to_hash_no_pad::(pub_statements_flattened); // 4. Verify type @@ -591,6 +583,7 @@ impl MainPodVerifyGadget { // TODO: Store this hash in a global static with lazy init so that we don't have to // compute it every time. let expected_type_statement = StatementTarget::from_flattened( + &self.params, &builder.constants( &Statement::ValueOf( AnchoredKey::from((SELF, KEY_TYPE)), diff --git a/src/backends/plonky2/circuits/signedpod.rs b/src/backends/plonky2/circuits/signedpod.rs index 3b06d19..52c66a0 100644 --- a/src/backends/plonky2/circuits/signedpod.rs +++ b/src/backends/plonky2/circuits/signedpod.rs @@ -3,17 +3,16 @@ use std::iter; use itertools::Itertools; use plonky2::{ hash::hash_types::{HashOut, HashOutTarget}, - iop::{ - target::Target, - witness::{PartialWitness, WitnessWrite}, - }, + iop::witness::{PartialWitness, WitnessWrite}, plonk::circuit_builder::CircuitBuilder, }; use crate::{ backends::plonky2::{ basetypes::D, - circuits::common::{CircuitBuilderPod, StatementArgTarget, StatementTarget, ValueTarget}, + circuits::common::{ + CircuitBuilderPod, PredicateTarget, StatementArgTarget, StatementTarget, ValueTarget, + }, error::Result, primitives::{ merkletree::{ @@ -24,8 +23,8 @@ use crate::{ signedpod::SignedPod, }, middleware::{ - hash_str, Key, NativePredicate, Params, PodType, Predicate, RawValue, ToFields, Value, F, - KEY_SIGNER, KEY_TYPE, SELF, + hash_str, Key, NativePredicate, Params, PodType, RawValue, Value, F, KEY_SIGNER, KEY_TYPE, + SELF, }, }; @@ -91,10 +90,8 @@ impl SignedPodVerifyTarget { self_id: bool, ) -> Vec { let mut statements = Vec::new(); - let predicate: [Target; Params::predicate_size()] = builder - .constants(&Predicate::Native(NativePredicate::ValueOf).to_fields(&self.params)) - .try_into() - .expect("size predicate_size"); + let predicate = + PredicateTarget::new_native(builder, &self.params, NativePredicate::ValueOf); let pod_id = if self_id { builder.constant_value(SELF.0.into()) } else { @@ -111,7 +108,10 @@ impl SignedPodVerifyTarget { .chain(iter::repeat_with(|| StatementArgTarget::none(builder))) .take(self.params.max_statement_args) .collect(); - let statement = StatementTarget { predicate, args }; + let statement = StatementTarget { + predicate: predicate.clone(), + args, + }; statements.push(statement); } statements diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 23891ce..e08859f 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -5,7 +5,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use crate::middleware::{ - hash_fields, Error, Hash, Key, NativePredicate, Params, Result, ToFields, Value, F, HASH_SIZE, + hash_fields, Error, Hash, Key, Params, Predicate, Result, ToFields, Value, F, HASH_SIZE, }; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] @@ -72,40 +72,54 @@ pub enum StatementTmplArg { WildcardLiteral(Wildcard), } +#[derive(Clone, Copy)] +pub enum StatementTmplArgPrefix { + None = 0, + Literal = 1, + Key = 2, + WildcardLiteral = 3, +} + +impl From for F { + fn from(prefix: StatementTmplArgPrefix) -> Self { + Self::from_canonical_usize(prefix as usize) + } +} + impl ToFields for StatementTmplArg { fn to_fields(&self, params: &Params) -> Vec { // None => (0, ...) // Literal(value) => (1, [value], 0, 0, 0, 0) - // Key(wildcard1, key_or_wildcard2) - // => (2, [wildcard1], [key_or_wildcard2]) - // WildcardLiteral(wildcard) => (3, [wildcard], 0, 0, 0, 0) + // Key(wildcard1_index, key_or_wildcard2) + // => (2, [wildcard1_index], 0, 0, 0, [key_or_wildcard2]) + // WildcardLiteral(wildcard_index) => (3, [wildcard_index], 0, 0, 0, 0, 0, 0, 0) // In all three cases, we pad to 2 * hash_size + 1 = 9 field elements - let statement_tmpl_arg_size = 2 * HASH_SIZE + 1; match self { StatementTmplArg::None => { - let fields: Vec = iter::repeat_with(|| F::from_canonical_u64(0)) - .take(statement_tmpl_arg_size) + let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::None)) + .chain(iter::repeat(F::ZERO)) + .take(Params::statement_tmpl_arg_size()) .collect(); fields } StatementTmplArg::Literal(v) => { - let fields: Vec = iter::once(F::from_canonical_u64(1)) + let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::Literal)) .chain(v.raw().to_fields(params)) - .chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) + .chain(iter::repeat(F::ZERO).take(HASH_SIZE)) .collect(); fields } StatementTmplArg::Key(wc1, kw2) => { - let fields: Vec = iter::once(F::from_canonical_u64(2)) + let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::Key)) .chain(wc1.to_fields(params)) .chain(kw2.to_fields(params)) .collect(); fields } StatementTmplArg::WildcardLiteral(wc) => { - let fields: Vec = iter::once(F::from_canonical_u64(3)) + let fields: Vec = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral)) .chain(wc.to_fields(params)) - .chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) + .chain(iter::repeat(F::ZERO).take(HASH_SIZE)) .collect(); fields } @@ -312,7 +326,10 @@ impl ToFields for CustomPredicateBatch { } impl CustomPredicateBatch { - pub fn hash(&self, params: &Params) -> Hash { + /// Cryptographic identifier for the batch. + pub fn id(&self, params: &Params) -> Hash { + // NOTE: This implementation just hashes the concatenation of all the custom predicates, + // but ideally we want to use the root of a merkle tree built from the custom predicates. let input = self.to_fields(params); hash_fields(&input) @@ -334,65 +351,6 @@ impl CustomPredicateRef { } } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] -#[serde(tag = "type", content = "value")] -pub enum Predicate { - Native(NativePredicate), - BatchSelf(usize), - Custom(CustomPredicateRef), -} - -impl From for Predicate { - fn from(v: NativePredicate) -> Self { - Self::Native(v) - } -} - -impl ToFields for Predicate { - fn to_fields(&self, params: &Params) -> Vec { - // serialize: - // NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize - // BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize - // CustomPredicateRef(pb, i) as - // (2, [hash of pb], i) -- pb hashes to 4 field elements - // -- i: usize - - // in every case: pad to (hash_size + 2) field elements - let mut fields: Vec = match self { - Self::Native(p) => iter::once(F::from_canonical_u64(1)) - .chain(p.to_fields(params)) - .collect(), - Self::BatchSelf(i) => iter::once(F::from_canonical_u64(2)) - .chain(iter::once(F::from_canonical_usize(*i))) - .collect(), - Self::Custom(CustomPredicateRef { batch, index }) => { - iter::once(F::from_canonical_u64(3)) - .chain(batch.hash(params).0) - .chain(iter::once(F::from_canonical_usize(*index))) - .collect() - } - }; - fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0)); - fields - } -} - -impl fmt::Display for Predicate { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Native(p) => write!(f, "{:?}", p), - Self::BatchSelf(i) => write!(f, "self.{}", i), - Self::Custom(CustomPredicateRef { batch, index }) => { - write!( - f, - "{}.{}[{}]", - batch.name, index, batch.predicates[*index].name - ) - } - } - } -} - #[cfg(test)] mod tests { use std::{array, sync::Arc}; diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index bdb182e..6cde0a7 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -43,7 +43,7 @@ impl ToFields for OperationType { .collect(), Self::Custom(CustomPredicateRef { batch, index }) => { iter::once(F::from_canonical_u64(3)) - .chain(batch.hash(params).0) + .chain(batch.id(params).0) .chain(iter::once(F::from_canonical_usize(*index))) .collect() } diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 2f6f278..6246e94 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize}; use strum_macros::FromRepr; use crate::middleware::{ - AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, Predicate, RawValue, Result, - ToFields, Value, F, VALUE_SIZE, + AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, RawValue, Result, ToFields, Value, + F, VALUE_SIZE, }; // TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static @@ -84,6 +84,78 @@ impl ToFields for WildcardValue { } } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", content = "value")] +pub enum Predicate { + Native(NativePredicate), + BatchSelf(usize), + Custom(CustomPredicateRef), +} + +impl From for Predicate { + fn from(v: NativePredicate) -> Self { + Self::Native(v) + } +} + +#[derive(Clone, Copy)] +pub enum PredicatePrefix { + Native = 1, + BatchSelf = 2, + Custom = 3, +} + +impl From for F { + fn from(prefix: PredicatePrefix) -> Self { + Self::from_canonical_usize(prefix as usize) + } +} + +impl ToFields for Predicate { + fn to_fields(&self, params: &Params) -> Vec { + // serialize: + // NativePredicate(id) as (1, id, 0, 0, 0, 0) -- id: usize + // BatchSelf(i) as (2, i, 0, 0, 0, 0) -- i: usize + // CustomPredicateRef(pb, i) as + // (3, [hash of pb], i) -- pb hashes to 4 field elements + // -- i: usize + + // in every case: pad to (hash_size + 2) field elements + let mut fields: Vec = match self { + Self::Native(p) => iter::once(F::from(PredicatePrefix::Native)) + .chain(p.to_fields(params)) + .collect(), + Self::BatchSelf(i) => iter::once(F::from(PredicatePrefix::BatchSelf)) + .chain(iter::once(F::from_canonical_usize(*i))) + .collect(), + Self::Custom(CustomPredicateRef { batch, index }) => { + iter::once(F::from(PredicatePrefix::Custom)) + .chain(batch.id(params).0) + .chain(iter::once(F::from_canonical_usize(*index))) + .collect() + } + }; + fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0)); + fields + } +} + +impl fmt::Display for Predicate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Native(p) => write!(f, "{:?}", p), + Self::BatchSelf(i) => write!(f, "self.{}", i), + Self::Custom(CustomPredicateRef { batch, index }) => { + write!( + f, + "{}.{}[{}]", + batch.name, index, batch.predicates[*index].name + ) + } + } + } +} + /// Type encapsulating statements with their associated arguments. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[serde(tag = "predicate", content = "args")]