Allow literals in statements (#276)

Implements #229 and #261.
This commit is contained in:
Daniel Gulotta 2025-06-13 10:27:19 -07:00 committed by GitHub
parent 21ab3c2d0d
commit 7d0d3ad769
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 992 additions and 825 deletions

View file

@ -23,17 +23,16 @@ pub const OPERATION_AUX_F_LEN: usize = 2;
pub enum NativePredicate {
None = 0, // Always true
False = 1, // Always false
ValueOf = 2,
Equal = 3,
NotEqual = 4,
LtEq = 5,
Lt = 6,
Contains = 7,
NotContains = 8,
SumOf = 9,
ProductOf = 10,
MaxOf = 11,
HashOf = 12,
Equal = 2,
NotEqual = 3,
LtEq = 4,
Lt = 5,
Contains = 6,
NotContains = 7,
SumOf = 8,
ProductOf = 9,
MaxOf = 10,
HashOf = 11,
// Syntactic sugar predicates. These predicates are not supported by the backend. The
// frontend compiler is responsible of translating these predicates into the predicates above.
@ -168,33 +167,58 @@ impl fmt::Display for Predicate {
#[serde(tag = "predicate", content = "args")]
pub enum Statement {
None,
ValueOf(AnchoredKey, Value),
Equal(AnchoredKey, AnchoredKey),
NotEqual(AnchoredKey, AnchoredKey),
LtEq(AnchoredKey, AnchoredKey),
Lt(AnchoredKey, AnchoredKey),
Equal(ValueRef, ValueRef),
NotEqual(ValueRef, ValueRef),
LtEq(ValueRef, ValueRef),
Lt(ValueRef, ValueRef),
Contains(
/* root */ AnchoredKey,
/* key */ AnchoredKey,
/* value */ AnchoredKey,
/* root */ ValueRef,
/* key */ ValueRef,
/* value */ ValueRef,
),
NotContains(/* root */ AnchoredKey, /* key */ AnchoredKey),
SumOf(AnchoredKey, AnchoredKey, AnchoredKey),
ProductOf(AnchoredKey, AnchoredKey, AnchoredKey),
MaxOf(AnchoredKey, AnchoredKey, AnchoredKey),
HashOf(AnchoredKey, AnchoredKey, AnchoredKey),
NotContains(/* root */ ValueRef, /* key */ ValueRef),
SumOf(ValueRef, ValueRef, ValueRef),
ProductOf(ValueRef, ValueRef, ValueRef),
MaxOf(ValueRef, ValueRef, ValueRef),
HashOf(ValueRef, ValueRef, ValueRef),
Custom(CustomPredicateRef, Vec<WildcardValue>),
}
macro_rules! statement_constructor {
($var_name: ident, $cons_name: ident, 2) => {
pub fn $var_name(v1: impl Into<ValueRef>, v2: impl Into<ValueRef>) -> Self {
Self::$cons_name(v1.into(), v2.into())
}
};
($var_name: ident, $cons_name: ident, 3) => {
pub fn $var_name(
v1: impl Into<ValueRef>,
v2: impl Into<ValueRef>,
v3: impl Into<ValueRef>,
) -> Self {
Self::$cons_name(v1.into(), v2.into(), v3.into())
}
};
}
impl Statement {
pub fn is_none(&self) -> bool {
self == &Self::None
}
statement_constructor!(equal, Equal, 2);
statement_constructor!(not_equal, NotEqual, 2);
statement_constructor!(lt_eq, LtEq, 2);
statement_constructor!(lt, Lt, 2);
statement_constructor!(contains, Contains, 3);
statement_constructor!(not_contains, NotContains, 2);
statement_constructor!(sum_of, SumOf, 3);
statement_constructor!(product_of, ProductOf, 3);
statement_constructor!(max_of, MaxOf, 3);
statement_constructor!(hash_of, HashOf, 3);
pub fn predicate(&self) -> Predicate {
use Predicate::*;
match self {
Self::None => Native(NativePredicate::None),
Self::ValueOf(_, _) => Native(NativePredicate::ValueOf),
Self::Equal(_, _) => Native(NativePredicate::Equal),
Self::NotEqual(_, _) => Native(NativePredicate::NotEqual),
Self::LtEq(_, _) => Native(NativePredicate::LtEq),
@ -212,117 +236,66 @@ impl Statement {
use StatementArg::*;
match self.clone() {
Self::None => vec![],
Self::ValueOf(ak, v) => vec![Key(ak), Literal(v)],
Self::Equal(ak1, ak2) => vec![Key(ak1), Key(ak2)],
Self::NotEqual(ak1, ak2) => vec![Key(ak1), Key(ak2)],
Self::LtEq(ak1, ak2) => vec![Key(ak1), Key(ak2)],
Self::Lt(ak1, ak2) => vec![Key(ak1), Key(ak2)],
Self::Contains(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::NotContains(ak1, ak2) => vec![Key(ak1), Key(ak2)],
Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::HashOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::Equal(ak1, ak2) => vec![ak1.into(), ak2.into()],
Self::NotEqual(ak1, ak2) => vec![ak1.into(), ak2.into()],
Self::LtEq(ak1, ak2) => vec![ak1.into(), ak2.into()],
Self::Lt(ak1, ak2) => vec![ak1.into(), ak2.into()],
Self::Contains(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::NotContains(ak1, ak2) => vec![ak1.into(), ak2.into()],
Self::SumOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::ProductOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::MaxOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::HashOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(WildcardLiteral)),
}
}
pub fn as_entry(&self) -> Option<(&AnchoredKey, &Value)> {
if let Self::Equal(ValueRef::Key(k), ValueRef::Literal(v)) = self {
Some((k, v))
} else {
None
}
}
pub fn from_args(pred: Predicate, args: Vec<StatementArg>) -> Result<Self> {
use Predicate::*;
let st: Result<Self> = match pred {
Native(NativePredicate::None) => Ok(Self::None),
Native(NativePredicate::ValueOf) => {
if let (StatementArg::Key(a0), StatementArg::Literal(v1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::ValueOf(a0, v1))
} else {
Err(Error::incorrect_statements_args())
}
let st = match (pred, &args.as_slice()) {
(Native(NativePredicate::None), &[]) => Self::None,
(Native(NativePredicate::Equal), &[a1, a2]) => {
Self::Equal(a1.try_into()?, a2.try_into()?)
}
Native(NativePredicate::Equal) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::Equal(a0, a1))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::NotEqual), &[a1, a2]) => {
Self::NotEqual(a1.try_into()?, a2.try_into()?)
}
Native(NativePredicate::NotEqual) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::NotEqual(a0, a1))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::LtEq), &[a1, a2]) => {
Self::LtEq(a1.try_into()?, a2.try_into()?)
}
Native(NativePredicate::LtEq) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::LtEq(a0, a1))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::Lt), &[a1, a2]) => Self::Lt(a1.try_into()?, a2.try_into()?),
(Native(NativePredicate::Contains), &[a1, a2, a3]) => {
Self::Contains(a1.try_into()?, a2.try_into()?, a3.try_into()?)
}
Native(NativePredicate::Lt) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::Lt(a0, a1))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::NotContains), &[a1, a2]) => {
Self::NotContains(a1.try_into()?, a2.try_into()?)
}
Native(NativePredicate::Contains) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::Contains(a0, a1, a2))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::SumOf), &[a1, a2, a3]) => {
Self::SumOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
}
Native(NativePredicate::NotContains) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::NotContains(a0, a1))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::ProductOf), &[a1, a2, a3]) => {
Self::ProductOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
}
Native(NativePredicate::SumOf) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::SumOf(a0, a1, a2))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::MaxOf), &[a1, a2, a3]) => {
Self::MaxOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
}
Native(NativePredicate::ProductOf) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::ProductOf(a0, a1, a2))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::HashOf), &[a1, a2, a3]) => {
Self::HashOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
}
Native(NativePredicate::MaxOf) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::MaxOf(a0, a1, a2))
} else {
Err(Error::incorrect_statements_args())
}
(Native(np), _) => {
return Err(Error::custom(format!("Predicate {:?} is syntax sugar", np)))
}
Native(np) => Err(Error::custom(format!("Predicate {:?} is syntax sugar", np))),
BatchSelf(_) => unreachable!(),
Custom(cpr) => {
(BatchSelf(_), _) => unreachable!(),
(Custom(cpr), _) => {
let v_args: Result<Vec<WildcardValue>> = args
.iter()
.map(|x| match x {
@ -330,10 +303,10 @@ impl Statement {
_ => Err(Error::incorrect_statements_args()),
})
.collect();
Ok(Self::Custom(cpr, v_args?))
Self::Custom(cpr, v_args?)
}
};
st
Ok(st)
}
}
@ -437,6 +410,57 @@ impl ToFields for StatementArg {
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub enum ValueRef {
Literal(Value),
Key(AnchoredKey),
}
impl From<ValueRef> for StatementArg {
fn from(value: ValueRef) -> Self {
match value {
ValueRef::Literal(v) => StatementArg::Literal(v),
ValueRef::Key(v) => StatementArg::Key(v),
}
}
}
impl TryFrom<StatementArg> for ValueRef {
type Error = crate::middleware::Error;
fn try_from(value: StatementArg) -> std::result::Result<Self, Self::Error> {
match value {
StatementArg::Literal(v) => Ok(Self::Literal(v)),
StatementArg::Key(k) => Ok(Self::Key(k)),
_ => Err(Self::Error::invalid_statement_arg(
value,
"literal or key".to_string(),
)),
}
}
}
impl TryFrom<&StatementArg> for ValueRef {
type Error = crate::middleware::Error;
fn try_from(value: &StatementArg) -> std::result::Result<Self, Self::Error> {
value.clone().try_into()
}
}
impl From<AnchoredKey> for ValueRef {
fn from(value: AnchoredKey) -> Self {
Self::Key(value)
}
}
impl<T> From<T> for ValueRef
where
T: Into<Value>,
{
fn from(value: T) -> Self {
Self::Literal(value.into())
}
}
#[cfg(test)]
mod tests {
use super::*;