calculate MainPod id in a dynamic-friendly way (#241)
* calculate MainPod id in a dynamic-friendly way The MainPod id is now calculated with front padding and a fixed size independent of max_public_statements so that introduction gadgets can be verified by a MainPod while paying only for the number of statements they use. This is because with front padding of none-statements we can precompute the poseidon state corresponding to absorbing all the padding statements and only pay constraints for the non-padding statements. The id is calculated as follows: `id = hash(serialize(reverse(statements || none-statements)))` * fix test
This commit is contained in:
parent
82481e88d7
commit
d3fef8392e
6 changed files with 245 additions and 26 deletions
|
|
@ -3,9 +3,16 @@ use std::{array, iter, sync::Arc};
|
|||
use itertools::{zip_eq, Itertools};
|
||||
use plonky2::{
|
||||
field::types::Field,
|
||||
hash::{hash_types::HashOutTarget, poseidon::PoseidonHash},
|
||||
iop::{target::BoolTarget, witness::PartialWitness},
|
||||
plonk::circuit_builder::CircuitBuilder,
|
||||
hash::{
|
||||
hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS},
|
||||
hashing::PlonkyPermutation,
|
||||
poseidon::{PoseidonHash, PoseidonPermutation},
|
||||
},
|
||||
iop::{
|
||||
target::{BoolTarget, Target},
|
||||
witness::PartialWitness,
|
||||
},
|
||||
plonk::{circuit_builder::CircuitBuilder, config::AlgebraicHasher},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
|
|
@ -22,7 +29,7 @@ use crate::{
|
|||
signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget},
|
||||
},
|
||||
error::Result,
|
||||
mainpod,
|
||||
mainpod::{self, pad_statement},
|
||||
primitives::merkletree::{
|
||||
MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProofGadget,
|
||||
},
|
||||
|
|
@ -894,6 +901,88 @@ impl CustomOperationVerifyGadget {
|
|||
}
|
||||
}
|
||||
|
||||
struct CalculateIdGadget {
|
||||
params: Params,
|
||||
}
|
||||
|
||||
impl CalculateIdGadget {
|
||||
/// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder
|
||||
/// elements that didn't fit into a chunk.
|
||||
fn precompute_hash_state<F: RichField, P: PlonkyPermutation<F>>(inputs: &[F]) -> (P, &[F]) {
|
||||
let (inputs, inputs_rem) = inputs.split_at((inputs.len() / P::RATE) * P::RATE);
|
||||
let mut perm = P::new(core::iter::repeat(F::ZERO));
|
||||
|
||||
// Absorb all inputs up to the biggest multiple of RATE.
|
||||
for input_chunk in inputs.chunks(P::RATE) {
|
||||
perm.set_from_slice(input_chunk, 0);
|
||||
perm.permute();
|
||||
}
|
||||
|
||||
(perm, inputs_rem)
|
||||
}
|
||||
|
||||
/// Hash `inputs` starting from a circuit-constant `perm` state.
|
||||
fn hash_from_state<H: AlgebraicHasher<F>, P: PlonkyPermutation<F>>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
perm: P,
|
||||
inputs: &[Target],
|
||||
) -> HashOutTarget {
|
||||
let mut state =
|
||||
H::AlgebraicPermutation::new(perm.as_ref().iter().map(|v| builder.constant(*v)));
|
||||
|
||||
// Absorb all input chunks.
|
||||
for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) {
|
||||
// Overwrite the first r elements with the inputs. This differs from a standard sponge,
|
||||
// where we would xor or add in the inputs. This is a well-known variant, though,
|
||||
// sometimes called "overwrite mode".
|
||||
state.set_from_slice(input_chunk, 0);
|
||||
state = builder.permute::<H>(state);
|
||||
}
|
||||
|
||||
let num_outputs = NUM_HASH_OUT_ELTS;
|
||||
// Squeeze until we have the desired number of outputs.
|
||||
let mut outputs = Vec::with_capacity(num_outputs);
|
||||
loop {
|
||||
for &s in state.squeeze() {
|
||||
outputs.push(s);
|
||||
if outputs.len() == num_outputs {
|
||||
return HashOutTarget::from_vec(outputs);
|
||||
}
|
||||
}
|
||||
state = builder.permute::<H>(state);
|
||||
}
|
||||
}
|
||||
|
||||
fn eval(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
statements: &[StatementTarget],
|
||||
) -> HashOutTarget {
|
||||
let measure = measure_gates_begin!(builder, "CalculateId");
|
||||
let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten());
|
||||
let mut none_st = mainpod::Statement::from(Statement::None);
|
||||
pad_statement(&self.params, &mut none_st);
|
||||
let front_pad_elts = iter::repeat(&none_st)
|
||||
.take(self.params.num_public_statements_id - self.params.max_public_statements)
|
||||
.flat_map(|s| s.to_fields(&self.params))
|
||||
.collect_vec();
|
||||
let (perm, front_pad_elts_rem) =
|
||||
Self::precompute_hash_state::<F, PoseidonPermutation<F>>(&front_pad_elts);
|
||||
|
||||
// Precompute the Poseidon state for the initial padding chunks
|
||||
let inputs = front_pad_elts_rem
|
||||
.iter()
|
||||
.map(|v| builder.constant(*v))
|
||||
.chain(statements_rev_flattened)
|
||||
.collect_vec();
|
||||
let id =
|
||||
Self::hash_from_state::<PoseidonHash, PoseidonPermutation<F>>(builder, perm, &inputs);
|
||||
|
||||
measure_gates_end!(builder, measure);
|
||||
id
|
||||
}
|
||||
}
|
||||
|
||||
struct MainPodVerifyGadget {
|
||||
params: Params,
|
||||
}
|
||||
|
|
@ -1089,10 +1178,10 @@ impl MainPodVerifyGadget {
|
|||
self.build_custom_predicate_verification_table(builder, &custom_predicate_table)?;
|
||||
|
||||
// 2. Calculate the Pod Id from the public statements
|
||||
let measure_calc_id = measure_gates_begin!(builder, "MainPodId");
|
||||
let pub_statements_flattened = pub_statements.iter().flat_map(|s| s.flatten()).collect();
|
||||
let id = builder.hash_n_to_hash_no_pad::<PoseidonHash>(pub_statements_flattened);
|
||||
measure_gates_end!(builder, measure_calc_id);
|
||||
let id = CalculateIdGadget {
|
||||
params: self.params.clone(),
|
||||
}
|
||||
.eval(builder, pub_statements);
|
||||
|
||||
// 4. Verify type
|
||||
let type_statement = &pub_statements[0];
|
||||
|
|
@ -1266,10 +1355,12 @@ impl MainPodVerifyCircuit {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Not;
|
||||
use std::{iter, ops::Not};
|
||||
|
||||
use plonky2::{
|
||||
field::{goldilocks_field::GoldilocksField, types::Field},
|
||||
hash::hash_types::HashOut,
|
||||
iop::witness::WitnessWrite,
|
||||
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig},
|
||||
};
|
||||
|
||||
|
|
@ -1278,7 +1369,7 @@ mod tests {
|
|||
backends::plonky2::{
|
||||
basetypes::C,
|
||||
circuits::common::tests::I64_TEST_PAIRS,
|
||||
mainpod::{OperationArg, OperationAux},
|
||||
mainpod::{calculate_id, OperationArg, OperationAux},
|
||||
primitives::merkletree::{MerkleClaimAndProof, MerkleTree},
|
||||
},
|
||||
frontend::{self, key, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
|
||||
|
|
@ -2681,4 +2772,106 @@ mod tests {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn helper_calculate_id(params: &Params, statements: &[Statement]) -> Result<()> {
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
let gadget = CalculateIdGadget {
|
||||
params: params.clone(),
|
||||
};
|
||||
|
||||
let statements_target = (0..params.max_public_statements)
|
||||
.map(|_| builder.add_virtual_statement(params))
|
||||
.collect_vec();
|
||||
let id_target = gadget.eval(&mut builder, &statements_target);
|
||||
|
||||
let mut pw = PartialWitness::<F>::new();
|
||||
|
||||
// Input
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|st| {
|
||||
let mut st = mainpod::Statement::from(st.clone());
|
||||
pad_statement(params, &mut st);
|
||||
st
|
||||
})
|
||||
.collect_vec();
|
||||
for (st_target, st) in statements_target.iter().zip(statements.iter()) {
|
||||
st_target.set_targets(&mut pw, params, st)?;
|
||||
}
|
||||
// Expected Output
|
||||
let expected_id = calculate_id(&statements, params);
|
||||
pw.set_hash_target(
|
||||
id_target,
|
||||
HashOut {
|
||||
elements: expected_id.0,
|
||||
},
|
||||
)?;
|
||||
|
||||
// generate & verify proof
|
||||
let data = builder.build::<C>();
|
||||
let proof = data.prove(pw)?;
|
||||
Ok(data.verify(proof.clone())?)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_id() -> frontend::Result<()> {
|
||||
// Case with no public public statements
|
||||
let params = Params {
|
||||
max_public_statements: 0,
|
||||
num_public_statements_id: 8,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
helper_calculate_id(¶ms, &[]).unwrap();
|
||||
|
||||
// Case with number of statements for the id equal to number of public statements
|
||||
let params = Params {
|
||||
max_public_statements: 2,
|
||||
num_public_statements_id: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let statements = [
|
||||
Statement::ValueOf(AnchoredKey::from((SELF, "foo")), Value::from(42)),
|
||||
Statement::Equal(
|
||||
AnchoredKey::from((SELF, "bar")),
|
||||
AnchoredKey::from((SELF, "baz")),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.chain(iter::repeat(Statement::None))
|
||||
.take(params.max_public_statements)
|
||||
.collect_vec();
|
||||
|
||||
helper_calculate_id(¶ms, &statements).unwrap();
|
||||
|
||||
// Case with more statements for the id than the number of public statements
|
||||
let params = Params {
|
||||
max_public_statements: 4,
|
||||
num_public_statements_id: 6,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let pod_id = PodId(hash_str("pod_id"));
|
||||
let statements = [
|
||||
Statement::ValueOf(AnchoredKey::from((SELF, "foo")), Value::from(42)),
|
||||
Statement::Equal(
|
||||
AnchoredKey::from((SELF, "bar")),
|
||||
AnchoredKey::from((SELF, "baz")),
|
||||
),
|
||||
Statement::Lt(
|
||||
AnchoredKey::from((pod_id, "one")),
|
||||
AnchoredKey::from((pod_id, "two")),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.chain(iter::repeat(Statement::None))
|
||||
.take(params.max_public_statements)
|
||||
.collect_vec();
|
||||
|
||||
helper_calculate_id(¶ms, &statements).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
pub mod operation;
|
||||
pub mod statement;
|
||||
use std::{any::Any, sync::Arc};
|
||||
use std::{any::Any, iter, sync::Arc};
|
||||
|
||||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
use itertools::Itertools;
|
||||
|
|
@ -35,11 +35,30 @@ use crate::{
|
|||
},
|
||||
};
|
||||
|
||||
/// Hash a list of public statements to derive the PodId
|
||||
pub(crate) fn hash_statements(statements: &[Statement], _params: &Params) -> middleware::Hash {
|
||||
let field_elems = statements
|
||||
/// Hash a list of public statements to derive the PodId. To make circuits with different number
|
||||
/// of `max_public_statements compatible we pad the statements up to `num_public_statements_id`.
|
||||
/// As an optimization we front pad with none-statements so that circuits with a small
|
||||
/// `max_public_statements` only pay for `max_public_statements` by starting the poseidon state
|
||||
/// with a precomputed constant corresponding to the front-padding part:
|
||||
/// `id = hash(serialize(reverse(statements || none-statements)))`
|
||||
pub(crate) fn calculate_id(statements: &[Statement], params: &Params) -> middleware::Hash {
|
||||
assert_eq!(params.max_public_statements, statements.len());
|
||||
assert!(params.max_public_statements <= params.num_public_statements_id);
|
||||
statements
|
||||
.iter()
|
||||
.flat_map(|statement| statement.clone().to_fields(_params))
|
||||
.for_each(|st| assert_eq!(params.max_statement_args, st.1.len()));
|
||||
|
||||
let mut none_st: Statement = middleware::Statement::None.into();
|
||||
pad_statement(params, &mut none_st);
|
||||
let statements_back_padded = statements
|
||||
.iter()
|
||||
.chain(iter::repeat(&none_st))
|
||||
.take(params.num_public_statements_id)
|
||||
.collect_vec();
|
||||
let field_elems = statements_back_padded
|
||||
.iter()
|
||||
.rev()
|
||||
.flat_map(|statement| statement.to_fields(params))
|
||||
.collect::<Vec<_>>();
|
||||
Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
|
||||
}
|
||||
|
|
@ -421,7 +440,7 @@ impl Prover {
|
|||
let public_statements =
|
||||
statements[statements.len() - params.max_public_statements..].to_vec();
|
||||
// get the id out of the public statements
|
||||
let id: PodId = PodId(hash_statements(&public_statements, params));
|
||||
let id: PodId = PodId(calculate_id(&public_statements, params));
|
||||
|
||||
let input = MainPodVerifyInput {
|
||||
signed_pods: signed_pods_input,
|
||||
|
|
@ -505,7 +524,7 @@ fn get_common_data(params: &Params) -> Result<CommonCircuitData<F, D>, Error> {
|
|||
impl MainPod {
|
||||
fn _verify(&self) -> Result<()> {
|
||||
// 2. get the id out of the public statements
|
||||
let id: PodId = PodId(hash_statements(&self.public_statements, &self.params));
|
||||
let id: PodId = PodId(calculate_id(&self.public_statements, &self.params));
|
||||
if id != self.id {
|
||||
return Err(Error::id_not_equal(self.id, id));
|
||||
}
|
||||
|
|
@ -700,6 +719,7 @@ pub mod tests {
|
|||
max_statements: 5,
|
||||
max_signed_pod_values: 2,
|
||||
max_public_statements: 2,
|
||||
num_public_statements_id: 4,
|
||||
max_statement_args: 2,
|
||||
max_operation_args: 3,
|
||||
max_custom_predicate_batches: 2,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use crate::{
|
|||
backends::plonky2::{
|
||||
error::{Error, Result},
|
||||
mainpod::{
|
||||
extract_merkle_proofs, hash_statements, layout_statements, normalize_statement,
|
||||
calculate_id, extract_merkle_proofs, layout_statements, normalize_statement,
|
||||
process_private_statements_operations, process_public_statements_operations, Operation,
|
||||
Statement,
|
||||
},
|
||||
|
|
@ -163,7 +163,7 @@ impl MockMainPod {
|
|||
statements[statements.len() - params.max_public_statements..].to_vec();
|
||||
|
||||
// get the id out of the public statements
|
||||
let id: PodId = PodId(hash_statements(&public_statements, params));
|
||||
let id: PodId = PodId(calculate_id(&public_statements, params));
|
||||
|
||||
Ok(Self {
|
||||
params: params.clone(),
|
||||
|
|
@ -197,7 +197,7 @@ impl MockMainPod {
|
|||
// get the input_statements from the self.statements
|
||||
let input_statements = &self.statements[input_statement_offset..];
|
||||
// 2. get the id out of the public statements, and ensure it is equal to self.id
|
||||
let ids_match = self.id == PodId(hash_statements(&self.public_statements, &self.params));
|
||||
let ids_match = self.id == PodId(calculate_id(&self.public_statements, &self.params));
|
||||
// find a ValueOf statement from the public statements with key=KEY_TYPE and check that the
|
||||
// value is PodType::MockMainPod
|
||||
let has_type_statement = self.public_statements.iter().any(|s| {
|
||||
|
|
@ -351,8 +351,7 @@ pub mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_mock_main_great_boy() -> frontend::Result<()> {
|
||||
let params = middleware::Params::default();
|
||||
let great_boy_builder = great_boy_pod_full_flow()?;
|
||||
let (params, great_boy_builder) = great_boy_pod_full_flow()?;
|
||||
|
||||
let mut prover = MockProver {};
|
||||
let great_boy_pod = great_boy_builder.prove(&mut prover, ¶ms)?;
|
||||
|
|
|
|||
|
|
@ -296,11 +296,12 @@ pub fn great_boy_pod_builder(
|
|||
Ok(great_boy)
|
||||
}
|
||||
|
||||
pub fn great_boy_pod_full_flow() -> Result<MainPodBuilder> {
|
||||
pub fn great_boy_pod_full_flow() -> Result<(Params, MainPodBuilder)> {
|
||||
let params = Params {
|
||||
max_input_signed_pods: 6,
|
||||
max_statements: 100,
|
||||
max_public_statements: 50,
|
||||
num_public_statements_id: 50,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
|
@ -349,7 +350,7 @@ pub fn great_boy_pod_full_flow() -> Result<MainPodBuilder> {
|
|||
good_boy_issuers.into_iter().map(Value::from).collect(),
|
||||
)?);
|
||||
|
||||
great_boy_pod_builder(
|
||||
let builder = great_boy_pod_builder(
|
||||
¶ms,
|
||||
[
|
||||
&bob_good_boys[0],
|
||||
|
|
@ -360,7 +361,9 @@ pub fn great_boy_pod_full_flow() -> Result<MainPodBuilder> {
|
|||
[&alice_friend_pods[0], &alice_friend_pods[1]],
|
||||
&good_boy_issuers,
|
||||
alice,
|
||||
)
|
||||
)?;
|
||||
|
||||
Ok((params, builder))
|
||||
}
|
||||
|
||||
// Tickets
|
||||
|
|
|
|||
|
|
@ -940,7 +940,7 @@ pub mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_front_great_boy() -> Result<()> {
|
||||
let great_boy = great_boy_pod_full_flow()?;
|
||||
let (_, great_boy) = great_boy_pod_full_flow()?;
|
||||
println!("{}", great_boy);
|
||||
|
||||
// TODO: prove great_boy with MockProver and print it
|
||||
|
|
|
|||
|
|
@ -582,6 +582,9 @@ pub struct Params {
|
|||
pub max_statements: usize,
|
||||
pub max_signed_pod_values: usize,
|
||||
pub max_public_statements: usize,
|
||||
// Number of public statements to hash to calculate the id. Must be equal or greater than
|
||||
// `max_public_statements`.
|
||||
pub num_public_statements_id: usize,
|
||||
pub max_statement_args: usize,
|
||||
pub max_operation_args: usize,
|
||||
// max number of custom predicates batches that a MainPod can use
|
||||
|
|
@ -607,6 +610,7 @@ impl Default for Params {
|
|||
max_statements: 20,
|
||||
max_signed_pod_values: 8,
|
||||
max_public_statements: 10,
|
||||
num_public_statements_id: 16,
|
||||
max_statement_args: 5,
|
||||
max_operation_args: 5,
|
||||
max_custom_predicate_batches: 2,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue