merkletree: reduce gate amount (-23%) by custom poseidon to use flag as initial state (#472)

* merkletree: custom poseidon to use flag as initial state.

This allows to do the merkletree related hashing in 1 gate instead of 2,
reducing ~23% of gates per merkle proof.

| tree levels   | 10 | 16 | 32  | 40  | 64  | 128 | 130 | 250  | 256  |
|---------------|----|----|-----|-----|-----|-----|-----|------|------|
| old num gates | 50 | 76 | 144 | 178 | 280 | 554 | 564 | 1076 | 1102 |
| new num gates | 39 | 59 | 111 | 137 | 215 | 425 | 433 | 825  | 845  |

* update docs with new tree hashing approach

* add inline comment stating clear how the flag is used in the state permutation
This commit is contained in:
arnaucube 2026-02-04 12:31:56 +01:00 committed by GitHub
parent 641d8dabdd
commit b04560c362
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 15 deletions

View file

@ -37,10 +37,12 @@ A Merkle tree with no entry at all is represented by the hash value
(With the Plonky2 backend, the hash function ```hash``` will output a 4-tuple of field elements.)
A Merkle tree with a single entry ```(key, value)``` is called a "leaf". It is represented by the hash value
```root = hash((key, value, 1)).```
```root = hash(1, (key, value))```, where `1` is a flag indicating that it is a leaf, and it's used as the initial state of the hash (Poseidon) permutation.
A Merkle tree ```tree``` with more than one entry is required to have two subtrees, ```left``` and ```right```. It is then represented by the hash value
```root = hash((left_root, right_root, 2)).```
```root = hash(2, (left_root, right_root))```, where `2` is a flag indicating that it is an intermediate node, and it's used as the initial state of the hash (Poseidon) permutation.
The flags are used as the initial state of the Poseidon permutation so that they don't account for extra inputs in the Poseidon gadget, needing only 1 gate for each node/leaf hash.
(The role of the constants 1 and 2 is to prevent collisions between leaves and non-leaf Merkle roots. If the constants were omitted, a large Merkle tree could be dishonestly interpreted as a leaf, leading to security vulnerabilities.)

View file

@ -14,14 +14,15 @@ use itertools::zip_eq;
use plonky2::{
field::types::Field,
hash::{
hash_types::{HashOut, HashOutTarget},
hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS},
hashing::PlonkyPermutation,
poseidon::PoseidonHash,
},
iop::{
target::{BoolTarget, Target},
witness::{PartialWitness, WitnessWrite},
},
plonk::circuit_builder::CircuitBuilder,
plonk::{circuit_builder::CircuitBuilder, config::AlgebraicHasher},
};
use serde::{Deserialize, Serialize};
@ -349,7 +350,7 @@ fn compute_root_from_leaf(
.map(|j| builder.select(path[i], h.elements[j], sibling.elements[j]))
.collect();
let new_h =
builder.hash_n_to_hash_no_pad::<PoseidonHash>([input_1, input_2, vec![two]].concat());
hash_with_flag_target::<PoseidonHash>(builder, two, [input_1, input_2].concat());
let h_targ: Vec<Target> = (0..HASH_SIZE)
.map(|j| builder.select(*selector, new_h.elements[j], h.elements[j]))
@ -401,9 +402,51 @@ fn kv_hash_target(
.iter()
.chain(value.elements.iter())
.cloned()
.chain(iter::once(builder.one()))
.collect();
builder.hash_n_to_hash_no_pad::<PoseidonHash>(inputs)
let one = builder.one();
hash_with_flag_target::<PoseidonHash>(builder, one, inputs)
}
/// Circuit matching the `merkletree::hash_with_flag` function. Variation of
/// Poseidon hash which takes as input 1 Goldilock element as a flag, and 8
/// Goldilocks elements as inputs to the hash. Performs the hashing in a single
/// gate.
/// The function is a fork of
/// [hash_n_to_m_no_pad](https://github.com/0xPolygonZero/plonky2/tree/5d9da5a65bbcba2c66eb29c035090eb2e9ccb05f/plonky2/src/hash/hashing.rs#L118)
/// from plonky2.
fn hash_with_flag_target<H: AlgebraicHasher<F>>(
builder: &mut CircuitBuilder<F, D>,
flag: Target,
inputs: Vec<Target>,
) -> HashOutTarget {
assert_eq!(inputs.len(), H::AlgebraicPermutation::RATE);
// here we set `state` to a `SPONGE_RATE+SPONGE_CAPACITY` (8+4=12) in our
// case to a vector of repeated `flag` value. Later at the absorption step,
// it will fit the inputs values at positions 0-8, keeping the flag values
// at positions 8-12.
let mut state = H::AlgebraicPermutation::new(core::iter::repeat(flag));
// Absorb all input chunks.
for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) {
// Overwrite the first r elements with the inputs. This differs from a standard sponge,
// where we would xor or add in the inputs. This is a well-known variant, though,
// sometimes called "overwrite mode".
state.set_from_slice(input_chunk, 0);
state = builder.permute::<H>(state);
}
// Squeeze until we have the desired number of outputs.
let mut outputs = Vec::with_capacity(NUM_HASH_OUT_ELTS);
loop {
for &s in state.squeeze() {
outputs.push(s);
if outputs.len() == NUM_HASH_OUT_ELTS {
return HashOutTarget::from_vec(outputs);
}
}
state = builder.permute::<H>(state);
}
}
/// Verifies that the merkletree state transition (from old_root to new_root)

View file

@ -3,10 +3,15 @@
use std::{collections::HashMap, fmt, iter::IntoIterator};
use itertools::zip_eq;
use plonky2::field::types::Field;
use plonky2::{
field::types::Field,
hash::{
hash_types::NUM_HASH_OUT_ELTS, hashing::PlonkyPermutation, poseidon::PoseidonPermutation,
},
};
use serde::{Deserialize, Serialize};
use crate::middleware::{hash_fields, Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F};
use crate::middleware::{Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F};
pub mod circuit;
pub use circuit::*;
@ -376,10 +381,47 @@ impl MerkleTree {
/// mitigate fake proofs.
pub fn kv_hash(key: &RawValue, value: Option<RawValue>) -> Hash {
value
.map(|v| hash_fields(&[key.0.to_vec(), v.0.to_vec(), vec![F::ONE]].concat()))
.map(|v| hash_with_flag(F::ONE, &[key.0.to_vec(), v.0.to_vec()].concat()))
.unwrap_or(EMPTY_HASH)
}
/// Variation of Poseidon hash which takes as input 1 Goldilock element as a
/// flag, and 8 Goldilocks elements as inputs to the hash. Performs the hashing
/// in a single gate.
/// The function is a fork of
/// [hash_n_to_m_no_pad](https://github.com/0xPolygonZero/plonky2/tree/5d9da5a65bbcba2c66eb29c035090eb2e9ccb05f/plonky2/src/hash/hashing.rs#L30)
/// from plonky2.
fn hash_with_flag(flag: F, inputs: &[F]) -> Hash {
assert_eq!(
inputs.len(),
<PoseidonPermutation<F> as PlonkyPermutation<F>>::RATE
);
// this will set `perm` to a `SPONGE_RATE+SPONGE_CAPACITY` (8+4=12) in our
// case to a vector of repeated `flag` value. Later at the absorption step,
// it will fit the inputs values at positions 0-8, keeping the flag values
// at positions 8-12.
let mut perm = <PoseidonPermutation<F> as PlonkyPermutation<F>>::new(core::iter::repeat(flag));
// Absorb all input chunks.
for input_chunk in inputs.chunks(<PoseidonPermutation<F> as PlonkyPermutation<F>>::RATE) {
perm.set_from_slice(input_chunk, 0);
perm.permute();
}
// Squeeze until we have the desired number of outputs.
let mut outputs = Vec::new();
loop {
for &item in perm.squeeze() {
outputs.push(item);
if outputs.len() == NUM_HASH_OUT_ELTS {
return Hash(crate::middleware::HashOut::from_vec(outputs).elements);
}
}
perm.permute();
}
}
impl<'a> IntoIterator for &'a MerkleTree {
type Item = (&'a RawValue, &'a RawValue);
type IntoIter = Iter<'a>;
@ -437,13 +479,12 @@ impl MerkleProof {
fn compute_root_from_node(&self, node_hash: &Hash, path: Vec<bool>) -> TreeResult<Hash> {
let mut h = *node_hash;
for (i, sibling) in self.siblings.iter().enumerate().rev() {
let mut input: Vec<F> = if path[i] {
let input: Vec<F> = if path[i] {
[sibling.0, h.0].concat()
} else {
[h.0, sibling.0].concat()
};
input.push(F::TWO);
h = hash_fields(&input);
h = hash_with_flag(F::TWO, &input);
}
Ok(h)
}
@ -876,8 +917,8 @@ impl Intermediate {
}
let l_hash = self.left.compute_hash();
let r_hash = self.right.compute_hash();
let input: Vec<F> = [l_hash.0.to_vec(), r_hash.0.to_vec(), vec![F::TWO]].concat();
let h = hash_fields(&input);
let input: Vec<F> = [l_hash.0.to_vec(), r_hash.0.to_vec()].concat();
let h = hash_with_flag(F::TWO, &input);
self.hash = Some(h);
h
}