Basic 'use' syntax for importing custom predicates (#286)
* Basic 'use' syntax for importing custom predicates * Add extra test for unknown batches * Fix unused import * Enforce that imports must match number of predicates in a batch
This commit is contained in:
parent
f7bb6af219
commit
21ab3c2d0d
6 changed files with 499 additions and 136 deletions
|
|
@ -53,13 +53,15 @@ pub fn native_predicate_from_string(s: &str) -> Option<NativePredicate> {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ProcessedOutput {
|
||||
pub struct PodlangOutput {
|
||||
pub custom_batch: Arc<CustomPredicateBatch>,
|
||||
pub request_templates: Vec<StatementTmpl>,
|
||||
}
|
||||
|
||||
struct ProcessingContext<'a> {
|
||||
params: &'a Params,
|
||||
/// Maps imported predicate names to their full reference (batch and index)
|
||||
imported_predicates: HashMap<String, CustomPredicateRef>,
|
||||
/// Maps predicate names to their batch index and public argument count (from Pass 1)
|
||||
custom_predicate_signatures: HashMap<String, (usize, usize)>,
|
||||
/// Stores the original Pest pairs for custom predicate definitions for Pass 2
|
||||
|
|
@ -72,6 +74,7 @@ impl<'a> ProcessingContext<'a> {
|
|||
fn new(params: &'a Params) -> Self {
|
||||
ProcessingContext {
|
||||
params,
|
||||
imported_predicates: HashMap::new(),
|
||||
custom_predicate_signatures: HashMap::new(),
|
||||
custom_predicate_pairs: Vec::new(),
|
||||
request_pair: None,
|
||||
|
|
@ -82,7 +85,8 @@ impl<'a> ProcessingContext<'a> {
|
|||
pub fn process_pest_tree(
|
||||
mut pairs_iterator_for_document_rule: Pairs<'_, Rule>,
|
||||
params: &Params,
|
||||
) -> Result<ProcessedOutput, ProcessorError> {
|
||||
available_batches: &[Arc<CustomPredicateBatch>],
|
||||
) -> Result<PodlangOutput, ProcessorError> {
|
||||
let mut processing_ctx = ProcessingContext::new(params);
|
||||
|
||||
let document_node = pairs_iterator_for_document_rule.next().ok_or_else(|| {
|
||||
|
|
@ -102,7 +106,11 @@ pub fn process_pest_tree(
|
|||
|
||||
let document_content_pairs = document_node.into_inner();
|
||||
|
||||
first_pass(document_content_pairs, &mut processing_ctx)?;
|
||||
first_pass(
|
||||
document_content_pairs,
|
||||
&mut processing_ctx,
|
||||
available_batches,
|
||||
)?;
|
||||
|
||||
second_pass(&mut processing_ctx)
|
||||
}
|
||||
|
|
@ -112,12 +120,16 @@ pub fn process_pest_tree(
|
|||
fn first_pass<'a>(
|
||||
document_pairs: Pairs<'a, Rule>,
|
||||
ctx: &mut ProcessingContext<'a>,
|
||||
available_batches: &[Arc<CustomPredicateBatch>],
|
||||
) -> Result<(), ProcessorError> {
|
||||
let mut defined_custom_names: HashSet<String> = HashSet::new();
|
||||
let mut first_request_span: Option<(usize, usize)> = None;
|
||||
|
||||
for pair in document_pairs {
|
||||
match pair.as_rule() {
|
||||
Rule::use_statement => {
|
||||
process_use_statement(&pair, ctx, available_batches)?;
|
||||
}
|
||||
Rule::custom_predicate_def => {
|
||||
let pred_name_pair = pair
|
||||
.clone()
|
||||
|
|
@ -126,7 +138,9 @@ fn first_pass<'a>(
|
|||
.unwrap();
|
||||
let pred_name = pred_name_pair.as_str().to_string();
|
||||
|
||||
if defined_custom_names.contains(&pred_name) {
|
||||
if defined_custom_names.contains(&pred_name)
|
||||
|| ctx.imported_predicates.contains_key(&pred_name)
|
||||
{
|
||||
return Err(ProcessorError::DuplicateDefinition {
|
||||
name: pred_name,
|
||||
span: Some(get_span(&pred_name_pair)),
|
||||
|
|
@ -179,9 +193,85 @@ fn count_public_args(pred_def_pair: &Pair<Rule>) -> Result<usize, ProcessorError
|
|||
.count())
|
||||
}
|
||||
|
||||
fn second_pass(ctx: &mut ProcessingContext) -> Result<ProcessedOutput, ProcessorError> {
|
||||
fn process_use_statement(
|
||||
use_pair: &Pair<Rule>,
|
||||
ctx: &mut ProcessingContext,
|
||||
available_batches: &[Arc<CustomPredicateBatch>],
|
||||
) -> Result<(), ProcessorError> {
|
||||
let mut inner = use_pair.clone().into_inner();
|
||||
|
||||
let import_list_pair = inner
|
||||
.find(|p| p.as_rule() == Rule::use_predicate_list)
|
||||
.unwrap();
|
||||
let batch_ref_pair = inner.find(|p| p.as_rule() == Rule::batch_ref).unwrap();
|
||||
let batch_id_pair = batch_ref_pair.into_inner().next().unwrap();
|
||||
let batch_id_str_full = batch_id_pair.as_str();
|
||||
|
||||
let batch_id_hex = batch_id_str_full
|
||||
.strip_prefix("0x")
|
||||
.unwrap_or(batch_id_str_full);
|
||||
let batch_id_val = parse_hex_str_to_raw_value(batch_id_hex).map_err(|_| {
|
||||
ProcessorError::InvalidLiteralFormat {
|
||||
kind: "batch ID hash".to_string(),
|
||||
value: batch_id_str_full.to_string(),
|
||||
span: Some(get_span(&batch_id_pair)),
|
||||
}
|
||||
})?;
|
||||
|
||||
let target_batch = available_batches
|
||||
.iter()
|
||||
.find(|b| b.id().0 == batch_id_val.0)
|
||||
.ok_or_else(|| ProcessorError::BatchNotFound {
|
||||
id: batch_id_str_full.to_string(),
|
||||
span: Some(get_span(&batch_id_pair)),
|
||||
})?;
|
||||
|
||||
let import_names: Vec<Pair<Rule>> = import_list_pair
|
||||
.into_inner()
|
||||
.filter(|p| p.as_rule() == Rule::import_name)
|
||||
.collect();
|
||||
|
||||
if import_names.len() != target_batch.predicates().len() {
|
||||
return Err(ProcessorError::ImportArityMismatch {
|
||||
expected: target_batch.predicates().len(),
|
||||
found: import_names.len(),
|
||||
span: Some(get_span(use_pair)),
|
||||
});
|
||||
}
|
||||
|
||||
for (i, import_name_pair) in import_names.into_iter().enumerate() {
|
||||
if import_name_pair.as_str() == "_" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let name = import_name_pair.as_str().to_string();
|
||||
|
||||
if ctx.imported_predicates.contains_key(&name) {
|
||||
return Err(ProcessorError::DuplicateImportName {
|
||||
name,
|
||||
span: Some(get_span(&import_name_pair)),
|
||||
});
|
||||
}
|
||||
|
||||
let custom_pred_ref = CustomPredicateRef::new(target_batch.clone(), i);
|
||||
ctx.imported_predicates.insert(name, custom_pred_ref);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
enum StatementContext<'a> {
|
||||
CustomPredicate,
|
||||
Request {
|
||||
custom_batch: &'a Arc<CustomPredicateBatch>,
|
||||
wildcard_names: &'a mut Vec<String>,
|
||||
defined_wildcards: &'a mut HashSet<String>,
|
||||
},
|
||||
}
|
||||
|
||||
fn second_pass(ctx: &mut ProcessingContext) -> Result<PodlangOutput, ProcessorError> {
|
||||
let mut cpb_builder =
|
||||
CustomPredicateBatchBuilder::new(ctx.params.clone(), "PodlogBatch".to_string());
|
||||
CustomPredicateBatchBuilder::new(ctx.params.clone(), "PodlangBatch".to_string());
|
||||
|
||||
for pred_pair in &ctx.custom_predicate_pairs {
|
||||
process_and_add_custom_predicate_to_batch(pred_pair, ctx, &mut cpb_builder)?;
|
||||
|
|
@ -195,7 +285,7 @@ fn second_pass(ctx: &mut ProcessingContext) -> Result<ProcessedOutput, Processor
|
|||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(ProcessedOutput {
|
||||
Ok(PodlangOutput {
|
||||
custom_batch,
|
||||
request_templates,
|
||||
})
|
||||
|
|
@ -326,7 +416,32 @@ fn validate_and_build_statement_template(
|
|||
}
|
||||
}
|
||||
}
|
||||
Predicate::Custom(_) | Predicate::BatchSelf(_) => {
|
||||
Predicate::Custom(custom_ref) => {
|
||||
let expected_arity = custom_ref.predicate().args_len;
|
||||
if args.len() != expected_arity {
|
||||
return Err(ProcessorError::ArgumentCountMismatch {
|
||||
predicate: stmt_name_str.to_string(),
|
||||
expected: expected_arity,
|
||||
found: args.len(),
|
||||
span: Some(stmt_name_span),
|
||||
});
|
||||
}
|
||||
for (idx, arg) in args.iter().enumerate() {
|
||||
if !matches!(arg, BuilderArg::WildcardLiteral(_) | BuilderArg::Literal(_)) {
|
||||
return Err(ProcessorError::TypeError {
|
||||
expected: "Wildcard or Literal".to_string(),
|
||||
found: format!("{:?}", arg),
|
||||
item: format!(
|
||||
"argument {} of custom predicate call '{}'",
|
||||
idx + 1,
|
||||
stmt_name_str
|
||||
),
|
||||
span: Some(stmt_span),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Predicate::BatchSelf(_) => {
|
||||
let (_original_pred_idx, expected_arity_val) = processing_ctx
|
||||
.custom_predicate_signatures
|
||||
.get(stmt_name_str)
|
||||
|
|
@ -454,36 +569,10 @@ fn process_and_add_custom_predicate_to_batch(
|
|||
.into_inner()
|
||||
.filter(|p| p.as_rule() == Rule::statement)
|
||||
{
|
||||
let mut inner_stmt_pairs = stmt_pair.clone().into_inner();
|
||||
let stmt_name_pair = inner_stmt_pairs
|
||||
.find(|p| p.as_rule() == Rule::identifier)
|
||||
.unwrap_or_else(|| unreachable!("statement name must be present in statement"));
|
||||
let stmt_name_str = stmt_name_pair.as_str();
|
||||
|
||||
let builder_args = parse_statement_args(&stmt_pair)?;
|
||||
|
||||
let middleware_predicate_type =
|
||||
if let Some(native_pred) = native_predicate_from_string(stmt_name_str) {
|
||||
Predicate::Native(native_pred)
|
||||
} else if let Some((pred_index, _expected_arity)) = processing_ctx
|
||||
.custom_predicate_signatures
|
||||
.get(stmt_name_str)
|
||||
{
|
||||
Predicate::BatchSelf(*pred_index)
|
||||
} else {
|
||||
return Err(ProcessorError::UndefinedIdentifier {
|
||||
name: stmt_name_str.to_string(),
|
||||
span: Some(get_span(&stmt_name_pair)),
|
||||
});
|
||||
};
|
||||
|
||||
let stb = validate_and_build_statement_template(
|
||||
stmt_name_str,
|
||||
&middleware_predicate_type,
|
||||
builder_args,
|
||||
let stb = process_statement_template(
|
||||
&stmt_pair,
|
||||
processing_ctx,
|
||||
get_span(&stmt_pair),
|
||||
get_span(&stmt_name_pair),
|
||||
StatementContext::CustomPredicate,
|
||||
)?;
|
||||
statement_builders.push(stb);
|
||||
}
|
||||
|
|
@ -520,12 +609,14 @@ fn process_request_def(
|
|||
.into_inner()
|
||||
.filter(|p| p.as_rule() == Rule::statement)
|
||||
{
|
||||
let built_stb = process_proof_request_statement_template(
|
||||
let built_stb = process_statement_template(
|
||||
&stmt_pair,
|
||||
processing_ctx,
|
||||
Some(custom_batch), // Pass as Option<&Arc<...>>
|
||||
&mut request_wildcard_names,
|
||||
&mut defined_request_wildcards,
|
||||
StatementContext::Request {
|
||||
custom_batch,
|
||||
wildcard_names: &mut request_wildcard_names,
|
||||
defined_wildcards: &mut defined_request_wildcards,
|
||||
},
|
||||
)?;
|
||||
request_statement_builders.push(built_stb);
|
||||
}
|
||||
|
|
@ -542,12 +633,10 @@ fn process_request_def(
|
|||
Ok(request_templates)
|
||||
}
|
||||
|
||||
fn process_proof_request_statement_template(
|
||||
fn process_statement_template(
|
||||
stmt_pair: &Pair<Rule>,
|
||||
processing_ctx: &ProcessingContext,
|
||||
custom_batch_for_request: Option<&Arc<CustomPredicateBatch>>,
|
||||
request_wildcard_names: &mut Vec<String>,
|
||||
defined_request_wildcards: &mut HashSet<String>,
|
||||
mut context: StatementContext,
|
||||
) -> Result<StatementTmplBuilder, ProcessorError> {
|
||||
let mut inner_stmt_pairs = stmt_pair.clone().into_inner();
|
||||
let name_pair = inner_stmt_pairs
|
||||
|
|
@ -556,50 +645,58 @@ fn process_proof_request_statement_template(
|
|||
let stmt_name_str = name_pair.as_str();
|
||||
|
||||
let builder_args = parse_statement_args(stmt_pair)?;
|
||||
let mut temp_stmt_wildcard_names: Vec<String> = Vec::new();
|
||||
|
||||
for arg in &builder_args {
|
||||
match arg {
|
||||
BuilderArg::WildcardLiteral(name) => temp_stmt_wildcard_names.push(name.clone()),
|
||||
BuilderArg::Key(pod_id_str, key_wc_str) => {
|
||||
if let SelfOrWildcardStr::Wildcard(name) = pod_id_str {
|
||||
temp_stmt_wildcard_names.push(name.clone());
|
||||
}
|
||||
if let KeyOrWildcardStr::Wildcard(key_wc_name) = key_wc_str {
|
||||
temp_stmt_wildcard_names.push(key_wc_name.clone());
|
||||
if let StatementContext::Request {
|
||||
wildcard_names,
|
||||
defined_wildcards,
|
||||
..
|
||||
} = &mut context
|
||||
{
|
||||
let mut temp_stmt_wildcard_names: Vec<String> = Vec::new();
|
||||
for arg in &builder_args {
|
||||
match arg {
|
||||
BuilderArg::WildcardLiteral(name) => temp_stmt_wildcard_names.push(name.clone()),
|
||||
BuilderArg::Key(pod_id_str, key_wc_str) => {
|
||||
if let SelfOrWildcardStr::Wildcard(name) = pod_id_str {
|
||||
temp_stmt_wildcard_names.push(name.clone());
|
||||
}
|
||||
if let KeyOrWildcardStr::Wildcard(key_wc_name) = key_wc_str {
|
||||
temp_stmt_wildcard_names.push(key_wc_name.clone());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
for name in temp_stmt_wildcard_names {
|
||||
if defined_wildcards.insert(name.clone()) {
|
||||
wildcard_names.push(name);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
for name in temp_stmt_wildcard_names {
|
||||
if defined_request_wildcards.insert(name.clone()) {
|
||||
request_wildcard_names.push(name);
|
||||
}
|
||||
}
|
||||
|
||||
let middleware_predicate_type =
|
||||
if let Some(native_pred) = native_predicate_from_string(stmt_name_str) {
|
||||
Predicate::Native(native_pred)
|
||||
} else if let Some((pred_index, _expected_arity)) = processing_ctx
|
||||
.custom_predicate_signatures
|
||||
.get(stmt_name_str)
|
||||
{
|
||||
if let Some(batch_ref) = custom_batch_for_request {
|
||||
Predicate::Custom(CustomPredicateRef::new(batch_ref.clone(), *pred_index))
|
||||
} else {
|
||||
return Err(ProcessorError::Internal(format!(
|
||||
"Custom predicate '{}' found but no custom batch provided for request processing.",
|
||||
stmt_name_str
|
||||
)));
|
||||
let middleware_predicate_type = if let Some(native_pred) =
|
||||
native_predicate_from_string(stmt_name_str)
|
||||
{
|
||||
Predicate::Native(native_pred)
|
||||
} else if let Some(custom_ref) = processing_ctx.imported_predicates.get(stmt_name_str) {
|
||||
Predicate::Custom(custom_ref.clone())
|
||||
} else if let Some((pred_index, _expected_arity)) = processing_ctx
|
||||
.custom_predicate_signatures
|
||||
.get(stmt_name_str)
|
||||
{
|
||||
match context {
|
||||
StatementContext::CustomPredicate => Predicate::BatchSelf(*pred_index),
|
||||
StatementContext::Request { custom_batch, .. } => {
|
||||
let custom_pred_ref = CustomPredicateRef::new(custom_batch.clone(), *pred_index);
|
||||
Predicate::Custom(custom_pred_ref)
|
||||
}
|
||||
} else {
|
||||
return Err(ProcessorError::UndefinedIdentifier {
|
||||
name: stmt_name_str.to_string(),
|
||||
span: Some(get_span(&name_pair)),
|
||||
});
|
||||
};
|
||||
}
|
||||
} else {
|
||||
return Err(ProcessorError::UndefinedIdentifier {
|
||||
name: stmt_name_str.to_string(),
|
||||
span: Some(get_span(&name_pair)),
|
||||
});
|
||||
};
|
||||
|
||||
let stb = validate_and_build_statement_template(
|
||||
stmt_name_str,
|
||||
|
|
@ -882,13 +979,13 @@ mod processor_tests {
|
|||
use crate::{
|
||||
lang::{
|
||||
error::ProcessorError,
|
||||
parser::{parse_podlog, Rule},
|
||||
parser::{parse_podlang, Rule},
|
||||
},
|
||||
middleware::Params,
|
||||
};
|
||||
|
||||
fn get_document_content_pairs(input: &str) -> Result<Pairs<Rule>, ProcessorError> {
|
||||
let full_parse_tree = parse_podlog(input)
|
||||
let full_parse_tree = parse_podlang(input)
|
||||
.map_err(|e| ProcessorError::Internal(format!("Test parsing failed: {:?}", e)))?;
|
||||
|
||||
let document_node = full_parse_tree.peek().ok_or_else(|| {
|
||||
|
|
@ -910,7 +1007,7 @@ mod processor_tests {
|
|||
let pairs = get_document_content_pairs(input)?;
|
||||
let params = Params::default();
|
||||
let mut ctx = ProcessingContext::new(¶ms);
|
||||
first_pass(pairs, &mut ctx)?;
|
||||
first_pass(pairs, &mut ctx, &[])?;
|
||||
assert!(ctx.custom_predicate_signatures.is_empty());
|
||||
assert!(ctx.custom_predicate_pairs.is_empty());
|
||||
assert!(ctx.request_pair.is_none());
|
||||
|
|
@ -923,7 +1020,7 @@ mod processor_tests {
|
|||
let pairs = get_document_content_pairs(input)?;
|
||||
let params = Params::default();
|
||||
let mut ctx = ProcessingContext::new(¶ms);
|
||||
first_pass(pairs, &mut ctx)?;
|
||||
first_pass(pairs, &mut ctx, &[])?;
|
||||
assert!(ctx.custom_predicate_signatures.is_empty());
|
||||
assert!(ctx.custom_predicate_pairs.is_empty());
|
||||
assert!(ctx.request_pair.is_some());
|
||||
|
|
@ -940,7 +1037,7 @@ mod processor_tests {
|
|||
let pairs = get_document_content_pairs(input)?;
|
||||
let params = Params::default();
|
||||
let mut ctx = ProcessingContext::new(¶ms);
|
||||
first_pass(pairs, &mut ctx)?;
|
||||
first_pass(pairs, &mut ctx, &[])?;
|
||||
assert_eq!(ctx.custom_predicate_signatures.len(), 1);
|
||||
assert_eq!(ctx.custom_predicate_pairs.len(), 1);
|
||||
assert!(ctx.request_pair.is_none());
|
||||
|
|
@ -964,7 +1061,7 @@ mod processor_tests {
|
|||
let pairs = get_document_content_pairs(input)?;
|
||||
let params = Params::default();
|
||||
let mut ctx = ProcessingContext::new(¶ms);
|
||||
first_pass(pairs, &mut ctx)?;
|
||||
first_pass(pairs, &mut ctx, &[])?;
|
||||
assert_eq!(ctx.custom_predicate_signatures.len(), 2);
|
||||
assert_eq!(ctx.custom_predicate_pairs.len(), 2);
|
||||
|
||||
|
|
@ -991,11 +1088,12 @@ mod processor_tests {
|
|||
let params = Params::default();
|
||||
let mut ctx = ProcessingContext {
|
||||
params: ¶ms,
|
||||
imported_predicates: HashMap::new(),
|
||||
custom_predicate_signatures: HashMap::new(),
|
||||
custom_predicate_pairs: Vec::new(),
|
||||
request_pair: None,
|
||||
};
|
||||
first_pass(pairs, &mut ctx)?;
|
||||
first_pass(pairs, &mut ctx, &[])?;
|
||||
let pred_name = ctx
|
||||
.custom_predicate_signatures
|
||||
.keys()
|
||||
|
|
@ -1016,7 +1114,7 @@ mod processor_tests {
|
|||
let pairs = get_document_content_pairs(input).unwrap();
|
||||
let params = Params::default();
|
||||
let mut ctx = ProcessingContext::new(¶ms);
|
||||
let result = first_pass(pairs, &mut ctx);
|
||||
let result = first_pass(pairs, &mut ctx, &[]);
|
||||
assert!(result.is_err());
|
||||
match result.err().unwrap() {
|
||||
// Use .err().unwrap() for ProcessorError
|
||||
|
|
@ -1036,7 +1134,7 @@ mod processor_tests {
|
|||
let pairs = get_document_content_pairs(input).unwrap();
|
||||
let params = Params::default();
|
||||
let mut ctx = ProcessingContext::new(¶ms);
|
||||
let result = first_pass(pairs, &mut ctx);
|
||||
let result = first_pass(pairs, &mut ctx, &[]);
|
||||
assert!(result.is_err());
|
||||
match result.err().unwrap() {
|
||||
// Use .err().unwrap() for ProcessorError
|
||||
|
|
@ -1055,7 +1153,7 @@ mod processor_tests {
|
|||
let pairs = get_document_content_pairs(input)?;
|
||||
let params = Params::default();
|
||||
let mut ctx = ProcessingContext::new(¶ms);
|
||||
first_pass(pairs, &mut ctx)?;
|
||||
first_pass(pairs, &mut ctx, &[])?;
|
||||
|
||||
assert_eq!(ctx.custom_predicate_signatures.len(), 2);
|
||||
assert_eq!(ctx.custom_predicate_pairs.len(), 2);
|
||||
|
|
@ -1093,7 +1191,7 @@ mod processor_tests {
|
|||
let pairs = get_document_content_pairs(input)?;
|
||||
let params = Params::default();
|
||||
let mut ctx = ProcessingContext::new(¶ms);
|
||||
first_pass(pairs, &mut ctx)?;
|
||||
first_pass(pairs, &mut ctx, &[])?;
|
||||
let result = second_pass(&mut ctx);
|
||||
assert!(result.is_err());
|
||||
match result.err().unwrap() {
|
||||
|
|
@ -1112,7 +1210,7 @@ mod processor_tests {
|
|||
let pairs = get_document_content_pairs(input)?;
|
||||
let params = Params::default();
|
||||
let mut ctx = ProcessingContext::new(¶ms);
|
||||
first_pass(pairs, &mut ctx)?;
|
||||
first_pass(pairs, &mut ctx, &[])?;
|
||||
let result = second_pass(&mut ctx);
|
||||
assert!(result.is_err());
|
||||
match result.err().unwrap() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue