support longer arrays in vec_ref (#367)
Support arrays up to 256 elements (hardcoded maximum just to avoid abuse) by combining multiple random_accesses. The index is now split into low and high parts. It's a bit more inconvenient than using a single Target but this allows avoiding bit decomposition.
This commit is contained in:
parent
bde35369d3
commit
4fa285d9fb
5 changed files with 175 additions and 61 deletions
|
|
@ -33,7 +33,7 @@ use crate::{
|
|||
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
|
||||
NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg,
|
||||
StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F,
|
||||
HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE,
|
||||
HASH_SIZE, STATEMENT_ARG_F_LEN, VALUE_SIZE,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -255,9 +255,9 @@ impl OperationTypeTarget {
|
|||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct OperationTarget {
|
||||
pub op_type: OperationTypeTarget,
|
||||
pub args: Vec<[Target; OPERATION_ARG_F_LEN]>,
|
||||
pub args: Vec<IndexTarget>,
|
||||
#[serde(with = "serde_arrays")]
|
||||
pub aux: [Target; OPERATION_AUX_F_LEN],
|
||||
pub aux: [IndexTarget; 2],
|
||||
}
|
||||
|
||||
impl OperationTarget {
|
||||
|
|
@ -275,9 +275,12 @@ impl OperationTarget {
|
|||
.take(params.max_operation_args)
|
||||
.enumerate()
|
||||
{
|
||||
pw.set_target_arr(&self.args[i], &arg.to_fields(params))?;
|
||||
self.args[i].set_targets(pw, arg.as_usize())?;
|
||||
}
|
||||
let indexes = op.aux().as_usizes();
|
||||
for (index_target, index) in self.aux.iter().zip_eq(indexes.iter()) {
|
||||
index_target.set_targets(pw, *index)?;
|
||||
}
|
||||
pw.set_target_arr(&self.aux, &op.aux().to_fields(params))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
@ -584,7 +587,7 @@ impl CustomPredicateEntryTarget {
|
|||
// Custom predicate verification table entry
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct CustomPredicateVerifyEntryTarget {
|
||||
pub custom_predicate_table_index: Target,
|
||||
pub custom_predicate_table_index: IndexTarget,
|
||||
pub custom_predicate: CustomPredicateEntryTarget,
|
||||
pub args: Vec<ValueTarget>,
|
||||
pub op_args: Vec<StatementTarget>,
|
||||
|
|
@ -592,8 +595,13 @@ pub struct CustomPredicateVerifyEntryTarget {
|
|||
|
||||
impl CustomPredicateVerifyEntryTarget {
|
||||
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
|
||||
let custom_predicate_table_len =
|
||||
params.max_custom_predicate_batches * params.max_custom_batch_size;
|
||||
CustomPredicateVerifyEntryTarget {
|
||||
custom_predicate_table_index: builder.add_virtual_target(),
|
||||
custom_predicate_table_index: IndexTarget::new_virtual(
|
||||
custom_predicate_table_len,
|
||||
builder,
|
||||
),
|
||||
custom_predicate: builder.add_virtual_custom_predicate_entry(params),
|
||||
args: (0..params.max_custom_predicate_wildcards)
|
||||
.map(|_| builder.add_virtual_value())
|
||||
|
|
@ -609,10 +617,8 @@ impl CustomPredicateVerifyEntryTarget {
|
|||
params: &Params,
|
||||
cpv: &CustomPredicateVerification,
|
||||
) -> Result<()> {
|
||||
pw.set_target(
|
||||
self.custom_predicate_table_index,
|
||||
F::from_canonical_usize(cpv.custom_predicate_table_index),
|
||||
)?;
|
||||
self.custom_predicate_table_index
|
||||
.set_targets(pw, cpv.custom_predicate_table_index)?;
|
||||
// Replace statement templates of batch-self with (id,index)
|
||||
self.custom_predicate
|
||||
.set_targets(pw, params, &cpv.custom_predicate)?;
|
||||
|
|
@ -665,7 +671,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget {
|
|||
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
|
||||
let (pos, size) = (0, params.statement_size());
|
||||
let statement = StatementTarget::from_flattened(params, &vs[pos..pos + size]);
|
||||
let (pos, size) = (pos + size, params.operation_size());
|
||||
let (pos, size) = (pos + size, params.operation_size(IndexTarget::f_len()));
|
||||
let op_type = OperationTypeTarget {
|
||||
elements: vs[pos..pos + size]
|
||||
.try_into()
|
||||
|
|
@ -867,6 +873,41 @@ impl Flattenable for StatementTmplArgTarget {
|
|||
}
|
||||
}
|
||||
|
||||
/// Index to an array for random access
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct IndexTarget {
|
||||
max_array_len: usize,
|
||||
low: Target,
|
||||
high: Target,
|
||||
}
|
||||
|
||||
impl IndexTarget {
|
||||
// Length in field elements
|
||||
pub const fn f_len() -> usize {
|
||||
2
|
||||
}
|
||||
pub fn new_virtual(max_array_len: usize, builder: &mut CircuitBuilder) -> Self {
|
||||
// Limit the maximum array length to avoid abusing `vec_ref`
|
||||
assert!(max_array_len <= 256);
|
||||
Self {
|
||||
max_array_len,
|
||||
low: builder.add_virtual_target(),
|
||||
high: if max_array_len > 64 {
|
||||
builder.add_virtual_target()
|
||||
} else {
|
||||
builder.zero()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_targets(&self, pw: &mut PartialWitness<F>, index: usize) -> Result<()> {
|
||||
assert!(index == 0 || index < self.max_array_len);
|
||||
pw.set_target(self.low, F::from_canonical_usize(index & ((1 << 6) - 1)))?;
|
||||
pw.set_target(self.high, F::from_canonical_usize(index >> 6))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
||||
fn connect_values(&mut self, x: ValueTarget, y: ValueTarget);
|
||||
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
|
||||
|
|
@ -932,9 +973,14 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
|||
/// Creates value target that is a hash of two given values.
|
||||
fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
||||
|
||||
// Convenience methods for accessing and connecting elements of
|
||||
// (vectors of) flattenables.
|
||||
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
|
||||
/// Like `random_access` but allows using longer arrays.
|
||||
fn random_access_long(&mut self, i: &IndexTarget, array: &[Target]) -> Target;
|
||||
|
||||
/// Convenience methods for accessing and connecting elements of
|
||||
/// (vectors of) flattenables.
|
||||
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T;
|
||||
/// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target`
|
||||
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
|
||||
fn select_flattenable<T: Flattenable>(
|
||||
&mut self,
|
||||
params: &Params,
|
||||
|
|
@ -945,11 +991,11 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
|||
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T);
|
||||
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget;
|
||||
|
||||
// Convenience methods for Boolean into-iters.
|
||||
/// Convenience methods for Boolean into-iters.
|
||||
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
|
||||
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
|
||||
|
||||
// Return a bit-mask of size `len` that selects all positions lower than `n`
|
||||
/// Return a bit-mask of size `len` that selects all positions lower than `n`
|
||||
fn lt_mask(&mut self, len: usize, n: Target) -> Vec<BoolTarget>;
|
||||
}
|
||||
|
||||
|
|
@ -1003,9 +1049,12 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
|
|||
OperationTarget {
|
||||
op_type: self.add_virtual_operation_type(),
|
||||
args: (0..params.max_operation_args)
|
||||
.map(|_| self.add_virtual_target_arr())
|
||||
.map(|_| IndexTarget::new_virtual(params.statement_table_size(), self))
|
||||
.collect(),
|
||||
aux: self.add_virtual_target_arr(),
|
||||
aux: [
|
||||
IndexTarget::new_virtual(params.max_merkle_proofs_containers, self),
|
||||
IndexTarget::new_virtual(params.max_custom_predicate_verifications, self),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1293,18 +1342,41 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
|
|||
)
|
||||
}
|
||||
|
||||
fn random_access_long(&mut self, i: &IndexTarget, array: &[Target]) -> Target {
|
||||
const CHUNK_LEN: usize = 64; // Max size of a single gate native random access
|
||||
assert!(array.len() <= i.max_array_len);
|
||||
// Limit to 4 chunks (combination of 4 random_access of CHUNK_LEN elements) to avoid
|
||||
// abusing this method.
|
||||
assert!(array.len() <= 4 * CHUNK_LEN);
|
||||
|
||||
// We do several random accesses over chunks of CHUNK_LEN using the lowest bits of the
|
||||
// index. Then we combine them using the highest bits of the index.
|
||||
let mut chunk_res = Vec::new();
|
||||
let num_chunks = array.len().div_ceil(CHUNK_LEN);
|
||||
for chunk in array.chunks(CHUNK_LEN) {
|
||||
let mut index_chunk = i.low;
|
||||
// I we have several chunks and the last one is smaller (it's index needs less than 6
|
||||
// bits), make it zero except when it's used so that the range check over the index
|
||||
// passes.
|
||||
if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 {
|
||||
let last_chunk_index_high = self.constant(F::from_canonical_usize(num_chunks - 1));
|
||||
let selector = self.is_equal(i.high, last_chunk_index_high);
|
||||
index_chunk = self.mul(index_chunk, selector.target);
|
||||
}
|
||||
let res = self.random_access(index_chunk, chunk.to_vec());
|
||||
chunk_res.push(res);
|
||||
}
|
||||
|
||||
self.random_access(i.high, chunk_res)
|
||||
}
|
||||
|
||||
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
|
||||
// The idea would be the following: Take the array `ts` and hash each element. Then do the
|
||||
// random access on the hash result. Finally "unhash" to recover the resolved element.
|
||||
// We don't want to hash each element from the array each time, so we should cache the hashed
|
||||
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
|
||||
// then do `ts: &[HashCache<T>]`.
|
||||
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
|
||||
// TODO: Revisit this when we need more than 64 statements.
|
||||
let vector_ref = |builder: &mut CircuitBuilder, v: &[Target], i| {
|
||||
assert!(v.len() <= 64);
|
||||
builder.random_access(i, v.to_vec())
|
||||
};
|
||||
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T {
|
||||
let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| {
|
||||
let num_rows = m.len();
|
||||
let num_columns = m
|
||||
|
|
@ -1317,11 +1389,8 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
|
|||
.unwrap_or(0);
|
||||
(0..num_columns)
|
||||
.map(|j| {
|
||||
vector_ref(
|
||||
builder,
|
||||
&(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>(),
|
||||
i,
|
||||
)
|
||||
builder
|
||||
.random_access_long(i, &(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>())
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
|
@ -1330,6 +1399,19 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
|
|||
T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
|
||||
}
|
||||
|
||||
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
|
||||
let zero = self.zero();
|
||||
self.vec_ref(
|
||||
params,
|
||||
ts,
|
||||
&IndexTarget {
|
||||
max_array_len: 64,
|
||||
low: i,
|
||||
high: zero,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn select_flattenable<T: Flattenable>(
|
||||
&mut self,
|
||||
params: &Params,
|
||||
|
|
@ -1639,4 +1721,36 @@ pub(crate) mod tests {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_access_long() -> Result<(), anyhow::Error> {
|
||||
let lens: [usize; _] = [10, 60, 64, 96, 126, 159, 190, 256];
|
||||
|
||||
for len in &lens {
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let array = builder.add_virtual_targets(*len);
|
||||
let index_target = IndexTarget::new_virtual(*len, &mut builder);
|
||||
let res = builder.random_access_long(&index_target, &array);
|
||||
|
||||
let data = builder.build::<PoseidonGoldilocksConfig>();
|
||||
|
||||
for i in 0..3 {
|
||||
let index = (len - 1) * i / 2;
|
||||
println!("len={}, index={}", len, index);
|
||||
let mut pw = PartialWitness::<F>::new();
|
||||
for (j, elem) in array.iter().enumerate() {
|
||||
pw.set_target(*elem, F::from_canonical_usize(j * 11))?;
|
||||
}
|
||||
index_target.set_targets(&mut pw, index)?;
|
||||
pw.set_target(res, F::from_canonical_usize(index * 11))?; // Expected
|
||||
|
||||
let proof = data.prove(pw)?;
|
||||
data.verify(proof)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -95,8 +95,7 @@ impl StatementCache {
|
|||
// converting a length 1 array into a scalar.
|
||||
op.args
|
||||
.iter()
|
||||
.flatten()
|
||||
.map(|&i| builder.vec_ref(params, prev_statements, i))
|
||||
.map(|i| builder.vec_ref(params, prev_statements, i))
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
assert!(params.max_operation_args >= 3);
|
||||
|
|
@ -224,7 +223,7 @@ fn verify_operation_circuit(
|
|||
// been verified, so we need only look up the claim.
|
||||
let measure_resolve_merkle_claim = measure_gates_begin!(builder, "ResolveMerkleClaim");
|
||||
let resolved_merkle_claim =
|
||||
(!merkle_claims.is_empty()).then(|| builder.vec_ref(params, merkle_claims, op.aux[0]));
|
||||
(!merkle_claims.is_empty()).then(|| builder.vec_ref(params, merkle_claims, &op.aux[0]));
|
||||
measure_gates_end!(builder, measure_resolve_merkle_claim);
|
||||
|
||||
// Operations from custom statements will refer to one
|
||||
|
|
@ -233,7 +232,7 @@ fn verify_operation_circuit(
|
|||
let measure_resolve_custom_pred_verification =
|
||||
measure_gates_begin!(builder, "ResolveCustomPredVerification");
|
||||
let resolved_custom_pred_verification = (!custom_predicate_verification_table.is_empty())
|
||||
.then(|| builder.vec_ref(params, custom_predicate_verification_table, op.aux[1]));
|
||||
.then(|| builder.vec_ref(params, custom_predicate_verification_table, &op.aux[1]));
|
||||
measure_gates_end!(builder, measure_resolve_custom_pred_verification);
|
||||
|
||||
// The verification may require aux data which needs to be stored in the
|
||||
|
|
@ -938,7 +937,7 @@ fn make_statement_arg_from_template_circuit(
|
|||
let first_index = ak_id_wc_index;
|
||||
let is_first_index_valid = builder.or(is_ak, is_wc_literal);
|
||||
let first_index = builder.select(is_first_index_valid, first_index, zero);
|
||||
let resolved_ak_id = builder.vec_ref(params, &args, first_index);
|
||||
let resolved_ak_id = builder.vec_ref_small(params, &args, first_index);
|
||||
let resolved_wc = resolved_ak_id;
|
||||
|
||||
// If the index is not used, use a 0 instead to still pass the range constraints from
|
||||
|
|
@ -946,7 +945,7 @@ fn make_statement_arg_from_template_circuit(
|
|||
let second_index = ak_key_wc_index;
|
||||
let is_second_index_valid = builder.and(is_ak, is_ak_key_wc);
|
||||
let second_index = builder.select(is_second_index_valid, second_index, zero);
|
||||
let resolved_ak_key = builder.vec_ref(params, &args, second_index);
|
||||
let resolved_ak_key = builder.vec_ref_small(params, &args, second_index);
|
||||
|
||||
let ak_key = ak_key_lit; // is_ak_key_lit
|
||||
let ak_key = builder.select_flattenable(params, is_ak_key_wc, &resolved_ak_key, &ak_key);
|
||||
|
|
@ -1244,7 +1243,7 @@ fn build_custom_predicate_verification_table_circuit(
|
|||
let table_query_hash = builder.vec_ref(
|
||||
params,
|
||||
custom_predicate_table,
|
||||
entry.custom_predicate_table_index,
|
||||
&entry.custom_predicate_table_index,
|
||||
);
|
||||
let out_query_hash = entry.custom_predicate.hash(builder);
|
||||
builder.connect_array(table_query_hash.elements, out_query_hash.elements);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
use std::fmt;
|
||||
|
||||
use plonky2::field::types::Field;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
|
|
@ -9,7 +8,7 @@ use crate::{
|
|||
mainpod::Statement,
|
||||
primitives::merkletree::MerkleClaimAndProof,
|
||||
},
|
||||
middleware::{self, OperationType, Params, ToFields, F},
|
||||
middleware::{self, OperationType},
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
|
|
@ -18,20 +17,17 @@ pub enum OperationArg {
|
|||
Index(usize),
|
||||
}
|
||||
|
||||
impl ToFields for OperationArg {
|
||||
fn to_fields(&self, _params: &Params) -> Vec<F> {
|
||||
let f = match self {
|
||||
Self::None => F::ZERO,
|
||||
Self::Index(i) => F::from_canonical_usize(*i),
|
||||
};
|
||||
vec![f]
|
||||
}
|
||||
}
|
||||
|
||||
impl OperationArg {
|
||||
pub fn is_none(&self) -> bool {
|
||||
matches!(self, OperationArg::None)
|
||||
}
|
||||
|
||||
pub fn as_usize(&self) -> usize {
|
||||
match self {
|
||||
Self::None => 0,
|
||||
Self::Index(i) => *i,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
|
|
@ -41,14 +37,13 @@ pub enum OperationAux {
|
|||
CustomPredVerifyIndex(usize),
|
||||
}
|
||||
|
||||
impl ToFields for OperationAux {
|
||||
fn to_fields(&self, _params: &Params) -> Vec<F> {
|
||||
let fs = match self {
|
||||
Self::None => [F::ZERO, F::ZERO],
|
||||
Self::MerkleProofIndex(i) => [F::from_canonical_usize(*i), F::ZERO],
|
||||
Self::CustomPredVerifyIndex(i) => [F::ZERO, F::from_canonical_usize(*i)],
|
||||
};
|
||||
vec![fs[0], fs[1]]
|
||||
impl OperationAux {
|
||||
pub fn as_usizes(&self) -> [usize; 2] {
|
||||
match self {
|
||||
Self::None => [0, 0],
|
||||
Self::MerkleProofIndex(i) => [*i, 0],
|
||||
Self::CustomPredVerifyIndex(i) => [0, *i],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -828,8 +828,8 @@ impl Params {
|
|||
Self::predicate_size() + STATEMENT_ARG_F_LEN * self.max_statement_args
|
||||
}
|
||||
|
||||
pub fn operation_size(&self) -> usize {
|
||||
Self::operation_type_size() + OPERATION_ARG_F_LEN * self.max_operation_args
|
||||
pub fn operation_size(&self, operation_arg_f_len: usize) -> usize {
|
||||
Self::operation_type_size() + operation_arg_f_len * self.max_operation_args
|
||||
}
|
||||
|
||||
pub const fn statement_tmpl_size(&self) -> usize {
|
||||
|
|
@ -844,6 +844,14 @@ impl Params {
|
|||
self.max_custom_batch_size * self.custom_predicate_size()
|
||||
}
|
||||
|
||||
/// Total size of the statement table including None, input statements from signed pods and
|
||||
/// input recursive pods and new statements (public & private)
|
||||
pub fn statement_table_size(&self) -> usize {
|
||||
1 + self.max_input_signed_pods * self.max_signed_pod_values
|
||||
+ self.max_input_recursive_pods * self.max_input_pods_public_statements
|
||||
+ self.max_statements
|
||||
}
|
||||
|
||||
/// Parameters that define how the id is calculated
|
||||
pub fn id_params(&self) -> Vec<usize> {
|
||||
vec![
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@ pub const KEY_SIGNER: &str = "_signer";
|
|||
// hash(KEY_TYPE) = [17948789436443445142, 12513915140657440811, 15878361618879468769, 938231894693848619]
|
||||
pub const KEY_TYPE: &str = "_type";
|
||||
pub const STATEMENT_ARG_F_LEN: usize = 8;
|
||||
pub const OPERATION_ARG_F_LEN: usize = 1;
|
||||
pub const OPERATION_AUX_F_LEN: usize = 2;
|
||||
|
||||
#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||
pub enum NativePredicate {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue