Remove batch splitting system (#475)
* First pass at removing batch splitting * Refactor to separate module loading from request parsing * Consolidate module functionality * Tidy up comments * Use array of modules instead of HashMap * Formatting * Use module hashes when importing modules
This commit is contained in:
parent
5dab8195b4
commit
acab26e5c1
17 changed files with 1425 additions and 1938 deletions
|
|
@ -16,7 +16,7 @@ use pod2::{
|
||||||
primitives::ec::schnorr::SecretKey, signer::Signer,
|
primitives::ec::schnorr::SecretKey, signer::Signer,
|
||||||
},
|
},
|
||||||
frontend::{MainPodBuilder, Operation, SignedDictBuilder},
|
frontend::{MainPodBuilder, Operation, SignedDictBuilder},
|
||||||
lang::parse,
|
lang::load_module,
|
||||||
middleware::{MainPodProver, Params, VDSet},
|
middleware::{MainPodProver, Params, VDSet},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -88,10 +88,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
game_pk = game_pk,
|
game_pk = game_pk,
|
||||||
);
|
);
|
||||||
println!("# custom predicate batch:{}", input);
|
println!("# custom predicate batch:{}", input);
|
||||||
let batch = parse(&input, ¶ms, &[])?
|
let module = load_module(&input, "points_module", ¶ms, vec![])?;
|
||||||
.first_batch()
|
let batch = module.batch.clone();
|
||||||
.expect("Expected batch")
|
|
||||||
.clone();
|
|
||||||
let points_pred = batch.predicate_ref_by_name("points").unwrap();
|
let points_pred = batch.predicate_ref_by_name("points").unwrap();
|
||||||
let over_9000_pred = batch.predicate_ref_by_name("over_9000").unwrap();
|
let over_9000_pred = batch.predicate_ref_by_name("over_9000").unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -836,7 +836,7 @@ pub mod tests {
|
||||||
frontend::{
|
frontend::{
|
||||||
self, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB,
|
self, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB,
|
||||||
},
|
},
|
||||||
lang::parse,
|
lang::load_module,
|
||||||
middleware::{
|
middleware::{
|
||||||
self, containers::Set, CustomPredicateRef, NativePredicate as NP, Signer as _,
|
self, containers::Set, CustomPredicateRef, NativePredicate as NP, Signer as _,
|
||||||
DEFAULT_VD_LIST, DEFAULT_VD_SET,
|
DEFAULT_VD_LIST, DEFAULT_VD_SET,
|
||||||
|
|
@ -1165,7 +1165,7 @@ pub mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_undetermined_values() {
|
fn test_undetermined_values() {
|
||||||
let params = Default::default();
|
let params = Default::default();
|
||||||
let batch = parse(
|
let module = load_module(
|
||||||
r#"
|
r#"
|
||||||
two_equal(x,y,z) = OR(
|
two_equal(x,y,z) = OR(
|
||||||
Equal(x,y)
|
Equal(x,y)
|
||||||
|
|
@ -1173,13 +1173,12 @@ pub mod tests {
|
||||||
Equal(x,z)
|
Equal(x,z)
|
||||||
)
|
)
|
||||||
"#,
|
"#,
|
||||||
|
"test",
|
||||||
¶ms,
|
¶ms,
|
||||||
&[],
|
vec![],
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap();
|
||||||
.first_batch()
|
let batch = module.batch.clone();
|
||||||
.unwrap()
|
|
||||||
.clone();
|
|
||||||
let mut builder = MainPodBuilder::new(¶ms, &DEFAULT_VD_SET);
|
let mut builder = MainPodBuilder::new(¶ms, &DEFAULT_VD_SET);
|
||||||
let cpr = CustomPredicateRef { batch, index: 0 };
|
let cpr = CustomPredicateRef { batch, index: 0 };
|
||||||
let eq_st = builder.priv_op(frontend::Operation::eq(1, 1)).unwrap();
|
let eq_st = builder.priv_op(frontend::Operation::eq(1, 1)).unwrap();
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
use std::sync::Arc;
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
use hex::ToHex;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
frontend::{PodRequest, Result},
|
frontend::{PodRequest, Result},
|
||||||
lang::parse,
|
lang::{load_module, parse_request, Module},
|
||||||
middleware::{CustomPredicateBatch, Params},
|
middleware::{CustomPredicateBatch, Params},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -32,11 +30,8 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
|
||||||
eth_dos_ind(src, dst, distance)
|
eth_dos_ind(src, dst, distance)
|
||||||
)
|
)
|
||||||
"#;
|
"#;
|
||||||
let batch = parse(input, params, &[])
|
let module = load_module(input, "eth_dos", params, vec![]).expect("lang parse");
|
||||||
.expect("lang parse")
|
let batch = module.batch.clone();
|
||||||
.first_batch()
|
|
||||||
.expect("Expected batch")
|
|
||||||
.clone();
|
|
||||||
println!("a.0. {}", batch.predicates()[0]);
|
println!("a.0. {}", batch.predicates()[0]);
|
||||||
println!("a.1. {}", batch.predicates()[1]);
|
println!("a.1. {}", batch.predicates()[1]);
|
||||||
println!("a.2. {}", batch.predicates()[2]);
|
println!("a.2. {}", batch.predicates()[2]);
|
||||||
|
|
@ -45,18 +40,26 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn eth_dos_request() -> Result<PodRequest> {
|
pub fn eth_dos_request() -> Result<PodRequest> {
|
||||||
|
use hex::ToHex;
|
||||||
|
|
||||||
let batch = eth_dos_batch(&Params::default())?;
|
let batch = eth_dos_batch(&Params::default())?;
|
||||||
let batch_id = batch.id().encode_hex::<String>();
|
let eth_dos_module = Arc::new(Module::new(batch, HashMap::new()));
|
||||||
|
let module_hash = eth_dos_module.id().encode_hex::<String>();
|
||||||
|
|
||||||
let input = format!(
|
let input = format!(
|
||||||
r#"
|
r#"
|
||||||
use batch _, _, _, eth_dos from 0x{batch_id}
|
use module 0x{} as eth_dos
|
||||||
REQUEST(
|
REQUEST(
|
||||||
eth_dos(src, dst, distance)
|
eth_dos::eth_dos(src, dst, distance)
|
||||||
)
|
)
|
||||||
"#,
|
"#,
|
||||||
|
module_hash
|
||||||
);
|
);
|
||||||
let parsed = parse(&input, &Params::default(), &[batch])?;
|
Ok(parse_request(
|
||||||
Ok(parsed.request)
|
&input,
|
||||||
|
&Params::default(),
|
||||||
|
&[eth_dos_module],
|
||||||
|
)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ use crate::{
|
||||||
frontend::{
|
frontend::{
|
||||||
MainPod, MainPodBuilder, Operation, PodRequest, Result, SignedDict, SignedDictBuilder,
|
MainPod, MainPodBuilder, Operation, PodRequest, Result, SignedDict, SignedDictBuilder,
|
||||||
},
|
},
|
||||||
lang::parse,
|
lang::parse_request,
|
||||||
middleware::{
|
middleware::{
|
||||||
self, containers::Set, hash_values, CustomPredicateRef, Params, Predicate, PublicKey,
|
self, containers::Set, hash_values, CustomPredicateRef, Params, Predicate, PublicKey,
|
||||||
Signer as _, Statement, StatementArg, TypedValue, VDSet, Value,
|
Signer as _, Statement, StatementArg, TypedValue, VDSet, Value,
|
||||||
|
|
@ -90,8 +90,7 @@ pub fn zu_kyc_pod_request(gov_signer: &Value, pay_signer: &Value) -> Result<PodR
|
||||||
)
|
)
|
||||||
"#,
|
"#,
|
||||||
);
|
);
|
||||||
let parsed = parse(&input, &Params::default(), &[])?;
|
Ok(parse_request(&input, &Params::default(), &[])?)
|
||||||
Ok(parsed.request)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ETHDoS
|
// ETHDoS
|
||||||
|
|
|
||||||
|
|
@ -798,7 +798,6 @@ impl MainPodCompiler {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub mod tests {
|
pub mod tests {
|
||||||
|
|
||||||
use num::BigUint;
|
use num::BigUint;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
@ -813,7 +812,7 @@ pub mod tests {
|
||||||
tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request,
|
tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request,
|
||||||
zu_kyc_sign_dict_builders, EthDosHelper, MOCK_VD_SET,
|
zu_kyc_sign_dict_builders, EthDosHelper, MOCK_VD_SET,
|
||||||
},
|
},
|
||||||
lang::parse,
|
lang::load_module,
|
||||||
middleware::{
|
middleware::{
|
||||||
containers::{Array, Set},
|
containers::{Array, Set},
|
||||||
Signer as _, Value,
|
Signer as _, Value,
|
||||||
|
|
@ -1382,11 +1381,8 @@ pub mod tests {
|
||||||
Equal(b, 5)
|
Equal(b, 5)
|
||||||
)
|
)
|
||||||
"#;
|
"#;
|
||||||
let batch = parse(input, ¶ms, &[])
|
let module = load_module(input, "test", ¶ms, vec![]).unwrap();
|
||||||
.unwrap()
|
let batch = module.batch.clone();
|
||||||
.first_batch()
|
|
||||||
.unwrap()
|
|
||||||
.clone();
|
|
||||||
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
|
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
|
||||||
|
|
||||||
// Try to build with wrong type in 1st arg
|
// Try to build with wrong type in 1st arg
|
||||||
|
|
@ -1434,11 +1430,8 @@ pub mod tests {
|
||||||
c(6, 3)
|
c(6, 3)
|
||||||
)
|
)
|
||||||
"#;
|
"#;
|
||||||
let batch = parse(input, ¶ms, &[])
|
let module = load_module(input, "test", ¶ms, vec![]).unwrap();
|
||||||
.unwrap()
|
let batch = module.batch.clone();
|
||||||
.first_batch()
|
|
||||||
.unwrap()
|
|
||||||
.clone();
|
|
||||||
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
|
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
|
||||||
|
|
||||||
let mut builder = MainPodBuilder::new(¶ms, vd_set);
|
let mut builder = MainPodBuilder::new(¶ms, vd_set);
|
||||||
|
|
@ -1459,11 +1452,8 @@ pub mod tests {
|
||||||
c(6, 3)
|
c(6, 3)
|
||||||
)
|
)
|
||||||
"#;
|
"#;
|
||||||
let batch = parse(input, ¶ms, &[])
|
let module = load_module(input, "test", ¶ms, vec![]).unwrap();
|
||||||
.unwrap()
|
let batch = module.batch.clone();
|
||||||
.first_batch()
|
|
||||||
.unwrap()
|
|
||||||
.clone();
|
|
||||||
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
|
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
|
||||||
|
|
||||||
let mut builder = MainPodBuilder::new(¶ms, vd_set);
|
let mut builder = MainPodBuilder::new(¶ms, vd_set);
|
||||||
|
|
@ -1501,12 +1491,11 @@ pub mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
// Parse and batch the predicate (this handles splitting internally)
|
// Parse and batch the predicate (this handles splitting internally)
|
||||||
let parsed = parse(input, ¶ms, &[])?;
|
let module = load_module(input, "test", ¶ms, vec![])?;
|
||||||
let batches = &parsed.custom_batches;
|
|
||||||
|
|
||||||
// Verify it was split
|
// Verify it was split
|
||||||
assert!(batches.split_chain("large_pred").is_some());
|
assert!(module.split_chains.contains_key("large_pred"));
|
||||||
let chain_info = batches.split_chain("large_pred").unwrap();
|
let chain_info = module.split_chains.get("large_pred").unwrap();
|
||||||
assert_eq!(chain_info.chain_pieces.len(), 2);
|
assert_eq!(chain_info.chain_pieces.len(), 2);
|
||||||
assert_eq!(chain_info.real_statement_count, 6);
|
assert_eq!(chain_info.real_statement_count, 6);
|
||||||
|
|
||||||
|
|
@ -1538,10 +1527,10 @@ pub mod tests {
|
||||||
let statements = vec![st_a, st_b, st_c, st_d, st_e, st_f];
|
let statements = vec![st_a, st_b, st_c, st_d, st_e, st_f];
|
||||||
|
|
||||||
// Use apply_predicate (primary API) to automatically wire the split chain
|
// Use apply_predicate (primary API) to automatically wire the split chain
|
||||||
let result = batches.apply_predicate(&mut builder, "large_pred", statements, true)?;
|
let result = module.apply_predicate(&mut builder, "large_pred", statements, true)?;
|
||||||
|
|
||||||
// The result should be a valid statement
|
// The result should be a valid statement
|
||||||
let predicate = batches.predicate_ref_by_name("large_pred").unwrap();
|
let predicate = module.predicate_ref_by_name("large_pred").unwrap();
|
||||||
match &result {
|
match &result {
|
||||||
Statement::Custom(pred_ref, _) => {
|
Statement::Custom(pred_ref, _) => {
|
||||||
assert_eq!(pred_ref, &predicate);
|
assert_eq!(pred_ref, &predicate);
|
||||||
|
|
|
||||||
|
|
@ -632,7 +632,7 @@ mod tests {
|
||||||
dict,
|
dict,
|
||||||
examples::MOCK_VD_SET,
|
examples::MOCK_VD_SET,
|
||||||
frontend::{Operation as FrontendOp, SignedDictBuilder},
|
frontend::{Operation as FrontendOp, SignedDictBuilder},
|
||||||
lang::parse,
|
lang::load_module,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -756,18 +756,17 @@ mod tests {
|
||||||
|
|
||||||
// pred_a accepts a Contains statement
|
// pred_a accepts a Contains statement
|
||||||
// pred_b accepts a pred_a statement (Custom statement from pred_a)
|
// pred_b accepts a pred_a statement (Custom statement from pred_a)
|
||||||
let parsed = parse(
|
let module = load_module(
|
||||||
r#"
|
r#"
|
||||||
pred_a(X) = AND(Contains(X, "k", 1))
|
pred_a(X) = AND(Contains(X, "k", 1))
|
||||||
pred_b(X) = AND(pred_a(X))
|
pred_b(X) = AND(pred_a(X))
|
||||||
"#,
|
"#,
|
||||||
|
"test",
|
||||||
¶ms,
|
¶ms,
|
||||||
&[],
|
vec![],
|
||||||
)
|
)
|
||||||
.expect("parse predicates");
|
.expect("load module");
|
||||||
let batch = parsed
|
let batch = &module.batch;
|
||||||
.first_batch()
|
|
||||||
.expect("parse predicates should have a batch");
|
|
||||||
|
|
||||||
let mut builder = MultiPodBuilder::new(¶ms, vd_set);
|
let mut builder = MultiPodBuilder::new(¶ms, vd_set);
|
||||||
|
|
||||||
|
|
@ -1484,20 +1483,19 @@ mod tests {
|
||||||
let vd_set = &*MOCK_VD_SET;
|
let vd_set = &*MOCK_VD_SET;
|
||||||
|
|
||||||
// Chain of predicates: each accepts the output of the previous
|
// Chain of predicates: each accepts the output of the previous
|
||||||
let parsed = parse(
|
let module = load_module(
|
||||||
r#"
|
r#"
|
||||||
pred_a(X) = AND(Contains(X, "k", 1))
|
pred_a(X) = AND(Contains(X, "k", 1))
|
||||||
pred_b(X) = AND(pred_a(X))
|
pred_b(X) = AND(pred_a(X))
|
||||||
pred_c(X) = AND(pred_b(X))
|
pred_c(X) = AND(pred_b(X))
|
||||||
pred_d(X) = AND(pred_c(X))
|
pred_d(X) = AND(pred_c(X))
|
||||||
"#,
|
"#,
|
||||||
|
"test",
|
||||||
¶ms,
|
¶ms,
|
||||||
&[],
|
vec![],
|
||||||
)
|
)
|
||||||
.expect("parse predicates");
|
.expect("load module");
|
||||||
let batch = parsed
|
let batch = &module.batch;
|
||||||
.first_batch()
|
|
||||||
.expect("parse predicates should have a batch");
|
|
||||||
|
|
||||||
let mut builder = MultiPodBuilder::new(¶ms, vd_set);
|
let mut builder = MultiPodBuilder::new(¶ms, vd_set);
|
||||||
|
|
||||||
|
|
@ -1612,7 +1610,7 @@ mod tests {
|
||||||
// pred_a takes TWO custom statement arguments (b_out and c_out)
|
// pred_a takes TWO custom statement arguments (b_out and c_out)
|
||||||
// pred_b and pred_c each take a Contains
|
// pred_b and pred_c each take a Contains
|
||||||
// Note: AND clauses are newline-separated, not comma-separated
|
// Note: AND clauses are newline-separated, not comma-separated
|
||||||
let parsed = parse(
|
let module = load_module(
|
||||||
r#"
|
r#"
|
||||||
pred_b(X) = AND(Contains(X, "k", 1))
|
pred_b(X) = AND(Contains(X, "k", 1))
|
||||||
pred_c(X) = AND(Contains(X, "k", 1))
|
pred_c(X) = AND(Contains(X, "k", 1))
|
||||||
|
|
@ -1621,13 +1619,12 @@ mod tests {
|
||||||
pred_c(Y)
|
pred_c(Y)
|
||||||
)
|
)
|
||||||
"#,
|
"#,
|
||||||
|
"test",
|
||||||
¶ms,
|
¶ms,
|
||||||
&[],
|
vec![],
|
||||||
)
|
)
|
||||||
.expect("parse predicates");
|
.expect("load module");
|
||||||
let batch = parsed
|
let batch = &module.batch;
|
||||||
.first_batch()
|
|
||||||
.expect("parse predicates should have a batch");
|
|
||||||
|
|
||||||
let mut builder = MultiPodBuilder::new(¶ms, vd_set);
|
let mut builder = MultiPodBuilder::new(¶ms, vd_set);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -185,7 +185,7 @@ mod tests {
|
||||||
zu_kyc_pod_builder, zu_kyc_pod_request, zu_kyc_sign_dict_builders, MOCK_VD_SET,
|
zu_kyc_pod_builder, zu_kyc_pod_request, zu_kyc_sign_dict_builders, MOCK_VD_SET,
|
||||||
},
|
},
|
||||||
frontend::{MainPodBuilder, Operation},
|
frontend::{MainPodBuilder, Operation},
|
||||||
lang::parse,
|
lang::parse_request,
|
||||||
middleware::{Params, Value},
|
middleware::{Params, Value},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -210,7 +210,7 @@ mod tests {
|
||||||
assert!(request.exact_match_pod(&*kyc.pod).is_ok());
|
assert!(request.exact_match_pod(&*kyc.pod).is_ok());
|
||||||
|
|
||||||
// This request does not match the POD, because the POD does not contain a NotEqual statement.
|
// This request does not match the POD, because the POD does not contain a NotEqual statement.
|
||||||
let non_matching_request = parse(
|
let non_matching_request = parse_request(
|
||||||
r#"
|
r#"
|
||||||
REQUEST(
|
REQUEST(
|
||||||
NotEqual(4, 5)
|
NotEqual(4, 5)
|
||||||
|
|
@ -219,8 +219,7 @@ mod tests {
|
||||||
¶ms,
|
¶ms,
|
||||||
&[],
|
&[],
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap();
|
||||||
.request;
|
|
||||||
assert!(non_matching_request.exact_match_pod(&*kyc.pod).is_err());
|
assert!(non_matching_request.exact_match_pod(&*kyc.pod).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -240,7 +239,7 @@ mod tests {
|
||||||
|
|
||||||
println!("{pod}");
|
println!("{pod}");
|
||||||
|
|
||||||
let request = parse(
|
let request = parse_request(
|
||||||
r#"
|
r#"
|
||||||
REQUEST(
|
REQUEST(
|
||||||
SumOf(a, b, c)
|
SumOf(a, b, c)
|
||||||
|
|
@ -252,7 +251,7 @@ mod tests {
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let bindings = request.request.exact_match_pod(&*pod.pod).unwrap();
|
let bindings = request.exact_match_pod(&*pod.pod).unwrap();
|
||||||
assert_eq!(*bindings.get("a").unwrap(), 10.into());
|
assert_eq!(*bindings.get("a").unwrap(), 10.into());
|
||||||
assert_eq!(*bindings.get("b").unwrap(), 9.into());
|
assert_eq!(*bindings.get("b").unwrap(), 9.into());
|
||||||
assert_eq!(*bindings.get("c").unwrap(), 1.into());
|
assert_eq!(*bindings.get("c").unwrap(), 1.into());
|
||||||
|
|
|
||||||
|
|
@ -50,8 +50,8 @@ pub enum ValidationError {
|
||||||
span: Option<Span>,
|
span: Option<Span>,
|
||||||
},
|
},
|
||||||
|
|
||||||
#[error("Batch not found: {id}")]
|
#[error("Module not found: {name}")]
|
||||||
BatchNotFound { id: String, span: Option<Span> },
|
ModuleNotFound { name: String, span: Option<Span> },
|
||||||
|
|
||||||
#[error("Undefined predicate: {name}")]
|
#[error("Undefined predicate: {name}")]
|
||||||
UndefinedPredicate { name: String, span: Option<Span> },
|
UndefinedPredicate { name: String, span: Option<Span> },
|
||||||
|
|
@ -91,6 +91,18 @@ pub enum ValidationError {
|
||||||
|
|
||||||
#[error("Wildcard '{name}' collides with a predicate name")]
|
#[error("Wildcard '{name}' collides with a predicate name")]
|
||||||
WildcardPredicateNameCollision { name: String },
|
WildcardPredicateNameCollision { name: String },
|
||||||
|
|
||||||
|
#[error("Predicate definitions are not allowed in requests")]
|
||||||
|
PredicatesNotAllowedInRequest { span: Option<Span> },
|
||||||
|
|
||||||
|
#[error("REQUEST block is not allowed in modules")]
|
||||||
|
RequestNotAllowedInModule { span: Option<Span> },
|
||||||
|
|
||||||
|
#[error("Modules must contain at least one predicate definition")]
|
||||||
|
NoPredicatesInModule,
|
||||||
|
|
||||||
|
#[error("Requests must contain a REQUEST block")]
|
||||||
|
NoRequestBlock,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lowering errors from frontend AST lowering to middleware
|
/// Lowering errors from frontend AST lowering to middleware
|
||||||
|
|
|
||||||
|
|
@ -18,17 +18,17 @@ pub struct Document {
|
||||||
/// Top-level items that can appear in a document
|
/// Top-level items that can appear in a document
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub enum DocumentItem {
|
pub enum DocumentItem {
|
||||||
UseBatchStatement(UseBatchStatement),
|
UseModuleStatement(UseModuleStatement),
|
||||||
UseIntroStatement(UseIntroStatement),
|
UseIntroStatement(UseIntroStatement),
|
||||||
CustomPredicateDef(CustomPredicateDef),
|
CustomPredicateDef(CustomPredicateDef),
|
||||||
RequestDef(RequestDef),
|
RequestDef(RequestDef),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Import statement: `use batch pred1, pred2, _ from 0x...`
|
/// Module import statement: `use module 0xHASH as alias`
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct UseBatchStatement {
|
pub struct UseModuleStatement {
|
||||||
pub imports: Vec<ImportName>,
|
pub hash: HashHex,
|
||||||
pub batch_ref: HashHex,
|
pub alias: Identifier,
|
||||||
pub span: Option<Span>,
|
pub span: Option<Span>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -40,19 +40,6 @@ pub struct UseIntroStatement {
|
||||||
pub intro_hash: HashHex,
|
pub intro_hash: HashHex,
|
||||||
pub span: Option<Span>,
|
pub span: Option<Span>,
|
||||||
}
|
}
|
||||||
/// Individual import name (identifier or unused "_")
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub enum ImportName {
|
|
||||||
Named(String),
|
|
||||||
Unused, // "_"
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Batch reference (hash)
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub struct BatchRef {
|
|
||||||
pub hash: HashHex,
|
|
||||||
pub span: Option<Span>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Intro predicate reference (hash)
|
/// Intro predicate reference (hash)
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
|
@ -96,11 +83,33 @@ pub enum ConjunctionType {
|
||||||
/// Statement template: predicate call with arguments
|
/// Statement template: predicate call with arguments
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct StatementTmpl {
|
pub struct StatementTmpl {
|
||||||
pub predicate: Identifier,
|
pub predicate: PredicateRef,
|
||||||
pub args: Vec<StatementTmplArg>,
|
pub args: Vec<StatementTmplArg>,
|
||||||
pub span: Option<Span>,
|
pub span: Option<Span>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reference to a predicate (local or qualified with module name)
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum PredicateRef {
|
||||||
|
/// Unqualified name (local or native predicate)
|
||||||
|
Local(Identifier),
|
||||||
|
/// Qualified name (module::predicate)
|
||||||
|
Qualified {
|
||||||
|
module: Identifier,
|
||||||
|
predicate: Identifier,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PredicateRef {
|
||||||
|
/// Get the predicate name (without module qualifier)
|
||||||
|
pub fn predicate_name(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
PredicateRef::Local(id) => &id.name,
|
||||||
|
PredicateRef::Qualified { predicate, .. } => &predicate.name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Arguments that can be passed to statements
|
/// Arguments that can be passed to statements
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub enum StatementTmplArg {
|
pub enum StatementTmplArg {
|
||||||
|
|
@ -256,7 +265,7 @@ impl fmt::Display for Document {
|
||||||
impl fmt::Display for DocumentItem {
|
impl fmt::Display for DocumentItem {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
DocumentItem::UseBatchStatement(u) => write!(f, "{}", u),
|
DocumentItem::UseModuleStatement(u) => write!(f, "{}", u),
|
||||||
DocumentItem::UseIntroStatement(u) => write!(f, "{}", u),
|
DocumentItem::UseIntroStatement(u) => write!(f, "{}", u),
|
||||||
DocumentItem::CustomPredicateDef(c) => write!(f, "{}", c),
|
DocumentItem::CustomPredicateDef(c) => write!(f, "{}", c),
|
||||||
DocumentItem::RequestDef(r) => write!(f, "{}", r),
|
DocumentItem::RequestDef(r) => write!(f, "{}", r),
|
||||||
|
|
@ -264,16 +273,9 @@ impl fmt::Display for DocumentItem {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for UseBatchStatement {
|
impl fmt::Display for UseModuleStatement {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
write!(f, "use batch ")?;
|
write!(f, "use module {} as {}", self.hash, self.alias)
|
||||||
for (i, import) in self.imports.iter().enumerate() {
|
|
||||||
if i > 0 {
|
|
||||||
write!(f, ", ")?;
|
|
||||||
}
|
|
||||||
write!(f, "{}", import)?;
|
|
||||||
}
|
|
||||||
write!(f, " from {}", self.batch_ref)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -290,19 +292,15 @@ impl fmt::Display for UseIntroStatement {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for ImportName {
|
impl fmt::Display for PredicateRef {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
ImportName::Named(name) => write!(f, "{}", name),
|
PredicateRef::Local(id) => write!(f, "{}", id),
|
||||||
ImportName::Unused => write!(f, "_"),
|
PredicateRef::Qualified { module, predicate } => {
|
||||||
|
write!(f, "{}::{}", module, predicate)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for BatchRef {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(f, "{}", self.hash)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for IntroPredicateRef {
|
impl fmt::Display for IntroPredicateRef {
|
||||||
|
|
@ -536,10 +534,10 @@ pub mod parse {
|
||||||
|
|
||||||
for inner_pair in pair.into_inner() {
|
for inner_pair in pair.into_inner() {
|
||||||
match inner_pair.as_rule() {
|
match inner_pair.as_rule() {
|
||||||
Rule::use_batch_statement => {
|
Rule::use_module_statement => {
|
||||||
items.push(DocumentItem::UseBatchStatement(parse_use_batch_statement(
|
items.push(DocumentItem::UseModuleStatement(
|
||||||
inner_pair,
|
parse_use_module_statement(inner_pair),
|
||||||
)));
|
));
|
||||||
}
|
}
|
||||||
Rule::use_intro_statement => {
|
Rule::use_intro_statement => {
|
||||||
items.push(DocumentItem::UseIntroStatement(parse_use_intro_statement(
|
items.push(DocumentItem::UseIntroStatement(parse_use_intro_statement(
|
||||||
|
|
@ -562,25 +560,17 @@ pub mod parse {
|
||||||
Ok(Document { items })
|
Ok(Document { items })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_use_batch_statement(pair: Pair<Rule>) -> UseBatchStatement {
|
fn parse_use_module_statement(pair: Pair<Rule>) -> UseModuleStatement {
|
||||||
assert_eq!(pair.as_rule(), Rule::use_batch_statement);
|
assert_eq!(pair.as_rule(), Rule::use_module_statement);
|
||||||
let span = get_span(&pair);
|
let span = get_span(&pair);
|
||||||
let mut inner = pair.into_inner();
|
let mut inner = pair.into_inner();
|
||||||
|
|
||||||
let use_list_pair = inner
|
let hash = parse_hash_hex(inner.next().unwrap());
|
||||||
.find(|p| p.as_rule() == Rule::use_predicate_list)
|
let alias = parse_identifier(inner.next().unwrap());
|
||||||
.unwrap();
|
|
||||||
let batch_ref_pair = inner.find(|p| p.as_rule() == Rule::batch_ref).unwrap();
|
|
||||||
|
|
||||||
let imports = use_list_pair
|
UseModuleStatement {
|
||||||
.into_inner()
|
hash,
|
||||||
.filter(|p| p.as_rule() == Rule::import_name)
|
alias,
|
||||||
.map(parse_import_name)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
UseBatchStatement {
|
|
||||||
imports,
|
|
||||||
batch_ref: parse_hash_hex(batch_ref_pair.into_inner().next().unwrap()),
|
|
||||||
span: Some(span),
|
span: Some(span),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -622,16 +612,6 @@ pub mod parse {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_import_name(pair: Pair<Rule>) -> ImportName {
|
|
||||||
assert_eq!(pair.as_rule(), Rule::import_name);
|
|
||||||
let s = pair.as_str();
|
|
||||||
if s == "_" {
|
|
||||||
ImportName::Unused
|
|
||||||
} else {
|
|
||||||
ImportName::Named(s.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_hash_hex(pair: Pair<Rule>) -> HashHex {
|
fn parse_hash_hex(pair: Pair<Rule>) -> HashHex {
|
||||||
assert_eq!(pair.as_rule(), Rule::hash_hex);
|
assert_eq!(pair.as_rule(), Rule::hash_hex);
|
||||||
let span = get_span(&pair);
|
let span = get_span(&pair);
|
||||||
|
|
@ -748,7 +728,7 @@ pub mod parse {
|
||||||
let span = get_span(&pair);
|
let span = get_span(&pair);
|
||||||
let mut inner = pair.into_inner();
|
let mut inner = pair.into_inner();
|
||||||
|
|
||||||
let predicate = parse_identifier(inner.next().unwrap());
|
let predicate = parse_predicate_ref(inner.next().unwrap());
|
||||||
let mut args = Vec::new();
|
let mut args = Vec::new();
|
||||||
|
|
||||||
if let Some(arg_list) = inner.next() {
|
if let Some(arg_list) = inner.next() {
|
||||||
|
|
@ -768,6 +748,22 @@ pub mod parse {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_predicate_ref(pair: Pair<Rule>) -> PredicateRef {
|
||||||
|
assert_eq!(pair.as_rule(), Rule::predicate_ref);
|
||||||
|
let inner = pair.into_inner().next().unwrap();
|
||||||
|
|
||||||
|
match inner.as_rule() {
|
||||||
|
Rule::qualified_predicate_ref => {
|
||||||
|
let mut parts = inner.into_inner();
|
||||||
|
let module = parse_identifier(parts.next().unwrap());
|
||||||
|
let predicate = parse_identifier(parts.next().unwrap());
|
||||||
|
PredicateRef::Qualified { module, predicate }
|
||||||
|
}
|
||||||
|
Rule::identifier => PredicateRef::Local(parse_identifier(inner)),
|
||||||
|
_ => unreachable!("Unexpected predicate_ref rule: {:?}", inner.as_rule()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_statement_arg(pair: Pair<Rule>) -> Result<StatementTmplArg, parser::ParseError> {
|
fn parse_statement_arg(pair: Pair<Rule>) -> Result<StatementTmplArg, parser::ParseError> {
|
||||||
assert_eq!(pair.as_rule(), Rule::statement_arg);
|
assert_eq!(pair.as_rule(), Rule::statement_arg);
|
||||||
let inner = pair.into_inner().next().unwrap();
|
let inner = pair.into_inner().next().unwrap();
|
||||||
|
|
@ -1047,9 +1043,10 @@ mod tests {
|
||||||
fn clear_spans(doc: &mut Document) {
|
fn clear_spans(doc: &mut Document) {
|
||||||
for item in &mut doc.items {
|
for item in &mut doc.items {
|
||||||
match item {
|
match item {
|
||||||
DocumentItem::UseBatchStatement(u) => {
|
DocumentItem::UseModuleStatement(u) => {
|
||||||
u.span = None;
|
u.span = None;
|
||||||
u.batch_ref.span = None;
|
u.hash.span = None;
|
||||||
|
u.alias.span = None;
|
||||||
}
|
}
|
||||||
DocumentItem::UseIntroStatement(u) => {
|
DocumentItem::UseIntroStatement(u) => {
|
||||||
u.span = None;
|
u.span = None;
|
||||||
|
|
@ -1082,9 +1079,19 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_predicate_ref_spans(pred_ref: &mut PredicateRef) {
|
||||||
|
match pred_ref {
|
||||||
|
PredicateRef::Local(id) => id.span = None,
|
||||||
|
PredicateRef::Qualified { module, predicate } => {
|
||||||
|
module.span = None;
|
||||||
|
predicate.span = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn clear_statement_spans(stmt: &mut StatementTmpl) {
|
fn clear_statement_spans(stmt: &mut StatementTmpl) {
|
||||||
stmt.span = None;
|
stmt.span = None;
|
||||||
stmt.predicate.span = None;
|
clear_predicate_ref_spans(&mut stmt.predicate);
|
||||||
for arg in &mut stmt.args {
|
for arg in &mut stmt.args {
|
||||||
match arg {
|
match arg {
|
||||||
StatementTmplArg::Literal(lit) => clear_literal_spans(lit),
|
StatementTmplArg::Literal(lit) => clear_literal_spans(lit),
|
||||||
|
|
@ -1168,8 +1175,8 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_use_batch_statement() {
|
fn test_use_module_statement() {
|
||||||
let input = r#"use batch pred1, pred2, _ from 0x0000000000000000000000000000000000000000000000000000000000000000"#;
|
let input = r#"use module 0x0000000000000000000000000000000000000000000000000000000000000000 as helpers"#;
|
||||||
test_roundtrip(input);
|
test_roundtrip(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1223,11 +1230,11 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_complete_document() {
|
fn test_complete_document() {
|
||||||
let input = r#"use batch imported_pred from 0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd
|
let input = r#"use module 0x0000000000000000000000000000000000000000000000000000000000000000 as imported
|
||||||
|
|
||||||
is_valid(User, private: Config) = AND (
|
is_valid(User, private: Config) = AND (
|
||||||
Equal(User["age"], Config["min_age"])
|
Equal(User["age"], Config["min_age"])
|
||||||
imported_pred(User, Config)
|
imported::some_pred(User, Config)
|
||||||
)
|
)
|
||||||
|
|
||||||
check_both(A, B, C) = OR (
|
check_both(A, B, C) = OR (
|
||||||
|
|
@ -1306,7 +1313,7 @@ REQUEST(
|
||||||
// Check request structure
|
// Check request structure
|
||||||
if let DocumentItem::RequestDef(req) = &ast.items[1] {
|
if let DocumentItem::RequestDef(req) = &ast.items[1] {
|
||||||
assert_eq!(req.statements.len(), 1);
|
assert_eq!(req.statements.len(), 1);
|
||||||
assert_eq!(req.statements[0].predicate.name, "my_pred");
|
assert_eq!(req.statements[0].predicate.predicate_name(), "my_pred");
|
||||||
assert_eq!(req.statements[0].args.len(), 2);
|
assert_eq!(req.statements[0].args.len(), 2);
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected RequestDef");
|
panic!("Expected RequestDef");
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,52 +1,69 @@
|
||||||
//! Lowering from frontend AST to middleware structures
|
//! Lowering from frontend AST to middleware structures
|
||||||
//!
|
//!
|
||||||
//! This module converts validated frontend AST to middleware data structures.
|
//! This module converts validated frontend AST to middleware data structures.
|
||||||
//! Supports automatic predicate splitting and multi-batch packing.
|
//! Supports automatic predicate splitting.
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
str::FromStr,
|
str::FromStr,
|
||||||
sync::Arc,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
frontend::{BuilderArg, PredicateOrWildcard, StatementTmplBuilder},
|
frontend::{BuilderArg, PredicateOrWildcard, StatementTmplBuilder},
|
||||||
lang::{
|
lang::{
|
||||||
frontend_ast::*,
|
frontend_ast::*,
|
||||||
frontend_ast_batch::{self, PredicateBatches},
|
|
||||||
frontend_ast_split,
|
frontend_ast_split,
|
||||||
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
|
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
|
||||||
|
module, Module,
|
||||||
},
|
},
|
||||||
middleware::{
|
middleware::{
|
||||||
self, containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
|
self, containers, CustomPredicateRef, IntroPredicateRef, Key, NativePredicate, Params,
|
||||||
NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl,
|
Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg, Value,
|
||||||
StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
|
Wildcard,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Context for predicate resolution - determines how local custom predicates are resolved
|
/// Context for predicate resolution - determines how predicates are resolved
|
||||||
pub enum ResolutionContext<'a> {
|
pub enum ResolutionContext<'a> {
|
||||||
/// Request context: local custom predicates resolve to Intro/CustomPredicateRef via batches
|
/// Request context: predicates resolve via imports only (no local definitions)
|
||||||
Request {
|
Request,
|
||||||
batches: Option<&'a PredicateBatches>,
|
/// Module context: local predicates resolve to BatchSelf
|
||||||
},
|
Module {
|
||||||
/// Batch context: local custom predicates may resolve to BatchSelf or Intro/CustomPredicateRef
|
/// Maps predicate name to index within the module
|
||||||
Batch {
|
reference_map: &'a HashMap<String, usize>,
|
||||||
current_batch_idx: usize,
|
/// Name of the custom predicate being defined (for wildcard scope lookup)
|
||||||
reference_map: &'a HashMap<String, (usize, usize)>,
|
|
||||||
existing_batches: &'a [Arc<CustomPredicateBatch>],
|
|
||||||
custom_predicate_name: &'a str,
|
custom_predicate_name: &'a str,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resolve a predicate reference to a Predicate using the symbol table
|
||||||
|
pub fn resolve_predicate_ref(
|
||||||
|
pred_ref: &PredicateRef,
|
||||||
|
symbols: &SymbolTable,
|
||||||
|
context: &ResolutionContext,
|
||||||
|
) -> Option<PredicateOrWildcard> {
|
||||||
|
match pred_ref {
|
||||||
|
PredicateRef::Qualified { module, predicate } => {
|
||||||
|
// Look up the module in the imported_modules
|
||||||
|
let imported_module = symbols.imported_modules.get(&module.name)?;
|
||||||
|
// Find the predicate index in the module
|
||||||
|
let idx = *imported_module.predicate_index.get(&predicate.name)?;
|
||||||
|
Some(PredicateOrWildcard::Predicate(Predicate::Custom(
|
||||||
|
CustomPredicateRef::new(imported_module.batch.clone(), idx),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
PredicateRef::Local(id) => resolve_predicate(&id.name, symbols, context),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Resolve a predicate name to a Predicate using the symbol table
|
/// Resolve a predicate name to a Predicate using the symbol table
|
||||||
pub fn resolve_predicate(
|
pub fn resolve_predicate(
|
||||||
pred_name: &str,
|
pred_name: &str,
|
||||||
symbols: &SymbolTable,
|
symbols: &SymbolTable,
|
||||||
context: &ResolutionContext,
|
context: &ResolutionContext,
|
||||||
) -> Option<PredicateOrWildcard> {
|
) -> Option<PredicateOrWildcard> {
|
||||||
// 0. Try wildcard first
|
// 0. Try wildcard first (only in module context where we're defining predicates)
|
||||||
if let ResolutionContext::Batch {
|
if let ResolutionContext::Module {
|
||||||
custom_predicate_name,
|
custom_predicate_name,
|
||||||
..
|
..
|
||||||
} = context
|
} = context
|
||||||
|
|
@ -69,28 +86,35 @@ pub fn resolve_predicate(
|
||||||
PredicateKind::Native(np) => Predicate::Native(*np),
|
PredicateKind::Native(np) => Predicate::Native(*np),
|
||||||
|
|
||||||
PredicateKind::Custom { .. } => match context {
|
PredicateKind::Custom { .. } => match context {
|
||||||
ResolutionContext::Request { batches } => {
|
ResolutionContext::Request => {
|
||||||
let batches = batches.as_ref()?;
|
// Requests can't define local predicates, so this shouldn't happen
|
||||||
let pred_ref = batches.predicate_ref_by_name(pred_name)?;
|
return None;
|
||||||
Predicate::Custom(pred_ref)
|
}
|
||||||
|
ResolutionContext::Module { reference_map, .. } => {
|
||||||
|
resolve_local_predicate(pred_name, reference_map)?
|
||||||
}
|
}
|
||||||
ResolutionContext::Batch {
|
|
||||||
current_batch_idx,
|
|
||||||
reference_map,
|
|
||||||
existing_batches,
|
|
||||||
..
|
|
||||||
} => resolve_local_predicate(
|
|
||||||
pred_name,
|
|
||||||
*current_batch_idx,
|
|
||||||
reference_map,
|
|
||||||
existing_batches,
|
|
||||||
)?,
|
|
||||||
},
|
},
|
||||||
|
|
||||||
PredicateKind::BatchImported { batch, index } => {
|
PredicateKind::BatchImported { batch, index } => {
|
||||||
Predicate::Custom(CustomPredicateRef::new(batch.clone(), *index))
|
Predicate::Custom(CustomPredicateRef::new(batch.clone(), *index))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PredicateKind::ModuleImported {
|
||||||
|
module_name,
|
||||||
|
predicate_index,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
// Look up the module in the imported_modules
|
||||||
|
let module = symbols
|
||||||
|
.imported_modules
|
||||||
|
.get(module_name)
|
||||||
|
.expect("Module should exist if ModuleImported predicate kind exists");
|
||||||
|
Predicate::Custom(CustomPredicateRef::new(
|
||||||
|
module.batch.clone(),
|
||||||
|
*predicate_index,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
PredicateKind::IntroImported {
|
PredicateKind::IntroImported {
|
||||||
name,
|
name,
|
||||||
verifier_data_hash,
|
verifier_data_hash,
|
||||||
|
|
@ -103,22 +127,11 @@ pub fn resolve_predicate(
|
||||||
return Some(PredicateOrWildcard::Predicate(predicate));
|
return Some(PredicateOrWildcard::Predicate(predicate));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. In batch context, also check reference_map for split chain pieces
|
// 3. In module context, also check reference_map for split chain pieces
|
||||||
// (predicates created by splitting that aren't in the original symbol table)
|
// (predicates created by splitting that aren't in the original symbol table)
|
||||||
if let ResolutionContext::Batch {
|
if let ResolutionContext::Module { reference_map, .. } = context {
|
||||||
current_batch_idx,
|
|
||||||
reference_map,
|
|
||||||
existing_batches,
|
|
||||||
..
|
|
||||||
} = context
|
|
||||||
{
|
|
||||||
if reference_map.contains_key(pred_name) {
|
if reference_map.contains_key(pred_name) {
|
||||||
return resolve_local_predicate(
|
return resolve_local_predicate(pred_name, reference_map)
|
||||||
pred_name,
|
|
||||||
*current_batch_idx,
|
|
||||||
reference_map,
|
|
||||||
existing_batches,
|
|
||||||
)
|
|
||||||
.map(PredicateOrWildcard::Predicate);
|
.map(PredicateOrWildcard::Predicate);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -126,28 +139,13 @@ pub fn resolve_predicate(
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve a local predicate (one in this document or a split chain piece) using the reference_map
|
/// Resolve a local predicate (one in this module or a split chain piece) using the reference_map
|
||||||
fn resolve_local_predicate(
|
fn resolve_local_predicate(
|
||||||
pred_name: &str,
|
pred_name: &str,
|
||||||
current_batch_idx: usize,
|
reference_map: &HashMap<String, usize>,
|
||||||
reference_map: &HashMap<String, (usize, usize)>,
|
|
||||||
existing_batches: &[Arc<CustomPredicateBatch>],
|
|
||||||
) -> Option<Predicate> {
|
) -> Option<Predicate> {
|
||||||
let &(target_batch, target_idx) = reference_map.get(pred_name)?;
|
let &idx = reference_map.get(pred_name)?;
|
||||||
if target_batch == current_batch_idx {
|
Some(Predicate::BatchSelf(idx))
|
||||||
Some(Predicate::BatchSelf(target_idx))
|
|
||||||
} else if target_batch < current_batch_idx {
|
|
||||||
let batch = &existing_batches[target_batch];
|
|
||||||
Some(Predicate::Custom(CustomPredicateRef::new(
|
|
||||||
batch.clone(),
|
|
||||||
target_idx,
|
|
||||||
)))
|
|
||||||
} else {
|
|
||||||
unreachable!(
|
|
||||||
"Forward cross-batch reference should be impossible: {} -> {}",
|
|
||||||
current_batch_idx, target_batch
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
@ -155,7 +153,7 @@ fn resolve_local_predicate(
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// These functions convert AST types to middleware/builder types and are used
|
// These functions convert AST types to middleware/builder types and are used
|
||||||
// by both the request lowering (in this module) and predicate batching
|
// by both the request lowering (in this module) and predicate batching
|
||||||
// (in frontend_ast_batch).
|
// (in module.rs).
|
||||||
|
|
||||||
/// Lower a literal value from AST to middleware Value.
|
/// Lower a literal value from AST to middleware Value.
|
||||||
///
|
///
|
||||||
|
|
@ -215,38 +213,37 @@ pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result of lowering: optional custom predicate batches and optional request
|
|
||||||
///
|
|
||||||
/// A Podlang file can contain:
|
|
||||||
/// - Just custom predicates (batches: Some, request: None)
|
|
||||||
/// - Just a request (batches: None, request: Some)
|
|
||||||
/// - Both (batches: Some, request: Some)
|
|
||||||
/// - Neither (batches: None, request: None) - just imports
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct LoweredOutput {
|
|
||||||
pub batches: Option<PredicateBatches>,
|
|
||||||
pub request: Option<crate::frontend::PodRequest>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub use crate::lang::error::LoweringError;
|
pub use crate::lang::error::LoweringError;
|
||||||
|
|
||||||
/// Lower a validated AST to middleware structures
|
/// Lower a validated module AST to a Module
|
||||||
///
|
///
|
||||||
/// Returns both the custom predicate batch (if any) and the request (if any).
|
/// The validated AST must have been validated in Module mode.
|
||||||
/// At least one will be Some if the document contains custom predicates or a request.
|
pub fn lower_module(
|
||||||
pub fn lower(
|
|
||||||
validated: ValidatedAST,
|
validated: ValidatedAST,
|
||||||
params: &Params,
|
params: &Params,
|
||||||
batch_name: String,
|
module_name: &str,
|
||||||
) -> Result<LoweredOutput, LoweringError> {
|
) -> Result<Module, LoweringError> {
|
||||||
if !validated.diagnostics().is_empty() {
|
if !validated.diagnostics().is_empty() {
|
||||||
// For now, treat any diagnostics as errors
|
|
||||||
// In future we could allow warnings
|
|
||||||
return Err(LoweringError::ValidationErrors);
|
return Err(LoweringError::ValidationErrors);
|
||||||
}
|
}
|
||||||
|
|
||||||
let lowerer = Lowerer::new(validated, params);
|
let lowerer = Lowerer::new(validated, params);
|
||||||
lowerer.lower(batch_name)
|
lowerer.lower_module(module_name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lower a validated request AST to a PodRequest
|
||||||
|
///
|
||||||
|
/// The validated AST must have been validated in Request mode.
|
||||||
|
pub fn lower_request(
|
||||||
|
validated: ValidatedAST,
|
||||||
|
params: &Params,
|
||||||
|
) -> Result<crate::frontend::PodRequest, LoweringError> {
|
||||||
|
if !validated.diagnostics().is_empty() {
|
||||||
|
return Err(LoweringError::ValidationErrors);
|
||||||
|
}
|
||||||
|
|
||||||
|
let lowerer = Lowerer::new(validated, params);
|
||||||
|
lowerer.lower_request()
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Lowerer<'a> {
|
struct Lowerer<'a> {
|
||||||
|
|
@ -259,52 +256,33 @@ impl<'a> Lowerer<'a> {
|
||||||
Self { validated, params }
|
Self { validated, params }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lower(self, batch_name: String) -> Result<LoweredOutput, LoweringError> {
|
fn lower_module(self, module_name: &str) -> Result<Module, LoweringError> {
|
||||||
// Lower custom predicates (if any) - now supports multiple batches
|
|
||||||
let batches = self.lower_batches(batch_name)?;
|
|
||||||
|
|
||||||
// Lower request (if any) - pass batches so refs can be resolved
|
|
||||||
let request = self.lower_request(batches.as_ref())?;
|
|
||||||
|
|
||||||
Ok(LoweredOutput { batches, request })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn lower_batches(&self, batch_name: String) -> Result<Option<PredicateBatches>, LoweringError> {
|
|
||||||
// Extract and split custom predicates from document
|
// Extract and split custom predicates from document
|
||||||
let custom_predicates = self.extract_and_split_predicates()?;
|
let custom_predicates = self.extract_and_split_predicates()?;
|
||||||
|
|
||||||
// If no custom predicates, return None
|
// Build the module from split predicates
|
||||||
if custom_predicates.is_empty() {
|
let module = module::build_module(
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use the new batching module to pack predicates into batches
|
|
||||||
// Pass the symbol table for unified predicate resolution
|
|
||||||
let batches = frontend_ast_batch::batch_predicates(
|
|
||||||
custom_predicates,
|
custom_predicates,
|
||||||
self.params,
|
self.params,
|
||||||
&batch_name,
|
module_name,
|
||||||
self.validated.symbols(),
|
self.validated.symbols(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Some(batches))
|
Ok(module)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lower_request(
|
fn lower_request(self) -> Result<crate::frontend::PodRequest, LoweringError> {
|
||||||
&self,
|
|
||||||
batches: Option<&PredicateBatches>,
|
|
||||||
) -> Result<Option<crate::frontend::PodRequest>, LoweringError> {
|
|
||||||
let doc = self.validated.document();
|
let doc = self.validated.document();
|
||||||
|
|
||||||
// Find request definition (if any)
|
// Find request definition
|
||||||
let request_def = doc.items.iter().find_map(|item| match item {
|
let request_def = doc
|
||||||
|
.items
|
||||||
|
.iter()
|
||||||
|
.find_map(|item| match item {
|
||||||
DocumentItem::RequestDef(req) => Some(req),
|
DocumentItem::RequestDef(req) => Some(req),
|
||||||
_ => None,
|
_ => None,
|
||||||
});
|
})
|
||||||
|
.expect("Request mode validation ensures REQUEST block exists");
|
||||||
let Some(request_def) = request_def else {
|
|
||||||
return Ok(None);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Build wildcard map from all wildcards used in the request statements
|
// Build wildcard map from all wildcards used in the request statements
|
||||||
let wildcard_map = self.build_request_wildcard_map(request_def);
|
let wildcard_map = self.build_request_wildcard_map(request_def);
|
||||||
|
|
@ -312,18 +290,17 @@ impl<'a> Lowerer<'a> {
|
||||||
// Lower each statement to middleware templates, resolving predicates
|
// Lower each statement to middleware templates, resolving predicates
|
||||||
let mut request_templates = Vec::new();
|
let mut request_templates = Vec::new();
|
||||||
for stmt in &request_def.statements {
|
for stmt in &request_def.statements {
|
||||||
let mw_stmt = self.lower_request_statement(stmt, &wildcard_map, batches)?;
|
let mw_stmt = self.lower_request_statement(stmt, &wildcard_map)?;
|
||||||
request_templates.push(mw_stmt);
|
request_templates.push(mw_stmt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Some(crate::frontend::PodRequest::new(request_templates)))
|
Ok(crate::frontend::PodRequest::new(request_templates))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lower_request_statement(
|
fn lower_request_statement(
|
||||||
&self,
|
&self,
|
||||||
stmt: &StatementTmpl,
|
stmt: &StatementTmpl,
|
||||||
wildcard_map: &HashMap<String, usize>,
|
wildcard_map: &HashMap<String, usize>,
|
||||||
batches: Option<&PredicateBatches>,
|
|
||||||
) -> Result<MWStatementTmpl, LoweringError> {
|
) -> Result<MWStatementTmpl, LoweringError> {
|
||||||
// Enforce argument count limit for request statements
|
// Enforce argument count limit for request statements
|
||||||
if stmt.args.len() > Params::max_statement_args() {
|
if stmt.args.len() > Params::max_statement_args() {
|
||||||
|
|
@ -333,14 +310,14 @@ impl<'a> Lowerer<'a> {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let pred_name = &stmt.predicate.name;
|
|
||||||
let symbols = self.validated.symbols();
|
let symbols = self.validated.symbols();
|
||||||
|
|
||||||
// Resolve predicate using the unified resolution function
|
// Resolve predicate using the unified resolution function
|
||||||
let context = ResolutionContext::Request { batches };
|
let context = ResolutionContext::Request;
|
||||||
let predicate = resolve_predicate(pred_name, symbols, &context).ok_or_else(|| {
|
let predicate =
|
||||||
|
resolve_predicate_ref(&stmt.predicate, symbols, &context).ok_or_else(|| {
|
||||||
LoweringError::PredicateNotFound {
|
LoweringError::PredicateNotFound {
|
||||||
name: pred_name.clone(),
|
name: format!("{}", stmt.predicate),
|
||||||
}
|
}
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|
@ -453,31 +430,24 @@ impl<'a> Lowerer<'a> {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::lang::{
|
use crate::lang::{
|
||||||
frontend_ast::parse::parse_document, frontend_ast_validate::validate, parser::parse_podlang,
|
frontend_ast::parse::parse_document,
|
||||||
|
frontend_ast_validate::{validate, ParseMode},
|
||||||
|
parser::parse_podlang,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn parse_validate_and_lower(
|
fn parse_validate_and_lower_module(
|
||||||
input: &str,
|
input: &str,
|
||||||
params: &Params,
|
params: &Params,
|
||||||
) -> Result<LoweredOutput, LoweringError> {
|
) -> Result<Module, LoweringError> {
|
||||||
let parsed = parse_podlang(input).expect("Failed to parse");
|
let parsed = parse_podlang(input).expect("Failed to parse");
|
||||||
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
|
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
|
||||||
let validated = validate(document, &[]).expect("Failed to validate");
|
let validated =
|
||||||
lower(validated, params, "test_batch".to_string())
|
validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate");
|
||||||
}
|
lower_module(validated, params, "test_batch")
|
||||||
|
|
||||||
// Helper to get the first batch from the output (expecting it to exist)
|
|
||||||
fn expect_batch(
|
|
||||||
output: &LoweredOutput,
|
|
||||||
) -> &std::sync::Arc<crate::middleware::CustomPredicateBatch> {
|
|
||||||
output
|
|
||||||
.batches
|
|
||||||
.as_ref()
|
|
||||||
.expect("Expected batches to be present")
|
|
||||||
.first_batch()
|
|
||||||
.expect("Expected at least one batch")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -489,16 +459,16 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let result = parse_validate_and_lower(input, ¶ms);
|
let result = parse_validate_and_lower_module(input, ¶ms);
|
||||||
if let Err(e) = &result {
|
if let Err(e) = &result {
|
||||||
eprintln!("Error: {:?}", e);
|
eprintln!("Error: {:?}", e);
|
||||||
}
|
}
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let lowered = result.unwrap();
|
let module = result.unwrap();
|
||||||
assert_eq!(expect_batch(&lowered).predicates().len(), 1);
|
assert_eq!(module.batch.predicates().len(), 1);
|
||||||
|
|
||||||
let pred = &expect_batch(&lowered).predicates()[0];
|
let pred = &module.batch.predicates()[0];
|
||||||
assert_eq!(pred.name, "my_pred");
|
assert_eq!(pred.name, "my_pred");
|
||||||
assert_eq!(pred.args_len(), 2);
|
assert_eq!(pred.args_len(), 2);
|
||||||
assert_eq!(pred.wildcard_names().len(), 2);
|
assert_eq!(pred.wildcard_names().len(), 2);
|
||||||
|
|
@ -515,11 +485,11 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let result = parse_validate_and_lower(input, ¶ms);
|
let result = parse_validate_and_lower_module(input, ¶ms);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let lowered = result.unwrap();
|
let module = result.unwrap();
|
||||||
let pred = &expect_batch(&lowered).predicates()[0];
|
let pred = &module.batch.predicates()[0];
|
||||||
assert_eq!(pred.args_len(), 1); // Only A is public
|
assert_eq!(pred.args_len(), 1); // Only A is public
|
||||||
assert_eq!(pred.wildcard_names().len(), 3); // A, B, C total
|
assert_eq!(pred.wildcard_names().len(), 3); // A, B, C total
|
||||||
}
|
}
|
||||||
|
|
@ -534,11 +504,11 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let result = parse_validate_and_lower(input, ¶ms);
|
let result = parse_validate_and_lower_module(input, ¶ms);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let lowered = result.unwrap();
|
let module = result.unwrap();
|
||||||
let pred = &expect_batch(&lowered).predicates()[0];
|
let pred = &module.batch.predicates()[0];
|
||||||
assert!(pred.is_disjunction());
|
assert!(pred.is_disjunction());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -556,23 +526,22 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default(); // max_custom_predicate_arity = 5
|
let params = Params::default(); // max_custom_predicate_arity = 5
|
||||||
let result = parse_validate_and_lower(input, ¶ms);
|
let result = parse_validate_and_lower_module(input, ¶ms);
|
||||||
if let Err(e) = &result {
|
if let Err(e) = &result {
|
||||||
eprintln!("Splitting error: {:?}", e);
|
eprintln!("Splitting error: {:?}", e);
|
||||||
}
|
}
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let lowered = result.unwrap();
|
let module = result.unwrap();
|
||||||
// Should be automatically split into 2 predicates (my_pred and my_pred_1)
|
// Should be automatically split into 2 predicates (my_pred and my_pred_1)
|
||||||
let batches = lowered.batches.as_ref().expect("Expected batches");
|
assert_eq!(module.batch.predicates().len(), 2);
|
||||||
assert_eq!(batches.total_predicate_count(), 2);
|
|
||||||
|
|
||||||
// With topological sorting, my_pred_1 comes first (since my_pred depends on it)
|
// With topological sorting, my_pred_1 comes first (since my_pred depends on it)
|
||||||
// my_pred_1 has 2 statements
|
// my_pred_1 has 2 statements
|
||||||
// my_pred has 5 statements (4 + chain call)
|
// my_pred has 5 statements (4 + chain call)
|
||||||
// Just verify we have the right total statement counts
|
// Just verify we have the right total statement counts
|
||||||
let batch = batches.first_batch().unwrap();
|
let total_statements: usize = module
|
||||||
let total_statements: usize = batch
|
.batch
|
||||||
.predicates()
|
.predicates()
|
||||||
.iter()
|
.iter()
|
||||||
.map(|p| p.statements().len())
|
.map(|p| p.statements().len())
|
||||||
|
|
@ -593,11 +562,11 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let result = parse_validate_and_lower(input, ¶ms);
|
let result = parse_validate_and_lower_module(input, ¶ms);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let lowered = result.unwrap();
|
let module = result.unwrap();
|
||||||
assert_eq!(expect_batch(&lowered).predicates().len(), 2);
|
assert_eq!(module.batch.predicates().len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -613,11 +582,11 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let result = parse_validate_and_lower(input, ¶ms);
|
let result = parse_validate_and_lower_module(input, ¶ms);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let lowered = result.unwrap();
|
let module = result.unwrap();
|
||||||
let pred2 = &expect_batch(&lowered).predicates()[1];
|
let pred2 = &module.batch.predicates()[1];
|
||||||
let stmt = &pred2.statements()[0];
|
let stmt = &pred2.statements()[0];
|
||||||
|
|
||||||
// Should be BatchSelf(0) referring to pred1
|
// Should be BatchSelf(0) referring to pred1
|
||||||
|
|
@ -638,7 +607,7 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let result = parse_validate_and_lower(input, ¶ms);
|
let result = parse_validate_and_lower_module(input, ¶ms);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -651,11 +620,11 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let result = parse_validate_and_lower(input, ¶ms);
|
let result = parse_validate_and_lower_module(input, ¶ms);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let lowered = result.unwrap();
|
let module = result.unwrap();
|
||||||
let pred = &expect_batch(&lowered).predicates()[0];
|
let pred = &module.batch.predicates()[0];
|
||||||
let stmt = &pred.statements()[0];
|
let stmt = &pred.statements()[0];
|
||||||
|
|
||||||
// Should desugar to the Contains predicate
|
// Should desugar to the Contains predicate
|
||||||
|
|
@ -677,7 +646,7 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let result = parse_validate_and_lower(input, ¶ms);
|
let result = parse_validate_and_lower_module(input, ¶ms);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -706,18 +675,18 @@ mod tests {
|
||||||
let parsed = parse_podlang(&input).expect("Failed to parse");
|
let parsed = parse_podlang(&input).expect("Failed to parse");
|
||||||
let document =
|
let document =
|
||||||
parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document");
|
parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document");
|
||||||
let validated = validate(document, &[]).expect("Failed to validate");
|
let validated =
|
||||||
let result = lower(validated, ¶ms, "test_batch".to_string());
|
validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate");
|
||||||
|
let result = lower_module(validated, ¶ms, "test_batch");
|
||||||
|
|
||||||
assert!(result.is_ok(), "Lowering failed: {:?}", result.err());
|
assert!(result.is_ok(), "Lowering failed: {:?}", result.err());
|
||||||
|
|
||||||
let lowered = result.unwrap();
|
let module = result.unwrap();
|
||||||
let batch = expect_batch(&lowered);
|
|
||||||
|
|
||||||
// Should have one custom predicate
|
// Should have one custom predicate
|
||||||
assert_eq!(batch.predicates().len(), 1);
|
assert_eq!(module.batch.predicates().len(), 1);
|
||||||
|
|
||||||
let pred = &batch.predicates()[0];
|
let pred = &module.batch.predicates()[0];
|
||||||
assert_eq!(pred.name, "my_pred");
|
assert_eq!(pred.name, "my_pred");
|
||||||
// 2 statements: Equal and external_check
|
// 2 statements: Equal and external_check
|
||||||
assert_eq!(pred.statements().len(), 2);
|
assert_eq!(pred.statements().len(), 2);
|
||||||
|
|
|
||||||
|
|
@ -620,7 +620,7 @@ fn generate_chain_predicates(
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let chain_call = StatementTmpl {
|
let chain_call = StatementTmpl {
|
||||||
predicate: next_pred_name,
|
predicate: PredicateRef::Local(next_pred_name),
|
||||||
args: chain_call_args,
|
args: chain_call_args,
|
||||||
span: None,
|
span: None,
|
||||||
};
|
};
|
||||||
|
|
@ -832,7 +832,7 @@ mod tests {
|
||||||
let original = &chain[1];
|
let original = &chain[1];
|
||||||
assert_eq!(original.name.name, "complex");
|
assert_eq!(original.name.name, "complex");
|
||||||
let last_stmt = original.statements.last().unwrap();
|
let last_stmt = original.statements.last().unwrap();
|
||||||
assert_eq!(last_stmt.predicate.name, "complex_1");
|
assert_eq!(last_stmt.predicate.predicate_name(), "complex_1");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ use std::{
|
||||||
use hex::ToHex;
|
use hex::ToHex;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
lang::frontend_ast::*,
|
lang::{frontend_ast::*, Module},
|
||||||
middleware::{CustomPredicateBatch, Hash, NativePredicate},
|
middleware::{CustomPredicateBatch, Hash, NativePredicate},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -49,6 +49,8 @@ pub struct SymbolTable {
|
||||||
pub predicates: HashMap<String, PredicateInfo>,
|
pub predicates: HashMap<String, PredicateInfo>,
|
||||||
/// Wildcard scopes for each custom predicate
|
/// Wildcard scopes for each custom predicate
|
||||||
pub wildcard_scopes: HashMap<String, WildcardScope>,
|
pub wildcard_scopes: HashMap<String, WildcardScope>,
|
||||||
|
/// Imported modules (bound name → Module reference)
|
||||||
|
pub imported_modules: HashMap<String, Arc<Module>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Information about a predicate
|
/// Information about a predicate
|
||||||
|
|
@ -71,6 +73,11 @@ pub enum PredicateKind {
|
||||||
batch: Arc<CustomPredicateBatch>,
|
batch: Arc<CustomPredicateBatch>,
|
||||||
index: usize,
|
index: usize,
|
||||||
},
|
},
|
||||||
|
ModuleImported {
|
||||||
|
module_name: String,
|
||||||
|
predicate_name: String,
|
||||||
|
predicate_index: usize,
|
||||||
|
},
|
||||||
IntroImported {
|
IntroImported {
|
||||||
name: String,
|
name: String,
|
||||||
verifier_data_hash: Hash,
|
verifier_data_hash: Hash,
|
||||||
|
|
@ -107,39 +114,45 @@ pub enum DiagnosticLevel {
|
||||||
|
|
||||||
pub use crate::lang::error::ValidationError;
|
pub use crate::lang::error::ValidationError;
|
||||||
|
|
||||||
/// Validate an AST document
|
/// Mode for parsing/validation - determines what constructs are allowed
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ParseMode {
|
||||||
|
/// Module mode: predicate definitions allowed, REQUEST block not allowed
|
||||||
|
Module,
|
||||||
|
/// Request mode: REQUEST block required, predicate definitions not allowed
|
||||||
|
Request,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate an AST document in the given mode
|
||||||
pub fn validate(
|
pub fn validate(
|
||||||
document: Document,
|
document: Document,
|
||||||
available_batches: &[Arc<CustomPredicateBatch>],
|
available_modules: &HashMap<Hash, Arc<Module>>,
|
||||||
|
mode: ParseMode,
|
||||||
) -> Result<ValidatedAST, ValidationError> {
|
) -> Result<ValidatedAST, ValidationError> {
|
||||||
let validator = Validator::new(available_batches);
|
let validator = Validator::new(available_modules, mode);
|
||||||
validator.validate(document)
|
validator.validate(document)
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Validator {
|
struct Validator {
|
||||||
available_batches: HashMap<String, Arc<CustomPredicateBatch>>,
|
available_modules: HashMap<Hash, Arc<Module>>,
|
||||||
symbols: SymbolTable,
|
symbols: SymbolTable,
|
||||||
diagnostics: Vec<Diagnostic>,
|
diagnostics: Vec<Diagnostic>,
|
||||||
custom_predicate_count: usize,
|
custom_predicate_count: usize,
|
||||||
|
mode: ParseMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Validator {
|
impl Validator {
|
||||||
fn new(batches: &[Arc<CustomPredicateBatch>]) -> Self {
|
fn new(available_modules: &HashMap<Hash, Arc<Module>>, mode: ParseMode) -> Self {
|
||||||
let mut available_batches = HashMap::new();
|
|
||||||
for batch in batches {
|
|
||||||
// Store by hex ID for lookup
|
|
||||||
let id = format!("0x{}", batch.id().encode_hex::<String>());
|
|
||||||
available_batches.insert(id, batch.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
available_batches,
|
available_modules: available_modules.clone(),
|
||||||
symbols: SymbolTable {
|
symbols: SymbolTable {
|
||||||
predicates: HashMap::new(),
|
predicates: HashMap::new(),
|
||||||
wildcard_scopes: HashMap::new(),
|
wildcard_scopes: HashMap::new(),
|
||||||
|
imported_modules: HashMap::new(),
|
||||||
},
|
},
|
||||||
diagnostics: Vec::new(),
|
diagnostics: Vec::new(),
|
||||||
custom_predicate_count: 0,
|
custom_predicate_count: 0,
|
||||||
|
mode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -160,25 +173,36 @@ impl Validator {
|
||||||
fn build_symbol_table(&mut self, document: &Document) -> Result<(), ValidationError> {
|
fn build_symbol_table(&mut self, document: &Document) -> Result<(), ValidationError> {
|
||||||
// First process imports
|
// First process imports
|
||||||
for item in &document.items {
|
for item in &document.items {
|
||||||
if let DocumentItem::UseBatchStatement(use_stmt) = item {
|
if let DocumentItem::UseModuleStatement(use_stmt) = item {
|
||||||
self.process_use_batch_statement(use_stmt)?;
|
self.process_use_module_statement(use_stmt)?;
|
||||||
}
|
}
|
||||||
if let DocumentItem::UseIntroStatement(use_stmt) = item {
|
if let DocumentItem::UseIntroStatement(use_stmt) = item {
|
||||||
self.process_use_intro_statement(use_stmt)?;
|
self.process_use_intro_statement(use_stmt)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then process custom predicate definitions
|
// Check mode constraints for predicate definitions
|
||||||
|
let mut has_predicates = false;
|
||||||
for item in &document.items {
|
for item in &document.items {
|
||||||
if let DocumentItem::CustomPredicateDef(pred_def) = item {
|
if let DocumentItem::CustomPredicateDef(pred_def) = item {
|
||||||
|
if self.mode == ParseMode::Request {
|
||||||
|
return Err(ValidationError::PredicatesNotAllowedInRequest {
|
||||||
|
span: pred_def.span,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
has_predicates = true;
|
||||||
self.process_custom_predicate_def(pred_def)?;
|
self.process_custom_predicate_def(pred_def)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for multiple REQUEST definitions (only one allowed)
|
// Check mode constraints for REQUEST blocks
|
||||||
|
let mut has_request = false;
|
||||||
let mut first_request_span = None;
|
let mut first_request_span = None;
|
||||||
for item in &document.items {
|
for item in &document.items {
|
||||||
if let DocumentItem::RequestDef(req) = item {
|
if let DocumentItem::RequestDef(req) = item {
|
||||||
|
if self.mode == ParseMode::Module {
|
||||||
|
return Err(ValidationError::RequestNotAllowedInModule { span: req.span });
|
||||||
|
}
|
||||||
if let Some(first_span) = first_request_span {
|
if let Some(first_span) = first_request_span {
|
||||||
return Err(ValidationError::MultipleRequestDefinitions {
|
return Err(ValidationError::MultipleRequestDefinitions {
|
||||||
first_span: Some(first_span),
|
first_span: Some(first_span),
|
||||||
|
|
@ -186,61 +210,44 @@ impl Validator {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
first_request_span = req.span;
|
first_request_span = req.span;
|
||||||
|
has_request = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enforce that modules have predicates and requests have a REQUEST block
|
||||||
|
match self.mode {
|
||||||
|
ParseMode::Module if !has_predicates => {
|
||||||
|
return Err(ValidationError::NoPredicatesInModule);
|
||||||
|
}
|
||||||
|
ParseMode::Request if !has_request => {
|
||||||
|
return Err(ValidationError::NoRequestBlock);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process_use_batch_statement(
|
fn process_use_module_statement(
|
||||||
&mut self,
|
&mut self,
|
||||||
use_stmt: &UseBatchStatement,
|
use_stmt: &UseModuleStatement,
|
||||||
) -> Result<(), ValidationError> {
|
) -> Result<(), ValidationError> {
|
||||||
let batch_id = format!("0x{}", use_stmt.batch_ref.hash.encode_hex::<String>());
|
let alias = &use_stmt.alias.name;
|
||||||
|
let hash = &use_stmt.hash.hash;
|
||||||
|
|
||||||
let batch = self.available_batches.get(&batch_id).ok_or_else(|| {
|
// Check if the module is available by hash
|
||||||
ValidationError::BatchNotFound {
|
let module =
|
||||||
id: batch_id.clone(),
|
self.available_modules
|
||||||
span: use_stmt.batch_ref.span,
|
.get(hash)
|
||||||
}
|
.ok_or_else(|| ValidationError::ModuleNotFound {
|
||||||
|
name: hash.encode_hex::<String>(),
|
||||||
|
span: use_stmt.span,
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if use_stmt.imports.len() != batch.predicates().len() {
|
// Store the module keyed by alias for later qualified name resolution
|
||||||
return Err(ValidationError::ImportArityMismatch {
|
self.symbols
|
||||||
expected: batch.predicates().len(),
|
.imported_modules
|
||||||
found: use_stmt.imports.len(),
|
.insert(alias.clone(), module.clone());
|
||||||
span: use_stmt.span,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
for (i, import) in use_stmt.imports.iter().enumerate() {
|
|
||||||
if let ImportName::Named(name) = import {
|
|
||||||
if self.symbols.predicates.contains_key(name) {
|
|
||||||
return Err(ValidationError::DuplicateImport {
|
|
||||||
name: name.clone(),
|
|
||||||
span: use_stmt.span,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let pred = &batch.predicates()[i];
|
|
||||||
// CustomPredicate has args_len (public args) and wildcard_names (total args)
|
|
||||||
let total_arity = pred.wildcard_names.len();
|
|
||||||
let public_arity = pred.args_len;
|
|
||||||
|
|
||||||
self.symbols.predicates.insert(
|
|
||||||
name.clone(),
|
|
||||||
PredicateInfo {
|
|
||||||
kind: PredicateKind::BatchImported {
|
|
||||||
batch: batch.clone(),
|
|
||||||
index: i,
|
|
||||||
},
|
|
||||||
arity: total_arity,
|
|
||||||
public_arity,
|
|
||||||
source_span: use_stmt.span,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -435,7 +442,11 @@ impl Validator {
|
||||||
stmt: &StatementTmpl,
|
stmt: &StatementTmpl,
|
||||||
wildcard_context: Option<(&str, &WildcardScope)>,
|
wildcard_context: Option<(&str, &WildcardScope)>,
|
||||||
) -> Result<(), ValidationError> {
|
) -> Result<(), ValidationError> {
|
||||||
let pred_name = &stmt.predicate.name;
|
let pred_name = stmt.predicate.predicate_name();
|
||||||
|
let pred_span = match &stmt.predicate {
|
||||||
|
PredicateRef::Local(id) => id.span,
|
||||||
|
PredicateRef::Qualified { predicate, .. } => predicate.span,
|
||||||
|
};
|
||||||
|
|
||||||
let wc_names = match wildcard_context {
|
let wc_names = match wildcard_context {
|
||||||
Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(),
|
Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(),
|
||||||
|
|
@ -444,7 +455,39 @@ impl Validator {
|
||||||
self.validate_wildcard_names(&wc_names)?;
|
self.validate_wildcard_names(&wc_names)?;
|
||||||
|
|
||||||
// Check if predicate exists
|
// Check if predicate exists
|
||||||
let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) {
|
let pred_info = match &stmt.predicate {
|
||||||
|
PredicateRef::Qualified { module, predicate } => {
|
||||||
|
// Look up the predicate in the imported module
|
||||||
|
let module_name = &module.name;
|
||||||
|
if let Some(imported_module) = self.symbols.imported_modules.get(module_name) {
|
||||||
|
// Find the predicate in the module
|
||||||
|
if let Some(&idx) = imported_module.predicate_index.get(&predicate.name) {
|
||||||
|
let module_pred = &imported_module.batch.predicates()[idx];
|
||||||
|
Some(PredicateInfo {
|
||||||
|
kind: PredicateKind::ModuleImported {
|
||||||
|
module_name: module_name.clone(),
|
||||||
|
predicate_name: predicate.name.clone(),
|
||||||
|
predicate_index: idx,
|
||||||
|
},
|
||||||
|
arity: module_pred.wildcard_names.len(),
|
||||||
|
public_arity: module_pred.args_len,
|
||||||
|
source_span: None,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
return Err(ValidationError::UndefinedPredicate {
|
||||||
|
name: format!("{}::{}", module_name, predicate.name),
|
||||||
|
span: pred_span,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(ValidationError::ModuleNotFound {
|
||||||
|
name: module_name.clone(),
|
||||||
|
span: module.span,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PredicateRef::Local(_) => {
|
||||||
|
if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||||
// Native predicate
|
// Native predicate
|
||||||
Some(PredicateInfo {
|
Some(PredicateInfo {
|
||||||
kind: PredicateKind::Native(native),
|
kind: PredicateKind::Native(native),
|
||||||
|
|
@ -455,20 +498,22 @@ impl Validator {
|
||||||
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
|
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
|
||||||
// Custom or imported predicate
|
// Custom or imported predicate
|
||||||
Some(info.clone())
|
Some(info.clone())
|
||||||
} else if wc_names.contains(pred_name) {
|
} else if wc_names.contains(&pred_name.to_string()) {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
return Err(ValidationError::UndefinedPredicate {
|
return Err(ValidationError::UndefinedPredicate {
|
||||||
name: pred_name.clone(),
|
name: pred_name.to_string(),
|
||||||
span: stmt.predicate.span,
|
span: pred_span,
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(ref pred_info) = pred_info {
|
if let Some(ref pred_info) = pred_info {
|
||||||
let expected_arity = pred_info.public_arity;
|
let expected_arity = pred_info.public_arity;
|
||||||
if stmt.args.len() != expected_arity {
|
if stmt.args.len() != expected_arity {
|
||||||
return Err(ValidationError::ArgumentCountMismatch {
|
return Err(ValidationError::ArgumentCountMismatch {
|
||||||
predicate: pred_name.clone(),
|
predicate: pred_name.to_string(),
|
||||||
expected: expected_arity,
|
expected: expected_arity,
|
||||||
found: stmt.args.len(),
|
found: stmt.args.len(),
|
||||||
span: stmt.span,
|
span: stmt.span,
|
||||||
|
|
@ -491,13 +536,15 @@ impl Validator {
|
||||||
// For custom predicates, only wildcards and literals are allowed
|
// For custom predicates, only wildcards and literals are allowed
|
||||||
if matches!(
|
if matches!(
|
||||||
pred_info.map(|i| &i.kind),
|
pred_info.map(|i| &i.kind),
|
||||||
Some(PredicateKind::Custom { .. }) | Some(PredicateKind::BatchImported { .. })
|
Some(PredicateKind::Custom { .. })
|
||||||
|
| Some(PredicateKind::BatchImported { .. })
|
||||||
|
| Some(PredicateKind::ModuleImported { .. })
|
||||||
) {
|
) {
|
||||||
for arg in &stmt.args {
|
for arg in &stmt.args {
|
||||||
match arg {
|
match arg {
|
||||||
StatementTmplArg::AnchoredKey(_) => {
|
StatementTmplArg::AnchoredKey(_) => {
|
||||||
return Err(ValidationError::InvalidArgumentType {
|
return Err(ValidationError::InvalidArgumentType {
|
||||||
predicate: stmt.predicate.name.clone(),
|
predicate: stmt.predicate.predicate_name().to_string(),
|
||||||
span: stmt.span,
|
span: stmt.span,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -552,25 +599,30 @@ impl Validator {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{
|
use crate::{
|
||||||
lang::{frontend_ast::parse::parse_document, parser::parse_podlang},
|
lang::{frontend_ast::parse::parse_document, parser::parse_podlang, Module},
|
||||||
middleware::{CustomPredicate, Params, EMPTY_HASH},
|
middleware::{CustomPredicate, Params, EMPTY_HASH},
|
||||||
};
|
};
|
||||||
|
|
||||||
fn parse_and_validate(
|
fn parse_and_validate_module(
|
||||||
input: &str,
|
input: &str,
|
||||||
batches: &[Arc<CustomPredicateBatch>],
|
modules: &HashMap<Hash, Arc<Module>>,
|
||||||
) -> Result<ValidatedAST, ValidationError> {
|
) -> Result<ValidatedAST, ValidationError> {
|
||||||
let parsed = parse_podlang(input).expect("Failed to parse");
|
let parsed = parse_podlang(input).expect("Failed to parse");
|
||||||
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
|
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
|
||||||
validate(document, batches)
|
validate(document, modules, ParseMode::Module)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn parse_and_validate_request(
|
||||||
fn test_validate_empty() {
|
input: &str,
|
||||||
let result = parse_and_validate("", &[]);
|
modules: &HashMap<Hash, Arc<Module>>,
|
||||||
assert!(result.is_ok());
|
) -> Result<ValidatedAST, ValidationError> {
|
||||||
|
let parsed = parse_podlang(input).expect("Failed to parse");
|
||||||
|
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
|
||||||
|
validate(document, modules, ParseMode::Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -578,7 +630,7 @@ mod tests {
|
||||||
let input = r#"REQUEST(
|
let input = r#"REQUEST(
|
||||||
Equal(A["foo"], B["bar"])
|
Equal(A["foo"], B["bar"])
|
||||||
)"#;
|
)"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_request(input, &HashMap::new());
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -589,7 +641,7 @@ mod tests {
|
||||||
Equal(A["foo"], B["bar"])
|
Equal(A["foo"], B["bar"])
|
||||||
)
|
)
|
||||||
"#;
|
"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_module(input, &HashMap::new());
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let validated = result.unwrap();
|
let validated = result.unwrap();
|
||||||
|
|
@ -602,7 +654,7 @@ mod tests {
|
||||||
let input = r#"REQUEST(
|
let input = r#"REQUEST(
|
||||||
UndefinedPred(A, B)
|
UndefinedPred(A, B)
|
||||||
)"#;
|
)"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_request(input, &HashMap::new());
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(ValidationError::UndefinedPredicate { .. })
|
Err(ValidationError::UndefinedPredicate { .. })
|
||||||
|
|
@ -616,7 +668,7 @@ mod tests {
|
||||||
Equal(A["foo"], B["bar"])
|
Equal(A["foo"], B["bar"])
|
||||||
)
|
)
|
||||||
"#;
|
"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_module(input, &HashMap::new());
|
||||||
assert!(
|
assert!(
|
||||||
matches!(result, Err(ValidationError::UndefinedWildcard { name, .. }) if name == "B")
|
matches!(result, Err(ValidationError::UndefinedWildcard { name, .. }) if name == "B")
|
||||||
);
|
);
|
||||||
|
|
@ -627,7 +679,7 @@ mod tests {
|
||||||
let input = r#"REQUEST(
|
let input = r#"REQUEST(
|
||||||
Equal(A, B, C)
|
Equal(A, B, C)
|
||||||
)"#;
|
)"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_request(input, &HashMap::new());
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(ValidationError::ArgumentCountMismatch { .. })
|
Err(ValidationError::ArgumentCountMismatch { .. })
|
||||||
|
|
@ -640,7 +692,7 @@ mod tests {
|
||||||
my_pred(A) = AND (Equal(A["x"], 1))
|
my_pred(A) = AND (Equal(A["x"], 1))
|
||||||
my_pred(B) = AND (Equal(B["y"], 2))
|
my_pred(B) = AND (Equal(B["y"], 2))
|
||||||
"#;
|
"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_module(input, &HashMap::new());
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(ValidationError::DuplicatePredicate { .. })
|
Err(ValidationError::DuplicatePredicate { .. })
|
||||||
|
|
@ -652,7 +704,7 @@ mod tests {
|
||||||
let input = r#"
|
let input = r#"
|
||||||
my_pred(A, A) = AND (Equal(A["x"], 1))
|
my_pred(A, A) = AND (Equal(A["x"], 1))
|
||||||
"#;
|
"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_module(input, &HashMap::new());
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(ValidationError::DuplicateWildcard { .. })
|
Err(ValidationError::DuplicateWildcard { .. })
|
||||||
|
|
@ -664,7 +716,7 @@ mod tests {
|
||||||
let input = r#"
|
let input = r#"
|
||||||
my_pred(A, Lt) = AND (Equal(A["x"], Lt))
|
my_pred(A, Lt) = AND (Equal(A["x"], Lt))
|
||||||
"#;
|
"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_module(input, &HashMap::new());
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(ValidationError::WildcardPredicateNameCollision { .. })
|
Err(ValidationError::WildcardPredicateNameCollision { .. })
|
||||||
|
|
@ -673,16 +725,36 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_custom_predicate_with_anchored_key() {
|
fn test_custom_predicate_with_anchored_key() {
|
||||||
let input = r#"
|
// First create a module with the predicate
|
||||||
my_pred(A, B) = AND (
|
let params = Params::default();
|
||||||
Equal(A["foo"], B["bar"])
|
let pred = CustomPredicate::and(
|
||||||
|
¶ms,
|
||||||
|
"my_pred".to_string(),
|
||||||
|
vec![],
|
||||||
|
2,
|
||||||
|
vec!["A".to_string(), "B".to_string()],
|
||||||
)
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]);
|
||||||
|
let test_module = Arc::new(Module::new(batch, HashMap::new()));
|
||||||
|
let module_hash = test_module.id().encode_hex::<String>();
|
||||||
|
|
||||||
|
let mut available_modules = HashMap::new();
|
||||||
|
available_modules.insert(test_module.id(), test_module);
|
||||||
|
|
||||||
|
// Test that passing anchored key to custom predicate fails
|
||||||
|
let input = format!(
|
||||||
|
r#"
|
||||||
|
use module 0x{} as testmod
|
||||||
|
|
||||||
REQUEST(
|
REQUEST(
|
||||||
my_pred(X["key"], Y)
|
testmod::my_pred(X["key"], Y)
|
||||||
)
|
)
|
||||||
"#;
|
"#,
|
||||||
let result = parse_and_validate(input, &[]);
|
module_hash
|
||||||
|
);
|
||||||
|
let result = parse_and_validate_request(&input, &available_modules);
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(ValidationError::InvalidArgumentType { .. })
|
Err(ValidationError::InvalidArgumentType { .. })
|
||||||
|
|
@ -700,7 +772,7 @@ mod tests {
|
||||||
Equal(B["x"], 1)
|
Equal(B["x"], 1)
|
||||||
)
|
)
|
||||||
"#;
|
"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_module(input, &HashMap::new());
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -712,7 +784,7 @@ mod tests {
|
||||||
Equal(B["z"], C["w"])
|
Equal(B["z"], C["w"])
|
||||||
)
|
)
|
||||||
"#;
|
"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_module(input, &HashMap::new());
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let validated = result.unwrap();
|
let validated = result.unwrap();
|
||||||
|
|
@ -743,7 +815,7 @@ mod tests {
|
||||||
span: None,
|
span: None,
|
||||||
})],
|
})],
|
||||||
};
|
};
|
||||||
let result = validate(document, &[]);
|
let result = validate(document, &HashMap::new(), ParseMode::Module);
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(ValidationError::EmptyStatementList { .. })
|
Err(ValidationError::EmptyStatementList { .. })
|
||||||
|
|
@ -756,7 +828,7 @@ mod tests {
|
||||||
REQUEST(Equal(A["x"], 1))
|
REQUEST(Equal(A["x"], 1))
|
||||||
REQUEST(Equal(B["y"], 2))
|
REQUEST(Equal(B["y"], 2))
|
||||||
"#;
|
"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_request(input, &HashMap::new());
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result,
|
result,
|
||||||
Err(ValidationError::MultipleRequestDefinitions { .. })
|
Err(ValidationError::MultipleRequestDefinitions { .. })
|
||||||
|
|
@ -764,10 +836,14 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_use_statement() {
|
fn test_use_module_statement() {
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use hex::ToHex;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
|
|
||||||
// Create a batch to import
|
// Create a module to import
|
||||||
let pred = CustomPredicate::and(
|
let pred = CustomPredicate::and(
|
||||||
¶ms,
|
¶ms,
|
||||||
"imported".to_string(),
|
"imported".to_string(),
|
||||||
|
|
@ -778,28 +854,33 @@ mod tests {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]);
|
let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]);
|
||||||
|
let test_module = Arc::new(Module::new(batch, HashMap::new()));
|
||||||
|
let module_hash = test_module.id().encode_hex::<String>();
|
||||||
|
|
||||||
|
let mut available_modules = HashMap::new();
|
||||||
|
available_modules.insert(test_module.id(), test_module);
|
||||||
|
|
||||||
let batch_id = batch.id().encode_hex::<String>();
|
|
||||||
let input = format!(
|
let input = format!(
|
||||||
r#"
|
r#"
|
||||||
use batch imported_pred from 0x{}
|
use module 0x{} as testmod
|
||||||
use intro intro_pred() from 0x{}
|
use intro intro_pred() from 0x{}
|
||||||
|
|
||||||
REQUEST(
|
REQUEST(
|
||||||
imported_pred(A, B)
|
testmod::imported(A, B)
|
||||||
intro_pred()
|
intro_pred()
|
||||||
)
|
)
|
||||||
"#,
|
"#,
|
||||||
batch_id,
|
module_hash,
|
||||||
EMPTY_HASH.encode_hex::<String>()
|
EMPTY_HASH.encode_hex::<String>()
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = parse_and_validate(&input, &[batch]);
|
let result = parse_and_validate_request(&input, &available_modules);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
let validated = result.unwrap();
|
let validated = result.unwrap();
|
||||||
assert!(validated.symbols.predicates.contains_key("imported_pred"));
|
// Module predicates are accessed via qualified names, so no local binding
|
||||||
assert!(validated.symbols.predicates.contains_key("intro_pred"));
|
assert!(validated.symbols.predicates.contains_key("intro_pred"));
|
||||||
|
assert!(validated.symbols.imported_modules.contains_key("testmod"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -809,7 +890,7 @@ mod tests {
|
||||||
DictContains(D, K, V)
|
DictContains(D, K, V)
|
||||||
SetNotContains(S, E)
|
SetNotContains(S, E)
|
||||||
)"#;
|
)"#;
|
||||||
let result = parse_and_validate(input, &[]);
|
let result = parse_and_validate_request(input, &HashMap::new());
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,12 +26,9 @@ arg_section = {
|
||||||
public_arg_list = { identifier ~ ("," ~ identifier)* }
|
public_arg_list = { identifier ~ ("," ~ identifier)* }
|
||||||
private_arg_list = { identifier ~ ("," ~ identifier)* }
|
private_arg_list = { identifier ~ ("," ~ identifier)* }
|
||||||
|
|
||||||
document = { SOI ~ (use_batch_statement | use_intro_statement | custom_predicate_def | request_def)* ~ EOI }
|
document = { SOI ~ (use_module_statement | use_intro_statement | custom_predicate_def | request_def)* ~ EOI }
|
||||||
|
|
||||||
use_batch_statement = { "use" ~ "batch" ~ use_predicate_list ~ "from" ~ batch_ref }
|
use_module_statement = { "use" ~ "module" ~ hash_hex ~ "as" ~ identifier }
|
||||||
use_predicate_list = { import_name ~ ("," ~ import_name)* }
|
|
||||||
import_name = { identifier | "_" }
|
|
||||||
batch_ref = { hash_hex }
|
|
||||||
|
|
||||||
use_intro_statement = { "use" ~ "intro" ~ identifier ~ "(" ~ use_intro_arg_list? ~ ")" ~ "from" ~ intro_predicate_ref }
|
use_intro_statement = { "use" ~ "intro" ~ identifier ~ "(" ~ use_intro_arg_list? ~ ")" ~ "from" ~ intro_predicate_ref }
|
||||||
use_intro_arg_list = { identifier ~ ("," ~ identifier)* }
|
use_intro_arg_list = { identifier ~ ("," ~ identifier)* }
|
||||||
|
|
@ -55,7 +52,11 @@ statement_list = { statement+ }
|
||||||
statement_arg = { literal_value | anchored_key | identifier }
|
statement_arg = { literal_value | anchored_key | identifier }
|
||||||
statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* }
|
statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* }
|
||||||
|
|
||||||
statement = { identifier ~ "(" ~ statement_arg_list? ~ ")" }
|
// Predicate reference: either qualified (module::predicate) or local (predicate)
|
||||||
|
predicate_ref = { qualified_predicate_ref | identifier }
|
||||||
|
qualified_predicate_ref = { identifier ~ "::" ~ identifier }
|
||||||
|
|
||||||
|
statement = { predicate_ref ~ "(" ~ statement_arg_list? ~ ")" }
|
||||||
|
|
||||||
// Anchored Key: Var["key_literal"] or Var.key_identifier
|
// Anchored Key: Var["key_literal"] or Var.key_identifier
|
||||||
anchored_key = {
|
anchored_key = {
|
||||||
|
|
|
||||||
465
src/lang/mod.rs
465
src/lang/mod.rs
|
|
@ -1,110 +1,118 @@
|
||||||
//! Podlang front-end: parsing, validation, lowering, and multi-batch output.
|
//! Podlang front-end: parsing, validation, and lowering.
|
||||||
//!
|
//!
|
||||||
//! This module is the high-level entrypoint to the Podlang pipeline. It:
|
//! This module is the high-level entrypoint to the Podlang pipeline.
|
||||||
//! - Parses a Podlang document (`parse_podlang`).
|
|
||||||
//! - Validates names, imports, and well-formedness (`frontend_ast_validate`).
|
|
||||||
//! - Lowers to middleware structures, including automatic predicate splitting and
|
|
||||||
//! dependency-aware packing into one or more custom predicate batches (`frontend_ast_split`,
|
|
||||||
//! `frontend_ast_batch`, `frontend_ast_lower`).
|
|
||||||
//!
|
//!
|
||||||
//! The result is a [`PodlangOutput`], which contains:
|
//! ## API
|
||||||
//! - `custom_batches`: a [`PredicateBatches`] container (possibly empty) with all custom
|
|
||||||
//! predicates defined in the document. Use
|
|
||||||
//! [`PredicateBatches::apply_predicate`](crate::lang::frontend_ast_batch::PredicateBatches::apply_predicate)
|
|
||||||
//! to apply a predicate into a `MainPodBuilder` (recommended primary API), or
|
|
||||||
//! [`apply_predicate_with`](crate::lang::frontend_ast_batch::PredicateBatches::apply_predicate_with)
|
|
||||||
//! for advanced control.
|
|
||||||
//! - `request`: a `PodRequest` containing the request templates defined by a `REQUEST(...)` block
|
|
||||||
//! in the document (or empty if none was provided).
|
|
||||||
//!
|
//!
|
||||||
//! Notes
|
//! - [`load_module`]: Load a module file containing predicate definitions.
|
||||||
//! - Predicate splitting: large predicates are automatically split into a chain of smaller
|
//! Returns a [`Module`] wrapping a `CustomPredicateBatch`.
|
||||||
//! predicates while preserving semantics; only the final chain result is public when applying a
|
//!
|
||||||
//! predicate as public.
|
//! - [`parse_request`]: Parse a request file containing a REQUEST block.
|
||||||
//! - Multi-batch packing: predicates are packed dependency-aware; cross-batch references always
|
//! Returns a [`PodRequest`] with statement templates.
|
||||||
//! point to earlier batches and forward references cannot occur.
|
//!
|
||||||
//! - Backwards compatibility: `PodlangOutput::first_batch()` is provided to ease migration of code
|
//! ## Module vs Request
|
||||||
//! that expects a single custom predicate batch.
|
//!
|
||||||
|
//! - **Modules** contain predicate definitions (`pred(A) = AND(...)`) and imports.
|
||||||
|
//! They cannot contain a REQUEST block.
|
||||||
|
//!
|
||||||
|
//! - **Requests** contain a REQUEST block and imports.
|
||||||
|
//! They cannot define predicates.
|
||||||
|
//!
|
||||||
|
//! ## Using Modules
|
||||||
|
//!
|
||||||
|
//! Use [`Module::apply_predicate`] to apply a predicate into a `MainPodBuilder`
|
||||||
|
//! (recommended), or [`Module::apply_predicate_with`] for advanced control.
|
||||||
|
//!
|
||||||
|
//! Large predicates are automatically split into chains of smaller predicates;
|
||||||
|
//! `apply_predicate` handles this transparently.
|
||||||
//!
|
//!
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod frontend_ast;
|
pub mod frontend_ast;
|
||||||
pub mod frontend_ast_batch;
|
|
||||||
pub mod frontend_ast_lower;
|
pub mod frontend_ast_lower;
|
||||||
pub mod frontend_ast_split;
|
pub mod frontend_ast_split;
|
||||||
pub mod frontend_ast_validate;
|
pub mod frontend_ast_validate;
|
||||||
|
pub mod module;
|
||||||
pub mod parser;
|
pub mod parser;
|
||||||
pub mod pretty_print;
|
pub mod pretty_print;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
pub use error::LangError;
|
pub use error::LangError;
|
||||||
pub use frontend_ast_batch::{MultiOperationError, PredicateBatches};
|
|
||||||
pub use frontend_ast_split::{SplitChainInfo, SplitChainPiece, SplitResult};
|
pub use frontend_ast_split::{SplitChainInfo, SplitChainPiece, SplitResult};
|
||||||
|
pub use module::{Module, MultiOperationError};
|
||||||
pub use parser::{parse_podlang, Pairs, ParseError, Rule};
|
pub use parser::{parse_podlang, Pairs, ParseError, Rule};
|
||||||
pub use pretty_print::PrettyPrint;
|
pub use pretty_print::PrettyPrint;
|
||||||
|
|
||||||
use crate::{
|
use crate::{frontend::PodRequest, middleware::Params};
|
||||||
frontend::PodRequest,
|
|
||||||
middleware::{CustomPredicateBatch, Params},
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Final result of processing a Podlang document.
|
/// Load a module from Podlang source.
|
||||||
///
|
///
|
||||||
/// - `custom_batches`: all custom predicates defined in the document, possibly spanning multiple
|
/// Modules contain predicate definitions and imports, but no REQUEST block.
|
||||||
/// batches. Use [`PredicateBatches`] APIs to look up predicates by name and apply them.
|
|
||||||
/// - `request`: the request templates defined in the document (empty if not present).
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct PodlangOutput {
|
|
||||||
pub custom_batches: PredicateBatches,
|
|
||||||
pub request: PodRequest,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PodlangOutput {
|
|
||||||
/// Get the first batch, if any (for backwards compatibility).
|
|
||||||
///
|
///
|
||||||
/// Prefer using `custom_batches` directly if your code expects multiple batches.
|
/// - `source`: Podlang source code
|
||||||
pub fn first_batch(&self) -> Option<&Arc<CustomPredicateBatch>> {
|
/// - `name`: Name for the module (used in batch naming)
|
||||||
self.custom_batches.first_batch()
|
/// - `params`: Middleware parameters limiting sizes/arity
|
||||||
}
|
/// - `available_modules`: External modules available for `use module ...` imports
|
||||||
}
|
pub fn load_module(
|
||||||
|
source: &str,
|
||||||
/// Parse, validate, and lower a Podlang document into middleware structures.
|
name: &str,
|
||||||
///
|
|
||||||
/// - `input`: Podlang source.
|
|
||||||
/// - `params`: middleware parameters limiting sizes/arity and controlling lowering behavior.
|
|
||||||
/// - `available_batches`: external batches available for `use batch ... from 0x...` imports.
|
|
||||||
///
|
|
||||||
/// Returns a [`PodlangOutput`] containing custom predicate batches (if any) and a `PodRequest`
|
|
||||||
/// (possibly empty).
|
|
||||||
pub fn parse(
|
|
||||||
input: &str,
|
|
||||||
params: &Params,
|
params: &Params,
|
||||||
available_batches: &[Arc<CustomPredicateBatch>],
|
available_modules: Vec<Arc<Module>>,
|
||||||
) -> Result<PodlangOutput, LangError> {
|
) -> Result<Module, LangError> {
|
||||||
let pairs = parse_podlang(input)?;
|
let pairs = parse_podlang(source)?;
|
||||||
let document_pair = pairs
|
let document_pair = pairs
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.next()
|
.next()
|
||||||
.expect("parse_podlang should always return at least one pair for a valid document");
|
.expect("parse_podlang should always return at least one pair for a valid document");
|
||||||
let document = frontend_ast::parse::parse_document(document_pair)?;
|
let document = frontend_ast::parse::parse_document(document_pair)?;
|
||||||
let validated = frontend_ast_validate::validate(document, available_batches)?;
|
let available_modules_map = available_modules
|
||||||
let lowered = frontend_ast_lower::lower(validated, params, "PodlangBatch".to_string())?;
|
.iter()
|
||||||
|
.map(|m| (m.id(), m.clone()))
|
||||||
|
.collect();
|
||||||
|
let validated = frontend_ast_validate::validate(
|
||||||
|
document,
|
||||||
|
&available_modules_map,
|
||||||
|
frontend_ast_validate::ParseMode::Module,
|
||||||
|
)?;
|
||||||
|
let module = frontend_ast_lower::lower_module(validated, params, name)?;
|
||||||
|
Ok(module)
|
||||||
|
}
|
||||||
|
|
||||||
let custom_batches = lowered.batches.unwrap_or_default();
|
/// Parse a request from Podlang source.
|
||||||
|
///
|
||||||
let request = lowered.request.unwrap_or_else(|| {
|
/// Requests contain a REQUEST block and imports, but no predicate definitions.
|
||||||
// If no request, create an empty one
|
///
|
||||||
PodRequest::new(vec![])
|
/// - `source`: Podlang source code
|
||||||
});
|
/// - `params`: Middleware parameters limiting sizes/arity
|
||||||
|
/// - `available_modules`: External modules available for `use module ...` imports
|
||||||
Ok(PodlangOutput {
|
pub fn parse_request(
|
||||||
custom_batches,
|
source: &str,
|
||||||
request,
|
params: &Params,
|
||||||
})
|
available_modules: &[Arc<Module>],
|
||||||
|
) -> Result<PodRequest, LangError> {
|
||||||
|
let pairs = parse_podlang(source)?;
|
||||||
|
let document_pair = pairs
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.expect("parse_podlang should always return at least one pair for a valid document");
|
||||||
|
let document = frontend_ast::parse::parse_document(document_pair)?;
|
||||||
|
let available_modules_map = available_modules
|
||||||
|
.iter()
|
||||||
|
.map(|m| (m.id(), m.clone()))
|
||||||
|
.collect();
|
||||||
|
let validated = frontend_ast_validate::validate(
|
||||||
|
document,
|
||||||
|
&available_modules_map,
|
||||||
|
frontend_ast_validate::ParseMode::Request,
|
||||||
|
)?;
|
||||||
|
let request = frontend_ast_lower::lower_request(validated, params)?;
|
||||||
|
Ok(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use hex::ToHex;
|
use hex::ToHex;
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
|
|
@ -143,11 +151,6 @@ mod tests {
|
||||||
PredicateOrWildcard::Predicate(pred)
|
PredicateOrWildcard::Predicate(pred)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to get the first batch from the output
|
|
||||||
fn first_batch(output: &super::PodlangOutput) -> &Arc<CustomPredicateBatch> {
|
|
||||||
output.first_batch().expect("Expected at least one batch")
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_e2e_simple_predicate() -> Result<(), LangError> {
|
fn test_e2e_simple_predicate() -> Result<(), LangError> {
|
||||||
let input = r#"
|
let input = r#"
|
||||||
|
|
@ -157,12 +160,9 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let processed = parse(input, ¶ms, &[])?;
|
let module = load_module(input, "test_module", ¶ms, vec![])?;
|
||||||
let batch_result = first_batch(&processed);
|
|
||||||
let request_result = processed.request.templates();
|
|
||||||
|
|
||||||
assert_eq!(request_result.len(), 0);
|
assert_eq!(module.batch.predicates().len(), 1);
|
||||||
assert_eq!(batch_result.predicates().len(), 1);
|
|
||||||
|
|
||||||
// Expected structure
|
// Expected structure
|
||||||
let expected_statements = vec![StatementTmpl {
|
let expected_statements = vec![StatementTmpl {
|
||||||
|
|
@ -180,9 +180,9 @@ mod tests {
|
||||||
names(&["PodA", "PodB"]),
|
names(&["PodA", "PodB"]),
|
||||||
)?;
|
)?;
|
||||||
let expected_batch =
|
let expected_batch =
|
||||||
CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]);
|
CustomPredicateBatch::new("test_module".to_string(), vec![expected_predicate]);
|
||||||
|
|
||||||
assert_eq!(*batch_result, expected_batch);
|
assert_eq!(&*module.batch, &*expected_batch);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -197,10 +197,9 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let processed = parse(input, ¶ms, &[])?;
|
let request = parse_request(input, ¶ms, &[])?;
|
||||||
let request_templates = processed.request.templates();
|
let request_templates = request.templates();
|
||||||
|
|
||||||
assert!(processed.custom_batches.is_empty());
|
|
||||||
assert!(!request_templates.is_empty());
|
assert!(!request_templates.is_empty());
|
||||||
|
|
||||||
// Expected structure
|
// Expected structure
|
||||||
|
|
@ -236,12 +235,9 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let processed = parse(input, ¶ms, &[])?;
|
let module = load_module(input, "test_module", ¶ms, vec![])?;
|
||||||
let batch_result = first_batch(&processed);
|
|
||||||
let request_result = processed.request.templates();
|
|
||||||
|
|
||||||
assert_eq!(request_result.len(), 0);
|
assert_eq!(module.batch.predicates().len(), 1);
|
||||||
assert_eq!(batch_result.predicates().len(), 1);
|
|
||||||
|
|
||||||
// Expected structure: Public args: A (index 0). Private args: Temp (index 1)
|
// Expected structure: Public args: A (index 0). Private args: Temp (index 1)
|
||||||
let expected_statements = vec![
|
let expected_statements = vec![
|
||||||
|
|
@ -268,58 +264,51 @@ mod tests {
|
||||||
names(&["A", "Temp"]),
|
names(&["A", "Temp"]),
|
||||||
)?;
|
)?;
|
||||||
let expected_batch =
|
let expected_batch =
|
||||||
CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]);
|
CustomPredicateBatch::new("test_module".to_string(), vec![expected_predicate]);
|
||||||
|
|
||||||
assert_eq!(*batch_result, expected_batch);
|
assert_eq!(&*module.batch, &*expected_batch);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_e2e_request_with_custom_call() -> Result<(), LangError> {
|
fn test_e2e_request_with_custom_call() -> Result<(), LangError> {
|
||||||
let input = r#"
|
// First, load the module
|
||||||
|
let module_input = r#"
|
||||||
my_pred(X, Y) = AND(
|
my_pred(X, Y) = AND(
|
||||||
Equal(X["val"], Y["val"])
|
Equal(X["val"], Y["val"])
|
||||||
)
|
)
|
||||||
|
|
||||||
REQUEST(
|
|
||||||
my_pred(Pod1, Pod2)
|
|
||||||
)
|
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let processed = parse(input, ¶ms, &[])?;
|
let module = Arc::new(load_module(module_input, "my_module", ¶ms, vec![])?);
|
||||||
let batch_result = first_batch(&processed);
|
|
||||||
let request_templates = processed.request.templates();
|
assert_eq!(module.batch.predicates().len(), 1);
|
||||||
|
|
||||||
|
let module_hash = module.id().encode_hex::<String>();
|
||||||
|
|
||||||
|
// Then, parse the request using the module
|
||||||
|
let request_input = format!(
|
||||||
|
r#"
|
||||||
|
use module 0x{} as my_module
|
||||||
|
|
||||||
|
REQUEST(
|
||||||
|
my_module::my_pred(Pod1, Pod2)
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
module_hash
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = parse_request(&request_input, ¶ms, std::slice::from_ref(&module))?;
|
||||||
|
let request_templates = request.templates();
|
||||||
|
|
||||||
assert_eq!(batch_result.predicates().len(), 1);
|
|
||||||
assert!(!request_templates.is_empty());
|
assert!(!request_templates.is_empty());
|
||||||
|
|
||||||
// Expected Batch structure
|
|
||||||
let expected_pred_statements = vec![StatementTmpl {
|
|
||||||
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
|
||||||
args: vec![
|
|
||||||
sta_ak(("X", 0), "val"), // X["val"] -> Wildcard(0), Key("val")
|
|
||||||
sta_ak(("Y", 1), "val"), // Y["val"] -> Wildcard(1), Key("val")
|
|
||||||
],
|
|
||||||
}];
|
|
||||||
let expected_predicate = CustomPredicate::and(
|
|
||||||
¶ms,
|
|
||||||
"my_pred".to_string(),
|
|
||||||
expected_pred_statements,
|
|
||||||
2, // args_len (X, Y)
|
|
||||||
names(&["X", "Y"]),
|
|
||||||
)?;
|
|
||||||
let expected_batch =
|
|
||||||
CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]);
|
|
||||||
|
|
||||||
assert_eq!(*batch_result, expected_batch);
|
|
||||||
|
|
||||||
// Expected Request structure
|
// Expected Request structure
|
||||||
// Pod1 -> Wildcard 0, Pod2 -> Wildcard 1
|
// Pod1 -> Wildcard 0, Pod2 -> Wildcard 1
|
||||||
let expected_request_templates = vec![StatementTmpl {
|
let expected_request_templates = vec![StatementTmpl {
|
||||||
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
||||||
expected_batch,
|
module.batch.clone(),
|
||||||
0,
|
0,
|
||||||
))),
|
))),
|
||||||
args: vec![
|
args: vec![
|
||||||
|
|
@ -335,25 +324,36 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_e2e_request_with_various_args() -> Result<(), LangError> {
|
fn test_e2e_request_with_various_args() -> Result<(), LangError> {
|
||||||
let input = r#"
|
// First, load the module
|
||||||
|
let module_input = r#"
|
||||||
some_pred(A, B, C) = AND( Equal(A["foo"], B["bar"]) )
|
some_pred(A, B, C) = AND( Equal(A["foo"], B["bar"]) )
|
||||||
|
|
||||||
REQUEST(
|
|
||||||
some_pred(
|
|
||||||
Var1, // Wildcard
|
|
||||||
12345, // Int Literal
|
|
||||||
"hello_string" // String Literal (Removed invalid AK args)
|
|
||||||
)
|
|
||||||
Equal(AnotherPod["another_key"], Var1["some_field"])
|
|
||||||
)
|
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let processed = parse(input, ¶ms, &[])?;
|
let module = Arc::new(load_module(module_input, "some_module", ¶ms, vec![])?);
|
||||||
let batch_result = first_batch(&processed);
|
|
||||||
let request_templates = processed.request.templates();
|
let module_hash = module.id().encode_hex::<String>();
|
||||||
|
|
||||||
|
// Then, parse the request
|
||||||
|
let request_input = format!(
|
||||||
|
r#"
|
||||||
|
use module 0x{} as some_module
|
||||||
|
|
||||||
|
REQUEST(
|
||||||
|
some_module::some_pred(
|
||||||
|
Var1, // Wildcard
|
||||||
|
12345, // Int Literal
|
||||||
|
"hello_string" // String Literal
|
||||||
|
)
|
||||||
|
Equal(AnotherPod["another_key"], Var1["some_field"])
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
module_hash
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = parse_request(&request_input, ¶ms, std::slice::from_ref(&module))?;
|
||||||
|
let request_templates = request.templates();
|
||||||
|
|
||||||
assert_eq!(batch_result.predicates().len(), 1); // some_pred is defined
|
|
||||||
assert!(!request_templates.is_empty());
|
assert!(!request_templates.is_empty());
|
||||||
|
|
||||||
// Expected Wildcard Indices in Request Scope:
|
// Expected Wildcard Indices in Request Scope:
|
||||||
|
|
@ -364,7 +364,7 @@ mod tests {
|
||||||
let expected_templates = vec![
|
let expected_templates = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
||||||
batch_result.clone(),
|
module.batch.clone(),
|
||||||
0,
|
0,
|
||||||
))), // Refers to some_pred
|
))), // Refers to some_pred
|
||||||
args: vec![
|
args: vec![
|
||||||
|
|
@ -402,10 +402,9 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let processed = parse(input, ¶ms, &[])?;
|
let request = parse_request(input, ¶ms, &[])?;
|
||||||
let request_templates = processed.request.templates();
|
let request_templates = request.templates();
|
||||||
|
|
||||||
assert!(processed.custom_batches.is_empty());
|
|
||||||
assert!(!request_templates.is_empty());
|
assert!(!request_templates.is_empty());
|
||||||
|
|
||||||
let expected_templates = vec![
|
let expected_templates = vec![
|
||||||
|
|
@ -459,8 +458,8 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
// Parse the input string
|
// Parse the input string
|
||||||
let processed = super::parse(input, &Params::default(), &[])?;
|
let request = parse_request(input, &Params::default(), &[])?;
|
||||||
let parsed_templates = processed.request.templates();
|
let parsed_templates = request.templates();
|
||||||
|
|
||||||
// Define Expected Templates (Copied from prover/mod.rs)
|
// Define Expected Templates (Copied from prover/mod.rs)
|
||||||
let now_minus_18y_val = Value::from(1169909388_i64);
|
let now_minus_18y_val = Value::from(1169909388_i64);
|
||||||
|
|
@ -549,11 +548,6 @@ mod tests {
|
||||||
"Parsed ZuKYC request templates do not match the expected hard-coded version"
|
"Parsed ZuKYC request templates do not match the expected hard-coded version"
|
||||||
);
|
);
|
||||||
|
|
||||||
assert!(
|
|
||||||
processed.custom_batches.is_empty(),
|
|
||||||
"Expected no custom predicates for a REQUEST only input"
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -591,14 +585,10 @@ mod tests {
|
||||||
)
|
)
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let processed = super::parse(input, ¶ms, &[])?;
|
let module = load_module(input, "ethdos", ¶ms, vec![])?;
|
||||||
|
|
||||||
assert!(
|
|
||||||
processed.request.templates().is_empty(),
|
|
||||||
"Expected no request templates"
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
first_batch(&processed).predicates().len(),
|
module.batch.predicates().len(),
|
||||||
4,
|
4,
|
||||||
"Expected 4 custom predicates"
|
"Expected 4 custom predicates"
|
||||||
);
|
);
|
||||||
|
|
@ -718,7 +708,7 @@ mod tests {
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let expected_batch = CustomPredicateBatch::new(
|
let expected_batch = CustomPredicateBatch::new(
|
||||||
"PodlangBatch".to_string(),
|
"ethdos".to_string(),
|
||||||
vec![
|
vec![
|
||||||
expected_friend_pred,
|
expected_friend_pred,
|
||||||
expected_base_pred,
|
expected_base_pred,
|
||||||
|
|
@ -728,8 +718,7 @@ mod tests {
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
*first_batch(&processed),
|
&*module.batch, &*expected_batch,
|
||||||
expected_batch,
|
|
||||||
"Processed ETHDoS predicates do not match expected structure"
|
"Processed ETHDoS predicates do not match expected structure"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -737,10 +726,10 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_e2e_use_statement() -> Result<(), LangError> {
|
fn test_e2e_use_module_statement() -> Result<(), LangError> {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
|
|
||||||
// 1. Create a batch to be imported
|
// 1. Create a module with a predicate to be imported
|
||||||
let imported_pred_stmts = vec![StatementTmpl {
|
let imported_pred_stmts = vec![StatementTmpl {
|
||||||
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![
|
args: vec![
|
||||||
|
|
@ -755,98 +744,75 @@ mod tests {
|
||||||
2,
|
2,
|
||||||
names(&["A", "B"]),
|
names(&["A", "B"]),
|
||||||
)?;
|
)?;
|
||||||
let available_batch =
|
let batch = CustomPredicateBatch::new("my_module".to_string(), vec![imported_predicate]);
|
||||||
CustomPredicateBatch::new("MyBatch".to_string(), vec![imported_predicate]);
|
let module = Arc::new(Module::new(batch.clone(), HashMap::new()));
|
||||||
let available_batches = vec![available_batch.clone()];
|
let module_hash = module.id().encode_hex::<String>();
|
||||||
|
|
||||||
// 2. Create the input string that uses the batch
|
// 2. Create the input string that uses the module
|
||||||
let batch_id_str = available_batch.id().encode_hex::<String>();
|
|
||||||
let input = format!(
|
let input = format!(
|
||||||
r#"
|
r#"
|
||||||
use batch imported_pred from 0x{}
|
use module 0x{} as my_module
|
||||||
|
|
||||||
REQUEST(
|
REQUEST(
|
||||||
imported_pred(Pod1, Pod2)
|
my_module::imported_equal(Pod1, Pod2)
|
||||||
)
|
)
|
||||||
"#,
|
"#,
|
||||||
batch_id_str
|
module_hash
|
||||||
);
|
);
|
||||||
|
|
||||||
// 3. Parse the input
|
// 3. Parse the request
|
||||||
let processed = parse(&input, ¶ms, &available_batches)?;
|
let request = parse_request(&input, ¶ms, std::slice::from_ref(&module))?;
|
||||||
let request_templates = processed.request.templates();
|
let request_templates = request.templates();
|
||||||
|
|
||||||
assert!(
|
|
||||||
processed.custom_batches.is_empty(),
|
|
||||||
"No custom predicates should be defined in the main input"
|
|
||||||
);
|
|
||||||
assert_eq!(request_templates.len(), 1, "Expected one request template");
|
assert_eq!(request_templates.len(), 1, "Expected one request template");
|
||||||
|
|
||||||
// 4. Check the resulting request template
|
// 4. Check the resulting request template uses the imported predicate
|
||||||
let expected_request_templates = vec![StatementTmpl {
|
let template = &request_templates[0];
|
||||||
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
assert_eq!(template.args.len(), 2);
|
||||||
available_batch,
|
|
||||||
0,
|
|
||||||
))),
|
|
||||||
args: vec![
|
|
||||||
StatementTmplArg::Wildcard(wc("Pod1", 0)),
|
|
||||||
StatementTmplArg::Wildcard(wc("Pod2", 1)),
|
|
||||||
],
|
|
||||||
}];
|
|
||||||
|
|
||||||
assert_eq!(request_templates, expected_request_templates);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_e2e_use_statement_complex() -> Result<(), LangError> {
|
fn test_e2e_use_module_complex() -> Result<(), LangError> {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
|
|
||||||
// 1. Create a batch with multiple predicates
|
// 1. Create a module with multiple predicates
|
||||||
let pred1 = CustomPredicate::and(¶ms, "p1".into(), vec![], 1, names(&["A"]))?;
|
let pred1 = CustomPredicate::and(¶ms, "p1".into(), vec![], 1, names(&["A"]))?;
|
||||||
let pred2 = CustomPredicate::and(¶ms, "p2".into(), vec![], 2, names(&["B", "C"]))?;
|
let pred2 = CustomPredicate::and(¶ms, "p2".into(), vec![], 2, names(&["B", "C"]))?;
|
||||||
let pred3 = CustomPredicate::and(¶ms, "p3".into(), vec![], 1, names(&["D"]))?;
|
let pred3 = CustomPredicate::and(¶ms, "p3".into(), vec![], 1, names(&["D"]))?;
|
||||||
|
|
||||||
let available_batch =
|
let batch = CustomPredicateBatch::new("mymodule".to_string(), vec![pred1, pred2, pred3]);
|
||||||
CustomPredicateBatch::new("MyBatch".to_string(), vec![pred1, pred2, pred3]);
|
let mymodule = Arc::new(Module::new(batch.clone(), HashMap::new()));
|
||||||
let available_batches = vec![available_batch.clone()];
|
let module_hash = mymodule.id().encode_hex::<String>();
|
||||||
|
|
||||||
// 2. Create the input string that uses the batch with skips
|
|
||||||
let batch_id_str = available_batch.id().encode_hex::<String>();
|
|
||||||
|
|
||||||
|
// 2. Create the input string that uses qualified predicate access
|
||||||
let input = format!(
|
let input = format!(
|
||||||
r#"
|
r#"
|
||||||
use batch pred_one, _, pred_three from 0x{}
|
use module 0x{} as mymodule
|
||||||
|
|
||||||
REQUEST(
|
REQUEST(
|
||||||
pred_one(Pod1)
|
mymodule::p1(Pod1)
|
||||||
pred_three(Pod2)
|
mymodule::p3(Pod2)
|
||||||
)
|
)
|
||||||
"#,
|
"#,
|
||||||
batch_id_str
|
module_hash
|
||||||
);
|
);
|
||||||
|
|
||||||
// 3. Parse the input
|
// 3. Parse the request
|
||||||
let processed = parse(&input, ¶ms, &available_batches)?;
|
let request = parse_request(&input, ¶ms, std::slice::from_ref(&mymodule))?;
|
||||||
let request_templates = processed.request.templates();
|
let request_templates = request.templates();
|
||||||
|
|
||||||
assert_eq!(request_templates.len(), 2, "Expected two request templates");
|
assert_eq!(request_templates.len(), 2, "Expected two request templates");
|
||||||
|
|
||||||
// 4. Check the resulting request templates
|
// 4. Check the resulting request templates
|
||||||
let expected_templates = vec![
|
let expected_templates = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch.clone(), 0))),
|
||||||
available_batch.clone(),
|
|
||||||
0,
|
|
||||||
))),
|
|
||||||
args: vec![StatementTmplArg::Wildcard(wc("Pod1", 0))],
|
args: vec![StatementTmplArg::Wildcard(wc("Pod1", 0))],
|
||||||
},
|
},
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch, 2))),
|
||||||
available_batch,
|
|
||||||
2,
|
|
||||||
))),
|
|
||||||
args: vec![StatementTmplArg::Wildcard(wc("Pod2", 1))],
|
args: vec![StatementTmplArg::Wildcard(wc("Pod2", 1))],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
@ -857,10 +823,10 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_e2e_custom_predicate_uses_import() -> Result<(), LangError> {
|
fn test_e2e_custom_predicate_uses_module() -> Result<(), LangError> {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
|
|
||||||
// 1. Create a batch with a predicate to be imported
|
// 1. Create a module with a predicate to be imported
|
||||||
let imported_pred_stmts = vec![StatementTmpl {
|
let imported_pred_stmts = vec![StatementTmpl {
|
||||||
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
|
||||||
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
|
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
|
||||||
|
|
@ -872,47 +838,38 @@ mod tests {
|
||||||
2,
|
2,
|
||||||
names(&["A", "B"]),
|
names(&["A", "B"]),
|
||||||
)?;
|
)?;
|
||||||
let available_batch =
|
let batch = CustomPredicateBatch::new("extmod".to_string(), vec![imported_predicate]);
|
||||||
CustomPredicateBatch::new("MyBatch".to_string(), vec![imported_predicate]);
|
let extmod = Arc::new(Module::new(batch.clone(), HashMap::new()));
|
||||||
let available_batches = vec![available_batch.clone()];
|
let extmod_hash = extmod.id().encode_hex::<String>();
|
||||||
|
|
||||||
// 2. Create the input string that defines a new predicate using the imported one
|
// 2. Create the input string that defines a new predicate using the imported one
|
||||||
let batch_id_str = available_batch.id().encode_hex::<String>();
|
|
||||||
|
|
||||||
let input = format!(
|
let input = format!(
|
||||||
r#"
|
r#"
|
||||||
use batch imported_eq from 0x{}
|
use module 0x{} as extmod
|
||||||
|
|
||||||
wrapper_pred(X, Y) = AND(
|
wrapper_pred(X, Y) = AND(
|
||||||
imported_eq(X, Y)
|
extmod::imported_equal(X, Y)
|
||||||
)
|
)
|
||||||
"#,
|
"#,
|
||||||
batch_id_str
|
extmod_hash
|
||||||
);
|
);
|
||||||
|
|
||||||
// 3. Parse the input
|
// 3. Load as module
|
||||||
let processed = parse(&input, ¶ms, &available_batches)?;
|
let module = load_module(&input, "test", ¶ms, vec![extmod])?;
|
||||||
|
|
||||||
assert!(
|
|
||||||
processed.request.templates().is_empty(),
|
|
||||||
"No request should be defined"
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
first_batch(&processed).predicates().len(),
|
module.batch.predicates().len(),
|
||||||
1,
|
1,
|
||||||
"Expected one custom predicate to be defined"
|
"Expected one custom predicate to be defined"
|
||||||
);
|
);
|
||||||
|
|
||||||
// 4. Check the resulting predicate definition
|
// 4. Check the resulting predicate definition
|
||||||
let defined_pred = &first_batch(&processed).predicates()[0];
|
let defined_pred = &module.batch.predicates()[0];
|
||||||
assert_eq!(defined_pred.name, "wrapper_pred");
|
assert_eq!(defined_pred.name, "wrapper_pred");
|
||||||
assert_eq!(defined_pred.statements.len(), 1);
|
assert_eq!(defined_pred.statements.len(), 1);
|
||||||
|
|
||||||
let expected_statement = StatementTmpl {
|
let expected_statement = StatementTmpl {
|
||||||
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
|
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch, 0))),
|
||||||
available_batch.clone(),
|
|
||||||
0,
|
|
||||||
))),
|
|
||||||
args: vec![
|
args: vec![
|
||||||
StatementTmplArg::Wildcard(wc("X", 0)),
|
StatementTmplArg::Wildcard(wc("X", 0)),
|
||||||
StatementTmplArg::Wildcard(wc("Y", 1)),
|
StatementTmplArg::Wildcard(wc("Y", 1)),
|
||||||
|
|
@ -939,8 +896,8 @@ mod tests {
|
||||||
"#,
|
"#,
|
||||||
);
|
);
|
||||||
|
|
||||||
let processed = parse(&input, ¶ms, &[])?;
|
let request = parse_request(&input, ¶ms, &[])?;
|
||||||
let request_templates = processed.request.templates();
|
let request_templates = request.templates();
|
||||||
assert_eq!(request_templates.len(), 1);
|
assert_eq!(request_templates.len(), 1);
|
||||||
|
|
||||||
if let PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) =
|
if let PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) =
|
||||||
|
|
@ -998,8 +955,8 @@ mod tests {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let processed = parse(&input, ¶ms, &[])?;
|
let request = parse_request(&input, ¶ms, &[])?;
|
||||||
let request_templates = processed.request.templates();
|
let request_templates = request.templates();
|
||||||
|
|
||||||
let expected_templates = vec![
|
let expected_templates = vec![
|
||||||
StatementTmpl {
|
StatementTmpl {
|
||||||
|
|
@ -1034,29 +991,33 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_e2e_use_unknown_batch() {
|
fn test_e2e_use_unknown_module() {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let available_batches = &[];
|
|
||||||
|
|
||||||
let unknown_batch_id = format!("0x{}", "a".repeat(64));
|
|
||||||
|
|
||||||
|
// Use a hash that doesn't correspond to any loaded module
|
||||||
|
let fake_hash = EMPTY_HASH.encode_hex::<String>();
|
||||||
let input = format!(
|
let input = format!(
|
||||||
r#"
|
r#"
|
||||||
use batch some_pred from {}
|
use module 0x{} as unknown_module
|
||||||
|
|
||||||
|
REQUEST(
|
||||||
|
Equal(A["x"], 1)
|
||||||
|
)
|
||||||
"#,
|
"#,
|
||||||
unknown_batch_id
|
fake_hash
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = parse(&input, ¶ms, available_batches);
|
let result = parse_request(&input, ¶ms, &[]);
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
match result.err().unwrap() {
|
match result.err().unwrap() {
|
||||||
LangError::Validation(e) => match *e {
|
LangError::Validation(e) => match *e {
|
||||||
frontend_ast_validate::ValidationError::BatchNotFound { id, .. } => {
|
frontend_ast_validate::ValidationError::ModuleNotFound { name, .. } => {
|
||||||
assert_eq!(id, unknown_batch_id);
|
// The error now carries the hex-formatted hash
|
||||||
|
assert_eq!(name, fake_hash);
|
||||||
}
|
}
|
||||||
_ => panic!("Expected BatchNotFound error, but got {:?}", e),
|
_ => panic!("Expected ModuleNotFound error, but got {:?}", e),
|
||||||
},
|
},
|
||||||
e => panic!("Expected LangError::Validation, but got {:?}", e),
|
e => panic!("Expected LangError::Validation, but got {:?}", e),
|
||||||
}
|
}
|
||||||
|
|
@ -1065,17 +1026,15 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_e2e_undefined_wildcard() {
|
fn test_e2e_undefined_wildcard() {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let available_batches = &[];
|
|
||||||
|
|
||||||
let input = r#"
|
let input = r#"
|
||||||
identity_verified(username, private: identity_dict) = AND(
|
identity_verified(username, private: identity_dict) = AND(
|
||||||
Equal(identity_dict["username"], username)
|
Equal(identity_dict["username"], username)
|
||||||
Equal(identity_dict["user_public_key"], user_public_key)
|
Equal(identity_dict["user_public_key"], user_public_key)
|
||||||
)
|
)
|
||||||
"#
|
"#;
|
||||||
.to_string();
|
|
||||||
|
|
||||||
let result = parse(&input, ¶ms, available_batches);
|
let result = load_module(input, "test", ¶ms, vec![]);
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
|
|
||||||
671
src/lang/module.rs
Normal file
671
src/lang/module.rs
Normal file
|
|
@ -0,0 +1,671 @@
|
||||||
|
//! Podlang Module: definition, construction, and predicate application.
|
||||||
|
//!
|
||||||
|
//! A [`Module`] wraps a middleware `CustomPredicateBatch` with name resolution
|
||||||
|
//! and split chain metadata. Use [`build_module`] to construct a Module from
|
||||||
|
//! validated and split predicates.
|
||||||
|
|
||||||
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
frontend::{CustomPredicateBatchBuilder, Operation, OperationArg, StatementTmplBuilder},
|
||||||
|
lang::{
|
||||||
|
error::BatchingError,
|
||||||
|
frontend_ast::{ConjunctionType, CustomPredicateDef},
|
||||||
|
frontend_ast_lower::{lower_statement_arg, resolve_predicate_ref, ResolutionContext},
|
||||||
|
frontend_ast_split::{SplitChainInfo, SplitResult},
|
||||||
|
frontend_ast_validate::SymbolTable,
|
||||||
|
},
|
||||||
|
middleware::{CustomPredicateBatch, CustomPredicateRef, Hash, Params, Statement},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Errors that can occur when applying predicates
|
||||||
|
#[derive(Debug, Clone, thiserror::Error)]
|
||||||
|
pub enum MultiOperationError {
|
||||||
|
#[error("Predicate not found: {0}")]
|
||||||
|
PredicateNotFound(String),
|
||||||
|
|
||||||
|
#[error("Chain piece not found: {0}")]
|
||||||
|
ChainPieceNotFound(String),
|
||||||
|
|
||||||
|
#[error(
|
||||||
|
"Wrong statement count for predicate '{predicate}': expected {expected}, got {actual}"
|
||||||
|
)]
|
||||||
|
WrongStatementCount {
|
||||||
|
predicate: String,
|
||||||
|
expected: usize,
|
||||||
|
actual: usize,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("No operation steps to apply")]
|
||||||
|
NoSteps,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A Podlang module wrapping a middleware CustomPredicateBatch with name resolution info.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Module {
|
||||||
|
/// The middleware representation (CustomPredicateBatch)
|
||||||
|
pub batch: Arc<CustomPredicateBatch>,
|
||||||
|
|
||||||
|
/// Map from predicate name to index in batch
|
||||||
|
pub predicate_index: HashMap<String, usize>,
|
||||||
|
|
||||||
|
/// Split chain info for predicates that were split
|
||||||
|
pub split_chains: HashMap<String, SplitChainInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module {
|
||||||
|
/// Create a new Module from a batch, building the predicate_index automatically
|
||||||
|
pub fn new(
|
||||||
|
batch: Arc<CustomPredicateBatch>,
|
||||||
|
split_chains: HashMap<String, SplitChainInfo>,
|
||||||
|
) -> Self {
|
||||||
|
let predicate_index = batch
|
||||||
|
.predicates()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, p)| (p.name.clone(), i))
|
||||||
|
.collect();
|
||||||
|
Self {
|
||||||
|
batch,
|
||||||
|
predicate_index,
|
||||||
|
split_chains,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Root hash of the module's Merkle tree
|
||||||
|
pub fn id(&self) -> Hash {
|
||||||
|
self.batch.id()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a reference to a predicate by name
|
||||||
|
pub fn predicate_ref_by_name(&self, name: &str) -> Option<CustomPredicateRef> {
|
||||||
|
let idx = self.predicate_index.get(name)?;
|
||||||
|
Some(CustomPredicateRef::new(self.batch.clone(), *idx))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the module contains any predicates
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.batch.predicates().is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply a predicate directly into a `MainPodBuilder` (common case).
|
||||||
|
///
|
||||||
|
/// For split predicates, earlier chain links are applied as private, and only the final
|
||||||
|
/// piece is applied as public when `public` is true. For non-split predicates, the single
|
||||||
|
/// operation is applied with the provided `public` flag.
|
||||||
|
///
|
||||||
|
/// Arguments:
|
||||||
|
/// - `builder`: target builder to receive operations
|
||||||
|
/// - `name`: predicate name
|
||||||
|
/// - `statements`: user statements in original declaration order
|
||||||
|
/// - `public`: whether the final result should be public
|
||||||
|
pub fn apply_predicate(
|
||||||
|
&self,
|
||||||
|
builder: &mut crate::frontend::MainPodBuilder,
|
||||||
|
name: &str,
|
||||||
|
statements: Vec<Statement>,
|
||||||
|
public: bool,
|
||||||
|
) -> crate::frontend::Result<Statement> {
|
||||||
|
self.apply_predicate_with(name, statements, public, |is_public, op| {
|
||||||
|
if is_public {
|
||||||
|
builder.pub_op(op)
|
||||||
|
} else {
|
||||||
|
builder.priv_op(op)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Advanced variant: apply using a custom closure.
|
||||||
|
///
|
||||||
|
/// Prefer `apply_predicate` for common usage. This method allows callers to intercept each
|
||||||
|
/// operation (with its `public` flag) and decide how to execute it.
|
||||||
|
///
|
||||||
|
/// Arguments:
|
||||||
|
/// - `name`: predicate name
|
||||||
|
/// - `statements`: user statements in original declaration order
|
||||||
|
/// - `public`: whether the final result should be public
|
||||||
|
/// - `apply_op`: closure `(is_public, operation) -> Result<Statement>` used to execute each step
|
||||||
|
pub fn apply_predicate_with<F, E>(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
statements: Vec<Statement>,
|
||||||
|
public: bool,
|
||||||
|
mut apply_op: F,
|
||||||
|
) -> Result<Statement, E>
|
||||||
|
where
|
||||||
|
F: FnMut(bool, Operation) -> Result<Statement, E>,
|
||||||
|
E: From<MultiOperationError>,
|
||||||
|
{
|
||||||
|
let steps = self.build_steps(name, statements, public)?;
|
||||||
|
|
||||||
|
if steps.is_empty() {
|
||||||
|
return Err(MultiOperationError::NoSteps.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut prev_result: Option<Statement> = None;
|
||||||
|
|
||||||
|
for step in steps {
|
||||||
|
let op = if let Some(prev) = prev_result {
|
||||||
|
// Replace the last Statement::None arg with the previous result.
|
||||||
|
let mut args = step.operation.1;
|
||||||
|
let last = args
|
||||||
|
.last_mut()
|
||||||
|
.expect("chain statement should include placeholder arg");
|
||||||
|
assert!(
|
||||||
|
matches!(last, OperationArg::Statement(Statement::None)),
|
||||||
|
"expected last arg to be a Statement::None placeholder"
|
||||||
|
);
|
||||||
|
*last = OperationArg::Statement(prev);
|
||||||
|
Operation(step.operation.0, args, step.operation.2)
|
||||||
|
} else {
|
||||||
|
step.operation
|
||||||
|
};
|
||||||
|
|
||||||
|
prev_result = Some(apply_op(step.public, op)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(prev_result.unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build operation steps for a predicate (internal helper)
|
||||||
|
fn build_steps(
|
||||||
|
&self,
|
||||||
|
predicate_name: &str,
|
||||||
|
statements: Vec<Statement>,
|
||||||
|
public: bool,
|
||||||
|
) -> Result<Vec<OperationStep>, MultiOperationError> {
|
||||||
|
// Check if this predicate was split
|
||||||
|
let chain_info = match self.split_chains.get(predicate_name) {
|
||||||
|
Some(info) => info,
|
||||||
|
None => {
|
||||||
|
// Not split - single operation with all statements
|
||||||
|
let pred_ref = self.predicate_ref_by_name(predicate_name).ok_or_else(|| {
|
||||||
|
MultiOperationError::PredicateNotFound(predicate_name.to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
return Ok(vec![OperationStep {
|
||||||
|
operation: Operation::custom(pred_ref, statements),
|
||||||
|
public,
|
||||||
|
}]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Validate statement count
|
||||||
|
if statements.len() != chain_info.real_statement_count {
|
||||||
|
return Err(MultiOperationError::WrongStatementCount {
|
||||||
|
predicate: predicate_name.to_string(),
|
||||||
|
expected: chain_info.real_statement_count,
|
||||||
|
actual: statements.len(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reorder statements from original order to split order
|
||||||
|
let mut reordered = vec![Statement::None; statements.len()];
|
||||||
|
for (original_idx, stmt) in statements.into_iter().enumerate() {
|
||||||
|
let split_idx = chain_info.reorder_map[original_idx];
|
||||||
|
reordered[split_idx] = stmt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build operations for each piece in execution order
|
||||||
|
let num_pieces = chain_info.chain_pieces.len();
|
||||||
|
|
||||||
|
// Compute the starting offset for each piece
|
||||||
|
let mut piece_offsets = vec![0usize; num_pieces];
|
||||||
|
let mut offset = 0;
|
||||||
|
for i in (0..num_pieces).rev() {
|
||||||
|
piece_offsets[i] = offset;
|
||||||
|
offset += chain_info.chain_pieces[i].real_statement_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut steps = Vec::new();
|
||||||
|
for (piece_idx, piece) in chain_info.chain_pieces.iter().enumerate() {
|
||||||
|
let is_final = piece_idx == num_pieces - 1;
|
||||||
|
|
||||||
|
let piece_ref = self
|
||||||
|
.predicate_ref_by_name(&piece.name)
|
||||||
|
.ok_or_else(|| MultiOperationError::ChainPieceNotFound(piece.name.clone()))?;
|
||||||
|
|
||||||
|
let start = piece_offsets[piece_idx];
|
||||||
|
let end = start + piece.real_statement_count;
|
||||||
|
let mut args: Vec<Statement> = reordered[start..end].to_vec();
|
||||||
|
|
||||||
|
if piece.has_chain_call {
|
||||||
|
args.push(Statement::None);
|
||||||
|
}
|
||||||
|
|
||||||
|
steps.push(OperationStep {
|
||||||
|
operation: Operation::custom(piece_ref, args),
|
||||||
|
public: public && is_final,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(steps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single step in a multi-operation sequence for split predicates
|
||||||
|
struct OperationStep {
|
||||||
|
operation: Operation,
|
||||||
|
public: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a single Module from split predicate results.
|
||||||
|
///
|
||||||
|
/// Takes a list of split results (containing predicates and optional chain info)
|
||||||
|
/// and builds a single Module. With Merkle tree backing supporting up to 65536
|
||||||
|
/// predicates, all predicates from a document fit in one module.
|
||||||
|
///
|
||||||
|
/// `symbols` provides the symbol table for resolving predicate references,
|
||||||
|
/// including imported predicates from other modules and intro predicates.
|
||||||
|
pub fn build_module(
|
||||||
|
split_results: Vec<SplitResult>,
|
||||||
|
params: &Params,
|
||||||
|
module_name: &str,
|
||||||
|
symbols: &SymbolTable,
|
||||||
|
) -> Result<Module, BatchingError> {
|
||||||
|
// Extract predicates and collect split chains
|
||||||
|
let mut predicates = Vec::new();
|
||||||
|
let mut split_chains = HashMap::new();
|
||||||
|
|
||||||
|
for result in split_results {
|
||||||
|
// Collect chain info if present
|
||||||
|
if let Some(chain_info) = result.chain_info {
|
||||||
|
split_chains.insert(chain_info.original_name.clone(), chain_info);
|
||||||
|
}
|
||||||
|
// Flatten predicates
|
||||||
|
predicates.extend(result.predicates);
|
||||||
|
}
|
||||||
|
|
||||||
|
if predicates.is_empty() {
|
||||||
|
// Return an empty module
|
||||||
|
let empty_batch = CustomPredicateBatch::new(module_name.to_string(), vec![]);
|
||||||
|
return Ok(Module::new(empty_batch, split_chains));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build reference map: name -> index
|
||||||
|
let reference_map: HashMap<String, usize> = predicates
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, pred)| (pred.name.name.clone(), idx))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Build the batch
|
||||||
|
let batch = build_single_batch(&predicates, &reference_map, symbols, params, module_name)?;
|
||||||
|
|
||||||
|
Ok(Module::new(batch, split_chains))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a batch with properly resolved references
|
||||||
|
fn build_single_batch(
|
||||||
|
predicates: &[CustomPredicateDef],
|
||||||
|
reference_map: &HashMap<String, usize>,
|
||||||
|
symbols: &SymbolTable,
|
||||||
|
params: &Params,
|
||||||
|
batch_name: &str,
|
||||||
|
) -> Result<Arc<CustomPredicateBatch>, BatchingError> {
|
||||||
|
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), batch_name.to_string());
|
||||||
|
|
||||||
|
for pred in predicates {
|
||||||
|
let name = &pred.name.name;
|
||||||
|
|
||||||
|
// Collect argument names
|
||||||
|
let public_args: Vec<&str> = pred
|
||||||
|
.args
|
||||||
|
.public_args
|
||||||
|
.iter()
|
||||||
|
.map(|a| a.name.as_str())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let private_args: Vec<&str> = pred
|
||||||
|
.args
|
||||||
|
.private_args
|
||||||
|
.as_ref()
|
||||||
|
.map(|args| args.iter().map(|a| a.name.as_str()).collect())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
// Build statement templates with resolved predicates
|
||||||
|
let statement_builders: Vec<StatementTmplBuilder> = pred
|
||||||
|
.statements
|
||||||
|
.iter()
|
||||||
|
.map(|stmt| build_statement_with_resolved_refs(stmt, reference_map, name, symbols))
|
||||||
|
.collect::<Result<_, _>>()?;
|
||||||
|
|
||||||
|
let conjunction = pred.conjunction_type == ConjunctionType::And;
|
||||||
|
|
||||||
|
builder
|
||||||
|
.predicate(
|
||||||
|
name,
|
||||||
|
conjunction,
|
||||||
|
&public_args,
|
||||||
|
&private_args,
|
||||||
|
&statement_builders,
|
||||||
|
)
|
||||||
|
.map_err(|e| BatchingError::Internal {
|
||||||
|
message: format!("Failed to add predicate '{}': {}", name, e),
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(builder.finish())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a statement template with properly resolved predicate references
|
||||||
|
fn build_statement_with_resolved_refs(
|
||||||
|
stmt: &crate::lang::frontend_ast::StatementTmpl,
|
||||||
|
reference_map: &HashMap<String, usize>,
|
||||||
|
custom_predicate_name: &str, // custom pred that defines this statement template
|
||||||
|
symbols: &SymbolTable,
|
||||||
|
) -> Result<StatementTmplBuilder, BatchingError> {
|
||||||
|
// Resolve the predicate using the unified resolution function
|
||||||
|
let context = ResolutionContext::Module {
|
||||||
|
reference_map,
|
||||||
|
custom_predicate_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
let pred_or_wc =
|
||||||
|
resolve_predicate_ref(&stmt.predicate, symbols, &context).ok_or_else(|| {
|
||||||
|
BatchingError::Internal {
|
||||||
|
message: format!("Unknown predicate reference: '{}'", stmt.predicate),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Build the statement template
|
||||||
|
let mut builder = StatementTmplBuilder::new(pred_or_wc);
|
||||||
|
|
||||||
|
for arg in &stmt.args {
|
||||||
|
builder = builder.arg(lower_statement_arg(arg));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(builder)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{
|
||||||
|
lang::{
|
||||||
|
frontend_ast::parse::parse_document,
|
||||||
|
frontend_ast_split::split_predicate_if_needed,
|
||||||
|
frontend_ast_validate::{validate, ParseMode, ValidatedAST},
|
||||||
|
load_module,
|
||||||
|
parser::parse_podlang,
|
||||||
|
},
|
||||||
|
middleware::{CustomPredicateRef, Predicate, PredicateOrWildcard},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Helper: parse and validate input, returning predicates and symbol table
|
||||||
|
fn parse_and_validate(input: &str) -> (Vec<CustomPredicateDef>, ValidatedAST) {
|
||||||
|
let parsed = parse_podlang(input).expect("Failed to parse");
|
||||||
|
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
|
||||||
|
let validated = validate(document.clone(), &HashMap::new(), ParseMode::Module)
|
||||||
|
.expect("Failed to validate");
|
||||||
|
|
||||||
|
let predicates = document
|
||||||
|
.items
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|item| match item {
|
||||||
|
crate::lang::frontend_ast::DocumentItem::CustomPredicateDef(pred) => Some(pred),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
(predicates, validated)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper: wrap predicates into SplitResult (without actually splitting)
|
||||||
|
fn preds_to_split_results(predicates: Vec<CustomPredicateDef>) -> Vec<SplitResult> {
|
||||||
|
predicates
|
||||||
|
.into_iter()
|
||||||
|
.map(|pred| SplitResult {
|
||||||
|
predicates: vec![pred],
|
||||||
|
chain_info: None,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_single_predicate() {
|
||||||
|
let input = r#"
|
||||||
|
my_pred(A, B) = AND(
|
||||||
|
Equal(A["x"], B["y"])
|
||||||
|
)
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let (predicates, validated) = parse_and_validate(input);
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
let result = build_module(
|
||||||
|
preds_to_split_results(predicates),
|
||||||
|
¶ms,
|
||||||
|
"TestModule",
|
||||||
|
validated.symbols(),
|
||||||
|
);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let module = result.unwrap();
|
||||||
|
assert_eq!(module.batch.predicates().len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiple_predicates() {
|
||||||
|
let input = r#"
|
||||||
|
pred1(A) = AND(Equal(A["x"], 1))
|
||||||
|
pred2(B) = AND(Equal(B["y"], 2))
|
||||||
|
pred3(C) = AND(Equal(C["z"], 3))
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let (predicates, validated) = parse_and_validate(input);
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
let result = build_module(
|
||||||
|
preds_to_split_results(predicates),
|
||||||
|
¶ms,
|
||||||
|
"TestModule",
|
||||||
|
validated.symbols(),
|
||||||
|
);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let module = result.unwrap();
|
||||||
|
assert_eq!(module.batch.predicates().len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_intra_batch_forward_reference() {
|
||||||
|
// pred2 calls pred1, but pred2 is declared first
|
||||||
|
// This should work because they're in the same batch
|
||||||
|
let input = r#"
|
||||||
|
pred2(B) = AND(pred1(B))
|
||||||
|
pred1(A) = AND(Equal(A["x"], 1))
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let (predicates, validated) = parse_and_validate(input);
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
let result = build_module(
|
||||||
|
preds_to_split_results(predicates),
|
||||||
|
¶ms,
|
||||||
|
"TestModule",
|
||||||
|
validated.symbols(),
|
||||||
|
);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let module = result.unwrap();
|
||||||
|
assert_eq!(module.batch.predicates().len(), 2);
|
||||||
|
|
||||||
|
// pred2 should reference pred1 via BatchSelf
|
||||||
|
let pred2 = &module.batch.predicates()[0];
|
||||||
|
let stmt = &pred2.statements[0];
|
||||||
|
assert!(matches!(
|
||||||
|
stmt.pred_or_wc(),
|
||||||
|
PredicateOrWildcard::Predicate(Predicate::BatchSelf(1))
|
||||||
|
)); // pred1 is at index 1
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mutual_recursion() {
|
||||||
|
// pred1 calls pred2, pred2 calls pred1 - mutual recursion
|
||||||
|
// This should work because they're in the same batch
|
||||||
|
let input = r#"
|
||||||
|
pred1(A) = AND(pred2(A))
|
||||||
|
pred2(B) = AND(pred1(B))
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let (predicates, validated) = parse_and_validate(input);
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
let result = build_module(
|
||||||
|
preds_to_split_results(predicates),
|
||||||
|
¶ms,
|
||||||
|
"TestModule",
|
||||||
|
validated.symbols(),
|
||||||
|
);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let module = result.unwrap();
|
||||||
|
assert_eq!(module.batch.predicates().len(), 2);
|
||||||
|
|
||||||
|
// Both should use BatchSelf references
|
||||||
|
let pred1 = &module.batch.predicates()[0];
|
||||||
|
let pred2 = &module.batch.predicates()[1];
|
||||||
|
assert!(matches!(
|
||||||
|
pred1.statements[0].pred_or_wc(),
|
||||||
|
PredicateOrWildcard::Predicate(Predicate::BatchSelf(1))
|
||||||
|
)); // calls pred2
|
||||||
|
assert!(matches!(
|
||||||
|
pred2.statements[0].pred_or_wc(),
|
||||||
|
PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
|
||||||
|
)); // calls pred1
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predicate_ref_by_name() {
|
||||||
|
let input = r#"
|
||||||
|
pred1(A) = AND(Equal(A["x"], 1))
|
||||||
|
pred2(B) = AND(Equal(B["y"], 2))
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let (predicates, validated) = parse_and_validate(input);
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
let module = build_module(
|
||||||
|
preds_to_split_results(predicates),
|
||||||
|
¶ms,
|
||||||
|
"TestModule",
|
||||||
|
validated.symbols(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Should be able to look up both predicates
|
||||||
|
assert!(module.predicate_ref_by_name("pred1").is_some());
|
||||||
|
assert!(module.predicate_ref_by_name("pred2").is_some());
|
||||||
|
assert!(module.predicate_ref_by_name("nonexistent").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_split_predicate() {
|
||||||
|
// A predicate that will split into 2 pieces
|
||||||
|
let input = r#"
|
||||||
|
large_pred(A) = AND(
|
||||||
|
Equal(A["a"], 1)
|
||||||
|
Equal(A["b"], 2)
|
||||||
|
Equal(A["c"], 3)
|
||||||
|
Equal(A["d"], 4)
|
||||||
|
Equal(A["e"], 5)
|
||||||
|
Equal(A["f"], 6)
|
||||||
|
)
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let (predicates, validated) = parse_and_validate(input);
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
// Split the predicate
|
||||||
|
let mut split_results = Vec::new();
|
||||||
|
for pred in predicates {
|
||||||
|
let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed");
|
||||||
|
split_results.push(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should split into 2 pieces
|
||||||
|
assert_eq!(split_results.len(), 1);
|
||||||
|
assert_eq!(split_results[0].predicates.len(), 2);
|
||||||
|
assert!(split_results[0].chain_info.is_some());
|
||||||
|
|
||||||
|
let module =
|
||||||
|
build_module(split_results, ¶ms, "TestModule", validated.symbols()).unwrap();
|
||||||
|
|
||||||
|
// Verify chain info is preserved
|
||||||
|
let chain_info = module.split_chains.get("large_pred").unwrap();
|
||||||
|
assert_eq!(chain_info.chain_pieces.len(), 2);
|
||||||
|
assert_eq!(chain_info.real_statement_count, 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_load_module_importing_two_modules() {
|
||||||
|
use hex::ToHex;
|
||||||
|
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
// Module "checks": defines is_equal
|
||||||
|
let checks = Arc::new(
|
||||||
|
load_module(
|
||||||
|
r#"is_equal(X, Y) = AND(Equal(X["val"], Y["val"]))"#,
|
||||||
|
"checks",
|
||||||
|
¶ms,
|
||||||
|
vec![],
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Module "ordering": defines is_less
|
||||||
|
let ordering = Arc::new(
|
||||||
|
load_module(
|
||||||
|
r#"is_less(X, Y) = AND(Lt(X["val"], Y["val"]))"#,
|
||||||
|
"ordering",
|
||||||
|
¶ms,
|
||||||
|
vec![],
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let checks_hash = checks.id().encode_hex::<String>();
|
||||||
|
let ordering_hash = ordering.id().encode_hex::<String>();
|
||||||
|
|
||||||
|
// Module "combined": imports both, uses predicates from each
|
||||||
|
let combined = load_module(
|
||||||
|
&format!(
|
||||||
|
r#"
|
||||||
|
use module 0x{} as checks
|
||||||
|
use module 0x{} as ordering
|
||||||
|
|
||||||
|
equal_and_ordered(A, B, C) = AND(
|
||||||
|
checks::is_equal(A, B)
|
||||||
|
ordering::is_less(B, C)
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
checks_hash, ordering_hash
|
||||||
|
),
|
||||||
|
"combined",
|
||||||
|
¶ms,
|
||||||
|
vec![checks.clone(), ordering.clone()],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(combined.batch.predicates().len(), 1);
|
||||||
|
let pred = &combined.batch.predicates()[0];
|
||||||
|
assert_eq!(pred.name, "equal_and_ordered");
|
||||||
|
assert_eq!(pred.statements.len(), 2);
|
||||||
|
|
||||||
|
// First statement references checks::is_equal (external Custom ref, not BatchSelf)
|
||||||
|
let checks_ref = CustomPredicateRef::new(checks.batch.clone(), 0);
|
||||||
|
assert_eq!(
|
||||||
|
*pred.statements[0].pred_or_wc(),
|
||||||
|
PredicateOrWildcard::Predicate(Predicate::Custom(checks_ref))
|
||||||
|
);
|
||||||
|
|
||||||
|
// Second statement references ordering::is_less (external Custom ref, not BatchSelf)
|
||||||
|
let ordering_ref = CustomPredicateRef::new(ordering.batch.clone(), 0);
|
||||||
|
assert_eq!(
|
||||||
|
*pred.statements[1].pred_or_wc(),
|
||||||
|
PredicateOrWildcard::Predicate(Predicate::Custom(ordering_ref))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -219,7 +219,7 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{
|
use crate::{
|
||||||
backends::plonky2::primitives::ec::schnorr::SecretKey,
|
backends::plonky2::primitives::ec::schnorr::SecretKey,
|
||||||
lang::parse,
|
lang::load_module,
|
||||||
middleware::{
|
middleware::{
|
||||||
CustomPredicate, Key, NativePredicate, Params, Predicate, StatementTmpl,
|
CustomPredicate, Key, NativePredicate, Params, Predicate, StatementTmpl,
|
||||||
StatementTmplArg, Value, Wildcard,
|
StatementTmplArg, Value, Wildcard,
|
||||||
|
|
@ -388,20 +388,19 @@ mod tests {
|
||||||
/// Helper function for round-trip testing
|
/// Helper function for round-trip testing
|
||||||
fn assert_round_trip(input: &str) {
|
fn assert_round_trip(input: &str) {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let available_batches = &[];
|
|
||||||
|
|
||||||
// Step 1: Parse the input
|
// Step 1: Parse the input
|
||||||
let parsed_result =
|
let module =
|
||||||
parse(input, ¶ms, available_batches).expect("Initial parsing should succeed");
|
load_module(input, "test", ¶ms, vec![]).expect("Initial parsing should succeed");
|
||||||
|
|
||||||
// Step 2: Pretty-print the parsed batch
|
// Step 2: Pretty-print the parsed batch
|
||||||
let batch = parsed_result.first_batch().expect("Expected batch");
|
let batch = &module.batch;
|
||||||
let pretty_printed = batch.to_podlang_string();
|
let pretty_printed = batch.to_podlang_string();
|
||||||
|
|
||||||
// Step 3: Parse the pretty-printed result
|
// Step 3: Parse the pretty-printed result
|
||||||
let reparsed_result =
|
let reparsed_module = load_module(&pretty_printed, "test", ¶ms, vec![])
|
||||||
parse(&pretty_printed, ¶ms, available_batches).expect("Reparsing should succeed");
|
.expect("Reparsing should succeed");
|
||||||
let reparsed_batch = reparsed_result.first_batch().expect("Expected batch");
|
let reparsed_batch = &reparsed_module.batch;
|
||||||
|
|
||||||
// Step 4: Verify the ASTs are equivalent
|
// Step 4: Verify the ASTs are equivalent
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|
@ -556,16 +555,17 @@ mod tests {
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let parsed_result = parse(input, ¶ms, &[]).expect("Parsing should succeed");
|
let module = load_module(input, "test", ¶ms, vec![]).expect("Parsing should succeed");
|
||||||
let batch = parsed_result.first_batch().expect("Expected batch");
|
let batch = &module.batch;
|
||||||
|
|
||||||
let pretty_printed = batch.to_podlang_string();
|
let pretty_printed = batch.to_podlang_string();
|
||||||
|
|
||||||
println!("Original input:\n{}", input);
|
println!("Original input:\n{}", input);
|
||||||
println!("\nPretty-printed output:\n{}", pretty_printed);
|
println!("\nPretty-printed output:\n{}", pretty_printed);
|
||||||
|
|
||||||
let reparsed = parse(&pretty_printed, ¶ms, &[]).expect("Reparsing should succeed");
|
let reparsed = load_module(&pretty_printed, "test", ¶ms, vec![])
|
||||||
let reparsed_batch = reparsed.first_batch().expect("Expected batch");
|
.expect("Reparsing should succeed");
|
||||||
|
let reparsed_batch = &reparsed.batch;
|
||||||
|
|
||||||
assert_eq!(batch.predicates(), reparsed_batch.predicates());
|
assert_eq!(batch.predicates(), reparsed_batch.predicates());
|
||||||
}
|
}
|
||||||
|
|
@ -629,14 +629,15 @@ mod tests {
|
||||||
);
|
);
|
||||||
|
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
let parsed_result = parse(&input, ¶ms, &[]).expect("Should parse successfully");
|
let module =
|
||||||
let batch = parsed_result.first_batch().expect("Expected batch");
|
load_module(&input, "test", ¶ms, vec![]).expect("Should parse successfully");
|
||||||
|
let batch = &module.batch;
|
||||||
|
|
||||||
let pretty_printed = batch.to_podlang_string();
|
let pretty_printed = batch.to_podlang_string();
|
||||||
|
|
||||||
let reparsed_result =
|
let reparsed_module = load_module(&pretty_printed, "test", ¶ms, vec![])
|
||||||
parse(&pretty_printed, ¶ms, &[]).expect("Should reparse successfully");
|
.expect("Should reparse successfully");
|
||||||
let reparsed_batch = reparsed_result.first_batch().expect("Expected batch");
|
let reparsed_batch = &reparsed_module.batch;
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
batch.predicates(),
|
batch.predicates(),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue