feat(backend): implement gadgets for remaining ops (#228)
* Implement gadgets for remaining ops * Use overflowing arithmetic ops * Code review * Formatting
This commit is contained in:
parent
b2cb563eb6
commit
4fa9e20ecd
2 changed files with 570 additions and 3 deletions
|
|
@ -565,6 +565,26 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
||||||
/// and `y` each consist of two `u32` limbs.
|
/// and `y` each consist of two `u32` limbs.
|
||||||
fn assert_i64_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget);
|
fn assert_i64_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget);
|
||||||
|
|
||||||
|
/// Computes `x + y` assuming `x` and `y` are assigned `i64`
|
||||||
|
/// values.
|
||||||
|
fn i64_wrapping_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
||||||
|
|
||||||
|
/// Computes `x + y` assuming `x` and `y` are assigned `i64`
|
||||||
|
/// values. Enforces no overflow.
|
||||||
|
fn i64_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
||||||
|
|
||||||
|
/// Computes `x * y` assuming `x` and `y` are assigned `i64`
|
||||||
|
/// values. Enforces no overflow.
|
||||||
|
fn i64_mul(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
||||||
|
|
||||||
|
/// Computes the canonical involution of `x` in `i64`, i.e. the
|
||||||
|
/// negation of `x` as an `i64`.
|
||||||
|
fn i64_inv(&mut self, x: ValueTarget) -> ValueTarget;
|
||||||
|
|
||||||
|
/// Computes the absolute value of `x` *as an element of
|
||||||
|
/// `i64`*. Includes sign indicator (true if negative).
|
||||||
|
fn i64_abs(&mut self, x: ValueTarget) -> (ValueTarget, BoolTarget);
|
||||||
|
|
||||||
/// Creates value target that is a hash of two given values.
|
/// Creates value target that is a hash of two given values.
|
||||||
fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
||||||
|
|
||||||
|
|
@ -716,6 +736,138 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
||||||
assert_limb_lt(self, lhs, rhs);
|
assert_limb_lt(self, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn i64_wrapping_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
|
||||||
|
let zero = self.zero();
|
||||||
|
|
||||||
|
// Add components and carry where appropriate.
|
||||||
|
let (_, sum) = std::iter::zip(&x.elements[..2], &y.elements[..2]).fold(
|
||||||
|
(zero, vec![]),
|
||||||
|
|(carry, out), (&a, &b)| {
|
||||||
|
let sum = [a, b, carry]
|
||||||
|
.into_iter()
|
||||||
|
.reduce(|alpha, beta| self.add(alpha, beta))
|
||||||
|
.expect("Iterator should be nonempty.");
|
||||||
|
let (sum_residue, sum_quotient) = self.split_low_high(sum, NUM_BITS, F::BITS);
|
||||||
|
(sum_quotient, [out, vec![sum_residue]].concat())
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
ValueTarget::from_slice(&[sum[0], sum[1], zero, zero])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn i64_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
|
||||||
|
let zero = self.zero();
|
||||||
|
let sum = self.i64_wrapping_add(x, y);
|
||||||
|
|
||||||
|
// Overflow check.
|
||||||
|
let x_is_negative = self.i64_is_negative(x);
|
||||||
|
let x_is_nonnegative = self.not(x_is_negative);
|
||||||
|
let y_is_negative = self.i64_is_negative(y);
|
||||||
|
let y_is_nonnegative = self.not(y_is_negative);
|
||||||
|
|
||||||
|
let sum_is_negative = self.i64_is_negative(sum);
|
||||||
|
let sum_is_nonnegative = self.not(sum_is_negative);
|
||||||
|
|
||||||
|
let overflow_conditions = [
|
||||||
|
self.all([x_is_negative, y_is_negative, sum_is_nonnegative]),
|
||||||
|
self.all([x_is_nonnegative, y_is_nonnegative, sum_is_negative]),
|
||||||
|
];
|
||||||
|
|
||||||
|
let overflow = self.any(overflow_conditions);
|
||||||
|
|
||||||
|
self.connect(overflow.target, zero);
|
||||||
|
|
||||||
|
sum
|
||||||
|
}
|
||||||
|
|
||||||
|
fn i64_mul(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
|
||||||
|
let zero = self.zero();
|
||||||
|
let i64_min = ValueTarget::from_slice(&self.constants(&RawValue::from(i64::MIN).0));
|
||||||
|
let (abs_x, x_is_negative) = self.i64_abs(x);
|
||||||
|
let (abs_y, y_is_negative) = self.i64_abs(y);
|
||||||
|
|
||||||
|
// Sign indicators.
|
||||||
|
let same_sign_ind = self.is_equal(x_is_negative.target, y_is_negative.target);
|
||||||
|
let prod_sign = self.not(same_sign_ind);
|
||||||
|
|
||||||
|
// Determine product of absolute values.
|
||||||
|
let x = abs_x.elements[..2].to_vec();
|
||||||
|
let y = abs_y.elements[..2].to_vec();
|
||||||
|
|
||||||
|
let prods = [
|
||||||
|
self.mul(x[0], y[0]),
|
||||||
|
self.mul(x[0], y[1]),
|
||||||
|
self.mul(x[1], y[0]),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.map(|p| self.split_low_high(p, NUM_BITS, F::BITS))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let prod_lower = prods[0].0;
|
||||||
|
|
||||||
|
let (prod_upper, _) = {
|
||||||
|
let sum1 = self.add(prods[1].0, prods[2].0);
|
||||||
|
let sum2 = self.add(sum1, prods[0].1);
|
||||||
|
self.split_low_high(sum2, NUM_BITS, F::BITS)
|
||||||
|
};
|
||||||
|
|
||||||
|
let abs_prod = ValueTarget::from_slice(&[prod_lower, prod_upper, zero, zero]);
|
||||||
|
|
||||||
|
// Overflow check: The latter two products in `prods` should
|
||||||
|
// have zero higher-order coefficients.
|
||||||
|
let no_spillovers = [
|
||||||
|
self.is_equal(prods[1].1, zero),
|
||||||
|
self.is_equal(prods[2].1, zero),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.reduce(|a, b| self.and(a, b))
|
||||||
|
.expect("Iterator should be nonempty.");
|
||||||
|
|
||||||
|
// Overflow check: The product of the higher-order
|
||||||
|
// coefficients should be zero.
|
||||||
|
let higher_prod = self.mul(x[1], y[1]);
|
||||||
|
let higher_prod_is_zero = self.is_equal(higher_prod, zero);
|
||||||
|
|
||||||
|
// Overflow check: The product of the absolute values is
|
||||||
|
// either nonnegative or negative and equal to `i64::MIN`.
|
||||||
|
let abs_prod_is_negative = self.i64_is_negative(abs_prod);
|
||||||
|
let abs_prod_is_nonnegative = self.not(abs_prod_is_negative);
|
||||||
|
let abs_prod_is_min = self.is_equal_slice(&abs_prod.elements, &i64_min.elements);
|
||||||
|
let abs_prod_sign_ok = self.and(abs_prod_is_min, prod_sign);
|
||||||
|
let abs_prod_sign_ok = self.or(abs_prod_sign_ok, abs_prod_is_nonnegative);
|
||||||
|
|
||||||
|
// Combine the above conditions.
|
||||||
|
let no_overflow = self.and(abs_prod_sign_ok, higher_prod_is_zero);
|
||||||
|
let no_overflow = self.and(no_overflow, no_spillovers);
|
||||||
|
self.assert_one(no_overflow.target);
|
||||||
|
|
||||||
|
// Take sign into account.
|
||||||
|
let minus_abs_prod = self.i64_inv(abs_prod);
|
||||||
|
|
||||||
|
self.select_value(prod_sign, minus_abs_prod, abs_prod)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn i64_inv(&mut self, x: ValueTarget) -> ValueTarget {
|
||||||
|
let zero = self.zero();
|
||||||
|
let one = ValueTarget::one(self);
|
||||||
|
let u32_max = self.constant(F::from_canonical_u32(u32::MAX));
|
||||||
|
|
||||||
|
let flipped_x = ValueTarget::from_slice(&[
|
||||||
|
self.sub(u32_max, x.elements[0]),
|
||||||
|
self.sub(u32_max, x.elements[1]),
|
||||||
|
zero,
|
||||||
|
zero,
|
||||||
|
]);
|
||||||
|
|
||||||
|
self.i64_wrapping_add(one, flipped_x)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn i64_abs(&mut self, x: ValueTarget) -> (ValueTarget, BoolTarget) {
|
||||||
|
let x_is_negative = self.i64_is_negative(x);
|
||||||
|
let minus_x = self.i64_inv(x);
|
||||||
|
(self.select_value(x_is_negative, minus_x, x), x_is_negative)
|
||||||
|
}
|
||||||
|
|
||||||
fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
|
fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
|
||||||
ValueTarget::from_slice(
|
ValueTarget::from_slice(
|
||||||
&self
|
&self
|
||||||
|
|
@ -795,9 +947,13 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
pub(crate) mod tests {
|
||||||
|
use anyhow::anyhow;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use plonky2::plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig};
|
use plonky2::plonk::{
|
||||||
|
circuit_builder::CircuitBuilder, circuit_data::CircuitConfig,
|
||||||
|
config::PoseidonGoldilocksConfig,
|
||||||
|
};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{
|
use crate::{
|
||||||
|
|
@ -808,6 +964,48 @@ mod tests {
|
||||||
middleware::CustomPredicateBatch,
|
middleware::CustomPredicateBatch,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub(crate) const I64_TEST_PAIRS: [(i64, i64); 36] = [
|
||||||
|
// Nonnegative numbers
|
||||||
|
(0, 0),
|
||||||
|
(0, 50),
|
||||||
|
(35, 50),
|
||||||
|
(483748374, 221672),
|
||||||
|
(2, 1 << 31),
|
||||||
|
(2, 1 << 62),
|
||||||
|
(0, 1 << 62),
|
||||||
|
(1 << 31, 1 << 62),
|
||||||
|
(1 << 32, 1 << 32),
|
||||||
|
(1 << 62, 1 << 62),
|
||||||
|
(0, i64::MAX),
|
||||||
|
(i64::MAX, 1 << 62),
|
||||||
|
(i64::MAX, i64::MAX),
|
||||||
|
// Negative numbers
|
||||||
|
(-35, -50),
|
||||||
|
(-483748374, -221672),
|
||||||
|
(-(1 << 33), -1),
|
||||||
|
(-(1 << 32), -(1 << 32)),
|
||||||
|
(-(1 << 33), -(1 << 29)),
|
||||||
|
(-(1 << 33), -(1 << 30)),
|
||||||
|
(-(1 << 33), -(1 << 62)),
|
||||||
|
(-(1 << 62), -(1 << 62)),
|
||||||
|
(i64::MIN, -1),
|
||||||
|
(i64::MIN, -(1 << 31)),
|
||||||
|
(i64::MIN, -(1 << 62)),
|
||||||
|
(i64::MIN, i64::MIN),
|
||||||
|
// Mix of numbers
|
||||||
|
(-35, 50),
|
||||||
|
(-483748374, 221672),
|
||||||
|
(-(1 << 32), (1 << 32)),
|
||||||
|
(-(1 << 33), (1 << 30) - 1),
|
||||||
|
(-(1 << 33), (1 << 30)),
|
||||||
|
(-(1 << 62), (1 << 62)),
|
||||||
|
(i64::MIN, 0),
|
||||||
|
(i64::MIN, 1),
|
||||||
|
(i64::MIN, 1 << 31),
|
||||||
|
(i64::MIN, 1 << 62),
|
||||||
|
(i64::MIN, i64::MAX),
|
||||||
|
];
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn custom_predicate_target() -> frontend::Result<()> {
|
fn custom_predicate_target() -> frontend::Result<()> {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
|
|
@ -828,7 +1026,7 @@ mod tests {
|
||||||
|
|
||||||
// generate & verify proof
|
// generate & verify proof
|
||||||
let data = builder.build::<C>();
|
let data = builder.build::<C>();
|
||||||
let proof = data.prove(pw).expect(&format!("predicate {}", i));
|
let proof = data.prove(pw).unwrap_or_else(|_| panic!("predicate {}", i));
|
||||||
data.verify(proof.clone()).unwrap();
|
data.verify(proof.clone()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -912,4 +1110,73 @@ mod tests {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_i64_addition() -> Result<(), anyhow::Error> {
|
||||||
|
// Circuit declaration
|
||||||
|
let config = CircuitConfig::standard_recursion_config();
|
||||||
|
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||||
|
let x_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
|
||||||
|
let y_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
|
||||||
|
|
||||||
|
let sum_target = builder.i64_add(x_target, y_target);
|
||||||
|
|
||||||
|
let data = builder.build::<PoseidonGoldilocksConfig>();
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| {
|
||||||
|
let mut pw = PartialWitness::<F>::new();
|
||||||
|
let (sum, overflow) = x.overflowing_add(y);
|
||||||
|
pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields(¶ms))?;
|
||||||
|
pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields(¶ms))?;
|
||||||
|
pw.set_target_arr(
|
||||||
|
&sum_target.elements,
|
||||||
|
&RawValue::from(sum).to_fields(¶ms),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let proof = data.prove(pw);
|
||||||
|
|
||||||
|
match (overflow, proof) {
|
||||||
|
(false, Ok(pf)) => data.verify(pf),
|
||||||
|
(false, Err(e)) => Err(anyhow!("Proof failure despite no overflow: {}", e)),
|
||||||
|
(true, Ok(_)) => Err(anyhow!("Proof success despite overflow.")),
|
||||||
|
(true, Err(_)) => Ok(()),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_i64_multiplication() -> Result<(), anyhow::Error> {
|
||||||
|
// Circuit declaration
|
||||||
|
let config = CircuitConfig::standard_recursion_config();
|
||||||
|
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||||
|
let x_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
|
||||||
|
let y_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
|
||||||
|
|
||||||
|
let prod_target = builder.i64_mul(x_target, y_target);
|
||||||
|
|
||||||
|
let data = builder.build::<PoseidonGoldilocksConfig>();
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| {
|
||||||
|
println!("{}, {}", x, y);
|
||||||
|
let mut pw = PartialWitness::<F>::new();
|
||||||
|
let (prod, overflow) = x.overflowing_mul(y);
|
||||||
|
pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields(¶ms))?;
|
||||||
|
pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields(¶ms))?;
|
||||||
|
pw.set_target_arr(
|
||||||
|
&prod_target.elements,
|
||||||
|
&RawValue::from(prod).to_fields(¶ms),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let proof = data.prove(pw);
|
||||||
|
|
||||||
|
match (overflow, proof) {
|
||||||
|
(false, Ok(pf)) => data.verify(pf),
|
||||||
|
(false, Err(e)) => Err(anyhow!("Proof failure despite no overflow: {}", e)),
|
||||||
|
(true, Ok(_)) => Err(anyhow!("Proof success despite overflow.")),
|
||||||
|
(true, Err(_)) => Ok(()),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,9 @@ impl OperationVerifyGadget {
|
||||||
self.eval_transitive_eq(builder, st, op, &resolved_op_args),
|
self.eval_transitive_eq(builder, st, op, &resolved_op_args),
|
||||||
self.eval_lt_to_neq(builder, st, op, &resolved_op_args),
|
self.eval_lt_to_neq(builder, st, op, &resolved_op_args),
|
||||||
self.eval_hash_of(builder, st, op, &resolved_op_args),
|
self.eval_hash_of(builder, st, op, &resolved_op_args),
|
||||||
|
self.eval_sum_of(builder, st, op, &resolved_op_args),
|
||||||
|
self.eval_product_of(builder, st, op, &resolved_op_args),
|
||||||
|
self.eval_max_of(builder, st, op, &resolved_op_args),
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
// Skip these if there are no resolved Merkle claims
|
// Skip these if there are no resolved Merkle claims
|
||||||
|
|
@ -386,6 +389,121 @@ impl OperationVerifyGadget {
|
||||||
builder.all([op_code_ok, arg_types_ok, hash_value_ok, st_ok])
|
builder.all([op_code_ok, arg_types_ok, hash_value_ok, st_ok])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn eval_sum_of(
|
||||||
|
&self,
|
||||||
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
|
st: &StatementTarget,
|
||||||
|
op: &OperationTarget,
|
||||||
|
resolved_op_args: &[StatementTarget],
|
||||||
|
) -> BoolTarget {
|
||||||
|
let value_zero = ValueTarget::zero(builder);
|
||||||
|
|
||||||
|
let op_code_ok = op.has_native_type(builder, NativeOperation::SumOf);
|
||||||
|
|
||||||
|
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) =
|
||||||
|
self.first_n_args_as_values(builder, resolved_op_args);
|
||||||
|
|
||||||
|
// Select to avoid overflow.
|
||||||
|
let summand1 = builder.select_value(op_code_ok, arg2_value, value_zero);
|
||||||
|
let summand2 = builder.select_value(op_code_ok, arg3_value, value_zero);
|
||||||
|
|
||||||
|
let expected_sum = builder.i64_add(summand1, summand2);
|
||||||
|
|
||||||
|
let sum_ok = builder.is_equal_slice(&arg1_value.elements, &expected_sum.elements);
|
||||||
|
|
||||||
|
let arg1_key = resolved_op_args[0].args[0].clone();
|
||||||
|
let arg2_key = resolved_op_args[1].args[0].clone();
|
||||||
|
let arg3_key = resolved_op_args[2].args[0].clone();
|
||||||
|
let expected_statement = StatementTarget::new_native(
|
||||||
|
builder,
|
||||||
|
&self.params,
|
||||||
|
NativePredicate::SumOf,
|
||||||
|
&[arg1_key, arg2_key, arg3_key],
|
||||||
|
);
|
||||||
|
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
|
||||||
|
|
||||||
|
builder.all([op_code_ok, arg_types_ok, sum_ok, st_ok])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn eval_product_of(
|
||||||
|
&self,
|
||||||
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
|
st: &StatementTarget,
|
||||||
|
op: &OperationTarget,
|
||||||
|
resolved_op_args: &[StatementTarget],
|
||||||
|
) -> BoolTarget {
|
||||||
|
let value_zero = ValueTarget::zero(builder);
|
||||||
|
|
||||||
|
let op_code_ok = op.has_native_type(builder, NativeOperation::ProductOf);
|
||||||
|
|
||||||
|
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) =
|
||||||
|
self.first_n_args_as_values(builder, resolved_op_args);
|
||||||
|
|
||||||
|
// Select to avoid overflow.
|
||||||
|
let factor1 = builder.select_value(op_code_ok, arg2_value, value_zero);
|
||||||
|
let factor2 = builder.select_value(op_code_ok, arg3_value, value_zero);
|
||||||
|
|
||||||
|
let expected_product = builder.i64_mul(factor1, factor2);
|
||||||
|
|
||||||
|
let product_ok = builder.is_equal_slice(&arg1_value.elements, &expected_product.elements);
|
||||||
|
|
||||||
|
let arg1_key = resolved_op_args[0].args[0].clone();
|
||||||
|
let arg2_key = resolved_op_args[1].args[0].clone();
|
||||||
|
let arg3_key = resolved_op_args[2].args[0].clone();
|
||||||
|
let expected_statement = StatementTarget::new_native(
|
||||||
|
builder,
|
||||||
|
&self.params,
|
||||||
|
NativePredicate::ProductOf,
|
||||||
|
&[arg1_key, arg2_key, arg3_key],
|
||||||
|
);
|
||||||
|
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
|
||||||
|
|
||||||
|
builder.all([op_code_ok, arg_types_ok, product_ok, st_ok])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn eval_max_of(
|
||||||
|
&self,
|
||||||
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
|
st: &StatementTarget,
|
||||||
|
op: &OperationTarget,
|
||||||
|
resolved_op_args: &[StatementTarget],
|
||||||
|
) -> BoolTarget {
|
||||||
|
let op_code_ok = op.has_native_type(builder, NativeOperation::MaxOf);
|
||||||
|
|
||||||
|
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) =
|
||||||
|
self.first_n_args_as_values(builder, resolved_op_args);
|
||||||
|
|
||||||
|
// Check that arg1_value is equal to one of the other two
|
||||||
|
// values.
|
||||||
|
let arg1_eq_arg2 = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements);
|
||||||
|
let arg1_eq_arg3 = builder.is_equal_slice(&arg1_value.elements, &arg3_value.elements);
|
||||||
|
|
||||||
|
let all_eq = builder.and(arg1_eq_arg2, arg1_eq_arg3);
|
||||||
|
let not_all_eq = builder.not(all_eq);
|
||||||
|
|
||||||
|
let arg1_check = builder.or(arg1_eq_arg2, arg1_eq_arg3);
|
||||||
|
|
||||||
|
// If it is not equal to any of the other two values, it must be greater than it.
|
||||||
|
let lower_bound = builder.select_value(arg1_eq_arg2, arg3_value, arg2_value);
|
||||||
|
|
||||||
|
// Only check lower bound if not all args are equal.
|
||||||
|
let lt_check_enabled = builder.and(not_all_eq, op_code_ok);
|
||||||
|
builder.assert_i64_less_if(lt_check_enabled, lower_bound, arg1_value);
|
||||||
|
|
||||||
|
let arg1_key = resolved_op_args[0].args[0].clone();
|
||||||
|
let arg2_key = resolved_op_args[1].args[0].clone();
|
||||||
|
let arg3_key = resolved_op_args[2].args[0].clone();
|
||||||
|
let expected_statement = StatementTarget::new_native(
|
||||||
|
builder,
|
||||||
|
&self.params,
|
||||||
|
NativePredicate::MaxOf,
|
||||||
|
&[arg1_key, arg2_key, arg3_key],
|
||||||
|
);
|
||||||
|
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
|
||||||
|
|
||||||
|
builder.all([op_code_ok, arg_types_ok, arg1_check, st_ok])
|
||||||
|
}
|
||||||
|
|
||||||
fn eval_transitive_eq(
|
fn eval_transitive_eq(
|
||||||
&self,
|
&self,
|
||||||
builder: &mut CircuitBuilder<F, D>,
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
|
|
@ -684,6 +802,8 @@ impl MainPodVerifyCircuit {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::ops::Not;
|
||||||
|
|
||||||
use plonky2::{
|
use plonky2::{
|
||||||
field::{goldilocks_field::GoldilocksField, types::Field},
|
field::{goldilocks_field::GoldilocksField, types::Field},
|
||||||
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig},
|
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig},
|
||||||
|
|
@ -693,6 +813,7 @@ mod tests {
|
||||||
use crate::{
|
use crate::{
|
||||||
backends::plonky2::{
|
backends::plonky2::{
|
||||||
basetypes::C,
|
basetypes::C,
|
||||||
|
circuits::common::tests::I64_TEST_PAIRS,
|
||||||
mainpod::{OperationArg, OperationAux},
|
mainpod::{OperationArg, OperationAux},
|
||||||
primitives::merkletree::{MerkleClaimAndProof, MerkleTree},
|
primitives::merkletree::{MerkleClaimAndProof, MerkleTree},
|
||||||
},
|
},
|
||||||
|
|
@ -1236,6 +1357,185 @@ mod tests {
|
||||||
operation_verify(st, op, prev_statements, vec![])
|
operation_verify(st, op, prev_statements, vec![])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_operation_verify_sumof() -> Result<()> {
|
||||||
|
I64_TEST_PAIRS
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|(a, b)| {
|
||||||
|
let (sum, overflow) = a.overflowing_add(b);
|
||||||
|
overflow.not().then_some((a, b, sum))
|
||||||
|
})
|
||||||
|
.try_for_each(|(a, b, sum)| {
|
||||||
|
let st1: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||||
|
sum.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st2: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||||
|
a.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st3: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||||
|
b.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st: mainpod::Statement = Statement::SumOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
let op = mainpod::Operation(
|
||||||
|
OperationType::Native(NativeOperation::SumOf),
|
||||||
|
vec![
|
||||||
|
OperationArg::Index(0),
|
||||||
|
OperationArg::Index(1),
|
||||||
|
OperationArg::Index(2),
|
||||||
|
],
|
||||||
|
OperationAux::None,
|
||||||
|
);
|
||||||
|
let prev_statements = vec![st1, st2, st3];
|
||||||
|
operation_verify(st, op, prev_statements, vec![])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_operation_verify_productof() -> Result<()> {
|
||||||
|
I64_TEST_PAIRS
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|(a, b)| {
|
||||||
|
let (prod, overflow) = a.overflowing_mul(b);
|
||||||
|
overflow.not().then_some((a, b, prod))
|
||||||
|
})
|
||||||
|
.try_for_each(|(a, b, prod)| {
|
||||||
|
let st1: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||||
|
prod.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st2: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||||
|
a.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st3: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||||
|
b.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st: mainpod::Statement = Statement::ProductOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
let op = mainpod::Operation(
|
||||||
|
OperationType::Native(NativeOperation::ProductOf),
|
||||||
|
vec![
|
||||||
|
OperationArg::Index(0),
|
||||||
|
OperationArg::Index(1),
|
||||||
|
OperationArg::Index(2),
|
||||||
|
],
|
||||||
|
OperationAux::None,
|
||||||
|
);
|
||||||
|
let prev_statements = vec![st1, st2, st3];
|
||||||
|
operation_verify(st, op, prev_statements, vec![])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_operation_verify_maxof() -> Result<()> {
|
||||||
|
I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| {
|
||||||
|
let max = i64::max(a, b);
|
||||||
|
let st1: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||||
|
max.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st2: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||||
|
a.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st3: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||||
|
b.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st: mainpod::Statement = Statement::MaxOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
let op = mainpod::Operation(
|
||||||
|
OperationType::Native(NativeOperation::MaxOf),
|
||||||
|
vec![
|
||||||
|
OperationArg::Index(0),
|
||||||
|
OperationArg::Index(1),
|
||||||
|
OperationArg::Index(2),
|
||||||
|
],
|
||||||
|
OperationAux::None,
|
||||||
|
);
|
||||||
|
let prev_statements = vec![st1, st2, st3];
|
||||||
|
operation_verify(st, op, prev_statements, vec![])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_operation_verify_maxof_failures() {
|
||||||
|
[(5, 3, 4), (5, 5, 8), (3, 4, 5)]
|
||||||
|
.into_iter()
|
||||||
|
.for_each(|(max, a, b)| {
|
||||||
|
let st1: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||||
|
max.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st2: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||||
|
a.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st3: mainpod::Statement = Statement::ValueOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||||
|
b.into(),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let st: mainpod::Statement = Statement::MaxOf(
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||||
|
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
let op = mainpod::Operation(
|
||||||
|
OperationType::Native(NativeOperation::MaxOf),
|
||||||
|
vec![
|
||||||
|
OperationArg::Index(0),
|
||||||
|
OperationArg::Index(1),
|
||||||
|
OperationArg::Index(2),
|
||||||
|
],
|
||||||
|
OperationAux::None,
|
||||||
|
);
|
||||||
|
let prev_statements = vec![st1, st2, st3];
|
||||||
|
assert!(operation_verify(st, op, prev_statements, vec![]).is_err())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_operation_verify_lt_to_neq() -> Result<()> {
|
fn test_operation_verify_lt_to_neq() -> Result<()> {
|
||||||
let st: mainpod::Statement = Statement::NotEqual(
|
let st: mainpod::Statement = Statement::NotEqual(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue