Remove batch splitting system (#475)

* First pass at removing batch splitting

* Refactor to separate module loading from request parsing

* Consolidate module functionality

* Tidy up comments

* Use array of modules instead of HashMap

* Formatting

* Use module hashes when importing modules
This commit is contained in:
Rob Knight 2026-02-09 10:31:47 +01:00 committed by GitHub
parent 5dab8195b4
commit acab26e5c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1425 additions and 1938 deletions

View file

@ -18,17 +18,17 @@ pub struct Document {
/// Top-level items that can appear in a document
#[derive(Debug, Clone, PartialEq)]
pub enum DocumentItem {
UseBatchStatement(UseBatchStatement),
UseModuleStatement(UseModuleStatement),
UseIntroStatement(UseIntroStatement),
CustomPredicateDef(CustomPredicateDef),
RequestDef(RequestDef),
}
/// Import statement: `use batch pred1, pred2, _ from 0x...`
/// Module import statement: `use module 0xHASH as alias`
#[derive(Debug, Clone, PartialEq)]
pub struct UseBatchStatement {
pub imports: Vec<ImportName>,
pub batch_ref: HashHex,
pub struct UseModuleStatement {
pub hash: HashHex,
pub alias: Identifier,
pub span: Option<Span>,
}
@ -40,19 +40,6 @@ pub struct UseIntroStatement {
pub intro_hash: HashHex,
pub span: Option<Span>,
}
/// Individual import name (identifier or unused "_")
#[derive(Debug, Clone, PartialEq)]
pub enum ImportName {
Named(String),
Unused, // "_"
}
/// Batch reference (hash)
#[derive(Debug, Clone, PartialEq)]
pub struct BatchRef {
pub hash: HashHex,
pub span: Option<Span>,
}
/// Intro predicate reference (hash)
#[derive(Debug, Clone, PartialEq)]
@ -96,11 +83,33 @@ pub enum ConjunctionType {
/// Statement template: predicate call with arguments
#[derive(Debug, Clone, PartialEq)]
pub struct StatementTmpl {
pub predicate: Identifier,
pub predicate: PredicateRef,
pub args: Vec<StatementTmplArg>,
pub span: Option<Span>,
}
/// Reference to a predicate (local or qualified with module name)
#[derive(Debug, Clone, PartialEq)]
pub enum PredicateRef {
/// Unqualified name (local or native predicate)
Local(Identifier),
/// Qualified name (module::predicate)
Qualified {
module: Identifier,
predicate: Identifier,
},
}
impl PredicateRef {
/// Get the predicate name (without module qualifier)
pub fn predicate_name(&self) -> &str {
match self {
PredicateRef::Local(id) => &id.name,
PredicateRef::Qualified { predicate, .. } => &predicate.name,
}
}
}
/// Arguments that can be passed to statements
#[derive(Debug, Clone, PartialEq)]
pub enum StatementTmplArg {
@ -256,7 +265,7 @@ impl fmt::Display for Document {
impl fmt::Display for DocumentItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DocumentItem::UseBatchStatement(u) => write!(f, "{}", u),
DocumentItem::UseModuleStatement(u) => write!(f, "{}", u),
DocumentItem::UseIntroStatement(u) => write!(f, "{}", u),
DocumentItem::CustomPredicateDef(c) => write!(f, "{}", c),
DocumentItem::RequestDef(r) => write!(f, "{}", r),
@ -264,16 +273,9 @@ impl fmt::Display for DocumentItem {
}
}
impl fmt::Display for UseBatchStatement {
impl fmt::Display for UseModuleStatement {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "use batch ")?;
for (i, import) in self.imports.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", import)?;
}
write!(f, " from {}", self.batch_ref)
write!(f, "use module {} as {}", self.hash, self.alias)
}
}
@ -290,21 +292,17 @@ impl fmt::Display for UseIntroStatement {
}
}
impl fmt::Display for ImportName {
impl fmt::Display for PredicateRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ImportName::Named(name) => write!(f, "{}", name),
ImportName::Unused => write!(f, "_"),
PredicateRef::Local(id) => write!(f, "{}", id),
PredicateRef::Qualified { module, predicate } => {
write!(f, "{}::{}", module, predicate)
}
}
}
}
impl fmt::Display for BatchRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.hash)
}
}
impl fmt::Display for IntroPredicateRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.hash)
@ -536,10 +534,10 @@ pub mod parse {
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::use_batch_statement => {
items.push(DocumentItem::UseBatchStatement(parse_use_batch_statement(
inner_pair,
)));
Rule::use_module_statement => {
items.push(DocumentItem::UseModuleStatement(
parse_use_module_statement(inner_pair),
));
}
Rule::use_intro_statement => {
items.push(DocumentItem::UseIntroStatement(parse_use_intro_statement(
@ -562,25 +560,17 @@ pub mod parse {
Ok(Document { items })
}
fn parse_use_batch_statement(pair: Pair<Rule>) -> UseBatchStatement {
assert_eq!(pair.as_rule(), Rule::use_batch_statement);
fn parse_use_module_statement(pair: Pair<Rule>) -> UseModuleStatement {
assert_eq!(pair.as_rule(), Rule::use_module_statement);
let span = get_span(&pair);
let mut inner = pair.into_inner();
let use_list_pair = inner
.find(|p| p.as_rule() == Rule::use_predicate_list)
.unwrap();
let batch_ref_pair = inner.find(|p| p.as_rule() == Rule::batch_ref).unwrap();
let hash = parse_hash_hex(inner.next().unwrap());
let alias = parse_identifier(inner.next().unwrap());
let imports = use_list_pair
.into_inner()
.filter(|p| p.as_rule() == Rule::import_name)
.map(parse_import_name)
.collect();
UseBatchStatement {
imports,
batch_ref: parse_hash_hex(batch_ref_pair.into_inner().next().unwrap()),
UseModuleStatement {
hash,
alias,
span: Some(span),
}
}
@ -622,16 +612,6 @@ pub mod parse {
}
}
fn parse_import_name(pair: Pair<Rule>) -> ImportName {
assert_eq!(pair.as_rule(), Rule::import_name);
let s = pair.as_str();
if s == "_" {
ImportName::Unused
} else {
ImportName::Named(s.to_string())
}
}
fn parse_hash_hex(pair: Pair<Rule>) -> HashHex {
assert_eq!(pair.as_rule(), Rule::hash_hex);
let span = get_span(&pair);
@ -748,7 +728,7 @@ pub mod parse {
let span = get_span(&pair);
let mut inner = pair.into_inner();
let predicate = parse_identifier(inner.next().unwrap());
let predicate = parse_predicate_ref(inner.next().unwrap());
let mut args = Vec::new();
if let Some(arg_list) = inner.next() {
@ -768,6 +748,22 @@ pub mod parse {
})
}
fn parse_predicate_ref(pair: Pair<Rule>) -> PredicateRef {
assert_eq!(pair.as_rule(), Rule::predicate_ref);
let inner = pair.into_inner().next().unwrap();
match inner.as_rule() {
Rule::qualified_predicate_ref => {
let mut parts = inner.into_inner();
let module = parse_identifier(parts.next().unwrap());
let predicate = parse_identifier(parts.next().unwrap());
PredicateRef::Qualified { module, predicate }
}
Rule::identifier => PredicateRef::Local(parse_identifier(inner)),
_ => unreachable!("Unexpected predicate_ref rule: {:?}", inner.as_rule()),
}
}
fn parse_statement_arg(pair: Pair<Rule>) -> Result<StatementTmplArg, parser::ParseError> {
assert_eq!(pair.as_rule(), Rule::statement_arg);
let inner = pair.into_inner().next().unwrap();
@ -1047,9 +1043,10 @@ mod tests {
fn clear_spans(doc: &mut Document) {
for item in &mut doc.items {
match item {
DocumentItem::UseBatchStatement(u) => {
DocumentItem::UseModuleStatement(u) => {
u.span = None;
u.batch_ref.span = None;
u.hash.span = None;
u.alias.span = None;
}
DocumentItem::UseIntroStatement(u) => {
u.span = None;
@ -1082,9 +1079,19 @@ mod tests {
}
}
fn clear_predicate_ref_spans(pred_ref: &mut PredicateRef) {
match pred_ref {
PredicateRef::Local(id) => id.span = None,
PredicateRef::Qualified { module, predicate } => {
module.span = None;
predicate.span = None;
}
}
}
fn clear_statement_spans(stmt: &mut StatementTmpl) {
stmt.span = None;
stmt.predicate.span = None;
clear_predicate_ref_spans(&mut stmt.predicate);
for arg in &mut stmt.args {
match arg {
StatementTmplArg::Literal(lit) => clear_literal_spans(lit),
@ -1168,8 +1175,8 @@ mod tests {
}
#[test]
fn test_use_batch_statement() {
let input = r#"use batch pred1, pred2, _ from 0x0000000000000000000000000000000000000000000000000000000000000000"#;
fn test_use_module_statement() {
let input = r#"use module 0x0000000000000000000000000000000000000000000000000000000000000000 as helpers"#;
test_roundtrip(input);
}
@ -1223,11 +1230,11 @@ mod tests {
#[test]
fn test_complete_document() {
let input = r#"use batch imported_pred from 0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd
let input = r#"use module 0x0000000000000000000000000000000000000000000000000000000000000000 as imported
is_valid(User, private: Config) = AND (
Equal(User["age"], Config["min_age"])
imported_pred(User, Config)
imported::some_pred(User, Config)
)
check_both(A, B, C) = OR (
@ -1306,7 +1313,7 @@ REQUEST(
// Check request structure
if let DocumentItem::RequestDef(req) = &ast.items[1] {
assert_eq!(req.statements.len(), 1);
assert_eq!(req.statements[0].predicate.name, "my_pred");
assert_eq!(req.statements[0].predicate.predicate_name(), "my_pred");
assert_eq!(req.statements[0].args.len(), 2);
} else {
panic!("Expected RequestDef");