Featurize middleware types that are actually defined by the backend (#94)

At the middleware we were defining some types that actually are dependant on the
backend no matter how we define them in the middleware.

For example, we were hardcoding the `Hash` and `Value` types and their related
behaviour (eg. `.to_fields()`) to be based on the length of 4 field elements,
but that's not a choice of the middleware, and in fact this is determined by the
backend itself. On the same time, those types and related methods do not belong
to the backend, since conceptually they are part of the middleware reasoning.

The intention of this PR is not to prematurely abstract the library, but to
avoid inconsistencies where a type or parameter is defined in the middleware to
have certain carachteristic and later in the backend it gets used differently.
The idea is that those types and parameters (eg. lengths) have a single source
of truth in the code; and in the case of the "base types" (hash, value, etc)
this is determined by the backend being used under the hood, not by a choice of
the middleware parameters.

The idea with this approach, is that the frontend & middleware should not need
to import the proving library used by the backend (eg. plonky2, plonky3, etc).

As mentioned earlier, the `Hash` and `Value` types are types belonging at the
middleware, and is the middleware who reasons about them, but depending on the
backend being used, the `Hash` and `Value` types will have different sizes. So
it's the backend being used who actually defines their nature under the hood.
For example with a plonky2 backend, these types will have a length of 4 field
elements, whereas with a plonky3 backend they will have a length of 8 field
eleements.

Note that his approach does not introduce new traits or abstract code, just
makes use of rust features to define 'base types' that are being used in the
middleware.
This commit is contained in:
arnaucube 2025-02-27 14:15:31 +01:00 committed by GitHub
parent af46ab7a8d
commit 423605f867
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 359 additions and 278 deletions

View file

@ -1,36 +1,38 @@
//! The middleware includes the type definitions and the traits used to connect the frontend and
//! the backend.
mod basetypes;
pub mod containers;
mod custom;
mod operation;
mod statement;
pub use basetypes::*;
pub use custom::*;
pub use operation::*;
pub use statement::*;
use anyhow::{anyhow, Error, Result};
use anyhow::Result;
use dyn_clone::DynClone;
use hex::{FromHex, FromHexError};
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::types::{Field, PrimeField64};
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::{Hasher, PoseidonGoldilocksConfig};
use std::any::Any;
use std::cmp::{Ord, Ordering};
use std::collections::HashMap;
use std::fmt;
pub mod containers;
pub const SELF: PodId = PodId(SELF_ID_HASH);
/// F is the native field we use everywhere. Currently it's Goldilocks from plonky2
pub type F = GoldilocksField;
/// C is the Plonky2 config used in POD2 to work with Plonky2 recursion.
pub type C = PoseidonGoldilocksConfig;
/// D defines the extension degree of the field used in the Plonky2 proofs (quadratic extension).
pub const D: usize = 2;
impl fmt::Display for PodId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if *self == SELF {
write!(f, "self")
} else if self.0 == NULL {
write!(f, "null")
} else {
write!(f, "{}", self.0)
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
/// AnchoredKey is a tuple containing (OriginId: PodId, key: Hash)
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct AnchoredKey(pub PodId, pub Hash);
impl AnchoredKey {
@ -52,168 +54,15 @@ impl fmt::Display for AnchoredKey {
/// An entry consists of a key-value pair.
pub type Entry = (String, Value);
#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)]
pub struct Value(pub [F; 4]);
impl ToFields for Value {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
(self.0.to_vec(), 4)
}
}
impl Value {
pub fn to_bytes(self) -> Vec<u8> {
self.0
.iter()
.flat_map(|e| e.to_canonical_u64().to_le_bytes())
.collect()
}
}
impl Ord for Value {
fn cmp(&self, other: &Self) -> Ordering {
for (lhs, rhs) in self.0.iter().zip(other.0.iter()).rev() {
let (lhs, rhs) = (lhs.to_canonical_u64(), rhs.to_canonical_u64());
if lhs < rhs {
return Ordering::Less;
} else if lhs > rhs {
return Ordering::Greater;
}
}
return Ordering::Equal;
}
}
impl PartialOrd for Value {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl From<i64> for Value {
fn from(v: i64) -> Self {
let lo = F::from_canonical_u64((v as u64) & 0xffffffff);
let hi = F::from_canonical_u64((v as u64) >> 32);
Value([lo, hi, F::ZERO, F::ZERO])
}
}
impl From<Hash> for Value {
fn from(h: Hash) -> Self {
Value(h.0)
}
}
impl TryInto<i64> for Value {
type Error = Error;
fn try_into(self) -> std::result::Result<i64, Self::Error> {
let value = self.0;
if &value[2..] != &[F::ZERO, F::ZERO]
|| value[..2]
.iter()
.all(|x| x.to_canonical_u64() > u32::MAX as u64)
{
Err(anyhow!("Value not an element of the i64 embedding."))
} else {
Ok((value[0].to_canonical_u64() + value[1].to_canonical_u64() << 32) as i64)
}
}
}
impl fmt::Display for Value {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.0[2].is_zero() && self.0[3].is_zero() {
// Assume this is an integer
let (l0, l1) = (self.0[0].to_canonical_u64(), self.0[1].to_canonical_u64());
assert!(l0 < (1 << 32));
assert!(l1 < (1 << 32));
write!(f, "{}", l0 + l1 * (1 << 32))
} else {
// Assume this is a hash
Hash(self.0).fmt(f)
}
}
}
#[derive(Clone, Copy, Debug, Default, Hash, Eq, PartialEq)]
pub struct Hash(pub [F; 4]);
impl ToFields for Hash {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
(self.0.to_vec(), 4)
}
}
impl Ord for Hash {
fn cmp(&self, other: &Self) -> Ordering {
Value(self.0).cmp(&Value(other.0))
}
}
impl PartialOrd for Hash {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub const EMPTY: Value = Value([F::ZERO, F::ZERO, F::ZERO, F::ZERO]);
pub const NULL: Hash = Hash([F::ZERO, F::ZERO, F::ZERO, F::ZERO]);
impl fmt::Display for Hash {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let v0 = self.0[0].to_canonical_u64();
for i in 0..4 {
write!(f, "{:02x}", (v0 >> (i * 8)) & 0xff)?;
}
write!(f, "")
}
}
impl FromHex for Hash {
type Error = FromHexError;
fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> {
// In little endian
let bytes = <[u8; 32]>::from_hex(hex)?;
let mut buf: [u8; 8] = [0; 8];
let mut inner = [F::ZERO; 4];
for i in 0..4 {
buf.copy_from_slice(&bytes[8 * i..8 * (i + 1)]);
inner[i] = F::from_canonical_u64(u64::from_le_bytes(buf));
}
Ok(Self(inner))
}
}
impl From<&str> for Hash {
fn from(s: &str) -> Self {
hash_str(s)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
pub struct PodId(pub Hash);
impl ToFields for PodId {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) {
self.0.to_fields(params)
}
}
pub const SELF: PodId = PodId(Hash([F::ONE, F::ZERO, F::ZERO, F::ZERO]));
impl fmt::Display for PodId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if *self == SELF {
write!(f, "self")
} else if self.0 == NULL {
write!(f, "null")
} else {
write!(f, "{}", self.0)
}
}
}
pub enum PodType {
None = 0,
MockSigned = 1,
@ -228,26 +77,7 @@ impl From<PodType> for Value {
}
}
pub fn hash_str(s: &str) -> Hash {
let mut input = s.as_bytes().to_vec();
input.push(1); // padding
// Merge 7 bytes into 1 field, because the field is slightly below 64 bits
let input: Vec<F> = input
.chunks(7)
.map(|bytes| {
let mut v: u64 = 0;
for b in bytes.iter().rev() {
v <<= 8;
v += *b as u64;
}
F::from_canonical_u64(v)
})
.collect();
Hash(PoseidonHash::hash_no_pad(&input).elements)
}
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Params {
pub max_input_signed_pods: usize,
pub max_input_main_pods: usize,
@ -260,8 +90,6 @@ pub struct Params {
// in a custom predicate
pub max_custom_predicate_arity: usize,
pub max_custom_batch_size: usize,
// number of field elements in a hash
pub hash_size: usize,
}
impl Params {
@ -269,33 +97,33 @@ impl Params {
self.max_statements - self.max_public_statements
}
pub fn statement_tmpl_arg_size(self) -> usize {
2 * self.hash_size + 1
pub fn statement_tmpl_arg_size() -> usize {
2 * HASH_SIZE + 1
}
pub fn predicate_size(self) -> usize {
self.hash_size + 2
pub fn predicate_size() -> usize {
HASH_SIZE + 2
}
pub fn statement_tmpl_size(self) -> usize {
self.predicate_size() + self.max_statement_args * self.statement_tmpl_arg_size()
pub fn statement_tmpl_size(&self) -> usize {
Self::predicate_size() + self.max_statement_args * Self::statement_tmpl_arg_size()
}
pub fn custom_predicate_size(self) -> usize {
pub fn custom_predicate_size(&self) -> usize {
self.max_custom_predicate_arity * self.statement_tmpl_size() + 2
}
pub fn custom_predicate_batch_size_field_elts(self) -> usize {
pub fn custom_predicate_batch_size_field_elts(&self) -> usize {
self.max_custom_batch_size * self.custom_predicate_size()
}
pub fn print_serialized_sizes(self) -> () {
pub fn print_serialized_sizes(&self) -> () {
println!("Parameter sizes:");
println!(
" Statement template argument: {}",
self.statement_tmpl_arg_size()
Self::statement_tmpl_arg_size()
);
println!(" Predicate: {}", self.predicate_size());
println!(" Predicate: {}", Self::predicate_size());
println!(" Statement template: {}", self.statement_tmpl_size());
println!(" Custom predicate: {}", self.custom_predicate_size());
println!(
@ -318,7 +146,6 @@ impl Default for Params {
max_operation_args: 5,
max_custom_predicate_arity: 5,
max_custom_batch_size: 5,
hash_size: 4,
}
}
}
@ -386,5 +213,5 @@ pub trait PodProver {
pub trait ToFields {
/// returns Vec<F> representation of the type, and a usize indicating how many field elements
/// does the vector contain
fn to_fields(&self, params: Params) -> (Vec<F>, usize);
fn to_fields(&self, params: &Params) -> (Vec<F>, usize);
}