Use projected statement lookup for op arg resolution (#503)

* Use projected statement lookup for op arg resolution

* Add projected op-arg index coverage test

* Tidying and reorganising
This commit is contained in:
Rob Knight 2026-04-29 00:56:39 -07:00 committed by GitHub
parent 8844fe124c
commit 111b132a00
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 159 additions and 38 deletions

View file

@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::{
basetypes::{CircuitBuilder, CommonCircuitData, D},
circuits::mainpod::CustomPredicateVerification,
circuits::{mainpod::CustomPredicateVerification, mux_table::TableGetGenerator},
error::Result,
mainpod::{Operation, OperationArg, OperationAux, Statement},
primitives::merkletree::{
@ -1362,6 +1362,18 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T;
/// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target`
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
/// Like `vec_ref` but for wide rows: random-accesses a precomputed hash of each entry, then
/// materializes the selected row via a witness generator and constrains its hash. Cheaper than
/// `vec_ref` when each entry has many fields, since random access runs only over the 4-field
/// hashes. The caller is responsible for precomputing `ts_flattened` and `ts_hashes` once and
/// reusing the same slices across multiple lookups.
fn vec_ref_projected<T: Flattenable>(
&mut self,
params: &Params,
ts_flattened: &[Vec<Target>],
ts_hashes: &[HashOutTarget],
i: &IndexTarget,
) -> T;
fn select_flattenable<T: Flattenable>(
&mut self,
params: &Params,
@ -1764,12 +1776,6 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
self.random_access(i.high, chunk_res)
}
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
// The idea would be the following: Take the array `ts` and hash each element. Then do the
// random access on the hash result. Finally "unhash" to recover the resolved element.
// We don't want to hash each element from the array each time, so we should cache the hashed
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
// then do `ts: &[HashCache<T>]`.
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T {
let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| {
let num_rows = m.len();
@ -1793,6 +1799,28 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
}
fn vec_ref_projected<T: Flattenable>(
&mut self,
params: &Params,
ts_flattened: &[Vec<Target>],
ts_hashes: &[HashOutTarget],
i: &IndexTarget,
) -> T {
assert_eq!(ts_flattened.len(), ts_hashes.len());
let selected_hash = self.vec_ref(params, ts_hashes, i);
let selected_flattened = self.add_virtual_targets(T::size(params));
let selected_flattened_hash =
self.hash_n_to_hash_no_pad::<PoseidonHash>(selected_flattened.clone());
self.connect_hashes(selected_hash, selected_flattened_hash);
let result = T::from_flattened(params, &selected_flattened);
self.add_simple_generator(TableGetGenerator::new(
i.clone(),
ts_flattened.to_vec(),
selected_flattened,
));
result
}
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
let zero = self.zero();
self.vec_ref(

View file

@ -97,9 +97,10 @@ impl<const MAX_EQS: usize> StatementCache<MAX_EQS> {
builder: &mut CircuitBuilder,
op: &OperationTarget,
st: &StatementTarget,
prev_statements: &[StatementTarget],
prev_statement_flatteneds: &[Vec<Target>],
prev_statement_hashes: &[HashOutTarget],
) -> Self {
let op_args = if prev_statements.is_empty() {
let op_args = if prev_statement_flatteneds.is_empty() {
(0..max_operation_args)
.map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[]))
.collect_vec()
@ -109,7 +110,14 @@ impl<const MAX_EQS: usize> StatementCache<MAX_EQS> {
op.args
.iter()
.take(max_operation_args)
.map(|i| builder.vec_ref(params, prev_statements, i))
.map(|i| {
builder.vec_ref_projected(
params,
prev_statement_flatteneds,
prev_statement_hashes,
i,
)
})
.collect::<Vec<_>>()
};
assert!(Params::max_statement_args() >= MAX_VALUE_ARGS);
@ -193,7 +201,8 @@ fn verify_operation_public_statement_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op: &OperationTarget,
prev_statements: &[StatementTarget],
prev_statement_flatteneds: &[Vec<Target>],
prev_statement_hashes: &[HashOutTarget],
) -> Result<()> {
let measure = measure_gates_begin!(builder, "OpVerifyPub");
@ -203,7 +212,15 @@ fn verify_operation_public_statement_circuit(
let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs");
// None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the
// StatementCache requires.
let cache = StatementCachePub::new(params, 1, builder, op, st, prev_statements);
let cache = StatementCachePub::new(
params,
1,
builder,
op,
st,
prev_statement_flatteneds,
prev_statement_hashes,
);
measure_gates_end!(builder, measure_resolve_op_args);
let op_checks = vec![
@ -434,7 +451,8 @@ fn verify_operation_circuit(
builder: &mut CircuitBuilder,
st: &StatementTarget,
op: &OperationTarget,
prev_statements: &[StatementTarget],
prev_statement_flatteneds: &[Vec<Target>],
prev_statement_hashes: &[HashOutTarget],
aux_table: &MuxTableTarget,
) -> Result<()> {
let measure = measure_gates_begin!(builder, "OpVerifyPriv");
@ -451,7 +469,8 @@ fn verify_operation_circuit(
builder,
op,
st,
prev_statements,
prev_statement_flatteneds,
prev_statement_hashes,
);
measure_gates_end!(builder, measure_resolve_op_args);
@ -1837,13 +1856,37 @@ fn verify_main_pod_circuit(
// 2. Calculate the Pod Id from the public statements
let sts_hash = calculate_statements_hash_circuit(builder, pub_statements);
// Precompute flattened statements and their hashes once, then resolve operation args using
// projected lookups. Reusing the flattened forms avoids re-flattening per op-arg lookup.
let statement_flatteneds: Vec<Vec<Target>> = statements.iter().map(|st| st.flatten()).collect();
let statement_hashes = statement_flatteneds
.iter()
.map(|flat| builder.hash_n_to_hash_no_pad::<PoseidonHash>(flat.clone()))
.collect_vec();
// 5. Verify input statements
for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() {
let prev_statements = &statements[..input_statements_offset + i];
let prev_statement_flatteneds = &statement_flatteneds[..input_statements_offset + i];
let prev_statement_hashes = &statement_hashes[..input_statements_offset + i];
if i < public_statements_offset {
verify_operation_circuit(params, builder, st, op, prev_statements, &aux_table)?;
verify_operation_circuit(
params,
builder,
st,
op,
prev_statement_flatteneds,
prev_statement_hashes,
&aux_table,
)?;
} else {
verify_operation_public_statement_circuit(params, builder, st, op, prev_statements)?;
verify_operation_public_statement_circuit(
params,
builder,
st,
op,
prev_statement_flatteneds,
prev_statement_hashes,
)?;
}
}
@ -2221,6 +2264,14 @@ mod tests {
let prev_statements_target: Vec<_> = (0..prev_statements.len())
.map(|_| builder.add_virtual_statement(false))
.collect();
let prev_statement_flatteneds_target: Vec<Vec<Target>> = prev_statements_target
.iter()
.map(|st| st.flatten())
.collect();
let prev_statement_hashes_target: Vec<_> = prev_statement_flatteneds_target
.iter()
.map(|flat| builder.hash_n_to_hash_no_pad::<PoseidonHash>(flat.clone()))
.collect();
let merkle_proofs_target: Vec<_> = aux
.merkle_proofs
@ -2269,7 +2320,8 @@ mod tests {
&mut builder,
&st_target,
&op_target,
&prev_statements_target,
&prev_statement_flatteneds_target,
&prev_statement_hashes_target,
&aux_table,
)?;
@ -2711,6 +2763,37 @@ mod tests {
})
}
#[test]
fn test_operation_verify_sumof_non_monotonic_repeated_indices() -> Result<()> {
let local = dict!({
"a" => 3,
"noise" => 99,
"sum" => 6,
});
let st_a: mainpod::Statement = Statement::contains(local.clone(), "a", 3).into();
let st_noise: mainpod::Statement = Statement::contains(local.clone(), "noise", 99).into();
let st_sum: mainpod::Statement = Statement::contains(local.clone(), "sum", 6).into();
let st: mainpod::Statement = Statement::sum_of(
AnchoredKey::from((&local, "sum")),
AnchoredKey::from((&local, "a")),
AnchoredKey::from((&local, "a")),
)
.into();
let op = mainpod::Operation(
OperationType::Native(NativeOperation::SumOf),
vec![
// Non-monotonic and repeated indices to stress random-access resolution.
OperationArg::Index(2),
OperationArg::Index(0),
OperationArg::Index(0),
],
OperationAux::None,
);
let prev_statements = vec![st_a, st_noise, st_sum];
operation_verify(st, op, prev_statements, Aux::default())
}
#[test]
fn test_operation_verify_productof() -> Result<()> {
I64_TEST_PAIRS

View file

@ -107,11 +107,11 @@ impl MuxTableTarget {
rev_resolved_tagged_flattened.reverse();
let resolved_tagged_flattened = rev_resolved_tagged_flattened;
builder.add_simple_generator(TableGetGenerator {
index: index.clone(),
tagged_entries: self.tagged_entries.clone(),
get_tagged_entry: resolved_tagged_flattened.clone(),
});
builder.add_simple_generator(TableGetGenerator::new(
index.clone(),
self.tagged_entries.clone(),
resolved_tagged_flattened.clone(),
));
measure_gates_end!(builder, measure);
TableEntryTarget {
params: self.params.clone(),
@ -123,8 +123,18 @@ impl MuxTableTarget {
#[derive(Debug, Clone, Default)]
pub struct TableGetGenerator {
index: IndexTarget,
tagged_entries: Vec<Vec<Target>>,
get_tagged_entry: Vec<Target>,
entries: Vec<Vec<Target>>,
revealed_entry: Vec<Target>,
}
impl TableGetGenerator {
pub fn new(index: IndexTarget, entries: Vec<Vec<Target>>, revealed_entry: Vec<Target>) -> Self {
Self {
index,
entries,
revealed_entry,
}
}
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for TableGetGenerator {
@ -135,7 +145,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
fn dependencies(&self) -> Vec<Target> {
[self.index.low, self.index.high]
.into_iter()
.chain(self.tagged_entries.iter().flatten().copied())
.chain(self.entries.iter().flatten().copied())
.collect()
}
@ -148,12 +158,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
let index_high = witness.get_target(self.index.high);
let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64();
let entry = witness.get_targets(&self.tagged_entries[index as usize]);
let entry = witness.get_targets(&self.entries[index as usize]);
for (target, value) in self.get_tagged_entry.iter().zip(
for (target, value) in self.revealed_entry.iter().zip(
entry
.iter()
.chain(iter::repeat(&F::ZERO).take(self.get_tagged_entry.len())),
.chain(iter::repeat(&F::ZERO).take(self.revealed_entry.len())),
) {
out_buffer.set_target(*target, *value)?;
}
@ -166,12 +176,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
dst.write_target(self.index.low)?;
dst.write_target(self.index.high)?;
dst.write_usize(self.tagged_entries.len())?;
for tagged_entry in &self.tagged_entries {
dst.write_target_vec(tagged_entry)?;
dst.write_usize(self.entries.len())?;
for entry in &self.entries {
dst.write_target_vec(entry)?;
}
dst.write_target_vec(&self.get_tagged_entry)
dst.write_target_vec(&self.revealed_entry)
}
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
@ -181,16 +191,16 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
high: src.read_target()?,
};
let len = src.read_usize()?;
let mut tagged_entries = Vec::with_capacity(len);
let mut entries = Vec::with_capacity(len);
for _ in 0..len {
tagged_entries.push(src.read_target_vec()?);
entries.push(src.read_target_vec()?);
}
let get_tagged_entry = src.read_target_vec()?;
let revealed_entry = src.read_target_vec()?;
Ok(Self {
index,
tagged_entries,
get_tagged_entry,
entries,
revealed_entry,
})
}
}