Serialize and hash custom predicates (#90)

* Print pods from SignedPodBuilder

* Add additional print to test printing SignedPodBuilder

* Mock-prove and print MainPod

* Implement ToFields for custom predicates and dependencies

* Test: print serialization of a recursive batch

* Rearrange serialization of CustomPredicate so args_len is always in the same position

* Serialize predicates with first entry nonzero to avoid collision with padding

* Off by one error in ethdos test BatchSelf(2)

* cargo fmt

* not a typo

* Typos, trying again
This commit is contained in:
tideofwords 2025-02-26 11:28:27 -08:00 committed by GitHub
parent 05c21ebe6a
commit a37b96ab4f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 262 additions and 39 deletions

View file

@ -2,10 +2,14 @@ use std::sync::Arc;
use std::{fmt, hash as h, iter::zip};
use anyhow::{anyhow, Result};
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::types::Field;
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::Hasher;
use super::{
hash_str, AnchoredKey, Hash, NativePredicate, PodId, Statement, StatementArg, ToFields, Value,
F,
hash_str, AnchoredKey, Hash, NativePredicate, Params, PodId, Statement, StatementArg, ToFields,
Value, F,
};
// BEGIN Custom 1b
@ -38,6 +42,22 @@ impl fmt::Display for HashOrWildcard {
}
}
impl ToFields for HashOrWildcard {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
match self {
HashOrWildcard::Hash(h) => h.to_fields(params),
HashOrWildcard::Wildcard(w) => {
let usizes: Vec<usize> = vec![0, 0, 0, *w];
let fields: Vec<F> = usizes
.iter()
.map(|x| F::from_canonical_u64(*x as u64))
.collect();
(fields, 4)
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, h::Hash)]
pub enum StatementTmplArg {
None,
@ -64,6 +84,40 @@ impl StatementTmplArg {
}
}
impl ToFields for StatementTmplArg {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
// None => (0, ...)
// Literal(value) => (1, [value], 0, 0, 0, 0)
// Key(hash_or_wildcard1, hash_or_wildcard2)
// => (2, [hash_or_wildcard1], [hash_or_wildcard2])
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
let hash_size = 4;
let statement_tmpl_arg_size = 2 * hash_size + 1;
match self {
StatementTmplArg::None => {
let fields: Vec<F> = std::iter::repeat_with(|| F::from_canonical_u64(0))
.take(statement_tmpl_arg_size)
.collect();
(fields, statement_tmpl_arg_size)
}
StatementTmplArg::Literal(v) => {
let fields: Vec<F> = std::iter::once(F::from_canonical_u64(1))
.chain(v.to_fields(params).0.into_iter())
.chain(std::iter::repeat_with(|| F::from_canonical_u64(0)).take(hash_size))
.collect();
(fields, statement_tmpl_arg_size)
}
StatementTmplArg::Key(hw1, hw2) => {
let fields: Vec<F> = std::iter::once(F::from_canonical_u64(2))
.chain(hw1.to_fields(params).0.into_iter())
.chain(hw2.to_fields(params).0.into_iter())
.collect();
(fields, statement_tmpl_arg_size)
}
}
}
}
impl fmt::Display for StatementTmplArg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
@ -117,6 +171,26 @@ impl StatementTmpl {
}
}
impl ToFields for StatementTmpl {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
// serialize as:
// predicate (6 field elements)
// then the StatementTmplArgs
if self.1.len() > params.max_statement_args {
panic!("Statement template has too many arguments");
}
let mut fields: Vec<F> = self
.0
.to_fields(params)
.0
.into_iter()
.chain(self.1.iter().flat_map(|sta| sta.to_fields(params).0))
.collect();
fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0));
(fields, params.statement_tmpl_size())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CustomPredicate {
/// true for "and", false for "or"
@ -128,10 +202,22 @@ pub struct CustomPredicate {
}
impl ToFields for CustomPredicate {
fn to_fields(self) -> (Vec<F>, usize) {
todo!()
// let f: Vec<F> = Vec::new();
// (self.conjunction.to_f(), 1)
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
// serialize as:
// conjunction (one field element)
// args_len (one field element)
// statements
// (params.max_custom_predicate_arity * params.statement_tmpl_size())
// field elements
if self.statements.len() > params.max_custom_predicate_arity {
panic!("Custom predicate depends on too many statements");
}
let mut fields: Vec<F> = std::iter::once(F::from_bool(self.conjunction))
.chain(std::iter::once(F::from_canonical_usize(self.args_len)))
.chain(self.statements.iter().flat_map(|st| st.to_fields(params).0))
.collect();
fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0));
(fields, params.custom_predicate_size())
}
}
@ -166,10 +252,30 @@ pub struct CustomPredicateBatch {
pub predicates: Vec<CustomPredicate>,
}
impl ToFields for CustomPredicateBatch {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
// all the custom predicates in order
if self.predicates.len() > params.max_custom_batch_size {
panic!("Predicate batch exceeds maximum size");
}
let mut fields: Vec<F> = self
.predicates
.iter()
.flat_map(|p| p.to_fields(params).0)
.collect();
fields.resize_with(params.custom_predicate_batch_size_field_elts(), || {
F::from_canonical_u64(0)
});
(fields, params.custom_predicate_batch_size_field_elts())
}
}
impl CustomPredicateBatch {
pub fn hash(&self) -> Hash {
// TODO
hash_str(&format!("{:?}", self))
pub fn hash(&self, params: Params) -> Hash {
let input = self.to_fields(params).0;
let h = Hash(PoseidonHash::hash_no_pad(&input).elements);
h
}
}
@ -190,12 +296,36 @@ impl From<NativePredicate> for Predicate {
}
impl ToFields for Predicate {
fn to_fields(self) -> (Vec<F>, usize) {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
// serialize:
// NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize
// BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize
// CustomPredicateRef(pb, i) as
// (2, [hash of pb], i) -- pb hashes to 4 field elements
// -- i: usize
// in every case: pad to (hash_size + 2) field elements
let mut fields: Vec<F> = Vec::new();
match self {
Self::Native(p) => p.to_fields(),
Self::BatchSelf(i) => Value::from(i as i64).to_fields(),
Self::Custom(_) => todo!(), // TODO
Self::Native(p) => {
fields = std::iter::once(F::from_canonical_u64(1))
.chain(p.to_fields(params).0.into_iter())
.collect();
}
Self::BatchSelf(i) => {
fields = std::iter::once(F::from_canonical_u64(2))
.chain(std::iter::once(F::from_canonical_usize(*i)))
.collect();
}
Self::Custom(CustomPredicateRef(pb, i)) => {
fields = std::iter::once(F::from_canonical_u64(3))
.chain(pb.hash(params).0)
.chain(std::iter::once(F::from_canonical_usize(*i)))
.collect();
}
}
fields.resize_with(params.predicate_size(), || F::from_canonical_u64(0));
(fields, params.predicate_size())
}
}

View file

@ -42,6 +42,13 @@ impl AnchoredKey {
}
}
impl fmt::Display for AnchoredKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}.{}", self.0, self.1)?;
Ok(())
}
}
/// An entry consists of a key-value pair.
pub type Entry = (String, Value);
@ -49,7 +56,7 @@ pub type Entry = (String, Value);
pub struct Value(pub [F; 4]);
impl ToFields for Value {
fn to_fields(self) -> (Vec<F>, usize) {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
(self.0.to_vec(), 4)
}
}
@ -132,7 +139,7 @@ impl fmt::Display for Value {
pub struct Hash(pub [F; 4]);
impl ToFields for Hash {
fn to_fields(self) -> (Vec<F>, usize) {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
(self.0.to_vec(), 4)
}
}
@ -182,8 +189,8 @@ impl FromHex for Hash {
pub struct PodId(pub Hash);
impl ToFields for PodId {
fn to_fields(self) -> (Vec<F>, usize) {
self.0.to_fields()
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
self.0.to_fields(params)
}
}
@ -234,7 +241,7 @@ pub fn hash_str(s: &str) -> Hash {
Hash(PoseidonHash::hash_no_pad(&input).elements)
}
#[derive(Clone, Debug, Copy)]
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
pub struct Params {
pub max_input_signed_pods: usize,
pub max_input_main_pods: usize,
@ -243,12 +250,54 @@ pub struct Params {
pub max_public_statements: usize,
pub max_statement_args: usize,
pub max_operation_args: usize,
// max number of statements that can be ANDed or ORed together
// in a custom predicate
pub max_custom_predicate_arity: usize,
pub max_custom_batch_size: usize,
// number of field elements in a hash
pub hash_size: usize,
}
impl Params {
pub fn max_priv_statements(&self) -> usize {
self.max_statements - self.max_public_statements
}
pub fn statement_tmpl_arg_size(self) -> usize {
2 * self.hash_size + 1
}
pub fn predicate_size(self) -> usize {
self.hash_size + 2
}
pub fn statement_tmpl_size(self) -> usize {
self.predicate_size() + self.max_statement_args * self.statement_tmpl_arg_size()
}
pub fn custom_predicate_size(self) -> usize {
self.max_custom_predicate_arity * self.statement_tmpl_size() + 2
}
pub fn custom_predicate_batch_size_field_elts(self) -> usize {
self.max_custom_batch_size * self.custom_predicate_size()
}
pub fn print_serialized_sizes(self) -> () {
println!("Parameter sizes:");
println!(
" Statement template argument: {}",
self.statement_tmpl_arg_size()
);
println!(" Predicate: {}", self.predicate_size());
println!(" Statement template: {}", self.statement_tmpl_size());
println!(" Custom predicate: {}", self.custom_predicate_size());
println!(
" Custom predicate batch: {}",
self.custom_predicate_batch_size_field_elts()
);
println!("");
}
}
impl Default for Params {
@ -261,6 +310,9 @@ impl Default for Params {
max_public_statements: 10,
max_statement_args: 5,
max_operation_args: 5,
max_custom_predicate_arity: 5,
max_custom_batch_size: 5,
hash_size: 4,
}
}
}
@ -328,5 +380,5 @@ pub trait PodProver {
pub trait ToFields {
/// returns Vec<F> representation of the type, and a usize indicating how many field elements
/// does the vector contain
fn to_fields(self) -> (Vec<F>, usize);
fn to_fields(&self, params: Params) -> (Vec<F>, usize);
}

View file

@ -3,7 +3,7 @@ use plonky2::field::types::Field;
use std::{collections::HashMap, fmt};
use strum_macros::FromRepr;
use super::{AnchoredKey, CustomPredicateRef, Hash, Predicate, ToFields, Value, F};
use super::{AnchoredKey, CustomPredicateRef, Hash, Params, Predicate, ToFields, Value, F};
pub const KEY_SIGNER: &str = "_signer";
pub const KEY_TYPE: &str = "_type";
@ -25,8 +25,8 @@ pub enum NativePredicate {
}
impl ToFields for NativePredicate {
fn to_fields(self) -> (Vec<F>, usize) {
(vec![F::from_canonical_u64(self as u64)], 1)
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
(vec![F::from_canonical_u64(*self as u64)], 1)
}
}
@ -88,12 +88,12 @@ impl Statement {
}
impl ToFields for Statement {
fn to_fields(self) -> (Vec<F>, usize) {
let (native_statement_f, native_statement_f_len) = self.code().to_fields();
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
let (native_statement_f, native_statement_f_len) = self.code().to_fields(params);
let (vec_statementarg_f, vec_statementarg_f_len) = self
.args()
.into_iter()
.map(|statement_arg| statement_arg.to_fields())
.map(|statement_arg| statement_arg.to_fields(params))
.fold((Vec::new(), 0), |mut acc, (f, l)| {
acc.0.extend(f);
acc.1 += l;
@ -156,7 +156,7 @@ impl StatementArg {
}
impl ToFields for StatementArg {
fn to_fields(self) -> (Vec<F>, usize) {
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
// NOTE: current version returns always the same amount of field elements in the returned
// vector, which means that the `None` case is padded with 8 zeroes, and the `Literal` case
// is padded with 4 zeroes. Since the returned vector will mostly be hashed (and reproduced
@ -175,8 +175,8 @@ impl ToFields for StatementArg {
.concat()
}
StatementArg::Key(ak) => {
let (podid_f, _) = ak.0.to_fields();
let (hash_f, _) = ak.1.to_fields();
let (podid_f, _) = ak.0.to_fields(params);
let (hash_f, _) = ak.1.to_fields(params);
[podid_f, hash_f].concat()
}
};