feat(backend): implement gadgets for remaining ops (#228)

* Implement gadgets for remaining ops

* Use overflowing arithmetic ops

* Code review

* Formatting
This commit is contained in:
Ahmad Afuni 2025-05-13 07:34:35 +10:00 committed by GitHub
parent b2cb563eb6
commit 4fa9e20ecd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 570 additions and 3 deletions

View file

@ -565,6 +565,26 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
/// and `y` each consist of two `u32` limbs.
fn assert_i64_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget);
/// Computes `x + y` assuming `x` and `y` are assigned `i64`
/// values.
fn i64_wrapping_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
/// Computes `x + y` assuming `x` and `y` are assigned `i64`
/// values. Enforces no overflow.
fn i64_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
/// Computes `x * y` assuming `x` and `y` are assigned `i64`
/// values. Enforces no overflow.
fn i64_mul(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
/// Computes the canonical involution of `x` in `i64`, i.e. the
/// negation of `x` as an `i64`.
fn i64_inv(&mut self, x: ValueTarget) -> ValueTarget;
/// Computes the absolute value of `x` *as an element of
/// `i64`*. Includes sign indicator (true if negative).
fn i64_abs(&mut self, x: ValueTarget) -> (ValueTarget, BoolTarget);
/// Creates value target that is a hash of two given values.
fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
@ -716,6 +736,138 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
assert_limb_lt(self, lhs, rhs);
}
fn i64_wrapping_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
let zero = self.zero();
// Add components and carry where appropriate.
let (_, sum) = std::iter::zip(&x.elements[..2], &y.elements[..2]).fold(
(zero, vec![]),
|(carry, out), (&a, &b)| {
let sum = [a, b, carry]
.into_iter()
.reduce(|alpha, beta| self.add(alpha, beta))
.expect("Iterator should be nonempty.");
let (sum_residue, sum_quotient) = self.split_low_high(sum, NUM_BITS, F::BITS);
(sum_quotient, [out, vec![sum_residue]].concat())
},
);
ValueTarget::from_slice(&[sum[0], sum[1], zero, zero])
}
fn i64_add(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
let zero = self.zero();
let sum = self.i64_wrapping_add(x, y);
// Overflow check.
let x_is_negative = self.i64_is_negative(x);
let x_is_nonnegative = self.not(x_is_negative);
let y_is_negative = self.i64_is_negative(y);
let y_is_nonnegative = self.not(y_is_negative);
let sum_is_negative = self.i64_is_negative(sum);
let sum_is_nonnegative = self.not(sum_is_negative);
let overflow_conditions = [
self.all([x_is_negative, y_is_negative, sum_is_nonnegative]),
self.all([x_is_nonnegative, y_is_nonnegative, sum_is_negative]),
];
let overflow = self.any(overflow_conditions);
self.connect(overflow.target, zero);
sum
}
fn i64_mul(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
let zero = self.zero();
let i64_min = ValueTarget::from_slice(&self.constants(&RawValue::from(i64::MIN).0));
let (abs_x, x_is_negative) = self.i64_abs(x);
let (abs_y, y_is_negative) = self.i64_abs(y);
// Sign indicators.
let same_sign_ind = self.is_equal(x_is_negative.target, y_is_negative.target);
let prod_sign = self.not(same_sign_ind);
// Determine product of absolute values.
let x = abs_x.elements[..2].to_vec();
let y = abs_y.elements[..2].to_vec();
let prods = [
self.mul(x[0], y[0]),
self.mul(x[0], y[1]),
self.mul(x[1], y[0]),
]
.into_iter()
.map(|p| self.split_low_high(p, NUM_BITS, F::BITS))
.collect::<Vec<_>>();
let prod_lower = prods[0].0;
let (prod_upper, _) = {
let sum1 = self.add(prods[1].0, prods[2].0);
let sum2 = self.add(sum1, prods[0].1);
self.split_low_high(sum2, NUM_BITS, F::BITS)
};
let abs_prod = ValueTarget::from_slice(&[prod_lower, prod_upper, zero, zero]);
// Overflow check: The latter two products in `prods` should
// have zero higher-order coefficients.
let no_spillovers = [
self.is_equal(prods[1].1, zero),
self.is_equal(prods[2].1, zero),
]
.into_iter()
.reduce(|a, b| self.and(a, b))
.expect("Iterator should be nonempty.");
// Overflow check: The product of the higher-order
// coefficients should be zero.
let higher_prod = self.mul(x[1], y[1]);
let higher_prod_is_zero = self.is_equal(higher_prod, zero);
// Overflow check: The product of the absolute values is
// either nonnegative or negative and equal to `i64::MIN`.
let abs_prod_is_negative = self.i64_is_negative(abs_prod);
let abs_prod_is_nonnegative = self.not(abs_prod_is_negative);
let abs_prod_is_min = self.is_equal_slice(&abs_prod.elements, &i64_min.elements);
let abs_prod_sign_ok = self.and(abs_prod_is_min, prod_sign);
let abs_prod_sign_ok = self.or(abs_prod_sign_ok, abs_prod_is_nonnegative);
// Combine the above conditions.
let no_overflow = self.and(abs_prod_sign_ok, higher_prod_is_zero);
let no_overflow = self.and(no_overflow, no_spillovers);
self.assert_one(no_overflow.target);
// Take sign into account.
let minus_abs_prod = self.i64_inv(abs_prod);
self.select_value(prod_sign, minus_abs_prod, abs_prod)
}
fn i64_inv(&mut self, x: ValueTarget) -> ValueTarget {
let zero = self.zero();
let one = ValueTarget::one(self);
let u32_max = self.constant(F::from_canonical_u32(u32::MAX));
let flipped_x = ValueTarget::from_slice(&[
self.sub(u32_max, x.elements[0]),
self.sub(u32_max, x.elements[1]),
zero,
zero,
]);
self.i64_wrapping_add(one, flipped_x)
}
fn i64_abs(&mut self, x: ValueTarget) -> (ValueTarget, BoolTarget) {
let x_is_negative = self.i64_is_negative(x);
let minus_x = self.i64_inv(x);
(self.select_value(x_is_negative, minus_x, x), x_is_negative)
}
fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget {
ValueTarget::from_slice(
&self
@ -795,9 +947,13 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
}
#[cfg(test)]
mod tests {
pub(crate) mod tests {
use anyhow::anyhow;
use itertools::Itertools;
use plonky2::plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig};
use plonky2::plonk::{
circuit_builder::CircuitBuilder, circuit_data::CircuitConfig,
config::PoseidonGoldilocksConfig,
};
use super::*;
use crate::{
@ -808,6 +964,48 @@ mod tests {
middleware::CustomPredicateBatch,
};
pub(crate) const I64_TEST_PAIRS: [(i64, i64); 36] = [
// Nonnegative numbers
(0, 0),
(0, 50),
(35, 50),
(483748374, 221672),
(2, 1 << 31),
(2, 1 << 62),
(0, 1 << 62),
(1 << 31, 1 << 62),
(1 << 32, 1 << 32),
(1 << 62, 1 << 62),
(0, i64::MAX),
(i64::MAX, 1 << 62),
(i64::MAX, i64::MAX),
// Negative numbers
(-35, -50),
(-483748374, -221672),
(-(1 << 33), -1),
(-(1 << 32), -(1 << 32)),
(-(1 << 33), -(1 << 29)),
(-(1 << 33), -(1 << 30)),
(-(1 << 33), -(1 << 62)),
(-(1 << 62), -(1 << 62)),
(i64::MIN, -1),
(i64::MIN, -(1 << 31)),
(i64::MIN, -(1 << 62)),
(i64::MIN, i64::MIN),
// Mix of numbers
(-35, 50),
(-483748374, 221672),
(-(1 << 32), (1 << 32)),
(-(1 << 33), (1 << 30) - 1),
(-(1 << 33), (1 << 30)),
(-(1 << 62), (1 << 62)),
(i64::MIN, 0),
(i64::MIN, 1),
(i64::MIN, 1 << 31),
(i64::MIN, 1 << 62),
(i64::MIN, i64::MAX),
];
#[test]
fn custom_predicate_target() -> frontend::Result<()> {
let params = Params::default();
@ -828,7 +1026,7 @@ mod tests {
// generate & verify proof
let data = builder.build::<C>();
let proof = data.prove(pw).expect(&format!("predicate {}", i));
let proof = data.prove(pw).unwrap_or_else(|_| panic!("predicate {}", i));
data.verify(proof.clone()).unwrap();
}
@ -912,4 +1110,73 @@ mod tests {
Ok(())
}
#[test]
fn test_i64_addition() -> Result<(), anyhow::Error> {
// Circuit declaration
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
let y_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
let sum_target = builder.i64_add(x_target, y_target);
let data = builder.build::<PoseidonGoldilocksConfig>();
let params = Params::default();
I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| {
let mut pw = PartialWitness::<F>::new();
let (sum, overflow) = x.overflowing_add(y);
pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields(&params))?;
pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields(&params))?;
pw.set_target_arr(
&sum_target.elements,
&RawValue::from(sum).to_fields(&params),
)?;
let proof = data.prove(pw);
match (overflow, proof) {
(false, Ok(pf)) => data.verify(pf),
(false, Err(e)) => Err(anyhow!("Proof failure despite no overflow: {}", e)),
(true, Ok(_)) => Err(anyhow!("Proof success despite overflow.")),
(true, Err(_)) => Ok(()),
}
})
}
#[test]
fn test_i64_multiplication() -> Result<(), anyhow::Error> {
// Circuit declaration
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
let y_target = ValueTarget::from_slice(&builder.add_virtual_target_arr::<VALUE_SIZE>());
let prod_target = builder.i64_mul(x_target, y_target);
let data = builder.build::<PoseidonGoldilocksConfig>();
let params = Params::default();
I64_TEST_PAIRS.into_iter().try_for_each(|(x, y)| {
println!("{}, {}", x, y);
let mut pw = PartialWitness::<F>::new();
let (prod, overflow) = x.overflowing_mul(y);
pw.set_target_arr(&x_target.elements, &RawValue::from(x).to_fields(&params))?;
pw.set_target_arr(&y_target.elements, &RawValue::from(y).to_fields(&params))?;
pw.set_target_arr(
&prod_target.elements,
&RawValue::from(prod).to_fields(&params),
)?;
let proof = data.prove(pw);
match (overflow, proof) {
(false, Ok(pf)) => data.verify(pf),
(false, Err(e)) => Err(anyhow!("Proof failure despite no overflow: {}", e)),
(true, Ok(_)) => Err(anyhow!("Proof success despite overflow.")),
(true, Err(_)) => Ok(()),
}
})
}
}