diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 87101f6..ef16bd7 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -565,6 +565,26 @@ pub trait CircuitBuilderPod, const D: usize> { /// and `y` each consist of two `u32` limbs. fn assert_i64_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget); + /// Computes `x + y` assuming `x` and `y` are assigned `i64` + /// values. + fn i64_wrapping_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget; + + /// Computes `x + y` assuming `x` and `y` are assigned `i64` + /// values. Enforces no overflow. + fn i64_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget; + + /// Computes `x * y` assuming `x` and `y` are assigned `i64` + /// values. Enforces no overflow. + fn i64_mul(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget; + + /// Computes the canonical involution of `x` in `i64`, i.e. the + /// negation of `x` as an `i64`. + fn i64_inv(&mut self, x: ValueTarget) -> ValueTarget; + + /// Computes the absolute value of `x` *as an element of + /// `i64`*. Includes sign indicator (true if negative). + fn i64_abs(&mut self, x: ValueTarget) -> (ValueTarget, BoolTarget); + /// Creates value target that is a hash of two given values. fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget; @@ -716,6 +736,138 @@ impl CircuitBuilderPod for CircuitBuilder { assert_limb_lt(self, lhs, rhs); } + fn i64_wrapping_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget { + let zero = self.zero(); + + // Add components and carry where appropriate. + let (_, sum) = std::iter::zip(&x.elements[..2], &y.elements[..2]).fold( + (zero, vec![]), + |(carry, out), (&a, &b)| { + let sum = [a, b, carry] + .into_iter() + .reduce(|alpha, beta| self.add(alpha, beta)) + .expect("Iterator should be nonempty."); + let (sum_residue, sum_quotient) = self.split_low_high(sum, NUM_BITS, F::BITS); + (sum_quotient, [out, vec![sum_residue]].concat()) + }, + ); + + ValueTarget::from_slice(&[sum[0], sum[1], zero, zero]) + } + + fn i64_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget { + let zero = self.zero(); + let sum = self.i64_wrapping_add(x, y); + + // Overflow check. + let x_is_negative = self.i64_is_negative(x); + let x_is_nonnegative = self.not(x_is_negative); + let y_is_negative = self.i64_is_negative(y); + let y_is_nonnegative = self.not(y_is_negative); + + let sum_is_negative = self.i64_is_negative(sum); + let sum_is_nonnegative = self.not(sum_is_negative); + + let overflow_conditions = [ + self.all([x_is_negative, y_is_negative, sum_is_nonnegative]), + self.all([x_is_nonnegative, y_is_nonnegative, sum_is_negative]), + ]; + + let overflow = self.any(overflow_conditions); + + self.connect(overflow.target, zero); + + sum + } + + fn i64_mul(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget { + let zero = self.zero(); + let i64_min = ValueTarget::from_slice(&self.constants(&RawValue::from(i64::MIN).0)); + let (abs_x, x_is_negative) = self.i64_abs(x); + let (abs_y, y_is_negative) = self.i64_abs(y); + + // Sign indicators. + let same_sign_ind = self.is_equal(x_is_negative.target, y_is_negative.target); + let prod_sign = self.not(same_sign_ind); + + // Determine product of absolute values. + let x = abs_x.elements[..2].to_vec(); + let y = abs_y.elements[..2].to_vec(); + + let prods = [ + self.mul(x[0], y[0]), + self.mul(x[0], y[1]), + self.mul(x[1], y[0]), + ] + .into_iter() + .map(|p| self.split_low_high(p, NUM_BITS, F::BITS)) + .collect::>(); + + let prod_lower = prods[0].0; + + let (prod_upper, _) = { + let sum1 = self.add(prods[1].0, prods[2].0); + let sum2 = self.add(sum1, prods[0].1); + self.split_low_high(sum2, NUM_BITS, F::BITS) + }; + + let abs_prod = ValueTarget::from_slice(&[prod_lower, prod_upper, zero, zero]); + + // Overflow check: The latter two products in `prods` should + // have zero higher-order coefficients. + let no_spillovers = [ + self.is_equal(prods[1].1, zero), + self.is_equal(prods[2].1, zero), + ] + .into_iter() + .reduce(|a, b| self.and(a, b)) + .expect("Iterator should be nonempty."); + + // Overflow check: The product of the higher-order + // coefficients should be zero. + let higher_prod = self.mul(x[1], y[1]); + let higher_prod_is_zero = self.is_equal(higher_prod, zero); + + // Overflow check: The product of the absolute values is + // either nonnegative or negative and equal to `i64::MIN`. + let abs_prod_is_negative = self.i64_is_negative(abs_prod); + let abs_prod_is_nonnegative = self.not(abs_prod_is_negative); + let abs_prod_is_min = self.is_equal_slice(&abs_prod.elements, &i64_min.elements); + let abs_prod_sign_ok = self.and(abs_prod_is_min, prod_sign); + let abs_prod_sign_ok = self.or(abs_prod_sign_ok, abs_prod_is_nonnegative); + + // Combine the above conditions. + let no_overflow = self.and(abs_prod_sign_ok, higher_prod_is_zero); + let no_overflow = self.and(no_overflow, no_spillovers); + self.assert_one(no_overflow.target); + + // Take sign into account. + let minus_abs_prod = self.i64_inv(abs_prod); + + self.select_value(prod_sign, minus_abs_prod, abs_prod) + } + + fn i64_inv(&mut self, x: ValueTarget) -> ValueTarget { + let zero = self.zero(); + let one = ValueTarget::one(self); + let u32_max = self.constant(F::from_canonical_u32(u32::MAX)); + + let flipped_x = ValueTarget::from_slice(&[ + self.sub(u32_max, x.elements[0]), + self.sub(u32_max, x.elements[1]), + zero, + zero, + ]); + + self.i64_wrapping_add(one, flipped_x) + } + + fn i64_abs(&mut self, x: ValueTarget) -> (ValueTarget, BoolTarget) { + let x_is_negative = self.i64_is_negative(x); + let minus_x = self.i64_inv(x); + (self.select_value(x_is_negative, minus_x, x), x_is_negative) + } + fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget { ValueTarget::from_slice( &self @@ -795,9 +947,13 @@ impl CircuitBuilderPod for CircuitBuilder { } #[cfg(test)] -mod tests { +pub(crate) mod tests { + use anyhow::anyhow; use itertools::Itertools; - use plonky2::plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}; + use plonky2::plonk::{ + circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, + config::PoseidonGoldilocksConfig, + }; use super::*; use crate::{ @@ -808,6 +964,48 @@ mod tests { middleware::CustomPredicateBatch, }; + pub(crate) const I64_TEST_PAIRS: [(i64, i64); 36] = [ + // Nonnegative numbers + (0, 0), + (0, 50), + (35, 50), + (483748374, 221672), + (2, 1 << 31), + (2, 1 << 62), + (0, 1 << 62), + (1 << 31, 1 << 62), + (1 << 32, 1 << 32), + (1 << 62, 1 << 62), + (0, i64::MAX), + (i64::MAX, 1 << 62), + (i64::MAX, i64::MAX), + // Negative numbers + (-35, -50), + (-483748374, -221672), + (-(1 << 33), -1), + (-(1 << 32), -(1 << 32)), + (-(1 << 33), -(1 << 29)), + (-(1 << 33), -(1 << 30)), + (-(1 << 33), -(1 << 62)), + (-(1 << 62), -(1 << 62)), + (i64::MIN, -1), + (i64::MIN, -(1 << 31)), + (i64::MIN, -(1 << 62)), + (i64::MIN, i64::MIN), + // Mix of numbers + (-35, 50), + (-483748374, 221672), + (-(1 << 32), (1 << 32)), + (-(1 << 33), (1 << 30) - 1), + (-(1 << 33), (1 << 30)), + (-(1 << 62), (1 << 62)), + (i64::MIN, 0), + (i64::MIN, 1), + (i64::MIN, 1 << 31), + (i64::MIN, 1 << 62), + (i64::MIN, i64::MAX), + ]; + #[test] fn custom_predicate_target() -> frontend::Result<()> { let params = Params::default(); @@ -828,7 +1026,7 @@ mod tests { // generate & verify proof let data = builder.build::(); - let proof = data.prove(pw).expect(&format!("predicate {}", i)); + let proof = data.prove(pw).unwrap_or_else(|_| panic!("predicate {}", i)); data.verify(proof.clone()).unwrap(); } @@ -912,4 +1110,73 @@ mod tests { Ok(()) } + + #[test] + fn test_i64_addition() -> Result<(), anyhow::Error> { + // Circuit declaration + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let x_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::()); + let y_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::()); + + let sum_target = builder.i64_add(x_target, y_target); + + let data = builder.build::(); + let params = Params::default(); + + I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| { + let mut pw = PartialWitness::::new(); + let (sum, overflow) = x.overflowing_add(y); + pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields(¶ms))?; + pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields(¶ms))?; + pw.set_target_arr( + &sum_target.elements, + &RawValue::from(sum).to_fields(¶ms), + )?; + + let proof = data.prove(pw); + + match (overflow, proof) { + (false, Ok(pf)) => data.verify(pf), + (false, Err(e)) => Err(anyhow!("Proof failure despite no overflow: {}", e)), + (true, Ok(_)) => Err(anyhow!("Proof success despite overflow.")), + (true, Err(_)) => Ok(()), + } + }) + } + + #[test] + fn test_i64_multiplication() -> Result<(), anyhow::Error> { + // Circuit declaration + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let x_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::()); + let y_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::()); + + let prod_target = builder.i64_mul(x_target, y_target); + + let data = builder.build::(); + let params = Params::default(); + + I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| { + println!("{}, {}", x, y); + let mut pw = PartialWitness::::new(); + let (prod, overflow) = x.overflowing_mul(y); + pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields(¶ms))?; + pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields(¶ms))?; + pw.set_target_arr( + &prod_target.elements, + &RawValue::from(prod).to_fields(¶ms), + )?; + + let proof = data.prove(pw); + + match (overflow, proof) { + (false, Ok(pf)) => data.verify(pf), + (false, Err(e)) => Err(anyhow!("Proof failure despite no overflow: {}", e)), + (true, Ok(_)) => Err(anyhow!("Proof success despite overflow.")), + (true, Err(_)) => Ok(()), + } + }) + } } diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 12fe60d..95d34be 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -118,6 +118,9 @@ impl OperationVerifyGadget { self.eval_transitive_eq(builder, st, op, &resolved_op_args), self.eval_lt_to_neq(builder, st, op, &resolved_op_args), self.eval_hash_of(builder, st, op, &resolved_op_args), + self.eval_sum_of(builder, st, op, &resolved_op_args), + self.eval_product_of(builder, st, op, &resolved_op_args), + self.eval_max_of(builder, st, op, &resolved_op_args), ] }, // Skip these if there are no resolved Merkle claims @@ -386,6 +389,121 @@ impl OperationVerifyGadget { builder.all([op_code_ok, arg_types_ok, hash_value_ok, st_ok]) } + fn eval_sum_of( + &self, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op: &OperationTarget, + resolved_op_args: &[StatementTarget], + ) -> BoolTarget { + let value_zero = ValueTarget::zero(builder); + + let op_code_ok = op.has_native_type(builder, NativeOperation::SumOf); + + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = + self.first_n_args_as_values(builder, resolved_op_args); + + // Select to avoid overflow. + let summand1 = builder.select_value(op_code_ok, arg2_value, value_zero); + let summand2 = builder.select_value(op_code_ok, arg3_value, value_zero); + + let expected_sum = builder.i64_add(summand1, summand2); + + let sum_ok = builder.is_equal_slice(&arg1_value.elements, &expected_sum.elements); + + let arg1_key = resolved_op_args[0].args[0].clone(); + let arg2_key = resolved_op_args[1].args[0].clone(); + let arg3_key = resolved_op_args[2].args[0].clone(); + let expected_statement = StatementTarget::new_native( + builder, + &self.params, + NativePredicate::SumOf, + &[arg1_key, arg2_key, arg3_key], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); + + builder.all([op_code_ok, arg_types_ok, sum_ok, st_ok]) + } + + fn eval_product_of( + &self, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op: &OperationTarget, + resolved_op_args: &[StatementTarget], + ) -> BoolTarget { + let value_zero = ValueTarget::zero(builder); + + let op_code_ok = op.has_native_type(builder, NativeOperation::ProductOf); + + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = + self.first_n_args_as_values(builder, resolved_op_args); + + // Select to avoid overflow. + let factor1 = builder.select_value(op_code_ok, arg2_value, value_zero); + let factor2 = builder.select_value(op_code_ok, arg3_value, value_zero); + + let expected_product = builder.i64_mul(factor1, factor2); + + let product_ok = builder.is_equal_slice(&arg1_value.elements, &expected_product.elements); + + let arg1_key = resolved_op_args[0].args[0].clone(); + let arg2_key = resolved_op_args[1].args[0].clone(); + let arg3_key = resolved_op_args[2].args[0].clone(); + let expected_statement = StatementTarget::new_native( + builder, + &self.params, + NativePredicate::ProductOf, + &[arg1_key, arg2_key, arg3_key], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); + + builder.all([op_code_ok, arg_types_ok, product_ok, st_ok]) + } + + fn eval_max_of( + &self, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op: &OperationTarget, + resolved_op_args: &[StatementTarget], + ) -> BoolTarget { + let op_code_ok = op.has_native_type(builder, NativeOperation::MaxOf); + + let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) = + self.first_n_args_as_values(builder, resolved_op_args); + + // Check that arg1_value is equal to one of the other two + // values. + let arg1_eq_arg2 = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements); + let arg1_eq_arg3 = builder.is_equal_slice(&arg1_value.elements, &arg3_value.elements); + + let all_eq = builder.and(arg1_eq_arg2, arg1_eq_arg3); + let not_all_eq = builder.not(all_eq); + + let arg1_check = builder.or(arg1_eq_arg2, arg1_eq_arg3); + + // If it is not equal to any of the other two values, it must be greater than it. + let lower_bound = builder.select_value(arg1_eq_arg2, arg3_value, arg2_value); + + // Only check lower bound if not all args are equal. + let lt_check_enabled = builder.and(not_all_eq, op_code_ok); + builder.assert_i64_less_if(lt_check_enabled, lower_bound, arg1_value); + + let arg1_key = resolved_op_args[0].args[0].clone(); + let arg2_key = resolved_op_args[1].args[0].clone(); + let arg3_key = resolved_op_args[2].args[0].clone(); + let expected_statement = StatementTarget::new_native( + builder, + &self.params, + NativePredicate::MaxOf, + &[arg1_key, arg2_key, arg3_key], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); + + builder.all([op_code_ok, arg_types_ok, arg1_check, st_ok]) + } + fn eval_transitive_eq( &self, builder: &mut CircuitBuilder, @@ -684,6 +802,8 @@ impl MainPodVerifyCircuit { #[cfg(test)] mod tests { + use std::ops::Not; + use plonky2::{ field::{goldilocks_field::GoldilocksField, types::Field}, plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, @@ -693,6 +813,7 @@ mod tests { use crate::{ backends::plonky2::{ basetypes::C, + circuits::common::tests::I64_TEST_PAIRS, mainpod::{OperationArg, OperationAux}, primitives::merkletree::{MerkleClaimAndProof, MerkleTree}, }, @@ -1236,6 +1357,185 @@ mod tests { operation_verify(st, op, prev_statements, vec![]) } + #[test] + fn test_operation_verify_sumof() -> Result<()> { + I64_TEST_PAIRS + .into_iter() + .flat_map(|(a, b)| { + let (sum, overflow) = a.overflowing_add(b); + overflow.not().then_some((a, b, sum)) + }) + .try_for_each(|(a, b, sum)| { + let st1: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), + sum.into(), + ) + .into(); + + let st2: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), + a.into(), + ) + .into(); + + let st3: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), + b.into(), + ) + .into(); + + let st: mainpod::Statement = Statement::SumOf( + AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), + AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), + AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SumOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, vec![]) + }) + } + + #[test] + fn test_operation_verify_productof() -> Result<()> { + I64_TEST_PAIRS + .into_iter() + .flat_map(|(a, b)| { + let (prod, overflow) = a.overflowing_mul(b); + overflow.not().then_some((a, b, prod)) + }) + .try_for_each(|(a, b, prod)| { + let st1: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), + prod.into(), + ) + .into(); + + let st2: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), + a.into(), + ) + .into(); + + let st3: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), + b.into(), + ) + .into(); + + let st: mainpod::Statement = Statement::ProductOf( + AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), + AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), + AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::ProductOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, vec![]) + }) + } + + #[test] + fn test_operation_verify_maxof() -> Result<()> { + I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| { + let max = i64::max(a, b); + let st1: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), + max.into(), + ) + .into(); + + let st2: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), + a.into(), + ) + .into(); + + let st3: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), + b.into(), + ) + .into(); + + let st: mainpod::Statement = Statement::MaxOf( + AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), + AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), + AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::MaxOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, vec![]) + }) + } + + #[test] + fn test_operation_verify_maxof_failures() { + [(5, 3, 4), (5, 5, 8), (3, 4, 5)] + .into_iter() + .for_each(|(max, a, b)| { + let st1: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), + max.into(), + ) + .into(); + + let st2: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), + a.into(), + ) + .into(); + + let st3: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), + b.into(), + ) + .into(); + + let st: mainpod::Statement = Statement::MaxOf( + AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), + AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), + AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::MaxOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + assert!(operation_verify(st, op, prev_statements, vec![]).is_err()) + }) + } + #[test] fn test_operation_verify_lt_to_neq() -> Result<()> { let st: mainpod::Statement = Statement::NotEqual(