Use projected statement lookup for op arg resolution (#503)
* Use projected statement lookup for op arg resolution * Add projected op-arg index coverage test * Tidying and reorganising
This commit is contained in:
parent
8844fe124c
commit
111b132a00
3 changed files with 159 additions and 38 deletions
|
|
@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize};
|
||||||
use crate::{
|
use crate::{
|
||||||
backends::plonky2::{
|
backends::plonky2::{
|
||||||
basetypes::{CircuitBuilder, CommonCircuitData, D},
|
basetypes::{CircuitBuilder, CommonCircuitData, D},
|
||||||
circuits::mainpod::CustomPredicateVerification,
|
circuits::{mainpod::CustomPredicateVerification, mux_table::TableGetGenerator},
|
||||||
error::Result,
|
error::Result,
|
||||||
mainpod::{Operation, OperationArg, OperationAux, Statement},
|
mainpod::{Operation, OperationArg, OperationAux, Statement},
|
||||||
primitives::merkletree::{
|
primitives::merkletree::{
|
||||||
|
|
@ -1362,6 +1362,18 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
||||||
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T;
|
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T;
|
||||||
/// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target`
|
/// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target`
|
||||||
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
|
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T;
|
||||||
|
/// Like `vec_ref` but for wide rows: random-accesses a precomputed hash of each entry, then
|
||||||
|
/// materializes the selected row via a witness generator and constrains its hash. Cheaper than
|
||||||
|
/// `vec_ref` when each entry has many fields, since random access runs only over the 4-field
|
||||||
|
/// hashes. The caller is responsible for precomputing `ts_flattened` and `ts_hashes` once and
|
||||||
|
/// reusing the same slices across multiple lookups.
|
||||||
|
fn vec_ref_projected<T: Flattenable>(
|
||||||
|
&mut self,
|
||||||
|
params: &Params,
|
||||||
|
ts_flattened: &[Vec<Target>],
|
||||||
|
ts_hashes: &[HashOutTarget],
|
||||||
|
i: &IndexTarget,
|
||||||
|
) -> T;
|
||||||
fn select_flattenable<T: Flattenable>(
|
fn select_flattenable<T: Flattenable>(
|
||||||
&mut self,
|
&mut self,
|
||||||
params: &Params,
|
params: &Params,
|
||||||
|
|
@ -1764,12 +1776,6 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
|
||||||
self.random_access(i.high, chunk_res)
|
self.random_access(i.high, chunk_res)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Implement a version of vec_ref for types `T` which are big and support hashing.
|
|
||||||
// The idea would be the following: Take the array `ts` and hash each element. Then do the
|
|
||||||
// random access on the hash result. Finally "unhash" to recover the resolved element.
|
|
||||||
// We don't want to hash each element from the array each time, so we should cache the hashed
|
|
||||||
// result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and
|
|
||||||
// then do `ts: &[HashCache<T>]`.
|
|
||||||
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T {
|
fn vec_ref<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T {
|
||||||
let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| {
|
let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec<Target>], i| {
|
||||||
let num_rows = m.len();
|
let num_rows = m.len();
|
||||||
|
|
@ -1793,6 +1799,28 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
|
||||||
T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
|
T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn vec_ref_projected<T: Flattenable>(
|
||||||
|
&mut self,
|
||||||
|
params: &Params,
|
||||||
|
ts_flattened: &[Vec<Target>],
|
||||||
|
ts_hashes: &[HashOutTarget],
|
||||||
|
i: &IndexTarget,
|
||||||
|
) -> T {
|
||||||
|
assert_eq!(ts_flattened.len(), ts_hashes.len());
|
||||||
|
let selected_hash = self.vec_ref(params, ts_hashes, i);
|
||||||
|
let selected_flattened = self.add_virtual_targets(T::size(params));
|
||||||
|
let selected_flattened_hash =
|
||||||
|
self.hash_n_to_hash_no_pad::<PoseidonHash>(selected_flattened.clone());
|
||||||
|
self.connect_hashes(selected_hash, selected_flattened_hash);
|
||||||
|
let result = T::from_flattened(params, &selected_flattened);
|
||||||
|
self.add_simple_generator(TableGetGenerator::new(
|
||||||
|
i.clone(),
|
||||||
|
ts_flattened.to_vec(),
|
||||||
|
selected_flattened,
|
||||||
|
));
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
|
fn vec_ref_small<T: Flattenable>(&mut self, params: &Params, ts: &[T], i: Target) -> T {
|
||||||
let zero = self.zero();
|
let zero = self.zero();
|
||||||
self.vec_ref(
|
self.vec_ref(
|
||||||
|
|
|
||||||
|
|
@ -97,9 +97,10 @@ impl<const MAX_EQS: usize> StatementCache<MAX_EQS> {
|
||||||
builder: &mut CircuitBuilder,
|
builder: &mut CircuitBuilder,
|
||||||
op: &OperationTarget,
|
op: &OperationTarget,
|
||||||
st: &StatementTarget,
|
st: &StatementTarget,
|
||||||
prev_statements: &[StatementTarget],
|
prev_statement_flatteneds: &[Vec<Target>],
|
||||||
|
prev_statement_hashes: &[HashOutTarget],
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let op_args = if prev_statements.is_empty() {
|
let op_args = if prev_statement_flatteneds.is_empty() {
|
||||||
(0..max_operation_args)
|
(0..max_operation_args)
|
||||||
.map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[]))
|
.map(|_| StatementTarget::new_native(builder, params, NativePredicate::None, &[]))
|
||||||
.collect_vec()
|
.collect_vec()
|
||||||
|
|
@ -109,7 +110,14 @@ impl<const MAX_EQS: usize> StatementCache<MAX_EQS> {
|
||||||
op.args
|
op.args
|
||||||
.iter()
|
.iter()
|
||||||
.take(max_operation_args)
|
.take(max_operation_args)
|
||||||
.map(|i| builder.vec_ref(params, prev_statements, i))
|
.map(|i| {
|
||||||
|
builder.vec_ref_projected(
|
||||||
|
params,
|
||||||
|
prev_statement_flatteneds,
|
||||||
|
prev_statement_hashes,
|
||||||
|
i,
|
||||||
|
)
|
||||||
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
};
|
};
|
||||||
assert!(Params::max_statement_args() >= MAX_VALUE_ARGS);
|
assert!(Params::max_statement_args() >= MAX_VALUE_ARGS);
|
||||||
|
|
@ -193,7 +201,8 @@ fn verify_operation_public_statement_circuit(
|
||||||
builder: &mut CircuitBuilder,
|
builder: &mut CircuitBuilder,
|
||||||
st: &StatementTarget,
|
st: &StatementTarget,
|
||||||
op: &OperationTarget,
|
op: &OperationTarget,
|
||||||
prev_statements: &[StatementTarget],
|
prev_statement_flatteneds: &[Vec<Target>],
|
||||||
|
prev_statement_hashes: &[HashOutTarget],
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let measure = measure_gates_begin!(builder, "OpVerifyPub");
|
let measure = measure_gates_begin!(builder, "OpVerifyPub");
|
||||||
|
|
||||||
|
|
@ -203,7 +212,15 @@ fn verify_operation_public_statement_circuit(
|
||||||
let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs");
|
let measure_resolve_op_args = measure_gates_begin!(builder, "ResolveOpArgs");
|
||||||
// None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the
|
// None takes 0 arguments, Copy takes 1, so we reduce the number of random accesses that the
|
||||||
// StatementCache requires.
|
// StatementCache requires.
|
||||||
let cache = StatementCachePub::new(params, 1, builder, op, st, prev_statements);
|
let cache = StatementCachePub::new(
|
||||||
|
params,
|
||||||
|
1,
|
||||||
|
builder,
|
||||||
|
op,
|
||||||
|
st,
|
||||||
|
prev_statement_flatteneds,
|
||||||
|
prev_statement_hashes,
|
||||||
|
);
|
||||||
measure_gates_end!(builder, measure_resolve_op_args);
|
measure_gates_end!(builder, measure_resolve_op_args);
|
||||||
|
|
||||||
let op_checks = vec![
|
let op_checks = vec![
|
||||||
|
|
@ -434,7 +451,8 @@ fn verify_operation_circuit(
|
||||||
builder: &mut CircuitBuilder,
|
builder: &mut CircuitBuilder,
|
||||||
st: &StatementTarget,
|
st: &StatementTarget,
|
||||||
op: &OperationTarget,
|
op: &OperationTarget,
|
||||||
prev_statements: &[StatementTarget],
|
prev_statement_flatteneds: &[Vec<Target>],
|
||||||
|
prev_statement_hashes: &[HashOutTarget],
|
||||||
aux_table: &MuxTableTarget,
|
aux_table: &MuxTableTarget,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let measure = measure_gates_begin!(builder, "OpVerifyPriv");
|
let measure = measure_gates_begin!(builder, "OpVerifyPriv");
|
||||||
|
|
@ -451,7 +469,8 @@ fn verify_operation_circuit(
|
||||||
builder,
|
builder,
|
||||||
op,
|
op,
|
||||||
st,
|
st,
|
||||||
prev_statements,
|
prev_statement_flatteneds,
|
||||||
|
prev_statement_hashes,
|
||||||
);
|
);
|
||||||
measure_gates_end!(builder, measure_resolve_op_args);
|
measure_gates_end!(builder, measure_resolve_op_args);
|
||||||
|
|
||||||
|
|
@ -1837,13 +1856,37 @@ fn verify_main_pod_circuit(
|
||||||
// 2. Calculate the Pod Id from the public statements
|
// 2. Calculate the Pod Id from the public statements
|
||||||
let sts_hash = calculate_statements_hash_circuit(builder, pub_statements);
|
let sts_hash = calculate_statements_hash_circuit(builder, pub_statements);
|
||||||
|
|
||||||
|
// Precompute flattened statements and their hashes once, then resolve operation args using
|
||||||
|
// projected lookups. Reusing the flattened forms avoids re-flattening per op-arg lookup.
|
||||||
|
let statement_flatteneds: Vec<Vec<Target>> = statements.iter().map(|st| st.flatten()).collect();
|
||||||
|
let statement_hashes = statement_flatteneds
|
||||||
|
.iter()
|
||||||
|
.map(|flat| builder.hash_n_to_hash_no_pad::<PoseidonHash>(flat.clone()))
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
// 5. Verify input statements
|
// 5. Verify input statements
|
||||||
for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() {
|
for (i, (st, op)) in izip!(&main_pod.input_statements, &main_pod.operations).enumerate() {
|
||||||
let prev_statements = &statements[..input_statements_offset + i];
|
let prev_statement_flatteneds = &statement_flatteneds[..input_statements_offset + i];
|
||||||
|
let prev_statement_hashes = &statement_hashes[..input_statements_offset + i];
|
||||||
if i < public_statements_offset {
|
if i < public_statements_offset {
|
||||||
verify_operation_circuit(params, builder, st, op, prev_statements, &aux_table)?;
|
verify_operation_circuit(
|
||||||
|
params,
|
||||||
|
builder,
|
||||||
|
st,
|
||||||
|
op,
|
||||||
|
prev_statement_flatteneds,
|
||||||
|
prev_statement_hashes,
|
||||||
|
&aux_table,
|
||||||
|
)?;
|
||||||
} else {
|
} else {
|
||||||
verify_operation_public_statement_circuit(params, builder, st, op, prev_statements)?;
|
verify_operation_public_statement_circuit(
|
||||||
|
params,
|
||||||
|
builder,
|
||||||
|
st,
|
||||||
|
op,
|
||||||
|
prev_statement_flatteneds,
|
||||||
|
prev_statement_hashes,
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2221,6 +2264,14 @@ mod tests {
|
||||||
let prev_statements_target: Vec<_> = (0..prev_statements.len())
|
let prev_statements_target: Vec<_> = (0..prev_statements.len())
|
||||||
.map(|_| builder.add_virtual_statement(false))
|
.map(|_| builder.add_virtual_statement(false))
|
||||||
.collect();
|
.collect();
|
||||||
|
let prev_statement_flatteneds_target: Vec<Vec<Target>> = prev_statements_target
|
||||||
|
.iter()
|
||||||
|
.map(|st| st.flatten())
|
||||||
|
.collect();
|
||||||
|
let prev_statement_hashes_target: Vec<_> = prev_statement_flatteneds_target
|
||||||
|
.iter()
|
||||||
|
.map(|flat| builder.hash_n_to_hash_no_pad::<PoseidonHash>(flat.clone()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
let merkle_proofs_target: Vec<_> = aux
|
let merkle_proofs_target: Vec<_> = aux
|
||||||
.merkle_proofs
|
.merkle_proofs
|
||||||
|
|
@ -2269,7 +2320,8 @@ mod tests {
|
||||||
&mut builder,
|
&mut builder,
|
||||||
&st_target,
|
&st_target,
|
||||||
&op_target,
|
&op_target,
|
||||||
&prev_statements_target,
|
&prev_statement_flatteneds_target,
|
||||||
|
&prev_statement_hashes_target,
|
||||||
&aux_table,
|
&aux_table,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
|
@ -2711,6 +2763,37 @@ mod tests {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_operation_verify_sumof_non_monotonic_repeated_indices() -> Result<()> {
|
||||||
|
let local = dict!({
|
||||||
|
"a" => 3,
|
||||||
|
"noise" => 99,
|
||||||
|
"sum" => 6,
|
||||||
|
});
|
||||||
|
let st_a: mainpod::Statement = Statement::contains(local.clone(), "a", 3).into();
|
||||||
|
let st_noise: mainpod::Statement = Statement::contains(local.clone(), "noise", 99).into();
|
||||||
|
let st_sum: mainpod::Statement = Statement::contains(local.clone(), "sum", 6).into();
|
||||||
|
|
||||||
|
let st: mainpod::Statement = Statement::sum_of(
|
||||||
|
AnchoredKey::from((&local, "sum")),
|
||||||
|
AnchoredKey::from((&local, "a")),
|
||||||
|
AnchoredKey::from((&local, "a")),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
let op = mainpod::Operation(
|
||||||
|
OperationType::Native(NativeOperation::SumOf),
|
||||||
|
vec![
|
||||||
|
// Non-monotonic and repeated indices to stress random-access resolution.
|
||||||
|
OperationArg::Index(2),
|
||||||
|
OperationArg::Index(0),
|
||||||
|
OperationArg::Index(0),
|
||||||
|
],
|
||||||
|
OperationAux::None,
|
||||||
|
);
|
||||||
|
let prev_statements = vec![st_a, st_noise, st_sum];
|
||||||
|
operation_verify(st, op, prev_statements, Aux::default())
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_operation_verify_productof() -> Result<()> {
|
fn test_operation_verify_productof() -> Result<()> {
|
||||||
I64_TEST_PAIRS
|
I64_TEST_PAIRS
|
||||||
|
|
|
||||||
|
|
@ -107,11 +107,11 @@ impl MuxTableTarget {
|
||||||
rev_resolved_tagged_flattened.reverse();
|
rev_resolved_tagged_flattened.reverse();
|
||||||
let resolved_tagged_flattened = rev_resolved_tagged_flattened;
|
let resolved_tagged_flattened = rev_resolved_tagged_flattened;
|
||||||
|
|
||||||
builder.add_simple_generator(TableGetGenerator {
|
builder.add_simple_generator(TableGetGenerator::new(
|
||||||
index: index.clone(),
|
index.clone(),
|
||||||
tagged_entries: self.tagged_entries.clone(),
|
self.tagged_entries.clone(),
|
||||||
get_tagged_entry: resolved_tagged_flattened.clone(),
|
resolved_tagged_flattened.clone(),
|
||||||
});
|
));
|
||||||
measure_gates_end!(builder, measure);
|
measure_gates_end!(builder, measure);
|
||||||
TableEntryTarget {
|
TableEntryTarget {
|
||||||
params: self.params.clone(),
|
params: self.params.clone(),
|
||||||
|
|
@ -123,8 +123,18 @@ impl MuxTableTarget {
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct TableGetGenerator {
|
pub struct TableGetGenerator {
|
||||||
index: IndexTarget,
|
index: IndexTarget,
|
||||||
tagged_entries: Vec<Vec<Target>>,
|
entries: Vec<Vec<Target>>,
|
||||||
get_tagged_entry: Vec<Target>,
|
revealed_entry: Vec<Target>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TableGetGenerator {
|
||||||
|
pub fn new(index: IndexTarget, entries: Vec<Vec<Target>>, revealed_entry: Vec<Target>) -> Self {
|
||||||
|
Self {
|
||||||
|
index,
|
||||||
|
entries,
|
||||||
|
revealed_entry,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for TableGetGenerator {
|
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for TableGetGenerator {
|
||||||
|
|
@ -135,7 +145,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
|
||||||
fn dependencies(&self) -> Vec<Target> {
|
fn dependencies(&self) -> Vec<Target> {
|
||||||
[self.index.low, self.index.high]
|
[self.index.low, self.index.high]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.chain(self.tagged_entries.iter().flatten().copied())
|
.chain(self.entries.iter().flatten().copied())
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -148,12 +158,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
|
||||||
let index_high = witness.get_target(self.index.high);
|
let index_high = witness.get_target(self.index.high);
|
||||||
let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64();
|
let index = (index_low + index_high * F::from_canonical_usize(1 << 6)).to_canonical_u64();
|
||||||
|
|
||||||
let entry = witness.get_targets(&self.tagged_entries[index as usize]);
|
let entry = witness.get_targets(&self.entries[index as usize]);
|
||||||
|
|
||||||
for (target, value) in self.get_tagged_entry.iter().zip(
|
for (target, value) in self.revealed_entry.iter().zip(
|
||||||
entry
|
entry
|
||||||
.iter()
|
.iter()
|
||||||
.chain(iter::repeat(&F::ZERO).take(self.get_tagged_entry.len())),
|
.chain(iter::repeat(&F::ZERO).take(self.revealed_entry.len())),
|
||||||
) {
|
) {
|
||||||
out_buffer.set_target(*target, *value)?;
|
out_buffer.set_target(*target, *value)?;
|
||||||
}
|
}
|
||||||
|
|
@ -166,12 +176,12 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
|
||||||
dst.write_target(self.index.low)?;
|
dst.write_target(self.index.low)?;
|
||||||
dst.write_target(self.index.high)?;
|
dst.write_target(self.index.high)?;
|
||||||
|
|
||||||
dst.write_usize(self.tagged_entries.len())?;
|
dst.write_usize(self.entries.len())?;
|
||||||
for tagged_entry in &self.tagged_entries {
|
for entry in &self.entries {
|
||||||
dst.write_target_vec(tagged_entry)?;
|
dst.write_target_vec(entry)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst.write_target_vec(&self.get_tagged_entry)
|
dst.write_target_vec(&self.revealed_entry)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
|
fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData<F, D>) -> IoResult<Self> {
|
||||||
|
|
@ -181,16 +191,16 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F, D> for Tab
|
||||||
high: src.read_target()?,
|
high: src.read_target()?,
|
||||||
};
|
};
|
||||||
let len = src.read_usize()?;
|
let len = src.read_usize()?;
|
||||||
let mut tagged_entries = Vec::with_capacity(len);
|
let mut entries = Vec::with_capacity(len);
|
||||||
for _ in 0..len {
|
for _ in 0..len {
|
||||||
tagged_entries.push(src.read_target_vec()?);
|
entries.push(src.read_target_vec()?);
|
||||||
}
|
}
|
||||||
let get_tagged_entry = src.read_target_vec()?;
|
let revealed_entry = src.read_target_vec()?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
index,
|
index,
|
||||||
tagged_entries,
|
entries,
|
||||||
get_tagged_entry,
|
revealed_entry,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue