Integrate recursion into MainPod (#243)

* calculate MainPod id in a dynamic-friendly way

The MainPod id is now calculated with front padding and a fixed size
independent of max_public_statements so that introduction gadgets can be
verified by a MainPod while paying only for the number of statements
they use.  This is because with front padding of none-statements we can
precompute the poseidon state corresponding to absorbing all the padding
statements and only pay constraints for the non-padding statements.

The id is calculated as follows:
`id = hash(serialize(reverse(statements || none-statements)))`

* add time feature and disable timing by default

* apply suggestions from @arnaucube

* link issues in todos
This commit is contained in:
Eduard S. 2025-05-29 17:10:19 +02:00 committed by GitHub
parent d3fef8392e
commit 88a75986b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1405 additions and 729 deletions

View file

@ -10,14 +10,14 @@ use plonky2::{
},
iop::{
target::{BoolTarget, Target},
witness::PartialWitness,
witness::{PartialWitness, WitnessWrite},
},
plonk::{circuit_builder::CircuitBuilder, config::AlgebraicHasher},
plonk::config::AlgebraicHasher,
};
use crate::{
backends::plonky2::{
basetypes::D,
basetypes::CircuitBuilder,
circuits::{
common::{
CircuitBuilderPod, CustomPredicateBatchTarget, CustomPredicateEntryTarget,
@ -28,18 +28,21 @@ use crate::{
},
signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget},
},
emptypod::EmptyPod,
error::Result,
mainpod::{self, pad_statement},
primitives::merkletree::{
MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProofGadget,
},
recursion::{InnerCircuit, VerifiedProofTarget},
signedpod::SignedPod,
},
measure_gates_begin, measure_gates_end,
middleware::{
AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
NativePredicate, Params, PodType, PredicatePrefix, Statement, StatementArg, ToFields,
Value, WildcardValue, F, KEY_TYPE, SELF, VALUE_SIZE,
AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Hash,
NativeOperation, NativePredicate, Params, PodType, PredicatePrefix, Statement,
StatementArg, ToFields, Value, WildcardValue, EMPTY_VALUE, F, HASH_SIZE, KEY_TYPE, SELF,
VALUE_SIZE,
},
};
@ -47,6 +50,13 @@ use crate::{
// MainPod verification
//
/// Offset in public inputs where we store the pod id
pub const PI_OFFSET_ID: usize = 0;
/// Offset in public inputs where we store the verified data array root
pub const PI_OFFSET_VDSROOT: usize = 4;
pub const NUM_PUBLIC_INPUTS: usize = 8;
struct OperationVerifyGadget {
params: Params,
}
@ -58,7 +68,7 @@ impl OperationVerifyGadget {
/// argument.
fn first_n_args_as_values<const N: usize>(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
resolved_op_args: &[StatementTarget],
) -> (BoolTarget, [ValueTarget; N]) {
let arg_is_valueof = resolved_op_args[..N]
@ -80,7 +90,7 @@ impl OperationVerifyGadget {
fn eval(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op: &OperationTarget,
prev_statements: &[StatementTarget],
@ -212,7 +222,7 @@ impl OperationVerifyGadget {
fn eval_contains_from_entries(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_merkle_claim: MerkleClaimTarget,
@ -260,7 +270,7 @@ impl OperationVerifyGadget {
fn eval_not_contains_from_entries(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_merkle_claim: MerkleClaimTarget,
@ -306,7 +316,7 @@ impl OperationVerifyGadget {
fn eval_custom(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_custom_pred_verification: HashOutTarget,
@ -331,7 +341,7 @@ impl OperationVerifyGadget {
/// NotEqualFromEntries.
fn eval_eq_neq_from_entries(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
@ -382,7 +392,7 @@ impl OperationVerifyGadget {
/// LtEqFromEntries.
fn eval_lt_lteq_from_entries(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
@ -451,7 +461,7 @@ impl OperationVerifyGadget {
fn eval_hash_of(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
@ -485,7 +495,7 @@ impl OperationVerifyGadget {
fn eval_sum_of(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
@ -524,7 +534,7 @@ impl OperationVerifyGadget {
fn eval_product_of(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
@ -563,7 +573,7 @@ impl OperationVerifyGadget {
fn eval_max_of(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
@ -609,7 +619,7 @@ impl OperationVerifyGadget {
fn eval_transitive_eq(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
@ -645,7 +655,7 @@ impl OperationVerifyGadget {
}
fn eval_none(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
) -> BoolTarget {
@ -663,7 +673,7 @@ impl OperationVerifyGadget {
fn eval_new_entry(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
prev_statements: &[StatementTarget],
@ -701,7 +711,7 @@ impl OperationVerifyGadget {
fn eval_lt_to_neq(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
@ -730,7 +740,7 @@ impl OperationVerifyGadget {
fn eval_copy(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st: &StatementTarget,
op_type: &OperationTypeTarget,
resolved_op_args: &[StatementTarget],
@ -759,7 +769,7 @@ struct CustomOperationVerifyGadget {
impl CustomOperationVerifyGadget {
fn statement_arg_from_template(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st_tmpl_arg: &StatementTmplArgTarget,
args: &[ValueTarget],
) -> StatementArgTarget {
@ -812,7 +822,7 @@ impl CustomOperationVerifyGadget {
fn statement_from_template(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st_tmpl: &StatementTmplTarget,
args: &[ValueTarget],
) -> StatementTarget {
@ -836,7 +846,7 @@ impl CustomOperationVerifyGadget {
/// - Build the expected operation type
fn eval(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
custom_predicate: &CustomPredicateEntryTarget,
op_args: &[StatementTarget],
args: &[ValueTarget], // arguments to the custom predicate, public and private
@ -901,10 +911,49 @@ impl CustomOperationVerifyGadget {
}
}
struct CalculateIdGadget {
/// Replace references to SELF by `self_id` in a statement.
struct NormalizeStatementGadget {
params: Params,
}
impl NormalizeStatementGadget {
fn eval(
&self,
builder: &mut CircuitBuilder,
statement: &StatementTarget,
self_id: &ValueTarget,
) -> StatementTarget {
let zero_value = builder.constant_value(EMPTY_VALUE);
let self_value = builder.constant_value(SELF.0.into());
let args = statement
.args
.iter()
.map(|arg| {
let first = ValueTarget::from_slice(&arg.elements[..VALUE_SIZE]);
let second = ValueTarget::from_slice(&arg.elements[VALUE_SIZE..]);
let is_not_ak = builder.is_equal_flattenable(&zero_value, &second);
let is_ak = builder.not(is_not_ak);
let is_self = builder.is_equal_flattenable(&self_value, &first);
let normalize = builder.and(is_ak, is_self);
let first_normalized =
builder.select_flattenable(&self.params, normalize, self_id, &first);
StatementArgTarget::new(first_normalized, second)
})
.collect_vec();
StatementTarget {
predicate: statement.predicate.clone(),
args,
}
}
}
pub struct CalculateIdGadget {
/// `params.num_public_statements_id` is the total number of statements that will be hashed.
/// The id is calculated with front-padded none-statements and then the input statements
/// reversed. The part of the hash from the front-padded none-statements is precomputed.
pub params: Params,
}
impl CalculateIdGadget {
/// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder
/// elements that didn't fit into a chunk.
@ -923,7 +972,7 @@ impl CalculateIdGadget {
/// Hash `inputs` starting from a circuit-constant `perm` state.
fn hash_from_state<H: AlgebraicHasher<F>, P: PlonkyPermutation<F>>(
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
perm: P,
inputs: &[Target],
) -> HashOutTarget {
@ -953,17 +1002,19 @@ impl CalculateIdGadget {
}
}
fn eval(
pub fn eval(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
// These statements will be padded to reach `self.num_statements`
statements: &[StatementTarget],
) -> HashOutTarget {
assert!(statements.len() <= self.params.num_public_statements_id);
let measure = measure_gates_begin!(builder, "CalculateId");
let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten());
let mut none_st = mainpod::Statement::from(Statement::None);
pad_statement(&self.params, &mut none_st);
let front_pad_elts = iter::repeat(&none_st)
.take(self.params.num_public_statements_id - self.params.max_public_statements)
.take(self.params.num_public_statements_id - statements.len())
.flat_map(|s| s.to_fields(&self.params))
.collect_vec();
let (perm, front_pad_elts_rem) =
@ -992,7 +1043,7 @@ impl MainPodVerifyGadget {
// index
fn normalize_st_tmpl(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
st_tmpl: &StatementTmplTarget,
id: HashOutTarget,
) -> StatementTmplTarget {
@ -1012,7 +1063,7 @@ impl MainPodVerifyGadget {
/// calculate the id of each batch.
fn build_custom_predicate_table(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
) -> Result<(Vec<HashOutTarget>, Vec<CustomPredicateBatchTarget>)> {
let measure = measure_gates_begin!(builder, "BuildCustomPredicateTable");
let params = &self.params;
@ -1053,7 +1104,7 @@ impl MainPodVerifyGadget {
/// custom predicate against the operation and statement.
fn build_custom_predicate_verification_table(
&self,
builder: &mut CircuitBuilder<F, D>,
builder: &mut CircuitBuilder,
custom_predicate_table: &[HashOutTarget],
) -> Result<(Vec<HashOutTarget>, Vec<CustomPredicateVerifyEntryTarget>)> {
let measure = measure_gates_begin!(builder, "BuildCustomPredicateVerificationTable");
@ -1105,7 +1156,13 @@ impl MainPodVerifyGadget {
))
}
fn eval(&self, builder: &mut CircuitBuilder<F, D>) -> Result<MainPodVerifyTarget> {
fn eval(
&self,
builder: &mut CircuitBuilder,
verified_proofs: &[VerifiedProofTarget],
) -> Result<MainPodVerifyTarget> {
assert_eq!(self.params.max_input_recursive_pods, verified_proofs.len());
let measure = measure_gates_begin!(builder, "MainPodVerify");
let params = &self.params;
// 1. Verify all input signed pods
@ -1115,6 +1172,7 @@ impl MainPodVerifyGadget {
params: params.clone(),
}
.eval(builder)?;
builder.assert_one(signed_pod.signature.enabled.target);
signed_pods.push(signed_pod);
}
@ -1132,16 +1190,47 @@ impl MainPodVerifyGadget {
statements.len(),
1 + self.params.max_input_signed_pods * self.params.max_signed_pod_values
);
// TODO: Fill with input main pods
for _main_pod in 0..self.params.max_input_main_pods {
for _statement in 0..self.params.max_public_statements {
statements.push(StatementTarget::new_native(
builder,
&self.params,
NativePredicate::None,
&[],
))
let id_gadget = CalculateIdGadget {
params: params.clone(),
};
let mut input_pods_self_statements: Vec<Vec<StatementTarget>> = Vec::new();
let normalize_statement_gadget = NormalizeStatementGadget {
params: self.params.clone(),
};
for verified_proof in verified_proofs {
let expected_id = HashOutTarget::try_from(
&verified_proof.public_inputs[PI_OFFSET_ID..PI_OFFSET_ID + HASH_SIZE],
)
.expect("4 elements");
let id_value = ValueTarget {
elements: expected_id.elements,
};
let mut input_pod_self_statements = Vec::new();
for _ in 0..self.params.max_input_pods_public_statements {
let self_st = builder.add_virtual_statement(params);
let normalized_st = normalize_statement_gadget.eval(builder, &self_st, &id_value);
input_pod_self_statements.push(self_st);
statements.push(normalized_st);
}
let id = id_gadget.eval(builder, &input_pod_self_statements);
builder.connect_hashes(expected_id, id);
input_pods_self_statements.push(input_pod_self_statements);
}
let vds_root = builder.add_virtual_hash();
// TODO: verify that all input pod proofs use verifier data from the public input VD array
// This requires merkle proofs
// https://github.com/0xPARC/pod2/issues/250
// Verify that VD array that input pod uses is the same we use now.
for verified_proof in verified_proofs {
let verified_proof_vds_root = HashOutTarget::try_from(
&verified_proof.public_inputs[PI_OFFSET_VDSROOT..PI_OFFSET_VDSROOT + HASH_SIZE],
)
.expect("4 elements");
builder.connect_hashes(vds_root, verified_proof_vds_root);
}
// Add the input (private and public) statements and corresponding operations
@ -1220,8 +1309,10 @@ impl MainPodVerifyGadget {
measure_gates_end!(builder, measure);
Ok(MainPodVerifyTarget {
params: params.clone(),
vds_root,
id,
signed_pods,
input_pods_self_statements,
statements: input_statements.to_vec(),
operations,
merkle_proofs,
@ -1233,8 +1324,10 @@ impl MainPodVerifyGadget {
pub struct MainPodVerifyTarget {
params: Params,
vds_root: HashOutTarget,
id: HashOutTarget,
signed_pods: Vec<SignedPodVerifyTarget>,
input_pods_self_statements: Vec<Vec<StatementTarget>>,
// The KEY_TYPE statement must be the first public one
statements: Vec<StatementTarget>,
operations: Vec<OperationTarget>,
@ -1251,7 +1344,9 @@ pub struct CustomPredicateVerification {
}
pub struct MainPodVerifyInput {
pub vds_root: Hash,
pub signed_pods: Vec<SignedPod>,
pub recursive_pods_pub_self_statements: Vec<Vec<Statement>>,
pub statements: Vec<mainpod::Statement>,
pub operations: Vec<mainpod::Operation>,
pub merkle_proofs: Vec<MerkleClaimAndProof>,
@ -1259,18 +1354,44 @@ pub struct MainPodVerifyInput {
pub custom_predicate_verifications: Vec<CustomPredicateVerification>,
}
fn set_targets_input_pods_self_statements(
pw: &mut PartialWitness<F>,
params: &Params,
statements_target: &[StatementTarget],
statements: &[Statement],
) -> Result<()> {
assert_eq!(
statements_target.len(),
params.max_input_pods_public_statements
);
assert!(statements.len() <= params.num_public_statements_id);
for (i, statement) in statements.iter().enumerate() {
statements_target[i].set_targets(pw, params, &statement.clone().into())?;
}
// Padding
let mut none_st = mainpod::Statement::from(Statement::None);
pad_statement(params, &mut none_st);
for statement_target in statements_target.iter().skip(statements.len()) {
statement_target.set_targets(pw, params, &none_st)?;
}
Ok(())
}
impl MainPodVerifyTarget {
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
input: &MainPodVerifyInput,
) -> Result<()> {
pw.set_target_arr(&self.vds_root.elements, &input.vds_root.0)?;
assert!(input.signed_pods.len() <= self.params.max_input_signed_pods);
for (i, signed_pod) in input.signed_pods.iter().enumerate() {
self.signed_pods[i].set_targets(pw, signed_pod)?;
}
// Padding
if self.params.max_input_signed_pods > 0 {
if input.signed_pods.len() != self.params.max_input_signed_pods {
// TODO: Instead of using an input for padding, use a canonical minimal SignedPod,
// without it a MainPod configured to support input signed pods must have at least one
// input signed pod :(
@ -1279,6 +1400,34 @@ impl MainPodVerifyTarget {
self.signed_pods[i].set_targets(pw, pad_pod)?;
}
}
assert!(
input.recursive_pods_pub_self_statements.len() <= self.params.max_input_recursive_pods
);
for (i, pod_pub_statements) in input.recursive_pods_pub_self_statements.iter().enumerate() {
set_targets_input_pods_self_statements(
pw,
&self.params,
&self.input_pods_self_statements[i],
pod_pub_statements,
)?;
}
// Padding
if input.recursive_pods_pub_self_statements.len() != self.params.max_input_recursive_pods {
let empty_pod = EmptyPod::new_boxed(&self.params, input.vds_root);
let empty_pod_statements = empty_pod.pub_statements();
for i in
input.recursive_pods_pub_self_statements.len()..self.params.max_input_recursive_pods
{
set_targets_input_pods_self_statements(
pw,
&self.params,
&self.input_pods_self_statements[i],
&empty_pod_statements,
)?;
}
}
assert_eq!(input.statements.len(), self.params.max_statements);
for (i, (st, op)) in zip_eq(&input.statements, &input.operations).enumerate() {
self.statements[i].set_targets(pw, &self.params, st)?;
@ -1342,17 +1491,44 @@ pub struct MainPodVerifyCircuit {
pub params: Params,
}
// TODO: Remove this type and implement it's logic directly in `impl InnerCircuit for MainPodVerifyTarget`
impl MainPodVerifyCircuit {
pub fn eval(&self, builder: &mut CircuitBuilder<F, D>) -> Result<MainPodVerifyTarget> {
pub fn eval(
&self,
builder: &mut CircuitBuilder,
verified_proofs: &[VerifiedProofTarget],
) -> Result<MainPodVerifyTarget> {
let main_pod = MainPodVerifyGadget {
params: self.params.clone(),
}
.eval(builder)?;
.eval(builder, verified_proofs)?;
builder.register_public_inputs(&main_pod.id.elements);
builder.register_public_inputs(&main_pod.vds_root.elements);
Ok(main_pod)
}
}
impl InnerCircuit for MainPodVerifyTarget {
type Input = MainPodVerifyInput;
type Params = Params;
fn build(
builder: &mut CircuitBuilder,
params: &Self::Params,
verified_proofs: &[VerifiedProofTarget],
) -> Result<Self> {
MainPodVerifyCircuit {
params: params.clone(),
}
.eval(builder, verified_proofs)
}
/// assigns the values to the targets
fn set_targets(&self, pw: &mut PartialWitness<F>, input: &Self::Input) -> Result<()> {
self.set_targets(pw, input)
}
}
#[cfg(test)]
mod tests {
use std::{iter, ops::Not};
@ -1395,7 +1571,7 @@ mod tests {
};
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut builder = CircuitBuilder::new(config);
let st_target = builder.add_virtual_statement(&params);
let op_target = builder.add_virtual_operation(&params);
@ -2268,7 +2444,7 @@ mod tests {
expected_st_arg: StatementArg,
) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut builder = CircuitBuilder::new(config);
let gadget = CustomOperationVerifyGadget {
params: params.clone(),
};
@ -2369,7 +2545,7 @@ mod tests {
expected_st: Statement,
) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut builder = CircuitBuilder::new(config);
let gadget = CustomOperationVerifyGadget {
params: params.clone(),
};
@ -2433,7 +2609,7 @@ mod tests {
expected_st: Option<Statement>,
) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut builder = CircuitBuilder::new(config);
let gadget = CustomOperationVerifyGadget {
params: params.clone(),
};
@ -2775,7 +2951,7 @@ mod tests {
fn helper_calculate_id(params: &Params, statements: &[Statement]) -> Result<()> {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut builder = CircuitBuilder::new(config);
let gadget = CalculateIdGadget {
params: params.clone(),
};