Constraints for custom predicates (#227)
* add target types for custom predicates * simplify * fix clippy * fix typo * don't use ref for NativePredicate * fix wrong len * precalculate CustomPredicateBatch id * wip * wip * move code back * great progress * wip * code complete, hopefully; missing tests * fill aux for custom predicate op * fix clippy warnings * fix typos * fix test import * fix missing assignment in lt_mask, test custom_operation_verify_gadget * fix mistake * wip * fix * debug revert except for let entry = CustomPredicateVerifyEntryTarget * fix batch_id calculation by fixing padding * oops * remove completed TODOs
This commit is contained in:
parent
4fa9e20ecd
commit
024ed8bd04
12 changed files with 1597 additions and 291 deletions
|
|
@ -1,4 +1,4 @@
|
|||
use std::{fmt, iter, sync::Arc};
|
||||
use std::{fmt, iter};
|
||||
|
||||
use log::error;
|
||||
use plonky2::field::types::Field;
|
||||
|
|
@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
|||
use crate::{
|
||||
backends::plonky2::primitives::merkletree::MerkleProof,
|
||||
middleware::{
|
||||
custom::KeyOrWildcard, AnchoredKey, CustomPredicateBatch, CustomPredicateRef, Error,
|
||||
custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error,
|
||||
NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmplArg,
|
||||
ToFields, Wildcard, WildcardValue, F, SELF,
|
||||
},
|
||||
|
|
@ -36,6 +36,9 @@ impl fmt::Display for OperationAux {
|
|||
}
|
||||
|
||||
impl ToFields for OperationType {
|
||||
/// Encoding:
|
||||
/// - Native(native_op) => [1, [native_op], 0, 0, 0, 0]
|
||||
/// - Custom(batch, index) => [3, [batch.id], index]
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
let mut fields: Vec<F> = match self {
|
||||
Self::Native(p) => iter::once(F::from_canonical_u64(1))
|
||||
|
|
@ -43,7 +46,7 @@ impl ToFields for OperationType {
|
|||
.collect(),
|
||||
Self::Custom(CustomPredicateRef { batch, index }) => {
|
||||
iter::once(F::from_canonical_u64(3))
|
||||
.chain(batch.id(params).0)
|
||||
.chain(batch.id().0)
|
||||
.chain(iter::once(F::from_canonical_usize(*index)))
|
||||
.collect()
|
||||
}
|
||||
|
|
@ -321,7 +324,7 @@ impl Operation {
|
|||
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
|
||||
if batch == &cpr.batch && index == &cpr.index =>
|
||||
{
|
||||
check_custom_pred(params, batch, *index, args, s_args)
|
||||
check_custom_pred(params, cpr, args, s_args)
|
||||
}
|
||||
_ => Err(Error::invalid_deduction(
|
||||
self.clone(),
|
||||
|
|
@ -360,7 +363,7 @@ pub fn check_st_tmpl(
|
|||
(StatementTmplArg::None, StatementArg::None) => true,
|
||||
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true,
|
||||
(
|
||||
StatementTmplArg::Key(pod_id_wc, key_or_wc),
|
||||
StatementTmplArg::AnchoredKey(pod_id_wc, key_or_wc),
|
||||
StatementArg::Key(AnchoredKey { pod_id, key }),
|
||||
) => {
|
||||
let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map);
|
||||
|
|
@ -379,14 +382,46 @@ pub fn check_st_tmpl(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn resolve_wildcard_values(
|
||||
params: &Params,
|
||||
pred: &CustomPredicate,
|
||||
args: &[Statement],
|
||||
) -> Option<Vec<WildcardValue>> {
|
||||
// Check that all wildcard have consistent values as assigned in the statements while storing a
|
||||
// map of their values.
|
||||
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
|
||||
// disjunctions we expect Statement::None for the unused statements.
|
||||
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
|
||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||
let st_args = st.args();
|
||||
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
|
||||
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
|
||||
// TODO: Better errors. Example:
|
||||
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
|
||||
// println!("{} doesn't match {}", st, st_tmpl);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
|
||||
// they are beyond the number of used wildcards in this custom predicate, or they could be
|
||||
// private arguments that are unused in a particular disjunction.
|
||||
Some(
|
||||
wildcard_map
|
||||
.into_iter()
|
||||
.map(|opt| opt.unwrap_or(WildcardValue::None))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn check_custom_pred(
|
||||
params: &Params,
|
||||
batch: &Arc<CustomPredicateBatch>,
|
||||
index: usize,
|
||||
custom_pred_ref: &CustomPredicateRef,
|
||||
args: &[Statement],
|
||||
s_args: &[WildcardValue],
|
||||
) -> Result<bool> {
|
||||
let pred = &batch.predicates[index];
|
||||
let pred = custom_pred_ref.predicate();
|
||||
if pred.statements.len() != args.len() {
|
||||
return Err(Error::diff_amount(
|
||||
"custom predicate operation".to_string(),
|
||||
|
|
@ -404,26 +439,12 @@ fn check_custom_pred(
|
|||
));
|
||||
}
|
||||
|
||||
// Check that all wildcard have consistent values as assigned in the statements while storing a
|
||||
// map of their values. Count the number of statements that match the templates by predicate.
|
||||
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
|
||||
// disjunctions we expect Statement::None for the unused statements.
|
||||
// Count the number of statements that match the templates by predicate.
|
||||
let mut num_matches = 0;
|
||||
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
|
||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||
let st_args = st.args();
|
||||
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
|
||||
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
|
||||
// TODO: Better errors. Example:
|
||||
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
|
||||
// println!("{} doesn't match {}", st, st_tmpl);
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
let st_tmpl_pred = match &st_tmpl.pred {
|
||||
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
|
||||
batch: batch.clone(),
|
||||
batch: custom_pred_ref.batch.clone(),
|
||||
index: *i,
|
||||
}),
|
||||
p => p.clone(),
|
||||
|
|
@ -433,9 +454,14 @@ fn check_custom_pred(
|
|||
}
|
||||
}
|
||||
|
||||
let wildcard_map = match resolve_wildcard_values(params, pred, args) {
|
||||
Some(wc_map) => wc_map,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
// Check that the resolved wildcard match the statement arguments.
|
||||
for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) {
|
||||
if !wc_value.as_ref().is_none_or(|wc_value| *wc_value == *s_arg) {
|
||||
if *wc_value != *s_arg {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue