Feat/fst order pred part1 & part2 (#454)
Implement support for first order predicates in the backend. Now a statement template can have a predicate hash or a wildcard. ## predicate <-> predicate hash constraints To build the custom predicate table we need to calculate the custom predicate batch id, which uses the serialization of the statement templates before normalization. This serialization uses the predicate hash when the template uses a predicate (instead of a wildcard). Then in normalization we recalculate the predicate hash if it was a Batch Self. This means that the relation between hash and predicate must be checked before and after normalization when the template is not using a wildcard. How this is achieved: - Before normalization: the constructor of StatementTmplTarget forces that if we keep a predicate, it's hash must be equal to the pred_hash when the template has a predicate (and not a wildcard) - After normalization: the predicate hash is calculated in the normalization and replaced in the case of the template using a predicate and it being a BatchSelf. If it was a predicate but not batch self, the old value was used which was constrained via the constructor. See `CircuitBuilder::add_virtual_statement_tmpl` and `normalize_st_tmpl_circuit` ## Wildcard predicate resolution It is done via `make_predicate_from_template_circuit` and is fairly simple as it's contains similar logic to `make_statement_arg_from_template_circuit` but simpler.
This commit is contained in:
parent
1724e7b146
commit
9c9a2c454c
11 changed files with 569 additions and 240 deletions
|
|
@ -32,9 +32,10 @@ use crate::{
|
||||||
},
|
},
|
||||||
middleware::{
|
middleware::{
|
||||||
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
|
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
|
||||||
NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg,
|
NativePredicate, OperationType, Params, Predicate, PredicateOrWildcard,
|
||||||
StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F,
|
PredicateOrWildcardPrefix, PredicatePrefix, RawValue, StatementArg, StatementTmpl,
|
||||||
HASH_SIZE, STATEMENT_ARG_F_LEN, VALUE_SIZE,
|
StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE,
|
||||||
|
STATEMENT_ARG_F_LEN, VALUE_SIZE,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -46,6 +47,22 @@ pub struct ValueTarget {
|
||||||
pub elements: [Target; VALUE_SIZE],
|
pub elements: [Target; VALUE_SIZE],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<ValueTarget> for HashOutTarget {
|
||||||
|
fn from(v: ValueTarget) -> HashOutTarget {
|
||||||
|
HashOutTarget {
|
||||||
|
elements: v.elements,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<HashOutTarget> for ValueTarget {
|
||||||
|
fn from(h: HashOutTarget) -> ValueTarget {
|
||||||
|
ValueTarget {
|
||||||
|
elements: h.elements,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ValueTarget {
|
impl ValueTarget {
|
||||||
pub fn zero(builder: &mut CircuitBuilder) -> Self {
|
pub fn zero(builder: &mut CircuitBuilder) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|
@ -524,18 +541,112 @@ impl StatementTmplArgTarget {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PredicateHashOrWildcardTarget {
|
||||||
|
/// layout: `prefix | [data]`, where data is predicate_hash or wildcard_index
|
||||||
|
pub elements: [Target; Params::pred_hash_or_wc_size()],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PredicateHashOrWildcardTarget {
|
||||||
|
pub fn new(prefix: Target, data: ValueTarget) -> Self {
|
||||||
|
let v = data.elements;
|
||||||
|
Self {
|
||||||
|
elements: [prefix, v[0], v[1], v[2], v[3]],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn new_pred_hash(builder: &mut CircuitBuilder, pred_hash: HashOutTarget) -> Self {
|
||||||
|
Self::new(
|
||||||
|
builder.constant(F::from(PredicateOrWildcardPrefix::Predicate)),
|
||||||
|
ValueTarget::from(pred_hash),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
pub fn is_pred(&self, builder: &mut CircuitBuilder) -> BoolTarget {
|
||||||
|
let prefix_pred = builder.constant(F::from(PredicateOrWildcardPrefix::Predicate));
|
||||||
|
builder.is_equal(self.elements[0], prefix_pred)
|
||||||
|
}
|
||||||
|
pub fn data(&self) -> ValueTarget {
|
||||||
|
ValueTarget {
|
||||||
|
elements: self.elements[1..].try_into().expect("4 elements"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn pred_hash(&self) -> HashOutTarget {
|
||||||
|
HashOutTarget::from(self.data())
|
||||||
|
}
|
||||||
|
pub fn wc_index(&self) -> Target {
|
||||||
|
self.elements[1]
|
||||||
|
}
|
||||||
|
pub fn set_targets_raw(
|
||||||
|
&self,
|
||||||
|
pw: &mut PartialWitness<F>,
|
||||||
|
prefix: PredicateOrWildcardPrefix,
|
||||||
|
data: RawValue,
|
||||||
|
) -> Result<()> {
|
||||||
|
pw.set_target(self.elements[0], F::from(prefix))?;
|
||||||
|
pw.set_target_arr(&self.elements[1..], &data.0)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
pub fn set_targets(
|
||||||
|
&self,
|
||||||
|
pw: &mut PartialWitness<F>,
|
||||||
|
params: &Params,
|
||||||
|
pred: &PredicateOrWildcard,
|
||||||
|
) -> Result<()> {
|
||||||
|
match pred {
|
||||||
|
PredicateOrWildcard::Predicate(pred) => {
|
||||||
|
self.set_targets_raw(
|
||||||
|
pw,
|
||||||
|
PredicateOrWildcardPrefix::Predicate,
|
||||||
|
RawValue::from(pred.hash(params)),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
PredicateOrWildcard::Wildcard(wc) => {
|
||||||
|
self.set_targets_raw(
|
||||||
|
pw,
|
||||||
|
PredicateOrWildcardPrefix::Wildcard,
|
||||||
|
RawValue([F::from_canonical_usize(wc.index), F::ZERO, F::ZERO, F::ZERO]),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Flattenable for PredicateHashOrWildcardTarget {
|
||||||
|
fn flatten(&self) -> Vec<Target> {
|
||||||
|
self.elements.to_vec()
|
||||||
|
}
|
||||||
|
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
|
||||||
|
Self {
|
||||||
|
elements: vs.try_into().expect("5 elements"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn size(_params: &Params) -> usize {
|
||||||
|
Params::pred_hash_or_wc_size()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
pub struct StatementTmplTarget {
|
pub struct StatementTmplTarget {
|
||||||
|
/// The preimage of the predicate_hash. This predicate is needed only to build the custom
|
||||||
|
/// predicate table because it needs to normalize statement templates with predicates that
|
||||||
|
/// refer to self into content-addressed predicates (using the batch id and index). The
|
||||||
|
/// predicate type is inspected to do this normalization. After the table is built we only use
|
||||||
|
/// the predicate hash for equality checks.
|
||||||
pred: Option<PredicateTarget>,
|
pred: Option<PredicateTarget>,
|
||||||
pred_hash: HashOutTarget,
|
/// This is constrained to be `hash(pred)` through the type constructor when we have `pred`
|
||||||
|
/// and the template uses a predicate and not a wildcard.
|
||||||
|
pred_hash_or_wc: PredicateHashOrWildcardTarget,
|
||||||
pub args: Vec<StatementTmplArgTarget>,
|
pub args: Vec<StatementTmplArgTarget>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StatementTmplTarget {
|
impl StatementTmplTarget {
|
||||||
pub fn new(pred_hash: HashOutTarget, args: Vec<StatementTmplArgTarget>) -> Self {
|
pub fn new(
|
||||||
|
pred_hash_or_wc: PredicateHashOrWildcardTarget,
|
||||||
|
args: Vec<StatementTmplArgTarget>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
pred: None,
|
pred: None,
|
||||||
pred_hash,
|
pred_hash_or_wc,
|
||||||
args,
|
args,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -546,9 +657,22 @@ impl StatementTmplTarget {
|
||||||
st_tmpl: &StatementTmpl,
|
st_tmpl: &StatementTmpl,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if let Some(pred) = &self.pred {
|
if let Some(pred) = &self.pred {
|
||||||
pred.set_targets(pw, params, &st_tmpl.pred)?;
|
match &st_tmpl.pred_or_wc {
|
||||||
|
PredicateOrWildcard::Predicate(p) => {
|
||||||
|
// We store a predicate (not a wildcard) and we have it available. In this
|
||||||
|
// case the hash will be calculated by constraints later on and we should not
|
||||||
|
// rely on the original data.
|
||||||
|
pred.set_targets(pw, params, p)?
|
||||||
}
|
}
|
||||||
pw.set_hash_target(self.pred_hash, HashOut::from(st_tmpl.pred.hash(params)))?;
|
PredicateOrWildcard::Wildcard(_wc) => {
|
||||||
|
// Fill in with a recognizable constant for better debugging; this value is
|
||||||
|
// not supposed to be used.
|
||||||
|
pw.set_target_arr(&pred.elements, &[F(0xdead); Params::predicate_size()])?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.pred_hash_or_wc
|
||||||
|
.set_targets(pw, params, &st_tmpl.pred_or_wc)?;
|
||||||
let arg_pad = StatementTmplArg::None;
|
let arg_pad = StatementTmplArg::None;
|
||||||
for (i, arg) in st_tmpl
|
for (i, arg) in st_tmpl
|
||||||
.args
|
.args
|
||||||
|
|
@ -564,8 +688,8 @@ impl StatementTmplTarget {
|
||||||
pub fn pred(&self) -> Option<&PredicateTarget> {
|
pub fn pred(&self) -> Option<&PredicateTarget> {
|
||||||
self.pred.as_ref()
|
self.pred.as_ref()
|
||||||
}
|
}
|
||||||
pub fn pred_hash(&self) -> &HashOutTarget {
|
pub fn pred_hash_or_wc(&self) -> &PredicateHashOrWildcardTarget {
|
||||||
&self.pred_hash
|
&self.pred_hash_or_wc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -603,6 +727,8 @@ impl CustomPredicateTarget {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// This type is used to build the custom predicate table, which exposes the custom predicates with
|
||||||
|
/// normalized statement templates indexed by batch_id and custom_predicate_index.
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
pub struct CustomPredicateBatchTarget {
|
pub struct CustomPredicateBatchTarget {
|
||||||
pub predicates: Vec<CustomPredicateTarget>,
|
pub predicates: Vec<CustomPredicateTarget>,
|
||||||
|
|
@ -660,15 +786,17 @@ impl CustomPredicateEntryTarget {
|
||||||
.clone()
|
.clone()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|st_tmpl| {
|
.map(|st_tmpl| {
|
||||||
let pred = match st_tmpl.pred {
|
let pred_or_wc = match st_tmpl.pred_or_wc {
|
||||||
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
|
PredicateOrWildcard::Predicate(Predicate::BatchSelf(i)) => {
|
||||||
|
PredicateOrWildcard::Predicate(Predicate::Custom(CustomPredicateRef {
|
||||||
batch: batch.clone(),
|
batch: batch.clone(),
|
||||||
index: i,
|
index: i,
|
||||||
}),
|
}))
|
||||||
p => p,
|
}
|
||||||
|
x => x.clone(),
|
||||||
};
|
};
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred,
|
pred_or_wc,
|
||||||
args: st_tmpl.args,
|
args: st_tmpl.args,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -724,7 +852,7 @@ pub struct CustomPredicateVerifyEntryTarget {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CustomPredicateVerifyEntryTarget {
|
impl CustomPredicateVerifyEntryTarget {
|
||||||
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder, with_pred: bool) -> Self {
|
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
|
||||||
let custom_predicate_table_len =
|
let custom_predicate_table_len =
|
||||||
params.max_custom_predicate_batches * params.max_custom_batch_size;
|
params.max_custom_predicate_batches * params.max_custom_batch_size;
|
||||||
CustomPredicateVerifyEntryTarget {
|
CustomPredicateVerifyEntryTarget {
|
||||||
|
|
@ -732,7 +860,7 @@ impl CustomPredicateVerifyEntryTarget {
|
||||||
custom_predicate_table_len,
|
custom_predicate_table_len,
|
||||||
builder,
|
builder,
|
||||||
),
|
),
|
||||||
custom_predicate: builder.add_virtual_custom_predicate_entry(params, with_pred),
|
custom_predicate: builder.add_virtual_custom_predicate_entry(params),
|
||||||
args: (0..params.max_custom_predicate_wildcards)
|
args: (0..params.max_custom_predicate_wildcards)
|
||||||
.map(|_| builder.add_virtual_value())
|
.map(|_| builder.add_virtual_value())
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
@ -1062,7 +1190,7 @@ impl Flattenable for CustomPredicateTarget {
|
||||||
|
|
||||||
impl Flattenable for StatementTmplTarget {
|
impl Flattenable for StatementTmplTarget {
|
||||||
fn flatten(&self) -> Vec<Target> {
|
fn flatten(&self) -> Vec<Target> {
|
||||||
self.pred_hash
|
self.pred_hash_or_wc
|
||||||
.flatten()
|
.flatten()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.chain(self.args.iter().flat_map(|sta| sta.flatten()))
|
.chain(self.args.iter().flat_map(|sta| sta.flatten()))
|
||||||
|
|
@ -1071,24 +1199,27 @@ impl Flattenable for StatementTmplTarget {
|
||||||
|
|
||||||
fn from_flattened(params: &Params, v: &[Target]) -> Self {
|
fn from_flattened(params: &Params, v: &[Target]) -> Self {
|
||||||
assert_eq!(v.len(), Self::size(params));
|
assert_eq!(v.len(), Self::size(params));
|
||||||
let pred_hash_end = HASH_SIZE;
|
let pred_hash_or_wc_end = Params::pred_hash_or_wc_size();
|
||||||
let pred_hash = HashOutTarget::from_flattened(params, &v[..pred_hash_end]);
|
let pred_hash_or_wc =
|
||||||
|
PredicateHashOrWildcardTarget::from_flattened(params, &v[..pred_hash_or_wc_end]);
|
||||||
let sta_size = Params::statement_tmpl_arg_size();
|
let sta_size = Params::statement_tmpl_arg_size();
|
||||||
let args = (0..params.max_statement_args)
|
let args = (0..params.max_statement_args)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let sta_v = &v[pred_hash_end + sta_size * i..pred_hash_end + sta_size * (i + 1)];
|
let sta_v = &v
|
||||||
|
[pred_hash_or_wc_end + sta_size * i..pred_hash_or_wc_end + sta_size * (i + 1)];
|
||||||
StatementTmplArgTarget::from_flattened(params, sta_v)
|
StatementTmplArgTarget::from_flattened(params, sta_v)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
Self {
|
Self {
|
||||||
pred: None,
|
pred: None,
|
||||||
pred_hash,
|
pred_hash_or_wc,
|
||||||
args,
|
args,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size(params: &Params) -> usize {
|
fn size(params: &Params) -> usize {
|
||||||
HASH_SIZE + params.max_statement_args * StatementTmplArgTarget::size(params)
|
Params::pred_hash_or_wc_size()
|
||||||
|
+ params.max_statement_args * StatementTmplArgTarget::size(params)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1168,11 +1299,8 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
|
||||||
params: &Params,
|
params: &Params,
|
||||||
with_pred: bool,
|
with_pred: bool,
|
||||||
) -> CustomPredicateBatchTarget;
|
) -> CustomPredicateBatchTarget;
|
||||||
fn add_virtual_custom_predicate_entry(
|
fn add_virtual_custom_predicate_entry(&mut self, params: &Params)
|
||||||
&mut self,
|
-> CustomPredicateEntryTarget;
|
||||||
params: &Params,
|
|
||||||
with_pred: bool,
|
|
||||||
) -> CustomPredicateEntryTarget;
|
|
||||||
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
|
||||||
fn select_statement_arg(
|
fn select_statement_arg(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
|
@ -1320,24 +1448,32 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If `with_pred = true` a predicate is included and its hash constrained.
|
/// If `with_pred = true` a predicate is included.
|
||||||
/// If `with_pred = false` only the predicate hash is included.
|
/// If `with_pred = false` only the predicate hash is included.
|
||||||
|
/// The pred_hash is constrained to be hash(pred) conditionally on the template using a
|
||||||
|
/// predicate and not a wildcard.
|
||||||
fn add_virtual_statement_tmpl(
|
fn add_virtual_statement_tmpl(
|
||||||
&mut self,
|
&mut self,
|
||||||
params: &Params,
|
params: &Params,
|
||||||
with_pred: bool,
|
with_pred: bool,
|
||||||
) -> StatementTmplTarget {
|
) -> StatementTmplTarget {
|
||||||
let (pred, pred_hash) = if with_pred {
|
let pred_hash_or_wc =
|
||||||
|
PredicateHashOrWildcardTarget::new(self.add_virtual_target(), self.add_virtual_value());
|
||||||
|
let pred = if with_pred {
|
||||||
let pred = self.add_virtual_predicate();
|
let pred = self.add_virtual_predicate();
|
||||||
let pred_hash = pred.hash(self);
|
let pred_hash = pred.hash(self);
|
||||||
(Some(pred), pred_hash)
|
let is_pred = pred_hash_or_wc.is_pred(self);
|
||||||
|
let data = pred_hash_or_wc.data();
|
||||||
|
for i in 0..VALUE_SIZE {
|
||||||
|
self.conditional_assert_eq(is_pred.target, data.elements[i], pred_hash.elements[i]);
|
||||||
|
}
|
||||||
|
Some(pred)
|
||||||
} else {
|
} else {
|
||||||
let pred_hash = self.add_virtual_hash();
|
None
|
||||||
(None, pred_hash)
|
|
||||||
};
|
};
|
||||||
StatementTmplTarget {
|
StatementTmplTarget {
|
||||||
pred,
|
pred,
|
||||||
pred_hash,
|
pred_hash_or_wc,
|
||||||
args: (0..params.max_statement_args)
|
args: (0..params.max_statement_args)
|
||||||
.map(|_| self.add_virtual_statement_tmpl_arg())
|
.map(|_| self.add_virtual_statement_tmpl_arg())
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
@ -1377,12 +1513,11 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
|
||||||
fn add_virtual_custom_predicate_entry(
|
fn add_virtual_custom_predicate_entry(
|
||||||
&mut self,
|
&mut self,
|
||||||
params: &Params,
|
params: &Params,
|
||||||
with_pred: bool,
|
|
||||||
) -> CustomPredicateEntryTarget {
|
) -> CustomPredicateEntryTarget {
|
||||||
CustomPredicateEntryTarget {
|
CustomPredicateEntryTarget {
|
||||||
id: self.add_virtual_hash(),
|
id: self.add_virtual_hash(),
|
||||||
index: self.add_virtual_target(),
|
index: self.add_virtual_target(),
|
||||||
predicate: self.add_virtual_custom_predicate(params, with_pred),
|
predicate: self.add_virtual_custom_predicate(params, false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,8 @@ use crate::{
|
||||||
CustomPredicateTarget, CustomPredicateVerifyEntryTarget,
|
CustomPredicateTarget, CustomPredicateVerifyEntryTarget,
|
||||||
CustomPredicateVerifyQueryTarget, Flattenable, MerkleClaimTarget,
|
CustomPredicateVerifyQueryTarget, Flattenable, MerkleClaimTarget,
|
||||||
MerkleTreeStateTransitionClaimTarget, OperationTarget, OperationTypeTarget,
|
MerkleTreeStateTransitionClaimTarget, OperationTarget, OperationTypeTarget,
|
||||||
PredicateTarget, StatementArgTarget, StatementTarget, StatementTmplArgTarget,
|
PredicateHashOrWildcardTarget, PredicateTarget, StatementArgTarget,
|
||||||
StatementTmplTarget, ValueTarget,
|
StatementTarget, StatementTmplArgTarget, StatementTmplTarget, ValueTarget,
|
||||||
},
|
},
|
||||||
hash::{hash_from_state_circuit, precompute_hash_state},
|
hash::{hash_from_state_circuit, precompute_hash_state},
|
||||||
mux_table::{MuxTableTarget, TableEntryTarget},
|
mux_table::{MuxTableTarget, TableEntryTarget},
|
||||||
|
|
@ -341,12 +341,7 @@ fn build_operation_aux_table_circuit(
|
||||||
.chain(signed_by.pk.u.components)
|
.chain(signed_by.pk.u.components)
|
||||||
.collect(),
|
.collect(),
|
||||||
);
|
);
|
||||||
let entry: MsgPubKeyTarget = HashPairTarget(
|
let entry: MsgPubKeyTarget = HashPairTarget(HashOutTarget::from(signed_by.msg), pk_hash);
|
||||||
HashOutTarget {
|
|
||||||
elements: signed_by.msg.elements,
|
|
||||||
},
|
|
||||||
pk_hash,
|
|
||||||
);
|
|
||||||
|
|
||||||
table.push(builder, OperationAuxTableTag::SignedBy as u32, &entry);
|
table.push(builder, OperationAuxTableTag::SignedBy as u32, &entry);
|
||||||
measure_gates_end!(builder, measure);
|
measure_gates_end!(builder, measure);
|
||||||
|
|
@ -1381,6 +1376,26 @@ fn make_statement_arg_from_template_circuit(
|
||||||
StatementArgTarget::new(first, second)
|
StatementArgTarget::new(first, second)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_predicate_from_template_circuit(
|
||||||
|
params: &Params,
|
||||||
|
builder: &mut CircuitBuilder,
|
||||||
|
pred_hash_or_wc: &PredicateHashOrWildcardTarget,
|
||||||
|
args: &[ValueTarget],
|
||||||
|
) -> HashOutTarget {
|
||||||
|
let zero = builder.zero();
|
||||||
|
let is_pred = pred_hash_or_wc.is_pred(builder);
|
||||||
|
// If the index is not used, use a 0 instead to still pass the range constraints from
|
||||||
|
// vec_ref
|
||||||
|
let index = builder.select(is_pred, zero, pred_hash_or_wc.wc_index());
|
||||||
|
let resolved_pred_hash = HashOutTarget::from(builder.vec_ref_small(params, args, index));
|
||||||
|
builder.select_flattenable(
|
||||||
|
params,
|
||||||
|
is_pred,
|
||||||
|
&pred_hash_or_wc.pred_hash(),
|
||||||
|
&resolved_pred_hash,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn make_statement_from_template_circuit(
|
fn make_statement_from_template_circuit(
|
||||||
params: &Params,
|
params: &Params,
|
||||||
builder: &mut CircuitBuilder,
|
builder: &mut CircuitBuilder,
|
||||||
|
|
@ -1388,7 +1403,7 @@ fn make_statement_from_template_circuit(
|
||||||
args: &[ValueTarget],
|
args: &[ValueTarget],
|
||||||
) -> StatementTarget {
|
) -> StatementTarget {
|
||||||
let measure = measure_gates_begin!(builder, "StArgFromTmpl");
|
let measure = measure_gates_begin!(builder, "StArgFromTmpl");
|
||||||
let args = st_tmpl
|
let st_args = st_tmpl
|
||||||
.args
|
.args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|st_tmpl_arg| {
|
.map(|st_tmpl_arg| {
|
||||||
|
|
@ -1396,7 +1411,11 @@ fn make_statement_from_template_circuit(
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
measure_gates_end!(builder, measure);
|
measure_gates_end!(builder, measure);
|
||||||
StatementTarget::new(*st_tmpl.pred_hash(), args)
|
let measure = measure_gates_begin!(builder, "PredFromTmpl");
|
||||||
|
let pred_hash =
|
||||||
|
make_predicate_from_template_circuit(params, builder, st_tmpl.pred_hash_or_wc(), args);
|
||||||
|
measure_gates_end!(builder, measure);
|
||||||
|
StatementTarget::new(pred_hash, st_args)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard
|
/// Given a custom predicate, a list of operation arguments (statements) and a list of wildcard
|
||||||
|
|
@ -1527,13 +1546,29 @@ fn normalize_st_tmpl_circuit(
|
||||||
st_tmpl: &StatementTmplTarget,
|
st_tmpl: &StatementTmplTarget,
|
||||||
id: HashOutTarget,
|
id: HashOutTarget,
|
||||||
) -> StatementTmplTarget {
|
) -> StatementTmplTarget {
|
||||||
let pred = st_tmpl.pred().expect("StatementTmpl contains predicate");
|
// If the custom predicate is self, we normalize it and then hash it.
|
||||||
|
let old_pred = st_tmpl.pred().expect("StatementTmpl contains predicate");
|
||||||
let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf));
|
let prefix_batch_self = builder.constant(F::from(PredicatePrefix::BatchSelf));
|
||||||
let is_batch_self = builder.is_equal(pred.elements[0], prefix_batch_self);
|
let is_batch_self = builder.is_equal(old_pred.elements[0], prefix_batch_self);
|
||||||
let pred_index = pred.elements[1];
|
|
||||||
let custom_pred = PredicateTarget::new_custom(builder, id, pred_index);
|
let pred_index = old_pred.elements[1];
|
||||||
let pred = builder.select_flattenable(params, is_batch_self, &custom_pred, pred);
|
let normalized_custom_pred = PredicateTarget::new_custom(builder, id, pred_index);
|
||||||
StatementTmplTarget::new(pred.hash(builder), st_tmpl.args.clone())
|
let normalized_custom_pred_hash = normalized_custom_pred.hash(builder);
|
||||||
|
|
||||||
|
// If the template is using a predicate and it is batch self we use the freshly computed
|
||||||
|
// normalized predicate hash, otherwise we keep the original data.
|
||||||
|
let old_data = st_tmpl.pred_hash_or_wc().data();
|
||||||
|
let is_pred = st_tmpl.pred_hash_or_wc().is_pred(builder);
|
||||||
|
let is_pred_batch_self = builder.and(is_pred, is_batch_self);
|
||||||
|
let data = builder.select_flattenable(
|
||||||
|
params,
|
||||||
|
is_pred_batch_self,
|
||||||
|
&ValueTarget::from(normalized_custom_pred_hash),
|
||||||
|
&old_data,
|
||||||
|
);
|
||||||
|
let pred_hash_or_wc =
|
||||||
|
PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data);
|
||||||
|
StatementTmplTarget::new(pred_hash_or_wc, st_tmpl.args.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
|
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
|
||||||
|
|
@ -1773,7 +1808,7 @@ impl MainPodVerifyTarget {
|
||||||
.map(|_| builder.add_virtual_custom_predicate_batch(params, true))
|
.map(|_| builder.add_virtual_custom_predicate_batch(params, true))
|
||||||
.collect(),
|
.collect(),
|
||||||
custom_predicate_verifications: (0..params.max_custom_predicate_verifications)
|
custom_predicate_verifications: (0..params.max_custom_predicate_verifications)
|
||||||
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder, false))
|
.map(|_| CustomPredicateVerifyEntryTarget::new_virtual(params, builder))
|
||||||
.collect(),
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -2012,8 +2047,8 @@ mod tests {
|
||||||
dict,
|
dict,
|
||||||
frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
|
frontend::{self, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
|
||||||
middleware::{
|
middleware::{
|
||||||
hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, RawValue, StatementArg,
|
hash_values, AnchoredKey, Hash, Key, OperationType, Predicate, PredicateOrWildcard,
|
||||||
StatementTmpl, StatementTmplArg, Wildcard,
|
RawValue, StatementArg, StatementTmpl, StatementTmplArg, Wildcard,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -3124,7 +3159,7 @@ mod tests {
|
||||||
let dict = Hash([F(6), F(7), F(8), F(9)]);
|
let dict = Hash([F(6), F(7), F(8), F(9)]);
|
||||||
|
|
||||||
let st_tmpl = StatementTmpl {
|
let st_tmpl = StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")),
|
StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")),
|
||||||
StatementTmplArg::Literal(Value::from("value")),
|
StatementTmplArg::Literal(Value::from("value")),
|
||||||
|
|
@ -3137,6 +3172,21 @@ mod tests {
|
||||||
);
|
);
|
||||||
helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?;
|
helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?;
|
||||||
|
|
||||||
|
let st_tmpl = StatementTmpl {
|
||||||
|
pred_or_wc: PredicateOrWildcard::Wildcard(Wildcard::new("x".to_string(), 2)),
|
||||||
|
args: vec![
|
||||||
|
StatementTmplArg::AnchoredKey(Wildcard::new("a".to_string(), 1), Key::from("key")),
|
||||||
|
StatementTmplArg::Literal(Value::from("value")),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
let pred_hash = Predicate::Native(NativePredicate::NotEqual).hash(¶ms);
|
||||||
|
let args = vec![Value::from(1), Value::from(dict), Value::from(pred_hash)];
|
||||||
|
let expected_st = Statement::not_equal(
|
||||||
|
AnchoredKey::new(dict, Key::from("key")),
|
||||||
|
Value::from("value"),
|
||||||
|
);
|
||||||
|
helper_statement_from_template(¶ms, st_tmpl, args, expected_st)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3150,7 +3200,7 @@ mod tests {
|
||||||
let config = CircuitConfig::standard_recursion_config();
|
let config = CircuitConfig::standard_recursion_config();
|
||||||
let mut builder = CircuitBuilder::new(config);
|
let mut builder = CircuitBuilder::new(config);
|
||||||
|
|
||||||
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params, false);
|
let custom_predicate_target = builder.add_virtual_custom_predicate_entry(params);
|
||||||
let op_args_target: Vec<_> = (0..args.len())
|
let op_args_target: Vec<_> = (0..args.len())
|
||||||
.map(|_| builder.add_virtual_statement(params, false))
|
.map(|_| builder.add_virtual_statement(params, false))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ use crate::{
|
||||||
frontend::{AnchoredKey, Error, Result, Statement, StatementArg},
|
frontend::{AnchoredKey, Error, Result, Statement, StatementArg},
|
||||||
middleware::{
|
middleware::{
|
||||||
self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params,
|
self, hash_str, CustomPredicate, CustomPredicateBatch, Hash, Key, NativePredicate, Params,
|
||||||
Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
|
Predicate, PredicateOrWildcard, StatementTmpl, StatementTmplArg, ToFields, Value, Wildcard,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -217,7 +217,8 @@ impl CustomPredicateBatchBuilder {
|
||||||
})
|
})
|
||||||
.collect::<Result<_>>()?;
|
.collect::<Result<_>>()?;
|
||||||
Ok(StatementTmpl {
|
Ok(StatementTmpl {
|
||||||
pred: stb.predicate.clone(),
|
// TODO: Support wildcard
|
||||||
|
pred_or_wc: PredicateOrWildcard::Predicate(stb.predicate.clone()),
|
||||||
args,
|
args,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
@ -319,7 +320,10 @@ mod tests {
|
||||||
// Check that the desugared predicate is the same as the one in the statement template
|
// Check that the desugared predicate is the same as the one in the statement template
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
desugared_gt.predicate(),
|
desugared_gt.predicate(),
|
||||||
*batch_clone.predicates()[0].statements[0].pred()
|
*batch_clone.predicates()[0].statements[0]
|
||||||
|
.pred_or_wc()
|
||||||
|
.as_pred()
|
||||||
|
.unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check that our custom predicate matches the statement template
|
// Check that our custom predicate matches the statement template
|
||||||
|
|
@ -366,7 +370,10 @@ mod tests {
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
set_contains.predicate(),
|
set_contains.predicate(),
|
||||||
*batch_clone.predicates()[0].statements[0].pred()
|
*batch_clone.predicates()[0].statements[0]
|
||||||
|
.pred_or_wc()
|
||||||
|
.as_pred()
|
||||||
|
.unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
let set_contains_custom_pred = CustomPredicateRef::new(batch, 0);
|
let set_contains_custom_pred = CustomPredicateRef::new(batch, 0);
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@ use std::{collections::HashMap, fmt::Display};
|
||||||
use crate::{
|
use crate::{
|
||||||
frontend::{Error, Result},
|
frontend::{Error, Result},
|
||||||
lang::PrettyPrint,
|
lang::PrettyPrint,
|
||||||
middleware::{Pod, Statement, StatementArg, StatementTmpl, StatementTmplArg, Value},
|
middleware::{
|
||||||
|
Pod, PredicateOrWildcard, Statement, StatementArg, StatementTmpl, StatementTmplArg, Value,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Represents a request for a POD, in terms of a set of statement templates.
|
/// Represents a request for a POD, in terms of a set of statement templates.
|
||||||
|
|
@ -76,7 +78,8 @@ impl PodRequest {
|
||||||
statement: &Statement,
|
statement: &Statement,
|
||||||
current_bindings: &HashMap<String, Value>,
|
current_bindings: &HashMap<String, Value>,
|
||||||
) -> Option<HashMap<String, Value>> {
|
) -> Option<HashMap<String, Value>> {
|
||||||
if template.pred != statement.predicate() {
|
// TODO: Support wildcard
|
||||||
|
if template.pred_or_wc != PredicateOrWildcard::Predicate(statement.predicate()) {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ use crate::{
|
||||||
},
|
},
|
||||||
middleware::{
|
middleware::{
|
||||||
self, containers, CustomPredicateBatch, IntroPredicateRef, NativePredicate, Params,
|
self, containers, CustomPredicateBatch, IntroPredicateRef, NativePredicate, Params,
|
||||||
Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg,
|
Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
|
||||||
Wildcard,
|
StatementTmplArg as MWStatementTmplArg, Wildcard,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -201,7 +201,8 @@ impl<'a> Lowerer<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(MWStatementTmpl {
|
Ok(MWStatementTmpl {
|
||||||
pred: predicate,
|
// TODO: Support wildcard
|
||||||
|
pred_or_wc: PredicateOrWildcard::Predicate(predicate),
|
||||||
args: mw_args,
|
args: mw_args,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -596,7 +597,10 @@ mod tests {
|
||||||
let stmt = &pred2.statements()[0];
|
let stmt = &pred2.statements()[0];
|
||||||
|
|
||||||
// Should be BatchSelf(0) referring to pred1
|
// Should be BatchSelf(0) referring to pred1
|
||||||
assert!(matches!(stmt.pred, Predicate::BatchSelf(0)));
|
assert!(matches!(
|
||||||
|
stmt.pred_or_wc,
|
||||||
|
PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -632,8 +636,8 @@ mod tests {
|
||||||
|
|
||||||
// Should desugar to the Contains predicate
|
// Should desugar to the Contains predicate
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
stmt.pred,
|
stmt.pred_or_wc,
|
||||||
Predicate::Native(NativePredicate::Contains)
|
PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Contains))
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
109
src/lang/mod.rs
109
src/lang/mod.rs
|
|
@ -63,8 +63,8 @@ mod tests {
|
||||||
backends::plonky2::primitives::ec::schnorr::SecretKey,
|
backends::plonky2::primitives::ec::schnorr::SecretKey,
|
||||||
middleware::{
|
middleware::{
|
||||||
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate,
|
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate,
|
||||||
Params, Predicate, RawValue, StatementTmpl, StatementTmplArg, Value, Wildcard,
|
Params, Predicate, PredicateOrWildcard, RawValue, StatementTmpl, StatementTmplArg,
|
||||||
EMPTY_HASH,
|
Value, Wildcard, EMPTY_HASH,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -89,6 +89,10 @@ mod tests {
|
||||||
names.iter().map(|s| s.to_string()).collect()
|
names.iter().map(|s| s.to_string()).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn pred_lit(pred: Predicate) -> PredicateOrWildcard {
|
||||||
|
PredicateOrWildcard::Predicate(pred)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_e2e_simple_predicate() -> Result<(), LangError> {
|
fn test_e2e_simple_predicate() -> Result<(), LangError> {
|
||||||
let input = r#"
|
let input = r#"
|
||||||
|
|
@ -109,7 +113,7 @@ mod tests {
|
||||||
|
|
||||||
// Expected structure
|
// Expected structure
|
||||||
let expected_statements = vec![StatementTmpl {
|
let expected_statements = vec![StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(("PodA", 0), "the_key"), // PodA["the_key"] -> Wildcard(0), Key("the_key")
|
sta_ak(("PodA", 0), "the_key"), // PodA["the_key"] -> Wildcard(0), Key("the_key")
|
||||||
sta_ak(("PodB", 1), "the_key"), // PodB["the_key"] -> Wildcard(1), Key("the_key")
|
sta_ak(("PodB", 1), "the_key"), // PodB["the_key"] -> Wildcard(1), Key("the_key")
|
||||||
|
|
@ -153,14 +157,14 @@ mod tests {
|
||||||
// Expected structure
|
// Expected structure
|
||||||
let expected_templates = vec![
|
let expected_templates = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(("ConstPod", 0), "my_val"), // ConstPod["my_val"] -> Wildcard(0), Key("my_val")
|
sta_ak(("ConstPod", 0), "my_val"), // ConstPod["my_val"] -> Wildcard(0), Key("my_val")
|
||||||
sta_lit(RawValue::from(1)),
|
sta_lit(RawValue::from(1)),
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Lt),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(("GovPod", 1), "dob"), // GovPod["dob"] -> Wildcard(1), Key("dob")
|
sta_ak(("GovPod", 1), "dob"), // GovPod["dob"] -> Wildcard(1), Key("dob")
|
||||||
sta_ak(("ConstPod", 0), "my_val"), // ConstPod["my_val"] -> Wildcard(0), Key("my_val")
|
sta_ak(("ConstPod", 0), "my_val"), // ConstPod["my_val"] -> Wildcard(0), Key("my_val")
|
||||||
|
|
@ -195,14 +199,14 @@ mod tests {
|
||||||
// Expected structure: Public args: A (index 0). Private args: Temp (index 1)
|
// Expected structure: Public args: A (index 0). Private args: Temp (index 1)
|
||||||
let expected_statements = vec![
|
let expected_statements = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(("A", 0), "input_key"), // A["input_key"] -> Wildcard(0), Key("input_key")
|
sta_ak(("A", 0), "input_key"), // A["input_key"] -> Wildcard(0), Key("input_key")
|
||||||
sta_ak(("Temp", 1), "const_key"), // Temp["const_key"] -> Wildcard(1), Key("const_key")
|
sta_ak(("Temp", 1), "const_key"), // Temp["const_key"] -> Wildcard(1), Key("const_key")
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(("Temp", 1), "const_key"), // Temp["const_key"] -> Wildcard(1), Key("const_key")
|
sta_ak(("Temp", 1), "const_key"), // Temp["const_key"] -> Wildcard(1), Key("const_key")
|
||||||
sta_lit("some_value"), // Literal("some_value")
|
sta_lit("some_value"), // Literal("some_value")
|
||||||
|
|
@ -251,7 +255,7 @@ mod tests {
|
||||||
|
|
||||||
// Expected Batch structure
|
// Expected Batch structure
|
||||||
let expected_pred_statements = vec![StatementTmpl {
|
let expected_pred_statements = vec![StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(("X", 0), "val"), // X["val"] -> Wildcard(0), Key("val")
|
sta_ak(("X", 0), "val"), // X["val"] -> Wildcard(0), Key("val")
|
||||||
sta_ak(("Y", 1), "val"), // Y["val"] -> Wildcard(1), Key("val")
|
sta_ak(("Y", 1), "val"), // Y["val"] -> Wildcard(1), Key("val")
|
||||||
|
|
@ -275,7 +279,10 @@ mod tests {
|
||||||
// Expected Request structure
|
// Expected Request structure
|
||||||
// Pod1 -> Wildcard 0, Pod2 -> Wildcard 1
|
// Pod1 -> Wildcard 0, Pod2 -> Wildcard 1
|
||||||
let expected_request_templates = vec![StatementTmpl {
|
let expected_request_templates = vec![StatementTmpl {
|
||||||
pred: Predicate::Custom(CustomPredicateRef::new(expected_batch, 0)),
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
||||||
|
expected_batch,
|
||||||
|
0,
|
||||||
|
))),
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::Wildcard(wc("Pod1", 0)),
|
StatementTmplArg::Wildcard(wc("Pod1", 0)),
|
||||||
StatementTmplArg::Wildcard(wc("Pod2", 1)),
|
StatementTmplArg::Wildcard(wc("Pod2", 1)),
|
||||||
|
|
@ -317,7 +324,7 @@ mod tests {
|
||||||
// Expected structure
|
// Expected structure
|
||||||
let expected_templates = vec![
|
let expected_templates = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Custom(CustomPredicateRef::new(batch_result, 0)), // Refers to some_pred
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch_result, 0))), // Refers to some_pred
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::Wildcard(wc("Var1", 0)), // Var1
|
StatementTmplArg::Wildcard(wc("Var1", 0)), // Var1
|
||||||
StatementTmplArg::Literal(Value::from(12345i64)), // 12345
|
StatementTmplArg::Literal(Value::from(12345i64)), // 12345
|
||||||
|
|
@ -325,7 +332,7 @@ mod tests {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
// AnotherPod["another_key"] -> Wildcard(1), Key("another_key")
|
// AnotherPod["another_key"] -> Wildcard(1), Key("another_key")
|
||||||
sta_ak(("AnotherPod", 1), "another_key"),
|
sta_ak(("AnotherPod", 1), "another_key"),
|
||||||
|
|
@ -362,15 +369,15 @@ mod tests {
|
||||||
|
|
||||||
let expected_templates = vec![
|
let expected_templates = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::LtEq),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::LtEq)),
|
||||||
args: vec![sta_ak(("B", 1), "bar"), sta_ak(("A", 0), "foo")],
|
args: vec![sta_ak(("B", 1), "bar"), sta_ak(("A", 0), "foo")],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Lt),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)),
|
||||||
args: vec![sta_ak(("D", 3), "qux"), sta_ak(("C", 2), "baz")],
|
args: vec![sta_ak(("D", 3), "qux"), sta_ak(("C", 2), "baz")],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Contains),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Contains)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(("A", 0), "foo"),
|
sta_ak(("A", 0), "foo"),
|
||||||
sta_ak(("B", 1), "bar"),
|
sta_ak(("B", 1), "bar"),
|
||||||
|
|
@ -378,11 +385,11 @@ mod tests {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::NotContains),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::NotContains)),
|
||||||
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
|
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Contains),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Contains)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(("A", 0), "foo"),
|
sta_ak(("A", 0), "foo"),
|
||||||
sta_ak(("B", 1), "bar"),
|
sta_ak(("B", 1), "bar"),
|
||||||
|
|
@ -439,7 +446,7 @@ mod tests {
|
||||||
let expected_templates = vec![
|
let expected_templates = vec![
|
||||||
// 1. NotContains(sanctions["sanctionList"], gov["idNumber"])
|
// 1. NotContains(sanctions["sanctionList"], gov["idNumber"])
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::NotContains),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::NotContains)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(
|
sta_ak(
|
||||||
(wc_sanctions.name.as_str(), wc_sanctions.index),
|
(wc_sanctions.name.as_str(), wc_sanctions.index),
|
||||||
|
|
@ -450,7 +457,7 @@ mod tests {
|
||||||
},
|
},
|
||||||
// 2. Lt(gov["dateOfBirth"], SELF_HOLDER_18Y["const_18y"])
|
// 2. Lt(gov["dateOfBirth"], SELF_HOLDER_18Y["const_18y"])
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Lt),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak((wc_gov.name.as_str(), wc_gov.index), dob_key),
|
sta_ak((wc_gov.name.as_str(), wc_gov.index), dob_key),
|
||||||
sta_ak(
|
sta_ak(
|
||||||
|
|
@ -461,7 +468,7 @@ mod tests {
|
||||||
},
|
},
|
||||||
// 3. Equal(pay["startDate"], SELF_HOLDER_1Y["const_1y"])
|
// 3. Equal(pay["startDate"], SELF_HOLDER_1Y["const_1y"])
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak((wc_pay.name.as_str(), wc_pay.index), start_date_key),
|
sta_ak((wc_pay.name.as_str(), wc_pay.index), start_date_key),
|
||||||
sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key),
|
sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key),
|
||||||
|
|
@ -469,7 +476,7 @@ mod tests {
|
||||||
},
|
},
|
||||||
// 4. Equal(gov["socialSecurityNumber"], pay["socialSecurityNumber"])
|
// 4. Equal(gov["socialSecurityNumber"], pay["socialSecurityNumber"])
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak((wc_gov.name.as_str(), wc_gov.index), ssn_key),
|
sta_ak((wc_gov.name.as_str(), wc_gov.index), ssn_key),
|
||||||
sta_ak((wc_pay.name.as_str(), wc_pay.index), ssn_key),
|
sta_ak((wc_pay.name.as_str(), wc_pay.index), ssn_key),
|
||||||
|
|
@ -477,7 +484,7 @@ mod tests {
|
||||||
},
|
},
|
||||||
// 5. Equal(SELF_HOLDER_18Y["const_18y"], 1169909388)
|
// 5. Equal(SELF_HOLDER_18Y["const_18y"], 1169909388)
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(
|
sta_ak(
|
||||||
(wc_self_18y.name.as_str(), wc_self_18y.index),
|
(wc_self_18y.name.as_str(), wc_self_18y.index),
|
||||||
|
|
@ -488,7 +495,7 @@ mod tests {
|
||||||
},
|
},
|
||||||
// 6. Equal(SELF_HOLDER_1Y["const_1y"], 1706367566)
|
// 6. Equal(SELF_HOLDER_1Y["const_1y"], 1706367566)
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key),
|
sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key),
|
||||||
sta_lit(now_minus_1y_val.clone()),
|
sta_lit(now_minus_1y_val.clone()),
|
||||||
|
|
@ -563,11 +570,11 @@ mod tests {
|
||||||
// eth_friend (Index 0)
|
// eth_friend (Index 0)
|
||||||
let expected_friend_stmts = vec![
|
let expected_friend_stmts = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::SignedBy),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::SignedBy)),
|
||||||
args: vec![sta_wc_lit("attestation_dict", 2), sta_wc_lit("src", 0)],
|
args: vec![sta_wc_lit("attestation_dict", 2), sta_wc_lit("src", 0)],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(("attestation_dict", 2), "attestation"),
|
sta_ak(("attestation_dict", 2), "attestation"),
|
||||||
sta_wc_lit("dst", 1), // Pub arg 1
|
sta_wc_lit("dst", 1), // Pub arg 1
|
||||||
|
|
@ -586,11 +593,11 @@ mod tests {
|
||||||
// eth_dos_distance_base (Index 1)
|
// eth_dos_distance_base (Index 1)
|
||||||
let expected_base_stmts = vec![
|
let expected_base_stmts = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![sta_wc_lit("src", 0), sta_wc_lit("dst", 1)],
|
args: vec![sta_wc_lit("src", 0), sta_wc_lit("dst", 1)],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![sta_wc_lit("distance", 2), sta_lit(0i64)],
|
args: vec![sta_wc_lit("distance", 2), sta_lit(0i64)],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
@ -608,7 +615,7 @@ mod tests {
|
||||||
// Private args indices: 3-4 (shorter_distance, intermed)
|
// Private args indices: 3-4 (shorter_distance, intermed)
|
||||||
let expected_ind_stmts = vec![
|
let expected_ind_stmts = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::BatchSelf(3), // Calls eth_dos_distance (index 3)
|
pred_or_wc: pred_lit(Predicate::BatchSelf(3)), // Calls eth_dos_distance (index 3)
|
||||||
args: vec![
|
args: vec![
|
||||||
// WildcardLiteral args
|
// WildcardLiteral args
|
||||||
sta_wc_lit("src", 0),
|
sta_wc_lit("src", 0),
|
||||||
|
|
@ -617,7 +624,7 @@ mod tests {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::SumOf),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::SumOf)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_wc_lit("distance", 2), // public arg
|
sta_wc_lit("distance", 2), // public arg
|
||||||
sta_wc_lit("shorter_distance", 3), // private arg
|
sta_wc_lit("shorter_distance", 3), // private arg
|
||||||
|
|
@ -625,7 +632,7 @@ mod tests {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::BatchSelf(0), // Calls eth_friend (index 0)
|
pred_or_wc: pred_lit(Predicate::BatchSelf(0)), // Calls eth_friend (index 0)
|
||||||
args: vec![
|
args: vec![
|
||||||
// WildcardLiteral args
|
// WildcardLiteral args
|
||||||
sta_wc_lit("intermed", 4), // private arg
|
sta_wc_lit("intermed", 4), // private arg
|
||||||
|
|
@ -645,7 +652,7 @@ mod tests {
|
||||||
// eth_dos_distance (Index 3)
|
// eth_dos_distance (Index 3)
|
||||||
let expected_dist_stmts = vec![
|
let expected_dist_stmts = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::BatchSelf(1), // Calls eth_dos_distance_base (index 1)
|
pred_or_wc: pred_lit(Predicate::BatchSelf(1)), // Calls eth_dos_distance_base (index 1)
|
||||||
args: vec![
|
args: vec![
|
||||||
// WildcardLiteral args
|
// WildcardLiteral args
|
||||||
sta_wc_lit("src", 0),
|
sta_wc_lit("src", 0),
|
||||||
|
|
@ -654,7 +661,7 @@ mod tests {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::BatchSelf(2), // Calls eth_dos_distance_ind (index 2)
|
pred_or_wc: pred_lit(Predicate::BatchSelf(2)), // Calls eth_dos_distance_ind (index 2)
|
||||||
args: vec![
|
args: vec![
|
||||||
// WildcardLiteral args
|
// WildcardLiteral args
|
||||||
sta_wc_lit("src", 0),
|
sta_wc_lit("src", 0),
|
||||||
|
|
@ -697,7 +704,7 @@ mod tests {
|
||||||
|
|
||||||
// 1. Create a batch to be imported
|
// 1. Create a batch to be imported
|
||||||
let imported_pred_stmts = vec![StatementTmpl {
|
let imported_pred_stmts = vec![StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
sta_ak(("A", 0), "foo"), // A["foo"]
|
sta_ak(("A", 0), "foo"), // A["foo"]
|
||||||
sta_ak(("B", 1), "bar"), // B["bar"]
|
sta_ak(("B", 1), "bar"), // B["bar"]
|
||||||
|
|
@ -739,7 +746,10 @@ mod tests {
|
||||||
|
|
||||||
// 4. Check the resulting request template
|
// 4. Check the resulting request template
|
||||||
let expected_request_templates = vec![StatementTmpl {
|
let expected_request_templates = vec![StatementTmpl {
|
||||||
pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 0)),
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
||||||
|
available_batch,
|
||||||
|
0,
|
||||||
|
))),
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::Wildcard(wc("Pod1", 0)),
|
StatementTmplArg::Wildcard(wc("Pod1", 0)),
|
||||||
StatementTmplArg::Wildcard(wc("Pod2", 1)),
|
StatementTmplArg::Wildcard(wc("Pod2", 1)),
|
||||||
|
|
@ -788,11 +798,17 @@ mod tests {
|
||||||
// 4. Check the resulting request templates
|
// 4. Check the resulting request templates
|
||||||
let expected_templates = vec![
|
let expected_templates = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)),
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
||||||
|
available_batch.clone(),
|
||||||
|
0,
|
||||||
|
))),
|
||||||
args: vec![StatementTmplArg::Wildcard(wc("Pod1", 0))],
|
args: vec![StatementTmplArg::Wildcard(wc("Pod1", 0))],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 2)),
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
||||||
|
available_batch,
|
||||||
|
2,
|
||||||
|
))),
|
||||||
args: vec![StatementTmplArg::Wildcard(wc("Pod2", 1))],
|
args: vec![StatementTmplArg::Wildcard(wc("Pod2", 1))],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
@ -808,7 +824,7 @@ mod tests {
|
||||||
|
|
||||||
// 1. Create a batch with a predicate to be imported
|
// 1. Create a batch with a predicate to be imported
|
||||||
let imported_pred_stmts = vec![StatementTmpl {
|
let imported_pred_stmts = vec![StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
|
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
|
||||||
}];
|
}];
|
||||||
let imported_predicate = CustomPredicate::and(
|
let imported_predicate = CustomPredicate::and(
|
||||||
|
|
@ -855,7 +871,10 @@ mod tests {
|
||||||
assert_eq!(defined_pred.statements.len(), 1);
|
assert_eq!(defined_pred.statements.len(), 1);
|
||||||
|
|
||||||
let expected_statement = StatementTmpl {
|
let expected_statement = StatementTmpl {
|
||||||
pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)),
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
||||||
|
available_batch.clone(),
|
||||||
|
0,
|
||||||
|
))),
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::Wildcard(wc("X", 0)),
|
StatementTmplArg::Wildcard(wc("X", 0)),
|
||||||
StatementTmplArg::Wildcard(wc("Y", 1)),
|
StatementTmplArg::Wildcard(wc("Y", 1)),
|
||||||
|
|
@ -886,7 +905,9 @@ mod tests {
|
||||||
let request_templates = processed.request.templates();
|
let request_templates = processed.request.templates();
|
||||||
assert_eq!(request_templates.len(), 1);
|
assert_eq!(request_templates.len(), 1);
|
||||||
|
|
||||||
if let Predicate::Intro(intro_ref) = &request_templates[0].pred {
|
if let PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) =
|
||||||
|
&request_templates[0].pred_or_wc
|
||||||
|
{
|
||||||
assert_eq!(intro_ref.name, "empty");
|
assert_eq!(intro_ref.name, "empty");
|
||||||
assert_eq!(intro_ref.args_len, 0);
|
assert_eq!(intro_ref.args_len, 0);
|
||||||
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
|
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
|
||||||
|
|
@ -944,27 +965,27 @@ mod tests {
|
||||||
|
|
||||||
let expected_templates = vec![
|
let expected_templates = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![sta_ak(("A", 0), "pk"), sta_lit(Value::from(pk))],
|
args: vec![sta_ak(("A", 0), "pk"), sta_lit(Value::from(pk))],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![sta_ak(("B", 1), "raw"), sta_lit(Value::from(raw))],
|
args: vec![sta_ak(("B", 1), "raw"), sta_lit(Value::from(raw))],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![sta_ak(("C", 2), "string"), sta_lit(Value::from(string))],
|
args: vec![sta_ak(("C", 2), "string"), sta_lit(Value::from(string))],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![sta_ak(("D", 3), "int"), sta_lit(Value::from(int))],
|
args: vec![sta_ak(("D", 3), "int"), sta_lit(Value::from(int))],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![sta_ak(("E", 4), "bool"), sta_lit(Value::from(bool))],
|
args: vec![sta_ak(("E", 4), "bool"), sta_lit(Value::from(bool))],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![sta_ak(("F", 5), "sk"), sta_lit(Value::from(sk))],
|
args: vec![sta_ak(("F", 5), "sk"), sta_lit(Value::from(sk))],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,8 @@ use std::fmt::Write;
|
||||||
use crate::{
|
use crate::{
|
||||||
frontend::PodRequest,
|
frontend::PodRequest,
|
||||||
middleware::{
|
middleware::{
|
||||||
CustomPredicate, CustomPredicateBatch, Predicate, StatementTmpl, StatementTmplArg, Value,
|
CustomPredicate, CustomPredicateBatch, Predicate, PredicateOrWildcard, StatementTmpl,
|
||||||
|
StatementTmplArg, Value,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -57,7 +58,8 @@ impl StatementTmpl {
|
||||||
w: &mut dyn Write,
|
w: &mut dyn Write,
|
||||||
batch_context: Option<&CustomPredicateBatch>,
|
batch_context: Option<&CustomPredicateBatch>,
|
||||||
) -> std::fmt::Result {
|
) -> std::fmt::Result {
|
||||||
match &self.pred {
|
match &self.pred_or_wc {
|
||||||
|
PredicateOrWildcard::Predicate(pred) => match pred {
|
||||||
Predicate::Native(native_pred) => {
|
Predicate::Native(native_pred) => {
|
||||||
write!(w, "{}", native_pred)?;
|
write!(w, "{}", native_pred)?;
|
||||||
}
|
}
|
||||||
|
|
@ -78,6 +80,11 @@ impl StatementTmpl {
|
||||||
write!(w, "batch_self_{}", index)?;
|
write!(w, "batch_self_{}", index)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
PredicateOrWildcard::Wildcard(wc) => {
|
||||||
|
// TODO: Decide the syntax for a wildcard predicate
|
||||||
|
write!(w, "?{}", wc.name)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
write!(w, "(")?;
|
write!(w, "(")?;
|
||||||
|
|
@ -223,13 +230,17 @@ mod tests {
|
||||||
Wildcard::new(name.to_string(), index)
|
Wildcard::new(name.to_string(), index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn pred_lit(pred: Predicate) -> PredicateOrWildcard {
|
||||||
|
PredicateOrWildcard::Predicate(pred)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_simple_predicate_pretty_print() {
|
fn test_simple_predicate_pretty_print() {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
|
|
||||||
// Create a simple predicate: is_equal(PodA, PodB) = AND(Equal(PodA["key"], PodB["key"]))
|
// Create a simple predicate: is_equal(PodA, PodB) = AND(Equal(PodA["key"], PodB["key"]))
|
||||||
let statements = vec![StatementTmpl {
|
let statements = vec![StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::AnchoredKey(
|
StatementTmplArg::AnchoredKey(
|
||||||
create_test_wildcard("PodA", 0),
|
create_test_wildcard("PodA", 0),
|
||||||
|
|
@ -265,7 +276,7 @@ mod tests {
|
||||||
|
|
||||||
// Create: uses_private(A, private: Temp) = AND(Equal(A["input"], Temp["const"]))
|
// Create: uses_private(A, private: Temp) = AND(Equal(A["input"], Temp["const"]))
|
||||||
let statements = vec![StatementTmpl {
|
let statements = vec![StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::AnchoredKey(
|
StatementTmplArg::AnchoredKey(
|
||||||
create_test_wildcard("A", 0),
|
create_test_wildcard("A", 0),
|
||||||
|
|
@ -301,7 +312,7 @@ mod tests {
|
||||||
|
|
||||||
// Create: check_value(Pod) = AND(Equal(Pod["field"], 42))
|
// Create: check_value(Pod) = AND(Equal(Pod["field"], 42))
|
||||||
let statements = vec![StatementTmpl {
|
let statements = vec![StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::AnchoredKey(
|
StatementTmplArg::AnchoredKey(
|
||||||
create_test_wildcard("Pod", 0),
|
create_test_wildcard("Pod", 0),
|
||||||
|
|
@ -335,7 +346,7 @@ mod tests {
|
||||||
// Create: either_or(A, B) = OR(Equal(A["x"], 1), Equal(B["y"], 2))
|
// Create: either_or(A, B) = OR(Equal(A["x"], 1), Equal(B["y"], 2))
|
||||||
let statements = vec![
|
let statements = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::AnchoredKey(
|
StatementTmplArg::AnchoredKey(
|
||||||
create_test_wildcard("A", 0),
|
create_test_wildcard("A", 0),
|
||||||
|
|
@ -345,7 +356,7 @@ mod tests {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::Equal),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::AnchoredKey(
|
StatementTmplArg::AnchoredKey(
|
||||||
create_test_wildcard("B", 1),
|
create_test_wildcard("B", 1),
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
use std::{fmt, iter, sync::Arc};
|
use std::{fmt, iter, sync::Arc};
|
||||||
|
|
||||||
|
use itertools::Itertools;
|
||||||
use plonky2::field::types::Field;
|
use plonky2::field::types::Field;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -70,36 +71,28 @@ impl ToFields for StatementTmplArg {
|
||||||
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
|
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
|
||||||
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
|
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
|
||||||
match self {
|
match self {
|
||||||
StatementTmplArg::None => {
|
StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None))
|
||||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::None))
|
|
||||||
.chain(iter::repeat(F::ZERO))
|
.chain(iter::repeat(F::ZERO))
|
||||||
.take(Params::statement_tmpl_arg_size())
|
.take(Params::statement_tmpl_arg_size())
|
||||||
.collect();
|
.collect_vec(),
|
||||||
fields
|
StatementTmplArg::Literal(v) => iter::once(F::from(StatementTmplArgPrefix::Literal))
|
||||||
}
|
|
||||||
StatementTmplArg::Literal(v) => {
|
|
||||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::Literal))
|
|
||||||
.chain(v.raw().to_fields(params))
|
.chain(v.raw().to_fields(params))
|
||||||
.chain(iter::repeat(F::ZERO))
|
.chain(iter::repeat(F::ZERO))
|
||||||
.take(Params::statement_tmpl_arg_size())
|
.take(Params::statement_tmpl_arg_size())
|
||||||
.collect();
|
.collect_vec(),
|
||||||
fields
|
|
||||||
}
|
|
||||||
StatementTmplArg::AnchoredKey(wc1, kw2) => {
|
StatementTmplArg::AnchoredKey(wc1, kw2) => {
|
||||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::AnchoredKey))
|
iter::once(F::from(StatementTmplArgPrefix::AnchoredKey))
|
||||||
.chain(wc1.to_fields(params))
|
.chain(wc1.to_fields(params))
|
||||||
.chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1))
|
.chain(iter::repeat(F::ZERO).take(VALUE_SIZE - 1))
|
||||||
.chain(kw2.to_fields(params))
|
.chain(kw2.to_fields(params))
|
||||||
.collect();
|
.collect_vec()
|
||||||
fields
|
|
||||||
}
|
}
|
||||||
StatementTmplArg::Wildcard(wc) => {
|
StatementTmplArg::Wildcard(wc) => {
|
||||||
let fields: Vec<F> = iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
|
iter::once(F::from(StatementTmplArgPrefix::WildcardLiteral))
|
||||||
.chain(wc.to_fields(params))
|
.chain(wc.to_fields(params))
|
||||||
.chain(iter::repeat(F::ZERO))
|
.chain(iter::repeat(F::ZERO))
|
||||||
.take(Params::statement_tmpl_arg_size())
|
.take(Params::statement_tmpl_arg_size())
|
||||||
.collect();
|
.collect_vec()
|
||||||
fields
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -121,16 +114,79 @@ impl fmt::Display for StatementTmplArg {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||||
|
pub enum PredicateOrWildcard {
|
||||||
|
Predicate(Predicate),
|
||||||
|
Wildcard(Wildcard),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PredicateOrWildcard {
|
||||||
|
pub fn as_pred(&self) -> Option<&Predicate> {
|
||||||
|
match self {
|
||||||
|
Self::Predicate(pred) => Some(pred),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn as_wc(&self) -> Option<&Wildcard> {
|
||||||
|
match self {
|
||||||
|
Self::Wildcard(wc) => Some(wc),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for PredicateOrWildcard {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Predicate(pred) => pred.fmt(f),
|
||||||
|
Self::Wildcard(wc) => {
|
||||||
|
write!(f, "?")?;
|
||||||
|
wc.fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub enum PredicateOrWildcardPrefix {
|
||||||
|
Predicate = 0,
|
||||||
|
Wildcard = 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<PredicateOrWildcardPrefix> for F {
|
||||||
|
fn from(prefix: PredicateOrWildcardPrefix) -> Self {
|
||||||
|
Self::from_canonical_usize(prefix as usize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToFields for PredicateOrWildcard {
|
||||||
|
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||||
|
// Encoding:
|
||||||
|
// Predicate(pred) => (0, [hash(pred) ])
|
||||||
|
// Wildcard(wc) => (1, wc_index, 0...)
|
||||||
|
match self {
|
||||||
|
Self::Predicate(pred) => iter::once(F::from(PredicateOrWildcardPrefix::Predicate))
|
||||||
|
.chain(pred.hash(params).to_fields(params))
|
||||||
|
.collect_vec(),
|
||||||
|
Self::Wildcard(wc) => iter::once(F::from(PredicateOrWildcardPrefix::Wildcard))
|
||||||
|
.chain(wc.to_fields(params))
|
||||||
|
.chain(iter::repeat(F::ZERO))
|
||||||
|
.take(Params::pred_hash_or_wc_size())
|
||||||
|
.collect_vec(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Statement Template for a Custom Predicate
|
/// Statement Template for a Custom Predicate
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||||
pub struct StatementTmpl {
|
pub struct StatementTmpl {
|
||||||
pub pred: Predicate,
|
pub pred_or_wc: PredicateOrWildcard,
|
||||||
pub args: Vec<StatementTmplArg>,
|
pub args: Vec<StatementTmplArg>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StatementTmpl {
|
impl StatementTmpl {
|
||||||
pub fn pred(&self) -> &Predicate {
|
pub fn pred_or_wc(&self) -> &PredicateOrWildcard {
|
||||||
&self.pred
|
&self.pred_or_wc
|
||||||
}
|
}
|
||||||
pub fn args(&self) -> &[StatementTmplArg] {
|
pub fn args(&self) -> &[StatementTmplArg] {
|
||||||
&self.args
|
&self.args
|
||||||
|
|
@ -139,7 +195,7 @@ impl StatementTmpl {
|
||||||
|
|
||||||
impl fmt::Display for StatementTmpl {
|
impl fmt::Display for StatementTmpl {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
self.pred.fmt(f)?;
|
self.pred_or_wc.fmt(f)?;
|
||||||
write!(f, "(")?;
|
write!(f, "(")?;
|
||||||
for (i, arg) in self.args.iter().enumerate() {
|
for (i, arg) in self.args.iter().enumerate() {
|
||||||
if i != 0 {
|
if i != 0 {
|
||||||
|
|
@ -154,7 +210,7 @@ impl fmt::Display for StatementTmpl {
|
||||||
impl ToFields for StatementTmpl {
|
impl ToFields for StatementTmpl {
|
||||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||||
// serialize as:
|
// serialize as:
|
||||||
// predicate (6 field elements)
|
// predicate (4 field elements)
|
||||||
// then the StatementTmplArgs
|
// then the StatementTmplArgs
|
||||||
|
|
||||||
// TODO think if this check should go into the StatementTmpl creation,
|
// TODO think if this check should go into the StatementTmpl creation,
|
||||||
|
|
@ -168,15 +224,13 @@ impl ToFields for StatementTmpl {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut fields: Vec<F> = self
|
self.pred_or_wc
|
||||||
.pred
|
|
||||||
.hash(params)
|
|
||||||
.to_fields(params)
|
.to_fields(params)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.chain(self.args.iter().flat_map(|sta| sta.to_fields(params)))
|
.chain(self.args.iter().flat_map(|sta| sta.to_fields(params)))
|
||||||
.collect();
|
.chain(iter::repeat(F::ZERO))
|
||||||
fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0));
|
.take(params.statement_tmpl_size())
|
||||||
fields
|
.collect_vec()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -203,7 +257,9 @@ impl CustomPredicate {
|
||||||
name: "empty".to_string(),
|
name: "empty".to_string(),
|
||||||
conjunction: false,
|
conjunction: false,
|
||||||
statements: vec![StatementTmpl {
|
statements: vec![StatementTmpl {
|
||||||
pred: Predicate::Native(NativePredicate::None),
|
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(
|
||||||
|
NativePredicate::None,
|
||||||
|
)),
|
||||||
args: vec![],
|
args: vec![],
|
||||||
}],
|
}],
|
||||||
args_len: 0,
|
args_len: 0,
|
||||||
|
|
@ -276,11 +332,11 @@ impl CustomPredicate {
|
||||||
}
|
}
|
||||||
pub fn pad_statement_tmpl(&self) -> StatementTmpl {
|
pub fn pad_statement_tmpl(&self) -> StatementTmpl {
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred: Predicate::Native(if self.conjunction {
|
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(if self.conjunction {
|
||||||
NativePredicate::None
|
NativePredicate::None
|
||||||
} else {
|
} else {
|
||||||
NativePredicate::False
|
NativePredicate::False
|
||||||
}),
|
})),
|
||||||
args: vec![],
|
args: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -318,7 +374,7 @@ impl ToFields for CustomPredicate {
|
||||||
}
|
}
|
||||||
|
|
||||||
let pad_st = self.pad_statement_tmpl();
|
let pad_st = self.pad_statement_tmpl();
|
||||||
let fields: Vec<F> = iter::once(F::from_bool(self.conjunction))
|
iter::once(F::from_bool(self.conjunction))
|
||||||
.chain(iter::once(F::from_canonical_usize(self.args_len)))
|
.chain(iter::once(F::from_canonical_usize(self.args_len)))
|
||||||
.chain(
|
.chain(
|
||||||
self.statements
|
self.statements
|
||||||
|
|
@ -327,8 +383,7 @@ impl ToFields for CustomPredicate {
|
||||||
.take(params.max_custom_predicate_arity)
|
.take(params.max_custom_predicate_arity)
|
||||||
.flat_map(|st| st.to_fields(params)),
|
.flat_map(|st| st.to_fields(params)),
|
||||||
)
|
)
|
||||||
.collect();
|
.collect_vec()
|
||||||
fields
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -350,7 +405,7 @@ impl fmt::Display for CustomPredicate {
|
||||||
writeln!(f, ") = {}(", if self.conjunction { "AND" } else { "OR" })?;
|
writeln!(f, ") = {}(", if self.conjunction { "AND" } else { "OR" })?;
|
||||||
for st in &self.statements {
|
for st in &self.statements {
|
||||||
write!(f, " ")?;
|
write!(f, " ")?;
|
||||||
st.pred.fmt(f)?;
|
st.pred_or_wc.fmt(f)?;
|
||||||
write!(f, "(")?;
|
write!(f, "(")?;
|
||||||
for (i, arg) in st.args.iter().enumerate() {
|
for (i, arg) in st.args.iter().enumerate() {
|
||||||
if i != 0 {
|
if i != 0 {
|
||||||
|
|
@ -382,14 +437,12 @@ impl ToFields for CustomPredicateBatch {
|
||||||
fn to_fields(&self, params: &Params) -> Vec<F> {
|
fn to_fields(&self, params: &Params) -> Vec<F> {
|
||||||
// all the custom predicates in order
|
// all the custom predicates in order
|
||||||
let pad_pred = CustomPredicate::empty();
|
let pad_pred = CustomPredicate::empty();
|
||||||
let fields: Vec<F> = self
|
self.predicates
|
||||||
.predicates
|
|
||||||
.iter()
|
.iter()
|
||||||
.chain(iter::repeat(&pad_pred))
|
.chain(iter::repeat(&pad_pred))
|
||||||
.take(params.max_custom_batch_size)
|
.take(params.max_custom_batch_size)
|
||||||
.flat_map(|p| p.to_fields(params))
|
.flat_map(|p| p.to_fields(params))
|
||||||
.collect();
|
.collect_vec()
|
||||||
fields
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -418,7 +471,6 @@ impl CustomPredicateBatch {
|
||||||
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
|
// NOTE: This implementation just hashes the concatenation of all the custom predicates,
|
||||||
// but ideally we want to use the root of a merkle tree built from the custom predicates.
|
// but ideally we want to use the root of a merkle tree built from the custom predicates.
|
||||||
let input = self.to_fields(params);
|
let input = self.to_fields(params);
|
||||||
|
|
||||||
hash_fields(&input)
|
hash_fields(&input)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -470,7 +522,10 @@ mod tests {
|
||||||
};
|
};
|
||||||
|
|
||||||
fn st(p: Predicate, args: Vec<StatementTmplArg>) -> StatementTmpl {
|
fn st(p: Predicate, args: Vec<StatementTmplArg>) -> StatementTmpl {
|
||||||
StatementTmpl { pred: p, args }
|
StatementTmpl {
|
||||||
|
pred_or_wc: PredicateOrWildcard::Predicate(p),
|
||||||
|
args,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn key(name: &str) -> Key {
|
fn key(name: &str) -> Key {
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,8 @@ pub enum MiddlewareInnerError {
|
||||||
MismatchedStatementTmplArg(StatementTmplArg, StatementArg),
|
MismatchedStatementTmplArg(StatementTmplArg, StatementArg),
|
||||||
#[error("Expected a statement of type {0}, got {1}")]
|
#[error("Expected a statement of type {0}, got {1}")]
|
||||||
MismatchedStatementType(Predicate, Predicate),
|
MismatchedStatementType(Predicate, Predicate),
|
||||||
|
#[error("Expected a statement with hash(predicate) {0}, got {1} ({2})")]
|
||||||
|
MismatchedStatementWildcardPredicate(Value, Value, Predicate),
|
||||||
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
|
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
|
||||||
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
|
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
|
||||||
#[error(
|
#[error(
|
||||||
|
|
@ -111,6 +113,15 @@ impl Error {
|
||||||
pub(crate) fn mismatched_statement_type(expected: Predicate, seen: Predicate) -> Self {
|
pub(crate) fn mismatched_statement_type(expected: Predicate, seen: Predicate) -> Self {
|
||||||
new!(MismatchedStatementType(expected, seen))
|
new!(MismatchedStatementType(expected, seen))
|
||||||
}
|
}
|
||||||
|
pub(crate) fn mismatched_statement_wc_pred(
|
||||||
|
expected: Value,
|
||||||
|
seen: Value,
|
||||||
|
seen_pred: Predicate,
|
||||||
|
) -> Self {
|
||||||
|
new!(MismatchedStatementWildcardPredicate(
|
||||||
|
expected, seen, seen_pred
|
||||||
|
))
|
||||||
|
}
|
||||||
pub(crate) fn mismatched_wildcard_value_and_statement_arg(
|
pub(crate) fn mismatched_wildcard_value_and_statement_arg(
|
||||||
wc_value: Value,
|
wc_value: Value,
|
||||||
st_arg: Value,
|
st_arg: Value,
|
||||||
|
|
|
||||||
|
|
@ -820,8 +820,12 @@ impl Params {
|
||||||
HASH_SIZE + STATEMENT_ARG_F_LEN * self.max_statement_args
|
HASH_SIZE + STATEMENT_ARG_F_LEN * self.max_statement_args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const fn pred_hash_or_wc_size() -> usize {
|
||||||
|
1 + HASH_SIZE
|
||||||
|
}
|
||||||
|
|
||||||
pub const fn statement_tmpl_size(&self) -> usize {
|
pub const fn statement_tmpl_size(&self) -> usize {
|
||||||
HASH_SIZE + self.max_statement_args * Self::statement_tmpl_arg_size()
|
Self::pred_hash_or_wc_size() + self.max_statement_args * Self::statement_tmpl_arg_size()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn custom_predicate_size(&self) -> usize {
|
pub fn custom_predicate_size(&self) -> usize {
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,9 @@ use crate::{
|
||||||
},
|
},
|
||||||
middleware::{
|
middleware::{
|
||||||
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key,
|
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key,
|
||||||
MiddlewareInnerError, NativePredicate, Params, Predicate, Result, Statement, StatementArg,
|
MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result,
|
||||||
StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F,
|
Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value,
|
||||||
|
ValueRef, Wildcard, F,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -550,16 +551,8 @@ impl Operation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check that a StatementArg follows a StatementTmplArg based on the currently mapped wildcards.
|
|
||||||
/// Update the wildcard map with newly found wildcards.
|
|
||||||
pub fn check_st_tmpl(
|
|
||||||
st_tmpl_arg: &StatementTmplArg,
|
|
||||||
st_arg: &StatementArg,
|
|
||||||
// Map from wildcards to values that we have seen so far.
|
|
||||||
wildcard_map: &mut [Option<Value>],
|
|
||||||
) -> Result<()> {
|
|
||||||
// Check that the value `v` at wildcard `wc` exists in the map or set it.
|
// Check that the value `v` at wildcard `wc` exists in the map or set it.
|
||||||
fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> Result<()> {
|
fn wc_check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> Result<()> {
|
||||||
if let Some(prev) = &wildcard_map[wc.index] {
|
if let Some(prev) = &wildcard_map[wc.index] {
|
||||||
if *prev != v {
|
if *prev != v {
|
||||||
return Err(Error::invalid_wildcard_assignment(
|
return Err(Error::invalid_wildcard_assignment(
|
||||||
|
|
@ -574,6 +567,14 @@ pub fn check_st_tmpl(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check that a StatementArg follows a StatementTmplArg based on the currently mapped wildcards.
|
||||||
|
/// Update the wildcard map with newly found wildcards.
|
||||||
|
pub fn check_st_tmpl(
|
||||||
|
st_tmpl_arg: &StatementTmplArg,
|
||||||
|
st_arg: &StatementArg,
|
||||||
|
// Map from wildcards to values that we have seen so far.
|
||||||
|
wildcard_map: &mut [Option<Value>],
|
||||||
|
) -> Result<()> {
|
||||||
match (st_tmpl_arg, st_arg) {
|
match (st_tmpl_arg, st_arg) {
|
||||||
(StatementTmplArg::None, StatementArg::None) => Ok(()),
|
(StatementTmplArg::None, StatementArg::None) => Ok(()),
|
||||||
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => Ok(()),
|
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => Ok(()),
|
||||||
|
|
@ -581,7 +582,7 @@ pub fn check_st_tmpl(
|
||||||
StatementTmplArg::AnchoredKey(root_wc, key_tmpl),
|
StatementTmplArg::AnchoredKey(root_wc, key_tmpl),
|
||||||
StatementArg::Key(AnchoredKey { root, key }),
|
StatementArg::Key(AnchoredKey { root, key }),
|
||||||
) => {
|
) => {
|
||||||
let root_ok = check_or_set(Value::from(*root), root_wc, wildcard_map);
|
let root_ok = wc_check_or_set(Value::from(*root), root_wc, wildcard_map);
|
||||||
root_ok.and_then(|_| {
|
root_ok.and_then(|_| {
|
||||||
(key_tmpl == key).then_some(()).ok_or(
|
(key_tmpl == key).then_some(()).ok_or(
|
||||||
Error::mismatched_anchored_key_in_statement_tmpl_arg(
|
Error::mismatched_anchored_key_in_statement_tmpl_arg(
|
||||||
|
|
@ -594,7 +595,7 @@ pub fn check_st_tmpl(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
|
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
|
||||||
check_or_set(v.clone(), wc, wildcard_map)
|
wc_check_or_set(v.clone(), wc, wildcard_map)
|
||||||
}
|
}
|
||||||
_ => Err(Error::mismatched_statement_tmpl_arg(
|
_ => Err(Error::mismatched_statement_tmpl_arg(
|
||||||
st_tmpl_arg.clone(),
|
st_tmpl_arg.clone(),
|
||||||
|
|
@ -604,12 +605,16 @@ pub fn check_st_tmpl(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fill_wildcard_values(
|
pub fn fill_wildcard_values(
|
||||||
|
params: &Params,
|
||||||
pred: &CustomPredicate,
|
pred: &CustomPredicate,
|
||||||
args: &[Statement],
|
args: &[Statement],
|
||||||
wildcard_map: &mut [Option<Value>],
|
wildcard_map: &mut [Option<Value>],
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||||
let st_args = st.args();
|
let st_args = st.args();
|
||||||
|
if let PredicateOrWildcard::Wildcard(wc) = &st_tmpl.pred_or_wc {
|
||||||
|
wc_check_or_set(Value::from(st.predicate().hash(params)), wc, wildcard_map)?;
|
||||||
|
}
|
||||||
st_tmpl
|
st_tmpl
|
||||||
.args
|
.args
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -633,7 +638,7 @@ pub fn wildcard_values_from_op_st(
|
||||||
.chain(core::iter::repeat(None))
|
.chain(core::iter::repeat(None))
|
||||||
.take(params.max_custom_predicate_wildcards)
|
.take(params.max_custom_predicate_wildcards)
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
fill_wildcard_values(pred, op_args, &mut wildcard_map)?;
|
fill_wildcard_values(params, pred, op_args, &mut wildcard_map)?;
|
||||||
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
|
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
|
||||||
// they are beyond the number of used wildcards in this custom predicate, or they could be
|
// they are beyond the number of used wildcards in this custom predicate, or they could be
|
||||||
// private arguments that are unused in a particular disjunction.
|
// private arguments that are unused in a particular disjunction.
|
||||||
|
|
@ -644,11 +649,15 @@ pub fn wildcard_values_from_op_st(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_custom_pred_argument(
|
fn check_custom_pred_argument(
|
||||||
|
params: &Params,
|
||||||
custom_pred_ref: &CustomPredicateRef,
|
custom_pred_ref: &CustomPredicateRef,
|
||||||
template: &StatementTmpl,
|
template: &StatementTmpl,
|
||||||
statement: &Statement,
|
statement: &Statement,
|
||||||
|
wc_values: &[Value],
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let template_pred = match &template.pred {
|
match &template.pred_or_wc {
|
||||||
|
PredicateOrWildcard::Predicate(pred) => {
|
||||||
|
let template_pred = match pred {
|
||||||
&Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
|
&Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
|
||||||
batch: custom_pred_ref.batch.clone(),
|
batch: custom_pred_ref.batch.clone(),
|
||||||
index: i,
|
index: i,
|
||||||
|
|
@ -661,6 +670,18 @@ fn check_custom_pred_argument(
|
||||||
statement.predicate(),
|
statement.predicate(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
PredicateOrWildcard::Wildcard(wc) => {
|
||||||
|
let pred_hash = Value::from(statement.predicate().hash(params));
|
||||||
|
if wc_values[wc.index] != pred_hash {
|
||||||
|
return Err(Error::mismatched_statement_wc_pred(
|
||||||
|
wc_values[wc.index].clone(),
|
||||||
|
pred_hash,
|
||||||
|
statement.predicate(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
let st_args_len = statement.args().len();
|
let st_args_len = statement.args().len();
|
||||||
if template.args.len() != st_args_len {
|
if template.args.len() != st_args_len {
|
||||||
return Err(Error::diff_amount(
|
return Err(Error::diff_amount(
|
||||||
|
|
@ -697,17 +718,42 @@ pub(crate) fn check_custom_pred(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that the resolved wildcards match the statement arguments.
|
||||||
|
let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) {
|
||||||
|
Ok(wc_values) => wc_values,
|
||||||
|
Err(Error::Inner { inner, backtrace }) => match *inner {
|
||||||
|
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
|
||||||
|
if wc.index <= s_args.len() =>
|
||||||
|
{
|
||||||
|
return Err(Error::mismatched_wildcard_value_and_statement_arg(
|
||||||
|
v,
|
||||||
|
prev,
|
||||||
|
wc.index,
|
||||||
|
pred.clone(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
_ => return Err(Error::Inner { inner, backtrace }),
|
||||||
|
},
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
let mut match_exists = false;
|
let mut match_exists = false;
|
||||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||||
// For `or` predicates, only one statement needs to match the template.
|
// For `or` predicates, only one statement needs to match the template.
|
||||||
// The rest of the statements can be `None`.
|
// The rest of the statements can be `None`.
|
||||||
if !pred.conjunction
|
let expected_pred_is_none = match &st_tmpl.pred_or_wc {
|
||||||
&& matches!(st, Statement::None)
|
PredicateOrWildcard::Predicate(st_tmpl_pred) => {
|
||||||
&& st_tmpl.pred != Predicate::Native(NativePredicate::None)
|
*st_tmpl_pred == Predicate::Native(NativePredicate::None)
|
||||||
{
|
}
|
||||||
|
PredicateOrWildcard::Wildcard(wc) => {
|
||||||
|
wc_values[wc.index]
|
||||||
|
== Value::from(Predicate::Native(NativePredicate::None).hash(params))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if !pred.conjunction && matches!(st, Statement::None) && !expected_pred_is_none {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
check_custom_pred_argument(custom_pred_ref, st_tmpl, st)?;
|
check_custom_pred_argument(params, custom_pred_ref, st_tmpl, st, &wc_values)?;
|
||||||
match_exists = true;
|
match_exists = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -716,25 +762,7 @@ pub(crate) fn check_custom_pred(
|
||||||
pred.clone(),
|
pred.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
// Check that the resolved wildcards match the statement arguments.
|
|
||||||
match wildcard_values_from_op_st(params, pred, args, s_args) {
|
|
||||||
Ok(_) => Ok(()),
|
|
||||||
Err(Error::Inner { inner, backtrace }) => match *inner {
|
|
||||||
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
|
|
||||||
if wc.index <= s_args.len() =>
|
|
||||||
{
|
|
||||||
Err(Error::mismatched_wildcard_value_and_statement_arg(
|
|
||||||
v,
|
|
||||||
prev,
|
|
||||||
wc.index,
|
|
||||||
pred.clone(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
_ => Err(Error::Inner { inner, backtrace }),
|
|
||||||
},
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToFields for Operation {
|
impl ToFields for Operation {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue