allow SELF in st_tmpl (#240)

* allow SELF in st_tmpl

* add some tests

* Update src/backends/plonky2/circuits/mainpod.rs

Co-authored-by: Ahmad Afuni <root@ahmadafuni.com>

---------

Co-authored-by: Ahmad Afuni <root@ahmadafuni.com>
This commit is contained in:
Eduard S. 2025-05-22 15:13:02 +02:00 committed by GitHub
parent b4a4c72328
commit 82481e88d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 178 additions and 87 deletions

25
.github/workflows/build.yml vendored Normal file
View file

@ -0,0 +1,25 @@
name: Rust Build with features
on:
pull_request:
branches: [ main ]
types: [ready_for_review, opened, synchronize, reopened]
push:
branches: [ main ]
jobs:
test:
if: github.event.pull_request.draft == false
name: Rust tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
- name: Build default
run: cargo build
- name: Build metrics
run: cargo build --features metrics

View file

@ -1,4 +1,4 @@
use std::{array, sync::Arc};
use std::{array, iter, sync::Arc};
use itertools::{zip_eq, Itertools};
use plonky2::{
@ -767,12 +767,18 @@ impl CustomOperationVerifyGadget {
// optimization: ak_id_wc_index and wc_index use the same signals, so we only need to do one
// random access to resolve both of them
assert_eq!(ak_id_wc_index, wc_index);
// optimization: the wildcard indices have an offset of +1. This allows us to set a fixed
// SELF in args[0] to resolve SelfOrWildcard::SELF encoded as a wildcard at index 0.
let value_self = ValueTarget::from_slice(&builder.constants(&SELF.0 .0));
let args = iter::once(value_self)
.chain(args.iter().cloned())
.collect_vec();
// If the index is not used, use a 0 instead to still pass the range constraints from
// vec_ref
let first_index = ak_id_wc_index;
let is_first_index_valid = builder.or(is_ak, is_wc_literal);
let first_index = builder.select(is_first_index_valid, first_index, zero);
let resolved_ak_id = builder.vec_ref(&self.params, args, first_index);
let resolved_ak_id = builder.vec_ref(&self.params, &args, first_index);
let resolved_wc = resolved_ak_id;
// If the index is not used, use a 0 instead to still pass the range constraints from
@ -780,7 +786,7 @@ impl CustomOperationVerifyGadget {
let second_index = ak_key_wc_index;
let is_second_index_valid = builder.and(is_ak, is_ak_key_wc);
let second_index = builder.select(is_second_index_valid, second_index, zero);
let resolved_ak_key = builder.vec_ref(&self.params, args, second_index);
let resolved_ak_key = builder.vec_ref(&self.params, &args, second_index);
let ak_key = ak_key_lit; // is_ak_key_lit
let ak_key =
@ -1278,7 +1284,7 @@ mod tests {
frontend::{self, key, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
middleware::{
hash_str, hash_values, Hash, Key, KeyOrWildcard, OperationType, PodId, Predicate,
RawValue, StatementTmpl, StatementTmplArg, Wildcard, WildcardValue,
RawValue, SelfOrWildcard, StatementTmpl, StatementTmplArg, Wildcard, WildcardValue,
},
};
@ -2222,7 +2228,7 @@ mod tests {
// case: AnchoredKey(id_wildcard, key_literal)
let st_tmpl_arg = StatementTmplArg::AnchoredKey(
Wildcard::new("a".to_string(), 1),
SelfOrWildcard::Wildcard(Wildcard::new("a".to_string(), 1)),
KeyOrWildcard::Key(Key::from("foo")),
);
let args = vec![Value::from(1), Value::from(pod_id.0), Value::from(3)];
@ -2231,13 +2237,31 @@ mod tests {
// case: AnchoredKey(id_wildcard, key_wildcard)
let st_tmpl_arg = StatementTmplArg::AnchoredKey(
Wildcard::new("a".to_string(), 1),
SelfOrWildcard::Wildcard(Wildcard::new("a".to_string(), 1)),
KeyOrWildcard::Wildcard(Wildcard::new("b".to_string(), 2)),
);
let args = vec![Value::from(1), Value::from(pod_id.0), Value::from("key")];
let expected_st_arg = StatementArg::Key(AnchoredKey::new(pod_id, Key::from("key")));
helper_statement_arg_from_template(&params, st_tmpl_arg, args, expected_st_arg)?;
// case: AnchoredKey(SELF, key_literal)
let st_tmpl_arg = StatementTmplArg::AnchoredKey(
SelfOrWildcard::SELF,
KeyOrWildcard::Key(Key::from("foo")),
);
let args = vec![Value::from(1), Value::from(pod_id.0), Value::from(3)];
let expected_st_arg = StatementArg::Key(AnchoredKey::new(SELF, Key::from("foo")));
helper_statement_arg_from_template(&params, st_tmpl_arg, args, expected_st_arg)?;
// case: AnchoredKey(SELF, key_wildcard)
let st_tmpl_arg = StatementTmplArg::AnchoredKey(
SelfOrWildcard::SELF,
KeyOrWildcard::Wildcard(Wildcard::new("b".to_string(), 2)),
);
let args = vec![Value::from(1), Value::from(pod_id.0), Value::from("key")];
let expected_st_arg = StatementArg::Key(AnchoredKey::new(SELF, Key::from("key")));
helper_statement_arg_from_template(&params, st_tmpl_arg, args, expected_st_arg)?;
// case: WildcardLiteral(wildcard)
let st_tmpl_arg = StatementTmplArg::WildcardLiteral(Wildcard::new("a".to_string(), 1));
let args = vec![Value::from(1), Value::from("key"), Value::from(3)];
@ -2294,7 +2318,7 @@ mod tests {
pred: Predicate::Native(NativePredicate::ValueOf),
args: vec![
StatementTmplArg::AnchoredKey(
Wildcard::new("a".to_string(), 1),
SelfOrWildcard::Wildcard(Wildcard::new("a".to_string(), 1)),
KeyOrWildcard::Key(Key::from("key")),
),
StatementTmplArg::Literal(Value::from("value")),

View file

@ -67,7 +67,7 @@ pub mod measure_macros {
#[macro_export]
macro_rules! measure_gates_begin {
($builder:expr, $name:expr) => {{
use $crate::backends::plonky2::circuits::utils::METRICS;
use $crate::backends::plonky2::circuits::metrics::METRICS;
let mut metrics = METRICS.lock().unwrap();
metrics.begin($builder, $name)
}};
@ -76,7 +76,7 @@ pub mod measure_macros {
#[macro_export]
macro_rules! measure_gates_end {
($builder:expr, $measure:expr) => {{
use $crate::backends::plonky2::circuits::utils::METRICS;
use $crate::backends::plonky2::circuits::metrics::METRICS;
let mut metrics = METRICS.lock().unwrap();
metrics.end($builder, $measure);
}};
@ -85,7 +85,7 @@ pub mod measure_macros {
#[macro_export]
macro_rules! measure_gates_print {
() => {{
use $crate::backends::plonky2::circuits::utils::METRICS;
use $crate::backends::plonky2::circuits::metrics::METRICS;
let metrics = METRICS.lock().unwrap();
metrics.print();
}};

View file

@ -737,11 +737,11 @@ pub mod tests {
max_statements: 26,
max_public_statements: 5,
max_signed_pod_values: 8,
max_statement_args: 6,
max_statement_args: 3,
max_operation_args: 4,
max_custom_predicate_arity: 4,
max_custom_batch_size: 3,
max_custom_predicate_wildcards: 12,
max_custom_predicate_wildcards: 6,
max_custom_predicate_verifications: 8,
..Default::default()
};

View file

@ -21,7 +21,7 @@ pub fn eth_friend_batch(params: &Params, mock: bool) -> Result<Arc<CustomPredica
let _eth_friend = builder.predicate_and(
"eth_friend",
// arguments:
&["src_ori", "src_key", "dst_ori", "dst_key"],
&["src_key", "dst_key"],
// private arguments:
&["attestation_pod"],
// statement templates:
@ -33,11 +33,11 @@ pub fn eth_friend_batch(params: &Params, mock: bool) -> Result<Arc<CustomPredica
// the attestation pod is signed by (src_or, src_key)
STB::new(NP::Equal)
.arg(("attestation_pod", key(KEY_SIGNER)))
.arg(("src_ori", "src_key")),
.arg(("SELF", "src_key")),
// that same attestation pod has an "attestation"
STB::new(NP::Equal)
.arg(("attestation_pod", key("attestation")))
.arg(("dst_ori", "dst_key")),
.arg(("SELF", "dst_key")),
],
)?;
@ -59,11 +59,8 @@ pub fn eth_dos_batch(params: &Params, mock: bool) -> Result<Arc<CustomPredicateB
"eth_dos_distance_base",
&[
// arguments:
"src_ori",
"src_key",
"dst_ori",
"dst_key",
"distance_ori",
"distance_key",
],
&[ // private arguments:
@ -71,10 +68,10 @@ pub fn eth_dos_batch(params: &Params, mock: bool) -> Result<Arc<CustomPredicateB
&[
// statement templates:
STB::new(NP::Equal)
.arg(("src_ori", "src_key"))
.arg(("dst_ori", "dst_key")),
.arg(("SELF", "src_key"))
.arg(("SELF", "dst_key")),
STB::new(NP::ValueOf)
.arg(("distance_ori", "distance_key"))
.arg(("SELF", "distance_key"))
.arg(literal(0)),
],
)?;
@ -89,45 +86,32 @@ pub fn eth_dos_batch(params: &Params, mock: bool) -> Result<Arc<CustomPredicateB
"eth_dos_distance_ind",
&[
// arguments:
"src_ori",
"src_key",
"dst_ori",
"dst_key",
"distance_ori",
"distance_key",
],
&[
// private arguments:
"one_ori",
"one_key",
"shorter_distance_ori",
"shorter_distance_key",
"intermed_ori",
"intermed_key",
],
&[
// statement templates:
STB::new(eth_dos_distance)
.arg("src_ori")
.arg("src_key")
.arg("intermed_ori")
.arg("intermed_key")
.arg("shorter_distance_ori")
.arg("shorter_distance_key"),
// distance == shorter_distance + 1
STB::new(NP::ValueOf)
.arg(("one_ori", "one_key"))
.arg(("SELF", "one_key"))
.arg(literal(1)),
STB::new(NP::SumOf)
.arg(("distance_ori", "distance_key"))
.arg(("shorter_distance_ori", "shorter_distance_key"))
.arg(("one_ori", "one_key")),
.arg(("SELF", "distance_key"))
.arg(("SELF", "shorter_distance_key"))
.arg(("SELF", "one_key")),
// intermed is a friend of dst
STB::new(eth_friend)
.arg("intermed_ori")
.arg("intermed_key")
.arg("dst_ori")
.arg("dst_key"),
STB::new(eth_friend).arg("intermed_key").arg("dst_key"),
],
)?;
@ -138,29 +122,16 @@ pub fn eth_dos_batch(params: &Params, mock: bool) -> Result<Arc<CustomPredicateB
let _eth_dos_distance = builder.predicate_or(
"eth_dos_distance",
&[
"src_ori",
"src_key",
"dst_ori",
"dst_key",
"distance_ori",
"distance_key",
],
&["src_key", "dst_key", "distance_key"],
&[],
&[
STB::new(eth_dos_distance_base)
.arg("src_ori")
.arg("src_key")
.arg("dst_ori")
.arg("dst_key")
.arg("distance_ori")
.arg("distance_key"),
STB::new(eth_dos_distance_ind)
.arg("src_ori")
.arg("src_key")
.arg("dst_ori")
.arg("dst_key")
.arg("distance_ori")
.arg("distance_key"),
],
)?;

View file

@ -115,7 +115,7 @@ pub fn eth_dos_pod_builder(
let zero = alice_bob_ethdos.priv_literal(0)?;
let alice_equals_alice = alice_bob_ethdos.priv_op(op!(
eq,
(alice_attestation, KEY_SIGNER),
alice_pubkey_copy.clone(),
alice_pubkey_copy.clone()
))?;
let ethdos_alice_alice_is_zero_base = alice_bob_ethdos.priv_op(op!(

View file

@ -7,7 +7,8 @@ use crate::{
frontend::{AnchoredKey, Error, Result, Statement, StatementArg},
middleware::{
self, hash_str, CustomPredicate, CustomPredicateBatch, Key, KeyOrWildcard, NativePredicate,
Params, PodId, Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
Params, PodId, Predicate, SelfOrWildcard, StatementTmpl, StatementTmplArg, ToFields, Value,
Wildcard,
},
};
@ -18,6 +19,12 @@ pub enum KeyOrWildcardStr {
Wildcard(String),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SelfOrWildcardStr {
SELF,
Wildcard(String),
}
/// helper to build a literal KeyOrWildcardStr::Key from the given str
pub fn key(s: &str) -> KeyOrWildcardStr {
KeyOrWildcardStr::Key(s.to_string())
@ -27,11 +34,21 @@ pub fn key(s: &str) -> KeyOrWildcardStr {
#[derive(Clone)]
pub enum BuilderArg {
Literal(Value),
/// Key: (origin, key), where origin is a Wildcard and key can be both Key or Wildcard
Key(String, KeyOrWildcardStr),
/// Key: (origin, key), where origin is SELF or Wildcard and key is Key or Wildcard
Key(SelfOrWildcardStr, KeyOrWildcardStr),
WildcardLiteral(String),
}
impl From<&str> for SelfOrWildcardStr {
fn from(origin: &str) -> Self {
if origin == "SELF" {
SelfOrWildcardStr::SELF
} else {
SelfOrWildcardStr::Wildcard(origin.into())
}
}
}
/// When defining a `BuilderArg`, it can be done from 3 different inputs:
/// i. (&str, literal): this is to set a POD and a field, ie. (POD, literal("field"))
/// ii. (&str, &str): this is to define a origin-key wildcard pair, ie. (src_origin, src_dest)
@ -40,11 +57,6 @@ pub enum BuilderArg {
/// case i.
impl From<(&str, KeyOrWildcardStr)> for BuilderArg {
fn from((origin, lit): (&str, KeyOrWildcardStr)) -> Self {
// ensure that `lit` is of HashOrWildcardStr::Hash type
match lit {
KeyOrWildcardStr::Key(_) => (),
_ => panic!("not supported"),
};
Self::Key(origin.into(), lit)
}
}
@ -197,7 +209,7 @@ impl CustomPredicateBatchBuilder {
.map(|a| match a {
BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()),
BuilderArg::Key(pod_id, key) => StatementTmplArg::AnchoredKey(
resolve_wildcard(args, priv_args, pod_id),
resolve_self_or_wildcard(args, priv_args, pod_id),
resolve_key_or_wildcard(args, priv_args, key),
),
BuilderArg::WildcardLiteral(v) => {
@ -227,6 +239,19 @@ impl CustomPredicateBatchBuilder {
}
}
fn resolve_self_or_wildcard(
args: &[&str],
priv_args: &[&str],
v: &SelfOrWildcardStr,
) -> SelfOrWildcard {
match v {
SelfOrWildcardStr::SELF => SelfOrWildcard::SELF,
SelfOrWildcardStr::Wildcard(s) => {
SelfOrWildcard::Wildcard(resolve_wildcard(args, priv_args, s))
}
}
}
fn resolve_key_or_wildcard(
args: &[&str],
priv_args: &[&str],

View file

@ -29,7 +29,7 @@ impl fmt::Display for Wildcard {
impl ToFields for Wildcard {
fn to_fields(&self, _params: &Params) -> Vec<F> {
vec![F::from_canonical_u64(self.index as u64)]
vec![F::from_canonical_u64(self.index as u64 + 1)]
}
}
@ -52,11 +52,11 @@ impl fmt::Display for KeyOrWildcard {
impl ToFields for KeyOrWildcard {
// Encoding:
// - Key(k) => [[k]]
// - Wildcard(index) => [[index], 0, 0, 0]
// - Wildcard(index) => [[index + 1], 0, 0, 0]
fn to_fields(&self, params: &Params) -> Vec<F> {
match self {
KeyOrWildcard::Key(k) => k.hash().to_fields(params),
KeyOrWildcard::Wildcard(wc) => iter::once(F::from_canonical_u64(wc.index as u64))
KeyOrWildcard::Wildcard(wc) => iter::once(F::from_canonical_u64(wc.index as u64 + 1))
.chain(iter::repeat(F::ZERO))
.take(HASH_SIZE)
.collect(),
@ -64,13 +64,41 @@ impl ToFields for KeyOrWildcard {
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", content = "value")]
pub enum SelfOrWildcard {
SELF,
Wildcard(Wildcard),
}
impl fmt::Display for SelfOrWildcard {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::SELF => write!(f, "SELF"),
Self::Wildcard(wc) => write!(f, "{}", wc),
}
}
}
impl ToFields for SelfOrWildcard {
// Encoding:
// - Self => [0]
// - Wildcard(index) => [index+1]
fn to_fields(&self, _params: &Params) -> Vec<F> {
match self {
SelfOrWildcard::SELF => vec![F::ZERO],
SelfOrWildcard::Wildcard(wc) => vec![F::from_canonical_u64(wc.index as u64 + 1)],
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", content = "value")]
pub enum StatementTmplArg {
None,
Literal(Value),
// AnchoredKey
AnchoredKey(Wildcard, KeyOrWildcard),
AnchoredKey(SelfOrWildcard, KeyOrWildcard),
// TODO: This naming is a bit confusing: a WildcardLiteral that contains a Wildcard...
// Could we merge WildcardValue and Value and allow wildcard value apart from pod_id and key?
WildcardLiteral(Wildcard),
@ -438,6 +466,9 @@ mod tests {
fn kow_wc(i: usize) -> KOW {
KOW::Wildcard(wc(i))
}
fn sow_wc(i: usize) -> SOW {
SOW::Wildcard(wc(i))
}
fn wc(i: usize) -> Wildcard {
Wildcard {
name: format!("{}", i),
@ -447,6 +478,7 @@ mod tests {
type STA = StatementTmplArg;
type KOW = KeyOrWildcard;
type SOW = SelfOrWildcard;
type P = Predicate;
type NP = NativePredicate;
@ -468,14 +500,17 @@ mod tests {
vec![
st(
P::Native(NP::ValueOf),
vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(2.into())],
vec![
STA::AnchoredKey(sow_wc(4), kow_wc(5)),
STA::Literal(2.into()),
],
),
st(
P::Native(NP::ProductOf),
vec![
STA::AnchoredKey(wc(0), kow_wc(1)),
STA::AnchoredKey(wc(4), kow_wc(5)),
STA::AnchoredKey(wc(2), kow_wc(3)),
STA::AnchoredKey(sow_wc(0), kow_wc(1)),
STA::AnchoredKey(sow_wc(4), kow_wc(5)),
STA::AnchoredKey(sow_wc(2), kow_wc(3)),
],
),
],
@ -523,22 +558,22 @@ mod tests {
st(
P::Native(NP::ValueOf),
vec![
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("type".into())),
STA::AnchoredKey(sow_wc(4), KeyOrWildcard::Key("type".into())),
STA::Literal(PodType::Signed.into()),
],
),
st(
P::Native(NP::Equal),
vec![
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("signer".into())),
STA::AnchoredKey(wc(0), kow_wc(1)),
STA::AnchoredKey(sow_wc(4), KeyOrWildcard::Key("signer".into())),
STA::AnchoredKey(sow_wc(0), kow_wc(1)),
],
),
st(
P::Native(NP::Equal),
vec![
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("attestation".into())),
STA::AnchoredKey(wc(2), kow_wc(3)),
STA::AnchoredKey(sow_wc(4), KeyOrWildcard::Key("attestation".into())),
STA::AnchoredKey(sow_wc(2), kow_wc(3)),
],
),
],
@ -556,13 +591,16 @@ mod tests {
st(
P::Native(NP::Equal),
vec![
STA::AnchoredKey(wc(0), kow_wc(1)),
STA::AnchoredKey(wc(2), kow_wc(3)),
STA::AnchoredKey(sow_wc(0), kow_wc(1)),
STA::AnchoredKey(sow_wc(2), kow_wc(3)),
],
),
st(
P::Native(NP::ValueOf),
vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(0.into())],
vec![
STA::AnchoredKey(sow_wc(4), kow_wc(5)),
STA::Literal(0.into()),
],
),
],
6,
@ -586,14 +624,17 @@ mod tests {
),
st(
P::Native(NP::ValueOf),
vec![STA::AnchoredKey(wc(6), kow_wc(7)), STA::Literal(1.into())],
vec![
STA::AnchoredKey(sow_wc(6), kow_wc(7)),
STA::Literal(1.into()),
],
),
st(
P::Native(NP::SumOf),
vec![
STA::AnchoredKey(wc(4), kow_wc(5)),
STA::AnchoredKey(wc(8), kow_wc(9)),
STA::AnchoredKey(wc(6), kow_wc(7)),
STA::AnchoredKey(sow_wc(4), kow_wc(5)),
STA::AnchoredKey(sow_wc(8), kow_wc(9)),
STA::AnchoredKey(sow_wc(6), kow_wc(7)),
],
),
st(

View file

@ -8,8 +8,8 @@ use crate::{
backends::plonky2::primitives::merkletree::MerkleProof,
middleware::{
custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error,
NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmplArg,
ToFields, Wildcard, WildcardValue, F, SELF,
NativePredicate, Params, Predicate, Result, SelfOrWildcard, Statement, StatementArg,
StatementTmplArg, ToFields, Wildcard, WildcardValue, F, SELF,
},
};
@ -363,10 +363,15 @@ pub fn check_st_tmpl(
(StatementTmplArg::None, StatementArg::None) => true,
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true,
(
StatementTmplArg::AnchoredKey(pod_id_wc, key_or_wc),
StatementTmplArg::AnchoredKey(self_or_wc, key_or_wc),
StatementArg::Key(AnchoredKey { pod_id, key }),
) => {
let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map);
let pod_id_ok = match self_or_wc {
SelfOrWildcard::SELF => SELF == *pod_id,
SelfOrWildcard::Wildcard(pod_id_wc) => {
check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map)
}
};
let key_ok = match key_or_wc {
KeyOrWildcard::Key(tmpl_key) => tmpl_key == key,
KeyOrWildcard::Wildcard(key_wc) => {