feat(backend): implement gadgets for remaining ops (#228)

* Implement gadgets for remaining ops

* Use overflowing arithmetic ops

* Code review

* Formatting
This commit is contained in:
Ahmad Afuni 2025-05-13 07:34:35 +10:00 committed by GitHub
parent b2cb563eb6
commit 4fa9e20ecd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 570 additions and 3 deletions

View file

@ -118,6 +118,9 @@ impl OperationVerifyGadget {
self.eval_transitive_eq(builder, st, op, &resolved_op_args),
self.eval_lt_to_neq(builder, st, op, &resolved_op_args),
self.eval_hash_of(builder, st, op, &resolved_op_args),
self.eval_sum_of(builder, st, op, &resolved_op_args),
self.eval_product_of(builder, st, op, &resolved_op_args),
self.eval_max_of(builder, st, op, &resolved_op_args),
]
},
// Skip these if there are no resolved Merkle claims
@ -386,6 +389,121 @@ impl OperationVerifyGadget {
builder.all([op_code_ok, arg_types_ok, hash_value_ok, st_ok])
}
fn eval_sum_of(
&self,
builder: &mut CircuitBuilder<F, D>,
st: &StatementTarget,
op: &OperationTarget,
resolved_op_args: &[StatementTarget],
) -> BoolTarget {
let value_zero = ValueTarget::zero(builder);
let op_code_ok = op.has_native_type(builder, NativeOperation::SumOf);
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) =
self.first_n_args_as_values(builder, resolved_op_args);
// Select to avoid overflow.
let summand1 = builder.select_value(op_code_ok, arg2_value, value_zero);
let summand2 = builder.select_value(op_code_ok, arg3_value, value_zero);
let expected_sum = builder.i64_add(summand1, summand2);
let sum_ok = builder.is_equal_slice(&arg1_value.elements, &expected_sum.elements);
let arg1_key = resolved_op_args[0].args[0].clone();
let arg2_key = resolved_op_args[1].args[0].clone();
let arg3_key = resolved_op_args[2].args[0].clone();
let expected_statement = StatementTarget::new_native(
builder,
&self.params,
NativePredicate::SumOf,
&[arg1_key, arg2_key, arg3_key],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
builder.all([op_code_ok, arg_types_ok, sum_ok, st_ok])
}
fn eval_product_of(
&self,
builder: &mut CircuitBuilder<F, D>,
st: &StatementTarget,
op: &OperationTarget,
resolved_op_args: &[StatementTarget],
) -> BoolTarget {
let value_zero = ValueTarget::zero(builder);
let op_code_ok = op.has_native_type(builder, NativeOperation::ProductOf);
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) =
self.first_n_args_as_values(builder, resolved_op_args);
// Select to avoid overflow.
let factor1 = builder.select_value(op_code_ok, arg2_value, value_zero);
let factor2 = builder.select_value(op_code_ok, arg3_value, value_zero);
let expected_product = builder.i64_mul(factor1, factor2);
let product_ok = builder.is_equal_slice(&arg1_value.elements, &expected_product.elements);
let arg1_key = resolved_op_args[0].args[0].clone();
let arg2_key = resolved_op_args[1].args[0].clone();
let arg3_key = resolved_op_args[2].args[0].clone();
let expected_statement = StatementTarget::new_native(
builder,
&self.params,
NativePredicate::ProductOf,
&[arg1_key, arg2_key, arg3_key],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
builder.all([op_code_ok, arg_types_ok, product_ok, st_ok])
}
fn eval_max_of(
&self,
builder: &mut CircuitBuilder<F, D>,
st: &StatementTarget,
op: &OperationTarget,
resolved_op_args: &[StatementTarget],
) -> BoolTarget {
let op_code_ok = op.has_native_type(builder, NativeOperation::MaxOf);
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) =
self.first_n_args_as_values(builder, resolved_op_args);
// Check that arg1_value is equal to one of the other two
// values.
let arg1_eq_arg2 = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements);
let arg1_eq_arg3 = builder.is_equal_slice(&arg1_value.elements, &arg3_value.elements);
let all_eq = builder.and(arg1_eq_arg2, arg1_eq_arg3);
let not_all_eq = builder.not(all_eq);
let arg1_check = builder.or(arg1_eq_arg2, arg1_eq_arg3);
// If it is not equal to any of the other two values, it must be greater than it.
let lower_bound = builder.select_value(arg1_eq_arg2, arg3_value, arg2_value);
// Only check lower bound if not all args are equal.
let lt_check_enabled = builder.and(not_all_eq, op_code_ok);
builder.assert_i64_less_if(lt_check_enabled, lower_bound, arg1_value);
let arg1_key = resolved_op_args[0].args[0].clone();
let arg2_key = resolved_op_args[1].args[0].clone();
let arg3_key = resolved_op_args[2].args[0].clone();
let expected_statement = StatementTarget::new_native(
builder,
&self.params,
NativePredicate::MaxOf,
&[arg1_key, arg2_key, arg3_key],
);
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
builder.all([op_code_ok, arg_types_ok, arg1_check, st_ok])
}
fn eval_transitive_eq(
&self,
builder: &mut CircuitBuilder<F, D>,
@ -684,6 +802,8 @@ impl MainPodVerifyCircuit {
#[cfg(test)]
mod tests {
use std::ops::Not;
use plonky2::{
field::{goldilocks_field::GoldilocksField, types::Field},
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig},
@ -693,6 +813,7 @@ mod tests {
use crate::{
backends::plonky2::{
basetypes::C,
circuits::common::tests::I64_TEST_PAIRS,
mainpod::{OperationArg, OperationAux},
primitives::merkletree::{MerkleClaimAndProof, MerkleTree},
},
@ -1236,6 +1357,185 @@ mod tests {
operation_verify(st, op, prev_statements, vec![])
}
#[test]
fn test_operation_verify_sumof() -> Result<()> {
I64_TEST_PAIRS
.into_iter()
.flat_map(|(a, b)| {
let (sum, overflow) = a.overflowing_add(b);
overflow.not().then_some((a, b, sum))
})
.try_for_each(|(a, b, sum)| {
let st1: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
sum.into(),
)
.into();
let st2: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
a.into(),
)
.into();
let st3: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
b.into(),
)
.into();
let st: mainpod::Statement = Statement::SumOf(
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::SumOf),
vec![
OperationArg::Index(0),
OperationArg::Index(1),
OperationArg::Index(2),
],
OperationAux::None,
);
let prev_statements = vec![st1, st2, st3];
operation_verify(st, op, prev_statements, vec![])
})
}
#[test]
fn test_operation_verify_productof() -> Result<()> {
I64_TEST_PAIRS
.into_iter()
.flat_map(|(a, b)| {
let (prod, overflow) = a.overflowing_mul(b);
overflow.not().then_some((a, b, prod))
})
.try_for_each(|(a, b, prod)| {
let st1: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
prod.into(),
)
.into();
let st2: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
a.into(),
)
.into();
let st3: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
b.into(),
)
.into();
let st: mainpod::Statement = Statement::ProductOf(
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::ProductOf),
vec![
OperationArg::Index(0),
OperationArg::Index(1),
OperationArg::Index(2),
],
OperationAux::None,
);
let prev_statements = vec![st1, st2, st3];
operation_verify(st, op, prev_statements, vec![])
})
}
#[test]
fn test_operation_verify_maxof() -> Result<()> {
I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| {
let max = i64::max(a, b);
let st1: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
max.into(),
)
.into();
let st2: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
a.into(),
)
.into();
let st3: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
b.into(),
)
.into();
let st: mainpod::Statement = Statement::MaxOf(
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::MaxOf),
vec![
OperationArg::Index(0),
OperationArg::Index(1),
OperationArg::Index(2),
],
OperationAux::None,
);
let prev_statements = vec![st1, st2, st3];
operation_verify(st, op, prev_statements, vec![])
})
}
#[test]
fn test_operation_verify_maxof_failures() {
[(5, 3, 4), (5, 5, 8), (3, 4, 5)]
.into_iter()
.for_each(|(max, a, b)| {
let st1: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
max.into(),
)
.into();
let st2: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
a.into(),
)
.into();
let st3: mainpod::Statement = Statement::ValueOf(
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
b.into(),
)
.into();
let st: mainpod::Statement = Statement::MaxOf(
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::MaxOf),
vec![
OperationArg::Index(0),
OperationArg::Index(1),
OperationArg::Index(2),
],
OperationAux::None,
);
let prev_statements = vec![st1, st2, st3];
assert!(operation_verify(st, op, prev_statements, vec![]).is_err())
})
}
#[test]
fn test_operation_verify_lt_to_neq() -> Result<()> {
let st: mainpod::Statement = Statement::NotEqual(