diff --git a/src/backends/plonky2/basetypes.rs b/src/backends/plonky2/basetypes.rs index 393654e..65368d8 100644 --- a/src/backends/plonky2/basetypes.rs +++ b/src/backends/plonky2/basetypes.rs @@ -38,8 +38,8 @@ pub const EMPTY_HASH: Hash = Hash([F::ZERO, F::ZERO, F::ZERO, F::ZERO]); pub struct Value(pub [F; VALUE_SIZE]); impl ToFields for Value { - fn to_fields(&self, _params: &Params) -> (Vec, usize) { - (self.0.to_vec(), VALUE_SIZE) + fn to_fields(&self, _params: &Params) -> Vec { + self.0.to_vec() } } @@ -143,8 +143,8 @@ impl Hash { } impl ToFields for Hash { - fn to_fields(&self, _params: &Params) -> (Vec, usize) { - (self.0.to_vec(), VALUE_SIZE) + fn to_fields(&self, _params: &Params) -> Vec { + self.0.to_vec() } } diff --git a/src/backends/plonky2/mock_main/mod.rs b/src/backends/plonky2/mock_main/mod.rs index a4b4cad..e22c3bd 100644 --- a/src/backends/plonky2/mock_main/mod.rs +++ b/src/backends/plonky2/mock_main/mod.rs @@ -364,7 +364,7 @@ impl MockMainPod { pub fn hash_statements(statements: &[Statement], _params: &Params) -> middleware::Hash { let field_elems = statements .iter() - .flat_map(|statement| statement.clone().to_fields(_params).0) + .flat_map(|statement| statement.clone().to_fields(_params)) .collect::>(); Hash(PoseidonHash::hash_no_pad(&field_elems).elements) } diff --git a/src/backends/plonky2/mock_main/statement.rs b/src/backends/plonky2/mock_main/statement.rs index 452a29c..553b3d2 100644 --- a/src/backends/plonky2/mock_main/statement.rs +++ b/src/backends/plonky2/mock_main/statement.rs @@ -23,22 +23,10 @@ impl Statement { } impl ToFields for Statement { - fn to_fields(&self, _params: &Params) -> (Vec, usize) { - let (native_statement_f, native_statement_f_len) = self.0.to_fields(_params); - let (vec_statementarg_f, vec_statementarg_f_len) = self - .1 - .clone() - .into_iter() - .map(|statement_arg| statement_arg.to_fields(_params)) - .fold((Vec::new(), 0), |mut acc, (f, l)| { - acc.0.extend(f); - acc.1 += l; - acc - }); - ( - [native_statement_f, vec_statementarg_f].concat(), - native_statement_f_len + vec_statementarg_f_len, - ) + fn to_fields(&self, _params: &Params) -> Vec { + let mut fields = self.0.to_fields(_params); + fields.extend(self.1.iter().flat_map(|arg| arg.to_fields(_params))); + fields } } diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 1c133cd..d4d0caa 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; -use std::{fmt, hash as h, iter::zip}; +use std::{fmt, hash as h, iter, iter::zip}; use anyhow::{anyhow, Result}; use plonky2::field::types::Field; @@ -48,18 +48,13 @@ impl fmt::Display for HashOrWildcard { } impl ToFields for HashOrWildcard { - fn to_fields(&self, _params: &Params) -> (Vec, usize) { + fn to_fields(&self, _params: &Params) -> Vec { match self { HashOrWildcard::Hash(h) => h.to_fields(_params), - HashOrWildcard::Wildcard(w) => { - let mut usizes: Vec = vec![0; HASH_SIZE - 1]; - usizes.push(*w); - let fields: Vec = usizes - .iter() - .map(|x| F::from_canonical_u64(*x as u64)) - .collect(); - (fields, HASH_SIZE) - } + HashOrWildcard::Wildcard(w) => (0..HASH_SIZE - 1) + .chain(iter::once(*w)) + .map(|x| F::from_canonical_u64(x as u64)) + .collect(), } } } @@ -95,7 +90,7 @@ impl StatementTmplArg { } impl ToFields for StatementTmplArg { - fn to_fields(&self, _params: &Params) -> (Vec, usize) { + fn to_fields(&self, _params: &Params) -> Vec { // None => (0, ...) // Literal(value) => (1, [value], 0, 0, 0, 0) // Key(hash_or_wildcard1, hash_or_wildcard2) @@ -104,24 +99,24 @@ impl ToFields for StatementTmplArg { let statement_tmpl_arg_size = 2 * HASH_SIZE + 1; match self { StatementTmplArg::None => { - let fields: Vec = std::iter::repeat_with(|| F::from_canonical_u64(0)) + let fields: Vec = iter::repeat_with(|| F::from_canonical_u64(0)) .take(statement_tmpl_arg_size) .collect(); - (fields, statement_tmpl_arg_size) + fields } StatementTmplArg::Literal(v) => { - let fields: Vec = std::iter::once(F::from_canonical_u64(1)) - .chain(v.to_fields(_params).0) - .chain(std::iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) + let fields: Vec = iter::once(F::from_canonical_u64(1)) + .chain(v.to_fields(_params)) + .chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) .collect(); - (fields, statement_tmpl_arg_size) + fields } StatementTmplArg::Key(hw1, hw2) => { - let fields: Vec = std::iter::once(F::from_canonical_u64(2)) - .chain(hw1.to_fields(_params).0) - .chain(hw2.to_fields(_params).0) + let fields: Vec = iter::once(F::from_canonical_u64(2)) + .chain(hw1.to_fields(_params)) + .chain(hw2.to_fields(_params)) .collect(); - (fields, statement_tmpl_arg_size) + fields } } } @@ -181,7 +176,7 @@ impl StatementTmpl { } impl ToFields for StatementTmpl { - fn to_fields(&self, params: &Params) -> (Vec, usize) { + fn to_fields(&self, params: &Params) -> Vec { // serialize as: // predicate (6 field elements) // then the StatementTmplArgs @@ -196,12 +191,11 @@ impl ToFields for StatementTmpl { let mut fields: Vec = self .0 .to_fields(params) - .0 .into_iter() - .chain(self.1.iter().flat_map(|sta| sta.to_fields(params).0)) + .chain(self.1.iter().flat_map(|sta| sta.to_fields(params))) .collect(); fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0)); - (fields, params.statement_tmpl_size()) + fields } } @@ -244,7 +238,7 @@ impl CustomPredicate { } impl ToFields for CustomPredicate { - fn to_fields(&self, params: &Params) -> (Vec, usize) { + fn to_fields(&self, params: &Params) -> Vec { // serialize as: // conjunction (one field element) // args_len (one field element) @@ -259,12 +253,12 @@ impl ToFields for CustomPredicate { panic!("Custom predicate depends on too many statements"); } - let mut fields: Vec = std::iter::once(F::from_bool(self.conjunction)) - .chain(std::iter::once(F::from_canonical_usize(self.args_len))) - .chain(self.statements.iter().flat_map(|st| st.to_fields(params).0)) + let mut fields: Vec = iter::once(F::from_bool(self.conjunction)) + .chain(iter::once(F::from_canonical_usize(self.args_len))) + .chain(self.statements.iter().flat_map(|st| st.to_fields(params))) .collect(); fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0)); - (fields, params.custom_predicate_size()) + fields } } @@ -300,7 +294,7 @@ pub struct CustomPredicateBatch { } impl ToFields for CustomPredicateBatch { - fn to_fields(&self, params: &Params) -> (Vec, usize) { + fn to_fields(&self, params: &Params) -> Vec { // all the custom predicates in order // TODO think if this check should go into the StatementTmpl creation, @@ -313,19 +307,18 @@ impl ToFields for CustomPredicateBatch { let mut fields: Vec = self .predicates .iter() - .flat_map(|p| p.to_fields(params).0) + .flat_map(|p| p.to_fields(params)) .collect(); fields.resize_with(params.custom_predicate_batch_size_field_elts(), || { F::from_canonical_u64(0) }); - - (fields, params.custom_predicate_batch_size_field_elts()) + fields } } impl CustomPredicateBatch { pub fn hash(&self, _params: &Params) -> Hash { - let input = self.to_fields(_params).0; + let input = self.to_fields(_params); hash_fields(&input) } @@ -367,7 +360,7 @@ impl CustomPredicateRef { match custom_predicate.conjunction { true if custom_predicate.statements.len() == statements.len() => { // Match op args against statement templates - let match_bindings = std::iter::zip(custom_predicate.statements, statements).map( + let match_bindings = iter::zip(custom_predicate.statements, statements).map( |(s_tmpl, s)| s_tmpl.match_against(s) ).collect::>>() .map(|v| v.concat())?; @@ -404,7 +397,7 @@ impl From for Predicate { } impl ToFields for Predicate { - fn to_fields(&self, _params: &Params) -> (Vec, usize) { + fn to_fields(&self, _params: &Params) -> Vec { // serialize: // NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize // BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize @@ -414,19 +407,19 @@ impl ToFields for Predicate { // in every case: pad to (hash_size + 2) field elements let mut fields: Vec = match self { - Self::Native(p) => std::iter::once(F::from_canonical_u64(1)) - .chain(p.to_fields(_params).0) + Self::Native(p) => iter::once(F::from_canonical_u64(1)) + .chain(p.to_fields(_params)) .collect(), - Self::BatchSelf(i) => std::iter::once(F::from_canonical_u64(2)) - .chain(std::iter::once(F::from_canonical_usize(*i))) + Self::BatchSelf(i) => iter::once(F::from_canonical_u64(2)) + .chain(iter::once(F::from_canonical_usize(*i))) .collect(), - Self::Custom(CustomPredicateRef(pb, i)) => std::iter::once(F::from_canonical_u64(3)) + Self::Custom(CustomPredicateRef(pb, i)) => iter::once(F::from_canonical_u64(3)) .chain(pb.hash(_params).0) - .chain(std::iter::once(F::from_canonical_usize(*i))) + .chain(iter::once(F::from_canonical_usize(*i))) .collect(), }; fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0)); - (fields, Params::predicate_size()) + fields } } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index af0c2cb..1032f1a 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -58,7 +58,7 @@ pub type Entry = (String, Value); pub struct PodId(pub Hash); impl ToFields for PodId { - fn to_fields(&self, params: &Params) -> (Vec, usize) { + fn to_fields(&self, params: &Params) -> Vec { self.0.to_fields(params) } } @@ -211,7 +211,6 @@ pub trait PodProver { } pub trait ToFields { - /// returns Vec representation of the type, and a usize indicating how many field elements - /// does the vector contain - fn to_fields(&self, params: &Params) -> (Vec, usize); + /// returns Vec representation of the type + fn to_fields(&self, params: &Params) -> Vec; } diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 6b03d51..a6bd627 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -1,9 +1,9 @@ use anyhow::{anyhow, Result}; use plonky2::field::types::Field; -use std::fmt; +use std::{fmt, iter}; use strum_macros::FromRepr; -use super::{AnchoredKey, CustomPredicateRef, Params, Predicate, ToFields, Value, F}; +use super::{AnchoredKey, CustomPredicateRef, Params, Predicate, ToFields, Value, F, VALUE_SIZE}; pub const KEY_SIGNER: &str = "_signer"; pub const KEY_TYPE: &str = "_type"; @@ -25,8 +25,8 @@ pub enum NativePredicate { } impl ToFields for NativePredicate { - fn to_fields(&self, _params: &Params) -> (Vec, usize) { - (vec![F::from_canonical_u64(*self as u64)], 1) + fn to_fields(&self, _params: &Params) -> Vec { + vec![F::from_canonical_u64(*self as u64)] } } @@ -182,21 +182,10 @@ impl Statement { } impl ToFields for Statement { - fn to_fields(&self, _params: &Params) -> (Vec, usize) { - let (native_statement_f, native_statement_f_len) = self.code().to_fields(_params); - let (vec_statementarg_f, vec_statementarg_f_len) = self - .args() - .into_iter() - .map(|statement_arg| statement_arg.to_fields(_params)) - .fold((Vec::new(), 0), |mut acc, (f, l)| { - acc.0.extend(f); - acc.1 += l; - acc - }); - ( - [native_statement_f, vec_statementarg_f].concat(), - native_statement_f_len + vec_statementarg_f_len, - ) + fn to_fields(&self, _params: &Params) -> Vec { + let mut fields = self.code().to_fields(_params); + fields.extend(self.args().iter().flat_map(|arg| arg.to_fields(_params))); + fields } } @@ -250,7 +239,7 @@ impl StatementArg { } impl ToFields for StatementArg { - fn to_fields(&self, _params: &Params) -> (Vec, usize) { + fn to_fields(&self, _params: &Params) -> Vec { // NOTE: current version returns always the same amount of field elements in the returned // vector, which means that the `None` case is padded with 8 zeroes, and the `Literal` case // is padded with 4 zeroes. Since the returned vector will mostly be hashed (and reproduced @@ -261,20 +250,17 @@ impl ToFields for StatementArg { let f = match self { StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN], StatementArg::Literal(v) => { - let value_f = v.0.to_vec(); - [ - value_f.clone(), - vec![F::ZERO; STATEMENT_ARG_F_LEN - value_f.len()], - ] - .concat() + v.0.into_iter() + .chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE)) + .collect() } StatementArg::Key(ak) => { - let (podid_f, _) = ak.0.to_fields(_params); - let (hash_f, _) = ak.1.to_fields(_params); - [podid_f, hash_f].concat() + let mut fields = ak.0.to_fields(_params); + fields.extend(ak.1.to_fields(_params)); + fields } }; assert_eq!(f.len(), STATEMENT_ARG_F_LEN); // sanity check - (f, STATEMENT_ARG_F_LEN) + f } }