Allow literals in statements (#276)

Implements #229 and #261.
This commit is contained in:
Daniel Gulotta 2025-06-13 10:27:19 -07:00 committed by GitHub
parent 21ab3c2d0d
commit 7d0d3ad769
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 992 additions and 825 deletions

View file

@ -7,9 +7,9 @@ use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::primitives::merkletree::MerkleProof,
middleware::{
custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error,
NativePredicate, Params, Predicate, Result, SelfOrWildcard, Statement, StatementArg,
StatementTmplArg, ToFields, Wildcard, WildcardValue, F, SELF,
custom::KeyOrWildcard, hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef,
Error, NativePredicate, Params, Predicate, Result, SelfOrWildcard, Statement, StatementArg,
StatementTmplArg, ToFields, Value, ValueRef, Wildcard, WildcardValue, F, SELF,
},
};
@ -86,6 +86,12 @@ pub enum NativeOperation {
GtToNotEqual = 1008,
}
impl NativeOperation {
pub fn is_syntactic_sugar(self) -> bool {
(self as usize) >= 1000
}
}
impl ToFields for NativeOperation {
fn to_fields(&self, _params: &Params) -> Vec<F> {
vec![F::from_canonical_u64(*self as u64)]
@ -100,7 +106,7 @@ impl OperationType {
match self {
OperationType::Native(native_op) => match native_op {
NativeOperation::None => Some(Predicate::Native(NativePredicate::None)),
NativeOperation::NewEntry => Some(Predicate::Native(NativePredicate::ValueOf)),
NativeOperation::NewEntry => Some(Predicate::Native(NativePredicate::Equal)),
NativeOperation::CopyStatement => None,
NativeOperation::EqualFromEntries => {
Some(Predicate::Native(NativePredicate::Equal))
@ -161,6 +167,22 @@ pub enum Operation {
Custom(CustomPredicateRef, Vec<Statement>),
}
pub(crate) fn sum_op(x: i64, y: i64) -> i64 {
x + y
}
pub(crate) fn prod_op(x: i64, y: i64) -> i64 {
x * y
}
pub(crate) fn max_op(x: i64, y: i64) -> i64 {
x.max(y)
}
pub(crate) fn hash_op(x: Value, y: Value) -> Value {
Value::from(hash_values(&[x, y]))
}
impl Operation {
pub fn op_type(&self) -> OperationType {
type OT = OperationType;
@ -219,56 +241,53 @@ impl Operation {
pub fn op(op_code: OperationType, args: &[Statement], aux: &OperationAux) -> Result<Self> {
type OA = OperationAux;
type NO = NativeOperation;
let arg_tup = (
args.first().cloned(),
args.get(1).cloned(),
args.get(2).cloned(),
);
Ok(match op_code {
OperationType::Native(o) => match (o, arg_tup, aux.clone(), args.len()) {
(NO::None, (None, None, None), OA::None, 0) => Self::None,
(NO::NewEntry, (None, None, None), OA::None, 0) => Self::NewEntry,
(NO::CopyStatement, (Some(s), None, None), OA::None, 1) => Self::CopyStatement(s),
(NO::EqualFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => {
Self::EqualFromEntries(s1, s2)
OperationType::Native(o) => match (o, &args, aux.clone()) {
(NO::None, &[], OA::None) => Self::None,
(NO::NewEntry, &[], OA::None) => Self::NewEntry,
(NO::CopyStatement, &[s], OA::None) => Self::CopyStatement(s.clone()),
(NO::EqualFromEntries, &[s1, s2], OA::None) => {
Self::EqualFromEntries(s1.clone(), s2.clone())
}
(NO::NotEqualFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => {
Self::NotEqualFromEntries(s1, s2)
(NO::NotEqualFromEntries, &[s1, s2], OA::None) => {
Self::NotEqualFromEntries(s1.clone(), s2.clone())
}
(NO::LtEqFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => {
Self::LtEqFromEntries(s1, s2)
(NO::LtEqFromEntries, &[s1, s2], OA::None) => {
Self::LtEqFromEntries(s1.clone(), s2.clone())
}
(NO::LtFromEntries, (Some(s1), Some(s2), None), OA::None, 2) => {
Self::LtFromEntries(s1, s2)
(NO::LtFromEntries, &[s1, s2], OA::None) => {
Self::LtFromEntries(s1.clone(), s2.clone())
}
(
NO::ContainsFromEntries,
(Some(s1), Some(s2), Some(s3)),
OA::MerkleProof(pf),
3,
) => Self::ContainsFromEntries(s1, s2, s3, pf),
(
NO::NotContainsFromEntries,
(Some(s1), Some(s2), None),
OA::MerkleProof(pf),
2,
) => Self::NotContainsFromEntries(s1, s2, pf),
(NO::SumOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => Self::SumOf(s1, s2, s3),
(NO::ProductOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => {
Self::ProductOf(s1, s2, s3)
(NO::ContainsFromEntries, &[s1, s2, s3], OA::MerkleProof(pf)) => {
Self::ContainsFromEntries(s1.clone(), s2.clone(), s3.clone(), pf)
}
(NO::MaxOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => Self::MaxOf(s1, s2, s3),
(NO::HashOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => {
Self::HashOf(s1, s2, s3)
(NO::NotContainsFromEntries, &[s1, s2], OA::MerkleProof(pf)) => {
Self::NotContainsFromEntries(s1.clone(), s2.clone(), pf)
}
(NO::SumOf, &[s1, s2, s3], OA::None) => {
Self::SumOf(s1.clone(), s2.clone(), s3.clone())
}
(NO::ProductOf, &[s1, s2, s3], OA::None) => {
Self::ProductOf(s1.clone(), s2.clone(), s3.clone())
}
(NO::MaxOf, &[s1, s2, s3], OA::None) => {
Self::MaxOf(s1.clone(), s2.clone(), s3.clone())
}
(NO::HashOf, &[s1, s2, s3], OA::None) => {
Self::HashOf(s1.clone(), s2.clone(), s3.clone())
}
_ => Err(Error::custom(format!(
"Ill-formed operation {:?} with arguments {:?}.",
op_code, args
"Ill-formed operation {:?} with {} arguments {:?} and aux {:?}.",
op_code,
args.len(),
args,
aux
)))?,
},
OperationType::Custom(cpr) => Self::Custom(cpr, args.to_vec()),
})
}
/// Checks the given operation against a statement, and prints information if the check does not pass
pub fn check_and_log(&self, params: &Params, output_statement: &Statement) -> Result<bool> {
let valid: bool = self.check(params, output_statement)?;
@ -278,59 +297,69 @@ impl Operation {
}
Ok(valid)
}
pub(crate) fn check_int_fn(
v1: &Value,
v2: &Value,
v3: &Value,
f: impl FnOnce(i64, i64) -> i64,
) -> Result<bool> {
let i1: i64 = v1.typed().try_into()?;
let i2: i64 = v2.typed().try_into()?;
let i3: i64 = v3.typed().try_into()?;
Ok(i1 == f(i2, i3))
}
/// Checks the given operation against a statement.
pub fn check(&self, params: &Params, output_statement: &Statement) -> Result<bool> {
use Statement::*;
match (self, output_statement) {
(Self::None, None) => Ok(true),
(Self::NewEntry, ValueOf(AnchoredKey { pod_id, .. }, _)) => Ok(pod_id == &SELF),
(Self::CopyStatement(s1), s2) => Ok(s1 == s2),
(Self::EqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), Equal(ak3, ak4)) => {
Ok(v1 == v2 && ak3 == ak1 && ak4 == ak2)
}
(Self::NotEqualFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), NotEqual(ak3, ak4)) => {
Ok(v1 != v2 && ak3 == ak1 && ak4 == ak2)
}
(Self::LtEqFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), LtEq(ak3, ak4)) => {
Ok(v1 <= v2 && ak3 == ak1 && ak4 == ak2)
}
(Self::LtFromEntries(ValueOf(ak1, v1), ValueOf(ak2, v2)), Lt(ak3, ak4)) => {
Ok(v1 < v2 && ak3 == ak1 && ak4 == ak2)
let deduction_err = || Error::invalid_deduction(self.clone(), output_statement.clone());
let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err);
let b = match (self, output_statement) {
(Self::None, None) => true,
(Self::NewEntry, Equal(ValueRef::Key(AnchoredKey { pod_id, .. }), _)) => {
pod_id == &SELF
}
(Self::CopyStatement(s1), s2) => s1 == s2,
(Self::EqualFromEntries(s1, s2), Equal(v3, v4)) => val(v3, s1)? == val(v4, s2)?,
(Self::NotEqualFromEntries(s1, s2), NotEqual(v3, v4)) => val(v3, s1)? != val(v4, s2)?,
(Self::LtEqFromEntries(s1, s2), LtEq(v3, v4)) => val(v3, s1)? <= val(v4, s2)?,
(Self::LtFromEntries(s1, s2), Lt(v3, v4)) => val(v3, s1)? < val(v4, s2)?,
(Self::ContainsFromEntries(_, _, _, _), Contains(_, _, _)) =>
/* TODO */
{
Ok(true)
true
}
(Self::NotContainsFromEntries(_, _, _), NotContains(_, _)) =>
/* TODO */
{
Ok(true)
true
}
(
Self::TransitiveEqualFromStatements(Equal(ak1, ak2), Equal(ak3, ak4)),
Equal(ak5, ak6),
) => Ok(ak2 == ak3 && ak5 == ak1 && ak6 == ak4),
(Self::LtToNotEqual(Lt(ak1, ak2)), NotEqual(ak3, ak4)) => Ok(ak1 == ak3 && ak2 == ak4),
(
Self::SumOf(ValueOf(ak1, v1), ValueOf(ak2, v2), ValueOf(ak3, v3)),
SumOf(ak4, ak5, ak6),
) => {
let v1: i64 = v1.typed().try_into()?;
let v2: i64 = v2.typed().try_into()?;
let v3: i64 = v3.typed().try_into()?;
Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3)
) => ak2 == ak3 && ak5 == ak1 && ak6 == ak4,
(Self::LtToNotEqual(Lt(ak1, ak2)), NotEqual(ak3, ak4)) => ak1 == ak3 && ak2 == ak4,
(Self::SumOf(s1, s2, s3), SumOf(v4, v5, v6)) => {
Self::check_int_fn(&val(v4, s1)?, &val(v5, s2)?, &val(v6, s3)?, sum_op)?
}
(Self::ProductOf(s1, s2, s3), ProductOf(v4, v5, v6)) => {
Self::check_int_fn(&val(v4, s1)?, &val(v5, s2)?, &val(v6, s3)?, prod_op)?
}
(Self::MaxOf(s1, s2, s3), ProductOf(v4, v5, v6)) => {
Self::check_int_fn(&val(v4, s1)?, &val(v5, s2)?, &val(v6, s3)?, max_op)?
}
(Self::HashOf(s1, s2, s3), ProductOf(v4, v5, v6)) => {
val(v4, s1)? == hash_op(val(v5, s2)?, val(v6, s3)?)
}
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
if batch == &cpr.batch && index == &cpr.index =>
{
check_custom_pred(params, cpr, args, s_args)
check_custom_pred(params, cpr, args, s_args)?
}
_ => Err(Error::invalid_deduction(
self.clone(),
output_statement.clone(),
)),
}
_ => return Err(deduction_err()),
};
Ok(b)
}
}
@ -494,3 +523,15 @@ impl fmt::Display for Operation {
Ok(())
}
}
/// Returns the value associated with `output_ref`.
/// If `output_ref` is a concrete value, returns that value.
/// Otherwise, `output_ref` was constructed using an `Equal` statement, and `input_st`
/// must be that statement.
pub(crate) fn value_from_op(input_st: &Statement, output_ref: &ValueRef) -> Option<Value> {
match (input_st, output_ref) {
(Statement::None, ValueRef::Literal(v)) => Some(v.clone()),
(Statement::Equal(r1, ValueRef::Literal(v)), r2) if r1 == r2 => Some(v.clone()),
_ => None,
}
}