Feat/fst order pred part1 & part2 (#454)

Implement support for first order predicates in the backend.
Now a statement template can have a predicate hash or a wildcard.

## predicate <-> predicate hash constraints

To build the custom predicate table we need to calculate the custom predicate batch id, which uses the serialization of the statement templates before normalization.  This serialization uses the predicate hash when the template uses a predicate (instead of a wildcard).  Then in normalization we recalculate the predicate hash if it was a Batch Self.

This means that the relation between hash and predicate must be checked before and after normalization when the template is not using a wildcard.  How this is achieved:
- Before normalization: the constructor of StatementTmplTarget forces that if we keep a predicate, it's hash must be equal to the pred_hash when the template has a predicate (and not a wildcard)
- After normalization: the predicate hash is calculated in the normalization and replaced in the case of the template using a predicate and it being a BatchSelf.  If it was a predicate but not batch self, the old value was used which was constrained via the constructor.

See `CircuitBuilder::add_virtual_statement_tmpl` and `normalize_st_tmpl_circuit`

## Wildcard predicate resolution

It is done via `make_predicate_from_template_circuit` and is fairly simple as it's contains similar logic to `make_statement_arg_from_template_circuit` but simpler.
This commit is contained in:
Eduard S. 2026-01-20 13:14:22 +01:00 committed by GitHub
parent 1724e7b146
commit 9c9a2c454c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 569 additions and 240 deletions

View file

@ -32,9 +32,10 @@ use crate::{
},
middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg,
StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F,
HASH_SIZE, STATEMENT_ARG_F_LEN, VALUE_SIZE,
NativePredicate, OperationType, Params, Predicate, PredicateOrWildcard,
PredicateOrWildcardPrefix, PredicatePrefix, RawValue, StatementArg, StatementTmpl,
StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE,
STATEMENT_ARG_F_LEN, VALUE_SIZE,
},
};
@ -46,6 +47,22 @@ pub struct ValueTarget {
pub elements: [Target; VALUE_SIZE],
}
impl From<ValueTarget> for HashOutTarget {
fn from(v: ValueTarget) -> HashOutTarget {
HashOutTarget {
elements: v.elements,
}
}
}
impl From<HashOutTarget> for ValueTarget {
fn from(h: HashOutTarget) -> ValueTarget {
ValueTarget {
elements: h.elements,
}
}
}
impl ValueTarget {
pub fn zero(builder: &mut CircuitBuilder) -> Self {
Self {
@ -524,18 +541,112 @@ impl StatementTmplArgTarget {
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct PredicateHashOrWildcardTarget {
/// layout: `prefix | [data]`, where data is predicate_hash or wildcard_index
pub elements: [Target; Params::pred_hash_or_wc_size()],
}
impl PredicateHashOrWildcardTarget {
pub fn new(prefix: Target, data: ValueTarget) -> Self {
let v = data.elements;
Self {
elements: [prefix, v[0], v[1], v[2], v[3]],
}
}
pub fn new_pred_hash(builder: &mut CircuitBuilder, pred_hash: HashOutTarget) -> Self {
Self::new(
builder.constant(F::from(PredicateOrWildcardPrefix::Predicate)),
ValueTarget::from(pred_hash),
)
}
pub fn is_pred(&self, builder: &mut CircuitBuilder) -> BoolTarget {
let prefix_pred = builder.constant(F::from(PredicateOrWildcardPrefix::Predicate));
builder.is_equal(self.elements[0], prefix_pred)
}
pub fn data(&self) -> ValueTarget {
ValueTarget {
elements: self.elements[1..].try_into().expect("4 elements"),
}
}
pub fn pred_hash(&self) -> HashOutTarget {
HashOutTarget::from(self.data())
}
pub fn wc_index(&self) -> Target {
self.elements[1]
}
pub fn set_targets_raw(
&self,
pw: &mut PartialWitness<F>,
prefix: PredicateOrWildcardPrefix,
data: RawValue,
) -> Result<()> {
pw.set_target(self.elements[0], F::from(prefix))?;
pw.set_target_arr(&self.elements[1..], &data.0)?;
Ok(())
}
pub fn set_targets(
&self,
pw: &mut PartialWitness<F>,
params: &Params,
pred: &PredicateOrWildcard,
) -> Result<()> {
match pred {
PredicateOrWildcard::Predicate(pred) => {
self.set_targets_raw(
pw,
PredicateOrWildcardPrefix::Predicate,
RawValue::from(pred.hash(params)),
)?;
}
PredicateOrWildcard::Wildcard(wc) => {
self.set_targets_raw(
pw,
PredicateOrWildcardPrefix::Wildcard,
RawValue([F::from_canonical_usize(wc.index), F::ZERO, F::ZERO, F::ZERO]),
)?;
}
}
Ok(())
}
}
impl Flattenable for PredicateHashOrWildcardTarget {
fn flatten(&self) -> Vec<Target> {
self.elements.to_vec()
}
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
Self {
elements: vs.try_into().expect("5 elements"),
}
}
fn size(_params: &Params) -> usize {
Params::pred_hash_or_wc_size()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct StatementTmplTarget {
/// The preimage of the predicate_hash. This predicate is needed only to build the custom
/// predicate table because it needs to normalize statement templates with predicates that
/// refer to self into content-addressed predicates (using the batch id and index). The
/// predicate type is inspected to do this normalization. After the table is built we only use
/// the predicate hash for equality checks.
pred: Option<PredicateTarget>,
pred_hash: HashOutTarget,
/// This is constrained to be `hash(pred)` through the type constructor when we have `pred`
/// and the template uses a predicate and not a wildcard.
pred_hash_or_wc: PredicateHashOrWildcardTarget,
pub args: Vec<StatementTmplArgTarget>,
}
impl StatementTmplTarget {
pub fn new(pred_hash: HashOutTarget, args: Vec<StatementTmplArgTarget>) -> Self {
pub fn new(
pred_hash_or_wc: PredicateHashOrWildcardTarget,
args: Vec<StatementTmplArgTarget>,
) -> Self {
Self {
pred: None,
pred_hash,
pred_hash_or_wc,
args,
}
}
@ -546,9 +657,22 @@ impl StatementTmplTarget {
st_tmpl: &StatementTmpl,
) -> Result<()> {
if let Some(pred) = &self.pred {
pred.set_targets(pw, params, &st_tmpl.pred)?;
match &st_tmpl.pred_or_wc {
PredicateOrWildcard::Predicate(p) => {
// We store a predicate (not a wildcard) and we have it available. In this
// case the hash will be calculated by constraints later on and we should not
// rely on the original data.
pred.set_targets(pw, params, p)?
}
pw.set_hash_target(self.pred_hash, HashOut::from(st_tmpl.pred.hash(params)))?;
PredicateOrWildcard::Wildcard(_wc) => {
// Fill in with a recognizable constant for better debugging; this value is
// not supposed to be used.
pw.set_target_arr(&pred.elements, &[F(0xdead); Params::predicate_size()])?
}
}
}
self.pred_hash_or_wc
.set_targets(pw, params, &st_tmpl.pred_or_wc)?;
let arg_pad = StatementTmplArg::None;
for (i, arg) in st_tmpl
.args
@ -564,8 +688,8 @@ impl StatementTmplTarget {
pub fn pred(&self) -> Option<&PredicateTarget> {
self.pred.as_ref()
}
pub fn pred_hash(&self) -> &HashOutTarget {
&self.pred_hash
pub fn pred_hash_or_wc(&self) -> &PredicateHashOrWildcardTarget {
&self.pred_hash_or_wc
}
}
@ -603,6 +727,8 @@ impl CustomPredicateTarget {
}
}
/// This type is used to build the custom predicate table, which exposes the custom predicates with
/// normalized statement templates indexed by batch_id and custom_predicate_index.
#[derive(Clone, Serialize, Deserialize)]
pub struct CustomPredicateBatchTarget {
pub predicates: Vec<CustomPredicateTarget>,
@ -660,15 +786,17 @@ impl CustomPredicateEntryTarget {
.clone()
.into_iter()
.map(|st_tmpl| {
let pred = match st_tmpl.pred {
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
let pred_or_wc = match st_tmpl.pred_or_wc {
PredicateOrWildcard::Predicate(Predicate::BatchSelf(i)) => {
PredicateOrWildcard::Predicate(Predicate::Custom(CustomPredicateRef {
batch: batch.clone(),
index: i,
}),
p => p,
}))
}
x => x.clone(),
};
StatementTmpl {
pred,
pred_or_wc,
args: st_tmpl.args,
}
})
@ -724,7 +852,7 @@ pub struct CustomPredicateVerifyEntryTarget {
}
impl CustomPredicateVerifyEntryTarget {
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder, with_pred: bool) -> Self {
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
let custom_predicate_table_len =
params.max_custom_predicate_batches * params.max_custom_batch_size;
CustomPredicateVerifyEntryTarget {
@ -732,7 +860,7 @@ impl CustomPredicateVerifyEntryTarget {
custom_predicate_table_len,
builder,
),
custom_predicate: builder.add_virtual_custom_predicate_entry(params, with_pred),
custom_predicate: builder.add_virtual_custom_predicate_entry(params),
args: (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value())
.collect(),
@ -1062,7 +1190,7 @@ impl Flattenable for CustomPredicateTarget {
impl Flattenable for StatementTmplTarget {
fn flatten(&self) -> Vec<Target> {
self.pred_hash
self.pred_hash_or_wc
.flatten()
.into_iter()
.chain(self.args.iter().flat_map(|sta| sta.flatten()))
@ -1071,24 +1199,27 @@ impl Flattenable for StatementTmplTarget {
fn from_flattened(params: &Params, v: &[Target]) -> Self {
assert_eq!(v.len(), Self::size(params));
let pred_hash_end = HASH_SIZE;
let pred_hash = HashOutTarget::from_flattened(params, &v[..pred_hash_end]);
let pred_hash_or_wc_end = Params::pred_hash_or_wc_size();
let pred_hash_or_wc =
PredicateHashOrWildcardTarget::from_flattened(params, &v[..pred_hash_or_wc_end]);
let sta_size = Params::statement_tmpl_arg_size();
let args = (0..params.max_statement_args)
.map(|i| {
let sta_v = &v[pred_hash_end + sta_size * i..pred_hash_end + sta_size * (i + 1)];
let sta_v = &v
[pred_hash_or_wc_end + sta_size * i..pred_hash_or_wc_end + sta_size * (i + 1)];
StatementTmplArgTarget::from_flattened(params, sta_v)
})
.collect();
Self {
pred: None,
pred_hash,
pred_hash_or_wc,
args,
}
}
fn size(params: &Params) -> usize {
HASH_SIZE + params.max_statement_args * StatementTmplArgTarget::size(params)
Params::pred_hash_or_wc_size()
+ params.max_statement_args * StatementTmplArgTarget::size(params)
}
}
@ -1168,11 +1299,8 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
params: &Params,
with_pred: bool,
) -> CustomPredicateBatchTarget;
fn add_virtual_custom_predicate_entry(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateEntryTarget;
fn add_virtual_custom_predicate_entry(&mut self, params: &Params)
-> CustomPredicateEntryTarget;
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
fn select_statement_arg(
&mut self,
@ -1320,24 +1448,32 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
}
}
/// If `with_pred = true` a predicate is included and its hash constrained.
/// If `with_pred = true` a predicate is included.
/// If `with_pred = false` only the predicate hash is included.
/// The pred_hash is constrained to be hash(pred) conditionally on the template using a
/// predicate and not a wildcard.
fn add_virtual_statement_tmpl(
&mut self,
params: &Params,
with_pred: bool,
) -> StatementTmplTarget {
let (pred, pred_hash) = if with_pred {
let pred_hash_or_wc =
PredicateHashOrWildcardTarget::new(self.add_virtual_target(), self.add_virtual_value());
let pred = if with_pred {
let pred = self.add_virtual_predicate();
let pred_hash = pred.hash(self);
(Some(pred), pred_hash)
let is_pred = pred_hash_or_wc.is_pred(self);
let data = pred_hash_or_wc.data();
for i in 0..VALUE_SIZE {
self.conditional_assert_eq(is_pred.target, data.elements[i], pred_hash.elements[i]);
}
Some(pred)
} else {
let pred_hash = self.add_virtual_hash();
(None, pred_hash)
None
};
StatementTmplTarget {
pred,
pred_hash,
pred_hash_or_wc,
args: (0..params.max_statement_args)
.map(|_| self.add_virtual_statement_tmpl_arg())
.collect(),
@ -1377,12 +1513,11 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
fn add_virtual_custom_predicate_entry(
&mut self,
params: &Params,
with_pred: bool,
) -> CustomPredicateEntryTarget {
CustomPredicateEntryTarget {
id: self.add_virtual_hash(),
index: self.add_virtual_target(),
predicate: self.add_virtual_custom_predicate(params, with_pred),
predicate: self.add_virtual_custom_predicate(params, false),
}
}

View file

@ -25,8 +25,8 @@ use crate::{
CustomPredicateTarget, CustomPredicateVerifyEntryTarget,
CustomPredicateVerifyQueryTarget, Flattenable, MerkleClaimTarget,
MerkleTreeStateTransitionClaimTarget, OperationTarget, OperationTypeTarget,
PredicateTarget, StatementArgTarget, StatementTarget, StatementTmplArgTarget,
StatementTmplTarget, ValueTarget,
PredicateHashOrWildcardTarget, PredicateTarget, StatementArgTarget,
StatementTarget, StatementTmplArgTarget, StatementTmplTarget, ValueTarget,
},
hash::{hash_from_state_circuit, precompute_hash_state},
mux_table::{MuxTableTarget, TableEntryTarget},
@ -341,12 +341,7 @@ fn build_operation_aux_table_circuit(
.chain(signed_by.pk.u.components)
.collect(),
);
let entry: MsgPubKeyTarget = HashPairTarget(
HashOutTarget {
elements: signed_by.msg.elements,
},
pk_hash,
);
let entry: MsgPubKeyTarget = HashPairTarget(HashOutTarget::from(signed_by.msg), pk_hash);
table.push(builder, OperationAuxTableTag::SignedBy as u32, &entry);
measure_gates_end!(builder, measure);
@ -1381,6 +1376,26 @@ fn make_statement_arg_from_template_circuit(
StatementArgTarget::new(first, second)
}
fn make_predicate_from_template_circuit(
params: &Params,
builder: &mut CircuitBuilder,
pred_hash_or_wc: &PredicateHashOrWildcardTarget,
args: &[ValueTarget],
) -> HashOutTarget {
let zero = builder.zero();
let is_pred = pred_hash_or_wc.is_pred(builder);
// If the index is not used, use a 0 instead to still pass the range constraints from
// vec_ref
let index = builder.select(is_pred, zero, pred_hash_or_wc.wc_index());
let resolved_pred_hash = HashOutTarget::from(builder.vec_ref_small(params, args, index));
builder.select_flattenable(
params,
is_pred,
&pred_hash_or_wc.pred_hash(),
&resolved_pred_hash,
)
}
fn make_statement_from_template_circuit(
params: &Params,
builder: &mut CircuitBuilder,
@ -1388,7 +1403,7 @@ fn make_statement_from_template_circuit(
args: &[ValueTarget],
) -> StatementTarget {
let measure = measure_gates_begin!(builder, "StArgFromTmpl");
let args = st_tmpl
let st_args = st_tmpl
.args
.iter()
.map(|st_tmpl_arg| {
@ -1396,7 +1411,11 @@ fn make_statement_from_template_circuit(
})
.collect();
measure_gates_end!(builder, measure);
StatementTarget::new(*st_tmpl.pred_hash(), args)
let measure = measure_gates_begin!(builder, "PredFromTmpl");
let pred_hash =
make_predicate_from_template_circuit(params, builder, st_tmpl.pred_hash_or_wc(), args);
measure_gates_end!(builder, measure);
StatementTarget::new(pred_hash, st_args)
}
/// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard
@ -1527,13 +1546,29 @@ fn normalize_st_tmpl_circuit(
st_tmpl: &StatementTmplTarget,
id: HashOutTarget,
) -> StatementTmplTarget {
let pred = st_tmpl.pred().expect("StatementTmpl contains predicate");
// If the custom predicate is self, we normalize it and then hash it.
let old_pred = st_tmpl.pred().expect("StatementTmpl contains predicate");
let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf));
let is_batch_self = builder.is_equal(pred.elements[0], prefix_batch_self);
let pred_index = pred.elements[1];
let custom_pred = PredicateTarget::new_custom(builder, id, pred_index);
let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, pred);
StatementTmplTarget::new(pred.hash(builder), st_tmpl.args.clone())
let is_batch_self = builder.is_equal(old_pred.elements[0], prefix_batch_self);
let pred_index = old_pred.elements[1];
let normalized_custom_pred = PredicateTarget::new_custom(builder, id, pred_index);
let normalized_custom_pred_hash = normalized_custom_pred.hash(builder);
// If the template is using a predicate and it is batch self we use the freshly computed
// normalized predicate hash, otherwise we keep the original data.
let old_data = st_tmpl.pred_hash_or_wc().data();
let is_pred = st_tmpl.pred_hash_or_wc().is_pred(builder);
let is_pred_batch_self = builder.and(is_pred, is_batch_self);
let data = builder.select_flattenable(
params,
is_pred_batch_self,
&ValueTarget::from(normalized_custom_pred_hash),
&old_data,
);
let pred_hash_or_wc =
PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data);
StatementTmplTarget::new(pred_hash_or_wc, st_tmpl.args.clone())
}
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
@ -1773,7 +1808,7 @@ impl MainPodVerifyTarget {
.map(|_| builder.add_virtual_custom_predicate_batch(params, true))
.collect(),
custom_predicate_verifications: (0..params.max_custom_predicate_verifications)
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder, false))
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder))
.collect(),
}
}
@ -2012,8 +2047,8 @@ mod tests {
dict,
frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
middleware::{
hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, RawValue, StatementArg,
StatementTmpl, StatementTmplArg, Wildcard,
hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard,
RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard,
},
};
@ -3124,7 +3159,7 @@ mod tests {
let dict = Hash([F(6), F(7), F(8), F(9)]);
let st_tmpl = StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")),
StatementTmplArg::Literal(Value::from("value")),
@ -3137,6 +3172,21 @@ mod tests {
);
helper_statement_from_template(&params, st_tmpl, args, expected_st)?;
let st_tmpl = StatementTmpl {
pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)),
args: vec![
StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")),
StatementTmplArg::Literal(Value::from("value")),
],
};
let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(&params);
let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)];
let expected_st = Statement::not_equal(
AnchoredKey::new(dict, Key::from("key")),
Value::from("value"),
);
helper_statement_from_template(&params, st_tmpl, args, expected_st)?;
Ok(())
}
@ -3150,7 +3200,7 @@ mod tests {
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params, false);
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params);
let op_args_target: Vec<_> = (0..args.len())
.map(|_| builder.add_virtual_statement(params, false))
.collect();

View file

@ -7,7 +7,7 @@ use crate::{
frontend::{AnchoredKey, Error, Result, Statement, StatementArg},
middleware::{
self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params,
Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
Predicate, PredicateOrWildcard, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
},
};
@ -217,7 +217,8 @@ impl CustomPredicateBatchBuilder {
})
.collect::<Result<_>>()?;
Ok(StatementTmpl {
pred: stb.predicate.clone(),
// TODO: Support wildcard
pred_or_wc: PredicateOrWildcard::Predicate(stb.predicate.clone()),
args,
})
})
@ -319,7 +320,10 @@ mod tests {
// Check that the desugared predicate is the same as the one in the statement template
assert_eq!(
desugared_gt.predicate(),
*batch_clone.predicates()[0].statements[0].pred()
*batch_clone.predicates()[0].statements[0]
.pred_or_wc()
.as_pred()
.unwrap()
);
// Check that our custom predicate matches the statement template
@ -366,7 +370,10 @@ mod tests {
);
assert_eq!(
set_contains.predicate(),
*batch_clone.predicates()[0].statements[0].pred()
*batch_clone.predicates()[0].statements[0]
.pred_or_wc()
.as_pred()
.unwrap()
);
let set_contains_custom_pred = CustomPredicateRef::new(batch, 0);

View file

@ -3,7 +3,9 @@ use std::{collections::HashMap, fmt::Display};
use crate::{
frontend::{Error, Result},
lang::PrettyPrint,
middleware::{Pod, Statement, StatementArg, StatementTmpl, StatementTmplArg, Value},
middleware::{
Pod, PredicateOrWildcard, Statement, StatementArg, StatementTmpl, StatementTmplArg, Value,
},
};
/// Represents a request for a POD, in terms of a set of statement templates.
@ -76,7 +78,8 @@ impl PodRequest {
statement: &Statement,
current_bindings: &HashMap<String, Value>,
) -> Option<HashMap<String, Value>> {
if template.pred != statement.predicate() {
// TODO: Support wildcard
if template.pred_or_wc != PredicateOrWildcard::Predicate(statement.predicate()) {
return None;
}

View file

@ -18,8 +18,8 @@ use crate::{
},
middleware::{
self, containers, CustomPredicateBatch, IntroPredicateRef, NativePredicate, Params,
Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg,
Wildcard,
Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
StatementTmplArg as MWStatementTmplArg, Wildcard,
},
};
@ -201,7 +201,8 @@ impl<'a> Lowerer<'a> {
}
Ok(MWStatementTmpl {
pred: predicate,
// TODO: Support wildcard
pred_or_wc: PredicateOrWildcard::Predicate(predicate),
args: mw_args,
})
}
@ -596,7 +597,10 @@ mod tests {
let stmt = &pred2.statements()[0];
// Should be BatchSelf(0) referring to pred1
assert!(matches!(stmt.pred, Predicate::BatchSelf(0)));
assert!(matches!(
stmt.pred_or_wc,
PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
));
}
#[test]
@ -632,8 +636,8 @@ mod tests {
// Should desugar to the Contains predicate
assert!(matches!(
stmt.pred,
Predicate::Native(NativePredicate::Contains)
stmt.pred_or_wc,
PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Contains))
));
}

View file

@ -63,8 +63,8 @@ mod tests {
backends::plonky2::primitives::ec::schnorr::SecretKey,
middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate,
Params, Predicate, RawValue, StatementTmpl, StatementTmplArg, Value, Wildcard,
EMPTY_HASH,
Params, Predicate, PredicateOrWildcard, RawValue, StatementTmpl, StatementTmplArg,
Value, Wildcard, EMPTY_HASH,
},
};
@ -89,6 +89,10 @@ mod tests {
names.iter().map(|s| s.to_string()).collect()
}
fn pred_lit(pred: Predicate) -> PredicateOrWildcard {
PredicateOrWildcard::Predicate(pred)
}
#[test]
fn test_e2e_simple_predicate() -> Result<(), LangError> {
let input = r#"
@ -109,7 +113,7 @@ mod tests {
// Expected structure
let expected_statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("PodA", 0), "the_key"), // PodA["the_key"] -> Wildcard(0), Key("the_key")
sta_ak(("PodB", 1), "the_key"), // PodB["the_key"] -> Wildcard(1), Key("the_key")
@ -153,14 +157,14 @@ mod tests {
// Expected structure
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("ConstPod", 0), "my_val"), // ConstPod["my_val"] -> Wildcard(0), Key("my_val")
sta_lit(RawValue::from(1)),
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Lt),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)),
args: vec![
sta_ak(("GovPod", 1), "dob"), // GovPod["dob"] -> Wildcard(1), Key("dob")
sta_ak(("ConstPod", 0), "my_val"), // ConstPod["my_val"] -> Wildcard(0), Key("my_val")
@ -195,14 +199,14 @@ mod tests {
// Expected structure: Public args: A (index 0). Private args: Temp (index 1)
let expected_statements = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("A", 0), "input_key"), // A["input_key"] -> Wildcard(0), Key("input_key")
sta_ak(("Temp", 1), "const_key"), // Temp["const_key"] -> Wildcard(1), Key("const_key")
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("Temp", 1), "const_key"), // Temp["const_key"] -> Wildcard(1), Key("const_key")
sta_lit("some_value"), // Literal("some_value")
@ -251,7 +255,7 @@ mod tests {
// Expected Batch structure
let expected_pred_statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("X", 0), "val"), // X["val"] -> Wildcard(0), Key("val")
sta_ak(("Y", 1), "val"), // Y["val"] -> Wildcard(1), Key("val")
@ -275,7 +279,10 @@ mod tests {
// Expected Request structure
// Pod1 -> Wildcard 0, Pod2 -> Wildcard 1
let expected_request_templates = vec![StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(expected_batch, 0)),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
expected_batch,
0,
))),
args: vec![
StatementTmplArg::Wildcard(wc("Pod1", 0)),
StatementTmplArg::Wildcard(wc("Pod2", 1)),
@ -317,7 +324,7 @@ mod tests {
// Expected structure
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(batch_result, 0)), // Refers to some_pred
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch_result, 0))), // Refers to some_pred
args: vec![
StatementTmplArg::Wildcard(wc("Var1", 0)), // Var1
StatementTmplArg::Literal(Value::from(12345i64)), // 12345
@ -325,7 +332,7 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
// AnotherPod["another_key"] -> Wildcard(1), Key("another_key")
sta_ak(("AnotherPod", 1), "another_key"),
@ -362,15 +369,15 @@ mod tests {
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::LtEq),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::LtEq)),
args: vec![sta_ak(("B", 1), "bar"), sta_ak(("A", 0), "foo")],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Lt),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)),
args: vec![sta_ak(("D", 3), "qux"), sta_ak(("C", 2), "baz")],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Contains),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Contains)),
args: vec![
sta_ak(("A", 0), "foo"),
sta_ak(("B", 1), "bar"),
@ -378,11 +385,11 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::NotContains),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::NotContains)),
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Contains),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Contains)),
args: vec![
sta_ak(("A", 0), "foo"),
sta_ak(("B", 1), "bar"),
@ -439,7 +446,7 @@ mod tests {
let expected_templates = vec![
// 1. NotContains(sanctions["sanctionList"], gov["idNumber"])
StatementTmpl {
pred: Predicate::Native(NativePredicate::NotContains),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::NotContains)),
args: vec![
sta_ak(
(wc_sanctions.name.as_str(), wc_sanctions.index),
@ -450,7 +457,7 @@ mod tests {
},
// 2. Lt(gov["dateOfBirth"], SELF_HOLDER_18Y["const_18y"])
StatementTmpl {
pred: Predicate::Native(NativePredicate::Lt),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)),
args: vec![
sta_ak((wc_gov.name.as_str(), wc_gov.index), dob_key),
sta_ak(
@ -461,7 +468,7 @@ mod tests {
},
// 3. Equal(pay["startDate"], SELF_HOLDER_1Y["const_1y"])
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak((wc_pay.name.as_str(), wc_pay.index), start_date_key),
sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key),
@ -469,7 +476,7 @@ mod tests {
},
// 4. Equal(gov["socialSecurityNumber"], pay["socialSecurityNumber"])
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak((wc_gov.name.as_str(), wc_gov.index), ssn_key),
sta_ak((wc_pay.name.as_str(), wc_pay.index), ssn_key),
@ -477,7 +484,7 @@ mod tests {
},
// 5. Equal(SELF_HOLDER_18Y["const_18y"], 1169909388)
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(
(wc_self_18y.name.as_str(), wc_self_18y.index),
@ -488,7 +495,7 @@ mod tests {
},
// 6. Equal(SELF_HOLDER_1Y["const_1y"], 1706367566)
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key),
sta_lit(now_minus_1y_val.clone()),
@ -563,11 +570,11 @@ mod tests {
// eth_friend (Index 0)
let expected_friend_stmts = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::SignedBy),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::SignedBy)),
args: vec![sta_wc_lit("attestation_dict", 2), sta_wc_lit("src", 0)],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("attestation_dict", 2), "attestation"),
sta_wc_lit("dst", 1), // Pub arg 1
@ -586,11 +593,11 @@ mod tests {
// eth_dos_distance_base (Index 1)
let expected_base_stmts = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_wc_lit("src", 0), sta_wc_lit("dst", 1)],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_wc_lit("distance", 2), sta_lit(0i64)],
},
];
@ -608,7 +615,7 @@ mod tests {
// Private args indices: 3-4 (shorter_distance, intermed)
let expected_ind_stmts = vec![
StatementTmpl {
pred: Predicate::BatchSelf(3), // Calls eth_dos_distance (index 3)
pred_or_wc: pred_lit(Predicate::BatchSelf(3)), // Calls eth_dos_distance (index 3)
args: vec![
// WildcardLiteral args
sta_wc_lit("src", 0),
@ -617,7 +624,7 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::SumOf),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::SumOf)),
args: vec![
sta_wc_lit("distance", 2), // public arg
sta_wc_lit("shorter_distance", 3), // private arg
@ -625,7 +632,7 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::BatchSelf(0), // Calls eth_friend (index 0)
pred_or_wc: pred_lit(Predicate::BatchSelf(0)), // Calls eth_friend (index 0)
args: vec![
// WildcardLiteral args
sta_wc_lit("intermed", 4), // private arg
@ -645,7 +652,7 @@ mod tests {
// eth_dos_distance (Index 3)
let expected_dist_stmts = vec![
StatementTmpl {
pred: Predicate::BatchSelf(1), // Calls eth_dos_distance_base (index 1)
pred_or_wc: pred_lit(Predicate::BatchSelf(1)), // Calls eth_dos_distance_base (index 1)
args: vec![
// WildcardLiteral args
sta_wc_lit("src", 0),
@ -654,7 +661,7 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::BatchSelf(2), // Calls eth_dos_distance_ind (index 2)
pred_or_wc: pred_lit(Predicate::BatchSelf(2)), // Calls eth_dos_distance_ind (index 2)
args: vec![
// WildcardLiteral args
sta_wc_lit("src", 0),
@ -697,7 +704,7 @@ mod tests {
// 1. Create a batch to be imported
let imported_pred_stmts = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("A", 0), "foo"), // A["foo"]
sta_ak(("B", 1), "bar"), // B["bar"]
@ -739,7 +746,10 @@ mod tests {
// 4. Check the resulting request template
let expected_request_templates = vec![StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 0)),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch,
0,
))),
args: vec![
StatementTmplArg::Wildcard(wc("Pod1", 0)),
StatementTmplArg::Wildcard(wc("Pod2", 1)),
@ -788,11 +798,17 @@ mod tests {
// 4. Check the resulting request templates
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch.clone(),
0,
))),
args: vec![StatementTmplArg::Wildcard(wc("Pod1", 0))],
},
StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 2)),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch,
2,
))),
args: vec![StatementTmplArg::Wildcard(wc("Pod2", 1))],
},
];
@ -808,7 +824,7 @@ mod tests {
// 1. Create a batch with a predicate to be imported
let imported_pred_stmts = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
}];
let imported_predicate = CustomPredicate::and(
@ -855,7 +871,10 @@ mod tests {
assert_eq!(defined_pred.statements.len(), 1);
let expected_statement = StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch.clone(),
0,
))),
args: vec![
StatementTmplArg::Wildcard(wc("X", 0)),
StatementTmplArg::Wildcard(wc("Y", 1)),
@ -886,7 +905,9 @@ mod tests {
let request_templates = processed.request.templates();
assert_eq!(request_templates.len(), 1);
if let Predicate::Intro(intro_ref) = &request_templates[0].pred {
if let PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) =
&request_templates[0].pred_or_wc
{
assert_eq!(intro_ref.name, "empty");
assert_eq!(intro_ref.args_len, 0);
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
@ -944,27 +965,27 @@ mod tests {
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("A", 0), "pk"), sta_lit(Value::from(pk))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("B", 1), "raw"), sta_lit(Value::from(raw))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("C", 2), "string"), sta_lit(Value::from(string))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("D", 3), "int"), sta_lit(Value::from(int))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("E", 4), "bool"), sta_lit(Value::from(bool))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("F", 5), "sk"), sta_lit(Value::from(sk))],
},
];

View file

@ -5,7 +5,8 @@ use std::fmt::Write;
use crate::{
frontend::PodRequest,
middleware::{
CustomPredicate, CustomPredicateBatch, Predicate, StatementTmpl, StatementTmplArg, Value,
CustomPredicate, CustomPredicateBatch, Predicate, PredicateOrWildcard, StatementTmpl,
StatementTmplArg, Value,
},
};
@ -57,7 +58,8 @@ impl StatementTmpl {
w: &mut dyn Write,
batch_context: Option<&CustomPredicateBatch>,
) -> std::fmt::Result {
match &self.pred {
match &self.pred_or_wc {
PredicateOrWildcard::Predicate(pred) => match pred {
Predicate::Native(native_pred) => {
write!(w, "{}", native_pred)?;
}
@ -78,6 +80,11 @@ impl StatementTmpl {
write!(w, "batch_self_{}", index)?;
}
}
},
PredicateOrWildcard::Wildcard(wc) => {
// TODO: Decide the syntax for a wildcard predicate
write!(w, "?{}", wc.name)?;
}
}
write!(w, "(")?;
@ -223,13 +230,17 @@ mod tests {
Wildcard::new(name.to_string(), index)
}
fn pred_lit(pred: Predicate) -> PredicateOrWildcard {
PredicateOrWildcard::Predicate(pred)
}
#[test]
fn test_simple_predicate_pretty_print() {
let params = Params::default();
// Create a simple predicate: is_equal(PodA, PodB) = AND(Equal(PodA["key"], PodB["key"]))
let statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("PodA", 0),
@ -265,7 +276,7 @@ mod tests {
// Create: uses_private(A, private: Temp) = AND(Equal(A["input"], Temp["const"]))
let statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("A", 0),
@ -301,7 +312,7 @@ mod tests {
// Create: check_value(Pod) = AND(Equal(Pod["field"], 42))
let statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("Pod", 0),
@ -335,7 +346,7 @@ mod tests {
// Create: either_or(A, B) = OR(Equal(A["x"], 1), Equal(B["y"], 2))
let statements = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("A", 0),
@ -345,7 +356,7 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("B", 1),

View file

@ -1,5 +1,6 @@
use std::{fmt, iter, sync::Arc};
use itertools::Itertools;
use plonky2::field::types::Field;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -70,36 +71,28 @@ impl ToFields for StatementTmplArg {
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
match self {
StatementTmplArg::None => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::None))
StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None))
.chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect();
fields
}
StatementTmplArg::Literal(v) => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal))
.collect_vec(),
StatementTmplArg::Literal(v) => iter::once(F::from(StatementTmplArgPrefix::Literal))
.chain(v.raw().to_fields(params))
.chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect();
fields
}
.collect_vec(),
StatementTmplArg::AnchoredKey(wc1, kw2) => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::AnchoredKey))
iter::once(F::from(StatementTmplArgPrefix::AnchoredKey))
.chain(wc1.to_fields(params))
.chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1))
.chain(kw2.to_fields(params))
.collect();
fields
.collect_vec()
}
StatementTmplArg::Wildcard(wc) => {
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
.chain(wc.to_fields(params))
.chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect();
fields
.collect_vec()
}
}
}
@ -121,16 +114,79 @@ impl fmt::Display for StatementTmplArg {
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
pub enum PredicateOrWildcard {
Predicate(Predicate),
Wildcard(Wildcard),
}
impl PredicateOrWildcard {
pub fn as_pred(&self) -> Option<&Predicate> {
match self {
Self::Predicate(pred) => Some(pred),
_ => None,
}
}
pub fn as_wc(&self) -> Option<&Wildcard> {
match self {
Self::Wildcard(wc) => Some(wc),
_ => None,
}
}
}
impl fmt::Display for PredicateOrWildcard {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Predicate(pred) => pred.fmt(f),
Self::Wildcard(wc) => {
write!(f, "?")?;
wc.fmt(f)
}
}
}
}
#[derive(Clone, Copy)]
pub enum PredicateOrWildcardPrefix {
Predicate = 0,
Wildcard = 1,
}
impl From<PredicateOrWildcardPrefix> for F {
fn from(prefix: PredicateOrWildcardPrefix) -> Self {
Self::from_canonical_usize(prefix as usize)
}
}
impl ToFields for PredicateOrWildcard {
fn to_fields(&self, params: &Params) -> Vec<F> {
// Encoding:
// Predicate(pred) => (0, [hash(pred) ])
// Wildcard(wc) => (1, wc_index, 0...)
match self {
Self::Predicate(pred) => iter::once(F::from(PredicateOrWildcardPrefix::Predicate))
.chain(pred.hash(params).to_fields(params))
.collect_vec(),
Self::Wildcard(wc) => iter::once(F::from(PredicateOrWildcardPrefix::Wildcard))
.chain(wc.to_fields(params))
.chain(iter::repeat(F::ZERO))
.take(Params::pred_hash_or_wc_size())
.collect_vec(),
}
}
}
/// Statement Template for a Custom Predicate
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
pub struct StatementTmpl {
pub pred: Predicate,
pub pred_or_wc: PredicateOrWildcard,
pub args: Vec<StatementTmplArg>,
}
impl StatementTmpl {
pub fn pred(&self) -> &Predicate {
&self.pred
pub fn pred_or_wc(&self) -> &PredicateOrWildcard {
&self.pred_or_wc
}
pub fn args(&self) -> &[StatementTmplArg] {
&self.args
@ -139,7 +195,7 @@ impl StatementTmpl {
impl fmt::Display for StatementTmpl {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.pred.fmt(f)?;
self.pred_or_wc.fmt(f)?;
write!(f, "(")?;
for (i, arg) in self.args.iter().enumerate() {
if i != 0 {
@ -154,7 +210,7 @@ impl fmt::Display for StatementTmpl {
impl ToFields for StatementTmpl {
fn to_fields(&self, params: &Params) -> Vec<F> {
// serialize as:
// predicate (6 field elements)
// predicate (4 field elements)
// then the StatementTmplArgs
// TODO think if this check should go into the StatementTmpl creation,
@ -168,15 +224,13 @@ impl ToFields for StatementTmpl {
);
}
let mut fields: Vec<F> = self
.pred
.hash(params)
self.pred_or_wc
.to_fields(params)
.into_iter()
.chain(self.args.iter().flat_map(|sta| sta.to_fields(params)))
.collect();
fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0));
fields
.chain(iter::repeat(F::ZERO))
.take(params.statement_tmpl_size())
.collect_vec()
}
}
@ -203,7 +257,9 @@ impl CustomPredicate {
name: "empty".to_string(),
conjunction: false,
statements: vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::None),
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(
NativePredicate::None,
)),
args: vec![],
}],
args_len: 0,
@ -276,11 +332,11 @@ impl CustomPredicate {
}
pub fn pad_statement_tmpl(&self) -> StatementTmpl {
StatementTmpl {
pred: Predicate::Native(if self.conjunction {
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(if self.conjunction {
NativePredicate::None
} else {
NativePredicate::False
}),
})),
args: vec![],
}
}
@ -318,7 +374,7 @@ impl ToFields for CustomPredicate {
}
let pad_st = self.pad_statement_tmpl();
let fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
iter::once(F::from_bool(self.conjunction))
.chain(iter::once(F::from_canonical_usize(self.args_len)))
.chain(
self.statements
@ -327,8 +383,7 @@ impl ToFields for CustomPredicate {
.take(params.max_custom_predicate_arity)
.flat_map(|st| st.to_fields(params)),
)
.collect();
fields
.collect_vec()
}
}
@ -350,7 +405,7 @@ impl fmt::Display for CustomPredicate {
writeln!(f, ") = {}(", if self.conjunction { "AND" } else { "OR" })?;
for st in &self.statements {
write!(f, " ")?;
st.pred.fmt(f)?;
st.pred_or_wc.fmt(f)?;
write!(f, "(")?;
for (i, arg) in st.args.iter().enumerate() {
if i != 0 {
@ -382,14 +437,12 @@ impl ToFields for CustomPredicateBatch {
fn to_fields(&self, params: &Params) -> Vec<F> {
// all the custom predicates in order
let pad_pred = CustomPredicate::empty();
let fields: Vec<F> = self
.predicates
self.predicates
.iter()
.chain(iter::repeat(&pad_pred))
.take(params.max_custom_batch_size)
.flat_map(|p| p.to_fields(params))
.collect();
fields
.collect_vec()
}
}
@ -418,7 +471,6 @@ impl CustomPredicateBatch {
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
// but ideally we want to use the root of a merkle tree built from the custom predicates.
let input = self.to_fields(params);
hash_fields(&input)
}
@ -470,7 +522,10 @@ mod tests {
};
fn st(p: Predicate, args: Vec<StatementTmplArg>) -> StatementTmpl {
StatementTmpl { pred: p, args }
StatementTmpl {
pred_or_wc: PredicateOrWildcard::Predicate(p),
args,
}
}
fn key(name: &str) -> Key {

View file

@ -29,6 +29,8 @@ pub enum MiddlewareInnerError {
MismatchedStatementTmplArg(StatementTmplArg, StatementArg),
#[error("Expected a statement of type {0}, got {1}")]
MismatchedStatementType(Predicate, Predicate),
#[error("Expected a statement with hash(predicate) {0}, got {1} ({2})")]
MismatchedStatementWildcardPredicate(Value, Value, Predicate),
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
#[error(
@ -111,6 +113,15 @@ impl Error {
pub(crate) fn mismatched_statement_type(expected: Predicate, seen: Predicate) -> Self {
new!(MismatchedStatementType(expected, seen))
}
pub(crate) fn mismatched_statement_wc_pred(
expected: Value,
seen: Value,
seen_pred: Predicate,
) -> Self {
new!(MismatchedStatementWildcardPredicate(
expected, seen, seen_pred
))
}
pub(crate) fn mismatched_wildcard_value_and_statement_arg(
wc_value: Value,
st_arg: Value,

View file

@ -820,8 +820,12 @@ impl Params {
HASH_SIZE + STATEMENT_ARG_F_LEN * self.max_statement_args
}
pub const fn pred_hash_or_wc_size() -> usize {
1 + HASH_SIZE
}
pub const fn statement_tmpl_size(&self) -> usize {
HASH_SIZE + self.max_statement_args * Self::statement_tmpl_arg_size()
Self::pred_hash_or_wc_size() + self.max_statement_args * Self::statement_tmpl_arg_size()
}
pub fn custom_predicate_size(&self) -> usize {

View file

@ -15,8 +15,9 @@ use crate::{
},
middleware::{
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key,
MiddlewareInnerError, NativePredicate, Params, Predicate, Result, Statement, StatementArg,
StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F,
MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result,
Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value,
ValueRef, Wildcard, F,
},
};
@ -550,16 +551,8 @@ impl Operation {
}
}
/// Check that a StatementArg follows a StatementTmplArg based on the currently mapped wildcards.
/// Update the wildcard map with newly found wildcards.
pub fn check_st_tmpl(
st_tmpl_arg: &StatementTmplArg,
st_arg: &StatementArg,
// Map from wildcards to values that we have seen so far.
wildcard_map: &mut [Option<Value>],
) -> Result<()> {
// Check that the value `v` at wildcard `wc` exists in the map or set it.
fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> Result<()> {
fn wc_check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> Result<()> {
if let Some(prev) = &wildcard_map[wc.index] {
if *prev != v {
return Err(Error::invalid_wildcard_assignment(
@ -574,6 +567,14 @@ pub fn check_st_tmpl(
Ok(())
}
/// Check that a StatementArg follows a StatementTmplArg based on the currently mapped wildcards.
/// Update the wildcard map with newly found wildcards.
pub fn check_st_tmpl(
st_tmpl_arg: &StatementTmplArg,
st_arg: &StatementArg,
// Map from wildcards to values that we have seen so far.
wildcard_map: &mut [Option<Value>],
) -> Result<()> {
match (st_tmpl_arg, st_arg) {
(StatementTmplArg::None, StatementArg::None) => Ok(()),
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => Ok(()),
@ -581,7 +582,7 @@ pub fn check_st_tmpl(
StatementTmplArg::AnchoredKey(root_wc, key_tmpl),
StatementArg::Key(AnchoredKey { root, key }),
) => {
let root_ok = check_or_set(Value::from(*root), root_wc, wildcard_map);
let root_ok = wc_check_or_set(Value::from(*root), root_wc, wildcard_map);
root_ok.and_then(|_| {
(key_tmpl == key).then_some(()).ok_or(
Error::mismatched_anchored_key_in_statement_tmpl_arg(
@ -594,7 +595,7 @@ pub fn check_st_tmpl(
})
}
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
check_or_set(v.clone(), wc, wildcard_map)
wc_check_or_set(v.clone(), wc, wildcard_map)
}
_ => Err(Error::mismatched_statement_tmpl_arg(
st_tmpl_arg.clone(),
@ -604,12 +605,16 @@ pub fn check_st_tmpl(
}
pub fn fill_wildcard_values(
params: &Params,
pred: &CustomPredicate,
args: &[Statement],
wildcard_map: &mut [Option<Value>],
) -> Result<()> {
for (st_tmpl, st) in pred.statements.iter().zip(args) {
let st_args = st.args();
if let PredicateOrWildcard::Wildcard(wc) = &st_tmpl.pred_or_wc {
wc_check_or_set(Value::from(st.predicate().hash(params)), wc, wildcard_map)?;
}
st_tmpl
.args
.iter()
@ -633,7 +638,7 @@ pub fn wildcard_values_from_op_st(
.chain(core::iter::repeat(None))
.take(params.max_custom_predicate_wildcards)
.collect_vec();
fill_wildcard_values(pred, op_args, &mut wildcard_map)?;
fill_wildcard_values(params, pred, op_args, &mut wildcard_map)?;
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
// they are beyond the number of used wildcards in this custom predicate, or they could be
// private arguments that are unused in a particular disjunction.
@ -644,11 +649,15 @@ pub fn wildcard_values_from_op_st(
}
fn check_custom_pred_argument(
params: &Params,
custom_pred_ref: &CustomPredicateRef,
template: &StatementTmpl,
statement: &Statement,
wc_values: &[Value],
) -> Result<()> {
let template_pred = match &template.pred {
match &template.pred_or_wc {
PredicateOrWildcard::Predicate(pred) => {
let template_pred = match pred {
&Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
batch: custom_pred_ref.batch.clone(),
index: i,
@ -661,6 +670,18 @@ fn check_custom_pred_argument(
statement.predicate(),
));
}
}
PredicateOrWildcard::Wildcard(wc) => {
let pred_hash = Value::from(statement.predicate().hash(params));
if wc_values[wc.index] != pred_hash {
return Err(Error::mismatched_statement_wc_pred(
wc_values[wc.index].clone(),
pred_hash,
statement.predicate(),
));
}
}
}
let st_args_len = statement.args().len();
if template.args.len() != st_args_len {
return Err(Error::diff_amount(
@ -697,17 +718,42 @@ pub(crate) fn check_custom_pred(
));
}
// Check that the resolved wildcards match the statement arguments.
let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) {
Ok(wc_values) => wc_values,
Err(Error::Inner { inner, backtrace }) => match *inner {
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
if wc.index <= s_args.len() =>
{
return Err(Error::mismatched_wildcard_value_and_statement_arg(
v,
prev,
wc.index,
pred.clone(),
))
}
_ => return Err(Error::Inner { inner, backtrace }),
},
_ => unreachable!(),
};
let mut match_exists = false;
for (st_tmpl, st) in pred.statements.iter().zip(args) {
// For `or` predicates, only one statement needs to match the template.
// The rest of the statements can be `None`.
if !pred.conjunction
&& matches!(st, Statement::None)
&& st_tmpl.pred != Predicate::Native(NativePredicate::None)
{
let expected_pred_is_none = match &st_tmpl.pred_or_wc {
PredicateOrWildcard::Predicate(st_tmpl_pred) => {
*st_tmpl_pred == Predicate::Native(NativePredicate::None)
}
PredicateOrWildcard::Wildcard(wc) => {
wc_values[wc.index]
== Value::from(Predicate::Native(NativePredicate::None).hash(params))
}
};
if !pred.conjunction && matches!(st, Statement::None) && !expected_pred_is_none {
continue;
}
check_custom_pred_argument(custom_pred_ref, st_tmpl, st)?;
check_custom_pred_argument(params, custom_pred_ref, st_tmpl, st, &wc_values)?;
match_exists = true;
}
@ -716,25 +762,7 @@ pub(crate) fn check_custom_pred(
pred.clone(),
));
}
// Check that the resolved wildcards match the statement arguments.
match wildcard_values_from_op_st(params, pred, args, s_args) {
Ok(_) => Ok(()),
Err(Error::Inner { inner, backtrace }) => match *inner {
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
if wc.index <= s_args.len() =>
{
Err(Error::mismatched_wildcard_value_and_statement_arg(
v,
prev,
wc.index,
pred.clone(),
))
}
_ => Err(Error::Inner { inner, backtrace }),
},
_ => unreachable!(),
}
Ok(())
}
impl ToFields for Operation {