diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 5b5db73..c189725 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -33,7 +33,7 @@ use crate::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg, StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, - HASH_SIZE, OPERATION_ARG_F_LEN, OPERATION_AUX_F_LEN, STATEMENT_ARG_F_LEN, VALUE_SIZE, + HASH_SIZE, STATEMENT_ARG_F_LEN, VALUE_SIZE, }, }; @@ -255,9 +255,9 @@ impl OperationTypeTarget { #[derive(Clone, Serialize, Deserialize)] pub struct OperationTarget { pub op_type: OperationTypeTarget, - pub args: Vec<[Target; OPERATION_ARG_F_LEN]>, + pub args: Vec, #[serde(with = "serde_arrays")] - pub aux: [Target; OPERATION_AUX_F_LEN], + pub aux: [IndexTarget; 2], } impl OperationTarget { @@ -275,9 +275,12 @@ impl OperationTarget { .take(params.max_operation_args) .enumerate() { - pw.set_target_arr(&self.args[i], &arg.to_fields(params))?; + self.args[i].set_targets(pw, arg.as_usize())?; + } + let indexes = op.aux().as_usizes(); + for (index_target, index) in self.aux.iter().zip_eq(indexes.iter()) { + index_target.set_targets(pw, *index)?; } - pw.set_target_arr(&self.aux, &op.aux().to_fields(params))?; Ok(()) } } @@ -584,7 +587,7 @@ impl CustomPredicateEntryTarget { // Custom predicate verification table entry #[derive(Clone, Serialize, Deserialize)] pub struct CustomPredicateVerifyEntryTarget { - pub custom_predicate_table_index: Target, + pub custom_predicate_table_index: IndexTarget, pub custom_predicate: CustomPredicateEntryTarget, pub args: Vec, pub op_args: Vec, @@ -592,8 +595,13 @@ pub struct CustomPredicateVerifyEntryTarget { impl CustomPredicateVerifyEntryTarget { pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self { + let custom_predicate_table_len = + params.max_custom_predicate_batches * params.max_custom_batch_size; CustomPredicateVerifyEntryTarget { - custom_predicate_table_index: builder.add_virtual_target(), + custom_predicate_table_index: IndexTarget::new_virtual( + custom_predicate_table_len, + builder, + ), custom_predicate: builder.add_virtual_custom_predicate_entry(params), args: (0..params.max_custom_predicate_wildcards) .map(|_| builder.add_virtual_value()) @@ -609,10 +617,8 @@ impl CustomPredicateVerifyEntryTarget { params: &Params, cpv: &CustomPredicateVerification, ) -> Result<()> { - pw.set_target( - self.custom_predicate_table_index, - F::from_canonical_usize(cpv.custom_predicate_table_index), - )?; + self.custom_predicate_table_index + .set_targets(pw, cpv.custom_predicate_table_index)?; // Replace statement templates of batch-self with (id,index) self.custom_predicate .set_targets(pw, params, &cpv.custom_predicate)?; @@ -665,7 +671,7 @@ impl Flattenable for CustomPredicateVerifyQueryTarget { fn from_flattened(params: &Params, vs: &[Target]) -> Self { let (pos, size) = (0, params.statement_size()); let statement = StatementTarget::from_flattened(params, &vs[pos..pos + size]); - let (pos, size) = (pos + size, params.operation_size()); + let (pos, size) = (pos + size, params.operation_size(IndexTarget::f_len())); let op_type = OperationTypeTarget { elements: vs[pos..pos + size] .try_into() @@ -867,6 +873,41 @@ impl Flattenable for StatementTmplArgTarget { } } +/// Index to an array for random access +#[derive(Clone, Serialize, Deserialize)] +pub struct IndexTarget { + max_array_len: usize, + low: Target, + high: Target, +} + +impl IndexTarget { + // Length in field elements + pub const fn f_len() -> usize { + 2 + } + pub fn new_virtual(max_array_len: usize, builder: &mut CircuitBuilder) -> Self { + // Limit the maximum array length to avoid abusing `vec_ref` + assert!(max_array_len <= 256); + Self { + max_array_len, + low: builder.add_virtual_target(), + high: if max_array_len > 64 { + builder.add_virtual_target() + } else { + builder.zero() + }, + } + } + + pub fn set_targets(&self, pw: &mut PartialWitness, index: usize) -> Result<()> { + assert!(index == 0 || index < self.max_array_len); + pw.set_target(self.low, F::from_canonical_usize(index & ((1 << 6) - 1)))?; + pw.set_target(self.high, F::from_canonical_usize(index >> 6))?; + Ok(()) + } +} + pub trait CircuitBuilderPod, const D: usize> { fn connect_values(&mut self, x: ValueTarget, y: ValueTarget); fn connect_slice(&mut self, xs: &[Target], ys: &[Target]); @@ -932,9 +973,14 @@ pub trait CircuitBuilderPod, const D: usize> { /// Creates value target that is a hash of two given values. fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget; - // Convenience methods for accessing and connecting elements of - // (vectors of) flattenables. - fn vec_ref(&mut self, params: &Params, ts: &[T], i: Target) -> T; + /// Like `random_access` but allows using longer arrays. + fn random_access_long(&mut self, i: &IndexTarget, array: &[Target]) -> Target; + + /// Convenience methods for accessing and connecting elements of + /// (vectors of) flattenables. + fn vec_ref(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T; + /// Like `vec_ref` but only supports arrays up to 64 elements and the index is a simple `Target` + fn vec_ref_small(&mut self, params: &Params, ts: &[T], i: Target) -> T; fn select_flattenable( &mut self, params: &Params, @@ -945,11 +991,11 @@ pub trait CircuitBuilderPod, const D: usize> { fn connect_flattenable(&mut self, xs: &T, ys: &T); fn is_equal_flattenable(&mut self, xs: &T, ys: &T) -> BoolTarget; - // Convenience methods for Boolean into-iters. + /// Convenience methods for Boolean into-iters. fn all(&mut self, xs: impl IntoIterator) -> BoolTarget; fn any(&mut self, xs: impl IntoIterator) -> BoolTarget; - // Return a bit-mask of size `len` that selects all positions lower than `n` + /// Return a bit-mask of size `len` that selects all positions lower than `n` fn lt_mask(&mut self, len: usize, n: Target) -> Vec; } @@ -1003,9 +1049,12 @@ impl CircuitBuilderPod for CircuitBuilder { OperationTarget { op_type: self.add_virtual_operation_type(), args: (0..params.max_operation_args) - .map(|_| self.add_virtual_target_arr()) + .map(|_| IndexTarget::new_virtual(params.statement_table_size(), self)) .collect(), - aux: self.add_virtual_target_arr(), + aux: [ + IndexTarget::new_virtual(params.max_merkle_proofs_containers, self), + IndexTarget::new_virtual(params.max_custom_predicate_verifications, self), + ], } } @@ -1293,18 +1342,41 @@ impl CircuitBuilderPod for CircuitBuilder { ) } + fn random_access_long(&mut self, i: &IndexTarget, array: &[Target]) -> Target { + const CHUNK_LEN: usize = 64; // Max size of a single gate native random access + assert!(array.len() <= i.max_array_len); + // Limit to 4 chunks (combination of 4 random_access of CHUNK_LEN elements) to avoid + // abusing this method. + assert!(array.len() <= 4 * CHUNK_LEN); + + // We do several random accesses over chunks of CHUNK_LEN using the lowest bits of the + // index. Then we combine them using the highest bits of the index. + let mut chunk_res = Vec::new(); + let num_chunks = array.len().div_ceil(CHUNK_LEN); + for chunk in array.chunks(CHUNK_LEN) { + let mut index_chunk = i.low; + // I we have several chunks and the last one is smaller (it's index needs less than 6 + // bits), make it zero except when it's used so that the range check over the index + // passes. + if chunk.len() <= CHUNK_LEN / 2 && num_chunks > 1 { + let last_chunk_index_high = self.constant(F::from_canonical_usize(num_chunks - 1)); + let selector = self.is_equal(i.high, last_chunk_index_high); + index_chunk = self.mul(index_chunk, selector.target); + } + let res = self.random_access(index_chunk, chunk.to_vec()); + chunk_res.push(res); + } + + self.random_access(i.high, chunk_res) + } + // TODO: Implement a version of vec_ref for types `T` which are big and support hashing. // The idea would be the following: Take the array `ts` and hash each element. Then do the // random access on the hash result. Finally "unhash" to recover the resolved element. // We don't want to hash each element from the array each time, so we should cache the hashed // result. For that we can create a wrapper over `T: Flattenable` that caches the hash, and // then do `ts: &[HashCache]`. - fn vec_ref(&mut self, params: &Params, ts: &[T], i: Target) -> T { - // TODO: Revisit this when we need more than 64 statements. - let vector_ref = |builder: &mut CircuitBuilder, v: &[Target], i| { - assert!(v.len() <= 64); - builder.random_access(i, v.to_vec()) - }; + fn vec_ref(&mut self, params: &Params, ts: &[T], i: &IndexTarget) -> T { let matrix_row_ref = |builder: &mut CircuitBuilder, m: &[Vec], i| { let num_rows = m.len(); let num_columns = m @@ -1317,11 +1389,8 @@ impl CircuitBuilderPod for CircuitBuilder { .unwrap_or(0); (0..num_columns) .map(|j| { - vector_ref( - builder, - &(0..num_rows).map(|i| m[i][j]).collect::>(), - i, - ) + builder + .random_access_long(i, &(0..num_rows).map(|i| m[i][j]).collect::>()) }) .collect::>() }; @@ -1330,6 +1399,19 @@ impl CircuitBuilderPod for CircuitBuilder { T::from_flattened(params, &matrix_row_ref(self, &flattened_ts, i)) } + fn vec_ref_small(&mut self, params: &Params, ts: &[T], i: Target) -> T { + let zero = self.zero(); + self.vec_ref( + params, + ts, + &IndexTarget { + max_array_len: 64, + low: i, + high: zero, + }, + ) + } + fn select_flattenable( &mut self, params: &Params, @@ -1639,4 +1721,36 @@ pub(crate) mod tests { } }) } + + #[test] + fn test_random_access_long() -> Result<(), anyhow::Error> { + let lens: [usize; _] = [10, 60, 64, 96, 126, 159, 190, 256]; + + for len in &lens { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let array = builder.add_virtual_targets(*len); + let index_target = IndexTarget::new_virtual(*len, &mut builder); + let res = builder.random_access_long(&index_target, &array); + + let data = builder.build::(); + + for i in 0..3 { + let index = (len - 1) * i / 2; + println!("len={}, index={}", len, index); + let mut pw = PartialWitness::::new(); + for (j, elem) in array.iter().enumerate() { + pw.set_target(*elem, F::from_canonical_usize(j * 11))?; + } + index_target.set_targets(&mut pw, index)?; + pw.set_target(res, F::from_canonical_usize(index * 11))?; // Expected + + let proof = data.prove(pw)?; + data.verify(proof)?; + } + } + + Ok(()) + } } diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 332eecd..81bd5a6 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -95,8 +95,7 @@ impl StatementCache { // converting a length 1 array into a scalar. op.args .iter() - .flatten() - .map(|&i| builder.vec_ref(params, prev_statements, i)) + .map(|i| builder.vec_ref(params, prev_statements, i)) .collect::>() }; assert!(params.max_operation_args >= 3); @@ -224,7 +223,7 @@ fn verify_operation_circuit( // been verified, so we need only look up the claim. let measure_resolve_merkle_claim = measure_gates_begin!(builder, "ResolveMerkleClaim"); let resolved_merkle_claim = - (!merkle_claims.is_empty()).then(|| builder.vec_ref(params, merkle_claims, op.aux[0])); + (!merkle_claims.is_empty()).then(|| builder.vec_ref(params, merkle_claims, &op.aux[0])); measure_gates_end!(builder, measure_resolve_merkle_claim); // Operations from custom statements will refer to one @@ -233,7 +232,7 @@ fn verify_operation_circuit( let measure_resolve_custom_pred_verification = measure_gates_begin!(builder, "ResolveCustomPredVerification"); let resolved_custom_pred_verification = (!custom_predicate_verification_table.is_empty()) - .then(|| builder.vec_ref(params, custom_predicate_verification_table, op.aux[1])); + .then(|| builder.vec_ref(params, custom_predicate_verification_table, &op.aux[1])); measure_gates_end!(builder, measure_resolve_custom_pred_verification); // The verification may require aux data which needs to be stored in the @@ -938,7 +937,7 @@ fn make_statement_arg_from_template_circuit( let first_index = ak_id_wc_index; let is_first_index_valid = builder.or(is_ak, is_wc_literal); let first_index = builder.select(is_first_index_valid, first_index, zero); - let resolved_ak_id = builder.vec_ref(params, &args, first_index); + let resolved_ak_id = builder.vec_ref_small(params, &args, first_index); let resolved_wc = resolved_ak_id; // If the index is not used, use a 0 instead to still pass the range constraints from @@ -946,7 +945,7 @@ fn make_statement_arg_from_template_circuit( let second_index = ak_key_wc_index; let is_second_index_valid = builder.and(is_ak, is_ak_key_wc); let second_index = builder.select(is_second_index_valid, second_index, zero); - let resolved_ak_key = builder.vec_ref(params, &args, second_index); + let resolved_ak_key = builder.vec_ref_small(params, &args, second_index); let ak_key = ak_key_lit; // is_ak_key_lit let ak_key = builder.select_flattenable(params, is_ak_key_wc, &resolved_ak_key, &ak_key); @@ -1244,7 +1243,7 @@ fn build_custom_predicate_verification_table_circuit( let table_query_hash = builder.vec_ref( params, custom_predicate_table, - entry.custom_predicate_table_index, + &entry.custom_predicate_table_index, ); let out_query_hash = entry.custom_predicate.hash(builder); builder.connect_array(table_query_hash.elements, out_query_hash.elements); diff --git a/src/backends/plonky2/mainpod/operation.rs b/src/backends/plonky2/mainpod/operation.rs index 6a91188..49f6770 100644 --- a/src/backends/plonky2/mainpod/operation.rs +++ b/src/backends/plonky2/mainpod/operation.rs @@ -1,6 +1,5 @@ use std::fmt; -use plonky2::field::types::Field; use serde::{Deserialize, Serialize}; use crate::{ @@ -9,7 +8,7 @@ use crate::{ mainpod::Statement, primitives::merkletree::MerkleClaimAndProof, }, - middleware::{self, OperationType, Params, ToFields, F}, + middleware::{self, OperationType}, }; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -18,20 +17,17 @@ pub enum OperationArg { Index(usize), } -impl ToFields for OperationArg { - fn to_fields(&self, _params: &Params) -> Vec { - let f = match self { - Self::None => F::ZERO, - Self::Index(i) => F::from_canonical_usize(*i), - }; - vec![f] - } -} - impl OperationArg { pub fn is_none(&self) -> bool { matches!(self, OperationArg::None) } + + pub fn as_usize(&self) -> usize { + match self { + Self::None => 0, + Self::Index(i) => *i, + } + } } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -41,14 +37,13 @@ pub enum OperationAux { CustomPredVerifyIndex(usize), } -impl ToFields for OperationAux { - fn to_fields(&self, _params: &Params) -> Vec { - let fs = match self { - Self::None => [F::ZERO, F::ZERO], - Self::MerkleProofIndex(i) => [F::from_canonical_usize(*i), F::ZERO], - Self::CustomPredVerifyIndex(i) => [F::ZERO, F::from_canonical_usize(*i)], - }; - vec![fs[0], fs[1]] +impl OperationAux { + pub fn as_usizes(&self) -> [usize; 2] { + match self { + Self::None => [0, 0], + Self::MerkleProofIndex(i) => [*i, 0], + Self::CustomPredVerifyIndex(i) => [0, *i], + } } } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 415c078..6d75071 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -828,8 +828,8 @@ impl Params { Self::predicate_size() + STATEMENT_ARG_F_LEN * self.max_statement_args } - pub fn operation_size(&self) -> usize { - Self::operation_type_size() + OPERATION_ARG_F_LEN * self.max_operation_args + pub fn operation_size(&self, operation_arg_f_len: usize) -> usize { + Self::operation_type_size() + operation_arg_f_len * self.max_operation_args } pub const fn statement_tmpl_size(&self) -> usize { @@ -844,6 +844,14 @@ impl Params { self.max_custom_batch_size * self.custom_predicate_size() } + /// Total size of the statement table including None, input statements from signed pods and + /// input recursive pods and new statements (public & private) + pub fn statement_table_size(&self) -> usize { + 1 + self.max_input_signed_pods * self.max_signed_pod_values + + self.max_input_recursive_pods * self.max_input_pods_public_statements + + self.max_statements + } + /// Parameters that define how the id is calculated pub fn id_params(&self) -> Vec { vec![ diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 3cbbdb8..9b4bf3e 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -18,8 +18,6 @@ pub const KEY_SIGNER: &str = "_signer"; // hash(KEY_TYPE) = [17948789436443445142, 12513915140657440811, 15878361618879468769, 938231894693848619] pub const KEY_TYPE: &str = "_type"; pub const STATEMENT_ARG_F_LEN: usize = 8; -pub const OPERATION_ARG_F_LEN: usize = 1; -pub const OPERATION_AUX_F_LEN: usize = 2; #[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] pub enum NativePredicate {