Use predicate hash in statements instead of the literal predicate

Resolve #448 

Previously a predicate was 6 elements.  Now it grows to 8 elements; and the hash is 4 elements.

Some parts of the circuit require only require equality checks with the predicate: that works with the predicate hash.  Other parts require inspecting or working with particular elements in the predicate, those need the preimage of the predicate hash.
Both `StatementTarget` and `StatementTmplTarget` have been updated to include the predicate hash and optionally the predicate.  When the predicate is included, constraints are automatically generated for `pred_hash = hash(pred)`.  We only include the predicate when needed.
This commit is contained in:
Eduard S. 2026-01-19 11:02:11 +01:00 committed by GitHub
parent 2eb1daeb92
commit 0fca00cc93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 319 additions and 159 deletions

View file

@ -9,7 +9,7 @@ use plonky2::{
types::{Field, PrimeField64},
},
hash::{
hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS},
hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS},
poseidon::PoseidonHash,
},
iop::{
@ -17,6 +17,7 @@ use plonky2::{
target::{BoolTarget, Target},
witness::{PartialWitness, PartitionWitness, Witness, WitnessWrite},
},
plonk::config::Hasher,
util::serialization::{Buffer, IoResult, Read, Write},
};
use serde::{Deserialize, Serialize};
@ -136,10 +137,99 @@ impl StatementArgTarget {
#[derive(Clone, Serialize, Deserialize)]
pub struct StatementTarget {
pub predicate: PredicateTarget,
// If the pred is Some, then the `pred_hash` is constrained to be the `hash(pred)`.
pred: Option<PredicateTarget>,
pred_hash: HashOutTarget,
pub args: Vec<StatementArgTarget>,
}
impl StatementTarget {
pub fn pred(&self) -> Option<&PredicateTarget> {
self.pred.as_ref()
}
pub fn pred_hash(&self) -> &HashOutTarget {
&self.pred_hash
}
pub fn new(pred_hash: HashOutTarget, args: Vec<StatementArgTarget>) -> Self {
Self {
pred: None,
pred_hash,
args,
}
}
pub fn new_with_pred(
builder: &mut CircuitBuilder,
params: &Params,
predicate: impl Build<PredicateTarget>,
args: &[StatementArgTarget],
) -> Self {
let pred = predicate.build(builder, params);
let pred_hash = pred.hash(builder);
Self {
pred: Some(pred),
pred_hash,
args: args
.iter()
.cloned()
.chain(iter::repeat_with(|| StatementArgTarget::none(builder)))
.take(params.max_statement_args)
.collect(),
}
}
pub fn new_native(
builder: &mut CircuitBuilder,
params: &Params,
native_predicate: impl Build<NativePredicateTarget>,
args: &[StatementArgTarget],
) -> Self {
let pred = PredicateTarget::new_native(builder, params, native_predicate);
Self::new_with_pred(builder, params, pred, args)
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
st: &Statement,
) -> Result<()> {
if let Some(pred) = &self.pred {
pred.set_targets(pw, params, &st.predicate())?;
}
pw.set_hash_target(self.pred_hash, HashOut::from(st.predicate().hash(params)))?;
for (i, arg) in st
.args()
.iter()
.chain(iter::repeat(&StatementArg::None))
.take(params.max_statement_args)
.enumerate()
{
self.args[i].set_targets(pw, params, arg)?;
}
Ok(())
}
pub fn pred_is_blank_intro(&self, builder: &mut CircuitBuilder) -> BoolTarget {
let zero_hash = builder.constant_hash(HashOut {
elements: [F::ZERO, F::ZERO, F::ZERO, F::ZERO],
});
let blank_intro = PredicateTarget::new_intro(builder, zero_hash).hash(builder);
builder.is_equal_flattenable(&self.pred_hash, &blank_intro)
}
pub fn has_native_type(
&self,
builder: &mut CircuitBuilder,
params: &Params,
t: NativePredicate,
) -> BoolTarget {
let expected_predicate_hash =
builder.constant_hash(HashOut::from(Predicate::Native(t).hash(params)));
builder.is_equal_flattenable(&self.pred_hash, &expected_predicate_hash)
}
}
pub trait Build<T> {
fn build(self, builder: &mut CircuitBuilder, params: &Params) -> T;
}
@ -156,57 +246,6 @@ impl<T> Build<T> for T {
}
}
impl StatementTarget {
/// Build a new native StatementTarget. Pads the arguments.
pub fn new_native(
builder: &mut CircuitBuilder,
params: &Params,
native_predicate: impl Build<NativePredicateTarget>,
args: &[StatementArgTarget],
) -> Self {
// if native_predicate is const then NativePredicate -> NativePredicateTarget
// else just use as is
Self {
predicate: PredicateTarget::new_native(builder, params, native_predicate),
args: args
.iter()
.cloned()
.chain(iter::repeat_with(|| StatementArgTarget::none(builder)))
.take(params.max_statement_args)
.collect(),
}
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
st: &Statement,
) -> Result<()> {
self.predicate.set_targets(pw, params, st.predicate())?;
for (i, arg) in st
.args()
.iter()
.chain(iter::repeat(&StatementArg::None))
.take(params.max_statement_args)
.enumerate()
{
self.args[i].set_targets(pw, params, arg)?;
}
Ok(())
}
pub fn has_native_type(
&self,
builder: &mut CircuitBuilder,
params: &Params,
t: NativePredicate,
) -> BoolTarget {
let expected_predicate = PredicateTarget::new_native(builder, params, t);
builder.is_equal_flattenable(&self.predicate, &expected_predicate)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct OperationTypeTarget {
#[serde(with = "serde_arrays")]
@ -336,7 +375,7 @@ impl PredicateTarget {
let id = native_predicate.build(builder, params).0;
let zero = builder.zero();
Self {
elements: [prefix, id, zero, zero, zero, zero],
elements: [prefix, id, zero, zero, zero, zero, zero, zero],
}
}
@ -344,7 +383,7 @@ impl PredicateTarget {
let prefix = builder.constant(F::from(PredicatePrefix::BatchSelf));
let zero = builder.zero();
Self {
elements: [prefix, index, zero, zero, zero, zero],
elements: [prefix, index, zero, zero, zero, zero, zero, zero],
}
}
@ -355,8 +394,9 @@ impl PredicateTarget {
) -> Self {
let prefix = builder.constant(F::from(PredicatePrefix::Custom));
let id = batch_id.elements;
let zero = builder.zero();
Self {
elements: [prefix, id[0], id[1], id[2], id[3], index],
elements: [prefix, id[0], id[1], id[2], id[3], index, zero, zero],
}
}
@ -365,7 +405,7 @@ impl PredicateTarget {
let vh = vd_hash.elements;
let zero = builder.zero();
Self {
elements: [prefix, vh[0], vh[1], vh[2], vh[3], zero],
elements: [prefix, vh[0], vh[1], vh[2], vh[3], zero, zero, zero],
}
}
@ -378,10 +418,30 @@ impl PredicateTarget {
&self,
pw: &mut PartialWitness<F>,
params: &Params,
predicate: Predicate,
predicate: &Predicate,
) -> Result<()> {
Ok(pw.set_target_arr(&self.elements, &predicate.to_fields(params))?)
}
pub fn hash(&self, builder: &mut CircuitBuilder) -> HashOutTarget {
// Optimization: if all the predicate values are constants we skip the hash circuit and
// return a hash constant
let mut predicate_values = [F::ZERO; Params::predicate_size()];
let mut predicate_constant = true;
for (i, target) in self.elements.iter().enumerate() {
if let Some(v) = builder.target_as_constant(*target) {
predicate_values[i] = v;
} else {
predicate_constant = false;
break;
}
}
if predicate_constant {
builder.constant_hash(PoseidonHash::hash_no_pad(&predicate_values))
} else {
builder.hash_n_to_hash_no_pad::<PoseidonHash>(self.elements.to_vec())
}
}
}
/// Mirrors `middleware::KeyOrWildcard`
@ -466,18 +526,46 @@ impl StatementTmplArgTarget {
#[derive(Clone, Serialize, Deserialize)]
pub struct StatementTmplTarget {
pub pred: PredicateTarget,
pred: Option<PredicateTarget>,
pred_hash: HashOutTarget,
pub args: Vec<StatementTmplArgTarget>,
}
impl StatementTmplTarget {
pub fn new(pred_hash: HashOutTarget, args: Vec<StatementTmplArgTarget>) -> Self {
Self {
pred: None,
pred_hash,
args,
}
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
st_tmpl: &StatementTmpl,
) -> Result<()> {
Ok(pw.set_target_arr(&self.flatten(), &st_tmpl.to_fields(params))?)
if let Some(pred) = &self.pred {
pred.set_targets(pw, params, &st_tmpl.pred)?;
}
pw.set_hash_target(self.pred_hash, HashOut::from(st_tmpl.pred.hash(params)))?;
let arg_pad = StatementTmplArg::None;
for (i, arg) in st_tmpl
.args
.iter()
.chain(iter::repeat(&arg_pad))
.take(params.max_statement_args)
.enumerate()
{
self.args[i].set_targets(pw, params, arg)?;
}
Ok(())
}
pub fn pred(&self) -> Option<&PredicateTarget> {
self.pred.as_ref()
}
pub fn pred_hash(&self) -> &HashOutTarget {
&self.pred_hash
}
}
@ -494,9 +582,24 @@ impl CustomPredicateTarget {
&self,
pw: &mut PartialWitness<F>,
params: &Params,
custom_predicate: &CustomPredicate,
custom_pred: &CustomPredicate,
) -> Result<()> {
Ok(pw.set_target_arr(&self.flatten(), &custom_predicate.to_fields(params))?)
pw.set_target(
self.conjunction.target,
F::from_bool(custom_pred.conjunction),
)?;
let st_tmpl_pad = custom_pred.pad_statement_tmpl();
for (i, st_tmpl) in custom_pred
.statements
.iter()
.chain(iter::repeat(&st_tmpl_pad))
.take(params.max_custom_predicate_arity)
.enumerate()
{
self.statements[i].set_targets(pw, params, st_tmpl)?;
}
pw.set_target(self.args_len, F::from_canonical_usize(custom_pred.args_len))?;
Ok(())
}
}
@ -507,7 +610,7 @@ pub struct CustomPredicateBatchTarget {
impl CustomPredicateBatchTarget {
pub fn id(&self, builder: &mut CircuitBuilder) -> HashOutTarget {
let flattened = self.predicates.iter().flat_map(|cp| cp.flatten()).collect();
let flattened: Vec<_> = self.predicates.iter().flat_map(|cp| cp.flatten()).collect();
builder.hash_n_to_hash_no_pad::<PoseidonHash>(flattened)
}
@ -621,7 +724,7 @@ pub struct CustomPredicateVerifyEntryTarget {
}
impl CustomPredicateVerifyEntryTarget {
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder, with_pred: bool) -> Self {
let custom_predicate_table_len =
params.max_custom_predicate_batches * params.max_custom_batch_size;
CustomPredicateVerifyEntryTarget {
@ -629,12 +732,12 @@ impl CustomPredicateVerifyEntryTarget {
custom_predicate_table_len,
builder,
),
custom_predicate: builder.add_virtual_custom_predicate_entry(params),
custom_predicate: builder.add_virtual_custom_predicate_entry(params, with_pred),
args: (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value())
.collect(),
op_args: (0..params.max_operation_args)
.map(|_| builder.add_virtual_statement(params))
.map(|_| builder.add_virtual_statement(params, false))
.collect(),
}
}
@ -897,7 +1000,7 @@ impl Flattenable for PredicateTarget {
impl Flattenable for StatementTarget {
fn flatten(&self) -> Vec<Target> {
self.predicate
self.pred_hash
.flatten()
.into_iter()
.chain(self.args.iter().flat_map(|a| &a.elements).cloned())
@ -906,20 +1009,22 @@ impl Flattenable for StatementTarget {
fn from_flattened(params: &Params, v: &[Target]) -> Self {
assert_eq!(v.len(), Self::size(params));
let predicate = PredicateTarget::from_flattened(params, &v[..Params::predicate_size()]);
let predicate_hash = HashOutTarget::from_flattened(params, &v[..HASH_SIZE]);
let args = (0..params.max_statement_args)
.map(|i| StatementArgTarget {
elements: array::from_fn(|j| {
v[Params::predicate_size() + i * STATEMENT_ARG_F_LEN + j]
}),
elements: array::from_fn(|j| v[HASH_SIZE + i * STATEMENT_ARG_F_LEN + j]),
})
.collect();
Self { predicate, args }
Self {
pred: None,
pred_hash: predicate_hash,
args,
}
}
fn size(params: &Params) -> usize {
PredicateTarget::size(params) + params.max_statement_args * StatementArgTarget::size(params)
HASH_SIZE + params.max_statement_args * StatementArgTarget::size(params)
}
}
@ -957,7 +1062,7 @@ impl Flattenable for CustomPredicateTarget {
impl Flattenable for StatementTmplTarget {
fn flatten(&self) -> Vec<Target> {
self.pred
self.pred_hash
.flatten()
.into_iter()
.chain(self.args.iter().flat_map(|sta| sta.flatten()))
@ -966,21 +1071,24 @@ impl Flattenable for StatementTmplTarget {
fn from_flattened(params: &Params, v: &[Target]) -> Self {
assert_eq!(v.len(), Self::size(params));
let pred_end = Params::predicate_size();
let pred = PredicateTarget::from_flattened(params, &v[..pred_end]);
let pred_hash_end = HASH_SIZE;
let pred_hash = HashOutTarget::from_flattened(params, &v[..pred_hash_end]);
let sta_size = Params::statement_tmpl_arg_size();
let args = (0..params.max_statement_args)
.map(|i| {
let sta_v = &v[pred_end + sta_size * i..pred_end + sta_size * (i + 1)];
let sta_v = &v[pred_hash_end + sta_size * i..pred_hash_end + sta_size * (i + 1)];
StatementTmplArgTarget::from_flattened(params, sta_v)
})
.collect();
Self { pred, args }
Self {
pred: None,
pred_hash,
args,
}
}
fn size(params: &Params) -> usize {
PredicateTarget::size(params)
+ params.max_statement_args * StatementTmplArgTarget::size(params)
HASH_SIZE + params.max_statement_args * StatementTmplArgTarget::size(params)
}
}
@ -1039,18 +1147,32 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
fn connect_values(&mut self, x: ValueTarget, y: ValueTarget);
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]);
fn add_virtual_value(&mut self) -> ValueTarget;
fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget;
fn add_virtual_statement(&mut self, params: &Params, with_pred: bool) -> StatementTarget;
fn add_virtual_statement_arg(&mut self) -> StatementArgTarget;
fn add_virtual_predicate(&mut self) -> PredicateTarget;
fn add_virtual_operation_type(&mut self) -> OperationTypeTarget;
fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget;
fn add_virtual_statement_tmpl_arg(&mut self) -> StatementTmplArgTarget;
fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget;
fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget;
fn add_virtual_custom_predicate_batch(&mut self, params: &Params)
-> CustomPredicateBatchTarget;
fn add_virtual_custom_predicate_entry(&mut self, params: &Params)
-> CustomPredicateEntryTarget;
fn add_virtual_statement_tmpl(
&mut self,
params: &Params,
with_pred: bool,
) -> StatementTmplTarget;
fn add_virtual_custom_predicate(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateTarget;
fn add_virtual_custom_predicate_batch(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateBatchTarget;
fn add_virtual_custom_predicate_entry(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateEntryTarget;
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
fn select_statement_arg(
&mut self,
@ -1144,10 +1266,20 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
}
}
fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget {
let predicate = self.add_virtual_predicate();
/// If `with_pred = true` a predicate is included and its hash constrained.
/// If `with_pred = false` only the predicate hash is included.
fn add_virtual_statement(&mut self, params: &Params, with_pred: bool) -> StatementTarget {
let (pred, pred_hash) = if with_pred {
let pred = self.add_virtual_predicate();
let pred_hash = pred.hash(self);
(Some(pred), pred_hash)
} else {
let pred_hash = self.add_virtual_hash();
(None, pred_hash)
};
StatementTarget {
predicate,
pred,
pred_hash,
args: (0..params.max_statement_args)
.map(|_| self.add_virtual_statement_arg())
.collect(),
@ -1188,19 +1320,38 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
}
}
fn add_virtual_statement_tmpl(&mut self, params: &Params) -> StatementTmplTarget {
let args = (0..params.max_statement_args)
.map(|_| self.add_virtual_statement_tmpl_arg())
.collect();
/// If `with_pred = true` a predicate is included and its hash constrained.
/// If `with_pred = false` only the predicate hash is included.
fn add_virtual_statement_tmpl(
&mut self,
params: &Params,
with_pred: bool,
) -> StatementTmplTarget {
let (pred, pred_hash) = if with_pred {
let pred = self.add_virtual_predicate();
let pred_hash = pred.hash(self);
(Some(pred), pred_hash)
} else {
let pred_hash = self.add_virtual_hash();
(None, pred_hash)
};
StatementTmplTarget {
pred: self.add_virtual_predicate(),
args,
pred,
pred_hash,
args: (0..params.max_statement_args)
.map(|_| self.add_virtual_statement_tmpl_arg())
.collect(),
}
}
fn add_virtual_custom_predicate(&mut self, params: &Params) -> CustomPredicateTarget {
/// See `add_virtual_statement_tmpl` for the meaning of `with_pred`.
fn add_virtual_custom_predicate(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateTarget {
let statements = (0..params.max_custom_predicate_arity)
.map(|_| self.add_virtual_statement_tmpl(params))
.map(|_| self.add_virtual_statement_tmpl(params, with_pred))
.collect();
CustomPredicateTarget {
conjunction: self.add_virtual_bool_target_safe(),
@ -1209,25 +1360,29 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
}
}
/// See `add_virtual_statement_tmpl` for the meaning of `with_pred`.
fn add_virtual_custom_predicate_batch(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateBatchTarget {
CustomPredicateBatchTarget {
predicates: (0..params.max_custom_batch_size)
.map(|_| self.add_virtual_custom_predicate(params))
.map(|_| self.add_virtual_custom_predicate(params, with_pred))
.collect(),
}
}
/// See `add_virtual_statement_tmpl` for the meaning of `with_pred`.
fn add_virtual_custom_predicate_entry(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateEntryTarget {
CustomPredicateEntryTarget {
id: self.add_virtual_hash(),
index: self.add_virtual_target(),
predicate: self.add_virtual_custom_predicate(params),
predicate: self.add_virtual_custom_predicate(params, with_pred),
}
}
@ -1734,7 +1889,8 @@ pub(crate) mod tests {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let custom_predicate_batch_target = builder.add_virtual_custom_predicate_batch(params);
let custom_predicate_batch_target =
builder.add_virtual_custom_predicate_batch(params, false);
// Calculate the id in constraints and compare it against the id calculated natively
let id_target = custom_predicate_batch_target.id(&mut builder);