Featurize middleware types that are actually defined by the backend (#94)

At the middleware we were defining some types that actually are dependant on the
backend no matter how we define them in the middleware.

For example, we were hardcoding the `Hash` and `Value` types and their related
behaviour (eg. `.to_fields()`) to be based on the length of 4 field elements,
but that's not a choice of the middleware, and in fact this is determined by the
backend itself. On the same time, those types and related methods do not belong
to the backend, since conceptually they are part of the middleware reasoning.

The intention of this PR is not to prematurely abstract the library, but to
avoid inconsistencies where a type or parameter is defined in the middleware to
have certain carachteristic and later in the backend it gets used differently.
The idea is that those types and parameters (eg. lengths) have a single source
of truth in the code; and in the case of the "base types" (hash, value, etc)
this is determined by the backend being used under the hood, not by a choice of
the middleware parameters.

The idea with this approach, is that the frontend & middleware should not need
to import the proving library used by the backend (eg. plonky2, plonky3, etc).

As mentioned earlier, the `Hash` and `Value` types are types belonging at the
middleware, and is the middleware who reasons about them, but depending on the
backend being used, the `Hash` and `Value` types will have different sizes. So
it's the backend being used who actually defines their nature under the hood.
For example with a plonky2 backend, these types will have a length of 4 field
elements, whereas with a plonky3 backend they will have a length of 8 field
eleements.

Note that his approach does not introduce new traits or abstract code, just
makes use of rust features to define 'base types' that are being used in the
middleware.
This commit is contained in:
arnaucube 2025-02-27 14:15:31 +01:00 committed by GitHub
parent af46ab7a8d
commit 423605f867
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 359 additions and 278 deletions

View file

@ -3,14 +3,10 @@ use std::{fmt, hash as h, iter::zip};
use anyhow::{anyhow, Result};
use plonky2::field::types::Field;
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::Hasher;
use crate::middleware::{Operation, SELF};
use super::{
hash_str, AnchoredKey, Hash, NativePredicate, Params, PodId, Statement, StatementArg, ToFields,
Value, F,
hash_fields, AnchoredKey, Hash, NativePredicate, Params, PodId, Statement, StatementArg,
ToFields, Value, F,
};
// BEGIN Custom 1b
@ -44,9 +40,9 @@ impl fmt::Display for HashOrWildcard {
}
impl ToFields for HashOrWildcard {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) {
match self {
HashOrWildcard::Hash(h) => h.to_fields(params),
HashOrWildcard::Hash(h) => h.to_fields(_params),
HashOrWildcard::Wildcard(w) => {
let usizes: Vec<usize> = vec![0, 0, 0, *w];
let fields: Vec<F> = usizes
@ -86,7 +82,7 @@ impl StatementTmplArg {
}
impl ToFields for StatementTmplArg {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) {
// None => (0, ...)
// Literal(value) => (1, [value], 0, 0, 0, 0)
// Key(hash_or_wildcard1, hash_or_wildcard2)
@ -103,15 +99,15 @@ impl ToFields for StatementTmplArg {
}
StatementTmplArg::Literal(v) => {
let fields: Vec<F> = std::iter::once(F::from_canonical_u64(1))
.chain(v.to_fields(params).0.into_iter())
.chain(v.to_fields(_params).0.into_iter())
.chain(std::iter::repeat_with(|| F::from_canonical_u64(0)).take(hash_size))
.collect();
(fields, statement_tmpl_arg_size)
}
StatementTmplArg::Key(hw1, hw2) => {
let fields: Vec<F> = std::iter::once(F::from_canonical_u64(2))
.chain(hw1.to_fields(params).0.into_iter())
.chain(hw2.to_fields(params).0.into_iter())
.chain(hw1.to_fields(_params).0.into_iter())
.chain(hw2.to_fields(_params).0.into_iter())
.collect();
(fields, statement_tmpl_arg_size)
}
@ -173,13 +169,18 @@ impl StatementTmpl {
}
impl ToFields for StatementTmpl {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) {
// serialize as:
// predicate (6 field elements)
// then the StatementTmplArgs
// TODO think if this check should go into the StatementTmpl creation,
// instead of at the `to_fields` method, where we should assume that the
// values are already valid
if self.1.len() > params.max_statement_args {
panic!("Statement template has too many arguments");
}
let mut fields: Vec<F> = self
.0
.to_fields(params)
@ -203,16 +204,21 @@ pub struct CustomPredicate {
}
impl ToFields for CustomPredicate {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) {
// serialize as:
// conjunction (one field element)
// args_len (one field element)
// statements
// (params.max_custom_predicate_arity * params.statement_tmpl_size())
// field elements
// TODO think if this check should go into the StatementTmpl creation,
// instead of at the `to_fields` method, where we should assume that the
// values are already valid
if self.statements.len() > params.max_custom_predicate_arity {
panic!("Custom predicate depends on too many statements");
}
let mut fields: Vec<F> = std::iter::once(F::from_bool(self.conjunction))
.chain(std::iter::once(F::from_canonical_usize(self.args_len)))
.chain(self.statements.iter().flat_map(|st| st.to_fields(params).0))
@ -254,11 +260,16 @@ pub struct CustomPredicateBatch {
}
impl ToFields for CustomPredicateBatch {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) {
// all the custom predicates in order
// TODO think if this check should go into the StatementTmpl creation,
// instead of at the `to_fields` method, where we should assume that the
// values are already valid
if self.predicates.len() > params.max_custom_batch_size {
panic!("Predicate batch exceeds maximum size");
}
let mut fields: Vec<F> = self
.predicates
.iter()
@ -273,9 +284,9 @@ impl ToFields for CustomPredicateBatch {
}
impl CustomPredicateBatch {
pub fn hash(&self, params: Params) -> Hash {
let input = self.to_fields(params).0;
let h = Hash(PoseidonHash::hash_no_pad(&input).elements);
pub fn hash(&self, _params: &Params) -> Hash {
let input = self.to_fields(_params).0;
let h = hash_fields(&input);
h
}
}
@ -297,7 +308,7 @@ impl From<NativePredicate> for Predicate {
}
impl ToFields for Predicate {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
fn to_fields(&self, _params: &Params) -> (Vec<F>, usize) {
// serialize:
// NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize
// BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize
@ -306,27 +317,20 @@ impl ToFields for Predicate {
// -- i: usize
// in every case: pad to (hash_size + 2) field elements
let mut fields: Vec<F> = Vec::new();
match self {
Self::Native(p) => {
fields = std::iter::once(F::from_canonical_u64(1))
.chain(p.to_fields(params).0.into_iter())
.collect();
}
Self::BatchSelf(i) => {
fields = std::iter::once(F::from_canonical_u64(2))
.chain(std::iter::once(F::from_canonical_usize(*i)))
.collect();
}
Self::Custom(CustomPredicateRef(pb, i)) => {
fields = std::iter::once(F::from_canonical_u64(3))
.chain(pb.hash(params).0)
.chain(std::iter::once(F::from_canonical_usize(*i)))
.collect();
}
}
fields.resize_with(params.predicate_size(), || F::from_canonical_u64(0));
(fields, params.predicate_size())
let mut fields: Vec<F> = match self {
Self::Native(p) => std::iter::once(F::from_canonical_u64(1))
.chain(p.to_fields(_params).0.into_iter())
.collect(),
Self::BatchSelf(i) => std::iter::once(F::from_canonical_u64(2))
.chain(std::iter::once(F::from_canonical_usize(*i)))
.collect(),
Self::Custom(CustomPredicateRef(pb, i)) => std::iter::once(F::from_canonical_u64(3))
.chain(pb.hash(_params).0)
.chain(std::iter::once(F::from_canonical_usize(*i)))
.collect(),
};
fields.resize_with(Params::predicate_size(), || F::from_canonical_u64(0));
(fields, Params::predicate_size())
}
}