feat(backend): implement gadgets for remaining ops (#228)
* Implement gadgets for remaining ops * Use overflowing arithmetic ops * Code review * Formatting
This commit is contained in:
parent
b2cb563eb6
commit
4fa9e20ecd
2 changed files with 570 additions and 3 deletions
|
|
@ -118,6 +118,9 @@ impl OperationVerifyGadget {
|
|||
self.eval_transitive_eq(builder, st, op, &resolved_op_args),
|
||||
self.eval_lt_to_neq(builder, st, op, &resolved_op_args),
|
||||
self.eval_hash_of(builder, st, op, &resolved_op_args),
|
||||
self.eval_sum_of(builder, st, op, &resolved_op_args),
|
||||
self.eval_product_of(builder, st, op, &resolved_op_args),
|
||||
self.eval_max_of(builder, st, op, &resolved_op_args),
|
||||
]
|
||||
},
|
||||
// Skip these if there are no resolved Merkle claims
|
||||
|
|
@ -386,6 +389,121 @@ impl OperationVerifyGadget {
|
|||
builder.all([op_code_ok, arg_types_ok, hash_value_ok, st_ok])
|
||||
}
|
||||
|
||||
fn eval_sum_of(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
st: &StatementTarget,
|
||||
op: &OperationTarget,
|
||||
resolved_op_args: &[StatementTarget],
|
||||
) -> BoolTarget {
|
||||
let value_zero = ValueTarget::zero(builder);
|
||||
|
||||
let op_code_ok = op.has_native_type(builder, NativeOperation::SumOf);
|
||||
|
||||
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) =
|
||||
self.first_n_args_as_values(builder, resolved_op_args);
|
||||
|
||||
// Select to avoid overflow.
|
||||
let summand1 = builder.select_value(op_code_ok, arg2_value, value_zero);
|
||||
let summand2 = builder.select_value(op_code_ok, arg3_value, value_zero);
|
||||
|
||||
let expected_sum = builder.i64_add(summand1, summand2);
|
||||
|
||||
let sum_ok = builder.is_equal_slice(&arg1_value.elements, &expected_sum.elements);
|
||||
|
||||
let arg1_key = resolved_op_args[0].args[0].clone();
|
||||
let arg2_key = resolved_op_args[1].args[0].clone();
|
||||
let arg3_key = resolved_op_args[2].args[0].clone();
|
||||
let expected_statement = StatementTarget::new_native(
|
||||
builder,
|
||||
&self.params,
|
||||
NativePredicate::SumOf,
|
||||
&[arg1_key, arg2_key, arg3_key],
|
||||
);
|
||||
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
|
||||
|
||||
builder.all([op_code_ok, arg_types_ok, sum_ok, st_ok])
|
||||
}
|
||||
|
||||
fn eval_product_of(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
st: &StatementTarget,
|
||||
op: &OperationTarget,
|
||||
resolved_op_args: &[StatementTarget],
|
||||
) -> BoolTarget {
|
||||
let value_zero = ValueTarget::zero(builder);
|
||||
|
||||
let op_code_ok = op.has_native_type(builder, NativeOperation::ProductOf);
|
||||
|
||||
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) =
|
||||
self.first_n_args_as_values(builder, resolved_op_args);
|
||||
|
||||
// Select to avoid overflow.
|
||||
let factor1 = builder.select_value(op_code_ok, arg2_value, value_zero);
|
||||
let factor2 = builder.select_value(op_code_ok, arg3_value, value_zero);
|
||||
|
||||
let expected_product = builder.i64_mul(factor1, factor2);
|
||||
|
||||
let product_ok = builder.is_equal_slice(&arg1_value.elements, &expected_product.elements);
|
||||
|
||||
let arg1_key = resolved_op_args[0].args[0].clone();
|
||||
let arg2_key = resolved_op_args[1].args[0].clone();
|
||||
let arg3_key = resolved_op_args[2].args[0].clone();
|
||||
let expected_statement = StatementTarget::new_native(
|
||||
builder,
|
||||
&self.params,
|
||||
NativePredicate::ProductOf,
|
||||
&[arg1_key, arg2_key, arg3_key],
|
||||
);
|
||||
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
|
||||
|
||||
builder.all([op_code_ok, arg_types_ok, product_ok, st_ok])
|
||||
}
|
||||
|
||||
fn eval_max_of(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
st: &StatementTarget,
|
||||
op: &OperationTarget,
|
||||
resolved_op_args: &[StatementTarget],
|
||||
) -> BoolTarget {
|
||||
let op_code_ok = op.has_native_type(builder, NativeOperation::MaxOf);
|
||||
|
||||
let (arg_types_ok, [arg1_value, arg2_value, arg3_value]) =
|
||||
self.first_n_args_as_values(builder, resolved_op_args);
|
||||
|
||||
// Check that arg1_value is equal to one of the other two
|
||||
// values.
|
||||
let arg1_eq_arg2 = builder.is_equal_slice(&arg1_value.elements, &arg2_value.elements);
|
||||
let arg1_eq_arg3 = builder.is_equal_slice(&arg1_value.elements, &arg3_value.elements);
|
||||
|
||||
let all_eq = builder.and(arg1_eq_arg2, arg1_eq_arg3);
|
||||
let not_all_eq = builder.not(all_eq);
|
||||
|
||||
let arg1_check = builder.or(arg1_eq_arg2, arg1_eq_arg3);
|
||||
|
||||
// If it is not equal to any of the other two values, it must be greater than it.
|
||||
let lower_bound = builder.select_value(arg1_eq_arg2, arg3_value, arg2_value);
|
||||
|
||||
// Only check lower bound if not all args are equal.
|
||||
let lt_check_enabled = builder.and(not_all_eq, op_code_ok);
|
||||
builder.assert_i64_less_if(lt_check_enabled, lower_bound, arg1_value);
|
||||
|
||||
let arg1_key = resolved_op_args[0].args[0].clone();
|
||||
let arg2_key = resolved_op_args[1].args[0].clone();
|
||||
let arg3_key = resolved_op_args[2].args[0].clone();
|
||||
let expected_statement = StatementTarget::new_native(
|
||||
builder,
|
||||
&self.params,
|
||||
NativePredicate::MaxOf,
|
||||
&[arg1_key, arg2_key, arg3_key],
|
||||
);
|
||||
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
|
||||
|
||||
builder.all([op_code_ok, arg_types_ok, arg1_check, st_ok])
|
||||
}
|
||||
|
||||
fn eval_transitive_eq(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
|
|
@ -684,6 +802,8 @@ impl MainPodVerifyCircuit {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Not;
|
||||
|
||||
use plonky2::{
|
||||
field::{goldilocks_field::GoldilocksField, types::Field},
|
||||
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig},
|
||||
|
|
@ -693,6 +813,7 @@ mod tests {
|
|||
use crate::{
|
||||
backends::plonky2::{
|
||||
basetypes::C,
|
||||
circuits::common::tests::I64_TEST_PAIRS,
|
||||
mainpod::{OperationArg, OperationAux},
|
||||
primitives::merkletree::{MerkleClaimAndProof, MerkleTree},
|
||||
},
|
||||
|
|
@ -1236,6 +1357,185 @@ mod tests {
|
|||
operation_verify(st, op, prev_statements, vec![])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_verify_sumof() -> Result<()> {
|
||||
I64_TEST_PAIRS
|
||||
.into_iter()
|
||||
.flat_map(|(a, b)| {
|
||||
let (sum, overflow) = a.overflowing_add(b);
|
||||
overflow.not().then_some((a, b, sum))
|
||||
})
|
||||
.try_for_each(|(a, b, sum)| {
|
||||
let st1: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||
sum.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st2: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||
a.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st3: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||
b.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st: mainpod::Statement = Statement::SumOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||
)
|
||||
.into();
|
||||
let op = mainpod::Operation(
|
||||
OperationType::Native(NativeOperation::SumOf),
|
||||
vec![
|
||||
OperationArg::Index(0),
|
||||
OperationArg::Index(1),
|
||||
OperationArg::Index(2),
|
||||
],
|
||||
OperationAux::None,
|
||||
);
|
||||
let prev_statements = vec![st1, st2, st3];
|
||||
operation_verify(st, op, prev_statements, vec![])
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_verify_productof() -> Result<()> {
|
||||
I64_TEST_PAIRS
|
||||
.into_iter()
|
||||
.flat_map(|(a, b)| {
|
||||
let (prod, overflow) = a.overflowing_mul(b);
|
||||
overflow.not().then_some((a, b, prod))
|
||||
})
|
||||
.try_for_each(|(a, b, prod)| {
|
||||
let st1: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||
prod.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st2: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||
a.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st3: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||
b.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st: mainpod::Statement = Statement::ProductOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||
)
|
||||
.into();
|
||||
let op = mainpod::Operation(
|
||||
OperationType::Native(NativeOperation::ProductOf),
|
||||
vec![
|
||||
OperationArg::Index(0),
|
||||
OperationArg::Index(1),
|
||||
OperationArg::Index(2),
|
||||
],
|
||||
OperationAux::None,
|
||||
);
|
||||
let prev_statements = vec![st1, st2, st3];
|
||||
operation_verify(st, op, prev_statements, vec![])
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_verify_maxof() -> Result<()> {
|
||||
I64_TEST_PAIRS.into_iter().try_for_each(|(a, b)| {
|
||||
let max = i64::max(a, b);
|
||||
let st1: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||
max.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st2: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||
a.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st3: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||
b.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st: mainpod::Statement = Statement::MaxOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||
)
|
||||
.into();
|
||||
let op = mainpod::Operation(
|
||||
OperationType::Native(NativeOperation::MaxOf),
|
||||
vec![
|
||||
OperationArg::Index(0),
|
||||
OperationArg::Index(1),
|
||||
OperationArg::Index(2),
|
||||
],
|
||||
OperationAux::None,
|
||||
);
|
||||
let prev_statements = vec![st1, st2, st3];
|
||||
operation_verify(st, op, prev_statements, vec![])
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_verify_maxof_failures() {
|
||||
[(5, 3, 4), (5, 5, 8), (3, 4, 5)]
|
||||
.into_iter()
|
||||
.for_each(|(max, a, b)| {
|
||||
let st1: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||
max.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st2: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||
a.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st3: mainpod::Statement = Statement::ValueOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||
b.into(),
|
||||
)
|
||||
.into();
|
||||
|
||||
let st: mainpod::Statement = Statement::MaxOf(
|
||||
AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")),
|
||||
AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")),
|
||||
AnchoredKey::from((PodId(RawValue::from(256).into()), "!")),
|
||||
)
|
||||
.into();
|
||||
let op = mainpod::Operation(
|
||||
OperationType::Native(NativeOperation::MaxOf),
|
||||
vec![
|
||||
OperationArg::Index(0),
|
||||
OperationArg::Index(1),
|
||||
OperationArg::Index(2),
|
||||
],
|
||||
OperationAux::None,
|
||||
);
|
||||
let prev_statements = vec![st1, st2, st3];
|
||||
assert!(operation_verify(st, op, prev_statements, vec![]).is_err())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_verify_lt_to_neq() -> Result<()> {
|
||||
let st: mainpod::Statement = Statement::NotEqual(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue