From b04560c362deb9b625ff0a3d25139812efd66a22 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Wed, 4 Feb 2026 12:31:56 +0100 Subject: [PATCH] merkletree: reduce gate amount (-23%) by custom poseidon to use flag as initial state (#472) * merkletree: custom poseidon to use flag as initial state. This allows to do the merkletree related hashing in 1 gate instead of 2, reducing ~23% of gates per merkle proof. | tree levels | 10 | 16 | 32 | 40 | 64 | 128 | 130 | 250 | 256 | |---------------|----|----|-----|-----|-----|-----|-----|------|------| | old num gates | 50 | 76 | 144 | 178 | 280 | 554 | 564 | 1076 | 1102 | | new num gates | 39 | 59 | 111 | 137 | 215 | 425 | 433 | 825 | 845 | * update docs with new tree hashing approach * add inline comment stating clear how the flag is used in the state permutation --- book/src/merkletree.md | 6 +- .../plonky2/primitives/merkletree/circuit.rs | 53 +++++++++++++++-- .../plonky2/primitives/merkletree/mod.rs | 57 ++++++++++++++++--- 3 files changed, 101 insertions(+), 15 deletions(-) diff --git a/book/src/merkletree.md b/book/src/merkletree.md index 574d06f..b246455 100644 --- a/book/src/merkletree.md +++ b/book/src/merkletree.md @@ -37,10 +37,12 @@ A Merkle tree with no entry at all is represented by the hash value (With the Plonky2 backend, the hash function ```hash``` will output a 4-tuple of field elements.) A Merkle tree with a single entry ```(key, value)``` is called a "leaf". It is represented by the hash value -```root = hash((key, value, 1)).``` +```root = hash(1, (key, value))```, where `1` is a flag indicating that it is a leaf, and it's used as the initial state of the hash (Poseidon) permutation. A Merkle tree ```tree``` with more than one entry is required to have two subtrees, ```left``` and ```right```. It is then represented by the hash value -```root = hash((left_root, right_root, 2)).``` +```root = hash(2, (left_root, right_root))```, where `2` is a flag indicating that it is an intermediate node, and it's used as the initial state of the hash (Poseidon) permutation. + +The flags are used as the initial state of the Poseidon permutation so that they don't account for extra inputs in the Poseidon gadget, needing only 1 gate for each node/leaf hash. (The role of the constants 1 and 2 is to prevent collisions between leaves and non-leaf Merkle roots. If the constants were omitted, a large Merkle tree could be dishonestly interpreted as a leaf, leading to security vulnerabilities.) diff --git a/src/backends/plonky2/primitives/merkletree/circuit.rs b/src/backends/plonky2/primitives/merkletree/circuit.rs index 8b34999..0c5978f 100644 --- a/src/backends/plonky2/primitives/merkletree/circuit.rs +++ b/src/backends/plonky2/primitives/merkletree/circuit.rs @@ -14,14 +14,15 @@ use itertools::zip_eq; use plonky2::{ field::types::Field, hash::{ - hash_types::{HashOut, HashOutTarget}, + hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, + hashing::PlonkyPermutation, poseidon::PoseidonHash, }, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, - plonk::circuit_builder::CircuitBuilder, + plonk::{circuit_builder::CircuitBuilder, config::AlgebraicHasher}, }; use serde::{Deserialize, Serialize}; @@ -349,7 +350,7 @@ fn compute_root_from_leaf( .map(|j| builder.select(path[i], h.elements[j], sibling.elements[j])) .collect(); let new_h = - builder.hash_n_to_hash_no_pad::([input_1, input_2, vec![two]].concat()); + hash_with_flag_target::(builder, two, [input_1, input_2].concat()); let h_targ: Vec = (0..HASH_SIZE) .map(|j| builder.select(*selector, new_h.elements[j], h.elements[j])) @@ -401,9 +402,51 @@ fn kv_hash_target( .iter() .chain(value.elements.iter()) .cloned() - .chain(iter::once(builder.one())) .collect(); - builder.hash_n_to_hash_no_pad::(inputs) + let one = builder.one(); + hash_with_flag_target::(builder, one, inputs) +} + +/// Circuit matching the `merkletree::hash_with_flag` function. Variation of +/// Poseidon hash which takes as input 1 Goldilock element as a flag, and 8 +/// Goldilocks elements as inputs to the hash. Performs the hashing in a single +/// gate. +/// The function is a fork of +/// [hash_n_to_m_no_pad](https://github.com/0xPolygonZero/plonky2/tree/5d9da5a65bbcba2c66eb29c035090eb2e9ccb05f/plonky2/src/hash/hashing.rs#L118) +/// from plonky2. +fn hash_with_flag_target>( + builder: &mut CircuitBuilder, + flag: Target, + inputs: Vec, +) -> HashOutTarget { + assert_eq!(inputs.len(), H::AlgebraicPermutation::RATE); + + // here we set `state` to a `SPONGE_RATE+SPONGE_CAPACITY` (8+4=12) in our + // case to a vector of repeated `flag` value. Later at the absorption step, + // it will fit the inputs values at positions 0-8, keeping the flag values + // at positions 8-12. + let mut state = H::AlgebraicPermutation::new(core::iter::repeat(flag)); + + // Absorb all input chunks. + for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) { + // Overwrite the first r elements with the inputs. This differs from a standard sponge, + // where we would xor or add in the inputs. This is a well-known variant, though, + // sometimes called "overwrite mode". + state.set_from_slice(input_chunk, 0); + state = builder.permute::(state); + } + + // Squeeze until we have the desired number of outputs. + let mut outputs = Vec::with_capacity(NUM_HASH_OUT_ELTS); + loop { + for &s in state.squeeze() { + outputs.push(s); + if outputs.len() == NUM_HASH_OUT_ELTS { + return HashOutTarget::from_vec(outputs); + } + } + state = builder.permute::(state); + } } /// Verifies that the merkletree state transition (from old_root to new_root) diff --git a/src/backends/plonky2/primitives/merkletree/mod.rs b/src/backends/plonky2/primitives/merkletree/mod.rs index 9b3609b..35c4c11 100644 --- a/src/backends/plonky2/primitives/merkletree/mod.rs +++ b/src/backends/plonky2/primitives/merkletree/mod.rs @@ -3,10 +3,15 @@ use std::{collections::HashMap, fmt, iter::IntoIterator}; use itertools::zip_eq; -use plonky2::field::types::Field; +use plonky2::{ + field::types::Field, + hash::{ + hash_types::NUM_HASH_OUT_ELTS, hashing::PlonkyPermutation, poseidon::PoseidonPermutation, + }, +}; use serde::{Deserialize, Serialize}; -use crate::middleware::{hash_fields, Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F}; +use crate::middleware::{Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F}; pub mod circuit; pub use circuit::*; @@ -376,10 +381,47 @@ impl MerkleTree { /// mitigate fake proofs. pub fn kv_hash(key: &RawValue, value: Option) -> Hash { value - .map(|v| hash_fields(&[key.0.to_vec(), v.0.to_vec(), vec![F::ONE]].concat())) + .map(|v| hash_with_flag(F::ONE, &[key.0.to_vec(), v.0.to_vec()].concat())) .unwrap_or(EMPTY_HASH) } +/// Variation of Poseidon hash which takes as input 1 Goldilock element as a +/// flag, and 8 Goldilocks elements as inputs to the hash. Performs the hashing +/// in a single gate. +/// The function is a fork of +/// [hash_n_to_m_no_pad](https://github.com/0xPolygonZero/plonky2/tree/5d9da5a65bbcba2c66eb29c035090eb2e9ccb05f/plonky2/src/hash/hashing.rs#L30) +/// from plonky2. +fn hash_with_flag(flag: F, inputs: &[F]) -> Hash { + assert_eq!( + inputs.len(), + as PlonkyPermutation>::RATE + ); + + // this will set `perm` to a `SPONGE_RATE+SPONGE_CAPACITY` (8+4=12) in our + // case to a vector of repeated `flag` value. Later at the absorption step, + // it will fit the inputs values at positions 0-8, keeping the flag values + // at positions 8-12. + let mut perm = as PlonkyPermutation>::new(core::iter::repeat(flag)); + + // Absorb all input chunks. + for input_chunk in inputs.chunks( as PlonkyPermutation>::RATE) { + perm.set_from_slice(input_chunk, 0); + perm.permute(); + } + + // Squeeze until we have the desired number of outputs. + let mut outputs = Vec::new(); + loop { + for &item in perm.squeeze() { + outputs.push(item); + if outputs.len() == NUM_HASH_OUT_ELTS { + return Hash(crate::middleware::HashOut::from_vec(outputs).elements); + } + } + perm.permute(); + } +} + impl<'a> IntoIterator for &'a MerkleTree { type Item = (&'a RawValue, &'a RawValue); type IntoIter = Iter<'a>; @@ -437,13 +479,12 @@ impl MerkleProof { fn compute_root_from_node(&self, node_hash: &Hash, path: Vec) -> TreeResult { let mut h = *node_hash; for (i, sibling) in self.siblings.iter().enumerate().rev() { - let mut input: Vec = if path[i] { + let input: Vec = if path[i] { [sibling.0, h.0].concat() } else { [h.0, sibling.0].concat() }; - input.push(F::TWO); - h = hash_fields(&input); + h = hash_with_flag(F::TWO, &input); } Ok(h) } @@ -876,8 +917,8 @@ impl Intermediate { } let l_hash = self.left.compute_hash(); let r_hash = self.right.compute_hash(); - let input: Vec = [l_hash.0.to_vec(), r_hash.0.to_vec(), vec![F::TWO]].concat(); - let h = hash_fields(&input); + let input: Vec = [l_hash.0.to_vec(), r_hash.0.to_vec()].concat(); + let h = hash_with_flag(F::TWO, &input); self.hash = Some(h); h }