Use projected statement lookup for op arg resolution (#503)

* Use projected statement lookup for op arg resolution

* Add projected op-arg index coverage test

* Tidying and reorganising
This commit is contained in:
Rob Knight 2026-04-29 00:56:39 -07:00 committed by GitHub
parent 8844fe124c
commit 111b132a00
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 159 additions and 38 deletions

View file

@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
backends::plonky2::{ backends::plonky2::{
basetypes::{CircuitBuilder, CommonCircuitData, D}, basetypes::{CircuitBuilder, CommonCircuitData, D},
circuits::mainpod::CustomPredicateVerification, circuits::{mainpod::CustomPredicateVerification, mux_table::TableGetGenerator},
error::Result, error::Result,
mainpod::{Operation, OperationArg, OperationAux, Statement}, mainpod::{Operation, OperationArg, OperationAux, Statement},
primitives::merkletree::{ primitives::merkletree::{
@ -1362,6 +1362,18 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T; fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T;
/// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target` /// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target`
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T; fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
/// Like `vec_ref` but for wide rows: random-accesses a precomputed hash of each entry, then
/// materializes the selected row via a witness generator and constrains its hash. Cheaper than
/// `vec_ref` when each entry has many fields, since random access runs only over the 4-field
/// hashes. The caller is responsible for precomputing `ts_flattened` and `ts_hashes` once and
/// reusing the same slices across multiple lookups.
fn vec_ref_projected<T: Flattenable>(
&mut self,
params: &Params,
ts_flattened: &[Vec<Target>],
ts_hashes: &[HashOutTarget],
i: &IndexTarget,
) -> T;
fn select_flattenable<T: Flattenable>( fn select_flattenable<T: Flattenable>(
&mut self, &mut self,
params: &Params, params: &Params,
@ -1764,12 +1776,6 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
self.random_access(i.high, chunk_res) self.random_access(i.high, chunk_res)
} }
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
// The idea would be the following: Take the array `ts` and hash each element. Then do the
// random access on the hash result. Finally "unhash" to recover the resolved element.
// We don't want to hash each element from the array each time, so we should cache the hashed
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
// then do `ts: &[HashCache<T>]`.
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T { fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T {
let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| { let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| {
let num_rows = m.len(); let num_rows = m.len();
@ -1793,6 +1799,28 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i)) T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
} }
fn vec_ref_projected<T: Flattenable>(
&mut self,
params: &Params,
ts_flattened: &[Vec<Target>],
ts_hashes: &[HashOutTarget],
i: &IndexTarget,
) -> T {
assert_eq!(ts_flattened.len(), ts_hashes.len());
let selected_hash = self.vec_ref(params, ts_hashes, i);
let selected_flattened = self.add_virtual_targets(T::size(params));
let selected_flattened_hash =
self.hash_n_to_hash_no_pad::<PoseidonHash>(selected_flattened.clone());
self.connect_hashes(selected_hash, selected_flattened_hash);
let result = T::from_flattened(params, &selected_flattened);
self.add_simple_generator(TableGetGenerator::new(
i.clone(),
ts_flattened.to_vec(),
selected_flattened,
));
result
}
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T { fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
let zero = self.zero(); let zero = self.zero();
self.vec_ref( self.vec_ref(

View file

@ -97,9 +97,10 @@ impl<const MAX_EQS: usize> StatementCache<MAX_EQS> {
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
op: &OperationTarget, op: &OperationTarget,
st: &StatementTarget, st: &StatementTarget,
prev_statements: &[StatementTarget], prev_statement_flatteneds: &[Vec<Target>],
prev_statement_hashes: &[HashOutTarget],
) -> Self { ) -> Self {
let op_args = if prev_statements.is_empty() { let op_args = if prev_statement_flatteneds.is_empty() {
(0..max_operation_args) (0..max_operation_args)
.map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[])) .map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[]))
.collect_vec() .collect_vec()
@ -109,7 +110,14 @@ impl<const MAX_EQS: usize> StatementCache<MAX_EQS> {
op.args op.args
.iter() .iter()
.take(max_operation_args) .take(max_operation_args)
.map(|i| builder.vec_ref(params, prev_statements, i)) .map(|i| {
builder.vec_ref_projected(
params,
prev_statement_flatteneds,
prev_statement_hashes,
i,
)
})
.collect::<Vec<_>>() .collect::<Vec<_>>()
}; };
assert!(Params::max_statement_args() >= MAX_VALUE_ARGS); assert!(Params::max_statement_args() >= MAX_VALUE_ARGS);
@ -193,7 +201,8 @@ fn verify_operation_public_statement_circuit(
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
st: &StatementTarget, st: &StatementTarget,
op: &OperationTarget, op: &OperationTarget,
prev_statements: &[StatementTarget], prev_statement_flatteneds: &[Vec<Target>],
prev_statement_hashes: &[HashOutTarget],
) -> Result<()> { ) -> Result<()> {
let measure = measure_gates_begin!(builder, "OpVerifyPub"); let measure = measure_gates_begin!(builder, "OpVerifyPub");
@ -203,7 +212,15 @@ fn verify_operation_public_statement_circuit(
let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs");
// None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the // None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the
// StatementCache requires. // StatementCache requires.
let cache = StatementCachePub::new(params, 1, builder, op, st, prev_statements); let cache = StatementCachePub::new(
params,
1,
builder,
op,
st,
prev_statement_flatteneds,
prev_statement_hashes,
);
measure_gates_end!(builder, measure_resolve_op_args); measure_gates_end!(builder, measure_resolve_op_args);
let op_checks = vec![ let op_checks = vec![
@ -434,7 +451,8 @@ fn verify_operation_circuit(
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
st: &StatementTarget, st: &StatementTarget,
op: &OperationTarget, op: &OperationTarget,
prev_statements: &[StatementTarget], prev_statement_flatteneds: &[Vec<Target>],
prev_statement_hashes: &[HashOutTarget],
aux_table: &MuxTableTarget, aux_table: &MuxTableTarget,
) -> Result<()> { ) -> Result<()> {
let measure = measure_gates_begin!(builder, "OpVerifyPriv"); let measure = measure_gates_begin!(builder, "OpVerifyPriv");
@ -451,7 +469,8 @@ fn verify_operation_circuit(
builder, builder,
op, op,
st, st,
prev_statements, prev_statement_flatteneds,
prev_statement_hashes,
); );
measure_gates_end!(builder, measure_resolve_op_args); measure_gates_end!(builder, measure_resolve_op_args);
@ -1837,13 +1856,37 @@ fn verify_main_pod_circuit(
// 2. Calculate the Pod Id from the public statements // 2. Calculate the Pod Id from the public statements
let sts_hash = calculate_statements_hash_circuit(builder, pub_statements); let sts_hash = calculate_statements_hash_circuit(builder, pub_statements);
// Precompute flattened statements and their hashes once, then resolve operation args using
// projected lookups. Reusing the flattened forms avoids re-flattening per op-arg lookup.
let statement_flatteneds: Vec<Vec<Target>> = statements.iter().map(|st| st.flatten()).collect();
let statement_hashes = statement_flatteneds
.iter()
.map(|flat| builder.hash_n_to_hash_no_pad::<PoseidonHash>(flat.clone()))
.collect_vec();
// 5. Verify input statements // 5. Verify input statements
for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() { for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() {
let prev_statements = &statements[..input_statements_offset + i]; let prev_statement_flatteneds = &statement_flatteneds[..input_statements_offset + i];
let prev_statement_hashes = &statement_hashes[..input_statements_offset + i];
if i < public_statements_offset { if i < public_statements_offset {
verify_operation_circuit(params, builder, st, op, prev_statements, &aux_table)?; verify_operation_circuit(
params,
builder,
st,
op,
prev_statement_flatteneds,
prev_statement_hashes,
&aux_table,
)?;
} else { } else {
verify_operation_public_statement_circuit(params, builder, st, op, prev_statements)?; verify_operation_public_statement_circuit(
params,
builder,
st,
op,
prev_statement_flatteneds,
prev_statement_hashes,
)?;
} }
} }
@ -2221,6 +2264,14 @@ mod tests {
let prev_statements_target: Vec<_> = (0..prev_statements.len()) let prev_statements_target: Vec<_> = (0..prev_statements.len())
.map(|_| builder.add_virtual_statement(false)) .map(|_| builder.add_virtual_statement(false))
.collect(); .collect();
let prev_statement_flatteneds_target: Vec<Vec<Target>> = prev_statements_target
.iter()
.map(|st| st.flatten())
.collect();
let prev_statement_hashes_target: Vec<_> = prev_statement_flatteneds_target
.iter()
.map(|flat| builder.hash_n_to_hash_no_pad::<PoseidonHash>(flat.clone()))
.collect();
let merkle_proofs_target: Vec<_> = aux let merkle_proofs_target: Vec<_> = aux
.merkle_proofs .merkle_proofs
@ -2269,7 +2320,8 @@ mod tests {
&mut builder, &mut builder,
&st_target, &st_target,
&op_target, &op_target,
&prev_statements_target, &prev_statement_flatteneds_target,
&prev_statement_hashes_target,
&aux_table, &aux_table,
)?; )?;
@ -2711,6 +2763,37 @@ mod tests {
}) })
} }
#[test]
fn test_operation_verify_sumof_non_monotonic_repeated_indices() -> Result<()> {
let local = dict!({
"a" => 3,
"noise" => 99,
"sum" => 6,
});
let st_a: mainpod::Statement = Statement::contains(local.clone(), "a", 3).into();
let st_noise: mainpod::Statement = Statement::contains(local.clone(), "noise", 99).into();
let st_sum: mainpod::Statement = Statement::contains(local.clone(), "sum", 6).into();
let st: mainpod::Statement = Statement::sum_of(
AnchoredKey::from((&local, "sum")),
AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "a")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::SumOf),
vec![
// Non-monotonic and repeated indices to stress random-access resolution.
OperationArg::Index(2),
OperationArg::Index(0),
OperationArg::Index(0),
],
OperationAux::None,
);
let prev_statements = vec![st_a, st_noise, st_sum];
operation_verify(st, op, prev_statements, Aux::default())
}
#[test] #[test]
fn test_operation_verify_productof() -> Result<()> { fn test_operation_verify_productof() -> Result<()> {
I64_TEST_PAIRS I64_TEST_PAIRS

View file

@ -107,11 +107,11 @@ impl MuxTableTarget {
rev_resolved_tagged_flattened.reverse(); rev_resolved_tagged_flattened.reverse();
let resolved_tagged_flattened = rev_resolved_tagged_flattened; let resolved_tagged_flattened = rev_resolved_tagged_flattened;
builder.add_simple_generator(TableGetGenerator { builder.add_simple_generator(TableGetGenerator::new(
index: index.clone(), index.clone(),
tagged_entries: self.tagged_entries.clone(), self.tagged_entries.clone(),
get_tagged_entry: resolved_tagged_flattened.clone(), resolved_tagged_flattened.clone(),
}); ));
measure_gates_end!(builder, measure); measure_gates_end!(builder, measure);
TableEntryTarget { TableEntryTarget {
params: self.params.clone(), params: self.params.clone(),
@ -123,8 +123,18 @@ impl MuxTableTarget {
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct TableGetGenerator { pub struct TableGetGenerator {
index: IndexTarget, index: IndexTarget,
tagged_entries: Vec<Vec<Target>>, entries: Vec<Vec<Target>>,
get_tagged_entry: Vec<Target>, revealed_entry: Vec<Target>,
}
impl TableGetGenerator {
pub fn new(index: IndexTarget, entries: Vec<Vec<Target>>, revealed_entry: Vec<Target>) -> Self {
Self {
index,
entries,
revealed_entry,
}
}
} }
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for TableGetGenerator { impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for TableGetGenerator {
@ -135,7 +145,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
fn dependencies(&self) -> Vec<Target> { fn dependencies(&self) -> Vec<Target> {
[self.index.low, self.index.high] [self.index.low, self.index.high]
.into_iter() .into_iter()
.chain(self.tagged_entries.iter().flatten().copied()) .chain(self.entries.iter().flatten().copied())
.collect() .collect()
} }
@ -148,12 +158,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
let index_high = witness.get_target(self.index.high); let index_high = witness.get_target(self.index.high);
let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64(); let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64();
let entry = witness.get_targets(&self.tagged_entries[index as usize]); let entry = witness.get_targets(&self.entries[index as usize]);
for (target, value) in self.get_tagged_entry.iter().zip( for (target, value) in self.revealed_entry.iter().zip(
entry entry
.iter() .iter()
.chain(iter::repeat(&F::ZERO).take(self.get_tagged_entry.len())), .chain(iter::repeat(&F::ZERO).take(self.revealed_entry.len())),
) { ) {
out_buffer.set_target(*target, *value)?; out_buffer.set_target(*target, *value)?;
} }
@ -166,12 +176,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
dst.write_target(self.index.low)?; dst.write_target(self.index.low)?;
dst.write_target(self.index.high)?; dst.write_target(self.index.high)?;
dst.write_usize(self.tagged_entries.len())?; dst.write_usize(self.entries.len())?;
for tagged_entry in &self.tagged_entries { for entry in &self.entries {
dst.write_target_vec(tagged_entry)?; dst.write_target_vec(entry)?;
} }
dst.write_target_vec(&self.get_tagged_entry) dst.write_target_vec(&self.revealed_entry)
} }
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> { fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
@ -181,16 +191,16 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
high: src.read_target()?, high: src.read_target()?,
}; };
let len = src.read_usize()?; let len = src.read_usize()?;
let mut tagged_entries = Vec::with_capacity(len); let mut entries = Vec::with_capacity(len);
for _ in 0..len { for _ in 0..len {
tagged_entries.push(src.read_target_vec()?); entries.push(src.read_target_vec()?);
} }
let get_tagged_entry = src.read_target_vec()?; let revealed_entry = src.read_target_vec()?;
Ok(Self { Ok(Self {
index, index,
tagged_entries, entries,
get_tagged_entry, revealed_entry,
}) })
} }
} }