Feat/fst order pred part1 & part2 (#454)

Implement support for first order predicates in the backend.
Now a statement template can have a predicate hash or a wildcard.

## predicate <-> predicate hash constraints

To build the custom predicate table we need to calculate the custom predicate batch id, which uses the serialization of the statement templates before normalization.  This serialization uses the predicate hash when the template uses a predicate (instead of a wildcard).  Then in normalization we recalculate the predicate hash if it was a Batch Self.

This means that the relation between hash and predicate must be checked before and after normalization when the template is not using a wildcard.  How this is achieved:
- Before normalization: the constructor of StatementTmplTarget forces that if we keep a predicate, it's hash must be equal to the pred_hash when the template has a predicate (and not a wildcard)
- After normalization: the predicate hash is calculated in the normalization and replaced in the case of the template using a predicate and it being a BatchSelf.  If it was a predicate but not batch self, the old value was used which was constrained via the constructor.

See `CircuitBuilder::add_virtual_statement_tmpl` and `normalize_st_tmpl_circuit`

## Wildcard predicate resolution

It is done via `make_predicate_from_template_circuit` and is fairly simple as it's contains similar logic to `make_statement_arg_from_template_circuit` but simpler.
This commit is contained in:
Eduard S. 2026-01-20 13:14:22 +01:00 committed by GitHub
parent 1724e7b146
commit 9c9a2c454c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 569 additions and 240 deletions

View file

@ -15,8 +15,9 @@ use crate::{
},
middleware::{
hash_values, AnchoredKey, CustomPredicate, CustomPredicateRef, Error, Hash, Key,
MiddlewareInnerError, NativePredicate, Params, Predicate, Result, Statement, StatementArg,
StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value, ValueRef, Wildcard, F,
MiddlewareInnerError, NativePredicate, Params, Predicate, PredicateOrWildcard, Result,
Statement, StatementArg, StatementTmpl, StatementTmplArg, ToFields, TypedValue, Value,
ValueRef, Wildcard, F,
},
};
@ -550,6 +551,22 @@ impl Operation {
}
}
// Check that the value `v` at wildcard `wc` exists in the map or set it.
fn wc_check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> Result<()> {
if let Some(prev) = &wildcard_map[wc.index] {
if *prev != v {
return Err(Error::invalid_wildcard_assignment(
wc.clone(),
v,
prev.clone(),
));
}
} else {
wildcard_map[wc.index] = Some(v);
}
Ok(())
}
/// Check that a StatementArg follows a StatementTmplArg based on the currently mapped wildcards.
/// Update the wildcard map with newly found wildcards.
pub fn check_st_tmpl(
@ -558,22 +575,6 @@ pub fn check_st_tmpl(
// Map from wildcards to values that we have seen so far.
wildcard_map: &mut [Option<Value>],
) -> Result<()> {
// Check that the value `v` at wildcard `wc` exists in the map or set it.
fn check_or_set(v: Value, wc: &Wildcard, wildcard_map: &mut [Option<Value>]) -> Result<()> {
if let Some(prev) = &wildcard_map[wc.index] {
if *prev != v {
return Err(Error::invalid_wildcard_assignment(
wc.clone(),
v,
prev.clone(),
));
}
} else {
wildcard_map[wc.index] = Some(v);
}
Ok(())
}
match (st_tmpl_arg, st_arg) {
(StatementTmplArg::None, StatementArg::None) => Ok(()),
(StatementTmplArg::Literal(lhs), StatementArg::Literal(rhs)) if lhs == rhs => Ok(()),
@ -581,7 +582,7 @@ pub fn check_st_tmpl(
StatementTmplArg::AnchoredKey(root_wc, key_tmpl),
StatementArg::Key(AnchoredKey { root, key }),
) => {
let root_ok = check_or_set(Value::from(*root), root_wc, wildcard_map);
let root_ok = wc_check_or_set(Value::from(*root), root_wc, wildcard_map);
root_ok.and_then(|_| {
(key_tmpl == key).then_some(()).ok_or(
Error::mismatched_anchored_key_in_statement_tmpl_arg(
@ -594,7 +595,7 @@ pub fn check_st_tmpl(
})
}
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
check_or_set(v.clone(), wc, wildcard_map)
wc_check_or_set(v.clone(), wc, wildcard_map)
}
_ => Err(Error::mismatched_statement_tmpl_arg(
st_tmpl_arg.clone(),
@ -604,12 +605,16 @@ pub fn check_st_tmpl(
}
pub fn fill_wildcard_values(
params: &Params,
pred: &CustomPredicate,
args: &[Statement],
wildcard_map: &mut [Option<Value>],
) -> Result<()> {
for (st_tmpl, st) in pred.statements.iter().zip(args) {
let st_args = st.args();
if let PredicateOrWildcard::Wildcard(wc) = &st_tmpl.pred_or_wc {
wc_check_or_set(Value::from(st.predicate().hash(params)), wc, wildcard_map)?;
}
st_tmpl
.args
.iter()
@ -633,7 +638,7 @@ pub fn wildcard_values_from_op_st(
.chain(core::iter::repeat(None))
.take(params.max_custom_predicate_wildcards)
.collect_vec();
fill_wildcard_values(pred, op_args, &mut wildcard_map)?;
fill_wildcard_values(params, pred, op_args, &mut wildcard_map)?;
// NOTE: We set unresolved wildcard slots with an empty value. They can be unresolved because
// they are beyond the number of used wildcards in this custom predicate, or they could be
// private arguments that are unused in a particular disjunction.
@ -644,22 +649,38 @@ pub fn wildcard_values_from_op_st(
}
fn check_custom_pred_argument(
params: &Params,
custom_pred_ref: &CustomPredicateRef,
template: &StatementTmpl,
statement: &Statement,
wc_values: &[Value],
) -> Result<()> {
let template_pred = match &template.pred {
&Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
batch: custom_pred_ref.batch.clone(),
index: i,
}),
p => p.clone(),
};
if template_pred != statement.predicate() {
return Err(Error::mismatched_statement_type(
template_pred,
statement.predicate(),
));
match &template.pred_or_wc {
PredicateOrWildcard::Predicate(pred) => {
let template_pred = match pred {
&Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
batch: custom_pred_ref.batch.clone(),
index: i,
}),
p => p.clone(),
};
if template_pred != statement.predicate() {
return Err(Error::mismatched_statement_type(
template_pred,
statement.predicate(),
));
}
}
PredicateOrWildcard::Wildcard(wc) => {
let pred_hash = Value::from(statement.predicate().hash(params));
if wc_values[wc.index] != pred_hash {
return Err(Error::mismatched_statement_wc_pred(
wc_values[wc.index].clone(),
pred_hash,
statement.predicate(),
));
}
}
}
let st_args_len = statement.args().len();
if template.args.len() != st_args_len {
@ -697,17 +718,42 @@ pub(crate) fn check_custom_pred(
));
}
// Check that the resolved wildcards match the statement arguments.
let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) {
Ok(wc_values) => wc_values,
Err(Error::Inner { inner, backtrace }) => match *inner {
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
if wc.index <= s_args.len() =>
{
return Err(Error::mismatched_wildcard_value_and_statement_arg(
v,
prev,
wc.index,
pred.clone(),
))
}
_ => return Err(Error::Inner { inner, backtrace }),
},
_ => unreachable!(),
};
let mut match_exists = false;
for (st_tmpl, st) in pred.statements.iter().zip(args) {
// For `or` predicates, only one statement needs to match the template.
// The rest of the statements can be `None`.
if !pred.conjunction
&& matches!(st, Statement::None)
&& st_tmpl.pred != Predicate::Native(NativePredicate::None)
{
let expected_pred_is_none = match &st_tmpl.pred_or_wc {
PredicateOrWildcard::Predicate(st_tmpl_pred) => {
*st_tmpl_pred == Predicate::Native(NativePredicate::None)
}
PredicateOrWildcard::Wildcard(wc) => {
wc_values[wc.index]
== Value::from(Predicate::Native(NativePredicate::None).hash(params))
}
};
if !pred.conjunction && matches!(st, Statement::None) && !expected_pred_is_none {
continue;
}
check_custom_pred_argument(custom_pred_ref, st_tmpl, st)?;
check_custom_pred_argument(params, custom_pred_ref, st_tmpl, st, &wc_values)?;
match_exists = true;
}
@ -716,25 +762,7 @@ pub(crate) fn check_custom_pred(
pred.clone(),
));
}
// Check that the resolved wildcards match the statement arguments.
match wildcard_values_from_op_st(params, pred, args, s_args) {
Ok(_) => Ok(()),
Err(Error::Inner { inner, backtrace }) => match *inner {
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
if wc.index <= s_args.len() =>
{
Err(Error::mismatched_wildcard_value_and_statement_arg(
v,
prev,
wc.index,
pred.clone(),
))
}
_ => Err(Error::Inner { inner, backtrace }),
},
_ => unreachable!(),
}
Ok(())
}
impl ToFields for Operation {