pod2/src/lang/frontend_ast_validate.rs
Rob Knight e9e3241263
Support "records" in Podlang (#507)
* Support both integer and string keys in anchored keys

* Podlang parser support for records

* Validate record usage in Podlang

* Lower records to middleware

* Cross-module record imports

* Tidying

* Record entry name literal

* More tidying

* More tests, make sure qualified record literals are supported

* Use snake-case for record entry names

* Review feedback
2026-05-06 06:21:22 -07:00

1409 lines
48 KiB
Rust

//! Validation for the frontend AST
//!
//! This module provides semantic validation for parsed AST documents,
//! including name resolution, arity checking, and wildcard validation.
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
};
use hex::ToHex;
use crate::{
lang::{frontend_ast::*, Module},
middleware::{CustomPredicateBatch, Hash, NativePredicate, Params},
};
/// A validated AST document with symbol table and diagnostics
#[derive(Debug, Clone)]
pub struct ValidatedAST {
document: Document,
symbols: SymbolTable,
diagnostics: Vec<Diagnostic>,
}
impl ValidatedAST {
pub fn document(&self) -> &Document {
&self.document
}
pub fn symbols(&self) -> &SymbolTable {
&self.symbols
}
pub fn diagnostics(&self) -> &[Diagnostic] {
&self.diagnostics
}
pub fn into_document(self) -> Document {
self.document
}
}
/// Symbol table containing all predicates and their metadata
#[derive(Debug, Clone)]
pub struct SymbolTable {
/// All predicates available in this scope
pub predicates: HashMap<String, PredicateInfo>,
/// Wildcard scopes for each custom predicate
pub wildcard_scopes: HashMap<String, WildcardScope>,
/// Imported modules (bound name → Module reference)
pub imported_modules: HashMap<String, Arc<Module>>,
/// Records visible in this scope (local declarations + imports).
pub records: HashMap<String, RecordSchema>,
}
/// Resolved record schema: ordered entries plus a name→index lookup, with
/// provenance for diagnostics. Lowering uses `entry_index` to translate
/// dot-access like `r.foo` into the integer key for an `AnchoredKey`.
#[derive(Debug, Clone)]
pub struct RecordSchema {
pub entries: Vec<String>,
pub entry_index: HashMap<String, usize>,
pub source: RecordSource,
pub source_span: Option<Span>,
}
impl RecordSchema {
/// Build a schema from already-deduplicated entries. Callers that need
/// to surface a per-entry span on duplicates (e.g. local declarations)
/// should detect duplicates themselves before calling this.
pub fn from_entries(
entries: Vec<String>,
source: RecordSource,
source_span: Option<Span>,
) -> Self {
let entry_index = entries
.iter()
.enumerate()
.map(|(i, e)| (e.clone(), i))
.collect();
Self {
entries,
entry_index,
source,
source_span,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RecordSource {
Local,
Imported { module: String },
}
/// Build the `SymbolTable.records` key for a record imported via
/// `use module ... as alias`. Mirrors the `alias::Name` form used for
/// `TypeRef::Qualified`.
pub fn qualified_record_key(alias: &str, name: &str) -> String {
format!("{}::{}", alias, name)
}
/// Information about a predicate
#[derive(Debug, Clone)]
pub struct PredicateInfo {
pub kind: PredicateKind,
pub arity: usize,
pub public_arity: usize,
pub source_span: Option<Span>,
}
/// Kind of predicate
#[derive(Debug, Clone)]
pub enum PredicateKind {
Native(NativePredicate),
Custom {
index: usize,
},
BatchImported {
batch: Arc<CustomPredicateBatch>,
index: usize,
},
ModuleImported {
module_name: String,
predicate_name: String,
predicate_index: usize,
},
IntroImported {
name: String,
verifier_data_hash: Hash,
},
}
/// Wildcard scope for a custom predicate
#[derive(Debug, Clone)]
pub struct WildcardScope {
pub wildcards: HashMap<String, WildcardInfo>,
}
/// Information about a wildcard
#[derive(Debug, Clone)]
pub struct WildcardInfo {
pub index: usize,
pub is_public: bool,
pub source_span: Option<Span>,
/// Record type tag for typed args (`name TypeName` syntax). The name
/// references an entry in `SymbolTable.records`.
pub record_type: Option<String>,
}
/// Diagnostic message (warning or info)
#[derive(Debug, Clone)]
pub struct Diagnostic {
pub level: DiagnosticLevel,
pub message: String,
pub span: Option<Span>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiagnosticLevel {
Warning,
Info,
}
pub use crate::lang::error::ValidationError;
/// Mode for parsing/validation - determines what constructs are allowed
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParseMode {
/// Module mode: predicate definitions allowed, REQUEST block not allowed
Module,
/// Request mode: REQUEST block required, predicate definitions not allowed
Request,
}
/// Validate an AST document in the given mode
pub fn validate(
document: Document,
available_modules: &HashMap<Hash, Arc<Module>>,
params: &Params,
mode: ParseMode,
) -> Result<ValidatedAST, ValidationError> {
let validator = Validator::new(available_modules, params, mode);
validator.validate(document)
}
struct Validator {
available_modules: HashMap<Hash, Arc<Module>>,
params: Params,
symbols: SymbolTable,
diagnostics: Vec<Diagnostic>,
custom_predicate_count: usize,
mode: ParseMode,
}
impl Validator {
fn new(
available_modules: &HashMap<Hash, Arc<Module>>,
params: &Params,
mode: ParseMode,
) -> Self {
Self {
available_modules: available_modules.clone(),
params: params.clone(),
symbols: SymbolTable {
predicates: HashMap::new(),
wildcard_scopes: HashMap::new(),
imported_modules: HashMap::new(),
records: HashMap::new(),
},
diagnostics: Vec::new(),
custom_predicate_count: 0,
mode,
}
}
fn validate(mut self, document: Document) -> Result<ValidatedAST, ValidationError> {
// Pass 1: Build symbol table
self.build_symbol_table(&document)?;
// Pass 2: Validate all references
self.validate_references(&document)?;
Ok(ValidatedAST {
document,
symbols: self.symbols,
diagnostics: self.diagnostics,
})
}
fn build_symbol_table(&mut self, document: &Document) -> Result<(), ValidationError> {
// First process imports
for item in &document.items {
if let DocumentItem::UseModuleStatement(use_stmt) = item {
self.process_use_module_statement(use_stmt)?;
}
if let DocumentItem::UseIntroStatement(use_stmt) = item {
self.process_use_intro_statement(use_stmt)?;
}
}
// Records before predicates so typed-arg resolution can find them.
for item in &document.items {
if let DocumentItem::RecordDef(record_def) = item {
self.process_record_def(record_def)?;
}
}
// Check mode constraints for predicate definitions
let mut has_predicates = false;
for item in &document.items {
if let DocumentItem::CustomPredicateDef(pred_def) = item {
if self.mode == ParseMode::Request {
return Err(ValidationError::PredicatesNotAllowedInRequest {
span: pred_def.span,
});
}
has_predicates = true;
self.process_custom_predicate_def(pred_def)?;
}
}
// Check mode constraints for REQUEST blocks
let mut has_request = false;
let mut first_request_span = None;
for item in &document.items {
if let DocumentItem::RequestDef(req) = item {
if self.mode == ParseMode::Module {
return Err(ValidationError::RequestNotAllowedInModule { span: req.span });
}
if let Some(first_span) = first_request_span {
return Err(ValidationError::MultipleRequestDefinitions {
first_span: Some(first_span),
second_span: req.span,
});
}
first_request_span = req.span;
has_request = true;
}
}
// Enforce that modules have predicates and requests have a REQUEST block.
match self.mode {
ParseMode::Module if !has_predicates => {
return Err(ValidationError::NoPredicatesInModule);
}
ParseMode::Request if !has_request => {
return Err(ValidationError::NoRequestBlock);
}
_ => {}
}
Ok(())
}
fn process_use_module_statement(
&mut self,
use_stmt: &UseModuleStatement,
) -> Result<(), ValidationError> {
let alias = &use_stmt.alias.name;
let hash = &use_stmt.hash.hash;
// Check if the module is available by hash
let module =
self.available_modules
.get(hash)
.ok_or_else(|| ValidationError::ModuleNotFound {
name: hash.encode_hex::<String>(),
span: use_stmt.span,
})?;
// Flatten the imported module's locally-declared records into the
// symbol table under qualified keys (`alias::Name`). No transitive
// re-export — `Module.records` only carries local declarations.
for (record_name, entries) in &module.records {
self.symbols.records.insert(
qualified_record_key(alias, record_name),
RecordSchema::from_entries(
entries.clone(),
RecordSource::Imported {
module: alias.clone(),
},
use_stmt.span,
),
);
}
// Store the module keyed by alias for later qualified name resolution
self.symbols
.imported_modules
.insert(alias.clone(), module.clone());
Ok(())
}
/// Returns the resolved `SymbolTable.records` key for a typed arg, or
/// `None` if the arg has no `type_name`. The key is the bare type name
/// for locals and `"alias::Name"` for qualified imports. Errors if the
/// tag doesn't refer to a known record.
fn resolve_typed_arg(&self, arg: &TypedArg) -> Result<Option<String>, ValidationError> {
let Some(type_ref) = &arg.type_name else {
return Ok(None);
};
let key = type_ref.symbol_table_key();
if !self.symbols.records.contains_key(&key) {
return Err(ValidationError::UnknownRecord {
name: key,
span: type_ref.span(),
});
}
Ok(Some(key))
}
fn process_use_intro_statement(
&mut self,
use_stmt: &UseIntroStatement,
) -> Result<(), ValidationError> {
let intro_name = &use_stmt.name.name;
let args = &use_stmt.args;
let intro_predicate_ref = &use_stmt.intro_hash;
if self.symbols.predicates.contains_key(intro_name) {
return Err(ValidationError::DuplicateImport {
name: intro_name.clone(),
span: use_stmt.span,
});
}
self.symbols.predicates.insert(
intro_name.clone(),
PredicateInfo {
kind: PredicateKind::IntroImported {
name: intro_name.clone(),
// Hash is already parsed in the AST
verifier_data_hash: intro_predicate_ref.hash,
},
arity: args.len(),
public_arity: args.len(),
source_span: use_stmt.span,
},
);
Ok(())
}
fn process_record_def(&mut self, record_def: &RecordDef) -> Result<(), ValidationError> {
let name = &record_def.name.name;
if let Some(existing) = self.symbols.records.get(name) {
return Err(ValidationError::DuplicateRecord {
name: name.clone(),
first_span: existing.source_span,
second_span: record_def.name.span,
});
}
let max = self.params.max_record_entries();
if record_def.entries.len() > max {
return Err(ValidationError::RecordTooManyEntries {
name: name.clone(),
count: record_def.entries.len(),
max,
span: record_def.span,
});
}
let mut seen = HashSet::with_capacity(record_def.entries.len());
let mut entries = Vec::with_capacity(record_def.entries.len());
for entry in &record_def.entries {
if !seen.insert(&entry.name) {
return Err(ValidationError::DuplicateRecordEntry {
record: name.clone(),
entry: entry.name.clone(),
span: entry.span,
});
}
entries.push(entry.name.clone());
}
self.symbols.records.insert(
name.clone(),
RecordSchema::from_entries(entries, RecordSource::Local, record_def.name.span),
);
Ok(())
}
fn process_custom_predicate_def(
&mut self,
pred_def: &CustomPredicateDef,
) -> Result<(), ValidationError> {
let name = &pred_def.name.name;
if self.symbols.predicates.contains_key(name) {
let first_span = self.symbols.predicates[name].source_span;
return Err(ValidationError::DuplicatePredicate {
name: name.clone(),
first_span,
second_span: pred_def.name.span,
});
}
// Check for empty statement list
if pred_def.statements.is_empty() {
return Err(ValidationError::EmptyStatementList {
context: format!("predicate '{}'", name),
span: pred_def.span,
});
}
// Build wildcard scope
let mut wildcards = HashMap::new();
let mut wildcard_index = 0;
// Process public arguments
for arg in &pred_def.args.public_args {
if wildcards.contains_key(&arg.name) {
return Err(ValidationError::DuplicateWildcard {
name: arg.name.clone(),
span: arg.span,
});
}
let record_type = self.resolve_typed_arg(arg)?;
wildcards.insert(
arg.name.clone(),
WildcardInfo {
index: wildcard_index,
is_public: true,
source_span: arg.span,
record_type,
},
);
wildcard_index += 1;
}
// Process private arguments
let mut private_count = 0;
if let Some(private_args) = &pred_def.args.private_args {
for arg in private_args {
if wildcards.contains_key(&arg.name) {
return Err(ValidationError::DuplicateWildcard {
name: arg.name.clone(),
span: arg.span,
});
}
let record_type = self.resolve_typed_arg(arg)?;
wildcards.insert(
arg.name.clone(),
WildcardInfo {
index: wildcard_index,
is_public: false,
source_span: arg.span,
record_type,
},
);
wildcard_index += 1;
private_count += 1;
}
}
// Add to symbol table
self.symbols.predicates.insert(
name.clone(),
PredicateInfo {
kind: PredicateKind::Custom {
index: self.custom_predicate_count,
},
arity: pred_def.args.public_args.len() + private_count,
public_arity: pred_def.args.public_args.len(),
source_span: pred_def.name.span,
},
);
self.symbols
.wildcard_scopes
.insert(name.clone(), WildcardScope { wildcards });
self.custom_predicate_count += 1;
Ok(())
}
fn validate_references(&mut self, document: &Document) -> Result<(), ValidationError> {
for item in &document.items {
match item {
DocumentItem::CustomPredicateDef(pred_def) => {
self.validate_custom_predicate_statements(pred_def)?;
}
DocumentItem::RequestDef(req_def) => {
self.validate_request_statements(req_def)?;
}
_ => {}
}
}
Ok(())
}
fn validate_custom_predicate_statements(
&self,
pred_def: &CustomPredicateDef,
) -> Result<(), ValidationError> {
let pred_name = pred_def.name.name.clone();
for stmt in &pred_def.statements {
let wildcard_scope = self
.symbols
.wildcard_scopes
.get(&pred_name)
.expect("Wildcard scope should exist after pass 1");
self.validate_statement(stmt, Some((&pred_name, wildcard_scope)))?;
}
Ok(())
}
fn validate_request_statements(&mut self, req_def: &RequestDef) -> Result<(), ValidationError> {
if req_def.statements.is_empty() {
self.diagnostics.push(Diagnostic {
level: DiagnosticLevel::Warning,
message: "Empty REQUEST block".to_string(),
span: req_def.span,
});
}
for stmt in &req_def.statements {
self.validate_statement(stmt, None)?;
}
Ok(())
}
/// Validate that no wildcard name collides with a predicate name to avoid ambiguity when using
/// wildcard predicates.
fn validate_wildcard_names(&self, names: &HashSet<&String>) -> Result<(), ValidationError> {
for name in names {
if NativePredicate::from_str(name).is_ok()
|| self.symbols.predicates.contains_key(*name)
{
return Err(ValidationError::WildcardPredicateNameCollision {
name: (*name).clone(),
});
}
}
Ok(())
}
fn validate_statement(
&self,
stmt: &StatementTmpl,
wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> {
let pred_name = stmt.predicate.predicate_name();
let pred_span = stmt.predicate.span();
let wc_names = match wildcard_context {
Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(),
None => HashSet::new(),
};
self.validate_wildcard_names(&wc_names)?;
// Check if predicate exists
let pred_info = match &stmt.predicate {
PredicateRef::Qualified { module, predicate } => {
// Look up the predicate in the imported module
let module_name = &module.name;
if let Some(imported_module) = self.symbols.imported_modules.get(module_name) {
// Find the predicate in the module
if let Some(&idx) = imported_module.predicate_index.get(&predicate.name) {
let module_pred = &imported_module.batch.predicates()[idx];
Some(PredicateInfo {
kind: PredicateKind::ModuleImported {
module_name: module_name.clone(),
predicate_name: predicate.name.clone(),
predicate_index: idx,
},
arity: module_pred.wildcard_names.len(),
public_arity: module_pred.args_len,
source_span: None,
})
} else {
return Err(ValidationError::UndefinedPredicate {
name: format!("{}::{}", module_name, predicate.name),
span: pred_span,
});
}
} else {
return Err(ValidationError::ModuleNotFound {
name: module_name.clone(),
span: module.span,
});
}
}
PredicateRef::Local(_) => {
if let Ok(native) = NativePredicate::from_str(pred_name) {
// Native predicate
Some(PredicateInfo {
kind: PredicateKind::Native(native),
arity: native.arity(),
public_arity: native.arity(),
source_span: None,
})
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
// Custom or imported predicate
Some(info.clone())
} else if wc_names.contains(&pred_name.to_string()) {
None
} else {
return Err(ValidationError::UndefinedPredicate {
name: pred_name.to_string(),
span: pred_span,
});
}
}
};
if let Some(ref pred_info) = pred_info {
let expected_arity = pred_info.public_arity;
if stmt.args.len() != expected_arity {
return Err(ValidationError::ArgumentCountMismatch {
predicate: pred_name.to_string(),
expected: expected_arity,
found: stmt.args.len(),
span: stmt.span,
});
}
}
// Validate arguments
self.validate_statement_args(stmt, wildcard_context)?;
Ok(())
}
fn validate_statement_args(
&self,
stmt: &StatementTmpl,
wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> {
for arg in &stmt.args {
match arg {
StatementTmplArg::Wildcard(id) => {
if let Some((pred_name, scope)) = wildcard_context {
if !scope.wildcards.contains_key(&id.name) {
return Err(ValidationError::UndefinedWildcard {
name: id.name.clone(),
pred_name: pred_name.to_string(),
span: id.span,
});
}
}
}
StatementTmplArg::AnchoredKey(ak) => {
if let Some((pred_name, scope)) = wildcard_context {
let Some(wc_info) = scope.wildcards.get(&ak.root.name) else {
return Err(ValidationError::UndefinedWildcard {
name: ak.root.name.clone(),
pred_name: pred_name.to_string(),
span: ak.root.span,
});
};
// Records are integer-keyed, so string-key access on
// a typed wildcard is dead code at proof time. Reject
// dot access for unknown entries and bracket access
// outright; require `r.entry` for record-shaped data.
if let Some(record_name) = &wc_info.record_type {
match &ak.key {
AnchoredKeyPath::Dot(entry) => {
let schema =
self.symbols.records.get(record_name).expect(
"record_type was resolved at predicate-def time",
);
if !schema.entry_index.contains_key(&entry.name) {
return Err(ValidationError::UnknownRecordEntry {
record: record_name.clone(),
entry: entry.name.clone(),
span: entry.span,
});
}
}
AnchoredKeyPath::Bracket(_) => {
return Err(ValidationError::BracketAccessOnTypedWildcard {
wildcard: ak.root.name.clone(),
record: record_name.clone(),
span: ak.span,
});
}
AnchoredKeyPath::Index(_) => unreachable!(
"AnchoredKeyPath::Index is introduced during lowering; \
it cannot appear in the parsed AST that validation sees"
),
}
}
}
}
StatementTmplArg::Literal(lit) => {
self.validate_literal_value(lit)?;
}
StatementTmplArg::SelfPredicateHash(id) => {
self.validate_self_predicate_hash(id, wildcard_context)?;
}
}
}
Ok(())
}
/// Validate a @self_predicate reference: the name must be a custom predicate in this module.
fn validate_self_predicate_hash(
&self,
id: &Identifier,
wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> {
// @self_predicate only makes sense inside module predicate definitions
if wildcard_context.is_none() {
return Err(
ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests {
span: id.span,
},
);
}
// Must refer to a custom predicate defined in this module (not intro/imported)
match self.symbols.predicates.get(&id.name) {
Some(info) if matches!(info.kind, PredicateKind::Custom { .. }) => Ok(()),
_ => Err(ValidationError::UndefinedPredicate {
name: id.name.clone(),
span: id.span,
}),
}
}
/// Recursively validate a literal value, checking predicate hash references.
fn validate_literal_value(&self, lit: &LiteralValue) -> Result<(), ValidationError> {
match lit {
LiteralValue::NativePredicateHash(id) => {
if NativePredicate::from_str(&id.name).is_err() {
return Err(ValidationError::UndefinedPredicate {
name: id.name.clone(),
span: id.span,
});
}
Ok(())
}
LiteralValue::ExternalPredicateHash { module, predicate } => {
if let Some(imported) = self.symbols.imported_modules.get(&module.name) {
if !imported.predicate_index.contains_key(&predicate.name) {
return Err(ValidationError::UndefinedPredicate {
name: format!("{}::{}", module.name, predicate.name),
span: predicate.span,
});
}
} else {
return Err(ValidationError::ModuleNotFound {
name: module.name.clone(),
span: module.span,
});
}
Ok(())
}
LiteralValue::Array(a) => {
for elem in &a.elements {
self.validate_literal_value(elem)?;
}
Ok(())
}
LiteralValue::Set(s) => {
for elem in &s.elements {
self.validate_literal_value(elem)?;
}
Ok(())
}
LiteralValue::Dict(d) => {
for pair in &d.pairs {
self.validate_literal_value(&pair.value)?;
}
Ok(())
}
LiteralValue::Record(r) => {
let key = r.name.symbol_table_key();
let Some(schema) = self.symbols.records.get(&key) else {
return Err(ValidationError::UnknownRecord {
name: key,
span: r.name.span(),
});
};
let mut seen: HashSet<&String> = HashSet::new();
for entry in &r.entries {
if !schema.entry_index.contains_key(&entry.name.name) {
return Err(ValidationError::UnknownRecordEntry {
record: key.clone(),
entry: entry.name.name.clone(),
span: entry.name.span,
});
}
if !seen.insert(&entry.name.name) {
return Err(ValidationError::DuplicateLiteralRecordEntry {
record: key.clone(),
entry: entry.name.name.clone(),
span: entry.name.span,
});
}
self.validate_literal_value(&entry.value)?;
}
Ok(())
}
LiteralValue::RecordEntryIndex { record, entry } => {
let key = record.symbol_table_key();
let Some(schema) = self.symbols.records.get(&key) else {
return Err(ValidationError::UnknownRecord {
name: key,
span: record.span(),
});
};
if !schema.entry_index.contains_key(&entry.name) {
return Err(ValidationError::UnknownRecordEntry {
record: key,
entry: entry.name.clone(),
span: entry.span,
});
}
Ok(())
}
_ => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::{
lang::{frontend_ast::parse::parse_document, parser::parse_podlang, Module},
middleware::{CustomPredicate, Params, EMPTY_HASH},
};
fn parse_and_validate_module(
input: &str,
modules: &HashMap<Hash, Arc<Module>>,
) -> Result<ValidatedAST, ValidationError> {
let parsed = parse_podlang(input).expect("Failed to parse");
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
validate(document, modules, &Params::default(), ParseMode::Module)
}
fn parse_and_validate_request(
input: &str,
modules: &HashMap<Hash, Arc<Module>>,
) -> Result<ValidatedAST, ValidationError> {
let parsed = parse_podlang(input).expect("Failed to parse");
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
validate(document, modules, &Params::default(), ParseMode::Request)
}
#[test]
fn test_validate_simple_request() {
let input = r#"REQUEST(
Equal(A["foo"], B["bar"])
)"#;
let result = parse_and_validate_request(input, &HashMap::new());
assert!(result.is_ok());
}
#[test]
fn test_validate_custom_predicate() {
let input = r#"
my_pred(A, B) = AND (
Equal(A["foo"], B["bar"])
)
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(result.is_ok());
let validated = result.unwrap();
assert!(validated.symbols.predicates.contains_key("my_pred"));
assert!(validated.symbols.wildcard_scopes.contains_key("my_pred"));
}
#[test]
fn test_undefined_predicate() {
let input = r#"REQUEST(
UndefinedPred(A, B)
)"#;
let result = parse_and_validate_request(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::UndefinedPredicate { .. })
));
}
#[test]
fn test_undefined_wildcard() {
let input = r#"
my_pred(A) = AND (
Equal(A["foo"], B["bar"])
)
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(
matches!(result, Err(ValidationError::UndefinedWildcard { name, .. }) if name == "B")
);
}
#[test]
fn test_arity_mismatch() {
let input = r#"REQUEST(
Equal(A, B, C)
)"#;
let result = parse_and_validate_request(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::ArgumentCountMismatch { .. })
));
}
#[test]
fn test_duplicate_predicate() {
let input = r#"
my_pred(A) = AND (Equal(A["x"], 1))
my_pred(B) = AND (Equal(B["y"], 2))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::DuplicatePredicate { .. })
));
}
#[test]
fn test_duplicate_wildcard() {
let input = r#"
my_pred(A, A) = AND (Equal(A["x"], 1))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::DuplicateWildcard { .. })
));
}
#[test]
fn test_wildcard_predicate_collision() {
let input = r#"
my_pred(A, Lt) = AND (Equal(A["x"], Lt))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::WildcardPredicateNameCollision { .. })
));
}
#[test]
fn test_custom_predicate_with_anchored_key() {
// First create a module with the predicate
let params = Params::default();
let pred = CustomPredicate::and(
&params,
"my_pred".to_string(),
vec![],
2,
vec!["A".to_string(), "B".to_string()],
)
.unwrap();
let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]);
let test_module = Arc::new(Module::new(batch, HashMap::new()));
let module_hash = test_module.id().encode_hex::<String>();
let mut available_modules = HashMap::new();
available_modules.insert(test_module.id(), test_module);
// Test that passing anchored key to custom predicate fails
let input = format!(
r#"
use module 0x{} as testmod
REQUEST(
testmod::my_pred(X["key"], Y)
)
"#,
module_hash
);
let result = parse_and_validate_request(&input, &available_modules);
assert!(result.is_ok());
}
#[test]
fn test_forward_reference() {
let input = r#"
pred1(A) = AND (
pred2(A)
)
pred2(B) = AND (
Equal(B["x"], 1)
)
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(result.is_ok());
}
#[test]
fn test_private_args() {
let input = r#"
my_pred(A, private: B, C) = AND (
Equal(A["x"], B["y"])
Equal(B["z"], C["w"])
)
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(result.is_ok());
let validated = result.unwrap();
let pred_info = &validated.symbols.predicates["my_pred"];
assert_eq!(pred_info.arity, 3);
assert_eq!(pred_info.public_arity, 1);
}
#[test]
fn test_empty_statement_list() {
// Create a custom predicate with empty statements to test validation
let document = Document {
items: vec![DocumentItem::CustomPredicateDef(CustomPredicateDef {
name: Identifier {
name: "my_pred".to_string(),
span: None,
},
args: ArgSection {
public_args: vec![TypedArg {
name: "A".to_string(),
type_name: None,
span: None,
}],
private_args: None,
span: None,
},
conjunction_type: ConjunctionType::And,
statements: vec![], // Empty statements
span: None,
})],
};
let result = validate(
document,
&HashMap::new(),
&Params::default(),
ParseMode::Module,
);
assert!(matches!(
result,
Err(ValidationError::EmptyStatementList { .. })
));
}
#[test]
fn test_multiple_request_definitions() {
let input = r#"
REQUEST(Equal(A["x"], 1))
REQUEST(Equal(B["y"], 2))
"#;
let result = parse_and_validate_request(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::MultipleRequestDefinitions { .. })
));
}
#[test]
fn test_use_module_statement() {
use std::sync::Arc;
use hex::ToHex;
let params = Params::default();
// Create a module to import
let pred = CustomPredicate::and(
&params,
"imported".to_string(),
vec![],
2,
vec!["X".to_string(), "Y".to_string()],
)
.unwrap();
let batch = CustomPredicateBatch::new("TestBatch".to_string(), vec![pred]);
let test_module = Arc::new(Module::new(batch, HashMap::new()));
let module_hash = test_module.id().encode_hex::<String>();
let mut available_modules = HashMap::new();
available_modules.insert(test_module.id(), test_module);
let input = format!(
r#"
use module 0x{} as testmod
use intro intro_pred() from 0x{}
REQUEST(
testmod::imported(A, B)
intro_pred()
)
"#,
module_hash,
EMPTY_HASH.encode_hex::<String>()
);
let result = parse_and_validate_request(&input, &available_modules);
assert!(result.is_ok());
let validated = result.unwrap();
// Module predicates are accessed via qualified names, so no local binding
assert!(validated.symbols.predicates.contains_key("intro_pred"));
assert!(validated.symbols.imported_modules.contains_key("testmod"));
}
#[test]
fn test_syntactic_sugar_predicates() {
let input = r#"REQUEST(
GtEq(A["x"], B["y"])
DictContains(D, K, V)
SetNotContains(S, E)
)"#;
let result = parse_and_validate_request(input, &HashMap::new());
assert!(result.is_ok());
}
// ----- Records ----------------------------------------------------------
#[test]
fn test_record_decl_accepted() {
let input = r#"
record ProcInputs = (foo, bar, baz)
my_pred(A) = AND(Equal(A["x"], 1))
"#;
let validated = parse_and_validate_module(input, &HashMap::new()).unwrap();
let schema = validated.symbols.records.get("ProcInputs").unwrap();
assert_eq!(schema.entries, vec!["foo", "bar", "baz"]);
assert_eq!(schema.source, RecordSource::Local);
}
#[test]
fn test_records_only_module_rejected() {
// A module needs at least one predicate; record-only modules are not
// a valid distribution unit.
let input = r#"record R = (x)"#;
assert!(matches!(
parse_and_validate_module(input, &HashMap::new()),
Err(ValidationError::NoPredicatesInModule)
));
}
#[test]
fn test_duplicate_record() {
let input = r#"
record R = (foo)
record R = (bar)
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::DuplicateRecord { .. })
));
}
#[test]
fn test_duplicate_entry_in_record() {
let input = r#"
record R = (foo, foo)
my_pred(A) = AND(Equal(A["x"], 1))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::DuplicateRecordEntry { record, entry, .. })
if record == "R" && entry == "foo"
));
}
#[test]
fn test_record_entry_cap() {
// Use a non-default depth so the cap reflects the parameter (not
// some hard-coded default). This pins three facts in one test:
// the param is wired through, the boundary is inclusive on accept,
// and cap + 1 is rejected.
let mut params = Params::default();
params.containers.max_depth_small -= 1;
let cap = params.max_record_entries();
let validate_with_n_entries = |n: usize| {
let entries: Vec<String> = (0..n).map(|i| format!("f{i}")).collect();
let input = format!(
"record Big = ({})\nmy_pred(A) = AND(Equal(A[\"x\"], 1))",
entries.join(", ")
);
let parsed = parse_podlang(&input).expect("Failed to parse");
let document =
parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
validate(document, &HashMap::new(), &params, ParseMode::Module)
};
assert!(validate_with_n_entries(cap).is_ok());
let too_many = cap + 1;
assert!(matches!(
validate_with_n_entries(too_many),
Err(ValidationError::RecordTooManyEntries { count, max, .. })
if count == too_many && max == cap
));
}
#[test]
fn test_typed_arg_resolves_known_record() {
let input = r#"
record R = (foo, bar)
my_pred(in R) = AND(Equal(in.foo, in.bar))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(result.is_ok());
let validated = result.unwrap();
let scope = validated.symbols.wildcard_scopes.get("my_pred").unwrap();
assert_eq!(scope.wildcards["in"].record_type.as_deref(), Some("R"));
}
#[test]
fn test_typed_arg_unknown_record_rejected() {
let input = r#"
my_pred(in NonExistent) = AND(Equal(in.foo, 1))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::UnknownRecord { name, .. }) if name == "NonExistent"
));
}
#[test]
fn test_dot_access_unknown_entry_rejected() {
let input = r#"
record R = (foo, bar)
my_pred(in R) = AND(Equal(in.quux, 1))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::UnknownRecordEntry { record, entry, .. })
if record == "R" && entry == "quux"
));
}
#[test]
fn test_dot_access_on_untyped_wildcard_unchecked() {
// r.foo on an untyped wildcard keeps current POD-string-key behavior;
// no record exists named anything that would constrain `foo`.
let input = r#"
my_pred(r) = AND(Equal(r.foo, 1))
"#;
assert!(parse_and_validate_module(input, &HashMap::new()).is_ok());
}
#[test]
fn test_bracket_access_on_typed_wildcard_rejected() {
// Records are integer-keyed; string-key access on a record-typed
// wildcard is incoherent and would never resolve at proof time.
// Force the user to use `.entry` instead.
let input = r#"
record R = (foo)
my_pred(r R) = AND(Equal(r["foo"], 1))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::BracketAccessOnTypedWildcard { wildcard, record, .. })
if wildcard == "r" && record == "R"
));
}
#[test]
fn test_record_literal_unknown_record() {
let input = r#"
my_pred(A) = AND(Equal(A["x"], NotARecord(f: 1)))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::UnknownRecord { name, .. }) if name == "NotARecord"
));
}
#[test]
fn test_record_literal_unknown_entry() {
let input = r#"
record R = (foo, bar)
my_pred(A) = AND(Equal(A["x"], R(foo: 1, quux: 2)))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::UnknownRecordEntry { record, entry, .. })
if record == "R" && entry == "quux"
));
}
#[test]
fn test_record_literal_nested() {
// Nested literals recurse through `validate_literal_value`: an unknown
// entry on the inner literal must still be caught.
let input = r#"
record Outer = (inner)
record Inner = (x, y)
my_pred(A) = AND(Equal(A["x"], Outer(inner: Inner(x: 1, z: 2))))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::UnknownRecordEntry { record, entry, .. })
if record == "Inner" && entry == "z"
));
}
#[test]
fn test_record_literal_duplicate_entry() {
let input = r#"
record R = (foo, bar)
my_pred(A) = AND(Equal(A["x"], R(foo: 1, foo: 2)))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::DuplicateLiteralRecordEntry { record, entry, .. })
if record == "R" && entry == "foo"
));
}
#[test]
fn test_record_entry_index_resolves() {
// Validation accepts `R::bar` and the schema records bar at index 1
// — the integer the literal will lower to.
let input = r#"
record R = (foo, bar)
my_pred(A) = AND(Contains(A, R::bar, 7))
"#;
let validated = parse_and_validate_module(input, &HashMap::new()).unwrap();
let schema = validated.symbols.records.get("R").unwrap();
assert_eq!(schema.entry_index["bar"], 1);
}
#[test]
fn test_record_entry_index_unknown_record() {
let input = r#"
my_pred(A) = AND(Contains(A, NotARecord::foo, 7))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::UnknownRecord { name, .. }) if name == "NotARecord"
));
}
#[test]
fn test_record_entry_index_unknown_entry() {
let input = r#"
record R = (foo, bar)
my_pred(A) = AND(Contains(A, R::quux, 7))
"#;
let result = parse_and_validate_module(input, &HashMap::new());
assert!(matches!(
result,
Err(ValidationError::UnknownRecordEntry { record, entry, .. })
if record == "R" && entry == "quux"
));
}
}