chore: simplify ToFields trait (#154)

This commit is contained in:
Eduard S. 2025-03-20 09:38:46 +01:00 committed by GitHub
parent b1689c5b37
commit 2a2628ccbf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 66 additions and 100 deletions

View file

@ -38,8 +38,8 @@ pub const EMPTY_HASH: Hash = Hash([F::ZERO, F::ZERO, F::ZERO, F::ZERO]);
pub struct Value(pub [F; VALUE_SIZE]); pub struct Value(pub [F; VALUE_SIZE]);
impl ToFields for Value { impl ToFields for Value {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, _params: &Params) -> Vec<F> {
(self.0.to_vec(), VALUE_SIZE) self.0.to_vec()
} }
} }
@ -143,8 +143,8 @@ impl Hash {
} }
impl ToFields for Hash { impl ToFields for Hash {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, _params: &Params) -> Vec<F> {
(self.0.to_vec(), VALUE_SIZE) self.0.to_vec()
} }
} }

View file

@ -364,7 +364,7 @@ impl MockMainPod {
pub fn hash_statements(statements: &[Statement], _params: &Params) -> middleware::Hash { pub fn hash_statements(statements: &[Statement], _params: &Params) -> middleware::Hash {
let field_elems = statements let field_elems = statements
.iter() .iter()
.flat_map(|statement| statement.clone().to_fields(_params).0) .flat_map(|statement| statement.clone().to_fields(_params))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Hash(PoseidonHash::hash_no_pad(&field_elems).elements) Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
} }

View file

@ -23,22 +23,10 @@ impl Statement {
} }
impl ToFields for Statement { impl ToFields for Statement {
fn to_fields(&self, _params: &Params) -> (Vec<middleware::F>, usize) { fn to_fields(&self, _params: &Params) -> Vec<middleware::F> {
let (native_statement_f, native_statement_f_len) = self.0.to_fields(_params); let mut fields = self.0.to_fields(_params);
let (vec_statementarg_f, vec_statementarg_f_len) = self fields.extend(self.1.iter().flat_map(|arg| arg.to_fields(_params)));
.1 fields
.clone()
.into_iter()
.map(|statement_arg| statement_arg.to_fields(_params))
.fold((Vec::new(), 0), |mut acc, (f, l)| {
acc.0.extend(f);
acc.1 += l;
acc
});
(
[native_statement_f, vec_statementarg_f].concat(),
native_statement_f_len + vec_statementarg_f_len,
)
} }
} }

View file

@ -1,6 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::{fmt, hash as h, iter::zip}; use std::{fmt, hash as h, iter, iter::zip};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use plonky2::field::types::Field; use plonky2::field::types::Field;
@ -48,18 +48,13 @@ impl fmt::Display for HashOrWildcard {
} }
impl ToFields for HashOrWildcard { impl ToFields for HashOrWildcard {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, _params: &Params) -> Vec<F> {
match self { match self {
HashOrWildcard::Hash(h) => h.to_fields(_params), HashOrWildcard::Hash(h) => h.to_fields(_params),
HashOrWildcard::Wildcard(w) => { HashOrWildcard::Wildcard(w) => (0..HASH_SIZE - 1)
let mut usizes: Vec<usize> = vec![0; HASH_SIZE - 1]; .chain(iter::once(*w))
usizes.push(*w); .map(|x| F::from_canonical_u64(x as u64))
let fields: Vec<F> = usizes .collect(),
.iter()
.map(|x| F::from_canonical_u64(*x as u64))
.collect();
(fields, HASH_SIZE)
}
} }
} }
} }
@ -95,7 +90,7 @@ impl StatementTmplArg {
} }
impl ToFields for StatementTmplArg { impl ToFields for StatementTmplArg {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, _params: &Params) -> Vec<F> {
// None => (0, ...) // None => (0, ...)
// Literal(value) => (1, [value], 0, 0, 0, 0) // Literal(value) => (1, [value], 0, 0, 0, 0)
// Key(hash_or_wildcard1, hash_or_wildcard2) // Key(hash_or_wildcard1, hash_or_wildcard2)
@ -104,24 +99,24 @@ impl ToFields for StatementTmplArg {
let statement_tmpl_arg_size = 2 * HASH_SIZE + 1; let statement_tmpl_arg_size = 2 * HASH_SIZE + 1;
match self { match self {
StatementTmplArg::None => { StatementTmplArg::None => {
let fields: Vec<F> = std::iter::repeat_with(|| F::from_canonical_u64(0)) let fields: Vec<F> = iter::repeat_with(|| F::from_canonical_u64(0))
.take(statement_tmpl_arg_size) .take(statement_tmpl_arg_size)
.collect(); .collect();
(fields, statement_tmpl_arg_size) fields
} }
StatementTmplArg::Literal(v) => { StatementTmplArg::Literal(v) => {
let fields: Vec<F> = std::iter::once(F::from_canonical_u64(1)) let fields: Vec<F> = iter::once(F::from_canonical_u64(1))
.chain(v.to_fields(_params).0) .chain(v.to_fields(_params))
.chain(std::iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE)) .chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE))
.collect(); .collect();
(fields, statement_tmpl_arg_size) fields
} }
StatementTmplArg::Key(hw1, hw2) => { StatementTmplArg::Key(hw1, hw2) => {
let fields: Vec<F> = std::iter::once(F::from_canonical_u64(2)) let fields: Vec<F> = iter::once(F::from_canonical_u64(2))
.chain(hw1.to_fields(_params).0) .chain(hw1.to_fields(_params))
.chain(hw2.to_fields(_params).0) .chain(hw2.to_fields(_params))
.collect(); .collect();
(fields, statement_tmpl_arg_size) fields
} }
} }
} }
@ -181,7 +176,7 @@ impl StatementTmpl {
} }
impl ToFields for StatementTmpl { impl ToFields for StatementTmpl {
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, params: &Params) -> Vec<F> {
// serialize as: // serialize as:
// predicate (6 field elements) // predicate (6 field elements)
// then the StatementTmplArgs // then the StatementTmplArgs
@ -196,12 +191,11 @@ impl ToFields for StatementTmpl {
let mut fields: Vec<F> = self let mut fields: Vec<F> = self
.0 .0
.to_fields(params) .to_fields(params)
.0
.into_iter() .into_iter()
.chain(self.1.iter().flat_map(|sta| sta.to_fields(params).0)) .chain(self.1.iter().flat_map(|sta| sta.to_fields(params)))
.collect(); .collect();
fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0)); fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0));
(fields, params.statement_tmpl_size()) fields
} }
} }
@ -244,7 +238,7 @@ impl CustomPredicate {
} }
impl ToFields for CustomPredicate { impl ToFields for CustomPredicate {
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, params: &Params) -> Vec<F> {
// serialize as: // serialize as:
// conjunction (one field element) // conjunction (one field element)
// args_len (one field element) // args_len (one field element)
@ -259,12 +253,12 @@ impl ToFields for CustomPredicate {
panic!("Custom predicate depends on too many statements"); panic!("Custom predicate depends on too many statements");
} }
let mut fields: Vec<F> = std::iter::once(F::from_bool(self.conjunction)) let mut fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
.chain(std::iter::once(F::from_canonical_usize(self.args_len))) .chain(iter::once(F::from_canonical_usize(self.args_len)))
.chain(self.statements.iter().flat_map(|st| st.to_fields(params).0)) .chain(self.statements.iter().flat_map(|st| st.to_fields(params)))
.collect(); .collect();
fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0)); fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0));
(fields, params.custom_predicate_size()) fields
} }
} }
@ -300,7 +294,7 @@ pub struct CustomPredicateBatch {
} }
impl ToFields for CustomPredicateBatch { impl ToFields for CustomPredicateBatch {
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, params: &Params) -> Vec<F> {
// all the custom predicates in order // all the custom predicates in order
// TODO think if this check should go into the StatementTmpl creation, // TODO think if this check should go into the StatementTmpl creation,
@ -313,19 +307,18 @@ impl ToFields for CustomPredicateBatch {
let mut fields: Vec<F> = self let mut fields: Vec<F> = self
.predicates .predicates
.iter() .iter()
.flat_map(|p| p.to_fields(params).0) .flat_map(|p| p.to_fields(params))
.collect(); .collect();
fields.resize_with(params.custom_predicate_batch_size_field_elts(), || { fields.resize_with(params.custom_predicate_batch_size_field_elts(), || {
F::from_canonical_u64(0) F::from_canonical_u64(0)
}); });
fields
(fields, params.custom_predicate_batch_size_field_elts())
} }
} }
impl CustomPredicateBatch { impl CustomPredicateBatch {
pub fn hash(&self, _params: &Params) -> Hash { pub fn hash(&self, _params: &Params) -> Hash {
let input = self.to_fields(_params).0; let input = self.to_fields(_params);
hash_fields(&input) hash_fields(&input)
} }
@ -367,7 +360,7 @@ impl CustomPredicateRef {
match custom_predicate.conjunction { match custom_predicate.conjunction {
true if custom_predicate.statements.len() == statements.len() => { true if custom_predicate.statements.len() == statements.len() => {
// Match op args against statement templates // Match op args against statement templates
let match_bindings = std::iter::zip(custom_predicate.statements, statements).map( let match_bindings = iter::zip(custom_predicate.statements, statements).map(
|(s_tmpl, s)| s_tmpl.match_against(s) |(s_tmpl, s)| s_tmpl.match_against(s)
).collect::<Result<Vec<_>>>() ).collect::<Result<Vec<_>>>()
.map(|v| v.concat())?; .map(|v| v.concat())?;
@ -404,7 +397,7 @@ impl From<NativePredicate> for Predicate {
} }
impl ToFields for Predicate { impl ToFields for Predicate {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, _params: &Params) -> Vec<F> {
// serialize: // serialize:
// NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize // NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize
// BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize // BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize
@ -414,19 +407,19 @@ impl ToFields for Predicate {
// in every case: pad to (hash_size + 2) field elements // in every case: pad to (hash_size + 2) field elements
let mut fields: Vec<F> = match self { let mut fields: Vec<F> = match self {
Self::Native(p) => std::iter::once(F::from_canonical_u64(1)) Self::Native(p) => iter::once(F::from_canonical_u64(1))
.chain(p.to_fields(_params).0) .chain(p.to_fields(_params))
.collect(), .collect(),
Self::BatchSelf(i) => std::iter::once(F::from_canonical_u64(2)) Self::BatchSelf(i) => iter::once(F::from_canonical_u64(2))
.chain(std::iter::once(F::from_canonical_usize(*i))) .chain(iter::once(F::from_canonical_usize(*i)))
.collect(), .collect(),
Self::Custom(CustomPredicateRef(pb, i)) => std::iter::once(F::from_canonical_u64(3)) Self::Custom(CustomPredicateRef(pb, i)) => iter::once(F::from_canonical_u64(3))
.chain(pb.hash(_params).0) .chain(pb.hash(_params).0)
.chain(std::iter::once(F::from_canonical_usize(*i))) .chain(iter::once(F::from_canonical_usize(*i)))
.collect(), .collect(),
}; };
fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0)); fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0));
(fields, Params::predicate_size()) fields
} }
} }

View file

@ -58,7 +58,7 @@ pub type Entry = (String, Value);
pub struct PodId(pub Hash); pub struct PodId(pub Hash);
impl ToFields for PodId { impl ToFields for PodId {
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, params: &Params) -> Vec<F> {
self.0.to_fields(params) self.0.to_fields(params)
} }
} }
@ -211,7 +211,6 @@ pub trait PodProver {
} }
pub trait ToFields { pub trait ToFields {
/// returns Vec<F> representation of the type, and a usize indicating how many field elements /// returns Vec<F> representation of the type
/// does the vector contain fn to_fields(&self, params: &Params) -> Vec<F>;
fn to_fields(&self, params: &Params) -> (Vec<F>, usize);
} }

View file

@ -1,9 +1,9 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use plonky2::field::types::Field; use plonky2::field::types::Field;
use std::fmt; use std::{fmt, iter};
use strum_macros::FromRepr; use strum_macros::FromRepr;
use super::{AnchoredKey, CustomPredicateRef, Params, Predicate, ToFields, Value, F}; use super::{AnchoredKey, CustomPredicateRef, Params, Predicate, ToFields, Value, F, VALUE_SIZE};
pub const KEY_SIGNER: &str = "_signer"; pub const KEY_SIGNER: &str = "_signer";
pub const KEY_TYPE: &str = "_type"; pub const KEY_TYPE: &str = "_type";
@ -25,8 +25,8 @@ pub enum NativePredicate {
} }
impl ToFields for NativePredicate { impl ToFields for NativePredicate {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, _params: &Params) -> Vec<F> {
(vec![F::from_canonical_u64(*self as u64)], 1) vec![F::from_canonical_u64(*self as u64)]
} }
} }
@ -182,21 +182,10 @@ impl Statement {
} }
impl ToFields for Statement { impl ToFields for Statement {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, _params: &Params) -> Vec<F> {
let (native_statement_f, native_statement_f_len) = self.code().to_fields(_params); let mut fields = self.code().to_fields(_params);
let (vec_statementarg_f, vec_statementarg_f_len) = self fields.extend(self.args().iter().flat_map(|arg| arg.to_fields(_params)));
.args() fields
.into_iter()
.map(|statement_arg| statement_arg.to_fields(_params))
.fold((Vec::new(), 0), |mut acc, (f, l)| {
acc.0.extend(f);
acc.1 += l;
acc
});
(
[native_statement_f, vec_statementarg_f].concat(),
native_statement_f_len + vec_statementarg_f_len,
)
} }
} }
@ -250,7 +239,7 @@ impl StatementArg {
} }
impl ToFields for StatementArg { impl ToFields for StatementArg {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) { fn to_fields(&self, _params: &Params) -> Vec<F> {
// NOTE: current version returns always the same amount of field elements in the returned // NOTE: current version returns always the same amount of field elements in the returned
// vector, which means that the `None` case is padded with 8 zeroes, and the `Literal` case // vector, which means that the `None` case is padded with 8 zeroes, and the `Literal` case
// is padded with 4 zeroes. Since the returned vector will mostly be hashed (and reproduced // is padded with 4 zeroes. Since the returned vector will mostly be hashed (and reproduced
@ -261,20 +250,17 @@ impl ToFields for StatementArg {
let f = match self { let f = match self {
StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN], StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN],
StatementArg::Literal(v) => { StatementArg::Literal(v) => {
let value_f = v.0.to_vec(); v.0.into_iter()
[ .chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE))
value_f.clone(), .collect()
vec![F::ZERO; STATEMENT_ARG_F_LEN - value_f.len()],
]
.concat()
} }
StatementArg::Key(ak) => { StatementArg::Key(ak) => {
let (podid_f, _) = ak.0.to_fields(_params); let mut fields = ak.0.to_fields(_params);
let (hash_f, _) = ak.1.to_fields(_params); fields.extend(ak.1.to_fields(_params));
[podid_f, hash_f].concat() fields
} }
}; };
assert_eq!(f.len(), STATEMENT_ARG_F_LEN); // sanity check assert_eq!(f.len(), STATEMENT_ARG_F_LEN); // sanity check
(f, STATEMENT_ARG_F_LEN) f
} }
} }