Remove batch splitting system (#475)

* First pass at removing batch splitting

* Refactor to separate module loading from request parsing

* Consolidate module functionality

* Tidy up comments

* Use array of modules instead of HashMap

* Formatting

* Use module hashes when importing modules
This commit is contained in:
Rob Knight 2026-02-09 10:31:47 +01:00 committed by GitHub
parent 5dab8195b4
commit acab26e5c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1425 additions and 1938 deletions

View file

@ -16,7 +16,7 @@ use pod2::{
primitives::ec::schnorr::SecretKey, signer::Signer,
},
frontend::{MainPodBuilder, Operation, SignedDictBuilder},
lang::parse,
lang::load_module,
middleware::{MainPodProver, Params, VDSet},
};
@ -88,10 +88,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
game_pk = game_pk,
);
println!("# custom predicate batch:{}", input);
let batch = parse(&input, &params, &[])?
.first_batch()
.expect("Expected batch")
.clone();
let module = load_module(&input, "points_module", &params, vec![])?;
let batch = module.batch.clone();
let points_pred = batch.predicate_ref_by_name("points").unwrap();
let over_9000_pred = batch.predicate_ref_by_name("over_9000").unwrap();

View file

@ -836,7 +836,7 @@ pub mod tests {
frontend::{
self, literal, CustomPredicateBatchBuilder, MainPodBuilder, StatementTmplBuilder as STB,
},
lang::parse,
lang::load_module,
middleware::{
self, containers::Set, CustomPredicateRef, NativePredicate as NP, Signer as _,
DEFAULT_VD_LIST, DEFAULT_VD_SET,
@ -1165,7 +1165,7 @@ pub mod tests {
#[test]
fn test_undetermined_values() {
let params = Default::default();
let batch = parse(
let module = load_module(
r#"
two_equal(x,y,z) = OR(
Equal(x,y)
@ -1173,13 +1173,12 @@ pub mod tests {
Equal(x,z)
)
"#,
"test",
&params,
&[],
vec![],
)
.unwrap()
.first_batch()
.unwrap()
.clone();
.unwrap();
let batch = module.batch.clone();
let mut builder = MainPodBuilder::new(&params, &DEFAULT_VD_SET);
let cpr = CustomPredicateRef { batch, index: 0 };
let eq_st = builder.priv_op(frontend::Operation::eq(1, 1)).unwrap();

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use hex::ToHex;
use std::{collections::HashMap, sync::Arc};
use crate::{
frontend::{PodRequest, Result},
lang::parse,
lang::{load_module, parse_request, Module},
middleware::{CustomPredicateBatch, Params},
};
@ -32,11 +30,8 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
eth_dos_ind(src, dst, distance)
)
"#;
let batch = parse(input, params, &[])
.expect("lang parse")
.first_batch()
.expect("Expected batch")
.clone();
let module = load_module(input, "eth_dos", params, vec![]).expect("lang parse");
let batch = module.batch.clone();
println!("a.0. {}", batch.predicates()[0]);
println!("a.1. {}", batch.predicates()[1]);
println!("a.2. {}", batch.predicates()[2]);
@ -45,18 +40,26 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
}
pub fn eth_dos_request() -> Result<PodRequest> {
use hex::ToHex;
let batch = eth_dos_batch(&Params::default())?;
let batch_id = batch.id().encode_hex::<String>();
let eth_dos_module = Arc::new(Module::new(batch, HashMap::new()));
let module_hash = eth_dos_module.id().encode_hex::<String>();
let input = format!(
r#"
use batch _, _, _, eth_dos from 0x{batch_id}
use module 0x{} as eth_dos
REQUEST(
eth_dos(src, dst, distance)
eth_dos::eth_dos(src, dst, distance)
)
"#,
module_hash
);
let parsed = parse(&input, &Params::default(), &[batch])?;
Ok(parsed.request)
Ok(parse_request(
&input,
&Params::default(),
&[eth_dos_module],
)?)
}
#[cfg(test)]

View file

@ -12,7 +12,7 @@ use crate::{
frontend::{
MainPod, MainPodBuilder, Operation, PodRequest, Result, SignedDict, SignedDictBuilder,
},
lang::parse,
lang::parse_request,
middleware::{
self, containers::Set, hash_values, CustomPredicateRef, Params, Predicate, PublicKey,
Signer as _, Statement, StatementArg, TypedValue, VDSet, Value,
@ -90,8 +90,7 @@ pub fn zu_kyc_pod_request(gov_signer: &Value, pay_signer: &Value) -> Result<PodR
)
"#,
);
let parsed = parse(&input, &Params::default(), &[])?;
Ok(parsed.request)
Ok(parse_request(&input, &Params::default(), &[])?)
}
// ETHDoS

View file

@ -798,7 +798,6 @@ impl MainPodCompiler {
#[cfg(test)]
pub mod tests {
use num::BigUint;
use super::*;
@ -813,7 +812,7 @@ pub mod tests {
tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request,
zu_kyc_sign_dict_builders, EthDosHelper, MOCK_VD_SET,
},
lang::parse,
lang::load_module,
middleware::{
containers::{Array, Set},
Signer as _, Value,
@ -1382,11 +1381,8 @@ pub mod tests {
Equal(b, 5)
)
"#;
let batch = parse(input, &params, &[])
.unwrap()
.first_batch()
.unwrap()
.clone();
let module = load_module(input, "test", &params, vec![]).unwrap();
let batch = module.batch.clone();
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
// Try to build with wrong type in 1st arg
@ -1434,11 +1430,8 @@ pub mod tests {
c(6, 3)
)
"#;
let batch = parse(input, &params, &[])
.unwrap()
.first_batch()
.unwrap()
.clone();
let module = load_module(input, "test", &params, vec![]).unwrap();
let batch = module.batch.clone();
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
let mut builder = MainPodBuilder::new(&params, vd_set);
@ -1459,11 +1452,8 @@ pub mod tests {
c(6, 3)
)
"#;
let batch = parse(input, &params, &[])
.unwrap()
.first_batch()
.unwrap()
.clone();
let module = load_module(input, "test", &params, vec![]).unwrap();
let batch = module.batch.clone();
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
let mut builder = MainPodBuilder::new(&params, vd_set);
@ -1501,12 +1491,11 @@ pub mod tests {
"#;
// Parse and batch the predicate (this handles splitting internally)
let parsed = parse(input, &params, &[])?;
let batches = &parsed.custom_batches;
let module = load_module(input, "test", &params, vec![])?;
// Verify it was split
assert!(batches.split_chain("large_pred").is_some());
let chain_info = batches.split_chain("large_pred").unwrap();
assert!(module.split_chains.contains_key("large_pred"));
let chain_info = module.split_chains.get("large_pred").unwrap();
assert_eq!(chain_info.chain_pieces.len(), 2);
assert_eq!(chain_info.real_statement_count, 6);
@ -1538,10 +1527,10 @@ pub mod tests {
let statements = vec![st_a, st_b, st_c, st_d, st_e, st_f];
// Use apply_predicate (primary API) to automatically wire the split chain
let result = batches.apply_predicate(&mut builder, "large_pred", statements, true)?;
let result = module.apply_predicate(&mut builder, "large_pred", statements, true)?;
// The result should be a valid statement
let predicate = batches.predicate_ref_by_name("large_pred").unwrap();
let predicate = module.predicate_ref_by_name("large_pred").unwrap();
match &result {
Statement::Custom(pred_ref, _) => {
assert_eq!(pred_ref, &predicate);

View file

@ -632,7 +632,7 @@ mod tests {
dict,
examples::MOCK_VD_SET,
frontend::{Operation as FrontendOp, SignedDictBuilder},
lang::parse,
lang::load_module,
};
#[test]
@ -756,18 +756,17 @@ mod tests {
// pred_a accepts a Contains statement
// pred_b accepts a pred_a statement (Custom statement from pred_a)
let parsed = parse(
let module = load_module(
r#"
pred_a(X) = AND(Contains(X, "k", 1))
pred_b(X) = AND(pred_a(X))
"#,
"test",
&params,
&[],
vec![],
)
.expect("parse predicates");
let batch = parsed
.first_batch()
.expect("parse predicates should have a batch");
.expect("load module");
let batch = &module.batch;
let mut builder = MultiPodBuilder::new(&params, vd_set);
@ -1484,20 +1483,19 @@ mod tests {
let vd_set = &*MOCK_VD_SET;
// Chain of predicates: each accepts the output of the previous
let parsed = parse(
let module = load_module(
r#"
pred_a(X) = AND(Contains(X, "k", 1))
pred_b(X) = AND(pred_a(X))
pred_c(X) = AND(pred_b(X))
pred_d(X) = AND(pred_c(X))
"#,
"test",
&params,
&[],
vec![],
)
.expect("parse predicates");
let batch = parsed
.first_batch()
.expect("parse predicates should have a batch");
.expect("load module");
let batch = &module.batch;
let mut builder = MultiPodBuilder::new(&params, vd_set);
@ -1612,7 +1610,7 @@ mod tests {
// pred_a takes TWO custom statement arguments (b_out and c_out)
// pred_b and pred_c each take a Contains
// Note: AND clauses are newline-separated, not comma-separated
let parsed = parse(
let module = load_module(
r#"
pred_b(X) = AND(Contains(X, "k", 1))
pred_c(X) = AND(Contains(X, "k", 1))
@ -1621,13 +1619,12 @@ mod tests {
pred_c(Y)
)
"#,
"test",
&params,
&[],
vec![],
)
.expect("parse predicates");
let batch = parsed
.first_batch()
.expect("parse predicates should have a batch");
.expect("load module");
let batch = &module.batch;
let mut builder = MultiPodBuilder::new(&params, vd_set);

View file

@ -185,7 +185,7 @@ mod tests {
zu_kyc_pod_builder, zu_kyc_pod_request, zu_kyc_sign_dict_builders, MOCK_VD_SET,
},
frontend::{MainPodBuilder, Operation},
lang::parse,
lang::parse_request,
middleware::{Params, Value},
};
@ -210,7 +210,7 @@ mod tests {
assert!(request.exact_match_pod(&*kyc.pod).is_ok());
// This request does not match the POD, because the POD does not contain a NotEqual statement.
let non_matching_request = parse(
let non_matching_request = parse_request(
r#"
REQUEST(
NotEqual(4, 5)
@ -219,8 +219,7 @@ mod tests {
&params,
&[],
)
.unwrap()
.request;
.unwrap();
assert!(non_matching_request.exact_match_pod(&*kyc.pod).is_err());
}
@ -240,7 +239,7 @@ mod tests {
println!("{pod}");
let request = parse(
let request = parse_request(
r#"
REQUEST(
SumOf(a, b, c)
@ -252,7 +251,7 @@ mod tests {
)
.unwrap();
let bindings = request.request.exact_match_pod(&*pod.pod).unwrap();
let bindings = request.exact_match_pod(&*pod.pod).unwrap();
assert_eq!(*bindings.get("a").unwrap(), 10.into());
assert_eq!(*bindings.get("b").unwrap(), 9.into());
assert_eq!(*bindings.get("c").unwrap(), 1.into());

View file

@ -50,8 +50,8 @@ pub enum ValidationError {
span: Option<Span>,
},
#[error("Batch not found: {id}")]
BatchNotFound { id: String, span: Option<Span> },
#[error("Module not found: {name}")]
ModuleNotFound { name: String, span: Option<Span> },
#[error("Undefined predicate: {name}")]
UndefinedPredicate { name: String, span: Option<Span> },
@ -91,6 +91,18 @@ pub enum ValidationError {
#[error("Wildcard '{name}' collides with a predicate name")]
WildcardPredicateNameCollision { name: String },
#[error("Predicate definitions are not allowed in requests")]
PredicatesNotAllowedInRequest { span: Option<Span> },
#[error("REQUEST block is not allowed in modules")]
RequestNotAllowedInModule { span: Option<Span> },
#[error("Modules must contain at least one predicate definition")]
NoPredicatesInModule,
#[error("Requests must contain a REQUEST block")]
NoRequestBlock,
}
/// Lowering errors from frontend AST lowering to middleware

View file

@ -18,17 +18,17 @@ pub struct Document {
/// Top-level items that can appear in a document
#[derive(Debug, Clone, PartialEq)]
pub enum DocumentItem {
UseBatchStatement(UseBatchStatement),
UseModuleStatement(UseModuleStatement),
UseIntroStatement(UseIntroStatement),
CustomPredicateDef(CustomPredicateDef),
RequestDef(RequestDef),
}
/// Import statement: `use batch pred1, pred2, _ from 0x...`
/// Module import statement: `use module 0xHASH as alias`
#[derive(Debug, Clone, PartialEq)]
pub struct UseBatchStatement {
pub imports: Vec<ImportName>,
pub batch_ref: HashHex,
pub struct UseModuleStatement {
pub hash: HashHex,
pub alias: Identifier,
pub span: Option<Span>,
}
@ -40,19 +40,6 @@ pub struct UseIntroStatement {
pub intro_hash: HashHex,
pub span: Option<Span>,
}
/// Individual import name (identifier or unused "_")
#[derive(Debug, Clone, PartialEq)]
pub enum ImportName {
Named(String),
Unused, // "_"
}
/// Batch reference (hash)
#[derive(Debug, Clone, PartialEq)]
pub struct BatchRef {
pub hash: HashHex,
pub span: Option<Span>,
}
/// Intro predicate reference (hash)
#[derive(Debug, Clone, PartialEq)]
@ -96,11 +83,33 @@ pub enum ConjunctionType {
/// Statement template: predicate call with arguments
#[derive(Debug, Clone, PartialEq)]
pub struct StatementTmpl {
pub predicate: Identifier,
pub predicate: PredicateRef,
pub args: Vec<StatementTmplArg>,
pub span: Option<Span>,
}
/// Reference to a predicate (local or qualified with module name)
#[derive(Debug, Clone, PartialEq)]
pub enum PredicateRef {
/// Unqualified name (local or native predicate)
Local(Identifier),
/// Qualified name (module::predicate)
Qualified {
module: Identifier,
predicate: Identifier,
},
}
impl PredicateRef {
/// Get the predicate name (without module qualifier)
pub fn predicate_name(&self) -> &str {
match self {
PredicateRef::Local(id) => &id.name,
PredicateRef::Qualified { predicate, .. } => &predicate.name,
}
}
}
/// Arguments that can be passed to statements
#[derive(Debug, Clone, PartialEq)]
pub enum StatementTmplArg {
@ -256,7 +265,7 @@ impl fmt::Display for Document {
impl fmt::Display for DocumentItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DocumentItem::UseBatchStatement(u) => write!(f, "{}", u),
DocumentItem::UseModuleStatement(u) => write!(f, "{}", u),
DocumentItem::UseIntroStatement(u) => write!(f, "{}", u),
DocumentItem::CustomPredicateDef(c) => write!(f, "{}", c),
DocumentItem::RequestDef(r) => write!(f, "{}", r),
@ -264,16 +273,9 @@ impl fmt::Display for DocumentItem {
}
}
impl fmt::Display for UseBatchStatement {
impl fmt::Display for UseModuleStatement {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "use batch ")?;
for (i, import) in self.imports.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", import)?;
}
write!(f, " from {}", self.batch_ref)
write!(f, "use module {} as {}", self.hash, self.alias)
}
}
@ -290,21 +292,17 @@ impl fmt::Display for UseIntroStatement {
}
}
impl fmt::Display for ImportName {
impl fmt::Display for PredicateRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ImportName::Named(name) => write!(f, "{}", name),
ImportName::Unused => write!(f, "_"),
PredicateRef::Local(id) => write!(f, "{}", id),
PredicateRef::Qualified { module, predicate } => {
write!(f, "{}::{}", module, predicate)
}
}
}
}
impl fmt::Display for BatchRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.hash)
}
}
impl fmt::Display for IntroPredicateRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.hash)
@ -536,10 +534,10 @@ pub mod parse {
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::use_batch_statement => {
items.push(DocumentItem::UseBatchStatement(parse_use_batch_statement(
inner_pair,
)));
Rule::use_module_statement => {
items.push(DocumentItem::UseModuleStatement(
parse_use_module_statement(inner_pair),
));
}
Rule::use_intro_statement => {
items.push(DocumentItem::UseIntroStatement(parse_use_intro_statement(
@ -562,25 +560,17 @@ pub mod parse {
Ok(Document { items })
}
fn parse_use_batch_statement(pair: Pair<Rule>) -> UseBatchStatement {
assert_eq!(pair.as_rule(), Rule::use_batch_statement);
fn parse_use_module_statement(pair: Pair<Rule>) -> UseModuleStatement {
assert_eq!(pair.as_rule(), Rule::use_module_statement);
let span = get_span(&pair);
let mut inner = pair.into_inner();
let use_list_pair = inner
.find(|p| p.as_rule() == Rule::use_predicate_list)
.unwrap();
let batch_ref_pair = inner.find(|p| p.as_rule() == Rule::batch_ref).unwrap();
let hash = parse_hash_hex(inner.next().unwrap());
let alias = parse_identifier(inner.next().unwrap());
let imports = use_list_pair
.into_inner()
.filter(|p| p.as_rule() == Rule::import_name)
.map(parse_import_name)
.collect();
UseBatchStatement {
imports,
batch_ref: parse_hash_hex(batch_ref_pair.into_inner().next().unwrap()),
UseModuleStatement {
hash,
alias,
span: Some(span),
}
}
@ -622,16 +612,6 @@ pub mod parse {
}
}
fn parse_import_name(pair: Pair<Rule>) -> ImportName {
assert_eq!(pair.as_rule(), Rule::import_name);
let s = pair.as_str();
if s == "_" {
ImportName::Unused
} else {
ImportName::Named(s.to_string())
}
}
fn parse_hash_hex(pair: Pair<Rule>) -> HashHex {
assert_eq!(pair.as_rule(), Rule::hash_hex);
let span = get_span(&pair);
@ -748,7 +728,7 @@ pub mod parse {
let span = get_span(&pair);
let mut inner = pair.into_inner();
let predicate = parse_identifier(inner.next().unwrap());
let predicate = parse_predicate_ref(inner.next().unwrap());
let mut args = Vec::new();
if let Some(arg_list) = inner.next() {
@ -768,6 +748,22 @@ pub mod parse {
})
}
fn parse_predicate_ref(pair: Pair<Rule>) -> PredicateRef {
assert_eq!(pair.as_rule(), Rule::predicate_ref);
let inner = pair.into_inner().next().unwrap();
match inner.as_rule() {
Rule::qualified_predicate_ref => {
let mut parts = inner.into_inner();
let module = parse_identifier(parts.next().unwrap());
let predicate = parse_identifier(parts.next().unwrap());
PredicateRef::Qualified { module, predicate }
}
Rule::identifier => PredicateRef::Local(parse_identifier(inner)),
_ => unreachable!("Unexpected predicate_ref rule: {:?}", inner.as_rule()),
}
}
fn parse_statement_arg(pair: Pair<Rule>) -> Result<StatementTmplArg, parser::ParseError> {
assert_eq!(pair.as_rule(), Rule::statement_arg);
let inner = pair.into_inner().next().unwrap();
@ -1047,9 +1043,10 @@ mod tests {
fn clear_spans(doc: &mut Document) {
for item in &mut doc.items {
match item {
DocumentItem::UseBatchStatement(u) => {
DocumentItem::UseModuleStatement(u) => {
u.span = None;
u.batch_ref.span = None;
u.hash.span = None;
u.alias.span = None;
}
DocumentItem::UseIntroStatement(u) => {
u.span = None;
@ -1082,9 +1079,19 @@ mod tests {
}
}
fn clear_predicate_ref_spans(pred_ref: &mut PredicateRef) {
match pred_ref {
PredicateRef::Local(id) => id.span = None,
PredicateRef::Qualified { module, predicate } => {
module.span = None;
predicate.span = None;
}
}
}
fn clear_statement_spans(stmt: &mut StatementTmpl) {
stmt.span = None;
stmt.predicate.span = None;
clear_predicate_ref_spans(&mut stmt.predicate);
for arg in &mut stmt.args {
match arg {
StatementTmplArg::Literal(lit) => clear_literal_spans(lit),
@ -1168,8 +1175,8 @@ mod tests {
}
#[test]
fn test_use_batch_statement() {
let input = r#"use batch pred1, pred2, _ from 0x0000000000000000000000000000000000000000000000000000000000000000"#;
fn test_use_module_statement() {
let input = r#"use module 0x0000000000000000000000000000000000000000000000000000000000000000 as helpers"#;
test_roundtrip(input);
}
@ -1223,11 +1230,11 @@ mod tests {
#[test]
fn test_complete_document() {
let input = r#"use batch imported_pred from 0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd
let input = r#"use module 0x0000000000000000000000000000000000000000000000000000000000000000 as imported
is_valid(User, private: Config) = AND (
Equal(User["age"], Config["min_age"])
imported_pred(User, Config)
imported::some_pred(User, Config)
)
check_both(A, B, C) = OR (
@ -1306,7 +1313,7 @@ REQUEST(
// Check request structure
if let DocumentItem::RequestDef(req) = &ast.items[1] {
assert_eq!(req.statements.len(), 1);
assert_eq!(req.statements[0].predicate.name, "my_pred");
assert_eq!(req.statements[0].predicate.predicate_name(), "my_pred");
assert_eq!(req.statements[0].args.len(), 2);
} else {
panic!("Expected RequestDef");

File diff suppressed because it is too large Load diff

View file

@ -1,52 +1,69 @@
//! Lowering from frontend AST to middleware structures
//!
//! This module converts validated frontend AST to middleware data structures.
//! Supports automatic predicate splitting and multi-batch packing.
//! Supports automatic predicate splitting.
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
};
use crate::{
frontend::{BuilderArg, PredicateOrWildcard, StatementTmplBuilder},
lang::{
frontend_ast::*,
frontend_ast_batch::{self, PredicateBatches},
frontend_ast_split,
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
module, Module,
},
middleware::{
self, containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl,
StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
self, containers, CustomPredicateRef, IntroPredicateRef, Key, NativePredicate, Params,
Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg, Value,
Wildcard,
},
};
/// Context for predicate resolution - determines how local custom predicates are resolved
/// Context for predicate resolution - determines how predicates are resolved
pub enum ResolutionContext<'a> {
/// Request context: local custom predicates resolve to Intro/CustomPredicateRef via batches
Request {
batches: Option<&'a PredicateBatches>,
},
/// Batch context: local custom predicates may resolve to BatchSelf or Intro/CustomPredicateRef
Batch {
current_batch_idx: usize,
reference_map: &'a HashMap<String, (usize, usize)>,
existing_batches: &'a [Arc<CustomPredicateBatch>],
/// Request context: predicates resolve via imports only (no local definitions)
Request,
/// Module context: local predicates resolve to BatchSelf
Module {
/// Maps predicate name to index within the module
reference_map: &'a HashMap<String, usize>,
/// Name of the custom predicate being defined (for wildcard scope lookup)
custom_predicate_name: &'a str,
},
}
/// Resolve a predicate reference to a Predicate using the symbol table
pub fn resolve_predicate_ref(
pred_ref: &PredicateRef,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Option<PredicateOrWildcard> {
match pred_ref {
PredicateRef::Qualified { module, predicate } => {
// Look up the module in the imported_modules
let imported_module = symbols.imported_modules.get(&module.name)?;
// Find the predicate index in the module
let idx = *imported_module.predicate_index.get(&predicate.name)?;
Some(PredicateOrWildcard::Predicate(Predicate::Custom(
CustomPredicateRef::new(imported_module.batch.clone(), idx),
)))
}
PredicateRef::Local(id) => resolve_predicate(&id.name, symbols, context),
}
}
/// Resolve a predicate name to a Predicate using the symbol table
pub fn resolve_predicate(
pred_name: &str,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Option<PredicateOrWildcard> {
// 0. Try wildcard first
if let ResolutionContext::Batch {
// 0. Try wildcard first (only in module context where we're defining predicates)
if let ResolutionContext::Module {
custom_predicate_name,
..
} = context
@ -69,28 +86,35 @@ pub fn resolve_predicate(
PredicateKind::Native(np) => Predicate::Native(*np),
PredicateKind::Custom { .. } => match context {
ResolutionContext::Request { batches } => {
let batches = batches.as_ref()?;
let pred_ref = batches.predicate_ref_by_name(pred_name)?;
Predicate::Custom(pred_ref)
ResolutionContext::Request => {
// Requests can't define local predicates, so this shouldn't happen
return None;
}
ResolutionContext::Module { reference_map, .. } => {
resolve_local_predicate(pred_name, reference_map)?
}
ResolutionContext::Batch {
current_batch_idx,
reference_map,
existing_batches,
..
} => resolve_local_predicate(
pred_name,
*current_batch_idx,
reference_map,
existing_batches,
)?,
},
PredicateKind::BatchImported { batch, index } => {
Predicate::Custom(CustomPredicateRef::new(batch.clone(), *index))
}
PredicateKind::ModuleImported {
module_name,
predicate_index,
..
} => {
// Look up the module in the imported_modules
let module = symbols
.imported_modules
.get(module_name)
.expect("Module should exist if ModuleImported predicate kind exists");
Predicate::Custom(CustomPredicateRef::new(
module.batch.clone(),
*predicate_index,
))
}
PredicateKind::IntroImported {
name,
verifier_data_hash,
@ -103,51 +127,25 @@ pub fn resolve_predicate(
return Some(PredicateOrWildcard::Predicate(predicate));
}
// 3. In batch context, also check reference_map for split chain pieces
// 3. In module context, also check reference_map for split chain pieces
// (predicates created by splitting that aren't in the original symbol table)
if let ResolutionContext::Batch {
current_batch_idx,
reference_map,
existing_batches,
..
} = context
{
if let ResolutionContext::Module { reference_map, .. } = context {
if reference_map.contains_key(pred_name) {
return resolve_local_predicate(
pred_name,
*current_batch_idx,
reference_map,
existing_batches,
)
.map(PredicateOrWildcard::Predicate);
return resolve_local_predicate(pred_name, reference_map)
.map(PredicateOrWildcard::Predicate);
}
}
None
}
/// Resolve a local predicate (one in this document or a split chain piece) using the reference_map
/// Resolve a local predicate (one in this module or a split chain piece) using the reference_map
fn resolve_local_predicate(
pred_name: &str,
current_batch_idx: usize,
reference_map: &HashMap<String, (usize, usize)>,
existing_batches: &[Arc<CustomPredicateBatch>],
reference_map: &HashMap<String, usize>,
) -> Option<Predicate> {
let &(target_batch, target_idx) = reference_map.get(pred_name)?;
if target_batch == current_batch_idx {
Some(Predicate::BatchSelf(target_idx))
} else if target_batch < current_batch_idx {
let batch = &existing_batches[target_batch];
Some(Predicate::Custom(CustomPredicateRef::new(
batch.clone(),
target_idx,
)))
} else {
unreachable!(
"Forward cross-batch reference should be impossible: {} -> {}",
current_batch_idx, target_batch
);
}
let &idx = reference_map.get(pred_name)?;
Some(Predicate::BatchSelf(idx))
}
// ============================================================================
@ -155,7 +153,7 @@ fn resolve_local_predicate(
// ============================================================================
// These functions convert AST types to middleware/builder types and are used
// by both the request lowering (in this module) and predicate batching
// (in frontend_ast_batch).
// (in module.rs).
/// Lower a literal value from AST to middleware Value.
///
@ -215,38 +213,37 @@ pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
}
}
/// Result of lowering: optional custom predicate batches and optional request
///
/// A Podlang file can contain:
/// - Just custom predicates (batches: Some, request: None)
/// - Just a request (batches: None, request: Some)
/// - Both (batches: Some, request: Some)
/// - Neither (batches: None, request: None) - just imports
#[derive(Debug, Clone)]
pub struct LoweredOutput {
pub batches: Option<PredicateBatches>,
pub request: Option<crate::frontend::PodRequest>,
}
pub use crate::lang::error::LoweringError;
/// Lower a validated AST to middleware structures
/// Lower a validated module AST to a Module
///
/// Returns both the custom predicate batch (if any) and the request (if any).
/// At least one will be Some if the document contains custom predicates or a request.
pub fn lower(
/// The validated AST must have been validated in Module mode.
pub fn lower_module(
validated: ValidatedAST,
params: &Params,
batch_name: String,
) -> Result<LoweredOutput, LoweringError> {
module_name: &str,
) -> Result<Module, LoweringError> {
if !validated.diagnostics().is_empty() {
// For now, treat any diagnostics as errors
// In future we could allow warnings
return Err(LoweringError::ValidationErrors);
}
let lowerer = Lowerer::new(validated, params);
lowerer.lower(batch_name)
lowerer.lower_module(module_name)
}
/// Lower a validated request AST to a PodRequest
///
/// The validated AST must have been validated in Request mode.
pub fn lower_request(
validated: ValidatedAST,
params: &Params,
) -> Result<crate::frontend::PodRequest, LoweringError> {
if !validated.diagnostics().is_empty() {
return Err(LoweringError::ValidationErrors);
}
let lowerer = Lowerer::new(validated, params);
lowerer.lower_request()
}
struct Lowerer<'a> {
@ -259,52 +256,33 @@ impl<'a> Lowerer<'a> {
Self { validated, params }
}
fn lower(self, batch_name: String) -> Result<LoweredOutput, LoweringError> {
// Lower custom predicates (if any) - now supports multiple batches
let batches = self.lower_batches(batch_name)?;
// Lower request (if any) - pass batches so refs can be resolved
let request = self.lower_request(batches.as_ref())?;
Ok(LoweredOutput { batches, request })
}
fn lower_batches(&self, batch_name: String) -> Result<Option<PredicateBatches>, LoweringError> {
fn lower_module(self, module_name: &str) -> Result<Module, LoweringError> {
// Extract and split custom predicates from document
let custom_predicates = self.extract_and_split_predicates()?;
// If no custom predicates, return None
if custom_predicates.is_empty() {
return Ok(None);
}
// Use the new batching module to pack predicates into batches
// Pass the symbol table for unified predicate resolution
let batches = frontend_ast_batch::batch_predicates(
// Build the module from split predicates
let module = module::build_module(
custom_predicates,
self.params,
&batch_name,
module_name,
self.validated.symbols(),
)?;
Ok(Some(batches))
Ok(module)
}
fn lower_request(
&self,
batches: Option<&PredicateBatches>,
) -> Result<Option<crate::frontend::PodRequest>, LoweringError> {
fn lower_request(self) -> Result<crate::frontend::PodRequest, LoweringError> {
let doc = self.validated.document();
// Find request definition (if any)
let request_def = doc.items.iter().find_map(|item| match item {
DocumentItem::RequestDef(req) => Some(req),
_ => None,
});
let Some(request_def) = request_def else {
return Ok(None);
};
// Find request definition
let request_def = doc
.items
.iter()
.find_map(|item| match item {
DocumentItem::RequestDef(req) => Some(req),
_ => None,
})
.expect("Request mode validation ensures REQUEST block exists");
// Build wildcard map from all wildcards used in the request statements
let wildcard_map = self.build_request_wildcard_map(request_def);
@ -312,18 +290,17 @@ impl<'a> Lowerer<'a> {
// Lower each statement to middleware templates, resolving predicates
let mut request_templates = Vec::new();
for stmt in &request_def.statements {
let mw_stmt = self.lower_request_statement(stmt, &wildcard_map, batches)?;
let mw_stmt = self.lower_request_statement(stmt, &wildcard_map)?;
request_templates.push(mw_stmt);
}
Ok(Some(crate::frontend::PodRequest::new(request_templates)))
Ok(crate::frontend::PodRequest::new(request_templates))
}
fn lower_request_statement(
&self,
stmt: &StatementTmpl,
wildcard_map: &HashMap<String, usize>,
batches: Option<&PredicateBatches>,
) -> Result<MWStatementTmpl, LoweringError> {
// Enforce argument count limit for request statements
if stmt.args.len() > Params::max_statement_args() {
@ -333,16 +310,16 @@ impl<'a> Lowerer<'a> {
});
}
let pred_name = &stmt.predicate.name;
let symbols = self.validated.symbols();
// Resolve predicate using the unified resolution function
let context = ResolutionContext::Request { batches };
let predicate = resolve_predicate(pred_name, symbols, &context).ok_or_else(|| {
LoweringError::PredicateNotFound {
name: pred_name.clone(),
}
})?;
let context = ResolutionContext::Request;
let predicate =
resolve_predicate_ref(&stmt.predicate, symbols, &context).ok_or_else(|| {
LoweringError::PredicateNotFound {
name: format!("{}", stmt.predicate),
}
})?;
// Create a builder with the resolved predicate and desugar
let mut builder = StatementTmplBuilder::new(predicate.clone());
@ -453,31 +430,24 @@ impl<'a> Lowerer<'a> {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::lang::{
frontend_ast::parse::parse_document, frontend_ast_validate::validate, parser::parse_podlang,
frontend_ast::parse::parse_document,
frontend_ast_validate::{validate, ParseMode},
parser::parse_podlang,
};
fn parse_validate_and_lower(
fn parse_validate_and_lower_module(
input: &str,
params: &Params,
) -> Result<LoweredOutput, LoweringError> {
) -> Result<Module, LoweringError> {
let parsed = parse_podlang(input).expect("Failed to parse");
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
let validated = validate(document, &[]).expect("Failed to validate");
lower(validated, params, "test_batch".to_string())
}
// Helper to get the first batch from the output (expecting it to exist)
fn expect_batch(
output: &LoweredOutput,
) -> &std::sync::Arc<crate::middleware::CustomPredicateBatch> {
output
.batches
.as_ref()
.expect("Expected batches to be present")
.first_batch()
.expect("Expected at least one batch")
let validated =
validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate");
lower_module(validated, params, "test_batch")
}
#[test]
@ -489,16 +459,16 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
if let Err(e) = &result {
eprintln!("Error: {:?}", e);
}
assert!(result.is_ok());
let lowered = result.unwrap();
assert_eq!(expect_batch(&lowered).predicates().len(), 1);
let module = result.unwrap();
assert_eq!(module.batch.predicates().len(), 1);
let pred = &expect_batch(&lowered).predicates()[0];
let pred = &module.batch.predicates()[0];
assert_eq!(pred.name, "my_pred");
assert_eq!(pred.args_len(), 2);
assert_eq!(pred.wildcard_names().len(), 2);
@ -515,11 +485,11 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
let pred = &expect_batch(&lowered).predicates()[0];
let module = result.unwrap();
let pred = &module.batch.predicates()[0];
assert_eq!(pred.args_len(), 1); // Only A is public
assert_eq!(pred.wildcard_names().len(), 3); // A, B, C total
}
@ -534,11 +504,11 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
let pred = &expect_batch(&lowered).predicates()[0];
let module = result.unwrap();
let pred = &module.batch.predicates()[0];
assert!(pred.is_disjunction());
}
@ -556,23 +526,22 @@ mod tests {
"#;
let params = Params::default(); // max_custom_predicate_arity = 5
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
if let Err(e) = &result {
eprintln!("Splitting error: {:?}", e);
}
assert!(result.is_ok());
let lowered = result.unwrap();
let module = result.unwrap();
// Should be automatically split into 2 predicates (my_pred and my_pred_1)
let batches = lowered.batches.as_ref().expect("Expected batches");
assert_eq!(batches.total_predicate_count(), 2);
assert_eq!(module.batch.predicates().len(), 2);
// With topological sorting, my_pred_1 comes first (since my_pred depends on it)
// my_pred_1 has 2 statements
// my_pred has 5 statements (4 + chain call)
// Just verify we have the right total statement counts
let batch = batches.first_batch().unwrap();
let total_statements: usize = batch
let total_statements: usize = module
.batch
.predicates()
.iter()
.map(|p| p.statements().len())
@ -593,11 +562,11 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
assert_eq!(expect_batch(&lowered).predicates().len(), 2);
let module = result.unwrap();
assert_eq!(module.batch.predicates().len(), 2);
}
#[test]
@ -613,11 +582,11 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
let pred2 = &expect_batch(&lowered).predicates()[1];
let module = result.unwrap();
let pred2 = &module.batch.predicates()[1];
let stmt = &pred2.statements()[0];
// Should be BatchSelf(0) referring to pred1
@ -638,7 +607,7 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
}
@ -651,11 +620,11 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
let lowered = result.unwrap();
let pred = &expect_batch(&lowered).predicates()[0];
let module = result.unwrap();
let pred = &module.batch.predicates()[0];
let stmt = &pred.statements()[0];
// Should desugar to the Contains predicate
@ -677,7 +646,7 @@ mod tests {
"#;
let params = Params::default();
let result = parse_validate_and_lower(input, &params);
let result = parse_validate_and_lower_module(input, &params);
assert!(result.is_ok());
}
@ -706,18 +675,18 @@ mod tests {
let parsed = parse_podlang(&input).expect("Failed to parse");
let document =
parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document");
let validated = validate(document, &[]).expect("Failed to validate");
let result = lower(validated, &params, "test_batch".to_string());
let validated =
validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate");
let result = lower_module(validated, &params, "test_batch");
assert!(result.is_ok(), "Lowering failed: {:?}", result.err());
let lowered = result.unwrap();
let batch = expect_batch(&lowered);
let module = result.unwrap();
// Should have one custom predicate
assert_eq!(batch.predicates().len(), 1);
assert_eq!(module.batch.predicates().len(), 1);
let pred = &batch.predicates()[0];
let pred = &module.batch.predicates()[0];
assert_eq!(pred.name, "my_pred");
// 2 statements: Equal and external_check
assert_eq!(pred.statements().len(), 2);

View file

@ -620,7 +620,7 @@ fn generate_chain_predicates(
.collect();
let chain_call = StatementTmpl {
predicate: next_pred_name,
predicate: PredicateRef::Local(next_pred_name),
args: chain_call_args,
span: None,
};
@ -832,7 +832,7 @@ mod tests {
let original = &chain[1];
assert_eq!(original.name.name, "complex");
let last_stmt = original.statements.last().unwrap();
assert_eq!(last_stmt.predicate.name, "complex_1");
assert_eq!(last_stmt.predicate.predicate_name(), "complex_1");
}
#[test]

View file

@ -12,7 +12,7 @@ use std::{
use hex::ToHex;
use crate::{
lang::frontend_ast::*,
lang::{frontend_ast::*, Module},
middleware::{CustomPredicateBatch, Hash, NativePredicate},
};
@ -49,6 +49,8 @@ pub struct SymbolTable {
pub predicates: HashMap<String, PredicateInfo>,
/// Wildcard scopes for each custom predicate
pub wildcard_scopes: HashMap<String, WildcardScope>,
/// Imported modules (bound name → Module reference)
pub imported_modules: HashMap<String, Arc<Module>>,
}
/// Information about a predicate
@ -71,6 +73,11 @@ pub enum PredicateKind {
batch: Arc<CustomPredicateBatch>,
index: usize,
},
ModuleImported {
module_name: String,
predicate_name: String,
predicate_index: usize,
},
IntroImported {
name: String,
verifier_data_hash: Hash,
@ -107,39 +114,45 @@ pub enum DiagnosticLevel {
pub use crate::lang::error::ValidationError;
/// Validate an AST document
/// Mode for parsing/validation - determines what constructs are allowed
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParseMode {
/// Module mode: predicate definitions allowed, REQUEST block not allowed
Module,
/// Request mode: REQUEST block required, predicate definitions not allowed
Request,
}
/// Validate an AST document in the given mode
pub fn validate(
document: Document,
available_batches: &[Arc<CustomPredicateBatch>],
available_modules: &HashMap<Hash, Arc<Module>>,
mode: ParseMode,
) -> Result<ValidatedAST, ValidationError> {
let validator = Validator::new(available_batches);
let validator = Validator::new(available_modules, mode);
validator.validate(document)
}
struct Validator {
available_batches: HashMap<String, Arc<CustomPredicateBatch>>,
available_modules: HashMap<Hash, Arc<Module>>,
symbols: SymbolTable,
diagnostics: Vec<Diagnostic>,
custom_predicate_count: usize,
mode: ParseMode,
}
impl Validator {
fn new(batches: &[Arc<CustomPredicateBatch>]) -> Self {
let mut available_batches = HashMap::new();
for batch in batches {
// Store by hex ID for lookup
let id = format!("0x{}", batch.id().encode_hex::<String>());
available_batches.insert(id, batch.clone());
}
fn new(available_modules: &HashMap<Hash, Arc<Module>>, mode: ParseMode) -> Self {
Self {
available_batches,
available_modules: available_modules.clone(),
symbols: SymbolTable {
predicates: HashMap::new(),
wildcard_scopes: HashMap::new(),
imported_modules: HashMap::new(),
},
diagnostics: Vec::new(),
custom_predicate_count: 0,
mode,
}
}
@ -160,25 +173,36 @@ impl Validator {
fn build_symbol_table(&mut self, document: &Document) -> Result<(), ValidationError> {
// First process imports
for item in &document.items {
if let DocumentItem::UseBatchStatement(use_stmt) = item {
self.process_use_batch_statement(use_stmt)?;
if let DocumentItem::UseModuleStatement(use_stmt) = item {
self.process_use_module_statement(use_stmt)?;
}
if let DocumentItem::UseIntroStatement(use_stmt) = item {
self.process_use_intro_statement(use_stmt)?;
}
}
// Then process custom predicate definitions
// Check mode constraints for predicate definitions
let mut has_predicates = false;
for item in &document.items {
if let DocumentItem::CustomPredicateDef(pred_def) = item {
if self.mode == ParseMode::Request {
return Err(ValidationError::PredicatesNotAllowedInRequest {
span: pred_def.span,
});
}
has_predicates = true;
self.process_custom_predicate_def(pred_def)?;
}
}
// Check for multiple REQUEST definitions (only one allowed)
// Check mode constraints for REQUEST blocks
let mut has_request = false;
let mut first_request_span = None;
for item in &document.items {
if let DocumentItem::RequestDef(req) = item {
if self.mode == ParseMode::Module {
return Err(ValidationError::RequestNotAllowedInModule { span: req.span });
}
if let Some(first_span) = first_request_span {
return Err(ValidationError::MultipleRequestDefinitions {
first_span: Some(first_span),
@ -186,61 +210,44 @@ impl Validator {
});
}
first_request_span = req.span;
has_request = true;
}
}
// Enforce that modules have predicates and requests have a REQUEST block
match self.mode {
ParseMode::Module if !has_predicates => {
return Err(ValidationError::NoPredicatesInModule);
}
ParseMode::Request if !has_request => {
return Err(ValidationError::NoRequestBlock);
}
_ => {}
}
Ok(())
}
fn process_use_batch_statement(
fn process_use_module_statement(
&mut self,
use_stmt: &UseBatchStatement,
use_stmt: &UseModuleStatement,
) -> Result<(), ValidationError> {
let batch_id = format!("0x{}", use_stmt.batch_ref.hash.encode_hex::<String>());
let alias = &use_stmt.alias.name;
let hash = &use_stmt.hash.hash;
let batch = self.available_batches.get(&batch_id).ok_or_else(|| {
ValidationError::BatchNotFound {
id: batch_id.clone(),
span: use_stmt.batch_ref.span,
}
})?;
// Check if the module is available by hash
let module =
self.available_modules
.get(hash)
.ok_or_else(|| ValidationError::ModuleNotFound {
name: hash.encode_hex::<String>(),
span: use_stmt.span,
})?;
if use_stmt.imports.len() != batch.predicates().len() {
return Err(ValidationError::ImportArityMismatch {
expected: batch.predicates().len(),
found: use_stmt.imports.len(),
span: use_stmt.span,
});
}
for (i, import) in use_stmt.imports.iter().enumerate() {
if let ImportName::Named(name) = import {
if self.symbols.predicates.contains_key(name) {
return Err(ValidationError::DuplicateImport {
name: name.clone(),
span: use_stmt.span,
});
}
let pred = &batch.predicates()[i];
// CustomPredicate has args_len (public args) and wildcard_names (total args)
let total_arity = pred.wildcard_names.len();
let public_arity = pred.args_len;
self.symbols.predicates.insert(
name.clone(),
PredicateInfo {
kind: PredicateKind::BatchImported {
batch: batch.clone(),
index: i,
},
arity: total_arity,
public_arity,
source_span: use_stmt.span,
},
);
}
}
// Store the module keyed by alias for later qualified name resolution
self.symbols
.imported_modules
.insert(alias.clone(), module.clone());
Ok(())
}
@ -435,7 +442,11 @@ impl Validator {
stmt: &StatementTmpl,
wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> {
let pred_name = &stmt.predicate.name;
let pred_name = stmt.predicate.predicate_name();
let pred_span = match &stmt.predicate {
PredicateRef::Local(id) => id.span,
PredicateRef::Qualified { predicate, .. } => predicate.span,
};
let wc_names = match wildcard_context {
Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(),
@ -444,31 +455,65 @@ impl Validator {
self.validate_wildcard_names(&wc_names)?;
// Check if predicate exists
let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) {
// Native predicate
Some(PredicateInfo {
kind: PredicateKind::Native(native),
arity: native.arity(),
public_arity: native.arity(),
source_span: None,
})
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
// Custom or imported predicate
Some(info.clone())
} else if wc_names.contains(pred_name) {
None
} else {
return Err(ValidationError::UndefinedPredicate {
name: pred_name.clone(),
span: stmt.predicate.span,
});
let pred_info = match &stmt.predicate {
PredicateRef::Qualified { module, predicate } => {
// Look up the predicate in the imported module
let module_name = &module.name;
if let Some(imported_module) = self.symbols.imported_modules.get(module_name) {
// Find the predicate in the module
if let Some(&idx) = imported_module.predicate_index.get(&predicate.name) {
let module_pred = &imported_module.batch.predicates()[idx];
Some(PredicateInfo {
kind: PredicateKind::ModuleImported {
module_name: module_name.clone(),
predicate_name: predicate.name.clone(),
predicate_index: idx,
},
arity: module_pred.wildcard_names.len(),
public_arity: module_pred.args_len,
source_span: None,
})
} else {
return Err(ValidationError::UndefinedPredicate {
name: format!("{}::{}", module_name, predicate.name),
span: pred_span,
});
}
} else {
return Err(ValidationError::ModuleNotFound {
name: module_name.clone(),
span: module.span,
});
}
}
PredicateRef::Local(_) => {
if let Ok(native) = NativePredicate::from_str(pred_name) {
// Native predicate
Some(PredicateInfo {
kind: PredicateKind::Native(native),
arity: native.arity(),
public_arity: native.arity(),
source_span: None,
})
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
// Custom or imported predicate
Some(info.clone())
} else if wc_names.contains(&pred_name.to_string()) {
None
} else {
return Err(ValidationError::UndefinedPredicate {
name: pred_name.to_string(),
span: pred_span,
});
}
}
};
if let Some(ref pred_info) = pred_info {
let expected_arity = pred_info.public_arity;
if stmt.args.len() != expected_arity {
return Err(ValidationError::ArgumentCountMismatch {
predicate: pred_name.clone(),
predicate: pred_name.to_string(),
expected: expected_arity,
found: stmt.args.len(),
span: stmt.span,
@ -491,13 +536,15 @@ impl Validator {
// For custom predicates, only wildcards and literals are allowed
if matches!(
pred_info.map(|i| &i.kind),
Some(PredicateKind::Custom { .. }) | Some(PredicateKind::BatchImported { .. })
Some(PredicateKind::Custom { .. })
| Some(PredicateKind::BatchImported { .. })
| Some(PredicateKind::ModuleImported { .. })
) {
for arg in &stmt.args {
match arg {
StatementTmplArg::AnchoredKey(_) => {
return Err(ValidationError::InvalidArgumentType {
predicate: stmt.predicate.name.clone(),
predicate: stmt.predicate.predicate_name().to_string(),
span: stmt.span,
});
}
@ -552,25 +599,30 @@ impl Validator {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::{
lang::{frontend_ast::parse::parse_document, parser::parse_podlang},
lang::{frontend_ast::parse::parse_document, parser::parse_podlang, Module},
middleware::{CustomPredicate, Params, EMPTY_HASH},
};
fn parse_and_validate(
fn parse_and_validate_module(
input: &str,
batches: &[Arc<CustomPredicateBatch>],
modules: &HashMap<Hash, Arc<Module>>,
) -> Result<ValidatedAST, ValidationError> {
let parsed = parse_podlang(input).expect("Failed to parse");
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
validate(document, batches)
validate(document, modules, ParseMode::Module)
}
#[test]
fn test_validate_empty() {
let result = parse_and_validate("", &[]);
assert!(result.is_ok());
fn parse_and_validate_request(
input: &str,
modules: &HashMap<Hash, Arc<Module>>,
) -> Result<ValidatedAST, ValidationError> {
let parsed = parse_podlang(input).expect("Failed to parse");
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
validate(document, modules, ParseMode::Request)
}
#[test]
@ -578,7 +630,7 @@ mod tests {
let input = r#"REQUEST(
Equal(A["foo"], B["bar"])
)"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_request(input, &HashMap::new());
assert!(result.is_ok());
}
@ -589,7 +641,7 @@ mod tests {
Equal(A["foo"], B["bar"])
)
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(result.is_ok());
let validated = result.unwrap();
@ -602,7 +654,7 @@ mod tests {
let input = r#"REQUEST(
UndefinedPred(A, B)
)"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_request(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::UndefinedPredicate { .. })
@ -616,7 +668,7 @@ mod tests {
Equal(A["foo"], B["bar"])
)
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(
matches!(result, Err(ValidationError::UndefinedWildcard { name, .. }) if name == "B")
);
@ -627,7 +679,7 @@ mod tests {
let input = r#"REQUEST(
Equal(A, B, C)
)"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_request(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::ArgumentCountMismatch { .. })
@ -640,7 +692,7 @@ mod tests {
my_pred(A) = AND (Equal(A["x"], 1))
my_pred(B) = AND (Equal(B["y"], 2))
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::DuplicatePredicate { .. })
@ -652,7 +704,7 @@ mod tests {
let input = r#"
my_pred(A, A) = AND (Equal(A["x"], 1))
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::DuplicateWildcard { .. })
@ -664,7 +716,7 @@ mod tests {
let input = r#"
my_pred(A, Lt) = AND (Equal(A["x"], Lt))
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::WildcardPredicateNameCollision { .. })
@ -673,16 +725,36 @@ mod tests {
#[test]
fn test_custom_predicate_with_anchored_key() {
let input = r#"
my_pred(A, B) = AND (
Equal(A["foo"], B["bar"])
)
// First create a module with the predicate
let params = Params::default();
let pred = CustomPredicate::and(
&params,
"my_pred".to_string(),
vec![],
2,
vec!["A".to_string(), "B".to_string()],
)
.unwrap();
let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]);
let test_module = Arc::new(Module::new(batch, HashMap::new()));
let module_hash = test_module.id().encode_hex::<String>();
let mut available_modules = HashMap::new();
available_modules.insert(test_module.id(), test_module);
// Test that passing anchored key to custom predicate fails
let input = format!(
r#"
use module 0x{} as testmod
REQUEST(
my_pred(X["key"], Y)
testmod::my_pred(X["key"], Y)
)
"#;
let result = parse_and_validate(input, &[]);
"#,
module_hash
);
let result = parse_and_validate_request(&input, &available_modules);
assert!(matches!(
result,
Err(ValidationError::InvalidArgumentType { .. })
@ -700,7 +772,7 @@ mod tests {
Equal(B["x"], 1)
)
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(result.is_ok());
}
@ -712,7 +784,7 @@ mod tests {
Equal(B["z"], C["w"])
)
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_module(input, &HashMap::new());
assert!(result.is_ok());
let validated = result.unwrap();
@ -743,7 +815,7 @@ mod tests {
span: None,
})],
};
let result = validate(document, &[]);
let result = validate(document, &HashMap::new(), ParseMode::Module);
assert!(matches!(
result,
Err(ValidationError::EmptyStatementList { .. })
@ -756,7 +828,7 @@ mod tests {
REQUEST(Equal(A["x"], 1))
REQUEST(Equal(B["y"], 2))
"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_request(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::MultipleRequestDefinitions { .. })
@ -764,10 +836,14 @@ mod tests {
}
#[test]
fn test_use_statement() {
fn test_use_module_statement() {
use std::sync::Arc;
use hex::ToHex;
let params = Params::default();
// Create a batch to import
// Create a module to import
let pred = CustomPredicate::and(
&params,
"imported".to_string(),
@ -778,28 +854,33 @@ mod tests {
.unwrap();
let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]);
let test_module = Arc::new(Module::new(batch, HashMap::new()));
let module_hash = test_module.id().encode_hex::<String>();
let mut available_modules = HashMap::new();
available_modules.insert(test_module.id(), test_module);
let batch_id = batch.id().encode_hex::<String>();
let input = format!(
r#"
use batch imported_pred from 0x{}
use module 0x{} as testmod
use intro intro_pred() from 0x{}
REQUEST(
imported_pred(A, B)
testmod::imported(A, B)
intro_pred()
)
"#,
batch_id,
module_hash,
EMPTY_HASH.encode_hex::<String>()
);
let result = parse_and_validate(&input, &[batch]);
let result = parse_and_validate_request(&input, &available_modules);
assert!(result.is_ok());
let validated = result.unwrap();
assert!(validated.symbols.predicates.contains_key("imported_pred"));
// Module predicates are accessed via qualified names, so no local binding
assert!(validated.symbols.predicates.contains_key("intro_pred"));
assert!(validated.symbols.imported_modules.contains_key("testmod"));
}
#[test]
@ -809,7 +890,7 @@ mod tests {
DictContains(D, K, V)
SetNotContains(S, E)
)"#;
let result = parse_and_validate(input, &[]);
let result = parse_and_validate_request(input, &HashMap::new());
assert!(result.is_ok());
}
}

View file

@ -26,12 +26,9 @@ arg_section = {
public_arg_list = { identifier ~ ("," ~ identifier)* }
private_arg_list = { identifier ~ ("," ~ identifier)* }
document = { SOI ~ (use_batch_statement | use_intro_statement | custom_predicate_def | request_def)* ~ EOI }
document = { SOI ~ (use_module_statement | use_intro_statement | custom_predicate_def | request_def)* ~ EOI }
use_batch_statement = { "use" ~ "batch" ~ use_predicate_list ~ "from" ~ batch_ref }
use_predicate_list = { import_name ~ ("," ~ import_name)* }
import_name = { identifier | "_" }
batch_ref = { hash_hex }
use_module_statement = { "use" ~ "module" ~ hash_hex ~ "as" ~ identifier }
use_intro_statement = { "use" ~ "intro" ~ identifier ~ "(" ~ use_intro_arg_list? ~ ")" ~ "from" ~ intro_predicate_ref }
use_intro_arg_list = { identifier ~ ("," ~ identifier)* }
@ -55,7 +52,11 @@ statement_list = { statement+ }
statement_arg = { literal_value | anchored_key | identifier }
statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* }
statement = { identifier ~ "(" ~ statement_arg_list? ~ ")" }
// Predicate reference: either qualified (module::predicate) or local (predicate)
predicate_ref = { qualified_predicate_ref | identifier }
qualified_predicate_ref = { identifier ~ "::" ~ identifier }
statement = { predicate_ref ~ "(" ~ statement_arg_list? ~ ")" }
// Anchored Key: Var["key_literal"] or Var.key_identifier
anchored_key = {

View file

@ -1,110 +1,118 @@
//! Podlang front-end: parsing, validation, lowering, and multi-batch output.
//! Podlang front-end: parsing, validation, and lowering.
//!
//! This module is the high-level entrypoint to the Podlang pipeline. It:
//! - Parses a Podlang document (`parse_podlang`).
//! - Validates names, imports, and well-formedness (`frontend_ast_validate`).
//! - Lowers to middleware structures, including automatic predicate splitting and
//! dependency-aware packing into one or more custom predicate batches (`frontend_ast_split`,
//! `frontend_ast_batch`, `frontend_ast_lower`).
//! This module is the high-level entrypoint to the Podlang pipeline.
//!
//! The result is a [`PodlangOutput`], which contains:
//! - `custom_batches`: a [`PredicateBatches`] container (possibly empty) with all custom
//! predicates defined in the document. Use
//! [`PredicateBatches::apply_predicate`](crate::lang::frontend_ast_batch::PredicateBatches::apply_predicate)
//! to apply a predicate into a `MainPodBuilder` (recommended primary API), or
//! [`apply_predicate_with`](crate::lang::frontend_ast_batch::PredicateBatches::apply_predicate_with)
//! for advanced control.
//! - `request`: a `PodRequest` containing the request templates defined by a `REQUEST(...)` block
//! in the document (or empty if none was provided).
//! ## API
//!
//! Notes
//! - Predicate splitting: large predicates are automatically split into a chain of smaller
//! predicates while preserving semantics; only the final chain result is public when applying a
//! predicate as public.
//! - Multi-batch packing: predicates are packed dependency-aware; cross-batch references always
//! point to earlier batches and forward references cannot occur.
//! - Backwards compatibility: `PodlangOutput::first_batch()` is provided to ease migration of code
//! that expects a single custom predicate batch.
//! - [`load_module`]: Load a module file containing predicate definitions.
//! Returns a [`Module`] wrapping a `CustomPredicateBatch`.
//!
//! - [`parse_request`]: Parse a request file containing a REQUEST block.
//! Returns a [`PodRequest`] with statement templates.
//!
//! ## Module vs Request
//!
//! - **Modules** contain predicate definitions (`pred(A) = AND(...)`) and imports.
//! They cannot contain a REQUEST block.
//!
//! - **Requests** contain a REQUEST block and imports.
//! They cannot define predicates.
//!
//! ## Using Modules
//!
//! Use [`Module::apply_predicate`] to apply a predicate into a `MainPodBuilder`
//! (recommended), or [`Module::apply_predicate_with`] for advanced control.
//!
//! Large predicates are automatically split into chains of smaller predicates;
//! `apply_predicate` handles this transparently.
//!
pub mod error;
pub mod frontend_ast;
pub mod frontend_ast_batch;
pub mod frontend_ast_lower;
pub mod frontend_ast_split;
pub mod frontend_ast_validate;
pub mod module;
pub mod parser;
pub mod pretty_print;
use std::sync::Arc;
pub use error::LangError;
pub use frontend_ast_batch::{MultiOperationError, PredicateBatches};
pub use frontend_ast_split::{SplitChainInfo, SplitChainPiece, SplitResult};
pub use module::{Module, MultiOperationError};
pub use parser::{parse_podlang, Pairs, ParseError, Rule};
pub use pretty_print::PrettyPrint;
use crate::{
frontend::PodRequest,
middleware::{CustomPredicateBatch, Params},
};
use crate::{frontend::PodRequest, middleware::Params};
/// Final result of processing a Podlang document.
/// Load a module from Podlang source.
///
/// - `custom_batches`: all custom predicates defined in the document, possibly spanning multiple
/// batches. Use [`PredicateBatches`] APIs to look up predicates by name and apply them.
/// - `request`: the request templates defined in the document (empty if not present).
#[derive(Debug, Clone)]
pub struct PodlangOutput {
pub custom_batches: PredicateBatches,
pub request: PodRequest,
}
impl PodlangOutput {
/// Get the first batch, if any (for backwards compatibility).
///
/// Prefer using `custom_batches` directly if your code expects multiple batches.
pub fn first_batch(&self) -> Option<&Arc<CustomPredicateBatch>> {
self.custom_batches.first_batch()
}
}
/// Parse, validate, and lower a Podlang document into middleware structures.
/// Modules contain predicate definitions and imports, but no REQUEST block.
///
/// - `input`: Podlang source.
/// - `params`: middleware parameters limiting sizes/arity and controlling lowering behavior.
/// - `available_batches`: external batches available for `use batch ... from 0x...` imports.
///
/// Returns a [`PodlangOutput`] containing custom predicate batches (if any) and a `PodRequest`
/// (possibly empty).
pub fn parse(
input: &str,
/// - `source`: Podlang source code
/// - `name`: Name for the module (used in batch naming)
/// - `params`: Middleware parameters limiting sizes/arity
/// - `available_modules`: External modules available for `use module ...` imports
pub fn load_module(
source: &str,
name: &str,
params: &Params,
available_batches: &[Arc<CustomPredicateBatch>],
) -> Result<PodlangOutput, LangError> {
let pairs = parse_podlang(input)?;
available_modules: Vec<Arc<Module>>,
) -> Result<Module, LangError> {
let pairs = parse_podlang(source)?;
let document_pair = pairs
.into_iter()
.next()
.expect("parse_podlang should always return at least one pair for a valid document");
let document = frontend_ast::parse::parse_document(document_pair)?;
let validated = frontend_ast_validate::validate(document, available_batches)?;
let lowered = frontend_ast_lower::lower(validated, params, "PodlangBatch".to_string())?;
let available_modules_map = available_modules
.iter()
.map(|m| (m.id(), m.clone()))
.collect();
let validated = frontend_ast_validate::validate(
document,
&available_modules_map,
frontend_ast_validate::ParseMode::Module,
)?;
let module = frontend_ast_lower::lower_module(validated, params, name)?;
Ok(module)
}
let custom_batches = lowered.batches.unwrap_or_default();
let request = lowered.request.unwrap_or_else(|| {
// If no request, create an empty one
PodRequest::new(vec![])
});
Ok(PodlangOutput {
custom_batches,
request,
})
/// Parse a request from Podlang source.
///
/// Requests contain a REQUEST block and imports, but no predicate definitions.
///
/// - `source`: Podlang source code
/// - `params`: Middleware parameters limiting sizes/arity
/// - `available_modules`: External modules available for `use module ...` imports
pub fn parse_request(
source: &str,
params: &Params,
available_modules: &[Arc<Module>],
) -> Result<PodRequest, LangError> {
let pairs = parse_podlang(source)?;
let document_pair = pairs
.into_iter()
.next()
.expect("parse_podlang should always return at least one pair for a valid document");
let document = frontend_ast::parse::parse_document(document_pair)?;
let available_modules_map = available_modules
.iter()
.map(|m| (m.id(), m.clone()))
.collect();
let validated = frontend_ast_validate::validate(
document,
&available_modules_map,
frontend_ast_validate::ParseMode::Request,
)?;
let request = frontend_ast_lower::lower_request(validated, params)?;
Ok(request)
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use hex::ToHex;
use pretty_assertions::assert_eq;
@ -143,11 +151,6 @@ mod tests {
PredicateOrWildcard::Predicate(pred)
}
// Helper to get the first batch from the output
fn first_batch(output: &super::PodlangOutput) -> &Arc<CustomPredicateBatch> {
output.first_batch().expect("Expected at least one batch")
}
#[test]
fn test_e2e_simple_predicate() -> Result<(), LangError> {
let input = r#"
@ -157,12 +160,9 @@ mod tests {
"#;
let params = Params::default();
let processed = parse(input, &params, &[])?;
let batch_result = first_batch(&processed);
let request_result = processed.request.templates();
let module = load_module(input, "test_module", &params, vec![])?;
assert_eq!(request_result.len(), 0);
assert_eq!(batch_result.predicates().len(), 1);
assert_eq!(module.batch.predicates().len(), 1);
// Expected structure
let expected_statements = vec![StatementTmpl {
@ -180,9 +180,9 @@ mod tests {
names(&["PodA", "PodB"]),
)?;
let expected_batch =
CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]);
CustomPredicateBatch::new("test_module".to_string(), vec![expected_predicate]);
assert_eq!(*batch_result, expected_batch);
assert_eq!(&*module.batch, &*expected_batch);
Ok(())
}
@ -197,10 +197,9 @@ mod tests {
"#;
let params = Params::default();
let processed = parse(input, &params, &[])?;
let request_templates = processed.request.templates();
let request = parse_request(input, &params, &[])?;
let request_templates = request.templates();
assert!(processed.custom_batches.is_empty());
assert!(!request_templates.is_empty());
// Expected structure
@ -236,12 +235,9 @@ mod tests {
"#;
let params = Params::default();
let processed = parse(input, &params, &[])?;
let batch_result = first_batch(&processed);
let request_result = processed.request.templates();
let module = load_module(input, "test_module", &params, vec![])?;
assert_eq!(request_result.len(), 0);
assert_eq!(batch_result.predicates().len(), 1);
assert_eq!(module.batch.predicates().len(), 1);
// Expected structure: Public args: A (index 0). Private args: Temp (index 1)
let expected_statements = vec![
@ -268,58 +264,51 @@ mod tests {
names(&["A", "Temp"]),
)?;
let expected_batch =
CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]);
CustomPredicateBatch::new("test_module".to_string(), vec![expected_predicate]);
assert_eq!(*batch_result, expected_batch);
assert_eq!(&*module.batch, &*expected_batch);
Ok(())
}
#[test]
fn test_e2e_request_with_custom_call() -> Result<(), LangError> {
let input = r#"
// First, load the module
let module_input = r#"
my_pred(X, Y) = AND(
Equal(X["val"], Y["val"])
)
REQUEST(
my_pred(Pod1, Pod2)
)
"#;
let params = Params::default();
let processed = parse(input, &params, &[])?;
let batch_result = first_batch(&processed);
let request_templates = processed.request.templates();
let module = Arc::new(load_module(module_input, "my_module", &params, vec![])?);
assert_eq!(module.batch.predicates().len(), 1);
let module_hash = module.id().encode_hex::<String>();
// Then, parse the request using the module
let request_input = format!(
r#"
use module 0x{} as my_module
REQUEST(
my_module::my_pred(Pod1, Pod2)
)
"#,
module_hash
);
let request = parse_request(&request_input, &params, std::slice::from_ref(&module))?;
let request_templates = request.templates();
assert_eq!(batch_result.predicates().len(), 1);
assert!(!request_templates.is_empty());
// Expected Batch structure
let expected_pred_statements = vec![StatementTmpl {
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
sta_ak(("X", 0), "val"), // X["val"] -> Wildcard(0), Key("val")
sta_ak(("Y", 1), "val"), // Y["val"] -> Wildcard(1), Key("val")
],
}];
let expected_predicate = CustomPredicate::and(
&params,
"my_pred".to_string(),
expected_pred_statements,
2, // args_len (X, Y)
names(&["X", "Y"]),
)?;
let expected_batch =
CustomPredicateBatch::new("PodlangBatch".to_string(), vec![expected_predicate]);
assert_eq!(*batch_result, expected_batch);
// Expected Request structure
// Pod1 -> Wildcard 0, Pod2 -> Wildcard 1
let expected_request_templates = vec![StatementTmpl {
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
expected_batch,
module.batch.clone(),
0,
))),
args: vec![
@ -335,25 +324,36 @@ mod tests {
#[test]
fn test_e2e_request_with_various_args() -> Result<(), LangError> {
let input = r#"
// First, load the module
let module_input = r#"
some_pred(A, B, C) = AND( Equal(A["foo"], B["bar"]) )
REQUEST(
some_pred(
Var1, // Wildcard
12345, // Int Literal
"hello_string" // String Literal (Removed invalid AK args)
)
Equal(AnotherPod["another_key"], Var1["some_field"])
)
"#;
let params = Params::default();
let processed = parse(input, &params, &[])?;
let batch_result = first_batch(&processed);
let request_templates = processed.request.templates();
let module = Arc::new(load_module(module_input, "some_module", &params, vec![])?);
let module_hash = module.id().encode_hex::<String>();
// Then, parse the request
let request_input = format!(
r#"
use module 0x{} as some_module
REQUEST(
some_module::some_pred(
Var1, // Wildcard
12345, // Int Literal
"hello_string" // String Literal
)
Equal(AnotherPod["another_key"], Var1["some_field"])
)
"#,
module_hash
);
let request = parse_request(&request_input, &params, std::slice::from_ref(&module))?;
let request_templates = request.templates();
assert_eq!(batch_result.predicates().len(), 1); // some_pred is defined
assert!(!request_templates.is_empty());
// Expected Wildcard Indices in Request Scope:
@ -364,7 +364,7 @@ mod tests {
let expected_templates = vec![
StatementTmpl {
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
batch_result.clone(),
module.batch.clone(),
0,
))), // Refers to some_pred
args: vec![
@ -402,10 +402,9 @@ mod tests {
"#;
let params = Params::default();
let processed = parse(input, &params, &[])?;
let request_templates = processed.request.templates();
let request = parse_request(input, &params, &[])?;
let request_templates = request.templates();
assert!(processed.custom_batches.is_empty());
assert!(!request_templates.is_empty());
let expected_templates = vec![
@ -459,8 +458,8 @@ mod tests {
"#;
// Parse the input string
let processed = super::parse(input, &Params::default(), &[])?;
let parsed_templates = processed.request.templates();
let request = parse_request(input, &Params::default(), &[])?;
let parsed_templates = request.templates();
// Define Expected Templates (Copied from prover/mod.rs)
let now_minus_18y_val = Value::from(1169909388_i64);
@ -549,11 +548,6 @@ mod tests {
"Parsed ZuKYC request templates do not match the expected hard-coded version"
);
assert!(
processed.custom_batches.is_empty(),
"Expected no custom predicates for a REQUEST only input"
);
Ok(())
}
@ -591,14 +585,10 @@ mod tests {
)
"#;
let processed = super::parse(input, &params, &[])?;
let module = load_module(input, "ethdos", &params, vec![])?;
assert!(
processed.request.templates().is_empty(),
"Expected no request templates"
);
assert_eq!(
first_batch(&processed).predicates().len(),
module.batch.predicates().len(),
4,
"Expected 4 custom predicates"
);
@ -718,7 +708,7 @@ mod tests {
)?;
let expected_batch = CustomPredicateBatch::new(
"PodlangBatch".to_string(),
"ethdos".to_string(),
vec![
expected_friend_pred,
expected_base_pred,
@ -728,8 +718,7 @@ mod tests {
);
assert_eq!(
*first_batch(&processed),
expected_batch,
&*module.batch, &*expected_batch,
"Processed ETHDoS predicates do not match expected structure"
);
@ -737,10 +726,10 @@ mod tests {
}
#[test]
fn test_e2e_use_statement() -> Result<(), LangError> {
fn test_e2e_use_module_statement() -> Result<(), LangError> {
let params = Params::default();
// 1. Create a batch to be imported
// 1. Create a module with a predicate to be imported
let imported_pred_stmts = vec![StatementTmpl {
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![
@ -755,98 +744,75 @@ mod tests {
2,
names(&["A", "B"]),
)?;
let available_batch =
CustomPredicateBatch::new("MyBatch".to_string(), vec![imported_predicate]);
let available_batches = vec![available_batch.clone()];
let batch = CustomPredicateBatch::new("my_module".to_string(), vec![imported_predicate]);
let module = Arc::new(Module::new(batch.clone(), HashMap::new()));
let module_hash = module.id().encode_hex::<String>();
// 2. Create the input string that uses the batch
let batch_id_str = available_batch.id().encode_hex::<String>();
// 2. Create the input string that uses the module
let input = format!(
r#"
use batch imported_pred from 0x{}
use module 0x{} as my_module
REQUEST(
imported_pred(Pod1, Pod2)
my_module::imported_equal(Pod1, Pod2)
)
"#,
batch_id_str
module_hash
);
// 3. Parse the input
let processed = parse(&input, &params, &available_batches)?;
let request_templates = processed.request.templates();
// 3. Parse the request
let request = parse_request(&input, &params, std::slice::from_ref(&module))?;
let request_templates = request.templates();
assert!(
processed.custom_batches.is_empty(),
"No custom predicates should be defined in the main input"
);
assert_eq!(request_templates.len(), 1, "Expected one request template");
// 4. Check the resulting request template
let expected_request_templates = vec![StatementTmpl {
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch,
0,
))),
args: vec![
StatementTmplArg::Wildcard(wc("Pod1", 0)),
StatementTmplArg::Wildcard(wc("Pod2", 1)),
],
}];
assert_eq!(request_templates, expected_request_templates);
// 4. Check the resulting request template uses the imported predicate
let template = &request_templates[0];
assert_eq!(template.args.len(), 2);
Ok(())
}
#[test]
fn test_e2e_use_statement_complex() -> Result<(), LangError> {
fn test_e2e_use_module_complex() -> Result<(), LangError> {
let params = Params::default();
// 1. Create a batch with multiple predicates
// 1. Create a module with multiple predicates
let pred1 = CustomPredicate::and(&params, "p1".into(), vec![], 1, names(&["A"]))?;
let pred2 = CustomPredicate::and(&params, "p2".into(), vec![], 2, names(&["B", "C"]))?;
let pred3 = CustomPredicate::and(&params, "p3".into(), vec![], 1, names(&["D"]))?;
let available_batch =
CustomPredicateBatch::new("MyBatch".to_string(), vec![pred1, pred2, pred3]);
let available_batches = vec![available_batch.clone()];
// 2. Create the input string that uses the batch with skips
let batch_id_str = available_batch.id().encode_hex::<String>();
let batch = CustomPredicateBatch::new("mymodule".to_string(), vec![pred1, pred2, pred3]);
let mymodule = Arc::new(Module::new(batch.clone(), HashMap::new()));
let module_hash = mymodule.id().encode_hex::<String>();
// 2. Create the input string that uses qualified predicate access
let input = format!(
r#"
use batch pred_one, _, pred_three from 0x{}
use module 0x{} as mymodule
REQUEST(
pred_one(Pod1)
pred_three(Pod2)
mymodule::p1(Pod1)
mymodule::p3(Pod2)
)
"#,
batch_id_str
module_hash
);
// 3. Parse the input
let processed = parse(&input, &params, &available_batches)?;
let request_templates = processed.request.templates();
// 3. Parse the request
let request = parse_request(&input, &params, std::slice::from_ref(&mymodule))?;
let request_templates = request.templates();
assert_eq!(request_templates.len(), 2, "Expected two request templates");
// 4. Check the resulting request templates
let expected_templates = vec![
StatementTmpl {
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch.clone(),
0,
))),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch.clone(), 0))),
args: vec![StatementTmplArg::Wildcard(wc("Pod1", 0))],
},
StatementTmpl {
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch,
2,
))),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch, 2))),
args: vec![StatementTmplArg::Wildcard(wc("Pod2", 1))],
},
];
@ -857,10 +823,10 @@ mod tests {
}
#[test]
fn test_e2e_custom_predicate_uses_import() -> Result<(), LangError> {
fn test_e2e_custom_predicate_uses_module() -> Result<(), LangError> {
let params = Params::default();
// 1. Create a batch with a predicate to be imported
// 1. Create a module with a predicate to be imported
let imported_pred_stmts = vec![StatementTmpl {
pred_or_wc: pred_lit(Predicate::Native(NativePredicate::Equal)),
args: vec![sta_ak(("A", 0), "foo"), sta_ak(("B", 1), "bar")],
@ -872,47 +838,38 @@ mod tests {
2,
names(&["A", "B"]),
)?;
let available_batch =
CustomPredicateBatch::new("MyBatch".to_string(), vec![imported_predicate]);
let available_batches = vec![available_batch.clone()];
let batch = CustomPredicateBatch::new("extmod".to_string(), vec![imported_predicate]);
let extmod = Arc::new(Module::new(batch.clone(), HashMap::new()));
let extmod_hash = extmod.id().encode_hex::<String>();
// 2. Create the input string that defines a new predicate using the imported one
let batch_id_str = available_batch.id().encode_hex::<String>();
let input = format!(
r#"
use batch imported_eq from 0x{}
use module 0x{} as extmod
wrapper_pred(X, Y) = AND(
imported_eq(X, Y)
extmod::imported_equal(X, Y)
)
"#,
batch_id_str
extmod_hash
);
// 3. Parse the input
let processed = parse(&input, &params, &available_batches)?;
// 3. Load as module
let module = load_module(&input, "test", &params, vec![extmod])?;
assert!(
processed.request.templates().is_empty(),
"No request should be defined"
);
assert_eq!(
first_batch(&processed).predicates().len(),
module.batch.predicates().len(),
1,
"Expected one custom predicate to be defined"
);
// 4. Check the resulting predicate definition
let defined_pred = &first_batch(&processed).predicates()[0];
let defined_pred = &module.batch.predicates()[0];
assert_eq!(defined_pred.name, "wrapper_pred");
assert_eq!(defined_pred.statements.len(), 1);
let expected_statement = StatementTmpl {
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(
available_batch.clone(),
0,
))),
pred_or_wc: pred_lit(Predicate::Custom(CustomPredicateRef::new(batch, 0))),
args: vec![
StatementTmplArg::Wildcard(wc("X", 0)),
StatementTmplArg::Wildcard(wc("Y", 1)),
@ -939,8 +896,8 @@ mod tests {
"#,
);
let processed = parse(&input, &params, &[])?;
let request_templates = processed.request.templates();
let request = parse_request(&input, &params, &[])?;
let request_templates = request.templates();
assert_eq!(request_templates.len(), 1);
if let PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) =
@ -998,8 +955,8 @@ mod tests {
*/
let params = Params::default();
let processed = parse(&input, &params, &[])?;
let request_templates = processed.request.templates();
let request = parse_request(&input, &params, &[])?;
let request_templates = request.templates();
let expected_templates = vec![
StatementTmpl {
@ -1034,29 +991,33 @@ mod tests {
}
#[test]
fn test_e2e_use_unknown_batch() {
fn test_e2e_use_unknown_module() {
let params = Params::default();
let available_batches = &[];
let unknown_batch_id = format!("0x{}", "a".repeat(64));
// Use a hash that doesn't correspond to any loaded module
let fake_hash = EMPTY_HASH.encode_hex::<String>();
let input = format!(
r#"
use batch some_pred from {}
use module 0x{} as unknown_module
REQUEST(
Equal(A["x"], 1)
)
"#,
unknown_batch_id
fake_hash
);
let result = parse(&input, &params, available_batches);
let result = parse_request(&input, &params, &[]);
assert!(result.is_err());
match result.err().unwrap() {
LangError::Validation(e) => match *e {
frontend_ast_validate::ValidationError::BatchNotFound { id, .. } => {
assert_eq!(id, unknown_batch_id);
frontend_ast_validate::ValidationError::ModuleNotFound { name, .. } => {
// The error now carries the hex-formatted hash
assert_eq!(name, fake_hash);
}
_ => panic!("Expected BatchNotFound error, but got {:?}", e),
_ => panic!("Expected ModuleNotFound error, but got {:?}", e),
},
e => panic!("Expected LangError::Validation, but got {:?}", e),
}
@ -1065,17 +1026,15 @@ mod tests {
#[test]
fn test_e2e_undefined_wildcard() {
let params = Params::default();
let available_batches = &[];
let input = r#"
identity_verified(username, private: identity_dict) = AND(
Equal(identity_dict["username"], username)
Equal(identity_dict["user_public_key"], user_public_key)
)
"#
.to_string();
"#;
let result = parse(&input, &params, available_batches);
let result = load_module(input, "test", &params, vec![]);
assert!(result.is_err());

671
src/lang/module.rs Normal file
View file

@ -0,0 +1,671 @@
//! Podlang Module: definition, construction, and predicate application.
//!
//! A [`Module`] wraps a middleware `CustomPredicateBatch` with name resolution
//! and split chain metadata. Use [`build_module`] to construct a Module from
//! validated and split predicates.
use std::{collections::HashMap, sync::Arc};
use crate::{
frontend::{CustomPredicateBatchBuilder, Operation, OperationArg, StatementTmplBuilder},
lang::{
error::BatchingError,
frontend_ast::{ConjunctionType, CustomPredicateDef},
frontend_ast_lower::{lower_statement_arg, resolve_predicate_ref, ResolutionContext},
frontend_ast_split::{SplitChainInfo, SplitResult},
frontend_ast_validate::SymbolTable,
},
middleware::{CustomPredicateBatch, CustomPredicateRef, Hash, Params, Statement},
};
/// Errors that can occur when applying predicates
#[derive(Debug, Clone, thiserror::Error)]
pub enum MultiOperationError {
#[error("Predicate not found: {0}")]
PredicateNotFound(String),
#[error("Chain piece not found: {0}")]
ChainPieceNotFound(String),
#[error(
"Wrong statement count for predicate '{predicate}': expected {expected}, got {actual}"
)]
WrongStatementCount {
predicate: String,
expected: usize,
actual: usize,
},
#[error("No operation steps to apply")]
NoSteps,
}
/// A Podlang module wrapping a middleware CustomPredicateBatch with name resolution info.
#[derive(Debug, Clone)]
pub struct Module {
/// The middleware representation (CustomPredicateBatch)
pub batch: Arc<CustomPredicateBatch>,
/// Map from predicate name to index in batch
pub predicate_index: HashMap<String, usize>,
/// Split chain info for predicates that were split
pub split_chains: HashMap<String, SplitChainInfo>,
}
impl Module {
/// Create a new Module from a batch, building the predicate_index automatically
pub fn new(
batch: Arc<CustomPredicateBatch>,
split_chains: HashMap<String, SplitChainInfo>,
) -> Self {
let predicate_index = batch
.predicates()
.iter()
.enumerate()
.map(|(i, p)| (p.name.clone(), i))
.collect();
Self {
batch,
predicate_index,
split_chains,
}
}
/// Root hash of the module's Merkle tree
pub fn id(&self) -> Hash {
self.batch.id()
}
/// Get a reference to a predicate by name
pub fn predicate_ref_by_name(&self, name: &str) -> Option<CustomPredicateRef> {
let idx = self.predicate_index.get(name)?;
Some(CustomPredicateRef::new(self.batch.clone(), *idx))
}
/// Check if the module contains any predicates
pub fn is_empty(&self) -> bool {
self.batch.predicates().is_empty()
}
/// Apply a predicate directly into a `MainPodBuilder` (common case).
///
/// For split predicates, earlier chain links are applied as private, and only the final
/// piece is applied as public when `public` is true. For non-split predicates, the single
/// operation is applied with the provided `public` flag.
///
/// Arguments:
/// - `builder`: target builder to receive operations
/// - `name`: predicate name
/// - `statements`: user statements in original declaration order
/// - `public`: whether the final result should be public
pub fn apply_predicate(
&self,
builder: &mut crate::frontend::MainPodBuilder,
name: &str,
statements: Vec<Statement>,
public: bool,
) -> crate::frontend::Result<Statement> {
self.apply_predicate_with(name, statements, public, |is_public, op| {
if is_public {
builder.pub_op(op)
} else {
builder.priv_op(op)
}
})
}
/// Advanced variant: apply using a custom closure.
///
/// Prefer `apply_predicate` for common usage. This method allows callers to intercept each
/// operation (with its `public` flag) and decide how to execute it.
///
/// Arguments:
/// - `name`: predicate name
/// - `statements`: user statements in original declaration order
/// - `public`: whether the final result should be public
/// - `apply_op`: closure `(is_public, operation) -> Result<Statement>` used to execute each step
pub fn apply_predicate_with<F, E>(
&self,
name: &str,
statements: Vec<Statement>,
public: bool,
mut apply_op: F,
) -> Result<Statement, E>
where
F: FnMut(bool, Operation) -> Result<Statement, E>,
E: From<MultiOperationError>,
{
let steps = self.build_steps(name, statements, public)?;
if steps.is_empty() {
return Err(MultiOperationError::NoSteps.into());
}
let mut prev_result: Option<Statement> = None;
for step in steps {
let op = if let Some(prev) = prev_result {
// Replace the last Statement::None arg with the previous result.
let mut args = step.operation.1;
let last = args
.last_mut()
.expect("chain statement should include placeholder arg");
assert!(
matches!(last, OperationArg::Statement(Statement::None)),
"expected last arg to be a Statement::None placeholder"
);
*last = OperationArg::Statement(prev);
Operation(step.operation.0, args, step.operation.2)
} else {
step.operation
};
prev_result = Some(apply_op(step.public, op)?);
}
Ok(prev_result.unwrap())
}
/// Build operation steps for a predicate (internal helper)
fn build_steps(
&self,
predicate_name: &str,
statements: Vec<Statement>,
public: bool,
) -> Result<Vec<OperationStep>, MultiOperationError> {
// Check if this predicate was split
let chain_info = match self.split_chains.get(predicate_name) {
Some(info) => info,
None => {
// Not split - single operation with all statements
let pred_ref = self.predicate_ref_by_name(predicate_name).ok_or_else(|| {
MultiOperationError::PredicateNotFound(predicate_name.to_string())
})?;
return Ok(vec![OperationStep {
operation: Operation::custom(pred_ref, statements),
public,
}]);
}
};
// Validate statement count
if statements.len() != chain_info.real_statement_count {
return Err(MultiOperationError::WrongStatementCount {
predicate: predicate_name.to_string(),
expected: chain_info.real_statement_count,
actual: statements.len(),
});
}
// Reorder statements from original order to split order
let mut reordered = vec![Statement::None; statements.len()];
for (original_idx, stmt) in statements.into_iter().enumerate() {
let split_idx = chain_info.reorder_map[original_idx];
reordered[split_idx] = stmt;
}
// Build operations for each piece in execution order
let num_pieces = chain_info.chain_pieces.len();
// Compute the starting offset for each piece
let mut piece_offsets = vec![0usize; num_pieces];
let mut offset = 0;
for i in (0..num_pieces).rev() {
piece_offsets[i] = offset;
offset += chain_info.chain_pieces[i].real_statement_count;
}
let mut steps = Vec::new();
for (piece_idx, piece) in chain_info.chain_pieces.iter().enumerate() {
let is_final = piece_idx == num_pieces - 1;
let piece_ref = self
.predicate_ref_by_name(&piece.name)
.ok_or_else(|| MultiOperationError::ChainPieceNotFound(piece.name.clone()))?;
let start = piece_offsets[piece_idx];
let end = start + piece.real_statement_count;
let mut args: Vec<Statement> = reordered[start..end].to_vec();
if piece.has_chain_call {
args.push(Statement::None);
}
steps.push(OperationStep {
operation: Operation::custom(piece_ref, args),
public: public && is_final,
});
}
Ok(steps)
}
}
/// A single step in a multi-operation sequence for split predicates
struct OperationStep {
operation: Operation,
public: bool,
}
/// Build a single Module from split predicate results.
///
/// Takes a list of split results (containing predicates and optional chain info)
/// and builds a single Module. With Merkle tree backing supporting up to 65536
/// predicates, all predicates from a document fit in one module.
///
/// `symbols` provides the symbol table for resolving predicate references,
/// including imported predicates from other modules and intro predicates.
pub fn build_module(
split_results: Vec<SplitResult>,
params: &Params,
module_name: &str,
symbols: &SymbolTable,
) -> Result<Module, BatchingError> {
// Extract predicates and collect split chains
let mut predicates = Vec::new();
let mut split_chains = HashMap::new();
for result in split_results {
// Collect chain info if present
if let Some(chain_info) = result.chain_info {
split_chains.insert(chain_info.original_name.clone(), chain_info);
}
// Flatten predicates
predicates.extend(result.predicates);
}
if predicates.is_empty() {
// Return an empty module
let empty_batch = CustomPredicateBatch::new(module_name.to_string(), vec![]);
return Ok(Module::new(empty_batch, split_chains));
}
// Build reference map: name -> index
let reference_map: HashMap<String, usize> = predicates
.iter()
.enumerate()
.map(|(idx, pred)| (pred.name.name.clone(), idx))
.collect();
// Build the batch
let batch = build_single_batch(&predicates, &reference_map, symbols, params, module_name)?;
Ok(Module::new(batch, split_chains))
}
/// Build a batch with properly resolved references
fn build_single_batch(
predicates: &[CustomPredicateDef],
reference_map: &HashMap<String, usize>,
symbols: &SymbolTable,
params: &Params,
batch_name: &str,
) -> Result<Arc<CustomPredicateBatch>, BatchingError> {
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), batch_name.to_string());
for pred in predicates {
let name = &pred.name.name;
// Collect argument names
let public_args: Vec<&str> = pred
.args
.public_args
.iter()
.map(|a| a.name.as_str())
.collect();
let private_args: Vec<&str> = pred
.args
.private_args
.as_ref()
.map(|args| args.iter().map(|a| a.name.as_str()).collect())
.unwrap_or_default();
// Build statement templates with resolved predicates
let statement_builders: Vec<StatementTmplBuilder> = pred
.statements
.iter()
.map(|stmt| build_statement_with_resolved_refs(stmt, reference_map, name, symbols))
.collect::<Result<_, _>>()?;
let conjunction = pred.conjunction_type == ConjunctionType::And;
builder
.predicate(
name,
conjunction,
&public_args,
&private_args,
&statement_builders,
)
.map_err(|e| BatchingError::Internal {
message: format!("Failed to add predicate '{}': {}", name, e),
})?;
}
Ok(builder.finish())
}
/// Build a statement template with properly resolved predicate references
fn build_statement_with_resolved_refs(
stmt: &crate::lang::frontend_ast::StatementTmpl,
reference_map: &HashMap<String, usize>,
custom_predicate_name: &str, // custom pred that defines this statement template
symbols: &SymbolTable,
) -> Result<StatementTmplBuilder, BatchingError> {
// Resolve the predicate using the unified resolution function
let context = ResolutionContext::Module {
reference_map,
custom_predicate_name,
};
let pred_or_wc =
resolve_predicate_ref(&stmt.predicate, symbols, &context).ok_or_else(|| {
BatchingError::Internal {
message: format!("Unknown predicate reference: '{}'", stmt.predicate),
}
})?;
// Build the statement template
let mut builder = StatementTmplBuilder::new(pred_or_wc);
for arg in &stmt.args {
builder = builder.arg(lower_statement_arg(arg));
}
Ok(builder)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
lang::{
frontend_ast::parse::parse_document,
frontend_ast_split::split_predicate_if_needed,
frontend_ast_validate::{validate, ParseMode, ValidatedAST},
load_module,
parser::parse_podlang,
},
middleware::{CustomPredicateRef, Predicate, PredicateOrWildcard},
};
/// Helper: parse and validate input, returning predicates and symbol table
fn parse_and_validate(input: &str) -> (Vec<CustomPredicateDef>, ValidatedAST) {
let parsed = parse_podlang(input).expect("Failed to parse");
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
let validated = validate(document.clone(), &HashMap::new(), ParseMode::Module)
.expect("Failed to validate");
let predicates = document
.items
.into_iter()
.filter_map(|item| match item {
crate::lang::frontend_ast::DocumentItem::CustomPredicateDef(pred) => Some(pred),
_ => None,
})
.collect();
(predicates, validated)
}
/// Helper: wrap predicates into SplitResult (without actually splitting)
fn preds_to_split_results(predicates: Vec<CustomPredicateDef>) -> Vec<SplitResult> {
predicates
.into_iter()
.map(|pred| SplitResult {
predicates: vec![pred],
chain_info: None,
})
.collect()
}
#[test]
fn test_single_predicate() {
let input = r#"
my_pred(A, B) = AND(
Equal(A["x"], B["y"])
)
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let result = build_module(
preds_to_split_results(predicates),
&params,
"TestModule",
validated.symbols(),
);
assert!(result.is_ok());
let module = result.unwrap();
assert_eq!(module.batch.predicates().len(), 1);
}
#[test]
fn test_multiple_predicates() {
let input = r#"
pred1(A) = AND(Equal(A["x"], 1))
pred2(B) = AND(Equal(B["y"], 2))
pred3(C) = AND(Equal(C["z"], 3))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let result = build_module(
preds_to_split_results(predicates),
&params,
"TestModule",
validated.symbols(),
);
assert!(result.is_ok());
let module = result.unwrap();
assert_eq!(module.batch.predicates().len(), 3);
}
#[test]
fn test_intra_batch_forward_reference() {
// pred2 calls pred1, but pred2 is declared first
// This should work because they're in the same batch
let input = r#"
pred2(B) = AND(pred1(B))
pred1(A) = AND(Equal(A["x"], 1))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let result = build_module(
preds_to_split_results(predicates),
&params,
"TestModule",
validated.symbols(),
);
assert!(result.is_ok());
let module = result.unwrap();
assert_eq!(module.batch.predicates().len(), 2);
// pred2 should reference pred1 via BatchSelf
let pred2 = &module.batch.predicates()[0];
let stmt = &pred2.statements[0];
assert!(matches!(
stmt.pred_or_wc(),
PredicateOrWildcard::Predicate(Predicate::BatchSelf(1))
)); // pred1 is at index 1
}
#[test]
fn test_mutual_recursion() {
// pred1 calls pred2, pred2 calls pred1 - mutual recursion
// This should work because they're in the same batch
let input = r#"
pred1(A) = AND(pred2(A))
pred2(B) = AND(pred1(B))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let result = build_module(
preds_to_split_results(predicates),
&params,
"TestModule",
validated.symbols(),
);
assert!(result.is_ok());
let module = result.unwrap();
assert_eq!(module.batch.predicates().len(), 2);
// Both should use BatchSelf references
let pred1 = &module.batch.predicates()[0];
let pred2 = &module.batch.predicates()[1];
assert!(matches!(
pred1.statements[0].pred_or_wc(),
PredicateOrWildcard::Predicate(Predicate::BatchSelf(1))
)); // calls pred2
assert!(matches!(
pred2.statements[0].pred_or_wc(),
PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
)); // calls pred1
}
#[test]
fn test_predicate_ref_by_name() {
let input = r#"
pred1(A) = AND(Equal(A["x"], 1))
pred2(B) = AND(Equal(B["y"], 2))
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
let module = build_module(
preds_to_split_results(predicates),
&params,
"TestModule",
validated.symbols(),
)
.unwrap();
// Should be able to look up both predicates
assert!(module.predicate_ref_by_name("pred1").is_some());
assert!(module.predicate_ref_by_name("pred2").is_some());
assert!(module.predicate_ref_by_name("nonexistent").is_none());
}
#[test]
fn test_split_predicate() {
// A predicate that will split into 2 pieces
let input = r#"
large_pred(A) = AND(
Equal(A["a"], 1)
Equal(A["b"], 2)
Equal(A["c"], 3)
Equal(A["d"], 4)
Equal(A["e"], 5)
Equal(A["f"], 6)
)
"#;
let (predicates, validated) = parse_and_validate(input);
let params = Params::default();
// Split the predicate
let mut split_results = Vec::new();
for pred in predicates {
let result = split_predicate_if_needed(pred, &params).expect("Split failed");
split_results.push(result);
}
// Should split into 2 pieces
assert_eq!(split_results.len(), 1);
assert_eq!(split_results[0].predicates.len(), 2);
assert!(split_results[0].chain_info.is_some());
let module =
build_module(split_results, &params, "TestModule", validated.symbols()).unwrap();
// Verify chain info is preserved
let chain_info = module.split_chains.get("large_pred").unwrap();
assert_eq!(chain_info.chain_pieces.len(), 2);
assert_eq!(chain_info.real_statement_count, 6);
}
#[test]
fn test_load_module_importing_two_modules() {
use hex::ToHex;
let params = Params::default();
// Module "checks": defines is_equal
let checks = Arc::new(
load_module(
r#"is_equal(X, Y) = AND(Equal(X["val"], Y["val"]))"#,
"checks",
&params,
vec![],
)
.unwrap(),
);
// Module "ordering": defines is_less
let ordering = Arc::new(
load_module(
r#"is_less(X, Y) = AND(Lt(X["val"], Y["val"]))"#,
"ordering",
&params,
vec![],
)
.unwrap(),
);
let checks_hash = checks.id().encode_hex::<String>();
let ordering_hash = ordering.id().encode_hex::<String>();
// Module "combined": imports both, uses predicates from each
let combined = load_module(
&format!(
r#"
use module 0x{} as checks
use module 0x{} as ordering
equal_and_ordered(A, B, C) = AND(
checks::is_equal(A, B)
ordering::is_less(B, C)
)
"#,
checks_hash, ordering_hash
),
"combined",
&params,
vec![checks.clone(), ordering.clone()],
)
.unwrap();
assert_eq!(combined.batch.predicates().len(), 1);
let pred = &combined.batch.predicates()[0];
assert_eq!(pred.name, "equal_and_ordered");
assert_eq!(pred.statements.len(), 2);
// First statement references checks::is_equal (external Custom ref, not BatchSelf)
let checks_ref = CustomPredicateRef::new(checks.batch.clone(), 0);
assert_eq!(
*pred.statements[0].pred_or_wc(),
PredicateOrWildcard::Predicate(Predicate::Custom(checks_ref))
);
// Second statement references ordering::is_less (external Custom ref, not BatchSelf)
let ordering_ref = CustomPredicateRef::new(ordering.batch.clone(), 0);
assert_eq!(
*pred.statements[1].pred_or_wc(),
PredicateOrWildcard::Predicate(Predicate::Custom(ordering_ref))
);
}
}

View file

@ -219,7 +219,7 @@ mod tests {
use super::*;
use crate::{
backends::plonky2::primitives::ec::schnorr::SecretKey,
lang::parse,
lang::load_module,
middleware::{
CustomPredicate, Key, NativePredicate, Params, Predicate, StatementTmpl,
StatementTmplArg, Value, Wildcard,
@ -388,20 +388,19 @@ mod tests {
/// Helper function for round-trip testing
fn assert_round_trip(input: &str) {
let params = Params::default();
let available_batches = &[];
// Step 1: Parse the input
let parsed_result =
parse(input, &params, available_batches).expect("Initial parsing should succeed");
let module =
load_module(input, "test", &params, vec![]).expect("Initial parsing should succeed");
// Step 2: Pretty-print the parsed batch
let batch = parsed_result.first_batch().expect("Expected batch");
let batch = &module.batch;
let pretty_printed = batch.to_podlang_string();
// Step 3: Parse the pretty-printed result
let reparsed_result =
parse(&pretty_printed, &params, available_batches).expect("Reparsing should succeed");
let reparsed_batch = reparsed_result.first_batch().expect("Expected batch");
let reparsed_module = load_module(&pretty_printed, "test", &params, vec![])
.expect("Reparsing should succeed");
let reparsed_batch = &reparsed_module.batch;
// Step 4: Verify the ASTs are equivalent
assert_eq!(
@ -556,16 +555,17 @@ mod tests {
"#;
let params = Params::default();
let parsed_result = parse(input, &params, &[]).expect("Parsing should succeed");
let batch = parsed_result.first_batch().expect("Expected batch");
let module = load_module(input, "test", &params, vec![]).expect("Parsing should succeed");
let batch = &module.batch;
let pretty_printed = batch.to_podlang_string();
println!("Original input:\n{}", input);
println!("\nPretty-printed output:\n{}", pretty_printed);
let reparsed = parse(&pretty_printed, &params, &[]).expect("Reparsing should succeed");
let reparsed_batch = reparsed.first_batch().expect("Expected batch");
let reparsed = load_module(&pretty_printed, "test", &params, vec![])
.expect("Reparsing should succeed");
let reparsed_batch = &reparsed.batch;
assert_eq!(batch.predicates(), reparsed_batch.predicates());
}
@ -629,14 +629,15 @@ mod tests {
);
let params = Params::default();
let parsed_result = parse(&input, &params, &[]).expect("Should parse successfully");
let batch = parsed_result.first_batch().expect("Expected batch");
let module =
load_module(&input, "test", &params, vec![]).expect("Should parse successfully");
let batch = &module.batch;
let pretty_printed = batch.to_podlang_string();
let reparsed_result =
parse(&pretty_printed, &params, &[]).expect("Should reparse successfully");
let reparsed_batch = reparsed_result.first_batch().expect("Expected batch");
let reparsed_module = load_module(&pretty_printed, "test", &params, vec![])
.expect("Should reparse successfully");
let reparsed_batch = &reparsed_module.batch;
assert_eq!(
batch.predicates(),