From 82481e88d721cac56a6759e79e4370d9e79af7b1 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Thu, 22 May 2025 15:13:02 +0200 Subject: [PATCH] allow SELF in st_tmpl (#240) * allow SELF in st_tmpl * add some tests * Update src/backends/plonky2/circuits/mainpod.rs Co-authored-by: Ahmad Afuni --------- Co-authored-by: Ahmad Afuni --- .github/workflows/build.yml | 25 ++++++++ src/backends/plonky2/circuits/mainpod.rs | 38 +++++++++-- src/backends/plonky2/circuits/metrics.rs | 6 +- src/backends/plonky2/mainpod/mod.rs | 4 +- src/examples/custom.rs | 53 ++++------------ src/examples/mod.rs | 2 +- src/frontend/custom.rs | 43 ++++++++++--- src/middleware/custom.rs | 81 ++++++++++++++++++------ src/middleware/operation.rs | 13 ++-- 9 files changed, 178 insertions(+), 87 deletions(-) create mode 100644 .github/workflows/build.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..3bb217a --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,25 @@ +name: Rust Build with features + +on: + pull_request: + branches: [ main ] + types: [ready_for_review, opened, synchronize, reopened] + push: + branches: [ main ] + +jobs: + test: + if: github.event.pull_request.draft == false + name: Rust tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + - name: Build default + run: cargo build + - name: Build metrics + run: cargo build --features metrics + diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index b3fc21b..ad2aba8 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -1,4 +1,4 @@ -use std::{array, sync::Arc}; +use std::{array, iter, sync::Arc}; use itertools::{zip_eq, Itertools}; use plonky2::{ @@ -767,12 +767,18 @@ impl CustomOperationVerifyGadget { // optimization: ak_id_wc_index and wc_index use the same signals, so we only need to do one // random access to resolve both of them assert_eq!(ak_id_wc_index, wc_index); + // optimization: the wildcard indices have an offset of +1. This allows us to set a fixed + // SELF in args[0] to resolve SelfOrWildcard::SELF encoded as a wildcard at index 0. + let value_self = ValueTarget::from_slice(&builder.constants(&SELF.0 .0)); + let args = iter::once(value_self) + .chain(args.iter().cloned()) + .collect_vec(); // If the index is not used, use a 0 instead to still pass the range constraints from // vec_ref let first_index = ak_id_wc_index; let is_first_index_valid = builder.or(is_ak, is_wc_literal); let first_index = builder.select(is_first_index_valid, first_index, zero); - let resolved_ak_id = builder.vec_ref(&self.params, args, first_index); + let resolved_ak_id = builder.vec_ref(&self.params, &args, first_index); let resolved_wc = resolved_ak_id; // If the index is not used, use a 0 instead to still pass the range constraints from @@ -780,7 +786,7 @@ impl CustomOperationVerifyGadget { let second_index = ak_key_wc_index; let is_second_index_valid = builder.and(is_ak, is_ak_key_wc); let second_index = builder.select(is_second_index_valid, second_index, zero); - let resolved_ak_key = builder.vec_ref(&self.params, args, second_index); + let resolved_ak_key = builder.vec_ref(&self.params, &args, second_index); let ak_key = ak_key_lit; // is_ak_key_lit let ak_key = @@ -1278,7 +1284,7 @@ mod tests { frontend::{self, key, literal, CustomPredicateBatchBuilder, StatementTmplBuilder}, middleware::{ hash_str, hash_values, Hash, Key, KeyOrWildcard, OperationType, PodId, Predicate, - RawValue, StatementTmpl, StatementTmplArg, Wildcard, WildcardValue, + RawValue, SelfOrWildcard, StatementTmpl, StatementTmplArg, Wildcard, WildcardValue, }, }; @@ -2222,7 +2228,7 @@ mod tests { // case: AnchoredKey(id_wildcard, key_literal) let st_tmpl_arg = StatementTmplArg::AnchoredKey( - Wildcard::new("a".to_string(), 1), + SelfOrWildcard::Wildcard(Wildcard::new("a".to_string(), 1)), KeyOrWildcard::Key(Key::from("foo")), ); let args = vec![Value::from(1), Value::from(pod_id.0), Value::from(3)]; @@ -2231,13 +2237,31 @@ mod tests { // case: AnchoredKey(id_wildcard, key_wildcard) let st_tmpl_arg = StatementTmplArg::AnchoredKey( - Wildcard::new("a".to_string(), 1), + SelfOrWildcard::Wildcard(Wildcard::new("a".to_string(), 1)), KeyOrWildcard::Wildcard(Wildcard::new("b".to_string(), 2)), ); let args = vec![Value::from(1), Value::from(pod_id.0), Value::from("key")]; let expected_st_arg = StatementArg::Key(AnchoredKey::new(pod_id, Key::from("key"))); helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + // case: AnchoredKey(SELF, key_literal) + let st_tmpl_arg = StatementTmplArg::AnchoredKey( + SelfOrWildcard::SELF, + KeyOrWildcard::Key(Key::from("foo")), + ); + let args = vec![Value::from(1), Value::from(pod_id.0), Value::from(3)]; + let expected_st_arg = StatementArg::Key(AnchoredKey::new(SELF, Key::from("foo"))); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + + // case: AnchoredKey(SELF, key_wildcard) + let st_tmpl_arg = StatementTmplArg::AnchoredKey( + SelfOrWildcard::SELF, + KeyOrWildcard::Wildcard(Wildcard::new("b".to_string(), 2)), + ); + let args = vec![Value::from(1), Value::from(pod_id.0), Value::from("key")]; + let expected_st_arg = StatementArg::Key(AnchoredKey::new(SELF, Key::from("key"))); + helper_statement_arg_from_template(¶ms, st_tmpl_arg, args, expected_st_arg)?; + // case: WildcardLiteral(wildcard) let st_tmpl_arg = StatementTmplArg::WildcardLiteral(Wildcard::new("a".to_string(), 1)); let args = vec![Value::from(1), Value::from("key"), Value::from(3)]; @@ -2294,7 +2318,7 @@ mod tests { pred: Predicate::Native(NativePredicate::ValueOf), args: vec![ StatementTmplArg::AnchoredKey( - Wildcard::new("a".to_string(), 1), + SelfOrWildcard::Wildcard(Wildcard::new("a".to_string(), 1)), KeyOrWildcard::Key(Key::from("key")), ), StatementTmplArg::Literal(Value::from("value")), diff --git a/src/backends/plonky2/circuits/metrics.rs b/src/backends/plonky2/circuits/metrics.rs index 15f896c..46a6dea 100644 --- a/src/backends/plonky2/circuits/metrics.rs +++ b/src/backends/plonky2/circuits/metrics.rs @@ -67,7 +67,7 @@ pub mod measure_macros { #[macro_export] macro_rules! measure_gates_begin { ($builder:expr, $name:expr) => {{ - use $crate::backends::plonky2::circuits::utils::METRICS; + use $crate::backends::plonky2::circuits::metrics::METRICS; let mut metrics = METRICS.lock().unwrap(); metrics.begin($builder, $name) }}; @@ -76,7 +76,7 @@ pub mod measure_macros { #[macro_export] macro_rules! measure_gates_end { ($builder:expr, $measure:expr) => {{ - use $crate::backends::plonky2::circuits::utils::METRICS; + use $crate::backends::plonky2::circuits::metrics::METRICS; let mut metrics = METRICS.lock().unwrap(); metrics.end($builder, $measure); }}; @@ -85,7 +85,7 @@ pub mod measure_macros { #[macro_export] macro_rules! measure_gates_print { () => {{ - use $crate::backends::plonky2::circuits::utils::METRICS; + use $crate::backends::plonky2::circuits::metrics::METRICS; let metrics = METRICS.lock().unwrap(); metrics.print(); }}; diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 251e9ff..fac5241 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -737,11 +737,11 @@ pub mod tests { max_statements: 26, max_public_statements: 5, max_signed_pod_values: 8, - max_statement_args: 6, + max_statement_args: 3, max_operation_args: 4, max_custom_predicate_arity: 4, max_custom_batch_size: 3, - max_custom_predicate_wildcards: 12, + max_custom_predicate_wildcards: 6, max_custom_predicate_verifications: 8, ..Default::default() }; diff --git a/src/examples/custom.rs b/src/examples/custom.rs index f059b12..3ca357d 100644 --- a/src/examples/custom.rs +++ b/src/examples/custom.rs @@ -21,7 +21,7 @@ pub fn eth_friend_batch(params: &Params, mock: bool) -> Result Result Result Result Result Result KeyOrWildcardStr { KeyOrWildcardStr::Key(s.to_string()) @@ -27,11 +34,21 @@ pub fn key(s: &str) -> KeyOrWildcardStr { #[derive(Clone)] pub enum BuilderArg { Literal(Value), - /// Key: (origin, key), where origin is a Wildcard and key can be both Key or Wildcard - Key(String, KeyOrWildcardStr), + /// Key: (origin, key), where origin is SELF or Wildcard and key is Key or Wildcard + Key(SelfOrWildcardStr, KeyOrWildcardStr), WildcardLiteral(String), } +impl From<&str> for SelfOrWildcardStr { + fn from(origin: &str) -> Self { + if origin == "SELF" { + SelfOrWildcardStr::SELF + } else { + SelfOrWildcardStr::Wildcard(origin.into()) + } + } +} + /// When defining a `BuilderArg`, it can be done from 3 different inputs: /// i. (&str, literal): this is to set a POD and a field, ie. (POD, literal("field")) /// ii. (&str, &str): this is to define a origin-key wildcard pair, ie. (src_origin, src_dest) @@ -40,11 +57,6 @@ pub enum BuilderArg { /// case i. impl From<(&str, KeyOrWildcardStr)> for BuilderArg { fn from((origin, lit): (&str, KeyOrWildcardStr)) -> Self { - // ensure that `lit` is of HashOrWildcardStr::Hash type - match lit { - KeyOrWildcardStr::Key(_) => (), - _ => panic!("not supported"), - }; Self::Key(origin.into(), lit) } } @@ -197,7 +209,7 @@ impl CustomPredicateBatchBuilder { .map(|a| match a { BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), BuilderArg::Key(pod_id, key) => StatementTmplArg::AnchoredKey( - resolve_wildcard(args, priv_args, pod_id), + resolve_self_or_wildcard(args, priv_args, pod_id), resolve_key_or_wildcard(args, priv_args, key), ), BuilderArg::WildcardLiteral(v) => { @@ -227,6 +239,19 @@ impl CustomPredicateBatchBuilder { } } +fn resolve_self_or_wildcard( + args: &[&str], + priv_args: &[&str], + v: &SelfOrWildcardStr, +) -> SelfOrWildcard { + match v { + SelfOrWildcardStr::SELF => SelfOrWildcard::SELF, + SelfOrWildcardStr::Wildcard(s) => { + SelfOrWildcard::Wildcard(resolve_wildcard(args, priv_args, s)) + } + } +} + fn resolve_key_or_wildcard( args: &[&str], priv_args: &[&str], diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 27109b9..fb5465c 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -29,7 +29,7 @@ impl fmt::Display for Wildcard { impl ToFields for Wildcard { fn to_fields(&self, _params: &Params) -> Vec { - vec![F::from_canonical_u64(self.index as u64)] + vec![F::from_canonical_u64(self.index as u64 + 1)] } } @@ -52,11 +52,11 @@ impl fmt::Display for KeyOrWildcard { impl ToFields for KeyOrWildcard { // Encoding: // - Key(k) => [[k]] - // - Wildcard(index) => [[index], 0, 0, 0] + // - Wildcard(index) => [[index + 1], 0, 0, 0] fn to_fields(&self, params: &Params) -> Vec { match self { KeyOrWildcard::Key(k) => k.hash().to_fields(params), - KeyOrWildcard::Wildcard(wc) => iter::once(F::from_canonical_u64(wc.index as u64)) + KeyOrWildcard::Wildcard(wc) => iter::once(F::from_canonical_u64(wc.index as u64 + 1)) .chain(iter::repeat(F::ZERO)) .take(HASH_SIZE) .collect(), @@ -64,13 +64,41 @@ impl ToFields for KeyOrWildcard { } } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", content = "value")] +pub enum SelfOrWildcard { + SELF, + Wildcard(Wildcard), +} + +impl fmt::Display for SelfOrWildcard { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::SELF => write!(f, "SELF"), + Self::Wildcard(wc) => write!(f, "{}", wc), + } + } +} + +impl ToFields for SelfOrWildcard { + // Encoding: + // - Self => [0] + // - Wildcard(index) => [index+1] + fn to_fields(&self, _params: &Params) -> Vec { + match self { + SelfOrWildcard::SELF => vec![F::ZERO], + SelfOrWildcard::Wildcard(wc) => vec![F::from_canonical_u64(wc.index as u64 + 1)], + } + } +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[serde(tag = "type", content = "value")] pub enum StatementTmplArg { None, Literal(Value), // AnchoredKey - AnchoredKey(Wildcard, KeyOrWildcard), + AnchoredKey(SelfOrWildcard, KeyOrWildcard), // TODO: This naming is a bit confusing: a WildcardLiteral that contains a Wildcard... // Could we merge WildcardValue and Value and allow wildcard value apart from pod_id and key? WildcardLiteral(Wildcard), @@ -438,6 +466,9 @@ mod tests { fn kow_wc(i: usize) -> KOW { KOW::Wildcard(wc(i)) } + fn sow_wc(i: usize) -> SOW { + SOW::Wildcard(wc(i)) + } fn wc(i: usize) -> Wildcard { Wildcard { name: format!("{}", i), @@ -447,6 +478,7 @@ mod tests { type STA = StatementTmplArg; type KOW = KeyOrWildcard; + type SOW = SelfOrWildcard; type P = Predicate; type NP = NativePredicate; @@ -468,14 +500,17 @@ mod tests { vec![ st( P::Native(NP::ValueOf), - vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(2.into())], + vec![ + STA::AnchoredKey(sow_wc(4), kow_wc(5)), + STA::Literal(2.into()), + ], ), st( P::Native(NP::ProductOf), vec![ - STA::AnchoredKey(wc(0), kow_wc(1)), - STA::AnchoredKey(wc(4), kow_wc(5)), - STA::AnchoredKey(wc(2), kow_wc(3)), + STA::AnchoredKey(sow_wc(0), kow_wc(1)), + STA::AnchoredKey(sow_wc(4), kow_wc(5)), + STA::AnchoredKey(sow_wc(2), kow_wc(3)), ], ), ], @@ -523,22 +558,22 @@ mod tests { st( P::Native(NP::ValueOf), vec![ - STA::AnchoredKey(wc(4), KeyOrWildcard::Key("type".into())), + STA::AnchoredKey(sow_wc(4), KeyOrWildcard::Key("type".into())), STA::Literal(PodType::Signed.into()), ], ), st( P::Native(NP::Equal), vec![ - STA::AnchoredKey(wc(4), KeyOrWildcard::Key("signer".into())), - STA::AnchoredKey(wc(0), kow_wc(1)), + STA::AnchoredKey(sow_wc(4), KeyOrWildcard::Key("signer".into())), + STA::AnchoredKey(sow_wc(0), kow_wc(1)), ], ), st( P::Native(NP::Equal), vec![ - STA::AnchoredKey(wc(4), KeyOrWildcard::Key("attestation".into())), - STA::AnchoredKey(wc(2), kow_wc(3)), + STA::AnchoredKey(sow_wc(4), KeyOrWildcard::Key("attestation".into())), + STA::AnchoredKey(sow_wc(2), kow_wc(3)), ], ), ], @@ -556,13 +591,16 @@ mod tests { st( P::Native(NP::Equal), vec![ - STA::AnchoredKey(wc(0), kow_wc(1)), - STA::AnchoredKey(wc(2), kow_wc(3)), + STA::AnchoredKey(sow_wc(0), kow_wc(1)), + STA::AnchoredKey(sow_wc(2), kow_wc(3)), ], ), st( P::Native(NP::ValueOf), - vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(0.into())], + vec![ + STA::AnchoredKey(sow_wc(4), kow_wc(5)), + STA::Literal(0.into()), + ], ), ], 6, @@ -586,14 +624,17 @@ mod tests { ), st( P::Native(NP::ValueOf), - vec![STA::AnchoredKey(wc(6), kow_wc(7)), STA::Literal(1.into())], + vec![ + STA::AnchoredKey(sow_wc(6), kow_wc(7)), + STA::Literal(1.into()), + ], ), st( P::Native(NP::SumOf), vec![ - STA::AnchoredKey(wc(4), kow_wc(5)), - STA::AnchoredKey(wc(8), kow_wc(9)), - STA::AnchoredKey(wc(6), kow_wc(7)), + STA::AnchoredKey(sow_wc(4), kow_wc(5)), + STA::AnchoredKey(sow_wc(8), kow_wc(9)), + STA::AnchoredKey(sow_wc(6), kow_wc(7)), ], ), st( diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 08efd93..3bdfa68 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -8,8 +8,8 @@ use crate::{ backends::plonky2::primitives::merkletree::MerkleProof, middleware::{ custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, - NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmplArg, - ToFields, Wildcard, WildcardValue, F, SELF, + NativePredicate, Params, Predicate, Result, SelfOrWildcard, Statement, StatementArg, + StatementTmplArg, ToFields, Wildcard, WildcardValue, F, SELF, }, }; @@ -363,10 +363,15 @@ pub fn check_st_tmpl( (StatementTmplArg::None, StatementArg::None) => true, (StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true, ( - StatementTmplArg::AnchoredKey(pod_id_wc, key_or_wc), + StatementTmplArg::AnchoredKey(self_or_wc, key_or_wc), StatementArg::Key(AnchoredKey { pod_id, key }), ) => { - let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map); + let pod_id_ok = match self_or_wc { + SelfOrWildcard::SELF => SELF == *pod_id, + SelfOrWildcard::Wildcard(pod_id_wc) => { + check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map) + } + }; let key_ok = match key_or_wc { KeyOrWildcard::Key(tmpl_key) => tmpl_key == key, KeyOrWildcard::Wildcard(key_wc) => {