add target types for custom predicates (#223)

* add target types for custom predicates

* simplify

* fix clippy

* fix typo

* don't use ref for NativePredicate

* fix wrong len

* apply feedback from @ax0
This commit is contained in:
Eduard S. 2025-05-07 11:09:38 +02:00 committed by GitHub
parent bf394eada3
commit 726f95483d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 527 additions and 123 deletions

View file

@ -5,7 +5,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::middleware::{
hash_fields, Error, Hash, Key, NativePredicate, Params, Result, ToFields, Value, F, HASH_SIZE,
hash_fields, Error, Hash, Key, Params, Predicate, Result, ToFields, Value, F, HASH_SIZE,
};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@ -72,40 +72,54 @@ pub enum StatementTmplArg {
WildcardLiteral(Wildcard),
}
#[derive(Clone, Copy)]
pub enum StatementTmplArgPrefix {
None = 0,
Literal = 1,
Key = 2,
WildcardLiteral = 3,
}
impl From<StatementTmplArgPrefix> for F {
fn from(prefix: StatementTmplArgPrefix) -> Self {
Self::from_canonical_usize(prefix as usize)
}
}
impl ToFields for StatementTmplArg {
fn to_fields(&self, params: &Params) -> Vec<F> {
// None => (0, ...)
// Literal(value) => (1, [value], 0, 0, 0, 0)
// Key(wildcard1, key_or_wildcard2)
// => (2, [wildcard1], [key_or_wildcard2])
// WildcardLiteral(wildcard) => (3, [wildcard], 0, 0, 0, 0)
// Key(wildcard1_index, key_or_wildcard2)
// => (2, [wildcard1_index], 0, 0, 0, [key_or_wildcard2])
// WildcardLiteral(wildcard_index) => (3, [wildcard_index], 0, 0, 0, 0, 0, 0, 0)
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
let statement_tmpl_arg_size = 2 * HASH_SIZE + 1;
match self {
StatementTmplArg::None => {
let fields: Vec<F> = iter::repeat_with(|| F::from_canonical_u64(0))
.take(statement_tmpl_arg_size)
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::None))
.chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect();
fields
}
StatementTmplArg::Literal(v) => {
let fields: Vec<F> = iter::once(F::from_canonical_u64(1))
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal))
.chain(v.raw().to_fields(params))
.chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE))
.chain(iter::repeat(F::ZERO).take(HASH_SIZE))
.collect();
fields
}
StatementTmplArg::Key(wc1, kw2) => {
let fields: Vec<F> = iter::once(F::from_canonical_u64(2))
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Key))
.chain(wc1.to_fields(params))
.chain(kw2.to_fields(params))
.collect();
fields
}
StatementTmplArg::WildcardLiteral(wc) => {
let fields: Vec<F> = iter::once(F::from_canonical_u64(3))
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
.chain(wc.to_fields(params))
.chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE))
.chain(iter::repeat(F::ZERO).take(HASH_SIZE))
.collect();
fields
}
@ -312,7 +326,10 @@ impl ToFields for CustomPredicateBatch {
}
impl CustomPredicateBatch {
pub fn hash(&self, params: &Params) -> Hash {
/// Cryptographic identifier for the batch.
pub fn id(&self, params: &Params) -> Hash {
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
// but ideally we want to use the root of a merkle tree built from the custom predicates.
let input = self.to_fields(params);
hash_fields(&input)
@ -334,65 +351,6 @@ impl CustomPredicateRef {
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", content = "value")]
pub enum Predicate {
Native(NativePredicate),
BatchSelf(usize),
Custom(CustomPredicateRef),
}
impl From<NativePredicate> for Predicate {
fn from(v: NativePredicate) -> Self {
Self::Native(v)
}
}
impl ToFields for Predicate {
fn to_fields(&self, params: &Params) -> Vec<F> {
// serialize:
// NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize
// BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize
// CustomPredicateRef(pb, i) as
// (2, [hash of pb], i) -- pb hashes to 4 field elements
// -- i: usize
// in every case: pad to (hash_size + 2) field elements
let mut fields: Vec<F> = match self {
Self::Native(p) => iter::once(F::from_canonical_u64(1))
.chain(p.to_fields(params))
.collect(),
Self::BatchSelf(i) => iter::once(F::from_canonical_u64(2))
.chain(iter::once(F::from_canonical_usize(*i)))
.collect(),
Self::Custom(CustomPredicateRef { batch, index }) => {
iter::once(F::from_canonical_u64(3))
.chain(batch.hash(params).0)
.chain(iter::once(F::from_canonical_usize(*index)))
.collect()
}
};
fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0));
fields
}
}
impl fmt::Display for Predicate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Native(p) => write!(f, "{:?}", p),
Self::BatchSelf(i) => write!(f, "self.{}", i),
Self::Custom(CustomPredicateRef { batch, index }) => {
write!(
f,
"{}.{}[{}]",
batch.name, index, batch.predicates[*index].name
)
}
}
}
}
#[cfg(test)]
mod tests {
use std::{array, sync::Arc};

View file

@ -43,7 +43,7 @@ impl ToFields for OperationType {
.collect(),
Self::Custom(CustomPredicateRef { batch, index }) => {
iter::once(F::from_canonical_u64(3))
.chain(batch.hash(params).0)
.chain(batch.id(params).0)
.chain(iter::once(F::from_canonical_usize(*index)))
.collect()
}

View file

@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize};
use strum_macros::FromRepr;
use crate::middleware::{
AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, Predicate, RawValue, Result,
ToFields, Value, F, VALUE_SIZE,
AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, RawValue, Result, ToFields, Value,
F, VALUE_SIZE,
};
// TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static
@ -84,6 +84,78 @@ impl ToFields for WildcardValue {
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", content = "value")]
pub enum Predicate {
Native(NativePredicate),
BatchSelf(usize),
Custom(CustomPredicateRef),
}
impl From<NativePredicate> for Predicate {
fn from(v: NativePredicate) -> Self {
Self::Native(v)
}
}
#[derive(Clone, Copy)]
pub enum PredicatePrefix {
Native = 1,
BatchSelf = 2,
Custom = 3,
}
impl From<PredicatePrefix> for F {
fn from(prefix: PredicatePrefix) -> Self {
Self::from_canonical_usize(prefix as usize)
}
}
impl ToFields for Predicate {
fn to_fields(&self, params: &Params) -> Vec<F> {
// serialize:
// NativePredicate(id) as (1, id, 0, 0, 0, 0) -- id: usize
// BatchSelf(i) as (2, i, 0, 0, 0, 0) -- i: usize
// CustomPredicateRef(pb, i) as
// (3, [hash of pb], i) -- pb hashes to 4 field elements
// -- i: usize
// in every case: pad to (hash_size + 2) field elements
let mut fields: Vec<F> = match self {
Self::Native(p) => iter::once(F::from(PredicatePrefix::Native))
.chain(p.to_fields(params))
.collect(),
Self::BatchSelf(i) => iter::once(F::from(PredicatePrefix::BatchSelf))
.chain(iter::once(F::from_canonical_usize(*i)))
.collect(),
Self::Custom(CustomPredicateRef { batch, index }) => {
iter::once(F::from(PredicatePrefix::Custom))
.chain(batch.id(params).0)
.chain(iter::once(F::from_canonical_usize(*index)))
.collect()
}
};
fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0));
fields
}
}
impl fmt::Display for Predicate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Native(p) => write!(f, "{:?}", p),
Self::BatchSelf(i) => write!(f, "self.{}", i),
Self::Custom(CustomPredicateRef { batch, index }) => {
write!(
f,
"{}.{}[{}]",
batch.name, index, batch.predicates[*index].name
)
}
}
}
}
/// Type encapsulating statements with their associated arguments.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "predicate", content = "args")]