Fix parsing of intro statement templates inside custom predicates (#467)

* Fix parsing of intro statement templates inside custom predicates

* Tidy up comments
This commit is contained in:
Rob Knight 2026-01-30 19:30:57 +01:00 committed by GitHub
parent 337a51135e
commit 879c7201ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 263 additions and 169 deletions

View file

@ -6,6 +6,7 @@
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
};
use crate::{
@ -14,15 +15,124 @@ use crate::{
frontend_ast::*,
frontend_ast_batch::{self, PredicateBatches},
frontend_ast_split,
frontend_ast_validate::{PredicateKind, ValidatedAST},
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
},
middleware::{
self, containers, IntroPredicateRef, NativePredicate, Params, Predicate,
PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
StatementTmplArg as MWStatementTmplArg, Wildcard,
containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
NativePredicate, Params, Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
},
};
/// Context for predicate resolution - determines how local custom predicates are resolved
pub enum ResolutionContext<'a> {
/// Request context: local custom predicates resolve to Intro/CustomPredicateRef via batches
Request {
batches: Option<&'a PredicateBatches>,
},
/// Batch context: local custom predicates may resolve to BatchSelf or Intro/CustomPredicateRef
Batch {
current_batch_idx: usize,
reference_map: &'a HashMap<String, (usize, usize)>,
existing_batches: &'a [Arc<CustomPredicateBatch>],
},
}
/// Resolve a predicate name to a Predicate using the symbol table
pub fn resolve_predicate(
pred_name: &str,
symbols: &SymbolTable,
context: &ResolutionContext,
) -> Option<Predicate> {
// 1. Try native predicate first
if let Ok(native) = NativePredicate::from_str(pred_name) {
return Some(Predicate::Native(native));
}
// 2. Look up in symbol table
if let Some(info) = symbols.predicates.get(pred_name) {
let predicate = match &info.kind {
PredicateKind::Native(np) => Predicate::Native(*np),
PredicateKind::Custom { .. } => match context {
ResolutionContext::Request { batches } => {
let batches = batches.as_ref()?;
let pred_ref = batches.predicate_ref_by_name(pred_name)?;
Predicate::Custom(pred_ref)
}
ResolutionContext::Batch {
current_batch_idx,
reference_map,
existing_batches,
} => resolve_local_predicate(
pred_name,
*current_batch_idx,
reference_map,
existing_batches,
)?,
},
PredicateKind::BatchImported { batch, index } => {
Predicate::Custom(CustomPredicateRef::new(batch.clone(), *index))
}
PredicateKind::IntroImported {
name,
verifier_data_hash,
} => Predicate::Intro(IntroPredicateRef {
name: name.clone(),
args_len: info.public_arity,
verifier_data_hash: *verifier_data_hash,
}),
};
return Some(predicate);
}
// 3. In batch context, also check reference_map for split chain pieces
// (predicates created by splitting that aren't in the original symbol table)
if let ResolutionContext::Batch {
current_batch_idx,
reference_map,
existing_batches,
} = context
{
if reference_map.contains_key(pred_name) {
return resolve_local_predicate(
pred_name,
*current_batch_idx,
reference_map,
existing_batches,
);
}
}
None
}
/// Resolve a local predicate (one in this document or a split chain piece) using the reference_map
fn resolve_local_predicate(
pred_name: &str,
current_batch_idx: usize,
reference_map: &HashMap<String, (usize, usize)>,
existing_batches: &[Arc<CustomPredicateBatch>],
) -> Option<Predicate> {
let &(target_batch, target_idx) = reference_map.get(pred_name)?;
if target_batch == current_batch_idx {
Some(Predicate::BatchSelf(target_idx))
} else if target_batch < current_batch_idx {
let batch = &existing_batches[target_batch];
Some(Predicate::Custom(CustomPredicateRef::new(
batch.clone(),
target_idx,
)))
} else {
unreachable!(
"Forward cross-batch reference should be impossible: {} -> {}",
current_batch_idx, target_batch
);
}
}
// ============================================================================
// Shared lowering utilities
// ============================================================================
@ -33,37 +143,37 @@ use crate::{
/// Lower a literal value from AST to middleware Value.
///
/// This is a pure conversion that cannot fail.
pub fn lower_literal(lit: &LiteralValue) -> middleware::Value {
pub fn lower_literal(lit: &LiteralValue) -> Value {
match lit {
LiteralValue::Int(i) => middleware::Value::from(i.value),
LiteralValue::Bool(b) => middleware::Value::from(b.value),
LiteralValue::String(s) => middleware::Value::from(s.value.clone()),
LiteralValue::Raw(r) => middleware::Value::from(r.hash.hash),
LiteralValue::PublicKey(pk) => middleware::Value::from(pk.point),
LiteralValue::SecretKey(sk) => middleware::Value::from(sk.secret_key.clone()),
LiteralValue::Int(i) => Value::from(i.value),
LiteralValue::Bool(b) => Value::from(b.value),
LiteralValue::String(s) => Value::from(s.value.clone()),
LiteralValue::Raw(r) => Value::from(r.hash.hash),
LiteralValue::PublicKey(pk) => Value::from(pk.point),
LiteralValue::SecretKey(sk) => Value::from(sk.secret_key.clone()),
LiteralValue::Array(a) => {
let elements: Vec<_> = a.elements.iter().map(lower_literal).collect();
let array = containers::Array::new(elements);
middleware::Value::from(array)
Value::from(array)
}
LiteralValue::Set(s) => {
let elements: std::collections::HashSet<_> =
s.elements.iter().map(lower_literal).collect();
let set = containers::Set::new(elements);
middleware::Value::from(set)
Value::from(set)
}
LiteralValue::Dict(d) => {
let pairs: std::collections::HashMap<_, _> = d
let pairs: HashMap<_, _> = d
.pairs
.iter()
.map(|pair| {
let key = middleware::Key::from(pair.key.value.as_str());
let key = Key::from(pair.key.value.as_str());
let value = lower_literal(&pair.value);
(key, value)
})
.collect();
let dict = containers::Dictionary::new(pairs);
middleware::Value::from(dict)
Value::from(dict)
}
}
}
@ -151,41 +261,18 @@ impl<'a> Lowerer<'a> {
return Ok(None);
}
// Build map of imported predicates for batching
let imported_predicates = self.build_imported_predicates_map();
// Use the new batching module to pack predicates into batches
// Pass the symbol table for unified predicate resolution
let batches = frontend_ast_batch::batch_predicates(
custom_predicates,
self.params,
&batch_name,
&imported_predicates,
self.validated.symbols(),
)?;
Ok(Some(batches))
}
fn build_imported_predicates_map(
&self,
) -> HashMap<String, frontend_ast_batch::ImportedPredicateInfo> {
let symbols = self.validated.symbols();
let mut imported = HashMap::new();
for (name, info) in &symbols.predicates {
if let PredicateKind::BatchImported { batch, index } = &info.kind {
imported.insert(
name.clone(),
frontend_ast_batch::ImportedPredicateInfo {
batch: batch.clone(),
index: *index,
},
);
}
}
imported
}
fn lower_request(
&self,
batches: Option<&PredicateBatches>,
@ -232,42 +319,13 @@ impl<'a> Lowerer<'a> {
let pred_name = &stmt.predicate.name;
let symbols = self.validated.symbols();
// Resolve predicate - for request statements, local custom predicates
// must be resolved to CustomPredicateRef (not BatchSelf)
let predicate = if let Ok(native) = NativePredicate::from_str(pred_name) {
Predicate::Native(native)
} else if let Some(info) = symbols.predicates.get(pred_name) {
match &info.kind {
PredicateKind::Native(np) => Predicate::Native(*np),
PredicateKind::Custom { .. } => {
// Local custom predicates - resolve to CustomPredicateRef
let batches = batches.ok_or_else(|| LoweringError::PredicateNotFound {
name: pred_name.clone(),
})?;
let pred_ref = batches.predicate_ref_by_name(pred_name).ok_or_else(|| {
LoweringError::PredicateNotFound {
name: pred_name.clone(),
}
})?;
Predicate::Custom(pred_ref)
}
PredicateKind::BatchImported { batch, index } => {
Predicate::Custom(middleware::CustomPredicateRef::new(batch.clone(), *index))
}
PredicateKind::IntroImported {
name,
verifier_data_hash,
} => Predicate::Intro(IntroPredicateRef {
name: name.clone(),
args_len: info.public_arity,
verifier_data_hash: *verifier_data_hash,
}),
}
} else {
return Err(LoweringError::PredicateNotFound {
// Resolve predicate using the unified resolution function
let context = ResolutionContext::Request { batches };
let predicate = resolve_predicate(pred_name, symbols, &context).ok_or_else(|| {
LoweringError::PredicateNotFound {
name: pred_name.clone(),
});
};
}
})?;
// Create a builder with the resolved predicate and desugar
let mut builder = StatementTmplBuilder::new(predicate);
@ -291,7 +349,7 @@ impl<'a> Lowerer<'a> {
.get(&root_name)
.expect("Root wildcard not found");
let wildcard = Wildcard::new(root_name, *root_index);
let key = middleware::Key::from(key_str.as_str());
let key = Key::from(key_str.as_str());
MWStatementTmplArg::AnchoredKey(wildcard, key)
}
};
@ -646,4 +704,57 @@ mod tests {
assert_eq!(batches.total_predicate_count(), 5);
assert_eq!(batches.batch_count(), 2);
}
#[test]
fn test_intro_predicate_in_custom_predicate() {
use hex::ToHex;
use crate::middleware::EMPTY_HASH;
// Import an intro predicate and use it inside a custom predicate definition
let intro_hash = EMPTY_HASH.encode_hex::<String>();
let input = format!(
r#"
use intro external_check(X) from 0x{intro_hash}
my_pred(A) = AND (
Equal(A["foo"], 42)
external_check(A)
)
"#
);
let params = Params::default();
// Parse, validate, and lower
let parsed = parse_podlang(&input).expect("Failed to parse");
let document =
parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document");
let validated = validate(document, &[]).expect("Failed to validate");
let result = lower(validated, &params, "test_batch".to_string());
assert!(result.is_ok(), "Lowering failed: {:?}", result.err());
let lowered = result.unwrap();
let batch = expect_batch(&lowered);
// Should have one custom predicate
assert_eq!(batch.predicates().len(), 1);
let pred = &batch.predicates()[0];
assert_eq!(pred.name, "my_pred");
// 2 statements: Equal and external_check
assert_eq!(pred.statements().len(), 2);
// Verify the second statement is an intro predicate reference
let intro_stmt = &pred.statements()[1];
match intro_stmt.pred_or_wc() {
PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => {
assert_eq!(intro_ref.name, "external_check");
assert_eq!(intro_ref.args_len, 1);
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
}
other => panic!("Expected Intro predicate, got {:?}", other),
}
}
}