Fix handling of Lt, LtEq (#393)

Changed the middleware to only allow comparison of integers and to
use the implementation of Ord for i64.  This matches the backend
behavior.

Also fixed a separate bug where LtEqFromEntries was producing a
NotEquals statement.
This commit is contained in:
Daniel Gulotta 2025-08-18 07:54:20 -07:00 committed by GitHub
parent 1508dd6126
commit f76197c602
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 36 additions and 27 deletions

View file

@ -1095,4 +1095,14 @@ pub mod tests {
let std_common = &*cache_get_standard_rec_main_pod_common_circuit_data(); let std_common = &*cache_get_standard_rec_main_pod_common_circuit_data();
assert_eq!(std_common.0, main_common.0); assert_eq!(std_common.0, main_common.0);
} }
#[test]
fn test_negative_less_than_zero() -> frontend::Result<()> {
let params = Params::default();
let mut builder = MainPodBuilder::new(&params, &DEFAULT_VD_SET);
builder.pub_op(frontend::Operation::lt(-1, 0))?;
let prover = Prover {};
builder.prove(&prover)?;
Ok(())
}
} }

View file

@ -408,8 +408,8 @@ impl MainPodBuilder {
} }
} }
(LtFromEntries, &[a1, a2]) => { (LtFromEntries, &[a1, a2]) => {
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; let (r1, v1) = a1.int_value_and_ref().ok_or_else(native_arg_error)?;
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; let (r2, v2) = a2.int_value_and_ref().ok_or_else(native_arg_error)?;
if v1 < v2 { if v1 < v2 {
Statement::lt(r1, r2) Statement::lt(r1, r2)
} else { } else {
@ -417,10 +417,10 @@ impl MainPodBuilder {
} }
} }
(LtEqFromEntries, &[a1, a2]) => { (LtEqFromEntries, &[a1, a2]) => {
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; let (r1, v1) = a1.int_value_and_ref().ok_or_else(native_arg_error)?;
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; let (r2, v2) = a2.int_value_and_ref().ok_or_else(native_arg_error)?;
if v1 <= v2 { if v1 <= v2 {
Statement::not_equal(r1, r2) Statement::lt_eq(r1, r2)
} else { } else {
return Err(native_arg_error()); return Err(native_arg_error());
} }

View file

@ -4,7 +4,7 @@ use crate::{
frontend::{MainPod, SignedPod}, frontend::{MainPod, SignedPod},
middleware::{ middleware::{
AnchoredKey, CustomPredicateRef, NativeOperation, OperationAux, OperationType, Statement, AnchoredKey, CustomPredicateRef, NativeOperation, OperationAux, OperationType, Statement,
Value, ValueRef, TypedValue, Value, ValueRef,
}, },
}; };
@ -33,6 +33,13 @@ impl OperationArg {
_ => None, _ => None,
} }
} }
pub(crate) fn int_value_and_ref(&self) -> Option<(ValueRef, i64)> {
self.value_and_ref().and_then(|(r, v)| match v.typed() {
&TypedValue::Int(i) => Some((r, i)),
_ => None,
})
}
} }
impl fmt::Display for OperationArg { impl fmt::Display for OperationArg {

View file

@ -7,10 +7,7 @@ use hex::ToHex;
use itertools::Itertools; use itertools::Itertools;
use strum_macros::FromRepr; use strum_macros::FromRepr;
mod basetypes; mod basetypes;
use std::{ use std::{cmp::PartialEq, hash};
cmp::{Ordering, PartialEq, PartialOrd},
hash,
};
use containers::{Array, Dictionary, Set}; use containers::{Array, Dictionary, Set};
use schemars::JsonSchema; use schemars::JsonSchema;
@ -254,7 +251,7 @@ impl fmt::Display for TypedValue {
} }
TypedValue::Set(s) => { TypedValue::Set(s) => {
write!(f, "#[")?; write!(f, "#[")?;
let values: Vec<_> = s.set().iter().sorted().collect(); let values: Vec<_> = s.set().iter().sorted_by_key(|k| k.raw()).collect();
for (i, v) in values.iter().enumerate() { for (i, v) in values.iter().enumerate() {
if i > 0 { if i > 0 {
write!(f, ", ")?; write!(f, ", ")?;
@ -461,18 +458,6 @@ impl PartialEq for Value {
impl Eq for Value {} impl Eq for Value {}
impl PartialOrd for Value {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Value {
fn cmp(&self, other: &Self) -> Ordering {
self.raw.cmp(&other.raw)
}
}
impl hash::Hash for Value { impl hash::Hash for Value {
fn hash<H: hash::Hasher>(&self, state: &mut H) { fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.raw.hash(state) self.raw.hash(state)

View file

@ -15,7 +15,7 @@ use crate::{
middleware::{ middleware::{
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, NativePredicate, hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, NativePredicate,
Params, Predicate, Result, Statement, StatementArg, StatementTmpl, StatementTmplArg, Params, Predicate, Result, Statement, StatementArg, StatementTmpl, StatementTmplArg,
ToFields, Value, ValueRef, Wildcard, F, SELF, ToFields, TypedValue, Value, ValueRef, Wildcard, F, SELF,
}, },
}; };
@ -415,6 +415,13 @@ impl Operation {
use Statement::*; use Statement::*;
let deduction_err = || Error::invalid_deduction(self.clone(), output_statement.clone()); let deduction_err = || Error::invalid_deduction(self.clone(), output_statement.clone());
let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err); let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err);
let int_val = |v, s| {
let v_op = value_from_op(s, v).ok_or_else(deduction_err)?;
match v_op.typed() {
&TypedValue::Int(i) => Ok(i),
_ => Err(deduction_err()),
}
};
let b = match (self, output_statement) { let b = match (self, output_statement) {
(Self::None, None) => true, (Self::None, None) => true,
(Self::NewEntry, Equal(ValueRef::Key(AnchoredKey { pod_id, .. }), _)) => { (Self::NewEntry, Equal(ValueRef::Key(AnchoredKey { pod_id, .. }), _)) => {
@ -423,8 +430,8 @@ impl Operation {
(Self::CopyStatement(s1), s2) => s1 == s2, (Self::CopyStatement(s1), s2) => s1 == s2,
(Self::EqualFromEntries(s1, s2), Equal(v3, v4)) => val(v3, s1)? == val(v4, s2)?, (Self::EqualFromEntries(s1, s2), Equal(v3, v4)) => val(v3, s1)? == val(v4, s2)?,
(Self::NotEqualFromEntries(s1, s2), NotEqual(v3, v4)) => val(v3, s1)? != val(v4, s2)?, (Self::NotEqualFromEntries(s1, s2), NotEqual(v3, v4)) => val(v3, s1)? != val(v4, s2)?,
(Self::LtEqFromEntries(s1, s2), LtEq(v3, v4)) => val(v3, s1)? <= val(v4, s2)?, (Self::LtEqFromEntries(s1, s2), LtEq(v3, v4)) => int_val(v3, s1)? <= int_val(v4, s2)?,
(Self::LtFromEntries(s1, s2), Lt(v3, v4)) => val(v3, s1)? < val(v4, s2)?, (Self::LtFromEntries(s1, s2), Lt(v3, v4)) => int_val(v3, s1)? < int_val(v4, s2)?,
( (
Self::ContainsFromEntries(root_s, key_s, val_s, pf), Self::ContainsFromEntries(root_s, key_s, val_s, pf),
Contains(root_v, key_v, val_v), Contains(root_v, key_v, val_v),

View file

@ -139,7 +139,7 @@ where
{ {
let mut set = serializer.serialize_seq(Some(value.len()))?; let mut set = serializer.serialize_seq(Some(value.len()))?;
let mut sorted_values: Vec<&Value> = value.iter().collect(); let mut sorted_values: Vec<&Value> = value.iter().collect();
sorted_values.sort(); sorted_values.sort_by_key(|v| v.raw());
for v in sorted_values { for v in sorted_values {
set.serialize_element(v)?; set.serialize_element(v)?;
} }