diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index a03cf94..638ecb2 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -1095,4 +1095,14 @@ pub mod tests { let std_common = &*cache_get_standard_rec_main_pod_common_circuit_data(); assert_eq!(std_common.0, main_common.0); } + + #[test] + fn test_negative_less_than_zero() -> frontend::Result<()> { + let params = Params::default(); + let mut builder = MainPodBuilder::new(¶ms, &DEFAULT_VD_SET); + builder.pub_op(frontend::Operation::lt(-1, 0))?; + let prover = Prover {}; + builder.prove(&prover)?; + Ok(()) + } } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 1df3192..158714f 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -408,8 +408,8 @@ impl MainPodBuilder { } } (LtFromEntries, &[a1, a2]) => { - let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; - let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + let (r1, v1) = a1.int_value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.int_value_and_ref().ok_or_else(native_arg_error)?; if v1 < v2 { Statement::lt(r1, r2) } else { @@ -417,10 +417,10 @@ impl MainPodBuilder { } } (LtEqFromEntries, &[a1, a2]) => { - let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?; - let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?; + let (r1, v1) = a1.int_value_and_ref().ok_or_else(native_arg_error)?; + let (r2, v2) = a2.int_value_and_ref().ok_or_else(native_arg_error)?; if v1 <= v2 { - Statement::not_equal(r1, r2) + Statement::lt_eq(r1, r2) } else { return Err(native_arg_error()); } diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index 514c58f..967928a 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -4,7 +4,7 @@ use crate::{ frontend::{MainPod, SignedPod}, middleware::{ AnchoredKey, CustomPredicateRef, NativeOperation, OperationAux, OperationType, Statement, - Value, ValueRef, + TypedValue, Value, ValueRef, }, }; @@ -33,6 +33,13 @@ impl OperationArg { _ => None, } } + + pub(crate) fn int_value_and_ref(&self) -> Option<(ValueRef, i64)> { + self.value_and_ref().and_then(|(r, v)| match v.typed() { + &TypedValue::Int(i) => Some((r, i)), + _ => None, + }) + } } impl fmt::Display for OperationArg { diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index a001989..aadc778 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -7,10 +7,7 @@ use hex::ToHex; use itertools::Itertools; use strum_macros::FromRepr; mod basetypes; -use std::{ - cmp::{Ordering, PartialEq, PartialOrd}, - hash, -}; +use std::{cmp::PartialEq, hash}; use containers::{Array, Dictionary, Set}; use schemars::JsonSchema; @@ -254,7 +251,7 @@ impl fmt::Display for TypedValue { } TypedValue::Set(s) => { write!(f, "#[")?; - let values: Vec<_> = s.set().iter().sorted().collect(); + let values: Vec<_> = s.set().iter().sorted_by_key(|k| k.raw()).collect(); for (i, v) in values.iter().enumerate() { if i > 0 { write!(f, ", ")?; @@ -461,18 +458,6 @@ impl PartialEq for Value { impl Eq for Value {} -impl PartialOrd for Value { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Value { - fn cmp(&self, other: &Self) -> Ordering { - self.raw.cmp(&other.raw) - } -} - impl hash::Hash for Value { fn hash(&self, state: &mut H) { self.raw.hash(state) diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index bb5a009..4b378e2 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -15,7 +15,7 @@ use crate::{ middleware::{ hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmpl, StatementTmplArg, - ToFields, Value, ValueRef, Wildcard, F, SELF, + ToFields, TypedValue, Value, ValueRef, Wildcard, F, SELF, }, }; @@ -415,6 +415,13 @@ impl Operation { use Statement::*; let deduction_err = || Error::invalid_deduction(self.clone(), output_statement.clone()); let val = |v, s| value_from_op(s, v).ok_or_else(deduction_err); + let int_val = |v, s| { + let v_op = value_from_op(s, v).ok_or_else(deduction_err)?; + match v_op.typed() { + &TypedValue::Int(i) => Ok(i), + _ => Err(deduction_err()), + } + }; let b = match (self, output_statement) { (Self::None, None) => true, (Self::NewEntry, Equal(ValueRef::Key(AnchoredKey { pod_id, .. }), _)) => { @@ -423,8 +430,8 @@ impl Operation { (Self::CopyStatement(s1), s2) => s1 == s2, (Self::EqualFromEntries(s1, s2), Equal(v3, v4)) => val(v3, s1)? == val(v4, s2)?, (Self::NotEqualFromEntries(s1, s2), NotEqual(v3, v4)) => val(v3, s1)? != val(v4, s2)?, - (Self::LtEqFromEntries(s1, s2), LtEq(v3, v4)) => val(v3, s1)? <= val(v4, s2)?, - (Self::LtFromEntries(s1, s2), Lt(v3, v4)) => val(v3, s1)? < val(v4, s2)?, + (Self::LtEqFromEntries(s1, s2), LtEq(v3, v4)) => int_val(v3, s1)? <= int_val(v4, s2)?, + (Self::LtFromEntries(s1, s2), Lt(v3, v4)) => int_val(v3, s1)? < int_val(v4, s2)?, ( Self::ContainsFromEntries(root_s, key_s, val_s, pf), Contains(root_v, key_v, val_v), diff --git a/src/middleware/serialization.rs b/src/middleware/serialization.rs index 81b86ff..68e6efb 100644 --- a/src/middleware/serialization.rs +++ b/src/middleware/serialization.rs @@ -139,7 +139,7 @@ where { let mut set = serializer.serialize_seq(Some(value.len()))?; let mut sorted_values: Vec<&Value> = value.iter().collect(); - sorted_values.sort(); + sorted_values.sort_by_key(|v| v.raw()); for v in sorted_values { set.serialize_element(v)?; }