Allow literals in statement templates (#287)

This PR is a continuation of the work done in #276 
- Fix PodType in MainPod (we were using `MockMain` instead of `Main`)
- Update anchored keys in statement template arguments to only support wildcards in the origin and literal keys as the key.
  - Update the pest grammar accordingly
  - Update the parser accordingly
- Rewrite the eth_dos example in a recursive manner so that we use one recursive pod for every distance increment of 1.
  - I've also used the podlang to define the eth_dos custom predicates.  Currently all predicates are in a single batch (previously `eth_friend` was in a different batch).  With #286 we could define `eth_friend` in a different batch again.
    - I was feeling a bit creative and used a format macro to pass `Value`s from rust to the podlang code.
  - The eth_dos is now written using literals.  This resolves https://github.com/0xPARC/pod2/issues/255
- Remove `StatementArg::WildcardValue` in favor of `StatementArg::Literal`.  The `WildcardValue` was just a way to have some kind of typing for values that would be used as arguments in custom predicates.  Now that we can have literals in any statement this value can be anything, so I just removed the `WildcardValue` and use `Literal` instead.  On the backend it was already the case that both cases were treated the same way (after all, `WildcardValue` and `Literal` were 4 fields in the backend).
  - Added a new type for Value: `PodId` so that we can use it for custom predicates that take a pod id to be used in a wildcard
- Add a mock vd_set that is empty for tests that don't use plonky2; this allows running those tests individually without paying for the expensive work of calculating the vd for various circuits.
- rename StatementTmplArg::WildcardValue to StatementTmplArg::Wildcard
This commit is contained in:
Eduard S. 2025-06-16 16:38:38 +02:00 committed by GitHub
parent 7d0d3ad769
commit 3c6930dfe6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 659 additions and 1111 deletions

View file

@ -54,8 +54,8 @@ statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* }
statement = { identifier ~ "(" ~ statement_arg_list? ~ ")" }
// Anchored Key: (SELF | ?Var)["key_literal" | ?KeyVar]
anchored_key = { ( self_keyword | wildcard ) ~ "[" ~ (wildcard | literal_string) ~ "]" }
// Anchored Key: ?Var["key_literal"]
anchored_key = { wildcard ~ "[" ~ literal_string ~ "]" }
// Literal Values (ordered to avoid ambiguity, e.g., string before int)
literal_value = {

View file

@ -29,9 +29,9 @@ mod tests {
use crate::{
lang::error::ProcessorError,
middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, KeyOrWildcard,
NativePredicate, Params, PodType, Predicate, SelfOrWildcard, StatementTmpl,
StatementTmplArg, Value, Wildcard, SELF_ID_HASH,
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate,
Params, PodType, Predicate, StatementTmpl, StatementTmplArg, Value, Wildcard,
SELF_ID_HASH,
},
};
@ -40,23 +40,12 @@ mod tests {
Wildcard::new(name.to_string(), index)
}
fn k(name: &str) -> KeyOrWildcard {
KeyOrWildcard::Key(Key::new(name.to_string()))
fn sta_ak(pod_var: (&str, usize), key: &str) -> StatementTmplArg {
StatementTmplArg::AnchoredKey(wc(pod_var.0, pod_var.1), Key::from(key))
}
fn ko_wc(name: &str, index: usize) -> KeyOrWildcard {
KeyOrWildcard::Wildcard(Wildcard::new(name.to_string(), index))
}
fn sta_ak(pod_var: (&str, usize), key_or_wc: KeyOrWildcard) -> StatementTmplArg {
StatementTmplArg::AnchoredKey(
SelfOrWildcard::Wildcard(wc(pod_var.0, pod_var.1)),
key_or_wc,
)
}
fn sta_ak_self(key_or_wc: KeyOrWildcard) -> StatementTmplArg {
StatementTmplArg::AnchoredKey(SelfOrWildcard::SELF, key_or_wc)
fn sta_wc_lit(name: &str, index: usize) -> StatementTmplArg {
StatementTmplArg::Wildcard(wc(name, index))
}
fn sta_lit(value: impl Into<Value>) -> StatementTmplArg {
@ -89,8 +78,8 @@ mod tests {
let expected_statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak(("PodA", 0), k("the_key")), // ?PodA["the_key"] -> Wildcard(0), Key("the_key")
sta_ak(("PodB", 1), k("the_key")), // ?PodB["the_key"] -> Wildcard(1), Key("the_key")
sta_ak(("PodA", 0), "the_key"), // ?PodA["the_key"] -> Wildcard(0), Key("the_key")
sta_ak(("PodB", 1), "the_key"), // ?PodB["the_key"] -> Wildcard(1), Key("the_key")
],
}];
let expected_predicate = CustomPredicate::and(
@ -135,15 +124,15 @@ mod tests {
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak(("ConstPod", 0), k("my_val")), // ?ConstPod["my_val"] -> Wildcard(0), Key("my_val")
sta_ak(("ConstPod", 0), "my_val"), // ?ConstPod["my_val"] -> Wildcard(0), Key("my_val")
sta_lit(SELF_ID_HASH),
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Lt),
args: vec![
sta_ak(("GovPod", 1), k("dob")), // ?GovPod["dob"] -> Wildcard(1), Key("dob")
sta_ak(("ConstPod", 0), k("my_val")), // ?ConstPod["my_val"] -> Wildcard(0), Key("my_val")
sta_ak(("GovPod", 1), "dob"), // ?GovPod["dob"] -> Wildcard(1), Key("dob")
sta_ak(("ConstPod", 0), "my_val"), // ?ConstPod["my_val"] -> Wildcard(0), Key("my_val")
],
},
];
@ -177,15 +166,15 @@ mod tests {
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak(("A", 0), k("input_key")), // ?A["input_key"] -> Wildcard(0), Key("input_key")
sta_ak(("Temp", 1), k("const_key")), // ?Temp["const_key"] -> Wildcard(1), Key("const_key")
sta_ak(("A", 0), "input_key"), // ?A["input_key"] -> Wildcard(0), Key("input_key")
sta_ak(("Temp", 1), "const_key"), // ?Temp["const_key"] -> Wildcard(1), Key("const_key")
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak(("Temp", 1), k("const_key")), // ?Temp["const_key"] -> Wildcard(1), Key("const_key")
sta_lit("some_value"), // Literal("some_value")
sta_ak(("Temp", 1), "const_key"), // ?Temp["const_key"] -> Wildcard(1), Key("const_key")
sta_lit("some_value"), // Literal("some_value")
],
},
];
@ -234,8 +223,8 @@ mod tests {
let expected_pred_statements = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak(("X", 0), k("val")), // ?X["val"] -> Wildcard(0), Key("val")
sta_ak(("Y", 1), k("val")), // ?Y["val"] -> Wildcard(1), Key("val")
sta_ak(("X", 0), "val"), // ?X["val"] -> Wildcard(0), Key("val")
sta_ak(("Y", 1), "val"), // ?Y["val"] -> Wildcard(1), Key("val")
],
}];
let expected_predicate = CustomPredicate::and(
@ -258,8 +247,8 @@ mod tests {
let expected_request_templates = vec![StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(expected_batch, 0)),
args: vec![
StatementTmplArg::WildcardLiteral(wc("Pod1", 0)),
StatementTmplArg::WildcardLiteral(wc("Pod2", 1)),
StatementTmplArg::Wildcard(wc("Pod1", 0)),
StatementTmplArg::Wildcard(wc("Pod2", 1)),
],
}];
@ -271,7 +260,7 @@ mod tests {
#[test]
fn test_e2e_request_with_various_args() -> Result<(), LangError> {
let input = r#"
some_pred(A, B, C) = AND( Equal(?A["foo"], ?B["bar"]) )
some_pred(A, B, C) = AND( Equal(?A["foo"], ?B["bar"]) )
REQUEST(
some_pred(
@ -302,7 +291,7 @@ mod tests {
StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(batch_result, 0)), // Refers to some_pred
args: vec![
StatementTmplArg::WildcardLiteral(wc("Var1", 0)), // ?Var1
StatementTmplArg::Wildcard(wc("Var1", 0)), // ?Var1
StatementTmplArg::Literal(Value::from(12345i64)), // 12345
StatementTmplArg::Literal(Value::from("hello_string")), // "hello_string"
],
@ -311,9 +300,9 @@ mod tests {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
// ?AnotherPod["another_key"] -> Wildcard(1), Key("another_key")
sta_ak(("AnotherPod", 1), k("another_key")),
sta_ak(("AnotherPod", 1), "another_key"),
// ?Var1["some_field"] -> Wildcard(0), Key("some_field")
sta_ak(("Var1", 0), k("some_field")),
sta_ak(("Var1", 0), "some_field"),
],
},
];
@ -348,30 +337,30 @@ mod tests {
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::LtEq),
args: vec![sta_ak(("B", 1), k("bar")), sta_ak(("A", 0), k("foo"))],
args: vec![sta_ak(("B", 1), "bar"), sta_ak(("A", 0), "foo")],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Lt),
args: vec![sta_ak(("D", 3), k("qux")), sta_ak(("C", 2), k("baz"))],
args: vec![sta_ak(("D", 3), "qux"), sta_ak(("C", 2), "baz")],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Contains),
args: vec![
sta_ak(("A", 0), k("foo")),
sta_ak(("B", 1), k("bar")),
sta_ak(("C", 2), k("baz")),
sta_ak(("A", 0), "foo"),
sta_ak(("B", 1), "bar"),
sta_ak(("C", 2), "baz"),
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::NotContains),
args: vec![sta_ak(("A", 0), k("foo")), sta_ak(("B", 1), k("bar"))],
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Contains),
args: vec![
sta_ak(("A", 0), k("foo")),
sta_ak(("B", 1), k("bar")),
sta_ak(("C", 2), k("baz")),
sta_ak(("A", 0), "foo"),
sta_ak(("B", 1), "bar"),
sta_ak(("C", 2), "baz"),
],
},
];
@ -386,12 +375,12 @@ mod tests {
let input = r#"
REQUEST(
// Order matters for comparison with the hardcoded templates
SetNotContains(?sanctions["sanctionList"], ?gov["idNumber"])
Lt(?gov["dateOfBirth"], ?SELF_HOLDER_18Y["const_18y"])
Equal(?pay["startDate"], ?SELF_HOLDER_1Y["const_1y"])
Equal(?gov["socialSecurityNumber"], ?pay["socialSecurityNumber"])
Equal(?SELF_HOLDER_18Y["const_18y"], 1169909388)
Equal(?SELF_HOLDER_1Y["const_1y"], 1706367566)
SetNotContains(?sanctions["sanctionList"], ?gov["idNumber"])
Lt(?gov["dateOfBirth"], ?SELF_HOLDER_18Y["const_18y"])
Equal(?pay["startDate"], ?SELF_HOLDER_1Y["const_1y"])
Equal(?gov["socialSecurityNumber"], ?pay["socialSecurityNumber"])
Equal(?SELF_HOLDER_18Y["const_18y"], 1169909388)
Equal(?SELF_HOLDER_1Y["const_1y"], 1706367566)
)
"#;
@ -412,13 +401,13 @@ mod tests {
let wc_pay = wc("pay", 3);
let wc_self_1y = wc("SELF_HOLDER_1Y", 4);
let id_num_key = k("idNumber");
let dob_key = k("dateOfBirth");
let const_18y_key = k("const_18y");
let start_date_key = k("startDate");
let const_1y_key = k("const_1y");
let ssn_key = k("socialSecurityNumber");
let sanction_list_key = k("sanctionList");
let id_num_key = "idNumber";
let dob_key = "dateOfBirth";
let const_18y_key = "const_18y";
let start_date_key = "startDate";
let const_1y_key = "const_1y";
let ssn_key = "socialSecurityNumber";
let sanction_list_key = "sanctionList";
// Define the request templates using wildcards for constants
let expected_templates = vec![
@ -428,19 +417,19 @@ mod tests {
args: vec![
sta_ak(
(wc_sanctions.name.as_str(), wc_sanctions.index),
sanction_list_key.clone(),
sanction_list_key,
),
sta_ak((wc_gov.name.as_str(), wc_gov.index), id_num_key.clone()),
sta_ak((wc_gov.name.as_str(), wc_gov.index), id_num_key),
],
},
// 2. Lt(?gov["dateOfBirth"], ?SELF_HOLDER_18Y["const_18y"])
StatementTmpl {
pred: Predicate::Native(NativePredicate::Lt),
args: vec![
sta_ak((wc_gov.name.as_str(), wc_gov.index), dob_key.clone()),
sta_ak((wc_gov.name.as_str(), wc_gov.index), dob_key),
sta_ak(
(wc_self_18y.name.as_str(), wc_self_18y.index),
const_18y_key.clone(),
const_18y_key,
),
],
},
@ -448,19 +437,16 @@ mod tests {
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak((wc_pay.name.as_str(), wc_pay.index), start_date_key.clone()),
sta_ak(
(wc_self_1y.name.as_str(), wc_self_1y.index),
const_1y_key.clone(),
),
sta_ak((wc_pay.name.as_str(), wc_pay.index), start_date_key),
sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key),
],
},
// 4. Equal(?gov["socialSecurityNumber"], ?pay["socialSecurityNumber"])
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak((wc_gov.name.as_str(), wc_gov.index), ssn_key.clone()),
sta_ak((wc_pay.name.as_str(), wc_pay.index), ssn_key.clone()),
sta_ak((wc_gov.name.as_str(), wc_gov.index), ssn_key),
sta_ak((wc_pay.name.as_str(), wc_pay.index), ssn_key),
],
},
// 5. Equal(?SELF_HOLDER_18Y["const_18y"], 1169909388)
@ -469,7 +455,7 @@ mod tests {
args: vec![
sta_ak(
(wc_self_18y.name.as_str(), wc_self_18y.index),
const_18y_key.clone(),
const_18y_key,
),
sta_lit(now_minus_18y_val.clone()),
],
@ -478,10 +464,7 @@ mod tests {
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak(
(wc_self_1y.name.as_str(), wc_self_1y.index),
const_1y_key.clone(),
),
sta_ak((wc_self_1y.name.as_str(), wc_self_1y.index), const_1y_key),
sta_lit(now_minus_1y_val.clone()),
],
},
@ -517,27 +500,26 @@ mod tests {
};
let input = r#"
eth_friend(src_key, dst_key, private: attestation_pod) = AND(
eth_friend(src, dst, private: attestation_pod) = AND(
Equal(?attestation_pod["_type"], 1)
Equal(?attestation_pod["_signer"], SELF[?src_key])
Equal(?attestation_pod["attestation"], SELF[?dst_key])
Equal(?attestation_pod["_signer"], ?src)
Equal(?attestation_pod["attestation"], ?dst)
)
eth_dos_distance_base(src_key, dst_key, distance_key) = AND(
Equal(SELF[?src_key], SELF[?dst_key])
Equal(SELF[?distance_key], 0)
eth_dos_distance_base(src, dst, distance) = AND(
Equal(?src, ?dst)
Equal(?distance, 0)
)
eth_dos_distance_ind(src_key, dst_key, distance_key, private: one_key, shorter_distance_key, intermed_key) = AND(
eth_dos_distance(?src_key, ?dst_key, ?distance_key)
Equal(SELF[?one_key], 1)
SumOf(SELF[?distance_key], SELF[?shorter_distance_key], SELF[?one_key])
eth_friend(?intermed_key, ?dst_key)
eth_dos_distance_ind(src, dst, distance, private: shorter_distance, intermed) = AND(
eth_dos_distance(?src, ?dst, ?distance)
SumOf(?distance, ?shorter_distance, 1)
eth_friend(?intermed, ?dst)
)
eth_dos_distance(src_key, dst_key, distance_key) = OR(
eth_dos_distance_base(?src_key, ?dst_key, ?distance_key)
eth_dos_distance_ind(?src_key, ?dst_key, ?distance_key)
eth_dos_distance(src, dst, distance) = OR(
eth_dos_distance_base(?src, ?dst, ?distance)
eth_dos_distance_ind(?src, ?dst, ?distance)
)
"#;
@ -560,22 +542,22 @@ mod tests {
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak(("attestation_pod", 2), k("_type")), // Pub(0-1), Priv(2)
sta_ak(("attestation_pod", 2), "_type"), // Pub(0-1), Priv(2)
sta_lit(PodType::MockSigned),
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak(("attestation_pod", 2), k("_signer")),
sta_ak_self(ko_wc("src_key", 0)), // Pub arg 0
sta_ak(("attestation_pod", 2), "_signer"),
sta_wc_lit("src", 0), // Pub arg 0
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak(("attestation_pod", 2), k("attestation")),
sta_ak_self(ko_wc("dst_key", 1)), // Pub arg 1
sta_ak(("attestation_pod", 2), "attestation"),
sta_wc_lit("dst", 1), // Pub arg 1
],
},
];
@ -584,22 +566,19 @@ mod tests {
"eth_friend".to_string(),
true, // AND
expected_friend_stmts,
2, // public_args_len: src_key, dst_key
names(&["src_key", "dst_key", "attestation_pod"]),
2, // public_args_len: src, dst
names(&["src", "dst", "attestation_pod"]),
)?;
// eth_dos_distance_base (Index 1)
let expected_base_stmts = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak_self(ko_wc("src_key", 0)),
sta_ak_self(ko_wc("dst_key", 1)),
],
args: vec![sta_wc_lit("src", 0), sta_wc_lit("dst", 1)],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak_self(ko_wc("distance_key", 2)), sta_lit(0i64)],
args: vec![sta_wc_lit("distance", 2), sta_lit(0i64)],
},
];
let expected_base_pred = CustomPredicate::new(
@ -608,40 +587,36 @@ mod tests {
true, // AND
expected_base_stmts,
3, // public_args_len
names(&["src_key", "dst_key", "distance_key"]),
names(&["src", "dst", "distance"]),
)?;
// eth_dos_distance_ind (Index 2)
// Public args indices: 0-2
// Private args indices: 3-5 (one_key, shorter_distance_key, intermed_key)
// Private args indices: 3-4 (shorter_distance, intermed)
let expected_ind_stmts = vec![
StatementTmpl {
pred: Predicate::BatchSelf(3), // Calls eth_dos_distance (index 3)
args: vec![
// WildcardLiteral args
StatementTmplArg::WildcardLiteral(wc("src_key", 0)),
StatementTmplArg::WildcardLiteral(wc("dst_key", 1)), // private arg
StatementTmplArg::WildcardLiteral(wc("distance_key", 2)), // private arg
sta_wc_lit("src", 0),
sta_wc_lit("dst", 1), // private arg
sta_wc_lit("distance", 2), // private arg
],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak_self(ko_wc("one_key", 3)), sta_lit(1i64)], // private arg
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::SumOf),
args: vec![
sta_ak_self(ko_wc("distance_key", 2)), // public arg
sta_ak_self(ko_wc("shorter_distance_key", 4)), // private arg
sta_ak_self(ko_wc("one_key", 3)), // private arg
sta_wc_lit("distance", 2), // public arg
sta_wc_lit("shorter_distance", 3), // private arg
sta_lit(1),
],
},
StatementTmpl {
pred: Predicate::BatchSelf(0), // Calls eth_friend (index 0)
args: vec![
// WildcardLiteral args
StatementTmplArg::WildcardLiteral(wc("intermed_key", 5)), // private arg
StatementTmplArg::WildcardLiteral(wc("dst_key", 1)), // public arg
sta_wc_lit("intermed", 4), // private arg
sta_wc_lit("dst", 1), // public arg
],
},
];
@ -651,14 +626,7 @@ mod tests {
true, // AND
expected_ind_stmts,
3, // public_args_len
names(&[
"src_key",
"dst_key",
"distance_key",
"one_key",
"shorter_distance_key",
"intermed_key",
]),
names(&["src", "dst", "distance", "shorter_distance", "intermed"]),
)?;
// eth_dos_distance (Index 3)
@ -667,18 +635,18 @@ mod tests {
pred: Predicate::BatchSelf(1), // Calls eth_dos_distance_base (index 1)
args: vec![
// WildcardLiteral args
StatementTmplArg::WildcardLiteral(wc("src_key", 0)),
StatementTmplArg::WildcardLiteral(wc("dst_key", 1)),
StatementTmplArg::WildcardLiteral(wc("distance_key", 2)),
sta_wc_lit("src", 0),
sta_wc_lit("dst", 1),
sta_wc_lit("distance", 2),
],
},
StatementTmpl {
pred: Predicate::BatchSelf(2), // Calls eth_dos_distance_ind (index 2)
args: vec![
// WildcardLiteral args
StatementTmplArg::WildcardLiteral(wc("src_key", 0)),
StatementTmplArg::WildcardLiteral(wc("dst_key", 1)),
StatementTmplArg::WildcardLiteral(wc("distance_key", 2)),
sta_wc_lit("src", 0),
sta_wc_lit("dst", 1),
sta_wc_lit("distance", 2),
],
},
];
@ -688,7 +656,7 @@ mod tests {
false, // OR
expected_dist_stmts,
3, // public_args_len
names(&["src_key", "dst_key", "distance_key"]),
names(&["src", "dst", "distance"]),
)?;
let expected_batch = CustomPredicateBatch::new(
@ -718,8 +686,8 @@ mod tests {
let imported_pred_stmts = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![
sta_ak(("A", 0), k("foo")), // ?A["foo"]
sta_ak(("B", 1), k("bar")), // ?B["bar"]
sta_ak(("A", 0), "foo"), // ?A["foo"]
sta_ak(("B", 1), "bar"), // ?B["bar"]
],
}];
let imported_predicate = CustomPredicate::and(
@ -760,8 +728,8 @@ mod tests {
let expected_request_templates = vec![StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 0)),
args: vec![
StatementTmplArg::WildcardLiteral(wc("Pod1", 0)),
StatementTmplArg::WildcardLiteral(wc("Pod2", 1)),
StatementTmplArg::Wildcard(wc("Pod1", 0)),
StatementTmplArg::Wildcard(wc("Pod2", 1)),
],
}];
@ -808,11 +776,11 @@ mod tests {
let expected_templates = vec![
StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)),
args: vec![StatementTmplArg::WildcardLiteral(wc("Pod1", 0))],
args: vec![StatementTmplArg::Wildcard(wc("Pod1", 0))],
},
StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch, 2)),
args: vec![StatementTmplArg::WildcardLiteral(wc("Pod2", 1))],
args: vec![StatementTmplArg::Wildcard(wc("Pod2", 1))],
},
];
@ -828,7 +796,7 @@ mod tests {
// 1. Create a batch with a predicate to be imported
let imported_pred_stmts = vec![StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("A", 0), k("foo")), sta_ak(("B", 1), k("bar"))],
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
}];
let imported_predicate = CustomPredicate::and(
&params,
@ -876,8 +844,8 @@ mod tests {
let expected_statement = StatementTmpl {
pred: Predicate::Custom(CustomPredicateRef::new(available_batch.clone(), 0)),
args: vec![
StatementTmplArg::WildcardLiteral(wc("X", 0)),
StatementTmplArg::WildcardLiteral(wc("Y", 1)),
StatementTmplArg::Wildcard(wc("X", 0)),
StatementTmplArg::Wildcard(wc("Y", 1)),
],
};

View file

@ -86,12 +86,10 @@ mod tests {
#[test]
fn test_parse_anchored_key() {
assert_parses(Rule::anchored_key, "?PodVar[\"literal_key\"]");
assert_parses(Rule::anchored_key, "?PodVar[?KeyVar]");
assert_parses(Rule::anchored_key, "SELF[?KeyVar]");
assert_parses(Rule::anchored_key, "SELF[\"literal_key\"]");
assert_fails(Rule::anchored_key, "PodVar[\"key\"]"); // Needs wildcard for pod
assert_fails(Rule::anchored_key, "?PodVar[invalid_key]"); // Key must be literal string or wildcard
assert_fails(Rule::anchored_key, "?PodVar[]"); // Key cannot be empty
assert_fails(Rule::anchored_key, "?PodVar[?key]"); // Key cannot be wildcard
}
#[test]
@ -179,7 +177,7 @@ mod tests {
Rule::test_custom_predicate_def,
// Trimmed leading/trailing whitespace
r#"pred_with_private(X, private: TempKey) = OR(
Equal(?X[?TempKey], ?X["other"])
Equal(?X["key"], 1234)
)"#,
);
assert_fails(

View file

@ -8,15 +8,11 @@ use plonky2::field::types::Field;
use super::error::ProcessorError;
use crate::{
frontend::{
BuilderArg, CustomPredicateBatchBuilder, KeyOrWildcardStr, SelfOrWildcardStr,
StatementTmplBuilder,
},
frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder},
lang::parser::Rule,
middleware::{
self, CustomPredicateBatch, CustomPredicateRef, Key, KeyOrWildcard, NativePredicate,
Params, Predicate, SelfOrWildcard as MiddlewareSelfOrWildcard, StatementTmpl,
StatementTmplArg, Value, Wildcard, F, VALUE_SIZE,
self, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate, Params, Predicate,
StatementTmpl, StatementTmplArg, Value, Wildcard, F, VALUE_SIZE,
},
};
@ -305,34 +301,11 @@ fn pest_pair_to_builder_arg(arg_content_pair: &Pair<Rule>) -> Result<BuilderArg,
Rule::anchored_key => {
let mut inner_ak_pairs = arg_content_pair.clone().into_inner();
let pod_id_pair = inner_ak_pairs.next().unwrap();
let pod_self_or_wc_str = match pod_id_pair.as_rule() {
Rule::wildcard => {
let name = pod_id_pair.as_str().strip_prefix("?").unwrap();
SelfOrWildcardStr::Wildcard(name.to_string())
}
Rule::self_keyword => SelfOrWildcardStr::SELF,
_ => {
unreachable!("Unexpected rule: {:?}", pod_id_pair.as_rule());
}
};
let pod_id_wc_str = pod_id_pair.as_str().strip_prefix("?").unwrap();
let key_part_pair = inner_ak_pairs.next().unwrap();
let key_or_wildcard_str = match key_part_pair.as_rule() {
Rule::wildcard => {
let key_wildcard_name = key_part_pair.as_str().strip_prefix("?").unwrap();
KeyOrWildcardStr::Wildcard(key_wildcard_name.to_string())
}
Rule::literal_string => {
let key_str_literal = parse_pest_string_literal(&key_part_pair)?;
KeyOrWildcardStr::Key(key_str_literal)
}
_ => {
unreachable!("Unexpected rule: {:?}", key_part_pair.as_rule());
}
};
Ok(BuilderArg::Key(pod_self_or_wc_str, key_or_wildcard_str))
let key_str = parse_pest_string_literal(&key_part_pair)?;
Ok(BuilderArg::Key(pod_id_wc_str.to_string(), key_str))
}
_ => unreachable!("Unexpected rule: {:?}", arg_content_pair.as_rule()),
}
@ -377,23 +350,6 @@ fn validate_and_build_statement_template(
span: Some(stmt_name_span),
});
}
if expected_arity > 0 {
for (i, arg) in args.iter().enumerate() {
if !matches!(arg, BuilderArg::Key(..) | BuilderArg::Literal(..)) {
return Err(ProcessorError::TypeError {
expected: "Anchored Key".to_string(),
found: format!("{:?}", arg),
item: format!(
"argument {} of native predicate '{}'",
i + 1,
stmt_name_str
),
span: Some(stmt_span),
});
}
}
}
}
Predicate::Custom(custom_ref) => {
let expected_arity = custom_ref.predicate().args_len;
@ -635,13 +591,8 @@ fn process_statement_template(
for arg in &builder_args {
match arg {
BuilderArg::WildcardLiteral(name) => temp_stmt_wildcard_names.push(name.clone()),
BuilderArg::Key(pod_id_str, key_wc_str) => {
if let SelfOrWildcardStr::Wildcard(name) = pod_id_str {
temp_stmt_wildcard_names.push(name.clone());
}
if let KeyOrWildcardStr::Wildcard(key_wc_name) = key_wc_str {
temp_stmt_wildcard_names.push(key_wc_name.clone());
}
BuilderArg::Key(pod_id_wc_str, _key_str) => {
temp_stmt_wildcard_names.push(pod_id_wc_str.clone());
}
_ => {}
}
@ -873,19 +824,6 @@ fn resolve_wildcard(
})
}
fn resolve_key_or_wildcard_str(
ordered_scope_wildcard_names: &[String],
kows: &KeyOrWildcardStr,
) -> Result<KeyOrWildcard, ProcessorError> {
match kows {
KeyOrWildcardStr::Key(k_str) => Ok(KeyOrWildcard::Key(Key::new(k_str.clone()))),
KeyOrWildcardStr::Wildcard(wc_name_str) => {
let resolved_wc = resolve_wildcard(ordered_scope_wildcard_names, wc_name_str)?;
Ok(KeyOrWildcard::Wildcard(resolved_wc))
}
}
}
fn resolve_request_statement_builder(
stb: StatementTmplBuilder,
ordered_request_wildcard_names: &[String],
@ -897,20 +835,14 @@ fn resolve_request_statement_builder(
for builder_arg in stb.args {
let mw_arg = match builder_arg {
BuilderArg::Literal(v) => StatementTmplArg::Literal(v),
BuilderArg::Key(pod_id_str, key_wc_str) => {
let pod_sowc = match pod_id_str {
SelfOrWildcardStr::SELF => MiddlewareSelfOrWildcard::SELF,
SelfOrWildcardStr::Wildcard(name) => MiddlewareSelfOrWildcard::Wildcard(
resolve_wildcard(ordered_request_wildcard_names, &name)?,
),
};
let key_or_wc =
resolve_key_or_wildcard_str(ordered_request_wildcard_names, &key_wc_str)?;
StatementTmplArg::AnchoredKey(pod_sowc, key_or_wc)
BuilderArg::Key(pod_id_wc_str, key_str) => {
let pod_id_wc = resolve_wildcard(ordered_request_wildcard_names, &pod_id_wc_str)?;
let key = Key::from(key_str);
StatementTmplArg::AnchoredKey(pod_id_wc, key)
}
BuilderArg::WildcardLiteral(wc_name) => {
let pod_wc = resolve_wildcard(ordered_request_wildcard_names, &wc_name)?;
StatementTmplArg::WildcardLiteral(pod_wc)
let wc = resolve_wildcard(ordered_request_wildcard_names, &wc_name)?;
StatementTmplArg::Wildcard(wc)
}
};
middleware_args.push(mw_arg);
@ -1183,7 +1115,7 @@ mod processor_tests {
// Native predicate names are case-sensitive
let input = r#"
REQUEST(
EQUAL(?A[?B], ?C[?D])
EQUAL(?A["b"], ?C["d"])
)
"#;
let pairs = get_document_content_pairs(input)?;