Allow literals in statements (#276)

Implements #229 and #261.
This commit is contained in:
Daniel Gulotta 2025-06-13 10:27:19 -07:00 committed by GitHub
parent 21ab3c2d0d
commit 7d0d3ad769
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 992 additions and 825 deletions

View file

@ -527,7 +527,7 @@ mod tests {
"_".into(),
vec![
st(
P::Native(NP::ValueOf),
P::Native(NP::Equal),
vec![
STA::AnchoredKey(sow_wc(4), kow_wc(5)),
STA::Literal(2.into()),
@ -558,8 +558,8 @@ mod tests {
let custom_deduction = Operation::Custom(
CustomPredicateRef::new(cust_pred_batch, 0),
vec![
Statement::ValueOf(AnchoredKey::from((SELF, "Some constant")), 2.into()),
Statement::ProductOf(
Statement::equal(AnchoredKey::from((SELF, "Some constant")), 2),
Statement::product_of(
AnchoredKey::from((SELF, "Some value")),
AnchoredKey::from((SELF, "Some constant")),
AnchoredKey::from((SELF, "Some other value")),
@ -585,7 +585,7 @@ mod tests {
"eth_friend_cp".into(),
vec![
st(
P::Native(NP::ValueOf),
P::Native(NP::Equal),
vec![
STA::AnchoredKey(sow_wc(4), KeyOrWildcard::Key("type".into())),
STA::Literal(PodType::Signed.into()),
@ -626,7 +626,7 @@ mod tests {
],
),
st(
P::Native(NP::ValueOf),
P::Native(NP::Equal),
vec![
STA::AnchoredKey(sow_wc(4), kow_wc(5)),
STA::Literal(0.into()),
@ -654,7 +654,7 @@ mod tests {
],
),
st(
P::Native(NP::ValueOf),
P::Native(NP::Equal),
vec![
STA::AnchoredKey(sow_wc(6), kow_wc(7)),
STA::Literal(1.into()),
@ -776,8 +776,8 @@ mod tests {
WildcardValue::Key(Key::from("Six")),
],
),
Statement::ValueOf(AnchoredKey::from((SELF, "One")), 1.into()),
Statement::SumOf(
Statement::equal(AnchoredKey::from((SELF, "One")), 1),
Statement::sum_of(
AnchoredKey::from((SELF, "Seven")),
AnchoredKey::from((pod_id4, "Six")),
AnchoredKey::from((SELF, "One")),

View file

@ -671,8 +671,7 @@ impl Default for Params {
max_signed_pod_values: 8,
max_public_statements: 10,
num_public_statements_id: 16,
// TODO: Reduce to 5 or less after https://github.com/0xPARC/pod2/issues/229
max_statement_args: 6,
max_statement_args: 5,
max_operation_args: 5,
max_custom_predicate_batches: 2,
max_custom_predicate_verifications: 5,
@ -793,7 +792,7 @@ pub trait Pod: fmt::Debug + DynClone + Any {
self.pub_statements()
.into_iter()
.filter_map(|st| match st {
Statement::ValueOf(ak, v) => Some((ak, v)),
Statement::Equal(ValueRef::Key(ak), ValueRef::Literal(v)) => Some((ak, v)),
_ => None,
})
.collect()

View file

@ -7,9 +7,9 @@ use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::primitives::merkletree::MerkleProof,
middleware::{
custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error,
NativePredicate, Params, Predicate, Result, SelfOrWildcard, Statement, StatementArg,
StatementTmplArg, ToFields, Wildcard, WildcardValue, F, SELF,
custom::KeyOrWildcard, hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef,
Error, NativePredicate, Params, Predicate, Result, SelfOrWildcard, Statement, StatementArg,
StatementTmplArg, ToFields, Value, ValueRef, Wildcard, WildcardValue, F, SELF,
},
};
@ -86,6 +86,12 @@ pub enum NativeOperation {
GtToNotEqual = 1008,
}
impl NativeOperation {
pub fn is_syntactic_sugar(self) -> bool {
(self as usize) >= 1000
}
}
impl ToFields for NativeOperation {
fn to_fields(&self, _params: &Params) -> Vec<F> {
vec![F::from_canonical_u64(*self as u64)]
@ -100,7 +106,7 @@ impl OperationType {
match self {
OperationType::Native(native_op) => match native_op {
NativeOperation::None => Some(Predicate::Native(NativePredicate::None)),
NativeOperation::NewEntry => Some(Predicate::Native(NativePredicate::ValueOf)),
NativeOperation::NewEntry => Some(Predicate::Native(NativePredicate::Equal)),
NativeOperation::CopyStatement => None,
NativeOperation::EqualFromEntries => {
Some(Predicate::Native(NativePredicate::Equal))
@ -161,6 +167,22 @@ pub enum Operation {
Custom(CustomPredicateRef, Vec<Statement>),
}
pub(crate) fn sum_op(x: i64, y: i64) -> i64 {
x + y
}
pub(crate) fn prod_op(x: i64, y: i64) -> i64 {
x * y
}
pub(crate) fn max_op(x: i64, y: i64) -> i64 {
x.max(y)
}
pub(crate) fn hash_op(x: Value, y: Value) -> Value {
Value::from(hash_values(&[x, y]))
}
impl Operation {
pub fn op_type(&self) -> OperationType {
type OT = OperationType;
@ -219,56 +241,53 @@ impl Operation {
pub fn op(op_code: OperationType, args: &[Statement], aux: &OperationAux) -> Result<Self> {
type OA = OperationAux;
type NO = NativeOperation;
let arg_tup = (
args.first().cloned(),
args.get(1).cloned(),
args.get(2).cloned(),
);
Ok(match op_code {
OperationType::Native(o) => match (o, arg_tup, aux.clone(), args.len()) {
(NO::None, (None, None, None), OA::None, 0) => Self::None,
(NO::NewEntry, (None, None, None), OA::None, 0) => Self::NewEntry,
(NO::CopyStatement, (Some(s), None, None), OA::None, 1) => Self::CopyStatement(s),
(NO::EqualFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => {
Self::EqualFromEntries(s1, s2)
OperationType::Native(o) => match (o, &args, aux.clone()) {
(NO::None, &[], OA::None) => Self::None,
(NO::NewEntry, &[], OA::None) => Self::NewEntry,
(NO::CopyStatement, &[s], OA::None) => Self::CopyStatement(s.clone()),
(NO::EqualFromEntries, &[s1, s2], OA::None) => {
Self::EqualFromEntries(s1.clone(), s2.clone())
}
(NO::NotEqualFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => {
Self::NotEqualFromEntries(s1, s2)
(NO::NotEqualFromEntries, &[s1, s2], OA::None) => {
Self::NotEqualFromEntries(s1.clone(), s2.clone())
}
(NO::LtEqFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => {
Self::LtEqFromEntries(s1, s2)
(NO::LtEqFromEntries, &[s1, s2], OA::None) => {
Self::LtEqFromEntries(s1.clone(), s2.clone())
}
(NO::LtFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => {
Self::LtFromEntries(s1, s2)
(NO::LtFromEntries, &[s1, s2], OA::None) => {
Self::LtFromEntries(s1.clone(), s2.clone())
}
(
NO::ContainsFromEntries,
(Some(s1), Some(s2), Some(s3)),
OA::MerkleProof(pf),
3,
) => Self::ContainsFromEntries(s1, s2, s3, pf),
(
NO::NotContainsFromEntries,
(Some(s1), Some(s2), None),
OA::MerkleProof(pf),
2,
) => Self::NotContainsFromEntries(s1, s2, pf),
(NO::SumOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => Self::SumOf(s1, s2, s3),
(NO::ProductOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => {
Self::ProductOf(s1, s2, s3)
(NO::ContainsFromEntries, &[s1, s2, s3], OA::MerkleProof(pf)) => {
Self::ContainsFromEntries(s1.clone(), s2.clone(), s3.clone(), pf)
}
(NO::MaxOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => Self::MaxOf(s1, s2, s3),
(NO::HashOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => {
Self::HashOf(s1, s2, s3)
(NO::NotContainsFromEntries, &[s1, s2], OA::MerkleProof(pf)) => {
Self::NotContainsFromEntries(s1.clone(), s2.clone(), pf)
}
(NO::SumOf, &[s1, s2, s3], OA::None) => {
Self::SumOf(s1.clone(), s2.clone(), s3.clone())
}
(NO::ProductOf, &[s1, s2, s3], OA::None) => {
Self::ProductOf(s1.clone(), s2.clone(), s3.clone())
}
(NO::MaxOf, &[s1, s2, s3], OA::None) => {
Self::MaxOf(s1.clone(), s2.clone(), s3.clone())
}
(NO::HashOf, &[s1, s2, s3], OA::None) => {
Self::HashOf(s1.clone(), s2.clone(), s3.clone())
}
_ => Err(Error::custom(format!(
"Ill-formed operation {:?} with arguments {:?}.",
op_code, args
"Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.",
op_code,
args.len(),
args,
aux
)))?,
},
OperationType::Custom(cpr) => Self::Custom(cpr, args.to_vec()),
})
}
/// Checks the given operation against a statement, and prints information if the check does not pass
pub fn check_and_log(&self, params: &Params, output_statement: &Statement) -> Result<bool> {
let valid: bool = self.check(params, output_statement)?;
@ -278,59 +297,69 @@ impl Operation {
}
Ok(valid)
}
pub(crate) fn check_int_fn(
v1: &Value,
v2: &Value,
v3: &Value,
f: impl FnOnce(i64, i64) -> i64,
) -> Result<bool> {
let i1: i64 = v1.typed().try_into()?;
let i2: i64 = v2.typed().try_into()?;
let i3: i64 = v3.typed().try_into()?;
Ok(i1 == f(i2, i3))
}
/// Checks the given operation against a statement.
pub fn check(&self, params: &Params, output_statement: &Statement) -> Result<bool> {
use Statement::*;
match (self, output_statement) {
(Self::None, None) => Ok(true),
(Self::NewEntry, ValueOf(AnchoredKey { pod_id, .. }, _)) => Ok(pod_id == &SELF),
(Self::CopyStatement(s1), s2) => Ok(s1 == s2),
(Self::EqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), Equal(ak3, ak4)) => {
Ok(v1 == v2 && ak3 == ak1 && ak4 == ak2)
}
(Self::NotEqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), NotEqual(ak3, ak4)) => {
Ok(v1 != v2 && ak3 == ak1 && ak4 == ak2)
}
(Self::LtEqFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), LtEq(ak3, ak4)) => {
Ok(v1 <= v2 && ak3 == ak1 && ak4 == ak2)
}
(Self::LtFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), Lt(ak3, ak4)) => {
Ok(v1 < v2 && ak3 == ak1 && ak4 == ak2)
let deduction_err = || Error::invalid_deduction(self.clone(), output_statement.clone());
let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err);
let b = match (self, output_statement) {
(Self::None, None) => true,
(Self::NewEntry, Equal(ValueRef::Key(AnchoredKey { pod_id, .. }), _)) => {
pod_id == &SELF
}
(Self::CopyStatement(s1), s2) => s1 == s2,
(Self::EqualFromEntries(s1, s2), Equal(v3, v4)) => val(v3, s1)? == val(v4, s2)?,
(Self::NotEqualFromEntries(s1, s2), NotEqual(v3, v4)) => val(v3, s1)? != val(v4, s2)?,
(Self::LtEqFromEntries(s1, s2), LtEq(v3, v4)) => val(v3, s1)? <= val(v4, s2)?,
(Self::LtFromEntries(s1, s2), Lt(v3, v4)) => val(v3, s1)? < val(v4, s2)?,
(Self::ContainsFromEntries(_, _, _, _), Contains(_, _, _)) =>
/* TODO */
{
Ok(true)
true
}
(Self::NotContainsFromEntries(_, _, _), NotContains(_, _)) =>
/* TODO */
{
Ok(true)
true
}
(
Self::TransitiveEqualFromStatements(Equal(ak1, ak2), Equal(ak3, ak4)),
Equal(ak5, ak6),
) => Ok(ak2 == ak3 && ak5 == ak1 && ak6 == ak4),
(Self::LtToNotEqual(Lt(ak1, ak2)), NotEqual(ak3, ak4)) => Ok(ak1 == ak3 && ak2 == ak4),
(
Self::SumOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)),
SumOf(ak4, ak5, ak6),
) => {
let v1: i64 = v1.typed().try_into()?;
let v2: i64 = v2.typed().try_into()?;
let v3: i64 = v3.typed().try_into()?;
Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3)
) => ak2 == ak3 && ak5 == ak1 && ak6 == ak4,
(Self::LtToNotEqual(Lt(ak1, ak2)), NotEqual(ak3, ak4)) => ak1 == ak3 && ak2 == ak4,
(Self::SumOf(s1, s2, s3), SumOf(v4, v5, v6)) => {
Self::check_int_fn(&val(v4, s1)?, &val(v5, s2)?, &val(v6, s3)?, sum_op)?
}
(Self::ProductOf(s1, s2, s3), ProductOf(v4, v5, v6)) => {
Self::check_int_fn(&val(v4, s1)?, &val(v5, s2)?, &val(v6, s3)?, prod_op)?
}
(Self::MaxOf(s1, s2, s3), ProductOf(v4, v5, v6)) => {
Self::check_int_fn(&val(v4, s1)?, &val(v5, s2)?, &val(v6, s3)?, max_op)?
}
(Self::HashOf(s1, s2, s3), ProductOf(v4, v5, v6)) => {
val(v4, s1)? == hash_op(val(v5, s2)?, val(v6, s3)?)
}
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
if batch == &cpr.batch && index == &cpr.index =>
{
check_custom_pred(params, cpr, args, s_args)
check_custom_pred(params, cpr, args, s_args)?
}
_ => Err(Error::invalid_deduction(
self.clone(),
output_statement.clone(),
)),
}
_ => return Err(deduction_err()),
};
Ok(b)
}
}
@ -494,3 +523,15 @@ impl fmt::Display for Operation {
Ok(())
}
}
/// Returns the value associated with `output_ref`.
/// If `output_ref` is a concrete value, returns that value.
/// Otherwise, `output_ref` was constructed using an `Equal` statement, and `input_st`
/// must be that statement.
pub(crate) fn value_from_op(input_st: &Statement, output_ref: &ValueRef) -> Option<Value> {
match (input_st, output_ref) {
(Statement::None, ValueRef::Literal(v)) => Some(v.clone()),
(Statement::Equal(r1, ValueRef::Literal(v)), r2) if r1 == r2 => Some(v.clone()),
_ => None,
}
}

View file

@ -23,17 +23,16 @@ pub const OPERATION_AUX_F_LEN: usize = 2;
pub enum NativePredicate {
None = 0, // Always true
False = 1, // Always false
ValueOf = 2,
Equal = 3,
NotEqual = 4,
LtEq = 5,
Lt = 6,
Contains = 7,
NotContains = 8,
SumOf = 9,
ProductOf = 10,
MaxOf = 11,
HashOf = 12,
Equal = 2,
NotEqual = 3,
LtEq = 4,
Lt = 5,
Contains = 6,
NotContains = 7,
SumOf = 8,
ProductOf = 9,
MaxOf = 10,
HashOf = 11,
// Syntactic sugar predicates. These predicates are not supported by the backend. The
// frontend compiler is responsible of translating these predicates into the predicates above.
@ -168,33 +167,58 @@ impl fmt::Display for Predicate {
#[serde(tag = "predicate", content = "args")]
pub enum Statement {
None,
ValueOf(AnchoredKey, Value),
Equal(AnchoredKey, AnchoredKey),
NotEqual(AnchoredKey, AnchoredKey),
LtEq(AnchoredKey, AnchoredKey),
Lt(AnchoredKey, AnchoredKey),
Equal(ValueRef, ValueRef),
NotEqual(ValueRef, ValueRef),
LtEq(ValueRef, ValueRef),
Lt(ValueRef, ValueRef),
Contains(
/* root */ AnchoredKey,
/* key */ AnchoredKey,
/* value */ AnchoredKey,
/* root */ ValueRef,
/* key */ ValueRef,
/* value */ ValueRef,
),
NotContains(/* root */ AnchoredKey, /* key */ AnchoredKey),
SumOf(AnchoredKey, AnchoredKey, AnchoredKey),
ProductOf(AnchoredKey, AnchoredKey, AnchoredKey),
MaxOf(AnchoredKey, AnchoredKey, AnchoredKey),
HashOf(AnchoredKey, AnchoredKey, AnchoredKey),
NotContains(/* root */ ValueRef, /* key */ ValueRef),
SumOf(ValueRef, ValueRef, ValueRef),
ProductOf(ValueRef, ValueRef, ValueRef),
MaxOf(ValueRef, ValueRef, ValueRef),
HashOf(ValueRef, ValueRef, ValueRef),
Custom(CustomPredicateRef, Vec<WildcardValue>),
}
macro_rules! statement_constructor {
($var_name: ident, $cons_name: ident, 2) => {
pub fn $var_name(v1: impl Into<ValueRef>, v2: impl Into<ValueRef>) -> Self {
Self::$cons_name(v1.into(), v2.into())
}
};
($var_name: ident, $cons_name: ident, 3) => {
pub fn $var_name(
v1: impl Into<ValueRef>,
v2: impl Into<ValueRef>,
v3: impl Into<ValueRef>,
) -> Self {
Self::$cons_name(v1.into(), v2.into(), v3.into())
}
};
}
impl Statement {
pub fn is_none(&self) -> bool {
self == &Self::None
}
statement_constructor!(equal, Equal, 2);
statement_constructor!(not_equal, NotEqual, 2);
statement_constructor!(lt_eq, LtEq, 2);
statement_constructor!(lt, Lt, 2);
statement_constructor!(contains, Contains, 3);
statement_constructor!(not_contains, NotContains, 2);
statement_constructor!(sum_of, SumOf, 3);
statement_constructor!(product_of, ProductOf, 3);
statement_constructor!(max_of, MaxOf, 3);
statement_constructor!(hash_of, HashOf, 3);
pub fn predicate(&self) -> Predicate {
use Predicate::*;
match self {
Self::None => Native(NativePredicate::None),
Self::ValueOf(_, _) => Native(NativePredicate::ValueOf),
Self::Equal(_, _) => Native(NativePredicate::Equal),
Self::NotEqual(_, _) => Native(NativePredicate::NotEqual),
Self::LtEq(_, _) => Native(NativePredicate::LtEq),
@ -212,117 +236,66 @@ impl Statement {
use StatementArg::*;
match self.clone() {
Self::None => vec![],
Self::ValueOf(ak, v) => vec![Key(ak), Literal(v)],
Self::Equal(ak1, ak2) => vec![Key(ak1), Key(ak2)],
Self::NotEqual(ak1, ak2) => vec![Key(ak1), Key(ak2)],
Self::LtEq(ak1, ak2) => vec![Key(ak1), Key(ak2)],
Self::Lt(ak1, ak2) => vec![Key(ak1), Key(ak2)],
Self::Contains(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::NotContains(ak1, ak2) => vec![Key(ak1), Key(ak2)],
Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::HashOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::Equal(ak1, ak2) => vec![ak1.into(), ak2.into()],
Self::NotEqual(ak1, ak2) => vec![ak1.into(), ak2.into()],
Self::LtEq(ak1, ak2) => vec![ak1.into(), ak2.into()],
Self::Lt(ak1, ak2) => vec![ak1.into(), ak2.into()],
Self::Contains(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::NotContains(ak1, ak2) => vec![ak1.into(), ak2.into()],
Self::SumOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::ProductOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::MaxOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::HashOf(ak1, ak2, ak3) => vec![ak1.into(), ak2.into(), ak3.into()],
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(WildcardLiteral)),
}
}
pub fn as_entry(&self) -> Option<(&AnchoredKey, &Value)> {
if let Self::Equal(ValueRef::Key(k), ValueRef::Literal(v)) = self {
Some((k, v))
} else {
None
}
}
pub fn from_args(pred: Predicate, args: Vec<StatementArg>) -> Result<Self> {
use Predicate::*;
let st: Result<Self> = match pred {
Native(NativePredicate::None) => Ok(Self::None),
Native(NativePredicate::ValueOf) => {
if let (StatementArg::Key(a0), StatementArg::Literal(v1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::ValueOf(a0, v1))
} else {
Err(Error::incorrect_statements_args())
}
let st = match (pred, &args.as_slice()) {
(Native(NativePredicate::None), &[]) => Self::None,
(Native(NativePredicate::Equal), &[a1, a2]) => {
Self::Equal(a1.try_into()?, a2.try_into()?)
}
Native(NativePredicate::Equal) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::Equal(a0, a1))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::NotEqual), &[a1, a2]) => {
Self::NotEqual(a1.try_into()?, a2.try_into()?)
}
Native(NativePredicate::NotEqual) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::NotEqual(a0, a1))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::LtEq), &[a1, a2]) => {
Self::LtEq(a1.try_into()?, a2.try_into()?)
}
Native(NativePredicate::LtEq) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::LtEq(a0, a1))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::Lt), &[a1, a2]) => Self::Lt(a1.try_into()?, a2.try_into()?),
(Native(NativePredicate::Contains), &[a1, a2, a3]) => {
Self::Contains(a1.try_into()?, a2.try_into()?, a3.try_into()?)
}
Native(NativePredicate::Lt) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::Lt(a0, a1))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::NotContains), &[a1, a2]) => {
Self::NotContains(a1.try_into()?, a2.try_into()?)
}
Native(NativePredicate::Contains) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::Contains(a0, a1, a2))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::SumOf), &[a1, a2, a3]) => {
Self::SumOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
}
Native(NativePredicate::NotContains) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1)) =
(args[0].clone(), args[1].clone())
{
Ok(Self::NotContains(a0, a1))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::ProductOf), &[a1, a2, a3]) => {
Self::ProductOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
}
Native(NativePredicate::SumOf) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::SumOf(a0, a1, a2))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::MaxOf), &[a1, a2, a3]) => {
Self::MaxOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
}
Native(NativePredicate::ProductOf) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::ProductOf(a0, a1, a2))
} else {
Err(Error::incorrect_statements_args())
}
(Native(NativePredicate::HashOf), &[a1, a2, a3]) => {
Self::HashOf(a1.try_into()?, a2.try_into()?, a3.try_into()?)
}
Native(NativePredicate::MaxOf) => {
if let (StatementArg::Key(a0), StatementArg::Key(a1), StatementArg::Key(a2)) =
(args[0].clone(), args[1].clone(), args[2].clone())
{
Ok(Self::MaxOf(a0, a1, a2))
} else {
Err(Error::incorrect_statements_args())
}
(Native(np), _) => {
return Err(Error::custom(format!("Predicate {:?} is syntax sugar", np)))
}
Native(np) => Err(Error::custom(format!("Predicate {:?} is syntax sugar", np))),
BatchSelf(_) => unreachable!(),
Custom(cpr) => {
(BatchSelf(_), _) => unreachable!(),
(Custom(cpr), _) => {
let v_args: Result<Vec<WildcardValue>> = args
.iter()
.map(|x| match x {
@ -330,10 +303,10 @@ impl Statement {
_ => Err(Error::incorrect_statements_args()),
})
.collect();
Ok(Self::Custom(cpr, v_args?))
Self::Custom(cpr, v_args?)
}
};
st
Ok(st)
}
}
@ -437,6 +410,57 @@ impl ToFields for StatementArg {
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub enum ValueRef {
Literal(Value),
Key(AnchoredKey),
}
impl From<ValueRef> for StatementArg {
fn from(value: ValueRef) -> Self {
match value {
ValueRef::Literal(v) => StatementArg::Literal(v),
ValueRef::Key(v) => StatementArg::Key(v),
}
}
}
impl TryFrom<StatementArg> for ValueRef {
type Error = crate::middleware::Error;
fn try_from(value: StatementArg) -> std::result::Result<Self, Self::Error> {
match value {
StatementArg::Literal(v) => Ok(Self::Literal(v)),
StatementArg::Key(k) => Ok(Self::Key(k)),
_ => Err(Self::Error::invalid_statement_arg(
value,
"literal or key".to_string(),
)),
}
}
}
impl TryFrom<&StatementArg> for ValueRef {
type Error = crate::middleware::Error;
fn try_from(value: &StatementArg) -> std::result::Result<Self, Self::Error> {
value.clone().try_into()
}
}
impl From<AnchoredKey> for ValueRef {
fn from(value: AnchoredKey) -> Self {
Self::Key(value)
}
}
impl<T> From<T> for ValueRef
where
T: Into<Value>,
{
fn from(value: T) -> Self {
Self::Literal(value.into())
}
}
#[cfg(test)]
mod tests {
use super::*;