chore: simplify ToFields trait (#154)
This commit is contained in:
parent
b1689c5b37
commit
2a2628ccbf
6 changed files with 66 additions and 100 deletions
|
|
@ -1,6 +1,6 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::{fmt, hash as h, iter::zip};
|
||||
use std::{fmt, hash as h, iter, iter::zip};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use plonky2::field::types::Field;
|
||||
|
|
@ -48,18 +48,13 @@ impl fmt::Display for HashOrWildcard {
|
|||
}
|
||||
|
||||
impl ToFields for HashOrWildcard {
|
||||
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) {
|
||||
fn to_fields(&self, _params: &Params) -> Vec<F> {
|
||||
match self {
|
||||
HashOrWildcard::Hash(h) => h.to_fields(_params),
|
||||
HashOrWildcard::Wildcard(w) => {
|
||||
let mut usizes: Vec<usize> = vec![0; HASH_SIZE - 1];
|
||||
usizes.push(*w);
|
||||
let fields: Vec<F> = usizes
|
||||
.iter()
|
||||
.map(|x| F::from_canonical_u64(*x as u64))
|
||||
.collect();
|
||||
(fields, HASH_SIZE)
|
||||
}
|
||||
HashOrWildcard::Wildcard(w) => (0..HASH_SIZE - 1)
|
||||
.chain(iter::once(*w))
|
||||
.map(|x| F::from_canonical_u64(x as u64))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -95,7 +90,7 @@ impl StatementTmplArg {
|
|||
}
|
||||
|
||||
impl ToFields for StatementTmplArg {
|
||||
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) {
|
||||
fn to_fields(&self, _params: &Params) -> Vec<F> {
|
||||
// None => (0, ...)
|
||||
// Literal(value) => (1, [value], 0, 0, 0, 0)
|
||||
// Key(hash_or_wildcard1, hash_or_wildcard2)
|
||||
|
|
@ -104,24 +99,24 @@ impl ToFields for StatementTmplArg {
|
|||
let statement_tmpl_arg_size = 2 * HASH_SIZE + 1;
|
||||
match self {
|
||||
StatementTmplArg::None => {
|
||||
let fields: Vec<F> = std::iter::repeat_with(|| F::from_canonical_u64(0))
|
||||
let fields: Vec<F> = iter::repeat_with(|| F::from_canonical_u64(0))
|
||||
.take(statement_tmpl_arg_size)
|
||||
.collect();
|
||||
(fields, statement_tmpl_arg_size)
|
||||
fields
|
||||
}
|
||||
StatementTmplArg::Literal(v) => {
|
||||
let fields: Vec<F> = std::iter::once(F::from_canonical_u64(1))
|
||||
.chain(v.to_fields(_params).0)
|
||||
.chain(std::iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE))
|
||||
let fields: Vec<F> = iter::once(F::from_canonical_u64(1))
|
||||
.chain(v.to_fields(_params))
|
||||
.chain(iter::repeat_with(|| F::from_canonical_u64(0)).take(HASH_SIZE))
|
||||
.collect();
|
||||
(fields, statement_tmpl_arg_size)
|
||||
fields
|
||||
}
|
||||
StatementTmplArg::Key(hw1, hw2) => {
|
||||
let fields: Vec<F> = std::iter::once(F::from_canonical_u64(2))
|
||||
.chain(hw1.to_fields(_params).0)
|
||||
.chain(hw2.to_fields(_params).0)
|
||||
let fields: Vec<F> = iter::once(F::from_canonical_u64(2))
|
||||
.chain(hw1.to_fields(_params))
|
||||
.chain(hw2.to_fields(_params))
|
||||
.collect();
|
||||
(fields, statement_tmpl_arg_size)
|
||||
fields
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -181,7 +176,7 @@ impl StatementTmpl {
|
|||
}
|
||||
|
||||
impl ToFields for StatementTmpl {
|
||||
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) {
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
// serialize as:
|
||||
// predicate (6 field elements)
|
||||
// then the StatementTmplArgs
|
||||
|
|
@ -196,12 +191,11 @@ impl ToFields for StatementTmpl {
|
|||
let mut fields: Vec<F> = self
|
||||
.0
|
||||
.to_fields(params)
|
||||
.0
|
||||
.into_iter()
|
||||
.chain(self.1.iter().flat_map(|sta| sta.to_fields(params).0))
|
||||
.chain(self.1.iter().flat_map(|sta| sta.to_fields(params)))
|
||||
.collect();
|
||||
fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0));
|
||||
(fields, params.statement_tmpl_size())
|
||||
fields
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -244,7 +238,7 @@ impl CustomPredicate {
|
|||
}
|
||||
|
||||
impl ToFields for CustomPredicate {
|
||||
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) {
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
// serialize as:
|
||||
// conjunction (one field element)
|
||||
// args_len (one field element)
|
||||
|
|
@ -259,12 +253,12 @@ impl ToFields for CustomPredicate {
|
|||
panic!("Custom predicate depends on too many statements");
|
||||
}
|
||||
|
||||
let mut fields: Vec<F> = std::iter::once(F::from_bool(self.conjunction))
|
||||
.chain(std::iter::once(F::from_canonical_usize(self.args_len)))
|
||||
.chain(self.statements.iter().flat_map(|st| st.to_fields(params).0))
|
||||
let mut fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
|
||||
.chain(iter::once(F::from_canonical_usize(self.args_len)))
|
||||
.chain(self.statements.iter().flat_map(|st| st.to_fields(params)))
|
||||
.collect();
|
||||
fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0));
|
||||
(fields, params.custom_predicate_size())
|
||||
fields
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -300,7 +294,7 @@ pub struct CustomPredicateBatch {
|
|||
}
|
||||
|
||||
impl ToFields for CustomPredicateBatch {
|
||||
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) {
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
// all the custom predicates in order
|
||||
|
||||
// TODO think if this check should go into the StatementTmpl creation,
|
||||
|
|
@ -313,19 +307,18 @@ impl ToFields for CustomPredicateBatch {
|
|||
let mut fields: Vec<F> = self
|
||||
.predicates
|
||||
.iter()
|
||||
.flat_map(|p| p.to_fields(params).0)
|
||||
.flat_map(|p| p.to_fields(params))
|
||||
.collect();
|
||||
fields.resize_with(params.custom_predicate_batch_size_field_elts(), || {
|
||||
F::from_canonical_u64(0)
|
||||
});
|
||||
|
||||
(fields, params.custom_predicate_batch_size_field_elts())
|
||||
fields
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomPredicateBatch {
|
||||
pub fn hash(&self, _params: &Params) -> Hash {
|
||||
let input = self.to_fields(_params).0;
|
||||
let input = self.to_fields(_params);
|
||||
|
||||
hash_fields(&input)
|
||||
}
|
||||
|
|
@ -367,7 +360,7 @@ impl CustomPredicateRef {
|
|||
match custom_predicate.conjunction {
|
||||
true if custom_predicate.statements.len() == statements.len() => {
|
||||
// Match op args against statement templates
|
||||
let match_bindings = std::iter::zip(custom_predicate.statements, statements).map(
|
||||
let match_bindings = iter::zip(custom_predicate.statements, statements).map(
|
||||
|(s_tmpl, s)| s_tmpl.match_against(s)
|
||||
).collect::<Result<Vec<_>>>()
|
||||
.map(|v| v.concat())?;
|
||||
|
|
@ -404,7 +397,7 @@ impl From<NativePredicate> for Predicate {
|
|||
}
|
||||
|
||||
impl ToFields for Predicate {
|
||||
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) {
|
||||
fn to_fields(&self, _params: &Params) -> Vec<F> {
|
||||
// serialize:
|
||||
// NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize
|
||||
// BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize
|
||||
|
|
@ -414,19 +407,19 @@ impl ToFields for Predicate {
|
|||
|
||||
// in every case: pad to (hash_size + 2) field elements
|
||||
let mut fields: Vec<F> = match self {
|
||||
Self::Native(p) => std::iter::once(F::from_canonical_u64(1))
|
||||
.chain(p.to_fields(_params).0)
|
||||
Self::Native(p) => iter::once(F::from_canonical_u64(1))
|
||||
.chain(p.to_fields(_params))
|
||||
.collect(),
|
||||
Self::BatchSelf(i) => std::iter::once(F::from_canonical_u64(2))
|
||||
.chain(std::iter::once(F::from_canonical_usize(*i)))
|
||||
Self::BatchSelf(i) => iter::once(F::from_canonical_u64(2))
|
||||
.chain(iter::once(F::from_canonical_usize(*i)))
|
||||
.collect(),
|
||||
Self::Custom(CustomPredicateRef(pb, i)) => std::iter::once(F::from_canonical_u64(3))
|
||||
Self::Custom(CustomPredicateRef(pb, i)) => iter::once(F::from_canonical_u64(3))
|
||||
.chain(pb.hash(_params).0)
|
||||
.chain(std::iter::once(F::from_canonical_usize(*i)))
|
||||
.chain(iter::once(F::from_canonical_usize(*i)))
|
||||
.collect(),
|
||||
};
|
||||
fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0));
|
||||
(fields, Params::predicate_size())
|
||||
fields
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue