merkletree: reduce gate amount (-23%) by custom poseidon to use flag as initial state (#472)

* merkletree: custom poseidon to use flag as initial state.

This allows to do the merkletree related hashing in 1 gate instead of 2,
reducing ~23% of gates per merkle proof.

| tree levels   | 10 | 16 | 32  | 40  | 64  | 128 | 130 | 250  | 256  |
|---------------|----|----|-----|-----|-----|-----|-----|------|------|
| old num gates | 50 | 76 | 144 | 178 | 280 | 554 | 564 | 1076 | 1102 |
| new num gates | 39 | 59 | 111 | 137 | 215 | 425 | 433 | 825  | 845  |

* update docs with new tree hashing approach

* add inline comment stating clear how the flag is used in the state permutation
This commit is contained in:
arnaucube 2026-02-04 12:31:56 +01:00 committed by GitHub
parent 641d8dabdd
commit b04560c362
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 15 deletions

View file

@ -14,14 +14,15 @@ use itertools::zip_eq;
use plonky2::{
field::types::Field,
hash::{
hash_types::{HashOut, HashOutTarget},
hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS},
hashing::PlonkyPermutation,
poseidon::PoseidonHash,
},
iop::{
target::{BoolTarget, Target},
witness::{PartialWitness, WitnessWrite},
},
plonk::circuit_builder::CircuitBuilder,
plonk::{circuit_builder::CircuitBuilder, config::AlgebraicHasher},
};
use serde::{Deserialize, Serialize};
@ -349,7 +350,7 @@ fn compute_root_from_leaf(
.map(|j| builder.select(path[i], h.elements[j], sibling.elements[j]))
.collect();
let new_h =
builder.hash_n_to_hash_no_pad::<PoseidonHash>([input_1, input_2, vec![two]].concat());
hash_with_flag_target::<PoseidonHash>(builder, two, [input_1, input_2].concat());
let h_targ: Vec<Target> = (0..HASH_SIZE)
.map(|j| builder.select(*selector, new_h.elements[j], h.elements[j]))
@ -401,9 +402,51 @@ fn kv_hash_target(
.iter()
.chain(value.elements.iter())
.cloned()
.chain(iter::once(builder.one()))
.collect();
builder.hash_n_to_hash_no_pad::<PoseidonHash>(inputs)
let one = builder.one();
hash_with_flag_target::<PoseidonHash>(builder, one, inputs)
}
/// Circuit matching the `merkletree::hash_with_flag` function. Variation of
/// Poseidon hash which takes as input 1 Goldilock element as a flag, and 8
/// Goldilocks elements as inputs to the hash. Performs the hashing in a single
/// gate.
/// The function is a fork of
/// [hash_n_to_m_no_pad](https://github.com/0xPolygonZero/plonky2/tree/5d9da5a65bbcba2c66eb29c035090eb2e9ccb05f/plonky2/src/hash/hashing.rs#L118)
/// from plonky2.
fn hash_with_flag_target<H: AlgebraicHasher<F>>(
builder: &mut CircuitBuilder<F, D>,
flag: Target,
inputs: Vec<Target>,
) -> HashOutTarget {
assert_eq!(inputs.len(), H::AlgebraicPermutation::RATE);
// here we set `state` to a `SPONGE_RATE+SPONGE_CAPACITY` (8+4=12) in our
// case to a vector of repeated `flag` value. Later at the absorption step,
// it will fit the inputs values at positions 0-8, keeping the flag values
// at positions 8-12.
let mut state = H::AlgebraicPermutation::new(core::iter::repeat(flag));
// Absorb all input chunks.
for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) {
// Overwrite the first r elements with the inputs. This differs from a standard sponge,
// where we would xor or add in the inputs. This is a well-known variant, though,
// sometimes called "overwrite mode".
state.set_from_slice(input_chunk, 0);
state = builder.permute::<H>(state);
}
// Squeeze until we have the desired number of outputs.
let mut outputs = Vec::with_capacity(NUM_HASH_OUT_ELTS);
loop {
for &s in state.squeeze() {
outputs.push(s);
if outputs.len() == NUM_HASH_OUT_ELTS {
return HashOutTarget::from_vec(outputs);
}
}
state = builder.permute::<H>(state);
}
}
/// Verifies that the merkletree state transition (from old_root to new_root)

View file

@ -3,10 +3,15 @@
use std::{collections::HashMap, fmt, iter::IntoIterator};
use itertools::zip_eq;
use plonky2::field::types::Field;
use plonky2::{
field::types::Field,
hash::{
hash_types::NUM_HASH_OUT_ELTS, hashing::PlonkyPermutation, poseidon::PoseidonPermutation,
},
};
use serde::{Deserialize, Serialize};
use crate::middleware::{hash_fields, Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F};
use crate::middleware::{Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F};
pub mod circuit;
pub use circuit::*;
@ -376,10 +381,47 @@ impl MerkleTree {
/// mitigate fake proofs.
pub fn kv_hash(key: &RawValue, value: Option<RawValue>) -> Hash {
value
.map(|v| hash_fields(&[key.0.to_vec(), v.0.to_vec(), vec![F::ONE]].concat()))
.map(|v| hash_with_flag(F::ONE, &[key.0.to_vec(), v.0.to_vec()].concat()))
.unwrap_or(EMPTY_HASH)
}
/// Variation of Poseidon hash which takes as input 1 Goldilock element as a
/// flag, and 8 Goldilocks elements as inputs to the hash. Performs the hashing
/// in a single gate.
/// The function is a fork of
/// [hash_n_to_m_no_pad](https://github.com/0xPolygonZero/plonky2/tree/5d9da5a65bbcba2c66eb29c035090eb2e9ccb05f/plonky2/src/hash/hashing.rs#L30)
/// from plonky2.
fn hash_with_flag(flag: F, inputs: &[F]) -> Hash {
assert_eq!(
inputs.len(),
<PoseidonPermutation<F> as PlonkyPermutation<F>>::RATE
);
// this will set `perm` to a `SPONGE_RATE+SPONGE_CAPACITY` (8+4=12) in our
// case to a vector of repeated `flag` value. Later at the absorption step,
// it will fit the inputs values at positions 0-8, keeping the flag values
// at positions 8-12.
let mut perm = <PoseidonPermutation<F> as PlonkyPermutation<F>>::new(core::iter::repeat(flag));
// Absorb all input chunks.
for input_chunk in inputs.chunks(<PoseidonPermutation<F> as PlonkyPermutation<F>>::RATE) {
perm.set_from_slice(input_chunk, 0);
perm.permute();
}
// Squeeze until we have the desired number of outputs.
let mut outputs = Vec::new();
loop {
for &item in perm.squeeze() {
outputs.push(item);
if outputs.len() == NUM_HASH_OUT_ELTS {
return Hash(crate::middleware::HashOut::from_vec(outputs).elements);
}
}
perm.permute();
}
}
impl<'a> IntoIterator for &'a MerkleTree {
type Item = (&'a RawValue, &'a RawValue);
type IntoIter = Iter<'a>;
@ -437,13 +479,12 @@ impl MerkleProof {
fn compute_root_from_node(&self, node_hash: &Hash, path: Vec<bool>) -> TreeResult<Hash> {
let mut h = *node_hash;
for (i, sibling) in self.siblings.iter().enumerate().rev() {
let mut input: Vec<F> = if path[i] {
let input: Vec<F> = if path[i] {
[sibling.0, h.0].concat()
} else {
[h.0, sibling.0].concat()
};
input.push(F::TWO);
h = hash_fields(&input);
h = hash_with_flag(F::TWO, &input);
}
Ok(h)
}
@ -876,8 +917,8 @@ impl Intermediate {
}
let l_hash = self.left.compute_hash();
let r_hash = self.right.compute_hash();
let input: Vec<F> = [l_hash.0.to_vec(), r_hash.0.to_vec(), vec![F::TWO]].concat();
let h = hash_fields(&input);
let input: Vec<F> = [l_hash.0.to_vec(), r_hash.0.to_vec()].concat();
let h = hash_with_flag(F::TWO, &input);
self.hash = Some(h);
h
}