fix frontend::Operation::new_entry + doc improvements (#370)
This commit is contained in:
parent
4fa285d9fb
commit
7f120f026d
5 changed files with 206 additions and 222 deletions
|
|
@ -8,9 +8,9 @@ use serde::{Deserialize, Serialize};
|
||||||
pub use serialization::{SerializedMainPod, SerializedSignedPod};
|
pub use serialization::{SerializedMainPod, SerializedSignedPod};
|
||||||
|
|
||||||
use crate::middleware::{
|
use crate::middleware::{
|
||||||
self, check_st_tmpl, hash_op, hash_str, max_op, prod_op, sum_op, AnchoredKey, Key,
|
self, check_custom_pred, check_st_tmpl, hash_op, hash_str, max_op, prod_op, sum_op,
|
||||||
MainPodInputs, NativeOperation, OperationAux, OperationType, Params, PodId, PodProver,
|
AnchoredKey, Key, MainPodInputs, NativeOperation, OperationAux, OperationType, Params, PodId,
|
||||||
PodSigner, Statement, StatementArg, VDSet, Value, ValueRef, KEY_TYPE, SELF,
|
PodProver, PodSigner, Statement, StatementArg, VDSet, Value, ValueRef, KEY_TYPE, SELF,
|
||||||
};
|
};
|
||||||
|
|
||||||
mod custom;
|
mod custom;
|
||||||
|
|
@ -285,63 +285,48 @@ impl MainPodBuilder {
|
||||||
|
|
||||||
fn op_statement(&mut self, op: Operation) -> Result<Statement> {
|
fn op_statement(&mut self, op: Operation) -> Result<Statement> {
|
||||||
use NativeOperation::*;
|
use NativeOperation::*;
|
||||||
let arg_error = |s: &str| Error::op_invalid_args(s.to_string());
|
|
||||||
let st = match op.0 {
|
let st = match op.0 {
|
||||||
OperationType::Native(o) => match (o, &op.1.as_slice()) {
|
OperationType::Native(o) => {
|
||||||
|
let native_arg_error = move || Error::op_invalid_args(format!("{o:?}"));
|
||||||
|
match (o, &op.1.as_slice()) {
|
||||||
(None, &[]) => Statement::None,
|
(None, &[]) => Statement::None,
|
||||||
(NewEntry, &[OperationArg::Entry(k, v)]) => {
|
(NewEntry, &[OperationArg::Entry(k, v)]) => {
|
||||||
Statement::equal(AnchoredKey::from((SELF, k.as_str())), v.clone())
|
Statement::equal(AnchoredKey::from((SELF, k.as_str())), v.clone())
|
||||||
}
|
}
|
||||||
(EqualFromEntries, &[a1, a2]) => {
|
(EqualFromEntries, &[a1, a2]) => {
|
||||||
let (r1, v1) = a1
|
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("equal-from-entries"))?;
|
|
||||||
let (r2, v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("equal-from-entries"))?;
|
|
||||||
if v1 == v2 {
|
if v1 == v2 {
|
||||||
Statement::equal(r1, r2)
|
Statement::equal(r1, r2)
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("equal-from-entries"));
|
return Err(native_arg_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(NotEqualFromEntries, &[a1, a2]) => {
|
(NotEqualFromEntries, &[a1, a2]) => {
|
||||||
let (r1, v1) = a1
|
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("not-equal-from-entries"))?;
|
|
||||||
let (r2, v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("not-equal-from-entries"))?;
|
|
||||||
if v1 != v2 {
|
if v1 != v2 {
|
||||||
Statement::not_equal(r1, r2)
|
Statement::not_equal(r1, r2)
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("not-equal-from-entries"));
|
return Err(native_arg_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(LtFromEntries, &[a1, a2]) => {
|
(LtFromEntries, &[a1, a2]) => {
|
||||||
let (r1, v1) = a1
|
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("lt-from-entries"))?;
|
|
||||||
let (r2, v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("lt-from-entries"))?;
|
|
||||||
if v1 < v2 {
|
if v1 < v2 {
|
||||||
Statement::lt(r1, r2)
|
Statement::lt(r1, r2)
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("lt-from-entries"));
|
return Err(native_arg_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(LtEqFromEntries, &[a1, a2]) => {
|
(LtEqFromEntries, &[a1, a2]) => {
|
||||||
let (r1, v1) = a1
|
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("lt-eq-from-entries"))?;
|
|
||||||
let (r2, v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("lt-eq-from-entries"))?;
|
|
||||||
if v1 <= v2 {
|
if v1 <= v2 {
|
||||||
Statement::not_equal(r1, r2)
|
Statement::not_equal(r1, r2)
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("lt-eq-from-entries"));
|
return Err(native_arg_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CopyStatement, &[OperationArg::Statement(s)]) => s.clone(),
|
(CopyStatement, &[OperationArg::Statement(s)]) => s.clone(),
|
||||||
|
|
@ -352,110 +337,72 @@ impl MainPodBuilder {
|
||||||
if r2 == r3 {
|
if r2 == r3 {
|
||||||
Statement::Equal(r1.clone(), r4.clone())
|
Statement::Equal(r1.clone(), r4.clone())
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("transitive-eq"));
|
return Err(native_arg_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(LtToNotEqual, &[OperationArg::Statement(Statement::Lt(r1, r2))]) => {
|
(LtToNotEqual, &[OperationArg::Statement(Statement::Lt(r1, r2))]) => {
|
||||||
Statement::NotEqual(r1.clone(), r2.clone())
|
Statement::NotEqual(r1.clone(), r2.clone())
|
||||||
}
|
}
|
||||||
(SumOf, &[a1, a2, a3]) => {
|
(SumOf, &[a1, a2, a3]) => {
|
||||||
let (r1, v1) = a1
|
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("sum-from-entries"))?;
|
let (r3, v3) = a3.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
let (r2, v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("sum-from-entries"))?;
|
|
||||||
let (r3, v3) = a3
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("sum-from-entries"))?;
|
|
||||||
if middleware::Operation::check_int_fn(v1, v2, v3, sum_op)? {
|
if middleware::Operation::check_int_fn(v1, v2, v3, sum_op)? {
|
||||||
Statement::SumOf(r1, r2, r3)
|
Statement::SumOf(r1, r2, r3)
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("sum-from-entries"));
|
return Err(native_arg_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(ProductOf, &[a1, a2, a3]) => {
|
(ProductOf, &[a1, a2, a3]) => {
|
||||||
let (r1, v1) = a1
|
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("prod-from-entries"))?;
|
let (r3, v3) = a3.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
let (r2, v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("prod-from-entries"))?;
|
|
||||||
let (r3, v3) = a3
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("prod-from-entries"))?;
|
|
||||||
if middleware::Operation::check_int_fn(v1, v2, v3, prod_op)? {
|
if middleware::Operation::check_int_fn(v1, v2, v3, prod_op)? {
|
||||||
Statement::ProductOf(r1, r2, r3)
|
Statement::ProductOf(r1, r2, r3)
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("prod-from-entries"));
|
return Err(native_arg_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(MaxOf, &[a1, a2, a3]) => {
|
(MaxOf, &[a1, a2, a3]) => {
|
||||||
let (r1, v1) = a1
|
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("max-from-entries"))?;
|
let (r3, v3) = a3.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
let (r2, v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("max-from-entries"))?;
|
|
||||||
let (r3, v3) = a3
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("max-from-entries"))?;
|
|
||||||
if middleware::Operation::check_int_fn(v1, v2, v3, max_op)? {
|
if middleware::Operation::check_int_fn(v1, v2, v3, max_op)? {
|
||||||
Statement::MaxOf(r1, r2, r3)
|
Statement::MaxOf(r1, r2, r3)
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("max-from-entries"));
|
return Err(native_arg_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(HashOf, &[a1, a2, a3]) => {
|
(HashOf, &[a1, a2, a3]) => {
|
||||||
let (r1, v1) = a1
|
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("hash-from-entries"))?;
|
let (r3, v3) = a3.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
let (r2, v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("hash-from-entries"))?;
|
|
||||||
let (r3, v3) = a3
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("hash-from-entries"))?;
|
|
||||||
if v1 == &hash_op(v2.clone(), v3.clone()) {
|
if v1 == &hash_op(v2.clone(), v3.clone()) {
|
||||||
Statement::HashOf(r1, r2, r3)
|
Statement::HashOf(r1, r2, r3)
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("hash-from-entries"));
|
return Err(native_arg_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(ContainsFromEntries, &[a1, a2, a3]) => {
|
(ContainsFromEntries, &[a1, a2, a3]) => {
|
||||||
let (r1, _v1) = a1
|
let (r1, _v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, _v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("contains-from-entries"))?;
|
let (r3, _v3) = a3.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
let (r2, _v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("contains-from-entries"))?;
|
|
||||||
let (r3, _v3) = a3
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("contains-from-entries"))?;
|
|
||||||
// TODO: validate proof
|
// TODO: validate proof
|
||||||
Statement::Contains(r1, r2, r3)
|
Statement::Contains(r1, r2, r3)
|
||||||
}
|
}
|
||||||
(NotContainsFromEntries, &[a1, a2]) => {
|
(NotContainsFromEntries, &[a1, a2]) => {
|
||||||
let (r1, _v1) = a1
|
let (r1, _v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, _v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("contains-from-entries"))?;
|
|
||||||
let (r2, _v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("contains-from-entries"))?;
|
|
||||||
// TODO: validate proof
|
// TODO: validate proof
|
||||||
Statement::NotContains(r1, r2)
|
Statement::NotContains(r1, r2)
|
||||||
}
|
}
|
||||||
(PublicKeyOf, &[a1, a2]) => {
|
(PublicKeyOf, &[a1, a2]) => {
|
||||||
let (r1, v1) = a1
|
let (r1, v1) = a1.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.value_and_ref()
|
let (r2, v2) = a2.value_and_ref().ok_or_else(native_arg_error)?;
|
||||||
.ok_or_else(|| arg_error("public-key-from-entries"))?;
|
|
||||||
let (r2, v2) = a2
|
|
||||||
.value_and_ref()
|
|
||||||
.ok_or_else(|| arg_error("public-key-from-entries"))?;
|
|
||||||
if middleware::Operation::check_public_key(v1, v2)? {
|
if middleware::Operation::check_public_key(v1, v2)? {
|
||||||
Statement::PublicKeyOf(r1, r2)
|
Statement::PublicKeyOf(r1, r2)
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("public-key-from-entries"));
|
return Err(native_arg_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(t, _) => {
|
(t, _) => {
|
||||||
|
|
@ -465,10 +412,11 @@ impl MainPodBuilder {
|
||||||
t
|
t
|
||||||
)));
|
)));
|
||||||
} else {
|
} else {
|
||||||
return Err(arg_error("malformed operation"));
|
return Err(native_arg_error());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
|
||||||
OperationType::Custom(cpr) => {
|
OperationType::Custom(cpr) => {
|
||||||
let pred = &cpr.batch.predicates()[cpr.index];
|
let pred = &cpr.batch.predicates()[cpr.index];
|
||||||
if pred.statements.len() != op.1.len() {
|
if pred.statements.len() != op.1.len() {
|
||||||
|
|
@ -509,6 +457,7 @@ impl MainPodBuilder {
|
||||||
.take(pred.args_len)
|
.take(pred.args_len)
|
||||||
.map(|v| v.unwrap_or_else(|| v_default.clone()))
|
.map(|v| v.unwrap_or_else(|| v_default.clone()))
|
||||||
.collect();
|
.collect();
|
||||||
|
check_custom_pred(&self.params, &cpr, &args, &st_args)?;
|
||||||
Statement::Custom(cpr, st_args)
|
Statement::Custom(cpr, st_args)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -164,10 +164,10 @@ macro_rules! op_impl_st {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Operation {
|
impl Operation {
|
||||||
pub fn new_entry(a1: impl Into<OperationArg>, a2: impl Into<Value>) -> Self {
|
pub fn new_entry(a1: impl Into<String>, a2: impl Into<Value>) -> Self {
|
||||||
Self(
|
Self(
|
||||||
OperationType::Native(NativeOperation::NewEntry),
|
OperationType::Native(NativeOperation::NewEntry),
|
||||||
vec![a1.into(), a2.into().into()],
|
vec![OperationArg::Entry(a1.into(), a2.into())],
|
||||||
OperationAux::None,
|
OperationAux::None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -180,6 +180,12 @@ impl Operation {
|
||||||
op_impl_oa!(sum_of, SumOf, 3);
|
op_impl_oa!(sum_of, SumOf, 3);
|
||||||
op_impl_oa!(product_of, ProductOf, 3);
|
op_impl_oa!(product_of, ProductOf, 3);
|
||||||
op_impl_oa!(max_of, MaxOf, 3);
|
op_impl_oa!(max_of, MaxOf, 3);
|
||||||
|
/// Creates a custom operation.
|
||||||
|
///
|
||||||
|
/// `args` should contain the statements that are needed to prove the
|
||||||
|
/// custom statement. It should have the same length as
|
||||||
|
/// `cpr.predicate().statements()`. If `cpr` refers to an `or` predicate,
|
||||||
|
/// then all but one of the statements should be `Statement::None`.
|
||||||
pub fn custom(cpr: CustomPredicateRef, args: Vec<OperationArg>) -> Self {
|
pub fn custom(cpr: CustomPredicateRef, args: Vec<OperationArg>) -> Self {
|
||||||
Self(OperationType::Custom(cpr), args, OperationAux::None)
|
Self(OperationType::Custom(cpr), args, OperationAux::None)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -227,6 +227,14 @@ impl CustomPredicate {
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Self::new(params, name, false, statements, args_len, wildcard_names)
|
Self::new(params, name, false, statements, args_len, wildcard_names)
|
||||||
}
|
}
|
||||||
|
/// Creates a new custom predicate.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `name` - The name of the custom predicate.
|
||||||
|
/// * `conjunction` - `true` for an `and` predicate, `false` for an `or` predicate.
|
||||||
|
/// * `statements` - The statements required to apply the custom predicate.
|
||||||
|
/// * `args_len` - The number of public arguments.
|
||||||
|
/// * `wildcard_names` - The names of the arguments (public and private).
|
||||||
pub fn new(
|
pub fn new(
|
||||||
params: &Params,
|
params: &Params,
|
||||||
name: String,
|
name: String,
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
||||||
use std::{backtrace::Backtrace, fmt::Debug};
|
use std::{backtrace::Backtrace, fmt::Debug};
|
||||||
|
|
||||||
use crate::middleware::{
|
use crate::middleware::{
|
||||||
CustomPredicate, Key, Operation, PodId, Statement, StatementArg, StatementTmplArg, Value,
|
CustomPredicate, Key, Operation, PodId, Predicate, Statement, StatementArg, StatementTmplArg,
|
||||||
Wildcard,
|
Value, Wildcard,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub type Result<T, E = Error> = core::result::Result<T, E>;
|
pub type Result<T, E = Error> = core::result::Result<T, E>;
|
||||||
|
|
@ -19,7 +19,7 @@ pub enum MiddlewareInnerError {
|
||||||
InvalidStatementArg(StatementArg, String),
|
InvalidStatementArg(StatementArg, String),
|
||||||
#[error("{0} {1} is over the limit {2}")]
|
#[error("{0} {1} is over the limit {2}")]
|
||||||
MaxLength(String, usize, usize),
|
MaxLength(String, usize, usize),
|
||||||
#[error("{0} amount of {1} should be {1} but it's {2}")]
|
#[error("{0} amount of {1} should be {2} but it's {3}")]
|
||||||
DiffAmount(String, String, usize, usize),
|
DiffAmount(String, String, usize, usize),
|
||||||
#[error("{0} should be assigned the value {1} but has previously been assigned {2}")]
|
#[error("{0} should be assigned the value {1} but has previously been assigned {2}")]
|
||||||
InvalidWildcardAssignment(Wildcard, Value, Value),
|
InvalidWildcardAssignment(Wildcard, Value, Value),
|
||||||
|
|
@ -27,12 +27,10 @@ pub enum MiddlewareInnerError {
|
||||||
MismatchedAnchoredKeyInStatementTmplArg(Wildcard, PodId, Key, Key),
|
MismatchedAnchoredKeyInStatementTmplArg(Wildcard, PodId, Key, Key),
|
||||||
#[error("{0} does not match against {1}")]
|
#[error("{0} does not match against {1}")]
|
||||||
MismatchedStatementTmplArg(StatementTmplArg, StatementArg),
|
MismatchedStatementTmplArg(StatementTmplArg, StatementArg),
|
||||||
|
#[error("Expected a statement of type {0}, got {1}")]
|
||||||
|
MismatchedStatementType(Predicate, Predicate),
|
||||||
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
|
#[error("Value {0} does not match argument {1} with index {2} in the following custom predicate:\n{3}")]
|
||||||
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
|
MismatchedWildcardValueAndStatementArg(Value, Value, usize, CustomPredicate),
|
||||||
#[error(
|
|
||||||
"Not all statement templates of the following custom predicate have been matched:\n{0}"
|
|
||||||
)]
|
|
||||||
UnsatisfiedCustomPredicateConjunction(CustomPredicate),
|
|
||||||
#[error(
|
#[error(
|
||||||
"None of the statement templates of the following custom predicate have been matched:\n{0}"
|
"None of the statement templates of the following custom predicate have been matched:\n{0}"
|
||||||
)]
|
)]
|
||||||
|
|
@ -110,6 +108,9 @@ impl Error {
|
||||||
) -> Self {
|
) -> Self {
|
||||||
new!(MismatchedStatementTmplArg(st_tmpl_arg, st_arg))
|
new!(MismatchedStatementTmplArg(st_tmpl_arg, st_arg))
|
||||||
}
|
}
|
||||||
|
pub(crate) fn mismatched_statement_type(expected: Predicate, seen: Predicate) -> Self {
|
||||||
|
new!(MismatchedStatementType(expected, seen))
|
||||||
|
}
|
||||||
pub(crate) fn mismatched_wildcard_value_and_statement_arg(
|
pub(crate) fn mismatched_wildcard_value_and_statement_arg(
|
||||||
wc_value: Value,
|
wc_value: Value,
|
||||||
st_arg: Value,
|
st_arg: Value,
|
||||||
|
|
@ -120,9 +121,6 @@ impl Error {
|
||||||
wc_value, st_arg, arg_index, pred
|
wc_value, st_arg, arg_index, pred
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
pub(crate) fn unsatisfied_custom_predicate_conjunction(pred: CustomPredicate) -> Self {
|
|
||||||
new!(UnsatisfiedCustomPredicateConjunction(pred))
|
|
||||||
}
|
|
||||||
pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self {
|
pub(crate) fn unsatisfied_custom_predicate_disjunction(pred: CustomPredicate) -> Self {
|
||||||
new!(UnsatisfiedCustomPredicateDisjunction(pred))
|
new!(UnsatisfiedCustomPredicateDisjunction(pred))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,8 @@ use crate::{
|
||||||
},
|
},
|
||||||
middleware::{
|
middleware::{
|
||||||
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, NativePredicate,
|
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, NativePredicate,
|
||||||
Params, Predicate, Result, Statement, StatementArg, StatementTmplArg, ToFields, Value,
|
Params, Predicate, Result, Statement, StatementArg, StatementTmpl, StatementTmplArg,
|
||||||
ValueRef, Wildcard, F, SELF,
|
ToFields, Value, ValueRef, Wildcard, F, SELF,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -486,7 +486,37 @@ pub fn resolve_wildcard_values(
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_custom_pred(
|
fn check_custom_pred_argument(
|
||||||
|
custom_pred_ref: &CustomPredicateRef,
|
||||||
|
template: &StatementTmpl,
|
||||||
|
statement: &Statement,
|
||||||
|
) -> Result<()> {
|
||||||
|
let template_pred = match &template.pred {
|
||||||
|
&Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
|
||||||
|
batch: custom_pred_ref.batch.clone(),
|
||||||
|
index: i,
|
||||||
|
}),
|
||||||
|
p => p.clone(),
|
||||||
|
};
|
||||||
|
if template_pred != statement.predicate() {
|
||||||
|
return Err(Error::mismatched_statement_type(
|
||||||
|
template_pred,
|
||||||
|
statement.predicate(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let st_args_len = statement.args().len();
|
||||||
|
if template.args.len() != st_args_len {
|
||||||
|
return Err(Error::diff_amount(
|
||||||
|
"statement template in custom predicate".to_string(),
|
||||||
|
"arguments".to_string(),
|
||||||
|
st_args_len,
|
||||||
|
template.args.len(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn check_custom_pred(
|
||||||
params: &Params,
|
params: &Params,
|
||||||
custom_pred_ref: &CustomPredicateRef,
|
custom_pred_ref: &CustomPredicateRef,
|
||||||
args: &[Statement],
|
args: &[Statement],
|
||||||
|
|
@ -510,19 +540,24 @@ fn check_custom_pred(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count the number of statements that match the templates by predicate.
|
let mut match_exists = false;
|
||||||
let mut num_matches = 0;
|
|
||||||
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
for (st_tmpl, st) in pred.statements.iter().zip(args) {
|
||||||
let st_tmpl_pred = match &st_tmpl.pred {
|
// For `or` predicates, only one statement needs to match the template.
|
||||||
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
|
// The rest of the statements can be `None`.
|
||||||
batch: custom_pred_ref.batch.clone(),
|
if !pred.conjunction
|
||||||
index: *i,
|
&& matches!(st, Statement::None)
|
||||||
}),
|
&& st_tmpl.pred != Predicate::Native(NativePredicate::None)
|
||||||
p => p.clone(),
|
{
|
||||||
};
|
continue;
|
||||||
if st_tmpl_pred == st.predicate() {
|
|
||||||
num_matches += 1;
|
|
||||||
}
|
}
|
||||||
|
check_custom_pred_argument(custom_pred_ref, st_tmpl, st)?;
|
||||||
|
match_exists = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pred.conjunction && !match_exists {
|
||||||
|
return Err(Error::unsatisfied_custom_predicate_disjunction(
|
||||||
|
pred.clone(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let wildcard_map = resolve_wildcard_values(params, pred, args)?;
|
let wildcard_map = resolve_wildcard_values(params, pred, args)?;
|
||||||
|
|
@ -539,18 +574,6 @@ fn check_custom_pred(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if pred.conjunction {
|
|
||||||
if num_matches != pred.statements.len() {
|
|
||||||
return Err(Error::unsatisfied_custom_predicate_conjunction(
|
|
||||||
pred.clone(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
} else if num_matches == 0 {
|
|
||||||
return Err(Error::unsatisfied_custom_predicate_disjunction(
|
|
||||||
pred.clone(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue