Refactor frontend/middleware types (#194)

* unify fe/be NativeOp and NativePred

* remove Origin in favour of PodId

* Combine string and hash in Key

* use middleware::AnchoredKey in frontend

* merge frontend/middleware types

* refactor custom predicates

* clean up a bit

* fix middleware custom tests

* clean up

* clean up 2

* add acronyms in typos list
This commit is contained in:
Eduard S. 2025-04-16 11:59:30 +02:00 committed by GitHub
parent 9e860ef262
commit c232c8dae5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1985 additions and 2800 deletions

View file

@ -2,14 +2,16 @@ use std::{fmt, iter};
use anyhow::{anyhow, Result};
use plonky2::field::types::Field;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
// use schemars::JsonSchema;
// use serde::{Deserialize, Serialize};
use strum_macros::FromRepr;
use crate::middleware::{
AnchoredKey, CustomPredicateRef, Params, Predicate, ToFields, Value, F, VALUE_SIZE,
AnchoredKey, CustomPredicateRef, Key, Params, PodId, Predicate, RawValue, ToFields, Value, F,
VALUE_SIZE,
};
// TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static
// hash(KEY_SIGNER) = [2145458785152392366, 15113074911296146791, 15323228995597834291, 11804480340100333725]
pub const KEY_SIGNER: &str = "_signer";
// hash(KEY_TYPE) = [17948789436443445142, 12513915140657440811, 15878361618879468769, 938231894693848619]
@ -18,7 +20,7 @@ pub const STATEMENT_ARG_F_LEN: usize = 8;
pub const OPERATION_ARG_F_LEN: usize = 1;
pub const OPERATION_AUX_F_LEN: usize = 1;
#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash)]
pub enum NativePredicate {
None = 0,
ValueOf = 1,
@ -31,6 +33,14 @@ pub enum NativePredicate {
SumOf = 8,
ProductOf = 9,
MaxOf = 10,
// Syntactic sugar predicates. These predicates are not supported by the backend. The
// frontend compiler is responsible of translating these predicates into the predicates above.
DictContains = 1000,
DictNotContains = 1001,
SetContains = 1002,
SetNotContains = 1003,
ArrayContains = 1004, // there is no ArrayNotContains
}
impl ToFields for NativePredicate {
@ -39,8 +49,41 @@ impl ToFields for NativePredicate {
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum WildcardValue {
PodId(PodId),
Key(Key),
}
impl WildcardValue {
pub fn raw(&self) -> RawValue {
match self {
WildcardValue::PodId(pod_id) => RawValue::from(pod_id.0),
WildcardValue::Key(key) => key.raw(),
}
}
}
impl fmt::Display for WildcardValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WildcardValue::PodId(pod_id) => write!(f, "{}", pod_id),
WildcardValue::Key(key) => write!(f, "{}", key),
}
}
}
impl ToFields for WildcardValue {
fn to_fields(&self, params: &Params) -> Vec<F> {
match self {
WildcardValue::PodId(pod_id) => pod_id.to_fields(params),
WildcardValue::Key(key) => key.to_fields(params),
}
}
}
/// Type encapsulating statements with their associated arguments.
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq)]
pub enum Statement {
None,
ValueOf(AnchoredKey, Value),
@ -57,7 +100,7 @@ pub enum Statement {
SumOf(AnchoredKey, AnchoredKey, AnchoredKey),
ProductOf(AnchoredKey, AnchoredKey, AnchoredKey),
MaxOf(AnchoredKey, AnchoredKey, AnchoredKey),
Custom(CustomPredicateRef, Vec<AnchoredKey>),
Custom(CustomPredicateRef, Vec<WildcardValue>),
}
impl Statement {
@ -95,7 +138,7 @@ impl Statement {
Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(Key)),
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(WildcardLiteral)),
}
}
pub fn from_args(pred: Predicate, args: Vec<StatementArg>) -> Result<Self> {
@ -103,35 +146,45 @@ impl Statement {
let st: Result<Self> = match pred {
Native(NativePredicate::None) => Ok(Self::None),
Native(NativePredicate::ValueOf) => {
if let (StatementArg::Key(a0), StatementArg::Literal(v1)) = (args[0], args[1]) {
if let (StatementArg::Key(a0), StatementArg::Literal(v1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::ValueOf(a0, v1))
} else {
Err(anyhow!("Incorrect statement args"))
}
}
Native(NativePredicate::Equal) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::Equal(a0, a1))
} else {
Err(anyhow!("Incorrect statement args"))
}
}
Native(NativePredicate::NotEqual) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::NotEqual(a0, a1))
} else {
Err(anyhow!("Incorrect statement args"))
}
}
Native(NativePredicate::Gt) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::Gt(a0, a1))
} else {
Err(anyhow!("Incorrect statement args"))
}
}
Native(NativePredicate::Lt) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::Lt(a0, a1))
} else {
Err(anyhow!("Incorrect statement args"))
@ -139,7 +192,7 @@ impl Statement {
}
Native(NativePredicate::Contains) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0], args[1], args[2])
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::Contains(a0, a1, a2))
} else {
@ -147,7 +200,9 @@ impl Statement {
}
}
Native(NativePredicate::NotContains) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) = (args[0], args[1]) {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::NotContains(a0, a1))
} else {
Err(anyhow!("Incorrect statement args"))
@ -155,7 +210,7 @@ impl Statement {
}
Native(NativePredicate::SumOf) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0], args[1], args[2])
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::SumOf(a0, a1, a2))
} else {
@ -164,7 +219,7 @@ impl Statement {
}
Native(NativePredicate::ProductOf) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0], args[1], args[2])
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::ProductOf(a0, a1, a2))
} else {
@ -173,23 +228,24 @@ impl Statement {
}
Native(NativePredicate::MaxOf) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0], args[1], args[2])
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::MaxOf(a0, a1, a2))
} else {
Err(anyhow!("Incorrect statement args"))
}
}
Native(np) => Err(anyhow!("Predicate {:?} is syntax sugar", np)),
BatchSelf(_) => unreachable!(),
Custom(cpr) => {
let ak_args: Result<Vec<AnchoredKey>> = args
let v_args: Result<Vec<WildcardValue>> = args
.iter()
.map(|x| match x {
StatementArg::Key(ak) => Ok(*ak),
StatementArg::WildcardLiteral(v) => Ok(v.clone()),
_ => Err(anyhow!("Incorrect statement args")),
})
.collect();
Ok(Self::Custom(cpr, ak_args?))
Ok(Self::Custom(cpr, v_args?))
}
};
st
@ -207,23 +263,24 @@ impl ToFields for Statement {
impl fmt::Display for Statement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?} ", self.predicate())?;
write!(f, "{}(", self.predicate())?;
for (i, arg) in self.args().iter().enumerate() {
if i != 0 {
write!(f, " ")?;
write!(f, ", ")?;
}
write!(f, "{}", arg)?;
}
Ok(())
write!(f, ")")
}
}
/// Statement argument type. Useful for statement decompositions.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq)]
pub enum StatementArg {
None,
Literal(Value),
Key(AnchoredKey),
WildcardLiteral(WildcardValue),
}
impl fmt::Display for StatementArg {
@ -231,7 +288,8 @@ impl fmt::Display for StatementArg {
match self {
StatementArg::None => write!(f, "none"),
StatementArg::Literal(v) => write!(f, "{}", v),
StatementArg::Key(r) => write!(f, "{}.{}", r.0, r.1),
StatementArg::Key(r) => write!(f, "{}.{}", r.pod_id, r.key),
StatementArg::WildcardLiteral(v) => write!(f, "{}", v),
}
}
}
@ -242,13 +300,13 @@ impl StatementArg {
}
pub fn literal(&self) -> Result<Value> {
match self {
Self::Literal(value) => Ok(*value),
Self::Literal(value) => Ok(value.clone()),
_ => Err(anyhow!("Statement argument {:?} is not a literal.", self)),
}
}
pub fn key(&self) -> Result<AnchoredKey> {
match self {
Self::Key(ak) => Ok(*ak),
Self::Key(ak) => Ok(ak.clone()),
_ => Err(anyhow!("Statement argument {:?} is not a key.", self)),
}
}
@ -265,16 +323,23 @@ impl ToFields for StatementArg {
// dealing with `Literal` it would be of length 4.
let f = match self {
StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN],
StatementArg::Literal(v) => {
v.0.into_iter()
.chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE))
.collect()
}
StatementArg::Literal(v) => v
.raw()
.0
.into_iter()
.chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE))
.collect(),
StatementArg::Key(ak) => {
let mut fields = ak.0.to_fields(_params);
fields.extend(ak.1.to_fields(_params));
let mut fields = ak.pod_id.to_fields(_params);
fields.extend(ak.key.to_fields(_params));
fields
}
StatementArg::WildcardLiteral(v) => v
.raw()
.0
.into_iter()
.chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE))
.collect(),
};
assert_eq!(f.len(), STATEMENT_ARG_F_LEN); // sanity check
f