Constraints for custom predicates (#227)

* add target types for custom predicates

* simplify

* fix clippy

* fix typo

* don't use ref for NativePredicate

* fix wrong len

* precalculate CustomPredicateBatch id

* wip

* wip

* move code back

* great progress

* wip

* code complete, hopefully; missing tests

* fill aux for custom predicate op

* fix clippy warnings

* fix typos

* fix test import

* fix missing assignment in lt_mask, test custom_operation_verify_gadget

* fix mistake

* wip

* fix

* debug revert except for let entry = CustomPredicateVerifyEntryTarget

* fix batch_id calculation by fixing padding

* oops

* remove completed TODOs
This commit is contained in:
Eduard S. 2025-05-13 11:00:45 +02:00 committed by GitHub
parent 4fa9e20ecd
commit 024ed8bd04
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1597 additions and 291 deletions

View file

@ -5,7 +5,8 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::middleware::{
hash_fields, Error, Hash, Key, Params, Predicate, Result, ToFields, Value, F, HASH_SIZE,
hash_fields, Error, Hash, Key, NativePredicate, Params, Predicate, Result, ToFields, Value,
EMPTY_HASH, F, HASH_SIZE, VALUE_SIZE,
};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@ -49,12 +50,15 @@ impl fmt::Display for KeyOrWildcard {
}
impl ToFields for KeyOrWildcard {
// Encoding:
// - Key(k) => [[k]]
// - Wildcard(index) => [[index], 0, 0, 0]
fn to_fields(&self, params: &Params) -> Vec<F> {
match self {
KeyOrWildcard::Key(k) => k.hash().to_fields(params),
KeyOrWildcard::Wildcard(wc) => iter::once(F::ZERO)
.take(HASH_SIZE - 1)
.chain(iter::once(F::from_canonical_u64(wc.index as u64)))
KeyOrWildcard::Wildcard(wc) => iter::once(F::from_canonical_u64(wc.index as u64))
.chain(iter::repeat(F::ZERO))
.take(HASH_SIZE)
.collect(),
}
}
@ -66,7 +70,7 @@ pub enum StatementTmplArg {
None,
Literal(Value),
// AnchoredKey
Key(Wildcard, KeyOrWildcard),
AnchoredKey(Wildcard, KeyOrWildcard),
// TODO: This naming is a bit confusing: a WildcardLiteral that contains a Wildcard...
// Could we merge WildcardValue and Value and allow wildcard value apart from pod_id and key?
WildcardLiteral(Wildcard),
@ -76,7 +80,7 @@ pub enum StatementTmplArg {
pub enum StatementTmplArgPrefix {
None = 0,
Literal = 1,
Key = 2,
AnchoredKey = 2,
WildcardLiteral = 3,
}
@ -88,11 +92,11 @@ impl From<StatementTmplArgPrefix> for F {
impl ToFields for StatementTmplArg {
fn to_fields(&self, params: &Params) -> Vec<F> {
// None => (0, ...)
// Literal(value) => (1, [value], 0, 0, 0, 0)
// Key(wildcard1_index, key_or_wildcard2)
// => (2, [wildcard1_index], 0, 0, 0, [key_or_wildcard2])
// WildcardLiteral(wildcard_index) => (3, [wildcard_index], 0, 0, 0, 0, 0, 0, 0)
// Encoding:
// None => (0, 0, 0, 0, 0, 0, 0, 0, 0)
// Literal(v) => (1, [v ], 0, 0, 0, 0)
// Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
match self {
StatementTmplArg::None => {
@ -105,13 +109,15 @@ impl ToFields for StatementTmplArg {
StatementTmplArg::Literal(v) => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal))
.chain(v.raw().to_fields(params))
.chain(iter::repeat(F::ZERO).take(HASH_SIZE))
.chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect();
fields
}
StatementTmplArg::Key(wc1, kw2) => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Key))
StatementTmplArg::AnchoredKey(wc1, kw2) => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::AnchoredKey))
.chain(wc1.to_fields(params))
.chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1))
.chain(kw2.to_fields(params))
.collect();
fields
@ -119,7 +125,8 @@ impl ToFields for StatementTmplArg {
StatementTmplArg::WildcardLiteral(wc) => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
.chain(wc.to_fields(params))
.chain(iter::repeat(F::ZERO).take(HASH_SIZE))
.chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect();
fields
}
@ -132,7 +139,7 @@ impl fmt::Display for StatementTmplArg {
match self {
Self::None => write!(f, "none"),
Self::Literal(v) => write!(f, "{}", v),
Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key),
Self::AnchoredKey(pod_id, key) => write!(f, "({}, {})", pod_id, key),
Self::WildcardLiteral(v) => write!(f, "{}", v),
}
}
@ -177,7 +184,11 @@ impl ToFields for StatementTmpl {
// instead of at the `to_fields` method, where we should assume that the
// values are already valid
if self.args.len() > params.max_statement_args {
panic!("Statement template has too many arguments");
panic!(
"Statement template has too many arguments {} > {}",
self.args.len(),
params.max_statement_args
);
}
let mut fields: Vec<F> = self
@ -206,25 +217,36 @@ pub struct CustomPredicate {
}
impl CustomPredicate {
pub fn empty() -> Self {
Self {
name: "empty".to_string(),
conjunction: false,
statements: vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::None),
args: vec![],
}],
args_len: 0,
}
}
pub fn and(
name: String,
params: &Params,
name: String,
statements: Vec<StatementTmpl>,
args_len: usize,
) -> Result<Self> {
Self::new(name, params, true, statements, args_len)
Self::new(params, name, true, statements, args_len)
}
pub fn or(
name: String,
params: &Params,
name: String,
statements: Vec<StatementTmpl>,
args_len: usize,
) -> Result<Self> {
Self::new(name, params, false, statements, args_len)
Self::new(params, name, false, statements, args_len)
}
pub fn new(
name: String,
params: &Params,
name: String,
conjunction: bool,
statements: Vec<StatementTmpl>,
args_len: usize,
@ -236,6 +258,13 @@ impl CustomPredicate {
params.max_custom_predicate_arity,
));
}
if args_len > params.max_statement_args {
return Err(Error::max_length(
"statement_args.len".to_string(),
args_len,
params.max_statement_args,
));
}
Ok(Self {
name,
@ -244,6 +273,16 @@ impl CustomPredicate {
args_len,
})
}
pub fn pad_statement_tmpl(&self) -> StatementTmpl {
StatementTmpl {
pred: Predicate::Native(if self.conjunction {
NativePredicate::False
} else {
NativePredicate::None
}),
args: vec![],
}
}
}
impl ToFields for CustomPredicate {
@ -262,11 +301,17 @@ impl ToFields for CustomPredicate {
panic!("Custom predicate depends on too many statements");
}
let mut fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
let pad_st = self.pad_statement_tmpl();
let fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
.chain(iter::once(F::from_canonical_usize(self.args_len)))
.chain(self.statements.iter().flat_map(|st| st.to_fields(params)))
.chain(
self.statements
.iter()
.chain(iter::repeat(&pad_st))
.take(params.max_custom_predicate_arity)
.flat_map(|st| st.to_fields(params)),
)
.collect();
fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0));
fields
}
}
@ -298,8 +343,9 @@ impl fmt::Display for CustomPredicate {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct CustomPredicateBatch {
id: Hash,
pub name: String,
pub predicates: Vec<CustomPredicate>,
predicates: Vec<CustomPredicate>,
}
impl ToFields for CustomPredicateBatch {
@ -313,27 +359,45 @@ impl ToFields for CustomPredicateBatch {
panic!("Predicate batch exceeds maximum size");
}
let mut fields: Vec<F> = self
let pad_pred = CustomPredicate::empty();
let fields: Vec<F> = self
.predicates
.iter()
.chain(iter::repeat(&pad_pred))
.take(params.max_custom_batch_size)
.flat_map(|p| p.to_fields(params))
.collect();
fields.resize_with(params.custom_predicate_batch_size_field_elts(), || {
F::from_canonical_u64(0)
});
fields
}
}
impl CustomPredicateBatch {
pub fn new(params: &Params, name: String, predicates: Vec<CustomPredicate>) -> Arc<Self> {
let mut cpb = Self {
id: EMPTY_HASH,
name,
predicates,
};
let id = cpb.calculate_id(params);
cpb.id = id;
Arc::new(cpb)
}
/// Cryptographic identifier for the batch.
pub fn id(&self, params: &Params) -> Hash {
fn calculate_id(&self, params: &Params) -> Hash {
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
// but ideally we want to use the root of a merkle tree built from the custom predicates.
let input = self.to_fields(params);
hash_fields(&input)
}
pub fn id(&self) -> Hash {
self.id
}
pub fn predicates(&self) -> &[CustomPredicate] {
&self.predicates
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@ -347,13 +411,16 @@ impl CustomPredicateRef {
Self { batch, index }
}
pub fn arg_len(&self) -> usize {
self.batch.predicates[self.index].args_len
self.predicate().args_len
}
pub fn predicate(&self) -> &CustomPredicate {
&self.batch.predicates[self.index]
}
}
#[cfg(test)]
mod tests {
use std::{array, sync::Arc};
use std::array;
use plonky2::field::goldilocks_field::GoldilocksField;
@ -392,28 +459,29 @@ mod tests {
p:value_of(Constant, 2),
p:product_of(S1, Constant, S2)
*/
let cust_pred_batch = Arc::new(CustomPredicateBatch {
name: "is_double".to_string(),
predicates: vec![CustomPredicate::and(
"_".into(),
let cust_pred_batch = CustomPredicateBatch::new(
&params,
"is_double".to_string(),
vec![CustomPredicate::and(
&params,
"_".into(),
vec![
st(
P::Native(NP::ValueOf),
vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(2.into())],
vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(2.into())],
),
st(
P::Native(NP::ProductOf),
vec![
STA::Key(wc(0), kow_wc(1)),
STA::Key(wc(4), kow_wc(5)),
STA::Key(wc(2), kow_wc(3)),
STA::AnchoredKey(wc(0), kow_wc(1)),
STA::AnchoredKey(wc(4), kow_wc(5)),
STA::AnchoredKey(wc(2), kow_wc(3)),
],
),
],
2,
)?],
});
);
let custom_statement = Statement::Custom(
CustomPredicateRef::new(cust_pred_batch.clone(), 0),
@ -444,55 +512,57 @@ mod tests {
fn ethdos_test() -> Result<()> {
let params = Params {
max_custom_predicate_wildcards: 12,
max_statement_args: 6,
..Default::default()
};
let eth_friend_cp = CustomPredicate::and(
"eth_friend_cp".into(),
&params,
"eth_friend_cp".into(),
vec![
st(
P::Native(NP::ValueOf),
vec![
STA::Key(wc(4), KeyOrWildcard::Key("type".into())),
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("type".into())),
STA::Literal(PodType::Signed.into()),
],
),
st(
P::Native(NP::Equal),
vec![
STA::Key(wc(4), KeyOrWildcard::Key("signer".into())),
STA::Key(wc(0), kow_wc(1)),
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("signer".into())),
STA::AnchoredKey(wc(0), kow_wc(1)),
],
),
st(
P::Native(NP::Equal),
vec![
STA::Key(wc(4), KeyOrWildcard::Key("attestation".into())),
STA::Key(wc(2), kow_wc(3)),
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("attestation".into())),
STA::AnchoredKey(wc(2), kow_wc(3)),
],
),
],
4,
)?;
let eth_friend_batch = Arc::new(CustomPredicateBatch {
name: "eth_friend".to_string(),
predicates: vec![eth_friend_cp],
});
let eth_friend_batch =
CustomPredicateBatch::new(&params, "eth_friend".to_string(), vec![eth_friend_cp]);
// 0
let eth_dos_base = CustomPredicate::and(
"eth_dos_base".into(),
&params,
"eth_dos_base".into(),
vec![
st(
P::Native(NP::Equal),
vec![STA::Key(wc(0), kow_wc(1)), STA::Key(wc(2), kow_wc(3))],
vec![
STA::AnchoredKey(wc(0), kow_wc(1)),
STA::AnchoredKey(wc(2), kow_wc(3)),
],
),
st(
P::Native(NP::ValueOf),
vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(0.into())],
vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(0.into())],
),
],
6,
@ -500,8 +570,8 @@ mod tests {
// 1
let eth_dos_ind = CustomPredicate::and(
"eth_dos_ind".into(),
&params,
"eth_dos_ind".into(),
vec![
st(
P::BatchSelf(2),
@ -516,14 +586,14 @@ mod tests {
),
st(
P::Native(NP::ValueOf),
vec![STA::Key(wc(6), kow_wc(7)), STA::Literal(1.into())],
vec![STA::AnchoredKey(wc(6), kow_wc(7)), STA::Literal(1.into())],
),
st(
P::Native(NP::SumOf),
vec![
STA::Key(wc(4), kow_wc(5)),
STA::Key(wc(8), kow_wc(9)),
STA::Key(wc(6), kow_wc(7)),
STA::AnchoredKey(wc(4), kow_wc(5)),
STA::AnchoredKey(wc(8), kow_wc(9)),
STA::AnchoredKey(wc(6), kow_wc(7)),
],
),
st(
@ -541,8 +611,8 @@ mod tests {
// 2
let eth_dos_distance_either = CustomPredicate::or(
"eth_dos_distance_either".into(),
&params,
"eth_dos_distance_either".into(),
vec![
st(
P::BatchSelf(0),
@ -570,10 +640,11 @@ mod tests {
6,
)?;
let eth_dos_distance_batch = Arc::new(CustomPredicateBatch {
name: "ETHDoS_distance".to_string(),
predicates: vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either],
});
let eth_dos_distance_batch = CustomPredicateBatch::new(
&params,
"ETHDoS_distance".to_string(),
vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either],
);
// Some POD IDs
let pod_id1 = PodId(Hash(array::from_fn(|i| GoldilocksField(i as u64))));