Constraints for custom predicates (#227)
* add target types for custom predicates * simplify * fix clippy * fix typo * don't use ref for NativePredicate * fix wrong len * precalculate CustomPredicateBatch id * wip * wip * move code back * great progress * wip * code complete, hopefully; missing tests * fill aux for custom predicate op * fix clippy warnings * fix typos * fix test import * fix missing assignment in lt_mask, test custom_operation_verify_gadget * fix mistake * wip * fix * debug revert except for let entry = CustomPredicateVerifyEntryTarget * fix batch_id calculation by fixing padding * oops * remove completed TODOs
This commit is contained in:
parent
4fa9e20ecd
commit
024ed8bd04
12 changed files with 1597 additions and 291 deletions
|
|
@ -5,7 +5,8 @@ use schemars::JsonSchema;
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::middleware::{
|
||||
hash_fields, Error, Hash, Key, Params, Predicate, Result, ToFields, Value, F, HASH_SIZE,
|
||||
hash_fields, Error, Hash, Key, NativePredicate, Params, Predicate, Result, ToFields, Value,
|
||||
EMPTY_HASH, F, HASH_SIZE, VALUE_SIZE,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
|
|
@ -49,12 +50,15 @@ impl fmt::Display for KeyOrWildcard {
|
|||
}
|
||||
|
||||
impl ToFields for KeyOrWildcard {
|
||||
// Encoding:
|
||||
// - Key(k) => [[k]]
|
||||
// - Wildcard(index) => [[index], 0, 0, 0]
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
match self {
|
||||
KeyOrWildcard::Key(k) => k.hash().to_fields(params),
|
||||
KeyOrWildcard::Wildcard(wc) => iter::once(F::ZERO)
|
||||
.take(HASH_SIZE - 1)
|
||||
.chain(iter::once(F::from_canonical_u64(wc.index as u64)))
|
||||
KeyOrWildcard::Wildcard(wc) => iter::once(F::from_canonical_u64(wc.index as u64))
|
||||
.chain(iter::repeat(F::ZERO))
|
||||
.take(HASH_SIZE)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
|
@ -66,7 +70,7 @@ pub enum StatementTmplArg {
|
|||
None,
|
||||
Literal(Value),
|
||||
// AnchoredKey
|
||||
Key(Wildcard, KeyOrWildcard),
|
||||
AnchoredKey(Wildcard, KeyOrWildcard),
|
||||
// TODO: This naming is a bit confusing: a WildcardLiteral that contains a Wildcard...
|
||||
// Could we merge WildcardValue and Value and allow wildcard value apart from pod_id and key?
|
||||
WildcardLiteral(Wildcard),
|
||||
|
|
@ -76,7 +80,7 @@ pub enum StatementTmplArg {
|
|||
pub enum StatementTmplArgPrefix {
|
||||
None = 0,
|
||||
Literal = 1,
|
||||
Key = 2,
|
||||
AnchoredKey = 2,
|
||||
WildcardLiteral = 3,
|
||||
}
|
||||
|
||||
|
|
@ -88,11 +92,11 @@ impl From<StatementTmplArgPrefix> for F {
|
|||
|
||||
impl ToFields for StatementTmplArg {
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
// None => (0, ...)
|
||||
// Literal(value) => (1, [value], 0, 0, 0, 0)
|
||||
// Key(wildcard1_index, key_or_wildcard2)
|
||||
// => (2, [wildcard1_index], 0, 0, 0, [key_or_wildcard2])
|
||||
// WildcardLiteral(wildcard_index) => (3, [wildcard_index], 0, 0, 0, 0, 0, 0, 0)
|
||||
// Encoding:
|
||||
// None => (0, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
// Literal(v) => (1, [v ], 0, 0, 0, 0)
|
||||
// Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
|
||||
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
|
||||
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
|
||||
match self {
|
||||
StatementTmplArg::None => {
|
||||
|
|
@ -105,13 +109,15 @@ impl ToFields for StatementTmplArg {
|
|||
StatementTmplArg::Literal(v) => {
|
||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal))
|
||||
.chain(v.raw().to_fields(params))
|
||||
.chain(iter::repeat(F::ZERO).take(HASH_SIZE))
|
||||
.chain(iter::repeat(F::ZERO))
|
||||
.take(Params::statement_tmpl_arg_size())
|
||||
.collect();
|
||||
fields
|
||||
}
|
||||
StatementTmplArg::Key(wc1, kw2) => {
|
||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Key))
|
||||
StatementTmplArg::AnchoredKey(wc1, kw2) => {
|
||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::AnchoredKey))
|
||||
.chain(wc1.to_fields(params))
|
||||
.chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1))
|
||||
.chain(kw2.to_fields(params))
|
||||
.collect();
|
||||
fields
|
||||
|
|
@ -119,7 +125,8 @@ impl ToFields for StatementTmplArg {
|
|||
StatementTmplArg::WildcardLiteral(wc) => {
|
||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
|
||||
.chain(wc.to_fields(params))
|
||||
.chain(iter::repeat(F::ZERO).take(HASH_SIZE))
|
||||
.chain(iter::repeat(F::ZERO))
|
||||
.take(Params::statement_tmpl_arg_size())
|
||||
.collect();
|
||||
fields
|
||||
}
|
||||
|
|
@ -132,7 +139,7 @@ impl fmt::Display for StatementTmplArg {
|
|||
match self {
|
||||
Self::None => write!(f, "none"),
|
||||
Self::Literal(v) => write!(f, "{}", v),
|
||||
Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key),
|
||||
Self::AnchoredKey(pod_id, key) => write!(f, "({}, {})", pod_id, key),
|
||||
Self::WildcardLiteral(v) => write!(f, "{}", v),
|
||||
}
|
||||
}
|
||||
|
|
@ -177,7 +184,11 @@ impl ToFields for StatementTmpl {
|
|||
// instead of at the `to_fields` method, where we should assume that the
|
||||
// values are already valid
|
||||
if self.args.len() > params.max_statement_args {
|
||||
panic!("Statement template has too many arguments");
|
||||
panic!(
|
||||
"Statement template has too many arguments {} > {}",
|
||||
self.args.len(),
|
||||
params.max_statement_args
|
||||
);
|
||||
}
|
||||
|
||||
let mut fields: Vec<F> = self
|
||||
|
|
@ -206,25 +217,36 @@ pub struct CustomPredicate {
|
|||
}
|
||||
|
||||
impl CustomPredicate {
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
name: "empty".to_string(),
|
||||
conjunction: false,
|
||||
statements: vec![StatementTmpl {
|
||||
pred: Predicate::Native(NativePredicate::None),
|
||||
args: vec![],
|
||||
}],
|
||||
args_len: 0,
|
||||
}
|
||||
}
|
||||
pub fn and(
|
||||
name: String,
|
||||
params: &Params,
|
||||
name: String,
|
||||
statements: Vec<StatementTmpl>,
|
||||
args_len: usize,
|
||||
) -> Result<Self> {
|
||||
Self::new(name, params, true, statements, args_len)
|
||||
Self::new(params, name, true, statements, args_len)
|
||||
}
|
||||
pub fn or(
|
||||
name: String,
|
||||
params: &Params,
|
||||
name: String,
|
||||
statements: Vec<StatementTmpl>,
|
||||
args_len: usize,
|
||||
) -> Result<Self> {
|
||||
Self::new(name, params, false, statements, args_len)
|
||||
Self::new(params, name, false, statements, args_len)
|
||||
}
|
||||
pub fn new(
|
||||
name: String,
|
||||
params: &Params,
|
||||
name: String,
|
||||
conjunction: bool,
|
||||
statements: Vec<StatementTmpl>,
|
||||
args_len: usize,
|
||||
|
|
@ -236,6 +258,13 @@ impl CustomPredicate {
|
|||
params.max_custom_predicate_arity,
|
||||
));
|
||||
}
|
||||
if args_len > params.max_statement_args {
|
||||
return Err(Error::max_length(
|
||||
"statement_args.len".to_string(),
|
||||
args_len,
|
||||
params.max_statement_args,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
name,
|
||||
|
|
@ -244,6 +273,16 @@ impl CustomPredicate {
|
|||
args_len,
|
||||
})
|
||||
}
|
||||
pub fn pad_statement_tmpl(&self) -> StatementTmpl {
|
||||
StatementTmpl {
|
||||
pred: Predicate::Native(if self.conjunction {
|
||||
NativePredicate::False
|
||||
} else {
|
||||
NativePredicate::None
|
||||
}),
|
||||
args: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFields for CustomPredicate {
|
||||
|
|
@ -262,11 +301,17 @@ impl ToFields for CustomPredicate {
|
|||
panic!("Custom predicate depends on too many statements");
|
||||
}
|
||||
|
||||
let mut fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
|
||||
let pad_st = self.pad_statement_tmpl();
|
||||
let fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
|
||||
.chain(iter::once(F::from_canonical_usize(self.args_len)))
|
||||
.chain(self.statements.iter().flat_map(|st| st.to_fields(params)))
|
||||
.chain(
|
||||
self.statements
|
||||
.iter()
|
||||
.chain(iter::repeat(&pad_st))
|
||||
.take(params.max_custom_predicate_arity)
|
||||
.flat_map(|st| st.to_fields(params)),
|
||||
)
|
||||
.collect();
|
||||
fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0));
|
||||
fields
|
||||
}
|
||||
}
|
||||
|
|
@ -298,8 +343,9 @@ impl fmt::Display for CustomPredicate {
|
|||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CustomPredicateBatch {
|
||||
id: Hash,
|
||||
pub name: String,
|
||||
pub predicates: Vec<CustomPredicate>,
|
||||
predicates: Vec<CustomPredicate>,
|
||||
}
|
||||
|
||||
impl ToFields for CustomPredicateBatch {
|
||||
|
|
@ -313,27 +359,45 @@ impl ToFields for CustomPredicateBatch {
|
|||
panic!("Predicate batch exceeds maximum size");
|
||||
}
|
||||
|
||||
let mut fields: Vec<F> = self
|
||||
let pad_pred = CustomPredicate::empty();
|
||||
let fields: Vec<F> = self
|
||||
.predicates
|
||||
.iter()
|
||||
.chain(iter::repeat(&pad_pred))
|
||||
.take(params.max_custom_batch_size)
|
||||
.flat_map(|p| p.to_fields(params))
|
||||
.collect();
|
||||
fields.resize_with(params.custom_predicate_batch_size_field_elts(), || {
|
||||
F::from_canonical_u64(0)
|
||||
});
|
||||
fields
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomPredicateBatch {
|
||||
pub fn new(params: &Params, name: String, predicates: Vec<CustomPredicate>) -> Arc<Self> {
|
||||
let mut cpb = Self {
|
||||
id: EMPTY_HASH,
|
||||
name,
|
||||
predicates,
|
||||
};
|
||||
let id = cpb.calculate_id(params);
|
||||
cpb.id = id;
|
||||
Arc::new(cpb)
|
||||
}
|
||||
|
||||
/// Cryptographic identifier for the batch.
|
||||
pub fn id(&self, params: &Params) -> Hash {
|
||||
fn calculate_id(&self, params: &Params) -> Hash {
|
||||
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
|
||||
// but ideally we want to use the root of a merkle tree built from the custom predicates.
|
||||
let input = self.to_fields(params);
|
||||
|
||||
hash_fields(&input)
|
||||
}
|
||||
|
||||
pub fn id(&self) -> Hash {
|
||||
self.id
|
||||
}
|
||||
pub fn predicates(&self) -> &[CustomPredicate] {
|
||||
&self.predicates
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
|
|
@ -347,13 +411,16 @@ impl CustomPredicateRef {
|
|||
Self { batch, index }
|
||||
}
|
||||
pub fn arg_len(&self) -> usize {
|
||||
self.batch.predicates[self.index].args_len
|
||||
self.predicate().args_len
|
||||
}
|
||||
pub fn predicate(&self) -> &CustomPredicate {
|
||||
&self.batch.predicates[self.index]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{array, sync::Arc};
|
||||
use std::array;
|
||||
|
||||
use plonky2::field::goldilocks_field::GoldilocksField;
|
||||
|
||||
|
|
@ -392,28 +459,29 @@ mod tests {
|
|||
p:value_of(Constant, 2),
|
||||
p:product_of(S1, Constant, S2)
|
||||
*/
|
||||
let cust_pred_batch = Arc::new(CustomPredicateBatch {
|
||||
name: "is_double".to_string(),
|
||||
predicates: vec![CustomPredicate::and(
|
||||
"_".into(),
|
||||
let cust_pred_batch = CustomPredicateBatch::new(
|
||||
¶ms,
|
||||
"is_double".to_string(),
|
||||
vec![CustomPredicate::and(
|
||||
¶ms,
|
||||
"_".into(),
|
||||
vec![
|
||||
st(
|
||||
P::Native(NP::ValueOf),
|
||||
vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(2.into())],
|
||||
vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(2.into())],
|
||||
),
|
||||
st(
|
||||
P::Native(NP::ProductOf),
|
||||
vec![
|
||||
STA::Key(wc(0), kow_wc(1)),
|
||||
STA::Key(wc(4), kow_wc(5)),
|
||||
STA::Key(wc(2), kow_wc(3)),
|
||||
STA::AnchoredKey(wc(0), kow_wc(1)),
|
||||
STA::AnchoredKey(wc(4), kow_wc(5)),
|
||||
STA::AnchoredKey(wc(2), kow_wc(3)),
|
||||
],
|
||||
),
|
||||
],
|
||||
2,
|
||||
)?],
|
||||
});
|
||||
);
|
||||
|
||||
let custom_statement = Statement::Custom(
|
||||
CustomPredicateRef::new(cust_pred_batch.clone(), 0),
|
||||
|
|
@ -444,55 +512,57 @@ mod tests {
|
|||
fn ethdos_test() -> Result<()> {
|
||||
let params = Params {
|
||||
max_custom_predicate_wildcards: 12,
|
||||
max_statement_args: 6,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let eth_friend_cp = CustomPredicate::and(
|
||||
"eth_friend_cp".into(),
|
||||
¶ms,
|
||||
"eth_friend_cp".into(),
|
||||
vec![
|
||||
st(
|
||||
P::Native(NP::ValueOf),
|
||||
vec![
|
||||
STA::Key(wc(4), KeyOrWildcard::Key("type".into())),
|
||||
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("type".into())),
|
||||
STA::Literal(PodType::Signed.into()),
|
||||
],
|
||||
),
|
||||
st(
|
||||
P::Native(NP::Equal),
|
||||
vec![
|
||||
STA::Key(wc(4), KeyOrWildcard::Key("signer".into())),
|
||||
STA::Key(wc(0), kow_wc(1)),
|
||||
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("signer".into())),
|
||||
STA::AnchoredKey(wc(0), kow_wc(1)),
|
||||
],
|
||||
),
|
||||
st(
|
||||
P::Native(NP::Equal),
|
||||
vec![
|
||||
STA::Key(wc(4), KeyOrWildcard::Key("attestation".into())),
|
||||
STA::Key(wc(2), kow_wc(3)),
|
||||
STA::AnchoredKey(wc(4), KeyOrWildcard::Key("attestation".into())),
|
||||
STA::AnchoredKey(wc(2), kow_wc(3)),
|
||||
],
|
||||
),
|
||||
],
|
||||
4,
|
||||
)?;
|
||||
|
||||
let eth_friend_batch = Arc::new(CustomPredicateBatch {
|
||||
name: "eth_friend".to_string(),
|
||||
predicates: vec![eth_friend_cp],
|
||||
});
|
||||
let eth_friend_batch =
|
||||
CustomPredicateBatch::new(¶ms, "eth_friend".to_string(), vec![eth_friend_cp]);
|
||||
|
||||
// 0
|
||||
let eth_dos_base = CustomPredicate::and(
|
||||
"eth_dos_base".into(),
|
||||
¶ms,
|
||||
"eth_dos_base".into(),
|
||||
vec![
|
||||
st(
|
||||
P::Native(NP::Equal),
|
||||
vec![STA::Key(wc(0), kow_wc(1)), STA::Key(wc(2), kow_wc(3))],
|
||||
vec![
|
||||
STA::AnchoredKey(wc(0), kow_wc(1)),
|
||||
STA::AnchoredKey(wc(2), kow_wc(3)),
|
||||
],
|
||||
),
|
||||
st(
|
||||
P::Native(NP::ValueOf),
|
||||
vec![STA::Key(wc(4), kow_wc(5)), STA::Literal(0.into())],
|
||||
vec![STA::AnchoredKey(wc(4), kow_wc(5)), STA::Literal(0.into())],
|
||||
),
|
||||
],
|
||||
6,
|
||||
|
|
@ -500,8 +570,8 @@ mod tests {
|
|||
|
||||
// 1
|
||||
let eth_dos_ind = CustomPredicate::and(
|
||||
"eth_dos_ind".into(),
|
||||
¶ms,
|
||||
"eth_dos_ind".into(),
|
||||
vec![
|
||||
st(
|
||||
P::BatchSelf(2),
|
||||
|
|
@ -516,14 +586,14 @@ mod tests {
|
|||
),
|
||||
st(
|
||||
P::Native(NP::ValueOf),
|
||||
vec![STA::Key(wc(6), kow_wc(7)), STA::Literal(1.into())],
|
||||
vec![STA::AnchoredKey(wc(6), kow_wc(7)), STA::Literal(1.into())],
|
||||
),
|
||||
st(
|
||||
P::Native(NP::SumOf),
|
||||
vec![
|
||||
STA::Key(wc(4), kow_wc(5)),
|
||||
STA::Key(wc(8), kow_wc(9)),
|
||||
STA::Key(wc(6), kow_wc(7)),
|
||||
STA::AnchoredKey(wc(4), kow_wc(5)),
|
||||
STA::AnchoredKey(wc(8), kow_wc(9)),
|
||||
STA::AnchoredKey(wc(6), kow_wc(7)),
|
||||
],
|
||||
),
|
||||
st(
|
||||
|
|
@ -541,8 +611,8 @@ mod tests {
|
|||
|
||||
// 2
|
||||
let eth_dos_distance_either = CustomPredicate::or(
|
||||
"eth_dos_distance_either".into(),
|
||||
¶ms,
|
||||
"eth_dos_distance_either".into(),
|
||||
vec![
|
||||
st(
|
||||
P::BatchSelf(0),
|
||||
|
|
@ -570,10 +640,11 @@ mod tests {
|
|||
6,
|
||||
)?;
|
||||
|
||||
let eth_dos_distance_batch = Arc::new(CustomPredicateBatch {
|
||||
name: "ETHDoS_distance".to_string(),
|
||||
predicates: vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either],
|
||||
});
|
||||
let eth_dos_distance_batch = CustomPredicateBatch::new(
|
||||
¶ms,
|
||||
"ETHDoS_distance".to_string(),
|
||||
vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either],
|
||||
);
|
||||
|
||||
// Some POD IDs
|
||||
let pod_id1 = PodId(Hash(array::from_fn(|i| GoldilocksField(i as u64))));
|
||||
|
|
|
|||
|
|
@ -584,6 +584,10 @@ pub struct Params {
|
|||
pub max_public_statements: usize,
|
||||
pub max_statement_args: usize,
|
||||
pub max_operation_args: usize,
|
||||
// max number of custom predicates batches that a MainPod can use
|
||||
pub max_custom_predicate_batches: usize,
|
||||
// max number of operations using custom predicates that can be verified in the MainPod
|
||||
pub max_custom_predicate_verifications: usize,
|
||||
// max number of statements that can be ANDed or ORed together
|
||||
// in a custom predicate
|
||||
pub max_custom_predicate_arity: usize,
|
||||
|
|
@ -605,6 +609,8 @@ impl Default for Params {
|
|||
max_public_statements: 10,
|
||||
max_statement_args: 5,
|
||||
max_operation_args: 5,
|
||||
max_custom_predicate_batches: 2,
|
||||
max_custom_predicate_verifications: 5,
|
||||
max_custom_predicate_arity: 5,
|
||||
max_custom_predicate_wildcards: 10,
|
||||
max_custom_batch_size: 5,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use std::{fmt, iter, sync::Arc};
|
||||
use std::{fmt, iter};
|
||||
|
||||
use log::error;
|
||||
use plonky2::field::types::Field;
|
||||
|
|
@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
|||
use crate::{
|
||||
backends::plonky2::primitives::merkletree::MerkleProof,
|
||||
middleware::{
|
||||
custom::KeyOrWildcard, AnchoredKey, CustomPredicateBatch, CustomPredicateRef, Error,
|
||||
custom::KeyOrWildcard, AnchoredKey, CustomPredicate, CustomPredicateRef, Error,
|
||||
NativePredicate, Params, Predicate, Result, Statement, StatementArg, StatementTmplArg,
|
||||
ToFields, Wildcard, WildcardValue, F, SELF,
|
||||
},
|
||||
|
|
@ -36,6 +36,9 @@ impl fmt::Display for OperationAux {
|
|||
}
|
||||
|
||||
impl ToFields for OperationType {
|
||||
/// Encoding:
|
||||
/// - Native(native_op) => [1, [native_op], 0, 0, 0, 0]
|
||||
/// - Custom(batch, index) => [3, [batch.id], index]
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
let mut fields: Vec<F> = match self {
|
||||
Self::Native(p) => iter::once(F::from_canonical_u64(1))
|
||||
|
|
@ -43,7 +46,7 @@ impl ToFields for OperationType {
|
|||
.collect(),
|
||||
Self::Custom(CustomPredicateRef { batch, index }) => {
|
||||
iter::once(F::from_canonical_u64(3))
|
||||
.chain(batch.id(params).0)
|
||||
.chain(batch.id().0)
|
||||
.chain(iter::once(F::from_canonical_usize(*index)))
|
||||
.collect()
|
||||
}
|
||||
|
|
@ -321,7 +324,7 @@ impl Operation {
|
|||
(Self::Custom(CustomPredicateRef { batch, index }, args), Custom(cpr, s_args))
|
||||
if batch == &cpr.batch && index == &cpr.index =>
|
||||
{
|
||||
check_custom_pred(params, batch, *index, args, s_args)
|
||||
check_custom_pred(params, cpr, args, s_args)
|
||||
}
|
||||
_ => Err(Error::invalid_deduction(
|
||||
self.clone(),
|
||||
|
|
@ -360,7 +363,7 @@ pub fn check_st_tmpl(
|
|||
(StatementTmplArg::None, StatementArg::None) => true,
|
||||
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => true,
|
||||
(
|
||||
StatementTmplArg::Key(pod_id_wc, key_or_wc),
|
||||
StatementTmplArg::AnchoredKey(pod_id_wc, key_or_wc),
|
||||
StatementArg::Key(AnchoredKey { pod_id, key }),
|
||||
) => {
|
||||
let pod_id_ok = check_or_set(WildcardValue::PodId(*pod_id), pod_id_wc, wildcard_map);
|
||||
|
|
@ -379,14 +382,46 @@ pub fn check_st_tmpl(
|
|||
}
|
||||
}
|
||||
|
||||
pub fn resolve_wildcard_values(
|
||||
params: &Params,
|
||||
pred: &CustomPredicate,
|
||||
args: &[Statement],
|
||||
) -> Option<Vec<WildcardValue>> {
|
||||
// Check that all wildcard have consistent values as assigned in the statements while storing a
|
||||
// map of their values.
|
||||
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
|
||||
// disjunctions we expect Statement::None for the unused statements.
|
||||
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
|
||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||
let st_args = st.args();
|
||||
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
|
||||
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
|
||||
// TODO: Better errors. Example:
|
||||
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
|
||||
// println!("{} doesn't match {}", st, st_tmpl);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
|
||||
// they are beyond the number of used wildcards in this custom predicate, or they could be
|
||||
// private arguments that are unused in a particular disjunction.
|
||||
Some(
|
||||
wildcard_map
|
||||
.into_iter()
|
||||
.map(|opt| opt.unwrap_or(WildcardValue::None))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn check_custom_pred(
|
||||
params: &Params,
|
||||
batch: &Arc<CustomPredicateBatch>,
|
||||
index: usize,
|
||||
custom_pred_ref: &CustomPredicateRef,
|
||||
args: &[Statement],
|
||||
s_args: &[WildcardValue],
|
||||
) -> Result<bool> {
|
||||
let pred = &batch.predicates[index];
|
||||
let pred = custom_pred_ref.predicate();
|
||||
if pred.statements.len() != args.len() {
|
||||
return Err(Error::diff_amount(
|
||||
"custom predicate operation".to_string(),
|
||||
|
|
@ -404,26 +439,12 @@ fn check_custom_pred(
|
|||
));
|
||||
}
|
||||
|
||||
// Check that all wildcard have consistent values as assigned in the statements while storing a
|
||||
// map of their values. Count the number of statements that match the templates by predicate.
|
||||
// NOTE: We assume the statements have the same order as defined in the custom predicate. For
|
||||
// disjunctions we expect Statement::None for the unused statements.
|
||||
// Count the number of statements that match the templates by predicate.
|
||||
let mut num_matches = 0;
|
||||
let mut wildcard_map = vec![None; params.max_custom_predicate_wildcards];
|
||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||
let st_args = st.args();
|
||||
for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) {
|
||||
if !check_st_tmpl(st_tmpl_arg, st_arg, &mut wildcard_map) {
|
||||
// TODO: Better errors. Example:
|
||||
// println!("{} doesn't match {}", st_arg, st_tmpl_arg);
|
||||
// println!("{} doesn't match {}", st, st_tmpl);
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
let st_tmpl_pred = match &st_tmpl.pred {
|
||||
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
|
||||
batch: batch.clone(),
|
||||
batch: custom_pred_ref.batch.clone(),
|
||||
index: *i,
|
||||
}),
|
||||
p => p.clone(),
|
||||
|
|
@ -433,9 +454,14 @@ fn check_custom_pred(
|
|||
}
|
||||
}
|
||||
|
||||
let wildcard_map = match resolve_wildcard_values(params, pred, args) {
|
||||
Some(wc_map) => wc_map,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
// Check that the resolved wildcard match the statement arguments.
|
||||
for (s_arg, wc_value) in s_args.iter().zip(wildcard_map.iter()) {
|
||||
if !wc_value.as_ref().is_none_or(|wc_value| *wc_value == *s_arg) {
|
||||
if *wc_value != *s_arg {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use strum_macros::FromRepr;
|
|||
|
||||
use crate::middleware::{
|
||||
AnchoredKey, CustomPredicateRef, Error, Key, Params, PodId, RawValue, Result, ToFields, Value,
|
||||
F, VALUE_SIZE,
|
||||
EMPTY_VALUE, F, VALUE_SIZE,
|
||||
};
|
||||
|
||||
// TODO: Maybe store KEY_SIGNER and KEY_TYPE as Key with lazy_static
|
||||
|
|
@ -17,22 +17,23 @@ pub const KEY_SIGNER: &str = "_signer";
|
|||
pub const KEY_TYPE: &str = "_type";
|
||||
pub const STATEMENT_ARG_F_LEN: usize = 8;
|
||||
pub const OPERATION_ARG_F_LEN: usize = 1;
|
||||
pub const OPERATION_AUX_F_LEN: usize = 1;
|
||||
pub const OPERATION_AUX_F_LEN: usize = 2;
|
||||
|
||||
#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||
pub enum NativePredicate {
|
||||
None = 0,
|
||||
ValueOf = 1,
|
||||
Equal = 2,
|
||||
NotEqual = 3,
|
||||
LtEq = 4,
|
||||
Lt = 5,
|
||||
Contains = 6,
|
||||
NotContains = 7,
|
||||
SumOf = 8,
|
||||
ProductOf = 9,
|
||||
MaxOf = 10,
|
||||
HashOf = 11,
|
||||
None = 0, // Always true
|
||||
False = 1, // Always false
|
||||
ValueOf = 2,
|
||||
Equal = 3,
|
||||
NotEqual = 4,
|
||||
LtEq = 5,
|
||||
Lt = 6,
|
||||
Contains = 7,
|
||||
NotContains = 8,
|
||||
SumOf = 9,
|
||||
ProductOf = 10,
|
||||
MaxOf = 11,
|
||||
HashOf = 12,
|
||||
|
||||
// Syntactic sugar predicates. These predicates are not supported by the backend. The
|
||||
// frontend compiler is responsible of translating these predicates into the predicates above.
|
||||
|
|
@ -53,6 +54,7 @@ impl ToFields for NativePredicate {
|
|||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub enum WildcardValue {
|
||||
None,
|
||||
PodId(PodId),
|
||||
Key(Key),
|
||||
}
|
||||
|
|
@ -60,6 +62,7 @@ pub enum WildcardValue {
|
|||
impl WildcardValue {
|
||||
pub fn raw(&self) -> RawValue {
|
||||
match self {
|
||||
WildcardValue::None => EMPTY_VALUE,
|
||||
WildcardValue::PodId(pod_id) => RawValue::from(pod_id.0),
|
||||
WildcardValue::Key(key) => key.raw(),
|
||||
}
|
||||
|
|
@ -69,6 +72,7 @@ impl WildcardValue {
|
|||
impl fmt::Display for WildcardValue {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
WildcardValue::None => write!(f, "none"),
|
||||
WildcardValue::PodId(pod_id) => write!(f, "{}", pod_id),
|
||||
WildcardValue::Key(key) => write!(f, "{}", key),
|
||||
}
|
||||
|
|
@ -77,10 +81,7 @@ impl fmt::Display for WildcardValue {
|
|||
|
||||
impl ToFields for WildcardValue {
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
match self {
|
||||
WildcardValue::PodId(pod_id) => pod_id.to_fields(params),
|
||||
WildcardValue::Key(key) => key.to_fields(params),
|
||||
}
|
||||
self.raw().to_fields(params)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -130,7 +131,7 @@ impl ToFields for Predicate {
|
|||
.collect(),
|
||||
Self::Custom(CustomPredicateRef { batch, index }) => {
|
||||
iter::once(F::from(PredicatePrefix::Custom))
|
||||
.chain(batch.id(params).0)
|
||||
.chain(batch.id().0)
|
||||
.chain(iter::once(F::from_canonical_usize(*index)))
|
||||
.collect()
|
||||
}
|
||||
|
|
@ -149,7 +150,9 @@ impl fmt::Display for Predicate {
|
|||
write!(
|
||||
f,
|
||||
"{}.{}[{}]",
|
||||
batch.name, index, batch.predicates[*index].name
|
||||
batch.name,
|
||||
index,
|
||||
batch.predicates()[*index].name
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -397,14 +400,14 @@ impl StatementArg {
|
|||
}
|
||||
|
||||
impl ToFields for StatementArg {
|
||||
fn to_fields(&self, _params: &Params) -> Vec<F> {
|
||||
// NOTE: current version returns always the same amount of field elements in the returned
|
||||
// vector, which means that the `None` case is padded with 8 zeroes, and the `Literal` case
|
||||
// is padded with 4 zeroes. Since the returned vector will mostly be hashed (and reproduced
|
||||
// in-circuit), we might be interested into reducing the length of it. If that's the case,
|
||||
// we can check if it makes sense to make it dependant on the concrete StatementArg; that
|
||||
// is, when dealing with a `None` it would be a single field element (zero value), and when
|
||||
// dealing with `Literal` it would be of length 4.
|
||||
/// Encoding:
|
||||
/// - None => [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
/// - Literal(v) => [[v], 0, 0, 0, 0]
|
||||
/// - Key(pod_id, key) => [[pod_id], [key]]
|
||||
/// - WildcardLiteral(v) => [[v], 0, 0, 0, 0]
|
||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||
// NOTE for @ax0: I removed the old comment because may `to_fields` implementations do
|
||||
// padding and we need fixed output length for the circuits.
|
||||
let f = match self {
|
||||
StatementArg::None => vec![F::ZERO; STATEMENT_ARG_F_LEN],
|
||||
StatementArg::Literal(v) => v
|
||||
|
|
@ -414,8 +417,8 @@ impl ToFields for StatementArg {
|
|||
.chain(iter::repeat(F::ZERO).take(STATEMENT_ARG_F_LEN - VALUE_SIZE))
|
||||
.collect(),
|
||||
StatementArg::Key(ak) => {
|
||||
let mut fields = ak.pod_id.to_fields(_params);
|
||||
fields.extend(ak.key.to_fields(_params));
|
||||
let mut fields = ak.pod_id.to_fields(params);
|
||||
fields.extend(ak.key.to_fields(params));
|
||||
fields
|
||||
}
|
||||
StatementArg::WildcardLiteral(v) => v
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue