merkletree: reduce gate amount (-23%) by custom poseidon to use flag as initial state (#472)
* merkletree: custom poseidon to use flag as initial state. This allows to do the merkletree related hashing in 1 gate instead of 2, reducing ~23% of gates per merkle proof. | tree levels | 10 | 16 | 32 | 40 | 64 | 128 | 130 | 250 | 256 | |---------------|----|----|-----|-----|-----|-----|-----|------|------| | old num gates | 50 | 76 | 144 | 178 | 280 | 554 | 564 | 1076 | 1102 | | new num gates | 39 | 59 | 111 | 137 | 215 | 425 | 433 | 825 | 845 | * update docs with new tree hashing approach * add inline comment stating clear how the flag is used in the state permutation
This commit is contained in:
parent
641d8dabdd
commit
b04560c362
3 changed files with 101 additions and 15 deletions
|
|
@ -37,10 +37,12 @@ A Merkle tree with no entry at all is represented by the hash value
|
||||||
(With the Plonky2 backend, the hash function ```hash``` will output a 4-tuple of field elements.)
|
(With the Plonky2 backend, the hash function ```hash``` will output a 4-tuple of field elements.)
|
||||||
|
|
||||||
A Merkle tree with a single entry ```(key, value)``` is called a "leaf". It is represented by the hash value
|
A Merkle tree with a single entry ```(key, value)``` is called a "leaf". It is represented by the hash value
|
||||||
```root = hash((key, value, 1)).```
|
```root = hash(1, (key, value))```, where `1` is a flag indicating that it is a leaf, and it's used as the initial state of the hash (Poseidon) permutation.
|
||||||
|
|
||||||
A Merkle tree ```tree``` with more than one entry is required to have two subtrees, ```left``` and ```right```. It is then represented by the hash value
|
A Merkle tree ```tree``` with more than one entry is required to have two subtrees, ```left``` and ```right```. It is then represented by the hash value
|
||||||
```root = hash((left_root, right_root, 2)).```
|
```root = hash(2, (left_root, right_root))```, where `2` is a flag indicating that it is an intermediate node, and it's used as the initial state of the hash (Poseidon) permutation.
|
||||||
|
|
||||||
|
The flags are used as the initial state of the Poseidon permutation so that they don't account for extra inputs in the Poseidon gadget, needing only 1 gate for each node/leaf hash.
|
||||||
|
|
||||||
(The role of the constants 1 and 2 is to prevent collisions between leaves and non-leaf Merkle roots. If the constants were omitted, a large Merkle tree could be dishonestly interpreted as a leaf, leading to security vulnerabilities.)
|
(The role of the constants 1 and 2 is to prevent collisions between leaves and non-leaf Merkle roots. If the constants were omitted, a large Merkle tree could be dishonestly interpreted as a leaf, leading to security vulnerabilities.)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,14 +14,15 @@ use itertools::zip_eq;
|
||||||
use plonky2::{
|
use plonky2::{
|
||||||
field::types::Field,
|
field::types::Field,
|
||||||
hash::{
|
hash::{
|
||||||
hash_types::{HashOut, HashOutTarget},
|
hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS},
|
||||||
|
hashing::PlonkyPermutation,
|
||||||
poseidon::PoseidonHash,
|
poseidon::PoseidonHash,
|
||||||
},
|
},
|
||||||
iop::{
|
iop::{
|
||||||
target::{BoolTarget, Target},
|
target::{BoolTarget, Target},
|
||||||
witness::{PartialWitness, WitnessWrite},
|
witness::{PartialWitness, WitnessWrite},
|
||||||
},
|
},
|
||||||
plonk::circuit_builder::CircuitBuilder,
|
plonk::{circuit_builder::CircuitBuilder, config::AlgebraicHasher},
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
|
@ -349,7 +350,7 @@ fn compute_root_from_leaf(
|
||||||
.map(|j| builder.select(path[i], h.elements[j], sibling.elements[j]))
|
.map(|j| builder.select(path[i], h.elements[j], sibling.elements[j]))
|
||||||
.collect();
|
.collect();
|
||||||
let new_h =
|
let new_h =
|
||||||
builder.hash_n_to_hash_no_pad::<PoseidonHash>([input_1, input_2, vec![two]].concat());
|
hash_with_flag_target::<PoseidonHash>(builder, two, [input_1, input_2].concat());
|
||||||
|
|
||||||
let h_targ: Vec<Target> = (0..HASH_SIZE)
|
let h_targ: Vec<Target> = (0..HASH_SIZE)
|
||||||
.map(|j| builder.select(*selector, new_h.elements[j], h.elements[j]))
|
.map(|j| builder.select(*selector, new_h.elements[j], h.elements[j]))
|
||||||
|
|
@ -401,9 +402,51 @@ fn kv_hash_target(
|
||||||
.iter()
|
.iter()
|
||||||
.chain(value.elements.iter())
|
.chain(value.elements.iter())
|
||||||
.cloned()
|
.cloned()
|
||||||
.chain(iter::once(builder.one()))
|
|
||||||
.collect();
|
.collect();
|
||||||
builder.hash_n_to_hash_no_pad::<PoseidonHash>(inputs)
|
let one = builder.one();
|
||||||
|
hash_with_flag_target::<PoseidonHash>(builder, one, inputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Circuit matching the `merkletree::hash_with_flag` function. Variation of
|
||||||
|
/// Poseidon hash which takes as input 1 Goldilock element as a flag, and 8
|
||||||
|
/// Goldilocks elements as inputs to the hash. Performs the hashing in a single
|
||||||
|
/// gate.
|
||||||
|
/// The function is a fork of
|
||||||
|
/// [hash_n_to_m_no_pad](https://github.com/0xPolygonZero/plonky2/tree/5d9da5a65bbcba2c66eb29c035090eb2e9ccb05f/plonky2/src/hash/hashing.rs#L118)
|
||||||
|
/// from plonky2.
|
||||||
|
fn hash_with_flag_target<H: AlgebraicHasher<F>>(
|
||||||
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
|
flag: Target,
|
||||||
|
inputs: Vec<Target>,
|
||||||
|
) -> HashOutTarget {
|
||||||
|
assert_eq!(inputs.len(), H::AlgebraicPermutation::RATE);
|
||||||
|
|
||||||
|
// here we set `state` to a `SPONGE_RATE+SPONGE_CAPACITY` (8+4=12) in our
|
||||||
|
// case to a vector of repeated `flag` value. Later at the absorption step,
|
||||||
|
// it will fit the inputs values at positions 0-8, keeping the flag values
|
||||||
|
// at positions 8-12.
|
||||||
|
let mut state = H::AlgebraicPermutation::new(core::iter::repeat(flag));
|
||||||
|
|
||||||
|
// Absorb all input chunks.
|
||||||
|
for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) {
|
||||||
|
// Overwrite the first r elements with the inputs. This differs from a standard sponge,
|
||||||
|
// where we would xor or add in the inputs. This is a well-known variant, though,
|
||||||
|
// sometimes called "overwrite mode".
|
||||||
|
state.set_from_slice(input_chunk, 0);
|
||||||
|
state = builder.permute::<H>(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Squeeze until we have the desired number of outputs.
|
||||||
|
let mut outputs = Vec::with_capacity(NUM_HASH_OUT_ELTS);
|
||||||
|
loop {
|
||||||
|
for &s in state.squeeze() {
|
||||||
|
outputs.push(s);
|
||||||
|
if outputs.len() == NUM_HASH_OUT_ELTS {
|
||||||
|
return HashOutTarget::from_vec(outputs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state = builder.permute::<H>(state);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verifies that the merkletree state transition (from old_root to new_root)
|
/// Verifies that the merkletree state transition (from old_root to new_root)
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,15 @@
|
||||||
use std::{collections::HashMap, fmt, iter::IntoIterator};
|
use std::{collections::HashMap, fmt, iter::IntoIterator};
|
||||||
|
|
||||||
use itertools::zip_eq;
|
use itertools::zip_eq;
|
||||||
use plonky2::field::types::Field;
|
use plonky2::{
|
||||||
|
field::types::Field,
|
||||||
|
hash::{
|
||||||
|
hash_types::NUM_HASH_OUT_ELTS, hashing::PlonkyPermutation, poseidon::PoseidonPermutation,
|
||||||
|
},
|
||||||
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::middleware::{hash_fields, Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F};
|
use crate::middleware::{Hash, RawValue, EMPTY_HASH, EMPTY_VALUE, F};
|
||||||
|
|
||||||
pub mod circuit;
|
pub mod circuit;
|
||||||
pub use circuit::*;
|
pub use circuit::*;
|
||||||
|
|
@ -376,10 +381,47 @@ impl MerkleTree {
|
||||||
/// mitigate fake proofs.
|
/// mitigate fake proofs.
|
||||||
pub fn kv_hash(key: &RawValue, value: Option<RawValue>) -> Hash {
|
pub fn kv_hash(key: &RawValue, value: Option<RawValue>) -> Hash {
|
||||||
value
|
value
|
||||||
.map(|v| hash_fields(&[key.0.to_vec(), v.0.to_vec(), vec![F::ONE]].concat()))
|
.map(|v| hash_with_flag(F::ONE, &[key.0.to_vec(), v.0.to_vec()].concat()))
|
||||||
.unwrap_or(EMPTY_HASH)
|
.unwrap_or(EMPTY_HASH)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Variation of Poseidon hash which takes as input 1 Goldilock element as a
|
||||||
|
/// flag, and 8 Goldilocks elements as inputs to the hash. Performs the hashing
|
||||||
|
/// in a single gate.
|
||||||
|
/// The function is a fork of
|
||||||
|
/// [hash_n_to_m_no_pad](https://github.com/0xPolygonZero/plonky2/tree/5d9da5a65bbcba2c66eb29c035090eb2e9ccb05f/plonky2/src/hash/hashing.rs#L30)
|
||||||
|
/// from plonky2.
|
||||||
|
fn hash_with_flag(flag: F, inputs: &[F]) -> Hash {
|
||||||
|
assert_eq!(
|
||||||
|
inputs.len(),
|
||||||
|
<PoseidonPermutation<F> as PlonkyPermutation<F>>::RATE
|
||||||
|
);
|
||||||
|
|
||||||
|
// this will set `perm` to a `SPONGE_RATE+SPONGE_CAPACITY` (8+4=12) in our
|
||||||
|
// case to a vector of repeated `flag` value. Later at the absorption step,
|
||||||
|
// it will fit the inputs values at positions 0-8, keeping the flag values
|
||||||
|
// at positions 8-12.
|
||||||
|
let mut perm = <PoseidonPermutation<F> as PlonkyPermutation<F>>::new(core::iter::repeat(flag));
|
||||||
|
|
||||||
|
// Absorb all input chunks.
|
||||||
|
for input_chunk in inputs.chunks(<PoseidonPermutation<F> as PlonkyPermutation<F>>::RATE) {
|
||||||
|
perm.set_from_slice(input_chunk, 0);
|
||||||
|
perm.permute();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Squeeze until we have the desired number of outputs.
|
||||||
|
let mut outputs = Vec::new();
|
||||||
|
loop {
|
||||||
|
for &item in perm.squeeze() {
|
||||||
|
outputs.push(item);
|
||||||
|
if outputs.len() == NUM_HASH_OUT_ELTS {
|
||||||
|
return Hash(crate::middleware::HashOut::from_vec(outputs).elements);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
perm.permute();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a> IntoIterator for &'a MerkleTree {
|
impl<'a> IntoIterator for &'a MerkleTree {
|
||||||
type Item = (&'a RawValue, &'a RawValue);
|
type Item = (&'a RawValue, &'a RawValue);
|
||||||
type IntoIter = Iter<'a>;
|
type IntoIter = Iter<'a>;
|
||||||
|
|
@ -437,13 +479,12 @@ impl MerkleProof {
|
||||||
fn compute_root_from_node(&self, node_hash: &Hash, path: Vec<bool>) -> TreeResult<Hash> {
|
fn compute_root_from_node(&self, node_hash: &Hash, path: Vec<bool>) -> TreeResult<Hash> {
|
||||||
let mut h = *node_hash;
|
let mut h = *node_hash;
|
||||||
for (i, sibling) in self.siblings.iter().enumerate().rev() {
|
for (i, sibling) in self.siblings.iter().enumerate().rev() {
|
||||||
let mut input: Vec<F> = if path[i] {
|
let input: Vec<F> = if path[i] {
|
||||||
[sibling.0, h.0].concat()
|
[sibling.0, h.0].concat()
|
||||||
} else {
|
} else {
|
||||||
[h.0, sibling.0].concat()
|
[h.0, sibling.0].concat()
|
||||||
};
|
};
|
||||||
input.push(F::TWO);
|
h = hash_with_flag(F::TWO, &input);
|
||||||
h = hash_fields(&input);
|
|
||||||
}
|
}
|
||||||
Ok(h)
|
Ok(h)
|
||||||
}
|
}
|
||||||
|
|
@ -876,8 +917,8 @@ impl Intermediate {
|
||||||
}
|
}
|
||||||
let l_hash = self.left.compute_hash();
|
let l_hash = self.left.compute_hash();
|
||||||
let r_hash = self.right.compute_hash();
|
let r_hash = self.right.compute_hash();
|
||||||
let input: Vec<F> = [l_hash.0.to_vec(), r_hash.0.to_vec(), vec![F::TWO]].concat();
|
let input: Vec<F> = [l_hash.0.to_vec(), r_hash.0.to_vec()].concat();
|
||||||
let h = hash_fields(&input);
|
let h = hash_with_flag(F::TWO, &input);
|
||||||
self.hash = Some(h);
|
self.hash = Some(h);
|
||||||
h
|
h
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue