support longer arrays in vec_ref (#367)

Support arrays up to 256 elements (hardcoded maximum just to avoid abuse) by combining multiple random_accesses.
The index is now split into low and high parts. It's a bit more inconvenient than using a single Target but this allows avoiding bit decomposition.
This commit is contained in:
Eduard S. 2025-07-30 16:07:25 -07:00 committed by GitHub
parent bde35369d3
commit 4fa285d9fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 175 additions and 61 deletions

View file

@ -33,7 +33,7 @@ use crate::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg,
StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F,
HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE,
HASH_SIZE, STATEMENT_ARG_F_LEN, VALUE_SIZE,
},
};
@ -255,9 +255,9 @@ impl OperationTypeTarget {
#[derive(Clone, Serialize, Deserialize)]
pub struct OperationTarget {
pub op_type: OperationTypeTarget,
pub args: Vec<[Target; OPERATION_ARG_F_LEN]>,
pub args: Vec<IndexTarget>,
#[serde(with = "serde_arrays")]
pub aux: [Target; OPERATION_AUX_F_LEN],
pub aux: [IndexTarget; 2],
}
impl OperationTarget {
@ -275,9 +275,12 @@ impl OperationTarget {
.take(params.max_operation_args)
.enumerate()
{
pw.set_target_arr(&self.args[i], &arg.to_fields(params))?;
self.args[i].set_targets(pw, arg.as_usize())?;
}
let indexes = op.aux().as_usizes();
for (index_target, index) in self.aux.iter().zip_eq(indexes.iter()) {
index_target.set_targets(pw, *index)?;
}
pw.set_target_arr(&self.aux, &op.aux().to_fields(params))?;
Ok(())
}
}
@ -584,7 +587,7 @@ impl CustomPredicateEntryTarget {
// Custom predicate verification table entry
#[derive(Clone, Serialize, Deserialize)]
pub struct CustomPredicateVerifyEntryTarget {
pub custom_predicate_table_index: Target,
pub custom_predicate_table_index: IndexTarget,
pub custom_predicate: CustomPredicateEntryTarget,
pub args: Vec<ValueTarget>,
pub op_args: Vec<StatementTarget>,
@ -592,8 +595,13 @@ pub struct CustomPredicateVerifyEntryTarget {
impl CustomPredicateVerifyEntryTarget {
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
let custom_predicate_table_len =
params.max_custom_predicate_batches * params.max_custom_batch_size;
CustomPredicateVerifyEntryTarget {
custom_predicate_table_index: builder.add_virtual_target(),
custom_predicate_table_index: IndexTarget::new_virtual(
custom_predicate_table_len,
builder,
),
custom_predicate: builder.add_virtual_custom_predicate_entry(params),
args: (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value())
@ -609,10 +617,8 @@ impl CustomPredicateVerifyEntryTarget {
params: &Params,
cpv: &CustomPredicateVerification,
) -> Result<()> {
pw.set_target(
self.custom_predicate_table_index,
F::from_canonical_usize(cpv.custom_predicate_table_index),
)?;
self.custom_predicate_table_index
.set_targets(pw, cpv.custom_predicate_table_index)?;
// Replace statement templates of batch-self with (id,index)
self.custom_predicate
.set_targets(pw, params, &cpv.custom_predicate)?;
@ -665,7 +671,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget {
fn from_flattened(params: &Params, vs: &[Target]) -> Self {
let (pos, size) = (0, params.statement_size());
let statement = StatementTarget::from_flattened(params, &vs[pos..pos + size]);
let (pos, size) = (pos + size, params.operation_size());
let (pos, size) = (pos + size, params.operation_size(IndexTarget::f_len()));
let op_type = OperationTypeTarget {
elements: vs[pos..pos + size]
.try_into()
@ -867,6 +873,41 @@ impl Flattenable for StatementTmplArgTarget {
}
}
/// Index to an array for random access
#[derive(Clone, Serialize, Deserialize)]
pub struct IndexTarget {
max_array_len: usize,
low: Target,
high: Target,
}
impl IndexTarget {
// Length in field elements
pub const fn f_len() -> usize {
2
}
pub fn new_virtual(max_array_len: usize, builder: &mut CircuitBuilder) -> Self {
// Limit the maximum array length to avoid abusing `vec_ref`
assert!(max_array_len <= 256);
Self {
max_array_len,
low: builder.add_virtual_target(),
high: if max_array_len > 64 {
builder.add_virtual_target()
} else {
builder.zero()
},
}
}
pub fn set_targets(&self, pw: &mut PartialWitness<F>, index: usize) -> Result<()> {
assert!(index == 0 || index < self.max_array_len);
pw.set_target(self.low, F::from_canonical_usize(index & ((1 << 6) - 1)))?;
pw.set_target(self.high, F::from_canonical_usize(index >> 6))?;
Ok(())
}
}
pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn connect_values(&mut self, x: ValueTarget, y: ValueTarget);
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
@ -932,9 +973,14 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
/// Creates value target that is a hash of two given values.
fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget;
// Convenience methods for accessing and connecting elements of
// (vectors of) flattenables.
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
/// Like `random_access` but allows using longer arrays.
fn random_access_long(&mut self, i: &IndexTarget, array: &[Target]) -> Target;
/// Convenience methods for accessing and connecting elements of
/// (vectors of) flattenables.
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T;
/// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target`
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
fn select_flattenable<T: Flattenable>(
&mut self,
params: &Params,
@ -945,11 +991,11 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T);
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget;
// Convenience methods for Boolean into-iters.
/// Convenience methods for Boolean into-iters.
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
// Return a bit-mask of size `len` that selects all positions lower than `n`
/// Return a bit-mask of size `len` that selects all positions lower than `n`
fn lt_mask(&mut self, len: usize, n: Target) -> Vec<BoolTarget>;
}
@ -1003,9 +1049,12 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
OperationTarget {
op_type: self.add_virtual_operation_type(),
args: (0..params.max_operation_args)
.map(|_| self.add_virtual_target_arr())
.map(|_| IndexTarget::new_virtual(params.statement_table_size(), self))
.collect(),
aux: self.add_virtual_target_arr(),
aux: [
IndexTarget::new_virtual(params.max_merkle_proofs_containers, self),
IndexTarget::new_virtual(params.max_custom_predicate_verifications, self),
],
}
}
@ -1293,18 +1342,41 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
)
}
fn random_access_long(&mut self, i: &IndexTarget, array: &[Target]) -> Target {
const CHUNK_LEN: usize = 64; // Max size of a single gate native random access
assert!(array.len() <= i.max_array_len);
// Limit to 4 chunks (combination of 4 random_access of CHUNK_LEN elements) to avoid
// abusing this method.
assert!(array.len() <= 4 * CHUNK_LEN);
// We do several random accesses over chunks of CHUNK_LEN using the lowest bits of the
// index. Then we combine them using the highest bits of the index.
let mut chunk_res = Vec::new();
let num_chunks = array.len().div_ceil(CHUNK_LEN);
for chunk in array.chunks(CHUNK_LEN) {
let mut index_chunk = i.low;
// I we have several chunks and the last one is smaller (it's index needs less than 6
// bits), make it zero except when it's used so that the range check over the index
// passes.
if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 {
let last_chunk_index_high = self.constant(F::from_canonical_usize(num_chunks - 1));
let selector = self.is_equal(i.high, last_chunk_index_high);
index_chunk = self.mul(index_chunk, selector.target);
}
let res = self.random_access(index_chunk, chunk.to_vec());
chunk_res.push(res);
}
self.random_access(i.high, chunk_res)
}
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
// The idea would be the following: Take the array `ts` and hash each element. Then do the
// random access on the hash result. Finally "unhash" to recover the resolved element.
// We don't want to hash each element from the array each time, so we should cache the hashed
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
// then do `ts: &[HashCache<T>]`.
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
// TODO: Revisit this when we need more than 64 statements.
let vector_ref = |builder: &mut CircuitBuilder, v: &[Target], i| {
assert!(v.len() <= 64);
builder.random_access(i, v.to_vec())
};
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T {
let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| {
let num_rows = m.len();
let num_columns = m
@ -1317,11 +1389,8 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
.unwrap_or(0);
(0..num_columns)
.map(|j| {
vector_ref(
builder,
&(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>(),
i,
)
builder
.random_access_long(i, &(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>())
})
.collect::<Vec<_>>()
};
@ -1330,6 +1399,19 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
}
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
let zero = self.zero();
self.vec_ref(
params,
ts,
&IndexTarget {
max_array_len: 64,
low: i,
high: zero,
},
)
}
fn select_flattenable<T: Flattenable>(
&mut self,
params: &Params,
@ -1639,4 +1721,36 @@ pub(crate) mod tests {
}
})
}
#[test]
fn test_random_access_long() -> Result<(), anyhow::Error> {
let lens: [usize; _] = [10, 60, 64, 96, 126, 159, 190, 256];
for len in &lens {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let array = builder.add_virtual_targets(*len);
let index_target = IndexTarget::new_virtual(*len, &mut builder);
let res = builder.random_access_long(&index_target, &array);
let data = builder.build::<PoseidonGoldilocksConfig>();
for i in 0..3 {
let index = (len - 1) * i / 2;
println!("len={}, index={}", len, index);
let mut pw = PartialWitness::<F>::new();
for (j, elem) in array.iter().enumerate() {
pw.set_target(*elem, F::from_canonical_usize(j * 11))?;
}
index_target.set_targets(&mut pw, index)?;
pw.set_target(res, F::from_canonical_usize(index * 11))?; // Expected
let proof = data.prove(pw)?;
data.verify(proof)?;
}
}
Ok(())
}
}

View file

@ -95,8 +95,7 @@ impl StatementCache {
// converting a length 1 array into a scalar.
op.args
.iter()
.flatten()
.map(|&i| builder.vec_ref(params, prev_statements, i))
.map(|i| builder.vec_ref(params, prev_statements, i))
.collect::<Vec<_>>()
};
assert!(params.max_operation_args >= 3);
@ -224,7 +223,7 @@ fn verify_operation_circuit(
// been verified, so we need only look up the claim.
let measure_resolve_merkle_claim = measure_gates_begin!(builder, "ResolveMerkleClaim");
let resolved_merkle_claim =
(!merkle_claims.is_empty()).then(|| builder.vec_ref(params, merkle_claims, op.aux[0]));
(!merkle_claims.is_empty()).then(|| builder.vec_ref(params, merkle_claims, &op.aux[0]));
measure_gates_end!(builder, measure_resolve_merkle_claim);
// Operations from custom statements will refer to one
@ -233,7 +232,7 @@ fn verify_operation_circuit(
let measure_resolve_custom_pred_verification =
measure_gates_begin!(builder, "ResolveCustomPredVerification");
let resolved_custom_pred_verification = (!custom_predicate_verification_table.is_empty())
.then(|| builder.vec_ref(params, custom_predicate_verification_table, op.aux[1]));
.then(|| builder.vec_ref(params, custom_predicate_verification_table, &op.aux[1]));
measure_gates_end!(builder, measure_resolve_custom_pred_verification);
// The verification may require aux data which needs to be stored in the
@ -938,7 +937,7 @@ fn make_statement_arg_from_template_circuit(
let first_index = ak_id_wc_index;
let is_first_index_valid = builder.or(is_ak, is_wc_literal);
let first_index = builder.select(is_first_index_valid, first_index, zero);
let resolved_ak_id = builder.vec_ref(params, &args, first_index);
let resolved_ak_id = builder.vec_ref_small(params, &args, first_index);
let resolved_wc = resolved_ak_id;
// If the index is not used, use a 0 instead to still pass the range constraints from
@ -946,7 +945,7 @@ fn make_statement_arg_from_template_circuit(
let second_index = ak_key_wc_index;
let is_second_index_valid = builder.and(is_ak, is_ak_key_wc);
let second_index = builder.select(is_second_index_valid, second_index, zero);
let resolved_ak_key = builder.vec_ref(params, &args, second_index);
let resolved_ak_key = builder.vec_ref_small(params, &args, second_index);
let ak_key = ak_key_lit; // is_ak_key_lit
let ak_key = builder.select_flattenable(params, is_ak_key_wc, &resolved_ak_key, &ak_key);
@ -1244,7 +1243,7 @@ fn build_custom_predicate_verification_table_circuit(
let table_query_hash = builder.vec_ref(
params,
custom_predicate_table,
entry.custom_predicate_table_index,
&entry.custom_predicate_table_index,
);
let out_query_hash = entry.custom_predicate.hash(builder);
builder.connect_array(table_query_hash.elements, out_query_hash.elements);

View file

@ -1,6 +1,5 @@
use std::fmt;
use plonky2::field::types::Field;
use serde::{Deserialize, Serialize};
use crate::{
@ -9,7 +8,7 @@ use crate::{
mainpod::Statement,
primitives::merkletree::MerkleClaimAndProof,
},
middleware::{self, OperationType, Params, ToFields, F},
middleware::{self, OperationType},
};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@ -18,20 +17,17 @@ pub enum OperationArg {
Index(usize),
}
impl ToFields for OperationArg {
fn to_fields(&self, _params: &Params) -> Vec<F> {
let f = match self {
Self::None => F::ZERO,
Self::Index(i) => F::from_canonical_usize(*i),
};
vec![f]
}
}
impl OperationArg {
pub fn is_none(&self) -> bool {
matches!(self, OperationArg::None)
}
pub fn as_usize(&self) -> usize {
match self {
Self::None => 0,
Self::Index(i) => *i,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@ -41,14 +37,13 @@ pub enum OperationAux {
CustomPredVerifyIndex(usize),
}
impl ToFields for OperationAux {
fn to_fields(&self, _params: &Params) -> Vec<F> {
let fs = match self {
Self::None => [F::ZERO, F::ZERO],
Self::MerkleProofIndex(i) => [F::from_canonical_usize(*i), F::ZERO],
Self::CustomPredVerifyIndex(i) => [F::ZERO, F::from_canonical_usize(*i)],
};
vec![fs[0], fs[1]]
impl OperationAux {
pub fn as_usizes(&self) -> [usize; 2] {
match self {
Self::None => [0, 0],
Self::MerkleProofIndex(i) => [*i, 0],
Self::CustomPredVerifyIndex(i) => [0, *i],
}
}
}