chore(backend): implement some circuit op logic (#165)
* Initial circuit op work * Fix copy op * Add more ops * Fixes * Code review
This commit is contained in:
parent
3b2860beeb
commit
30f26a94ef
2 changed files with 465 additions and 78 deletions
|
|
@ -1,8 +1,12 @@
|
|||
//! Common functionality to build Pod circuits with plonky2
|
||||
|
||||
use crate::backends::plonky2::basetypes::D;
|
||||
use crate::backends::plonky2::mock::mainpod::Statement;
|
||||
use crate::backends::plonky2::mock::mainpod::{Operation, OperationArg};
|
||||
use crate::middleware::{Params, StatementArg, ToFields, Value, F, HASH_SIZE, VALUE_SIZE};
|
||||
use crate::middleware::{
|
||||
NativeOperation, NativePredicate, Params, Predicate, StatementArg, ToFields, Value, F,
|
||||
HASH_SIZE, VALUE_SIZE,
|
||||
};
|
||||
use crate::middleware::{OPERATION_ARG_F_LEN, STATEMENT_ARG_F_LEN};
|
||||
use anyhow::Result;
|
||||
use plonky2::field::extension::Extendable;
|
||||
|
|
@ -11,13 +15,42 @@ use plonky2::hash::hash_types::RichField;
|
|||
use plonky2::iop::target::{BoolTarget, Target};
|
||||
use plonky2::iop::witness::{PartialWitness, WitnessWrite};
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
use std::iter;
|
||||
use std::{array, iter};
|
||||
|
||||
pub const CODE_SIZE: usize = HASH_SIZE + 2;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ValueTarget {
|
||||
pub elements: [Target; VALUE_SIZE],
|
||||
}
|
||||
|
||||
impl ValueTarget {
|
||||
pub fn zero(builder: &mut CircuitBuilder<F, D>) -> Self {
|
||||
Self {
|
||||
elements: [builder.zero(); VALUE_SIZE],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn one(builder: &mut CircuitBuilder<F, D>) -> Self {
|
||||
Self {
|
||||
elements: array::from_fn(|i| {
|
||||
if i == 0 {
|
||||
builder.one()
|
||||
} else {
|
||||
builder.zero()
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_slice(xs: &[Target]) -> Self {
|
||||
assert_eq!(xs.len(), VALUE_SIZE);
|
||||
Self {
|
||||
elements: array::from_fn(|i| xs[i]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StatementTarget {
|
||||
pub predicate: [Target; Params::predicate_size()],
|
||||
|
|
@ -25,12 +58,24 @@ pub struct StatementTarget {
|
|||
}
|
||||
|
||||
impl StatementTarget {
|
||||
pub fn to_flattened(&self) -> Vec<Target> {
|
||||
self.predicate
|
||||
.iter()
|
||||
.chain(self.args.iter().flatten())
|
||||
.cloned()
|
||||
.collect()
|
||||
pub fn new_native(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
params: &Params,
|
||||
predicate: NativePredicate,
|
||||
args: &[[Target; STATEMENT_ARG_F_LEN]],
|
||||
) -> Self {
|
||||
let predicate_vec = builder.constants(&Predicate::Native(predicate).to_fields(params));
|
||||
Self {
|
||||
predicate: array::from_fn(|i| predicate_vec[i]),
|
||||
args: args
|
||||
.iter()
|
||||
.map(|arg| *arg)
|
||||
.chain(
|
||||
iter::repeat([builder.zero(); STATEMENT_ARG_F_LEN])
|
||||
.take(params.max_statement_args - args.len()),
|
||||
)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_targets(
|
||||
|
|
@ -51,6 +96,16 @@ impl StatementTarget {
|
|||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn has_native_type(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
params: &Params,
|
||||
t: NativePredicate,
|
||||
) -> BoolTarget {
|
||||
let st_code = builder.constants(&Predicate::Native(t).to_fields(params));
|
||||
builder.is_equal_slice(&self.predicate, &st_code)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Implement Operation::to_field to determine the size of each element
|
||||
|
|
@ -79,6 +134,49 @@ impl OperationTarget {
|
|||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn has_native_type(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
t: NativeOperation,
|
||||
) -> BoolTarget {
|
||||
let one = builder.one();
|
||||
let op_is_native = builder.is_equal(self.op_type[0], one);
|
||||
let op_code = builder.constant(F::from_canonical_u64(t as u64));
|
||||
let op_code_matches = builder.is_equal(self.op_type[1], op_code);
|
||||
builder.and(op_is_native, op_code_matches)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for target structs that may be converted to and from vectors
|
||||
/// of targets.
|
||||
pub trait Flattenable {
|
||||
fn flatten(&self) -> Vec<Target>;
|
||||
fn from_flattened(vs: &[Target]) -> Self;
|
||||
}
|
||||
|
||||
impl Flattenable for StatementTarget {
|
||||
fn flatten(&self) -> Vec<Target> {
|
||||
self.predicate
|
||||
.iter()
|
||||
.chain(self.args.iter().flatten())
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn from_flattened(v: &[Target]) -> Self {
|
||||
let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN;
|
||||
assert_eq!(
|
||||
v.len(),
|
||||
Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN
|
||||
);
|
||||
let predicate: [Target; Params::predicate_size()] = array::from_fn(|i| v[i]);
|
||||
let args = (0..num_args)
|
||||
.map(|i| array::from_fn(|j| v[Params::predicate_size() + i * STATEMENT_ARG_F_LEN + j]))
|
||||
.collect();
|
||||
|
||||
Self { predicate, args }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
||||
|
|
@ -91,11 +189,27 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
|||
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget;
|
||||
fn constant_value(&mut self, v: Value) -> ValueTarget;
|
||||
fn is_equal_slice(&mut self, xs: &[Target], ys: &[Target]) -> BoolTarget;
|
||||
|
||||
// Convenience methods for checking values.
|
||||
/// Checks whether `xs` is right-padded with 0s so as to represent a `Value`.
|
||||
fn statement_arg_is_value(&mut self, xs: &[Target]) -> BoolTarget;
|
||||
/// Checks whether `x < y` if `b` is true. This involves checking
|
||||
/// that `x` and `y` each consist of two `u32` limbs.
|
||||
fn assert_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget);
|
||||
|
||||
// Convenience methods for accessing and connecting elements of
|
||||
// (vectors of) flattenables.
|
||||
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T;
|
||||
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T;
|
||||
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T);
|
||||
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget;
|
||||
|
||||
// Convenience methods for Boolean into-iters.
|
||||
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
|
||||
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderPod<F, D>
|
||||
for CircuitBuilder<F, D>
|
||||
{
|
||||
impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
|
||||
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]) {
|
||||
assert_eq!(xs.len(), ys.len());
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
|
|
@ -157,4 +271,110 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderPod<F, D>
|
|||
self.and(ok, is_eq)
|
||||
})
|
||||
}
|
||||
|
||||
fn statement_arg_is_value(&mut self, xs: &[Target]) -> BoolTarget {
|
||||
let zeros = iter::repeat(self.zero())
|
||||
.take(STATEMENT_ARG_F_LEN - VALUE_SIZE)
|
||||
.collect::<Vec<_>>();
|
||||
self.is_equal_slice(&xs[VALUE_SIZE..], &zeros)
|
||||
}
|
||||
|
||||
fn assert_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) {
|
||||
const NUM_BITS: usize = 32;
|
||||
|
||||
// Lt assertion with 32-bit range check.
|
||||
let assert_limb_lt = |builder: &mut Self, x, y| {
|
||||
// Check that targets fit within `NUM_BITS` bits.
|
||||
builder.range_check(x, NUM_BITS);
|
||||
builder.range_check(y, NUM_BITS);
|
||||
// Check that `y-1-x` fits within `NUM_BITS` bits.
|
||||
let one = builder.one();
|
||||
let y_minus_one = builder.sub(y, one);
|
||||
let expr = builder.sub(y_minus_one, x);
|
||||
builder.range_check(expr, NUM_BITS);
|
||||
};
|
||||
|
||||
// If b is false, replace `x` and `y` with dummy values.
|
||||
let zero = ValueTarget::zero(self);
|
||||
let one = ValueTarget::one(self);
|
||||
let x = self.select_value(b, x, zero);
|
||||
let y = self.select_value(b, y, one);
|
||||
|
||||
// `x` and `y` should only have two limbs each.
|
||||
x.elements
|
||||
.into_iter()
|
||||
.skip(2)
|
||||
.for_each(|l| self.assert_zero(l));
|
||||
y.elements
|
||||
.into_iter()
|
||||
.skip(2)
|
||||
.for_each(|l| self.assert_zero(l));
|
||||
|
||||
let big_limbs_eq = self.is_equal(x.elements[1], y.elements[1]);
|
||||
let lhs = self.select(big_limbs_eq, x.elements[0], x.elements[1]);
|
||||
let rhs = self.select(big_limbs_eq, y.elements[0], y.elements[1]);
|
||||
assert_limb_lt(self, lhs, rhs);
|
||||
}
|
||||
|
||||
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T {
|
||||
// TODO: Revisit this when we need more than 64 statements.
|
||||
let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| {
|
||||
assert!(v.len() <= 64);
|
||||
builder.random_access(i, v.to_vec())
|
||||
};
|
||||
let matrix_row_ref = |builder: &mut CircuitBuilder<F, D>, m: &[Vec<Target>], i| {
|
||||
let num_rows = m.len();
|
||||
let num_columns = m
|
||||
.get(0)
|
||||
.map(|row| {
|
||||
let row_len = row.len();
|
||||
assert!(m.iter().all(|row| row.len() == row_len));
|
||||
row_len
|
||||
})
|
||||
.unwrap_or(0);
|
||||
(0..num_columns)
|
||||
.map(|j| {
|
||||
vector_ref(
|
||||
builder,
|
||||
&(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>(),
|
||||
i,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let flattened_ts = ts.iter().map(|t| t.flatten()).collect::<Vec<_>>();
|
||||
T::from_flattened(&matrix_row_ref(self, &flattened_ts, i))
|
||||
}
|
||||
|
||||
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T {
|
||||
let flattened_x = x.flatten();
|
||||
let flattened_y = y.flatten();
|
||||
|
||||
T::from_flattened(
|
||||
&iter::zip(flattened_x, flattened_y)
|
||||
.map(|(x, y)| self.select(b, x, y))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) {
|
||||
self.connect_slice(&xs.flatten(), &ys.flatten())
|
||||
}
|
||||
|
||||
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget {
|
||||
self.is_equal_slice(&xs.flatten(), &ys.flatten())
|
||||
}
|
||||
|
||||
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget {
|
||||
xs.into_iter()
|
||||
.reduce(|a, b| self.and(a, b))
|
||||
.unwrap_or(self._true())
|
||||
}
|
||||
|
||||
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget {
|
||||
xs.into_iter()
|
||||
.reduce(|a, b| self.or(a, b))
|
||||
.unwrap_or(self._false())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue