Constraints for custom predicates (#227)

* add target types for custom predicates

* simplify

* fix clippy

* fix typo

* don't use ref for NativePredicate

* fix wrong len

* precalculate CustomPredicateBatch id

* wip

* wip

* move code back

* great progress

* wip

* code complete, hopefully; missing tests

* fill aux for custom predicate op

* fix clippy warnings

* fix typos

* fix test import

* fix missing assignment in lt_mask, test custom_operation_verify_gadget

* fix mistake

* wip

* fix

* debug revert except for let entry = CustomPredicateVerifyEntryTarget

* fix batch_id calculation by fixing padding

* oops

* remove completed TODOs
This commit is contained in:
Eduard S. 2025-05-13 11:00:45 +02:00 committed by GitHub
parent 4fa9e20ecd
commit 024ed8bd04
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1597 additions and 291 deletions

View file

@ -1,4 +1,4 @@
use std::{fmt, iter, sync::Arc};
use std::{fmt, iter};
use log::error;
use plonky2::field::types::Field;
@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use crate::{
backends::plonky2::primitives::merkletree::MerkleProof,
middleware::{
custom::KeyOrWildcard, AnchoredKey, CustomPredicateBatch, CustomPredicateRef, Error,
custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error,
NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmplArg,
ToFields, Wildcard, WildcardValue, F, SELF,
},
@ -36,6 +36,9 @@ impl fmt::Display for OperationAux {
}
impl ToFields for OperationType {
/// Encoding:
/// - Native(native_op) => [1, [native_op], 0, 0, 0, 0]
/// - Custom(batch, index) => [3, [batch.id], index]
fn to_fields(&self, params: &Params) -> Vec<F> {
let mut fields: Vec<F> = match self {
Self::Native(p) => iter::once(F::from_canonical_u64(1))
@ -43,7 +46,7 @@ impl ToFields for OperationType {
.collect(),
Self::Custom(CustomPredicateRef { batch, index }) => {
iter::once(F::from_canonical_u64(3))
.chain(batch.id(params).0)
.chain(batch.id().0)
.chain(iter::once(F::from_canonical_usize(*index)))
.collect()
}
@ -321,7 +324,7 @@ impl Operation {
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
if batch == &cpr.batch && index == &cpr.index =>
{
check_custom_pred(params, batch, *index, args, s_args)
check_custom_pred(params, cpr, args, s_args)
}
_ => Err(Error::invalid_deduction(
self.clone(),
@ -360,7 +363,7 @@ pub fn check_st_tmpl(
(StatementTmplArg::None, StatementArg::None) => true,
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true,
(
StatementTmplArg::Key(pod_id_wc, key_or_wc),
StatementTmplArg::AnchoredKey(pod_id_wc, key_or_wc),
StatementArg::Key(AnchoredKey { pod_id, key }),
) => {
let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map);
@ -379,14 +382,46 @@ pub fn check_st_tmpl(
}
}
pub fn resolve_wildcard_values(
params: &Params,
pred: &CustomPredicate,
args: &[Statement],
) -> Option<Vec<WildcardValue>> {
// Check that all wildcard have consistent values as assigned in the statements while storing a
// map of their values.
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
// disjunctions we expect Statement::None for the unused statements.
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
for (st_tmpl, st) in pred.statements.iter().zip(args) {
let st_args = st.args();
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
// TODO: Better errors. Example:
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
// println!("{} doesn't match {}", st, st_tmpl);
return None;
}
}
}
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
// they are beyond the number of used wildcards in this custom predicate, or they could be
// private arguments that are unused in a particular disjunction.
Some(
wildcard_map
.into_iter()
.map(|opt| opt.unwrap_or(WildcardValue::None))
.collect(),
)
}
fn check_custom_pred(
params: &Params,
batch: &Arc<CustomPredicateBatch>,
index: usize,
custom_pred_ref: &CustomPredicateRef,
args: &[Statement],
s_args: &[WildcardValue],
) -> Result<bool> {
let pred = &batch.predicates[index];
let pred = custom_pred_ref.predicate();
if pred.statements.len() != args.len() {
return Err(Error::diff_amount(
"custom predicate operation".to_string(),
@ -404,26 +439,12 @@ fn check_custom_pred(
));
}
// Check that all wildcard have consistent values as assigned in the statements while storing a
// map of their values. Count the number of statements that match the templates by predicate.
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
// disjunctions we expect Statement::None for the unused statements.
// Count the number of statements that match the templates by predicate.
let mut num_matches = 0;
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
for (st_tmpl, st) in pred.statements.iter().zip(args) {
let st_args = st.args();
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
// TODO: Better errors. Example:
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
// println!("{} doesn't match {}", st, st_tmpl);
return Ok(false);
}
}
let st_tmpl_pred = match &st_tmpl.pred {
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
batch: batch.clone(),
batch: custom_pred_ref.batch.clone(),
index: *i,
}),
p => p.clone(),
@ -433,9 +454,14 @@ fn check_custom_pred(
}
}
let wildcard_map = match resolve_wildcard_values(params, pred, args) {
Some(wc_map) => wc_map,
None => return Ok(false),
};
// Check that the resolved wildcard match the statement arguments.
for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) {
if !wc_value.as_ref().is_none_or(|wc_value| *wc_value == *s_arg) {
if *wc_value != *s_arg {
return Ok(false);
}
}