From 111b132a00aa9a32b26981a5d00ec7376e1c422a Mon Sep 17 00:00:00 2001 From: Rob Knight Date: Wed, 29 Apr 2026 00:56:39 -0700 Subject: [PATCH] Use projected statement lookup for op arg resolution (#503) * Use projected statement lookup for op arg resolution * Add projected op-arg index coverage test * Tidying and reorganising --- src/backends/plonky2/circuits/common.rs | 42 +++++++-- src/backends/plonky2/circuits/mainpod.rs | 105 ++++++++++++++++++--- src/backends/plonky2/circuits/mux_table.rs | 50 ++++++---- 3 files changed, 159 insertions(+), 38 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index de53ee5..dfee8a0 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::plonky2::{ basetypes::{CircuitBuilder, CommonCircuitData, D}, - circuits::mainpod::CustomPredicateVerification, + circuits::{mainpod::CustomPredicateVerification, mux_table::TableGetGenerator}, error::Result, mainpod::{Operation, OperationArg, OperationAux, Statement}, primitives::merkletree::{ @@ -1362,6 +1362,18 @@ pub trait CircuitBuilderPod, const D: usize> { fn vec_ref(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T; /// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target` fn vec_ref_small(&mut self, params: &Params, ts: &[T], i: Target) -> T; + /// Like `vec_ref` but for wide rows: random-accesses a precomputed hash of each entry, then + /// materializes the selected row via a witness generator and constrains its hash. Cheaper than + /// `vec_ref` when each entry has many fields, since random access runs only over the 4-field + /// hashes. The caller is responsible for precomputing `ts_flattened` and `ts_hashes` once and + /// reusing the same slices across multiple lookups. + fn vec_ref_projected( + &mut self, + params: &Params, + ts_flattened: &[Vec], + ts_hashes: &[HashOutTarget], + i: &IndexTarget, + ) -> T; fn select_flattenable( &mut self, params: &Params, @@ -1764,12 +1776,6 @@ impl CircuitBuilderPod for CircuitBuilder { self.random_access(i.high, chunk_res) } - // TODO: Implement a version of vec_ref for types `T` which are big and support hashing. - // The idea would be the following: Take the array `ts` and hash each element. Then do the - // random access on the hash result. Finally "unhash" to recover the resolved element. - // We don't want to hash each element from the array each time, so we should cache the hashed - // result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and - // then do `ts: &[HashCache]`. fn vec_ref(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T { let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec], i| { let num_rows = m.len(); @@ -1793,6 +1799,28 @@ impl CircuitBuilderPod for CircuitBuilder { T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i)) } + fn vec_ref_projected( + &mut self, + params: &Params, + ts_flattened: &[Vec], + ts_hashes: &[HashOutTarget], + i: &IndexTarget, + ) -> T { + assert_eq!(ts_flattened.len(), ts_hashes.len()); + let selected_hash = self.vec_ref(params, ts_hashes, i); + let selected_flattened = self.add_virtual_targets(T::size(params)); + let selected_flattened_hash = + self.hash_n_to_hash_no_pad::(selected_flattened.clone()); + self.connect_hashes(selected_hash, selected_flattened_hash); + let result = T::from_flattened(params, &selected_flattened); + self.add_simple_generator(TableGetGenerator::new( + i.clone(), + ts_flattened.to_vec(), + selected_flattened, + )); + result + } + fn vec_ref_small(&mut self, params: &Params, ts: &[T], i: Target) -> T { let zero = self.zero(); self.vec_ref( diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index c4b891a..0605e07 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -97,9 +97,10 @@ impl StatementCache { builder: &mut CircuitBuilder, op: &OperationTarget, st: &StatementTarget, - prev_statements: &[StatementTarget], + prev_statement_flatteneds: &[Vec], + prev_statement_hashes: &[HashOutTarget], ) -> Self { - let op_args = if prev_statements.is_empty() { + let op_args = if prev_statement_flatteneds.is_empty() { (0..max_operation_args) .map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[])) .collect_vec() @@ -109,7 +110,14 @@ impl StatementCache { op.args .iter() .take(max_operation_args) - .map(|i| builder.vec_ref(params, prev_statements, i)) + .map(|i| { + builder.vec_ref_projected( + params, + prev_statement_flatteneds, + prev_statement_hashes, + i, + ) + }) .collect::>() }; assert!(Params::max_statement_args() >= MAX_VALUE_ARGS); @@ -193,7 +201,8 @@ fn verify_operation_public_statement_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op: &OperationTarget, - prev_statements: &[StatementTarget], + prev_statement_flatteneds: &[Vec], + prev_statement_hashes: &[HashOutTarget], ) -> Result<()> { let measure = measure_gates_begin!(builder, "OpVerifyPub"); @@ -203,7 +212,15 @@ fn verify_operation_public_statement_circuit( let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs"); // None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the // StatementCache requires. - let cache = StatementCachePub::new(params, 1, builder, op, st, prev_statements); + let cache = StatementCachePub::new( + params, + 1, + builder, + op, + st, + prev_statement_flatteneds, + prev_statement_hashes, + ); measure_gates_end!(builder, measure_resolve_op_args); let op_checks = vec![ @@ -434,7 +451,8 @@ fn verify_operation_circuit( builder: &mut CircuitBuilder, st: &StatementTarget, op: &OperationTarget, - prev_statements: &[StatementTarget], + prev_statement_flatteneds: &[Vec], + prev_statement_hashes: &[HashOutTarget], aux_table: &MuxTableTarget, ) -> Result<()> { let measure = measure_gates_begin!(builder, "OpVerifyPriv"); @@ -451,7 +469,8 @@ fn verify_operation_circuit( builder, op, st, - prev_statements, + prev_statement_flatteneds, + prev_statement_hashes, ); measure_gates_end!(builder, measure_resolve_op_args); @@ -1837,13 +1856,37 @@ fn verify_main_pod_circuit( // 2. Calculate the Pod Id from the public statements let sts_hash = calculate_statements_hash_circuit(builder, pub_statements); + // Precompute flattened statements and their hashes once, then resolve operation args using + // projected lookups. Reusing the flattened forms avoids re-flattening per op-arg lookup. + let statement_flatteneds: Vec> = statements.iter().map(|st| st.flatten()).collect(); + let statement_hashes = statement_flatteneds + .iter() + .map(|flat| builder.hash_n_to_hash_no_pad::(flat.clone())) + .collect_vec(); + // 5. Verify input statements for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() { - let prev_statements = &statements[..input_statements_offset + i]; + let prev_statement_flatteneds = &statement_flatteneds[..input_statements_offset + i]; + let prev_statement_hashes = &statement_hashes[..input_statements_offset + i]; if i < public_statements_offset { - verify_operation_circuit(params, builder, st, op, prev_statements, &aux_table)?; + verify_operation_circuit( + params, + builder, + st, + op, + prev_statement_flatteneds, + prev_statement_hashes, + &aux_table, + )?; } else { - verify_operation_public_statement_circuit(params, builder, st, op, prev_statements)?; + verify_operation_public_statement_circuit( + params, + builder, + st, + op, + prev_statement_flatteneds, + prev_statement_hashes, + )?; } } @@ -2221,6 +2264,14 @@ mod tests { let prev_statements_target: Vec<_> = (0..prev_statements.len()) .map(|_| builder.add_virtual_statement(false)) .collect(); + let prev_statement_flatteneds_target: Vec> = prev_statements_target + .iter() + .map(|st| st.flatten()) + .collect(); + let prev_statement_hashes_target: Vec<_> = prev_statement_flatteneds_target + .iter() + .map(|flat| builder.hash_n_to_hash_no_pad::(flat.clone())) + .collect(); let merkle_proofs_target: Vec<_> = aux .merkle_proofs @@ -2269,7 +2320,8 @@ mod tests { &mut builder, &st_target, &op_target, - &prev_statements_target, + &prev_statement_flatteneds_target, + &prev_statement_hashes_target, &aux_table, )?; @@ -2711,6 +2763,37 @@ mod tests { }) } + #[test] + fn test_operation_verify_sumof_non_monotonic_repeated_indices() -> Result<()> { + let local = dict!({ + "a" => 3, + "noise" => 99, + "sum" => 6, + }); + let st_a: mainpod::Statement = Statement::contains(local.clone(), "a", 3).into(); + let st_noise: mainpod::Statement = Statement::contains(local.clone(), "noise", 99).into(); + let st_sum: mainpod::Statement = Statement::contains(local.clone(), "sum", 6).into(); + + let st: mainpod::Statement = Statement::sum_of( + AnchoredKey::from((&local, "sum")), + AnchoredKey::from((&local, "a")), + AnchoredKey::from((&local, "a")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::SumOf), + vec![ + // Non-monotonic and repeated indices to stress random-access resolution. + OperationArg::Index(2), + OperationArg::Index(0), + OperationArg::Index(0), + ], + OperationAux::None, + ); + let prev_statements = vec![st_a, st_noise, st_sum]; + operation_verify(st, op, prev_statements, Aux::default()) + } + #[test] fn test_operation_verify_productof() -> Result<()> { I64_TEST_PAIRS diff --git a/src/backends/plonky2/circuits/mux_table.rs b/src/backends/plonky2/circuits/mux_table.rs index 110dac9..c93d0e8 100644 --- a/src/backends/plonky2/circuits/mux_table.rs +++ b/src/backends/plonky2/circuits/mux_table.rs @@ -107,11 +107,11 @@ impl MuxTableTarget { rev_resolved_tagged_flattened.reverse(); let resolved_tagged_flattened = rev_resolved_tagged_flattened; - builder.add_simple_generator(TableGetGenerator { - index: index.clone(), - tagged_entries: self.tagged_entries.clone(), - get_tagged_entry: resolved_tagged_flattened.clone(), - }); + builder.add_simple_generator(TableGetGenerator::new( + index.clone(), + self.tagged_entries.clone(), + resolved_tagged_flattened.clone(), + )); measure_gates_end!(builder, measure); TableEntryTarget { params: self.params.clone(), @@ -123,8 +123,18 @@ impl MuxTableTarget { #[derive(Debug, Clone, Default)] pub struct TableGetGenerator { index: IndexTarget, - tagged_entries: Vec>, - get_tagged_entry: Vec, + entries: Vec>, + revealed_entry: Vec, +} + +impl TableGetGenerator { + pub fn new(index: IndexTarget, entries: Vec>, revealed_entry: Vec) -> Self { + Self { + index, + entries, + revealed_entry, + } + } } impl, const D: usize> SimpleGenerator for TableGetGenerator { @@ -135,7 +145,7 @@ impl, const D: usize> SimpleGenerator for Tab fn dependencies(&self) -> Vec { [self.index.low, self.index.high] .into_iter() - .chain(self.tagged_entries.iter().flatten().copied()) + .chain(self.entries.iter().flatten().copied()) .collect() } @@ -148,12 +158,12 @@ impl, const D: usize> SimpleGenerator for Tab let index_high = witness.get_target(self.index.high); let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64(); - let entry = witness.get_targets(&self.tagged_entries[index as usize]); + let entry = witness.get_targets(&self.entries[index as usize]); - for (target, value) in self.get_tagged_entry.iter().zip( + for (target, value) in self.revealed_entry.iter().zip( entry .iter() - .chain(iter::repeat(&F::ZERO).take(self.get_tagged_entry.len())), + .chain(iter::repeat(&F::ZERO).take(self.revealed_entry.len())), ) { out_buffer.set_target(*target, *value)?; } @@ -166,12 +176,12 @@ impl, const D: usize> SimpleGenerator for Tab dst.write_target(self.index.low)?; dst.write_target(self.index.high)?; - dst.write_usize(self.tagged_entries.len())?; - for tagged_entry in &self.tagged_entries { - dst.write_target_vec(tagged_entry)?; + dst.write_usize(self.entries.len())?; + for entry in &self.entries { + dst.write_target_vec(entry)?; } - dst.write_target_vec(&self.get_tagged_entry) + dst.write_target_vec(&self.revealed_entry) } fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { @@ -181,16 +191,16 @@ impl, const D: usize> SimpleGenerator for Tab high: src.read_target()?, }; let len = src.read_usize()?; - let mut tagged_entries = Vec::with_capacity(len); + let mut entries = Vec::with_capacity(len); for _ in 0..len { - tagged_entries.push(src.read_target_vec()?); + entries.push(src.read_target_vec()?); } - let get_tagged_entry = src.read_target_vec()?; + let revealed_entry = src.read_target_vec()?; Ok(Self { index, - tagged_entries, - get_tagged_entry, + entries, + revealed_entry, }) } }