chore(backend): implement some circuit op logic (#165)

* Initial circuit op work

* Fix copy op

* Add more ops

* Fixes

* Code review
This commit is contained in:
Ahmad Afuni 2025-03-26 03:40:23 +10:00 committed by GitHub
parent 3b2860beeb
commit 30f26a94ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 465 additions and 78 deletions

View file

@ -1,8 +1,12 @@
//! Common functionality to build Pod circuits with plonky2
use crate::backends::plonky2::basetypes::D;
use crate::backends::plonky2::mock::mainpod::Statement;
use crate::backends::plonky2::mock::mainpod::{Operation, OperationArg};
use crate::middleware::{Params, StatementArg, ToFields, Value, F, HASH_SIZE, VALUE_SIZE};
use crate::middleware::{
NativeOperation, NativePredicate, Params, Predicate, StatementArg, ToFields, Value, F,
HASH_SIZE, VALUE_SIZE,
};
use crate::middleware::{OPERATION_ARG_F_LEN, STATEMENT_ARG_F_LEN};
use anyhow::Result;
use plonky2::field::extension::Extendable;
@ -11,13 +15,42 @@ use plonky2::hash::hash_types::RichField;
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartialWitness, WitnessWrite};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use std::iter;
use std::{array, iter};
pub const CODE_SIZE: usize = HASH_SIZE + 2;
#[derive(Copy, Clone)]
pub struct ValueTarget {
pub elements: [Target; VALUE_SIZE],
}
impl ValueTarget {
pub fn zero(builder: &mut CircuitBuilder<F, D>) -> Self {
Self {
elements: [builder.zero(); VALUE_SIZE],
}
}
pub fn one(builder: &mut CircuitBuilder<F, D>) -> Self {
Self {
elements: array::from_fn(|i| {
if i == 0 {
builder.one()
} else {
builder.zero()
}
}),
}
}
pub fn from_slice(xs: &[Target]) -> Self {
assert_eq!(xs.len(), VALUE_SIZE);
Self {
elements: array::from_fn(|i| xs[i]),
}
}
}
#[derive(Clone)]
pub struct StatementTarget {
pub predicate: [Target; Params::predicate_size()],
@ -25,12 +58,24 @@ pub struct StatementTarget {
}
impl StatementTarget {
pub fn to_flattened(&self) -> Vec<Target> {
self.predicate
.iter()
.chain(self.args.iter().flatten())
.cloned()
.collect()
pub fn new_native(
builder: &mut CircuitBuilder<F, D>,
params: &Params,
predicate: NativePredicate,
args: &[[Target; STATEMENT_ARG_F_LEN]],
) -> Self {
let predicate_vec = builder.constants(&Predicate::Native(predicate).to_fields(params));
Self {
predicate: array::from_fn(|i| predicate_vec[i]),
args: args
.iter()
.map(|arg| *arg)
.chain(
iter::repeat([builder.zero(); STATEMENT_ARG_F_LEN])
.take(params.max_statement_args - args.len()),
)
.collect(),
}
}
pub fn set_targets(
@ -51,6 +96,16 @@ impl StatementTarget {
}
Ok(())
}
pub fn has_native_type(
&self,
builder: &mut CircuitBuilder<F, D>,
params: &Params,
t: NativePredicate,
) -> BoolTarget {
let st_code = builder.constants(&Predicate::Native(t).to_fields(params));
builder.is_equal_slice(&self.predicate, &st_code)
}
}
// TODO: Implement Operation::to_field to determine the size of each element
@ -79,6 +134,49 @@ impl OperationTarget {
}
Ok(())
}
pub fn has_native_type(
&self,
builder: &mut CircuitBuilder<F, D>,
t: NativeOperation,
) -> BoolTarget {
let one = builder.one();
let op_is_native = builder.is_equal(self.op_type[0], one);
let op_code = builder.constant(F::from_canonical_u64(t as u64));
let op_code_matches = builder.is_equal(self.op_type[1], op_code);
builder.and(op_is_native, op_code_matches)
}
}
/// Trait for target structs that may be converted to and from vectors
/// of targets.
pub trait Flattenable {
fn flatten(&self) -> Vec<Target>;
fn from_flattened(vs: &[Target]) -> Self;
}
impl Flattenable for StatementTarget {
fn flatten(&self) -> Vec<Target> {
self.predicate
.iter()
.chain(self.args.iter().flatten())
.cloned()
.collect()
}
fn from_flattened(v: &[Target]) -> Self {
let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN;
assert_eq!(
v.len(),
Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN
);
let predicate: [Target; Params::predicate_size()] = array::from_fn(|i| v[i]);
let args = (0..num_args)
.map(|i| array::from_fn(|j| v[Params::predicate_size() + i * STATEMENT_ARG_F_LEN + j]))
.collect();
Self { predicate, args }
}
}
pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
@ -91,11 +189,27 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget;
fn constant_value(&mut self, v: Value) -> ValueTarget;
fn is_equal_slice(&mut self, xs: &[Target], ys: &[Target]) -> BoolTarget;
// Convenience methods for checking values.
/// Checks whether `xs` is right-padded with 0s so as to represent a `Value`.
fn statement_arg_is_value(&mut self, xs: &[Target]) -> BoolTarget;
/// Checks whether `x < y` if `b` is true. This involves checking
/// that `x` and `y` each consist of two `u32` limbs.
fn assert_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget);
// Convenience methods for accessing and connecting elements of
// (vectors of) flattenables.
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T;
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T;
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T);
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget;
// Convenience methods for Boolean into-iters.
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderPod<F, D>
for CircuitBuilder<F, D>
{
impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]) {
assert_eq!(xs.len(), ys.len());
for (x, y) in xs.iter().zip(ys.iter()) {
@ -157,4 +271,110 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderPod<F, D>
self.and(ok, is_eq)
})
}
fn statement_arg_is_value(&mut self, xs: &[Target]) -> BoolTarget {
let zeros = iter::repeat(self.zero())
.take(STATEMENT_ARG_F_LEN - VALUE_SIZE)
.collect::<Vec<_>>();
self.is_equal_slice(&xs[VALUE_SIZE..], &zeros)
}
fn assert_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) {
const NUM_BITS: usize = 32;
// Lt assertion with 32-bit range check.
let assert_limb_lt = |builder: &mut Self, x, y| {
// Check that targets fit within `NUM_BITS` bits.
builder.range_check(x, NUM_BITS);
builder.range_check(y, NUM_BITS);
// Check that `y-1-x` fits within `NUM_BITS` bits.
let one = builder.one();
let y_minus_one = builder.sub(y, one);
let expr = builder.sub(y_minus_one, x);
builder.range_check(expr, NUM_BITS);
};
// If b is false, replace `x` and `y` with dummy values.
let zero = ValueTarget::zero(self);
let one = ValueTarget::one(self);
let x = self.select_value(b, x, zero);
let y = self.select_value(b, y, one);
// `x` and `y` should only have two limbs each.
x.elements
.into_iter()
.skip(2)
.for_each(|l| self.assert_zero(l));
y.elements
.into_iter()
.skip(2)
.for_each(|l| self.assert_zero(l));
let big_limbs_eq = self.is_equal(x.elements[1], y.elements[1]);
let lhs = self.select(big_limbs_eq, x.elements[0], x.elements[1]);
let rhs = self.select(big_limbs_eq, y.elements[0], y.elements[1]);
assert_limb_lt(self, lhs, rhs);
}
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T {
// TODO: Revisit this when we need more than 64 statements.
let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| {
assert!(v.len() <= 64);
builder.random_access(i, v.to_vec())
};
let matrix_row_ref = |builder: &mut CircuitBuilder<F, D>, m: &[Vec<Target>], i| {
let num_rows = m.len();
let num_columns = m
.get(0)
.map(|row| {
let row_len = row.len();
assert!(m.iter().all(|row| row.len() == row_len));
row_len
})
.unwrap_or(0);
(0..num_columns)
.map(|j| {
vector_ref(
builder,
&(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>(),
i,
)
})
.collect::<Vec<_>>()
};
let flattened_ts = ts.iter().map(|t| t.flatten()).collect::<Vec<_>>();
T::from_flattened(&matrix_row_ref(self, &flattened_ts, i))
}
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T {
let flattened_x = x.flatten();
let flattened_y = y.flatten();
T::from_flattened(
&iter::zip(flattened_x, flattened_y)
.map(|(x, y)| self.select(b, x, y))
.collect::<Vec<_>>(),
)
}
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) {
self.connect_slice(&xs.flatten(), &ys.flatten())
}
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget {
self.is_equal_slice(&xs.flatten(), &ys.flatten())
}
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget {
xs.into_iter()
.reduce(|a, b| self.and(a, b))
.unwrap_or(self._true())
}
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget {
xs.into_iter()
.reduce(|a, b| self.or(a, b))
.unwrap_or(self._false())
}
}