Feat/fst order pred part1 & part2 (#454)

Implement support for first order predicates in the backend.
Now a statement template can have a predicate hash or a wildcard.

## predicate <-> predicate hash constraints

To build the custom predicate table we need to calculate the custom predicate batch id, which uses the serialization of the statement templates before normalization.  This serialization uses the predicate hash when the template uses a predicate (instead of a wildcard).  Then in normalization we recalculate the predicate hash if it was a Batch Self.

This means that the relation between hash and predicate must be checked before and after normalization when the template is not using a wildcard.  How this is achieved:
- Before normalization: the constructor of StatementTmplTarget forces that if we keep a predicate, it's hash must be equal to the pred_hash when the template has a predicate (and not a wildcard)
- After normalization: the predicate hash is calculated in the normalization and replaced in the case of the template using a predicate and it being a BatchSelf.  If it was a predicate but not batch self, the old value was used which was constrained via the constructor.

See `CircuitBuilder::add_virtual_statement_tmpl` and `normalize_st_tmpl_circuit`

## Wildcard predicate resolution

It is done via `make_predicate_from_template_circuit` and is fairly simple as it's contains similar logic to `make_statement_arg_from_template_circuit` but simpler.
This commit is contained in:
Eduard S. 2026-01-20 13:14:22 +01:00 committed by GitHub
parent 1724e7b146
commit 9c9a2c454c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 569 additions and 240 deletions

View file

@ -18,8 +18,8 @@ use crate::{
},
middleware::{
self, containers, CustomPredicateBatch, IntroPredicateRef, NativePredicate, Params,
Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg,
Wildcard,
Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
StatementTmplArg as MWStatementTmplArg, Wildcard,
},
};
@ -201,7 +201,8 @@ impl<'a> Lowerer<'a> {
}
Ok(MWStatementTmpl {
pred: predicate,
// TODO: Support wildcard
pred_or_wc: PredicateOrWildcard::Predicate(predicate),
args: mw_args,
})
}
@ -596,7 +597,10 @@ mod tests {
let stmt = &pred2.statements()[0];
// Should be BatchSelf(0) referring to pred1
assert!(matches!(stmt.pred, Predicate::BatchSelf(0)));
assert!(matches!(
stmt.pred_or_wc,
PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
));
}
#[test]
@ -632,8 +636,8 @@ mod tests {
// Should desugar to the Contains predicate
assert!(matches!(
stmt.pred,
Predicate::Native(NativePredicate::Contains)
stmt.pred_or_wc,
PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Contains))
));
}

View file

@ -63,8 +63,8 @@ mod tests {
backends::plonky2::primitives::ec::schnorr::SecretKey,
middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate,
Params, Predicate, RawValue, StatementTmpl, StatementTmplArg, Value, Wildcard,
EMPTY_HASH,
Params, Predicate, PredicateOrWildcard, RawValue, StatementTmpl, StatementTmplArg,
Value, Wildcard, EMPTY_HASH,
},
};
@ -89,6 +89,10 @@ mod tests {
names.iter().map(|s| s.to_string()).collect()
}
fn pred_lit(pred: Predicate) -> PredicateOrWildcard {
PredicateOrWildcard::Predicate(pred)
}
#[test]
fn test_e2e_simple_predicate() -> Result<(), LangError> {
let input = r#"
@ -109,7 +113,7 @@ mod tests {
// Expected structure
let expected_statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("PodA", 0), "the_key"), // PodA["the_key"] -> Wildcard(0), Key("the_key")
sta_ak(("PodB", 1), "the_key"), // PodB["the_key"] -> Wildcard(1), Key("the_key")
@ -153,14 +157,14 @@ mod tests {
// Expected structure
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("ConstPod", 0), "my_val"), // ConstPod["my_val"] -> Wildcard(0), Key("my_val")
sta_lit(RawValue::from(1)),
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Lt),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)),
args: vec![
sta_ak(("GovPod", 1), "dob"), // GovPod["dob"] -> Wildcard(1), Key("dob")
sta_ak(("ConstPod", 0), "my_val"), // ConstPod["my_val"] -> Wildcard(0), Key("my_val")
@ -195,14 +199,14 @@ mod tests {
// Expected structure: Public args: A (index 0). Private args: Temp (index 1)
let expected_statements = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("A", 0), "input_key"), // A["input_key"] -> Wildcard(0), Key("input_key")
sta_ak(("Temp", 1), "const_key"), // Temp["const_key"] -> Wildcard(1), Key("const_key")
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("Temp", 1), "const_key"), // Temp["const_key"] -> Wildcard(1), Key("const_key")
sta_lit("some_value"), // Literal("some_value")
@ -251,7 +255,7 @@ mod tests {
// Expected Batch structure
let expected_pred_statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("X", 0), "val"), // X["val"] -> Wildcard(0), Key("val")
sta_ak(("Y", 1), "val"), // Y["val"] -> Wildcard(1), Key("val")
@ -275,7 +279,10 @@ mod tests {
// Expected Request structure
// Pod1 -> Wildcard 0, Pod2 -> Wildcard 1
let expected_request_templates = vec![StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(expected_batch, 0)),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
expected_batch,
0,
))),
args: vec![
StatementTmplArg::Wildcard(wc("Pod1", 0)),
StatementTmplArg::Wildcard(wc("Pod2", 1)),
@ -317,7 +324,7 @@ mod tests {
// Expected structure
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(batch_result, 0)), // Refers to some_pred
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch_result, 0))), // Refers to some_pred
args: vec![
StatementTmplArg::Wildcard(wc("Var1", 0)), // Var1
StatementTmplArg::Literal(Value::from(12345i64)), // 12345
@ -325,7 +332,7 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
// AnotherPod["another_key"] -> Wildcard(1), Key("another_key")
sta_ak(("AnotherPod", 1), "another_key"),
@ -362,15 +369,15 @@ mod tests {
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::LtEq),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::LtEq)),
args: vec![sta_ak(("B", 1), "bar"), sta_ak(("A", 0), "foo")],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Lt),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)),
args: vec![sta_ak(("D", 3), "qux"), sta_ak(("C", 2), "baz")],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Contains),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Contains)),
args: vec![
sta_ak(("A", 0), "foo"),
sta_ak(("B", 1), "bar"),
@ -378,11 +385,11 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::NotContains),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::NotContains)),
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Contains),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Contains)),
args: vec![
sta_ak(("A", 0), "foo"),
sta_ak(("B", 1), "bar"),
@ -439,7 +446,7 @@ mod tests {
let expected_templates = vec![
// 1. NotContains(sanctions["sanctionList"], gov["idNumber"])
StatementTmpl {
pred: Predicate::Native(NativePredicate::NotContains),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::NotContains)),
args: vec![
sta_ak(
(wc_sanctions.name.as_str(), wc_sanctions.index),
@ -450,7 +457,7 @@ mod tests {
},
// 2. Lt(gov["dateOfBirth"], SELF_HOLDER_18Y["const_18y"])
StatementTmpl {
pred: Predicate::Native(NativePredicate::Lt),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Lt)),
args: vec![
sta_ak((wc_gov.name.as_str(), wc_gov.index), dob_key),
sta_ak(
@ -461,7 +468,7 @@ mod tests {
},
// 3. Equal(pay["startDate"], SELF_HOLDER_1Y["const_1y"])
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak((wc_pay.name.as_str(), wc_pay.index), start_date_key),
sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key),
@ -469,7 +476,7 @@ mod tests {
},
// 4. Equal(gov["socialSecurityNumber"], pay["socialSecurityNumber"])
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak((wc_gov.name.as_str(), wc_gov.index), ssn_key),
sta_ak((wc_pay.name.as_str(), wc_pay.index), ssn_key),
@ -477,7 +484,7 @@ mod tests {
},
// 5. Equal(SELF_HOLDER_18Y["const_18y"], 1169909388)
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(
(wc_self_18y.name.as_str(), wc_self_18y.index),
@ -488,7 +495,7 @@ mod tests {
},
// 6. Equal(SELF_HOLDER_1Y["const_1y"], 1706367566)
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key),
sta_lit(now_minus_1y_val.clone()),
@ -563,11 +570,11 @@ mod tests {
// eth_friend (Index 0)
let expected_friend_stmts = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::SignedBy),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::SignedBy)),
args: vec![sta_wc_lit("attestation_dict", 2), sta_wc_lit("src", 0)],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("attestation_dict", 2), "attestation"),
sta_wc_lit("dst", 1), // Pub arg 1
@ -586,11 +593,11 @@ mod tests {
// eth_dos_distance_base (Index 1)
let expected_base_stmts = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_wc_lit("src", 0), sta_wc_lit("dst", 1)],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_wc_lit("distance", 2), sta_lit(0i64)],
},
];
@ -608,7 +615,7 @@ mod tests {
// Private args indices: 3-4 (shorter_distance, intermed)
let expected_ind_stmts = vec![
StatementTmpl {
pred: Predicate::BatchSelf(3), // Calls eth_dos_distance (index 3)
pred_or_wc: pred_lit(Predicate::BatchSelf(3)), // Calls eth_dos_distance (index 3)
args: vec![
// WildcardLiteral args
sta_wc_lit("src", 0),
@ -617,7 +624,7 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::SumOf),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::SumOf)),
args: vec![
sta_wc_lit("distance", 2), // public arg
sta_wc_lit("shorter_distance", 3), // private arg
@ -625,7 +632,7 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::BatchSelf(0), // Calls eth_friend (index 0)
pred_or_wc: pred_lit(Predicate::BatchSelf(0)), // Calls eth_friend (index 0)
args: vec![
// WildcardLiteral args
sta_wc_lit("intermed", 4), // private arg
@ -645,7 +652,7 @@ mod tests {
// eth_dos_distance (Index 3)
let expected_dist_stmts = vec![
StatementTmpl {
pred: Predicate::BatchSelf(1), // Calls eth_dos_distance_base (index 1)
pred_or_wc: pred_lit(Predicate::BatchSelf(1)), // Calls eth_dos_distance_base (index 1)
args: vec![
// WildcardLiteral args
sta_wc_lit("src", 0),
@ -654,7 +661,7 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::BatchSelf(2), // Calls eth_dos_distance_ind (index 2)
pred_or_wc: pred_lit(Predicate::BatchSelf(2)), // Calls eth_dos_distance_ind (index 2)
args: vec![
// WildcardLiteral args
sta_wc_lit("src", 0),
@ -697,7 +704,7 @@ mod tests {
// 1. Create a batch to be imported
let imported_pred_stmts = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("A", 0), "foo"), // A["foo"]
sta_ak(("B", 1), "bar"), // B["bar"]
@ -739,7 +746,10 @@ mod tests {
// 4. Check the resulting request template
let expected_request_templates = vec![StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 0)),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch,
0,
))),
args: vec![
StatementTmplArg::Wildcard(wc("Pod1", 0)),
StatementTmplArg::Wildcard(wc("Pod2", 1)),
@ -788,11 +798,17 @@ mod tests {
// 4. Check the resulting request templates
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch.clone(),
0,
))),
args: vec![StatementTmplArg::Wildcard(wc("Pod1", 0))],
},
StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 2)),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch,
2,
))),
args: vec![StatementTmplArg::Wildcard(wc("Pod2", 1))],
},
];
@ -808,7 +824,7 @@ mod tests {
// 1. Create a batch with a predicate to be imported
let imported_pred_stmts = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
}];
let imported_predicate = CustomPredicate::and(
@ -855,7 +871,10 @@ mod tests {
assert_eq!(defined_pred.statements.len(), 1);
let expected_statement = StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch.clone(),
0,
))),
args: vec![
StatementTmplArg::Wildcard(wc("X", 0)),
StatementTmplArg::Wildcard(wc("Y", 1)),
@ -886,7 +905,9 @@ mod tests {
let request_templates = processed.request.templates();
assert_eq!(request_templates.len(), 1);
if let Predicate::Intro(intro_ref) = &request_templates[0].pred {
if let PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) =
&request_templates[0].pred_or_wc
{
assert_eq!(intro_ref.name, "empty");
assert_eq!(intro_ref.args_len, 0);
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
@ -944,27 +965,27 @@ mod tests {
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("A", 0), "pk"), sta_lit(Value::from(pk))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("B", 1), "raw"), sta_lit(Value::from(raw))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("C", 2), "string"), sta_lit(Value::from(string))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("D", 3), "int"), sta_lit(Value::from(int))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("E", 4), "bool"), sta_lit(Value::from(bool))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("F", 5), "sk"), sta_lit(Value::from(sk))],
},
];

View file

@ -5,7 +5,8 @@ use std::fmt::Write;
use crate::{
frontend::PodRequest,
middleware::{
CustomPredicate, CustomPredicateBatch, Predicate, StatementTmpl, StatementTmplArg, Value,
CustomPredicate, CustomPredicateBatch, Predicate, PredicateOrWildcard, StatementTmpl,
StatementTmplArg, Value,
},
};
@ -57,26 +58,32 @@ impl StatementTmpl {
w: &mut dyn Write,
batch_context: Option<&CustomPredicateBatch>,
) -> std::fmt::Result {
match &self.pred {
Predicate::Native(native_pred) => {
write!(w, "{}", native_pred)?;
}
Predicate::Custom(custom_ref) => {
write!(w, "{}", custom_ref.predicate().name)?;
}
Predicate::Intro(intro_ref) => {
write!(w, "{}", intro_ref.name)?;
}
Predicate::BatchSelf(index) => {
if let Some(batch) = batch_context {
if let Some(predicate) = batch.predicates.get(*index) {
write!(w, "{}", predicate.name)?;
match &self.pred_or_wc {
PredicateOrWildcard::Predicate(pred) => match pred {
Predicate::Native(native_pred) => {
write!(w, "{}", native_pred)?;
}
Predicate::Custom(custom_ref) => {
write!(w, "{}", custom_ref.predicate().name)?;
}
Predicate::Intro(intro_ref) => {
write!(w, "{}", intro_ref.name)?;
}
Predicate::BatchSelf(index) => {
if let Some(batch) = batch_context {
if let Some(predicate) = batch.predicates.get(*index) {
write!(w, "{}", predicate.name)?;
} else {
write!(w, "batch_self_{}", index)?;
}
} else {
write!(w, "batch_self_{}", index)?;
}
} else {
write!(w, "batch_self_{}", index)?;
}
},
PredicateOrWildcard::Wildcard(wc) => {
// TODO: Decide the syntax for a wildcard predicate
write!(w, "?{}", wc.name)?;
}
}
@ -223,13 +230,17 @@ mod tests {
Wildcard::new(name.to_string(), index)
}
fn pred_lit(pred: Predicate) -> PredicateOrWildcard {
PredicateOrWildcard::Predicate(pred)
}
#[test]
fn test_simple_predicate_pretty_print() {
let params = Params::default();
// Create a simple predicate: is_equal(PodA, PodB) = AND(Equal(PodA["key"], PodB["key"]))
let statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("PodA", 0),
@ -265,7 +276,7 @@ mod tests {
// Create: uses_private(A, private: Temp) = AND(Equal(A["input"], Temp["const"]))
let statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("A", 0),
@ -301,7 +312,7 @@ mod tests {
// Create: check_value(Pod) = AND(Equal(Pod["field"], 42))
let statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("Pod", 0),
@ -335,7 +346,7 @@ mod tests {
// Create: either_or(A, B) = OR(Equal(A["x"], 1), Equal(B["y"], 2))
let statements = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("A", 0),
@ -345,7 +356,7 @@ mod tests {
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
StatementTmplArg::AnchoredKey(
create_test_wildcard("B", 1),