Podlang syntax for quoted predicates (#495)

This commit is contained in:
Rob Knight 2026-03-30 15:16:19 +01:00 committed by GitHub
parent a4069bcc55
commit 22d25e5cb2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 453 additions and 14 deletions

View file

@ -176,6 +176,12 @@ impl CustomPredicateBatchBuilder {
priv_args: &[&str],
sts: &[StatementTmplBuilder],
) -> Result<Predicate> {
if self.predicates.iter().any(|p| p.name == name) {
return Err(Error::custom(format!(
"Duplicate predicate name '{}' in batch",
name
)));
}
if self.predicates.len() >= Params::max_custom_batch_size() {
return Err(Error::max_length(
"self.predicates.len".to_string(),

View file

@ -287,6 +287,17 @@ fn render_validation_error(
ValidationError::NoRequestBlock => {
render_title_only(renderer, "requests must contain a REQUEST block")
}
ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests { span } => {
render_with_optional_span(
renderer,
source,
path,
"self-referential predicate literal not allowed in requests",
span.as_ref(),
"not allowed here",
)
}
}
}

View file

@ -165,6 +165,9 @@ pub enum ValidationError {
#[error("Modules must contain at least one predicate definition")]
NoPredicatesInModule,
#[error("Self-referential predicate literal not allowed in requests")]
SelfReferentialPredicateLiteralNotAllowedInRequests { span: Option<Span> },
#[error("Requests must contain a REQUEST block")]
NoRequestBlock,
}

View file

@ -116,6 +116,8 @@ pub enum StatementTmplArg {
Literal(LiteralValue),
Wildcard(Identifier),
AnchoredKey(AnchoredKey),
/// Hash of a same-module predicate, resolved at batch finalization time.
SelfPredicateHash(Identifier),
}
/// Anchored key: Var["key"] or Var.key
@ -168,6 +170,13 @@ pub enum LiteralValue {
Array(LiteralArray),
Set(LiteralSet),
Dict(LiteralDict),
/// Hash of a native predicate (resolved immediately).
NativePredicateHash(Identifier),
/// Hash of an external module's predicate (resolved immediately).
ExternalPredicateHash {
module: Identifier,
predicate: Identifier,
},
}
/// Integer literal
@ -391,6 +400,9 @@ impl fmt::Display for StatementTmplArg {
StatementTmplArg::Literal(lit) => write!(f, "{}", lit),
StatementTmplArg::Wildcard(id) => write!(f, "{}", id),
StatementTmplArg::AnchoredKey(ak) => write!(f, "{}", ak),
StatementTmplArg::SelfPredicateHash(id) => {
write!(f, "@self_predicate({})", id)
}
}
}
}
@ -422,6 +434,12 @@ impl fmt::Display for LiteralValue {
LiteralValue::Array(a) => write!(f, "{}", a),
LiteralValue::Set(s) => write!(f, "{}", s),
LiteralValue::Dict(d) => write!(f, "{}", d),
LiteralValue::NativePredicateHash(id) => {
write!(f, "@native_predicate({})", id)
}
LiteralValue::ExternalPredicateHash {
module, predicate, ..
} => write!(f, "@external_predicate({}, {})", module, predicate),
}
}
}
@ -769,6 +787,10 @@ pub mod parse {
let inner = pair.into_inner().next().unwrap();
match inner.as_rule() {
Rule::predicate_hash_self => {
let id = parse_identifier(inner.into_inner().next().unwrap());
Ok(StatementTmplArg::SelfPredicateHash(id))
}
Rule::literal_value => Ok(StatementTmplArg::Literal(parse_literal_value(inner)?)),
Rule::identifier => Ok(StatementTmplArg::Wildcard(parse_identifier(inner))),
Rule::anchored_key => Ok(StatementTmplArg::AnchoredKey(parse_anchored_key(inner)?)),
@ -823,6 +845,16 @@ pub mod parse {
Rule::literal_array => Ok(LiteralValue::Array(parse_literal_array(inner)?)),
Rule::literal_set => Ok(LiteralValue::Set(parse_literal_set(inner)?)),
Rule::literal_dict => Ok(LiteralValue::Dict(parse_literal_dict(inner)?)),
Rule::predicate_hash_native => {
let id = parse_identifier(inner.into_inner().next().unwrap());
Ok(LiteralValue::NativePredicateHash(id))
}
Rule::predicate_hash_external => {
let mut parts = inner.into_inner();
let module = parse_identifier(parts.next().unwrap());
let predicate = parse_identifier(parts.next().unwrap());
Ok(LiteralValue::ExternalPredicateHash { module, predicate })
}
_ => unreachable!("Unexpected literal value rule: {:?}", inner.as_rule()),
}
}
@ -1104,6 +1136,7 @@ mod tests {
AnchoredKeyPath::Dot(id) => id.span = None,
}
}
StatementTmplArg::SelfPredicateHash(id) => id.span = None,
}
}
}
@ -1139,6 +1172,13 @@ mod tests {
clear_literal_spans(&mut pair.value);
}
}
LiteralValue::NativePredicateHash(id) => id.span = None,
LiteralValue::ExternalPredicateHash {
module, predicate, ..
} => {
module.span = None;
predicate.span = None;
}
}
}

View file

@ -157,8 +157,10 @@ fn resolve_local_predicate(
/// Lower a literal value from AST to middleware Value.
///
/// This is a pure conversion that cannot fail.
pub fn lower_literal(lit: &LiteralValue) -> Value {
/// This is a pure conversion that cannot fail for context-free literals.
/// Panics on ExternalPredicateHash — use `lower_literal_with_context` when
/// external predicate references may appear (e.g. inside containers).
pub(crate) fn lower_literal(lit: &LiteralValue) -> Value {
match lit {
LiteralValue::Int(i) => Value::from(i.value),
LiteralValue::Bool(b) => Value::from(b.value),
@ -190,13 +192,83 @@ pub fn lower_literal(lit: &LiteralValue) -> Value {
let dict = containers::Dictionary::new(pairs);
Value::from(dict)
}
LiteralValue::NativePredicateHash(id) => {
let np = NativePredicate::from_str(&id.name).expect("validated native predicate");
Value::from(Predicate::Native(np).hash())
}
LiteralValue::ExternalPredicateHash { .. } => {
unreachable!(
"ExternalPredicateHash must be lowered with context via lower_literal_with_context"
)
}
}
}
/// Lower a literal value, resolving external predicate references using the symbol table.
pub fn lower_literal_with_context(
lit: &LiteralValue,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Result<Value, LoweringError> {
match lit {
LiteralValue::ExternalPredicateHash { module, predicate } => {
let pred_or_wc = resolve_predicate_ref(
&PredicateRef::Qualified {
module: module.clone(),
predicate: predicate.clone(),
},
symbols,
context,
)
.ok_or_else(|| LoweringError::PredicateNotFound {
name: format!("{}::{}", module.name, predicate.name),
})?;
let pred = match pred_or_wc {
crate::frontend::PredicateOrWildcard::Predicate(p) => p,
_ => unreachable!(
"`resolve_predicate_ref` always returns `PredicateOrWildcard::Predicate` on `PredicateRef::Qualified`"
)
};
Ok(Value::from(pred.hash()))
}
LiteralValue::Array(a) => {
let elements: Vec<_> = a
.elements
.iter()
.map(|e| lower_literal_with_context(e, symbols, context))
.collect::<Result<_, _>>()?;
Ok(Value::from(containers::Array::new(elements)))
}
LiteralValue::Set(s) => {
let elements: std::collections::HashSet<_> = s
.elements
.iter()
.map(|e| lower_literal_with_context(e, symbols, context))
.collect::<Result<_, _>>()?;
Ok(Value::from(containers::Set::new(elements)))
}
LiteralValue::Dict(d) => {
let pairs: HashMap<_, _> = d
.pairs
.iter()
.map(|pair| {
let key = Key::from(pair.key.value.as_str());
let value = lower_literal_with_context(&pair.value, symbols, context)?;
Ok((key, value))
})
.collect::<Result<_, LoweringError>>()?;
Ok(Value::from(containers::Dictionary::new(pairs)))
}
// All other variants are context-free
other => Ok(lower_literal(other)),
}
}
/// Lower a statement argument from AST to BuilderArg.
///
/// This is a pure conversion that cannot fail.
pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
/// Context-free for most arg types. Panics on ExternalPredicateHash inside literals —
/// use `lower_statement_arg_with_context` when external predicate references may appear.
pub(crate) fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
match arg {
StatementTmplArg::Literal(lit) => {
let value = lower_literal(lit);
@ -210,6 +282,25 @@ pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
};
BuilderArg::Key(ak.root.name.clone(), key_str)
}
StatementTmplArg::SelfPredicateHash(id) => BuilderArg::SelfPredicateHash(id.name.clone()),
}
}
/// Lower a statement argument, resolving external predicate references using the symbol table.
pub fn lower_statement_arg_with_context(
arg: &StatementTmplArg,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Result<BuilderArg, LoweringError> {
match arg {
StatementTmplArg::Literal(lit) => {
let value = lower_literal_with_context(lit, symbols, context)?;
Ok(BuilderArg::Literal(value))
}
StatementTmplArg::SelfPredicateHash(id) => {
Ok(BuilderArg::SelfPredicateHash(id.name.clone()))
}
other => Ok(lower_statement_arg(other)),
}
}
@ -324,7 +415,7 @@ impl<'a> Lowerer<'a> {
// Create a builder with the resolved predicate and desugar
let mut builder = StatementTmplBuilder::new(predicate.clone());
for arg in &stmt.args {
let builder_arg = lower_statement_arg(arg);
let builder_arg = lower_statement_arg_with_context(arg, symbols, &context)?;
builder = builder.arg(builder_arg);
}
let desugared = builder.desugar();
@ -402,7 +493,7 @@ impl<'a> Lowerer<'a> {
names.push(ak.root.name.clone());
}
}
StatementTmplArg::Literal(_) => {}
StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {}
}
}
}

View file

@ -123,7 +123,7 @@ fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet<String> {
StatementTmplArg::AnchoredKey(ak) => {
wildcards.insert(ak.root.name.clone());
}
StatementTmplArg::Literal(_) => {}
StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {}
}
}

View file

@ -559,7 +559,12 @@ impl Validator {
}
}
}
StatementTmplArg::Literal(_) => {}
StatementTmplArg::Literal(lit) => {
self.validate_literal_value(lit)?;
}
StatementTmplArg::SelfPredicateHash(id) => {
self.validate_self_predicate_hash(id, wildcard_context)?;
}
}
}
} else {
@ -588,13 +593,92 @@ impl Validator {
}
}
}
StatementTmplArg::Literal(_) => {}
StatementTmplArg::Literal(lit) => {
self.validate_literal_value(lit)?;
}
StatementTmplArg::SelfPredicateHash(id) => {
self.validate_self_predicate_hash(id, wildcard_context)?;
}
}
}
}
Ok(())
}
/// Validate a @self_predicate reference: the name must be a custom predicate in this module.
fn validate_self_predicate_hash(
&self,
id: &Identifier,
wildcard_context: Option<(&str, &WildcardScope)>,
) -> Result<(), ValidationError> {
// @self_predicate only makes sense inside module predicate definitions
if wildcard_context.is_none() {
return Err(
ValidationError::SelfReferentialPredicateLiteralNotAllowedInRequests {
span: id.span,
},
);
}
// Must refer to a custom predicate defined in this module (not intro/imported)
match self.symbols.predicates.get(&id.name) {
Some(info) if matches!(info.kind, PredicateKind::Custom { .. }) => Ok(()),
_ => Err(ValidationError::UndefinedPredicate {
name: id.name.clone(),
span: id.span,
}),
}
}
/// Recursively validate a literal value, checking predicate hash references.
fn validate_literal_value(&self, lit: &LiteralValue) -> Result<(), ValidationError> {
match lit {
LiteralValue::NativePredicateHash(id) => {
if NativePredicate::from_str(&id.name).is_err() {
return Err(ValidationError::UndefinedPredicate {
name: id.name.clone(),
span: id.span,
});
}
Ok(())
}
LiteralValue::ExternalPredicateHash { module, predicate } => {
if let Some(imported) = self.symbols.imported_modules.get(&module.name) {
if !imported.predicate_index.contains_key(&predicate.name) {
return Err(ValidationError::UndefinedPredicate {
name: format!("{}::{}", module.name, predicate.name),
span: predicate.span,
});
}
} else {
return Err(ValidationError::ModuleNotFound {
name: module.name.clone(),
span: module.span,
});
}
Ok(())
}
LiteralValue::Array(a) => {
for elem in &a.elements {
self.validate_literal_value(elem)?;
}
Ok(())
}
LiteralValue::Set(s) => {
for elem in &s.elements {
self.validate_literal_value(elem)?;
}
Ok(())
}
LiteralValue::Dict(d) => {
for pair in &d.pairs {
self.validate_literal_value(&pair.value)?;
}
Ok(())
}
_ => Ok(()),
}
}
}
#[cfg(test)]

View file

@ -49,7 +49,14 @@ custom_predicate_def = {
statement_list = { statement+ }
statement_arg = { literal_value | anchored_key | identifier }
// Predicate hash literals: resolve to the predicate's identity hash as a value.
// @native_predicate and @external_predicate are in literal_value (usable in containers).
// @self_predicate is only in statement_arg (not in containers — deferred resolution).
predicate_hash_native = { "@native_predicate" ~ "(" ~ identifier ~ ")" }
predicate_hash_external = { "@external_predicate" ~ "(" ~ identifier ~ "," ~ identifier ~ ")" }
predicate_hash_self = { "@self_predicate" ~ "(" ~ identifier ~ ")" }
statement_arg = { predicate_hash_self | literal_value | anchored_key | identifier }
statement_arg_list = { statement_arg ~ ("," ~ statement_arg)* }
// Predicate reference: either qualified (module::predicate) or local (predicate)
@ -74,6 +81,8 @@ literal_value = {
literal_bool |
literal_raw |
literal_string |
predicate_hash_native |
predicate_hash_external |
literal_int
}

View file

@ -11,7 +11,9 @@ use crate::{
lang::{
error::BatchingError,
frontend_ast::{ConjunctionType, CustomPredicateDef},
frontend_ast_lower::{lower_statement_arg, resolve_predicate_ref, ResolutionContext},
frontend_ast_lower::{
lower_statement_arg_with_context, resolve_predicate_ref, ResolutionContext,
},
frontend_ast_split::{SplitChainInfo, SplitResult},
frontend_ast_validate::SymbolTable,
},
@ -374,7 +376,13 @@ fn build_statement_with_resolved_refs(
let mut builder = StatementTmplBuilder::new(pred_or_wc);
for arg in &stmt.args {
builder = builder.arg(lower_statement_arg(arg));
let builder_arg =
lower_statement_arg_with_context(arg, symbols, &context).map_err(|e| {
BatchingError::Internal {
message: format!("Failed to lower argument: {}", e),
}
})?;
builder = builder.arg(builder_arg);
}
Ok(builder)
@ -670,4 +678,110 @@ mod tests {
PredicateOrWildcard::Predicate(Predicate::Custom(ordering_ref))
);
}
#[test]
fn test_self_predicate_hash_podlang() {
let params = Params::default();
let module = load_module(
r#"
pred_A(x, y) = AND(
Equal(x, y)
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#,
"test",
&params,
&[],
)
.unwrap();
let batch = &module.batch;
// pred_B is at index 1, its template should have SelfPredicateHash(0) resolved
// to a Literal containing pred_A's hash after normalization
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_a_hash = crate::middleware::Value::from(Predicate::Custom(pred_a_ref).hash());
// Use normalized_predicate to resolve
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let normalized = pred_b_ref.normalized_predicate();
assert_eq!(
normalized.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(pred_a_hash)
);
}
#[test]
fn test_self_predicate_hash_podlang_cyclic() {
let params = Params::default();
let module = load_module(
r#"
pred_A(x) = AND(
Equal(x, @self_predicate(pred_B))
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#,
"test",
&params,
&[],
)
.unwrap();
let batch = &module.batch;
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_a_hash =
crate::middleware::Value::from(Predicate::Custom(pred_a_ref.clone()).hash());
let pred_b_hash =
crate::middleware::Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
// pred_A's normalized form should contain pred_B's hash
let norm_a = pred_a_ref.normalized_predicate();
assert_eq!(
norm_a.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(pred_b_hash)
);
// pred_B's normalized form should contain pred_A's hash
let norm_b = pred_b_ref.normalized_predicate();
assert_eq!(
norm_b.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(pred_a_hash)
);
}
#[test]
fn test_native_predicate_hash_podlang() {
let params = Params::default();
let module = load_module(
r#"
pred_C(x) = AND(
Equal(x, @native_predicate(Equal))
)
"#,
"test",
&params,
&[],
)
.unwrap();
let batch = &module.batch;
let pred_c_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_c = pred_c_ref.predicate();
// The second arg should be a Literal containing Equal's predicate hash
let equal_hash = crate::middleware::Value::from(
Predicate::Native(crate::middleware::NativePredicate::Equal).hash(),
);
assert_eq!(
pred_c.statements[0].args[1],
crate::middleware::StatementTmplArg::Literal(equal_hash)
);
}
}

View file

@ -137,6 +137,9 @@ mod tests {
assert_inner(&Rule::anchored_key, "someVar[\"key\"]");
assert_inner(&Rule::literal_value, "true");
assert_inner(&Rule::literal_value, "PublicKey(abc)");
assert_inner(&Rule::predicate_hash_self, "@self_predicate(foo)");
assert_inner(&Rule::literal_value, "@native_predicate(Equal)");
assert_inner(&Rule::literal_value, "@external_predicate(mod_a, pred_b)");
}
#[test]
@ -207,6 +210,33 @@ mod tests {
"{ \"raw_val\": Raw(0x0000000000000000000000000000000000000000000000000000000000000000) } ",
);
assert_fails(Rule::literal_dict, "{ name: \"Alice\" }"); // Key must be string literal with quotes
// Predicate hash literals
assert_parses(Rule::predicate_hash_native, "@native_predicate(Equal)");
assert_parses(Rule::predicate_hash_native, "@native_predicate(Lt)");
assert_parses(
Rule::predicate_hash_external,
"@external_predicate(my_module, my_pred)",
);
assert_parses(Rule::predicate_hash_self, "@self_predicate(local_pred)");
// Predicate hashes inside containers (native and external only)
assert_parses(
Rule::literal_array,
"[1, @native_predicate(Equal), @external_predicate(m, p)]",
);
assert_parses(
Rule::literal_set,
"#[@native_predicate(Equal), @native_predicate(Lt)]",
);
assert_parses(
Rule::literal_dict,
"{ \"pred\": @external_predicate(m, p) }",
);
// @self_predicate is NOT a literal_value, so it cannot appear inside containers
assert_fails(Rule::test_literal_value, "@self_predicate(local_pred)");
assert_fails(Rule::literal_array, "[@self_predicate(foo)]");
}
#[test]

View file

@ -92,7 +92,7 @@ impl StatementTmpl {
if i > 0 {
write!(w, ", ")?;
}
arg.fmt_podlang(w)?;
arg.fmt_podlang_with_batch_context(w, batch_context)?;
}
write!(w, ")")?;
@ -102,7 +102,30 @@ impl StatementTmpl {
impl PrettyPrint for StatementTmplArg {
fn fmt_podlang_with_indent(&self, w: &mut dyn Write, _indent: usize) -> std::fmt::Result {
write!(w, "{}", self)
self.fmt_podlang_with_batch_context(w, None)
}
}
impl StatementTmplArg {
fn fmt_podlang_with_batch_context(
&self,
w: &mut dyn Write,
batch_context: Option<&CustomPredicateBatch>,
) -> std::fmt::Result {
match self {
StatementTmplArg::SelfPredicateHash(index) => {
if let Some(batch) = batch_context {
if let Some(predicate) = batch.predicates().get(*index) {
write!(w, "@self_predicate({})", predicate.name)
} else {
write!(w, "@self_predicate(self_{})", index)
}
} else {
write!(w, "@self_predicate(self_{})", index)
}
}
other => write!(w, "{}", other),
}
}
}
@ -540,6 +563,34 @@ mod tests {
assert_round_trip(&input);
}
#[test]
fn test_round_trip_self_predicate_hash() {
let input = r#"
pred_A(x, y) = AND(
Equal(x, y)
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#;
assert_round_trip(input);
}
#[test]
fn test_round_trip_self_predicate_hash_cyclic() {
let input = r#"
pred_A(x) = AND(
Equal(x, @self_predicate(pred_B))
)
pred_B(x) = AND(
Equal(x, @self_predicate(pred_A))
)
"#;
assert_round_trip(input);
}
#[test]
fn test_pretty_print_demonstration() {
let input = r#"