Frontend AST for Podlang (#432)
* Basic frontend AST and semantic validation * Intro statement support * Simplify validator lifetime * Fix arity validation * Lowering and splitting * Remove legacy processor and use frontend AST by default * Use builders instead of creating middleware types directly * Typos/formatting * Improve error messages when overflowing a batch due to splitting * Add FromStr implementation for NativePredicate * Remove 'raw' fields, and switch HashHex representation to byte vector rather than string * Simpler wrapper types for batch and intro predicate hashes * Parse secret and public keys to their respective data structures earlier * More detail around string escape validity * Simplify native predicate arity handling and move method to NativePredicate impl * Store hashes using middleware::Hash, and simplify lowering by using pre-parsed values * Simplify predicate building * Formatting * Better error messages/suggestions for cases where predicate splitting fails * Formatting * Clippy fix * Return error if we get a too-large int
This commit is contained in:
parent
c382bf487c
commit
42f979c408
11 changed files with 4250 additions and 1431 deletions
|
|
@ -165,7 +165,7 @@ impl CustomPredicateBatchBuilder {
|
|||
|
||||
/// creates the custom predicate from the given input, adds it to the
|
||||
/// self.predicates, and returns the index of the created predicate
|
||||
fn predicate(
|
||||
pub fn predicate(
|
||||
&mut self,
|
||||
name: &str,
|
||||
conjunction: bool,
|
||||
|
|
|
|||
|
|
@ -1,95 +1,283 @@
|
|||
use thiserror::Error;
|
||||
|
||||
use crate::{frontend, lang::parser::ParseError, middleware};
|
||||
use crate::{
|
||||
frontend,
|
||||
lang::{frontend_ast::Span, parser::ParseError},
|
||||
middleware,
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LangError {
|
||||
#[error("Parsing failed: {0}")]
|
||||
Parse(Box<ParseError>),
|
||||
|
||||
#[error("AST processing error: {0}")]
|
||||
Processor(Box<ProcessorError>),
|
||||
|
||||
#[error("Middleware error during processing: {0}")]
|
||||
Middleware(Box<middleware::Error>),
|
||||
|
||||
#[error("Frontend error: {0}")]
|
||||
Frontend(Box<frontend::Error>),
|
||||
|
||||
#[error("Validation error: {0}")]
|
||||
Validation(Box<ValidationError>),
|
||||
|
||||
#[error("Lowering error: {0}")]
|
||||
Lowering(Box<LoweringError>),
|
||||
}
|
||||
|
||||
/// Errors that can occur during the processing of Podlang Pest tree into middleware structures.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ProcessorError {
|
||||
#[error("Undefined identifier: '{name}' at {span:?}")]
|
||||
UndefinedIdentifier {
|
||||
/// Validation errors from frontend AST validation
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ValidationError {
|
||||
#[error("Invalid hash: {hash}")]
|
||||
InvalidHash { hash: String, span: Option<Span> },
|
||||
|
||||
#[error("Duplicate predicate definition: {name}")]
|
||||
DuplicatePredicate {
|
||||
name: String,
|
||||
span: Option<(usize, usize)>,
|
||||
first_span: Option<Span>,
|
||||
second_span: Option<Span>,
|
||||
},
|
||||
#[error("Duplicate definition: '{name}' at {span:?}")]
|
||||
DuplicateDefinition {
|
||||
|
||||
#[error("Duplicate import name: {name}")]
|
||||
DuplicateImport { name: String, span: Option<Span> },
|
||||
|
||||
#[error("Import arity mismatch: expected {expected} predicates, found {found}")]
|
||||
ImportArityMismatch {
|
||||
expected: usize,
|
||||
found: usize,
|
||||
span: Option<Span>,
|
||||
},
|
||||
|
||||
#[error("Batch not found: {id}")]
|
||||
BatchNotFound { id: String, span: Option<Span> },
|
||||
|
||||
#[error("Undefined predicate: {name}")]
|
||||
UndefinedPredicate { name: String, span: Option<Span> },
|
||||
|
||||
#[error("Undefined wildcard: {name} in predicate {pred_name}")]
|
||||
UndefinedWildcard {
|
||||
name: String,
|
||||
span: Option<(usize, usize)>,
|
||||
pred_name: String,
|
||||
span: Option<Span>,
|
||||
},
|
||||
#[error("Duplicate wildcard: ?{name} in scope at {span:?}")]
|
||||
DuplicateWildcard {
|
||||
name: String,
|
||||
span: Option<(usize, usize)>,
|
||||
},
|
||||
#[error("Type error: expected {expected}, found {found} for '{item}' at {span:?}")]
|
||||
TypeError {
|
||||
expected: String,
|
||||
found: String,
|
||||
item: String,
|
||||
span: Option<(usize, usize)>,
|
||||
},
|
||||
#[error(
|
||||
"Invalid argument count for '{predicate}': expected {expected}, found {found} at {span:?}"
|
||||
)]
|
||||
|
||||
#[error("Argument count mismatch for {predicate}: expected {expected}, found {found}")]
|
||||
ArgumentCountMismatch {
|
||||
predicate: String,
|
||||
expected: usize,
|
||||
found: usize,
|
||||
span: Option<(usize, usize)>,
|
||||
span: Option<Span>,
|
||||
},
|
||||
#[error("Multiple REQUEST definitions found. Only one is allowed. First at {first_span:?}, second at {second_span:?}")]
|
||||
|
||||
#[error("Invalid argument type for {predicate}: anchored keys not allowed")]
|
||||
InvalidArgumentType {
|
||||
predicate: String,
|
||||
span: Option<Span>,
|
||||
},
|
||||
|
||||
#[error("Duplicate wildcard in predicate arguments: {name}")]
|
||||
DuplicateWildcard { name: String, span: Option<Span> },
|
||||
|
||||
#[error("Empty statement list in {context}")]
|
||||
EmptyStatementList { context: String, span: Option<Span> },
|
||||
|
||||
#[error("Multiple REQUEST definitions found. Only one is allowed.")]
|
||||
MultipleRequestDefinitions {
|
||||
first_span: Option<(usize, usize)>,
|
||||
second_span: Option<(usize, usize)>,
|
||||
first_span: Option<Span>,
|
||||
second_span: Option<Span>,
|
||||
},
|
||||
#[error("Internal processing error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
/// Lowering errors from frontend AST lowering to middleware
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum LoweringError {
|
||||
#[error("Too many custom predicates in batch '{batch_name}': {count} exceeds limit of {max}{}", if *.original_count != *.count { format!(" (started with {} predicates before automatic splitting)", original_count) } else { String::new() })]
|
||||
TooManyPredicates {
|
||||
batch_name: String,
|
||||
count: usize,
|
||||
max: usize,
|
||||
original_count: usize,
|
||||
},
|
||||
|
||||
#[error("Too many statements in predicate '{predicate}': {count} exceeds limit of {max}")]
|
||||
TooManyStatements {
|
||||
predicate: String,
|
||||
count: usize,
|
||||
max: usize,
|
||||
},
|
||||
|
||||
#[error("Too many wildcards in predicate '{predicate}': {count} exceeds limit of {max}")]
|
||||
TooManyWildcards {
|
||||
predicate: String,
|
||||
count: usize,
|
||||
max: usize,
|
||||
},
|
||||
|
||||
#[error("Too many arguments in statement template: {count} exceeds limit of {max}")]
|
||||
TooManyStatementArgs { count: usize, max: usize },
|
||||
|
||||
#[error("Predicate '{name}' not found in symbol table")]
|
||||
PredicateNotFound { name: String },
|
||||
|
||||
#[error("Invalid argument type in statement template")]
|
||||
InvalidArgumentType,
|
||||
|
||||
#[error("Middleware error: {0}")]
|
||||
Middleware(middleware::Error),
|
||||
#[error("Undefined wildcard: '?{name}' in predicate '{pred_name}' at {span:?}")]
|
||||
UndefinedWildcard {
|
||||
name: String,
|
||||
pred_name: String,
|
||||
span: Option<(usize, usize)>,
|
||||
Middleware(#[from] middleware::Error),
|
||||
|
||||
#[error("Splitting error: {0}")]
|
||||
Splitting(#[from] SplittingError),
|
||||
|
||||
#[error("Cannot lower document with validation errors")]
|
||||
ValidationErrors,
|
||||
}
|
||||
|
||||
/// Context information for split boundary failures
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SplitContext {
|
||||
/// Index of the split boundary (0-based)
|
||||
pub split_index: usize,
|
||||
/// Range of statement indices in the segment before the split
|
||||
pub statement_range: (usize, usize),
|
||||
/// Public arguments coming into this segment
|
||||
pub incoming_public: Vec<String>,
|
||||
/// Wildcards that cross this boundary (need to be promoted)
|
||||
pub crossing_wildcards: Vec<String>,
|
||||
/// Total public arguments needed (incoming + crossing)
|
||||
pub total_public: usize,
|
||||
}
|
||||
|
||||
/// Suggestions for refactoring predicates that fail to split
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RefactorSuggestion {
|
||||
/// A wildcard is used across too many statements
|
||||
ReduceWildcardSpan {
|
||||
wildcard: String,
|
||||
first_use: usize,
|
||||
last_use: usize,
|
||||
span: usize,
|
||||
},
|
||||
#[error("Invalid literal format for {kind}: '{value}' at {span:?}")]
|
||||
InvalidLiteralFormat {
|
||||
kind: String,
|
||||
value: String,
|
||||
span: Option<(usize, usize)>,
|
||||
/// Multiple wildcards should be grouped together
|
||||
GroupWildcardUsages { wildcards: Vec<String> },
|
||||
}
|
||||
|
||||
impl RefactorSuggestion {
|
||||
pub fn format(&self) -> String {
|
||||
match self {
|
||||
RefactorSuggestion::ReduceWildcardSpan {
|
||||
wildcard,
|
||||
first_use,
|
||||
last_use,
|
||||
span,
|
||||
} => {
|
||||
format!(
|
||||
"Wildcard '{}' is used across {} statements (statements {}-{}).\n\
|
||||
Consider grouping all '{}' operations together, or split the wildcard\n\
|
||||
into separate early/late variables.",
|
||||
wildcard, span, first_use, last_use, wildcard
|
||||
)
|
||||
}
|
||||
RefactorSuggestion::GroupWildcardUsages { wildcards } => {
|
||||
format!(
|
||||
"Group operations for wildcards: {}\n\
|
||||
These wildcards are used across multiple segments. Try to complete\n\
|
||||
all operations for each wildcard before moving to the next.",
|
||||
wildcards.join(", ")
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Formats a detailed error message for TooManyPublicArgsAtSplit
|
||||
fn format_public_args_at_split_error(
|
||||
predicate: &str,
|
||||
context: &SplitContext,
|
||||
max_allowed: usize,
|
||||
suggestion: &Option<Box<RefactorSuggestion>>,
|
||||
) -> String {
|
||||
let mut msg = format!(
|
||||
"Too many public arguments at split boundary {} in predicate '{}':\n",
|
||||
context.split_index, predicate
|
||||
);
|
||||
|
||||
msg.push_str(&format!(
|
||||
" {} incoming public + {} crossing wildcards = {} total (exceeds max of {})\n",
|
||||
context.incoming_public.len(),
|
||||
context.crossing_wildcards.len(),
|
||||
context.total_public,
|
||||
max_allowed
|
||||
));
|
||||
|
||||
msg.push_str(&format!(
|
||||
" Statements {}-{} in this segment\n",
|
||||
context.statement_range.0, context.statement_range.1
|
||||
));
|
||||
|
||||
if !context.incoming_public.is_empty() {
|
||||
msg.push_str(&format!(
|
||||
" Incoming public args: {}\n",
|
||||
context.incoming_public.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
if !context.crossing_wildcards.is_empty() {
|
||||
msg.push_str(&format!(
|
||||
" Wildcards crossing this boundary: {}\n",
|
||||
context.crossing_wildcards.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(suggestion) = suggestion {
|
||||
msg.push_str("\nSuggestion:\n");
|
||||
msg.push_str(&suggestion.format());
|
||||
}
|
||||
|
||||
msg
|
||||
}
|
||||
|
||||
/// Splitting errors from predicate splitting
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SplittingError {
|
||||
#[error("Too many public arguments in predicate '{predicate}': {count} exceeds max of {max_allowed}. {message}")]
|
||||
TooManyPublicArgs {
|
||||
predicate: String,
|
||||
count: usize,
|
||||
max_allowed: usize,
|
||||
message: String,
|
||||
},
|
||||
#[error("Batch with ID '{id}' not found at {span:?}")]
|
||||
BatchNotFound {
|
||||
id: String,
|
||||
span: Option<(usize, usize)>,
|
||||
|
||||
#[error("Too many total arguments in predicate '{predicate}': {count} exceeds max of {max_allowed}. {message}")]
|
||||
TooManyTotalArgs {
|
||||
predicate: String,
|
||||
count: usize,
|
||||
max_allowed: usize,
|
||||
message: String,
|
||||
},
|
||||
#[error("Number of predicates in 'use' statement ({found}) exceeds the number of predicates in the batch ({expected}) at {span:?}")]
|
||||
ImportArityMismatch {
|
||||
expected: usize,
|
||||
found: usize,
|
||||
span: Option<(usize, usize)>,
|
||||
|
||||
#[error("Too many total arguments in chain link {link_index} of predicate '{predicate}': {public_count} public + {private_count} private = {total_count} total (exceeds max of {max_allowed})")]
|
||||
TooManyTotalArgsInChainLink {
|
||||
predicate: String,
|
||||
link_index: usize,
|
||||
public_count: usize,
|
||||
private_count: usize,
|
||||
total_count: usize,
|
||||
max_allowed: usize,
|
||||
},
|
||||
#[error("Duplicate import name '{name}' at {span:?}")]
|
||||
DuplicateImportName {
|
||||
name: String,
|
||||
span: Option<(usize, usize)>,
|
||||
|
||||
#[error("{}", format_public_args_at_split_error(.predicate, .context, *.max_allowed, .suggestion))]
|
||||
TooManyPublicArgsAtSplit {
|
||||
predicate: String,
|
||||
context: Box<SplitContext>,
|
||||
max_allowed: usize,
|
||||
suggestion: Option<Box<RefactorSuggestion>>,
|
||||
},
|
||||
|
||||
#[error("Too many predicates in chain for '{predicate}': {count} exceeds batch limit of {max_allowed}")]
|
||||
TooManyPredicatesInChain {
|
||||
predicate: String,
|
||||
count: usize,
|
||||
max_allowed: usize,
|
||||
},
|
||||
#[error("Frontend error: {0}")]
|
||||
Frontend(#[from] frontend::Error),
|
||||
}
|
||||
|
||||
impl From<ParseError> for LangError {
|
||||
|
|
@ -98,14 +286,20 @@ impl From<ParseError> for LangError {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ProcessorError> for LangError {
|
||||
fn from(err: ProcessorError) -> Self {
|
||||
LangError::Processor(Box::new(err))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<middleware::Error> for LangError {
|
||||
fn from(err: middleware::Error) -> Self {
|
||||
LangError::Middleware(Box::new(err))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ValidationError> for LangError {
|
||||
fn from(err: ValidationError) -> Self {
|
||||
LangError::Validation(Box::new(err))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LoweringError> for LangError {
|
||||
fn from(err: LoweringError) -> Self {
|
||||
LangError::Lowering(Box::new(err))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
1328
src/lang/frontend_ast.rs
Normal file
1328
src/lang/frontend_ast.rs
Normal file
File diff suppressed because it is too large
Load diff
749
src/lang/frontend_ast_lower.rs
Normal file
749
src/lang/frontend_ast_lower.rs
Normal file
|
|
@ -0,0 +1,749 @@
|
|||
//! Lowering from frontend AST to middleware structures
|
||||
//!
|
||||
//! This module converts validated frontend AST to middleware data structures.
|
||||
//! Currently implements basic 1:1 conversion without automatic predicate splitting.
|
||||
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder},
|
||||
lang::{
|
||||
frontend_ast::*,
|
||||
frontend_ast_split,
|
||||
frontend_ast_validate::{PredicateKind, ValidatedAST},
|
||||
},
|
||||
middleware::{
|
||||
self, containers, CustomPredicateBatch, IntroPredicateRef, NativePredicate, Params,
|
||||
Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg,
|
||||
Wildcard,
|
||||
},
|
||||
};
|
||||
|
||||
/// Result of lowering: optional custom predicate batch and optional request
|
||||
///
|
||||
/// A Podlang file can contain:
|
||||
/// - Just custom predicates (batch: Some, request: None)
|
||||
/// - Just a request (batch: None, request: Some)
|
||||
/// - Both (batch: Some, request: Some)
|
||||
/// - Neither (batch: None, request: None) - just imports
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoweredOutput {
|
||||
pub batch: Option<Arc<CustomPredicateBatch>>,
|
||||
pub request: Option<crate::frontend::PodRequest>,
|
||||
}
|
||||
|
||||
pub use crate::lang::error::LoweringError;
|
||||
|
||||
/// Lower a validated AST to middleware structures
|
||||
///
|
||||
/// Returns both the custom predicate batch (if any) and the request (if any).
|
||||
/// At least one will be Some if the document contains custom predicates or a request.
|
||||
pub fn lower(
|
||||
validated: ValidatedAST,
|
||||
params: &Params,
|
||||
batch_name: String,
|
||||
) -> Result<LoweredOutput, LoweringError> {
|
||||
if !validated.diagnostics().is_empty() {
|
||||
// For now, treat any diagnostics as errors
|
||||
// In future we could allow warnings
|
||||
return Err(LoweringError::ValidationErrors);
|
||||
}
|
||||
|
||||
let lowerer = Lowerer::new(validated, params);
|
||||
lowerer.lower(batch_name)
|
||||
}
|
||||
|
||||
struct Lowerer<'a> {
|
||||
validated: ValidatedAST,
|
||||
params: &'a Params,
|
||||
/// Map of predicate names to their index in the current batch (for split predicates)
|
||||
batch_predicate_index: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
impl<'a> Lowerer<'a> {
|
||||
fn new(validated: ValidatedAST, params: &'a Params) -> Self {
|
||||
Self {
|
||||
validated,
|
||||
params,
|
||||
batch_predicate_index: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn lower(mut self, batch_name: String) -> Result<LoweredOutput, LoweringError> {
|
||||
// Lower custom predicates (if any)
|
||||
let batch = self.lower_batch(batch_name)?;
|
||||
|
||||
// Lower request (if any) - pass batch so BatchSelf refs can be converted to Custom refs
|
||||
let request = self.lower_request(batch.as_ref())?;
|
||||
|
||||
Ok(LoweredOutput { batch, request })
|
||||
}
|
||||
|
||||
fn lower_batch(
|
||||
&mut self,
|
||||
batch_name: String,
|
||||
) -> Result<Option<Arc<CustomPredicateBatch>>, LoweringError> {
|
||||
// Extract and split custom predicates from document
|
||||
let (custom_predicates, original_count) = self.extract_and_split_predicates()?;
|
||||
|
||||
// If no custom predicates, return None
|
||||
if custom_predicates.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Check batch size constraint
|
||||
if custom_predicates.len() > self.params.max_custom_batch_size {
|
||||
return Err(LoweringError::TooManyPredicates {
|
||||
batch_name: batch_name.clone(),
|
||||
count: custom_predicates.len(),
|
||||
max: self.params.max_custom_batch_size,
|
||||
original_count,
|
||||
});
|
||||
}
|
||||
|
||||
// Build index of all predicates in the batch
|
||||
for (idx, pred) in custom_predicates.iter().enumerate() {
|
||||
self.batch_predicate_index
|
||||
.insert(pred.name.name.clone(), idx);
|
||||
}
|
||||
|
||||
// Create custom predicate batch using builder
|
||||
let mut cpb_builder =
|
||||
CustomPredicateBatchBuilder::new(self.params.clone(), batch_name.clone());
|
||||
|
||||
for pred_def in &custom_predicates {
|
||||
self.lower_custom_predicate(pred_def, &mut cpb_builder)?;
|
||||
}
|
||||
|
||||
Ok(Some(cpb_builder.finish()))
|
||||
}
|
||||
|
||||
fn lower_request(
|
||||
&self,
|
||||
batch: Option<&Arc<CustomPredicateBatch>>,
|
||||
) -> Result<Option<crate::frontend::PodRequest>, LoweringError> {
|
||||
let doc = self.validated.document();
|
||||
|
||||
// Find request definition (if any)
|
||||
let request_def = doc.items.iter().find_map(|item| match item {
|
||||
DocumentItem::RequestDef(req) => Some(req),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
let Some(request_def) = request_def else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Build wildcard map from all wildcards used in the request statements
|
||||
let wildcard_map = self.build_request_wildcard_map(request_def);
|
||||
|
||||
// Lower each statement to a builder first
|
||||
let mut statement_builders = Vec::new();
|
||||
for stmt in &request_def.statements {
|
||||
let stmt_builder = self.lower_statement_to_builder(stmt)?;
|
||||
statement_builders.push(stmt_builder);
|
||||
}
|
||||
|
||||
// Resolve builders to middleware statement templates
|
||||
let mut request_templates = Vec::new();
|
||||
for stmt_builder in statement_builders {
|
||||
let mw_stmt =
|
||||
self.resolve_request_statement_builder(stmt_builder, &wildcard_map, batch)?;
|
||||
request_templates.push(mw_stmt);
|
||||
}
|
||||
|
||||
Ok(Some(crate::frontend::PodRequest::new(request_templates)))
|
||||
}
|
||||
|
||||
fn resolve_request_statement_builder(
|
||||
&self,
|
||||
stmt_builder: StatementTmplBuilder,
|
||||
wildcard_map: &HashMap<String, usize>,
|
||||
batch: Option<&Arc<CustomPredicateBatch>>,
|
||||
) -> Result<MWStatementTmpl, LoweringError> {
|
||||
// First desugar the builder
|
||||
let desugared = stmt_builder.desugar();
|
||||
|
||||
// Convert BatchSelf predicate to Custom if we have a batch
|
||||
let mut predicate = desugared.predicate;
|
||||
if let Some(batch_ref) = batch {
|
||||
if let Predicate::BatchSelf(index) = predicate {
|
||||
predicate = Predicate::Custom(middleware::CustomPredicateRef::new(
|
||||
batch_ref.clone(),
|
||||
index,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BuilderArgs to StatementTmplArgs
|
||||
let mut mw_args = Vec::new();
|
||||
for builder_arg in desugared.args {
|
||||
let mw_arg = match builder_arg {
|
||||
BuilderArg::Literal(value) => MWStatementTmplArg::Literal(value),
|
||||
BuilderArg::WildcardLiteral(name) => {
|
||||
let index = wildcard_map.get(&name).expect("Wildcard not found");
|
||||
MWStatementTmplArg::Wildcard(Wildcard::new(name, *index))
|
||||
}
|
||||
BuilderArg::Key(root_name, key_str) => {
|
||||
let root_index = wildcard_map
|
||||
.get(&root_name)
|
||||
.expect("Root wildcard not found");
|
||||
let wildcard = Wildcard::new(root_name, *root_index);
|
||||
let key = middleware::Key::from(key_str.as_str());
|
||||
MWStatementTmplArg::AnchoredKey(wildcard, key)
|
||||
}
|
||||
};
|
||||
mw_args.push(mw_arg);
|
||||
}
|
||||
|
||||
Ok(MWStatementTmpl {
|
||||
pred: predicate,
|
||||
args: mw_args,
|
||||
})
|
||||
}
|
||||
|
||||
fn build_request_wildcard_map(&self, request_def: &RequestDef) -> HashMap<String, usize> {
|
||||
// Collect all unique wildcards from all statements
|
||||
let mut wildcard_names = Vec::new();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
for stmt in &request_def.statements {
|
||||
self.collect_statement_wildcards(stmt, &mut wildcard_names, &mut seen);
|
||||
}
|
||||
|
||||
// Build map from name to index
|
||||
wildcard_names
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, name)| (name, idx))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn collect_statement_wildcards(
|
||||
&self,
|
||||
stmt: &StatementTmpl,
|
||||
names: &mut Vec<String>,
|
||||
seen: &mut HashSet<String>,
|
||||
) {
|
||||
for arg in &stmt.args {
|
||||
match arg {
|
||||
StatementTmplArg::Wildcard(id) => {
|
||||
if !seen.contains(&id.name) {
|
||||
seen.insert(id.name.clone());
|
||||
names.push(id.name.clone());
|
||||
}
|
||||
}
|
||||
StatementTmplArg::AnchoredKey(ak) => {
|
||||
if !seen.contains(&ak.root.name) {
|
||||
seen.insert(ak.root.name.clone());
|
||||
names.push(ak.root.name.clone());
|
||||
}
|
||||
}
|
||||
StatementTmplArg::Literal(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_and_split_predicates(
|
||||
&self,
|
||||
) -> Result<(Vec<CustomPredicateDef>, usize), LoweringError> {
|
||||
let doc = self.validated.document();
|
||||
let predicates: Vec<CustomPredicateDef> = doc
|
||||
.items
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
DocumentItem::CustomPredicateDef(pred) => Some(pred.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let original_count = predicates.len();
|
||||
|
||||
// Apply splitting to each predicate as needed
|
||||
let mut split_predicates = Vec::new();
|
||||
for pred in predicates {
|
||||
let chain = frontend_ast_split::split_predicate_if_needed(pred, self.params)?;
|
||||
split_predicates.extend(chain);
|
||||
}
|
||||
|
||||
Ok((split_predicates, original_count))
|
||||
}
|
||||
|
||||
fn lower_custom_predicate(
|
||||
&self,
|
||||
pred_def: &CustomPredicateDef,
|
||||
cpb_builder: &mut CustomPredicateBatchBuilder,
|
||||
) -> Result<(), LoweringError> {
|
||||
let name = pred_def.name.name.clone();
|
||||
|
||||
// Note: Constraint checking is handled by the splitting phase
|
||||
// Predicates passed here should already be within limits
|
||||
|
||||
// Collect public and private argument names
|
||||
let mut public_arg_names = Vec::new();
|
||||
let mut private_arg_names = Vec::new();
|
||||
|
||||
for arg in &pred_def.args.public_args {
|
||||
public_arg_names.push(arg.name.clone());
|
||||
}
|
||||
|
||||
if let Some(private_args) = &pred_def.args.private_args {
|
||||
for arg in private_args {
|
||||
private_arg_names.push(arg.name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Lower statements to builders
|
||||
let mut statement_builders = Vec::new();
|
||||
for stmt in &pred_def.statements {
|
||||
let stmt_builder = self.lower_statement_to_builder(stmt)?;
|
||||
statement_builders.push(stmt_builder);
|
||||
}
|
||||
|
||||
// Convert to &str slices for builder API
|
||||
let public_args_str: Vec<&str> = public_arg_names.iter().map(|s| s.as_str()).collect();
|
||||
let private_args_str: Vec<&str> = private_arg_names.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
// Add predicate to batch using builder
|
||||
let conjunction = pred_def.conjunction_type == ConjunctionType::And;
|
||||
|
||||
cpb_builder
|
||||
.predicate(
|
||||
&name,
|
||||
conjunction,
|
||||
&public_args_str,
|
||||
&private_args_str,
|
||||
&statement_builders,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
crate::frontend::Error::Middleware(mw_err) => LoweringError::Middleware(mw_err),
|
||||
_ => LoweringError::InvalidArgumentType,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn lower_statement_to_builder(
|
||||
&self,
|
||||
stmt: &StatementTmpl,
|
||||
) -> Result<StatementTmplBuilder, LoweringError> {
|
||||
// Get predicate
|
||||
let pred_name = &stmt.predicate.name;
|
||||
let symbols = self.validated.symbols();
|
||||
|
||||
// Check for native predicates first
|
||||
let predicate = if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||
Predicate::Native(native)
|
||||
} else if let Some(&index) = self.batch_predicate_index.get(pred_name) {
|
||||
// References to other predicates in the same batch (including split chains)
|
||||
Predicate::BatchSelf(index)
|
||||
} else if let Some(info) = symbols.predicates.get(pred_name) {
|
||||
match &info.kind {
|
||||
PredicateKind::Native(np) => Predicate::Native(*np),
|
||||
PredicateKind::Custom { index } => Predicate::BatchSelf(*index),
|
||||
PredicateKind::BatchImported { batch, index } => {
|
||||
Predicate::Custom(middleware::CustomPredicateRef::new(batch.clone(), *index))
|
||||
}
|
||||
PredicateKind::IntroImported {
|
||||
name,
|
||||
verifier_data_hash,
|
||||
} => Predicate::Intro(IntroPredicateRef {
|
||||
name: name.clone(),
|
||||
args_len: info.public_arity,
|
||||
verifier_data_hash: *verifier_data_hash,
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
unreachable!("Predicate {} not found", pred_name);
|
||||
};
|
||||
|
||||
// Check args count
|
||||
if stmt.args.len() > self.params.max_statement_args {
|
||||
return Err(LoweringError::TooManyStatementArgs {
|
||||
count: stmt.args.len(),
|
||||
max: self.params.max_statement_args,
|
||||
});
|
||||
}
|
||||
|
||||
// Convert AST args to BuilderArgs
|
||||
let mut builder = StatementTmplBuilder::new(predicate);
|
||||
for arg in &stmt.args {
|
||||
let builder_arg = self.lower_statement_arg_to_builder(arg)?;
|
||||
builder = builder.arg(builder_arg);
|
||||
}
|
||||
|
||||
// Return builder without calling .desugar() - that will happen later
|
||||
Ok(builder)
|
||||
}
|
||||
|
||||
fn lower_statement_arg_to_builder(
|
||||
&self,
|
||||
arg: &StatementTmplArg,
|
||||
) -> Result<BuilderArg, LoweringError> {
|
||||
match arg {
|
||||
StatementTmplArg::Literal(lit) => {
|
||||
let value = self.lower_literal(lit)?;
|
||||
Ok(BuilderArg::Literal(value))
|
||||
}
|
||||
StatementTmplArg::Wildcard(id) => {
|
||||
// For builder, we just need the wildcard name
|
||||
Ok(BuilderArg::WildcardLiteral(id.name.clone()))
|
||||
}
|
||||
StatementTmplArg::AnchoredKey(ak) => {
|
||||
let key_str = match &ak.key {
|
||||
AnchoredKeyPath::Bracket(s) => s.value.clone(),
|
||||
AnchoredKeyPath::Dot(id) => id.name.clone(),
|
||||
};
|
||||
Ok(BuilderArg::Key(ak.root.name.clone(), key_str))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lower_literal(&self, lit: &LiteralValue) -> Result<middleware::Value, LoweringError> {
|
||||
let value = match lit {
|
||||
LiteralValue::Int(i) => middleware::Value::from(i.value),
|
||||
LiteralValue::Bool(b) => middleware::Value::from(b.value),
|
||||
LiteralValue::String(s) => middleware::Value::from(s.value.clone()),
|
||||
LiteralValue::Raw(r) => middleware::Value::from(r.hash.hash),
|
||||
LiteralValue::PublicKey(pk) => middleware::Value::from(pk.point),
|
||||
LiteralValue::SecretKey(sk) => middleware::Value::from(sk.secret_key.clone()),
|
||||
LiteralValue::Array(a) => {
|
||||
let elements: Result<Vec<_>, _> =
|
||||
a.elements.iter().map(|e| self.lower_literal(e)).collect();
|
||||
let array = containers::Array::new(self.params.max_depth_mt_containers, elements?)?;
|
||||
middleware::Value::from(array)
|
||||
}
|
||||
LiteralValue::Set(s) => {
|
||||
let elements: Result<Vec<_>, _> =
|
||||
s.elements.iter().map(|e| self.lower_literal(e)).collect();
|
||||
let set_values: std::collections::HashSet<_> = elements?.into_iter().collect();
|
||||
let set = containers::Set::new(self.params.max_depth_mt_containers, set_values)?;
|
||||
middleware::Value::from(set)
|
||||
}
|
||||
LiteralValue::Dict(d) => {
|
||||
let pairs: Result<Vec<(middleware::Key, middleware::Value)>, LoweringError> = d
|
||||
.pairs
|
||||
.iter()
|
||||
.map(|pair| {
|
||||
let key = middleware::Key::from(pair.key.value.as_str());
|
||||
let value = self.lower_literal(&pair.value)?;
|
||||
Ok((key, value))
|
||||
})
|
||||
.collect();
|
||||
let dict_map: std::collections::HashMap<_, _> = pairs?.into_iter().collect();
|
||||
let dict =
|
||||
containers::Dictionary::new(self.params.max_depth_mt_containers, dict_map)?;
|
||||
middleware::Value::from(dict)
|
||||
}
|
||||
};
|
||||
Ok(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::lang::{
|
||||
frontend_ast::parse::parse_document, frontend_ast_validate::validate, parser::parse_podlang,
|
||||
};
|
||||
|
||||
fn parse_validate_and_lower(
|
||||
input: &str,
|
||||
params: &Params,
|
||||
) -> Result<LoweredOutput, LoweringError> {
|
||||
let parsed = parse_podlang(input).expect("Failed to parse");
|
||||
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
|
||||
let validated = validate(document, &[]).expect("Failed to validate");
|
||||
lower(validated, params, "test_batch".to_string())
|
||||
}
|
||||
|
||||
// Helper to get the batch from the output (expecting it to exist)
|
||||
fn expect_batch(output: &LoweredOutput) -> &Arc<CustomPredicateBatch> {
|
||||
output.batch.as_ref().expect("Expected batch to be present")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_predicate() {
|
||||
let input = r#"
|
||||
my_pred(A, B) = AND (
|
||||
Equal(A["foo"], B["bar"])
|
||||
)
|
||||
"#;
|
||||
|
||||
let params = Params::default();
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
if let Err(e) = &result {
|
||||
eprintln!("Error: {:?}", e);
|
||||
}
|
||||
assert!(result.is_ok());
|
||||
|
||||
let lowered = result.unwrap();
|
||||
assert_eq!(expect_batch(&lowered).predicates().len(), 1);
|
||||
|
||||
let pred = &expect_batch(&lowered).predicates()[0];
|
||||
assert_eq!(pred.name, "my_pred");
|
||||
assert_eq!(pred.args_len(), 2);
|
||||
assert_eq!(pred.wildcard_names().len(), 2);
|
||||
assert_eq!(pred.statements().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_private_args() {
|
||||
let input = r#"
|
||||
my_pred(A, private: B, C) = AND (
|
||||
Equal(A["x"], B["y"])
|
||||
Equal(B["z"], C["w"])
|
||||
)
|
||||
"#;
|
||||
|
||||
let params = Params::default();
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let lowered = result.unwrap();
|
||||
let pred = &expect_batch(&lowered).predicates()[0];
|
||||
assert_eq!(pred.args_len(), 1); // Only A is public
|
||||
assert_eq!(pred.wildcard_names().len(), 3); // A, B, C total
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_or_predicate() {
|
||||
let input = r#"
|
||||
my_pred(A, B) = OR (
|
||||
Equal(A["x"], 1)
|
||||
Equal(B["y"], 2)
|
||||
)
|
||||
"#;
|
||||
|
||||
let params = Params::default();
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let lowered = result.unwrap();
|
||||
let pred = &expect_batch(&lowered).predicates()[0];
|
||||
assert!(pred.is_disjunction());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_automatic_splitting() {
|
||||
let input = r#"
|
||||
my_pred(A) = AND (
|
||||
Equal(A["a"], 1)
|
||||
Equal(A["b"], 2)
|
||||
Equal(A["c"], 3)
|
||||
Equal(A["d"], 4)
|
||||
Equal(A["e"], 5)
|
||||
Equal(A["f"], 6)
|
||||
)
|
||||
"#;
|
||||
|
||||
let params = Params::default(); // max_custom_predicate_arity = 5
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
if let Err(e) = &result {
|
||||
eprintln!("Splitting error: {:?}", e);
|
||||
}
|
||||
assert!(result.is_ok());
|
||||
|
||||
let lowered = result.unwrap();
|
||||
// Should be automatically split into 2 predicates (my_pred and my_pred_1)
|
||||
assert_eq!(expect_batch(&lowered).predicates().len(), 2);
|
||||
|
||||
// First predicate should have 5 statements (4 + chain call)
|
||||
assert_eq!(expect_batch(&lowered).predicates()[0].statements().len(), 5);
|
||||
|
||||
// Second predicate should have 2 statements
|
||||
assert_eq!(expect_batch(&lowered).predicates()[1].statements().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_predicates() {
|
||||
let input = r#"
|
||||
pred1(A) = AND (
|
||||
Equal(A["x"], 1)
|
||||
)
|
||||
|
||||
pred2(B) = AND (
|
||||
Equal(B["y"], 2)
|
||||
)
|
||||
"#;
|
||||
|
||||
let params = Params::default();
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let lowered = result.unwrap();
|
||||
assert_eq!(expect_batch(&lowered).predicates().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_self_reference() {
|
||||
let input = r#"
|
||||
pred1(A) = AND (
|
||||
Equal(A["x"], 1)
|
||||
)
|
||||
|
||||
pred2(B) = AND (
|
||||
pred1(B)
|
||||
)
|
||||
"#;
|
||||
|
||||
let params = Params::default();
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let lowered = result.unwrap();
|
||||
let pred2 = &expect_batch(&lowered).predicates()[1];
|
||||
let stmt = &pred2.statements()[0];
|
||||
|
||||
// Should be BatchSelf(0) referring to pred1
|
||||
assert!(matches!(stmt.pred, Predicate::BatchSelf(0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_literals() {
|
||||
let input = r#"
|
||||
my_pred(X) = AND (
|
||||
Equal(X["int"], 42)
|
||||
Equal(X["bool"], true)
|
||||
Equal(X["string"], "hello")
|
||||
)
|
||||
"#;
|
||||
|
||||
let params = Params::default();
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_syntactic_sugar_desugaring() {
|
||||
let input = r#"
|
||||
my_pred(D) = AND (
|
||||
DictContains(D, "key", "value")
|
||||
)
|
||||
"#;
|
||||
|
||||
let params = Params::default();
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let lowered = result.unwrap();
|
||||
let pred = &expect_batch(&lowered).predicates()[0];
|
||||
let stmt = &pred.statements()[0];
|
||||
|
||||
// Should desugar to the Contains predicate
|
||||
assert!(matches!(
|
||||
stmt.pred,
|
||||
Predicate::Native(NativePredicate::Contains)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_message_with_splitting() {
|
||||
// Create a document with predicates that will exceed the batch limit after splitting
|
||||
// We'll create 2 predicates with 4 statements each (max arity = 5)
|
||||
// Each will NOT split individually, but together they exceed a small batch limit
|
||||
let input = r#"
|
||||
pred1(A) = AND (
|
||||
Equal(A["a"], 1)
|
||||
Equal(A["b"], 2)
|
||||
)
|
||||
pred2(B) = AND (
|
||||
Equal(B["c"], 3)
|
||||
Equal(B["d"], 4)
|
||||
)
|
||||
"#;
|
||||
|
||||
// Use very restrictive params to force the error
|
||||
let params = Params {
|
||||
max_custom_batch_size: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
|
||||
// Should fail with TooManyPredicates error
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
|
||||
if let LoweringError::TooManyPredicates {
|
||||
count,
|
||||
max,
|
||||
original_count,
|
||||
..
|
||||
} = err
|
||||
{
|
||||
assert_eq!(count, 2); // 2 predicates after splitting (no splitting occurred)
|
||||
assert_eq!(max, 1);
|
||||
assert_eq!(original_count, 2); // Started with 2 predicates
|
||||
|
||||
// Error message should NOT mention splitting since no splitting occurred
|
||||
let err_msg = format!("{}", err);
|
||||
assert!(!err_msg.contains("before automatic splitting"));
|
||||
} else {
|
||||
panic!("Expected TooManyPredicates error, got: {:?}", err);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_message_after_splitting() {
|
||||
// Create TWO predicates that will EACH split into 2 predicates
|
||||
// This tests the case where splitting causes the batch to be too large
|
||||
// but no individual predicate chain exceeds the limit
|
||||
let input = r#"
|
||||
pred1(A) = AND (
|
||||
Equal(A["a"], 1)
|
||||
Equal(A["b"], 2)
|
||||
Equal(A["c"], 3)
|
||||
Equal(A["d"], 4)
|
||||
Equal(A["e"], 5)
|
||||
Equal(A["f"], 6)
|
||||
)
|
||||
pred2(B) = AND (
|
||||
Equal(B["a"], 1)
|
||||
Equal(B["b"], 2)
|
||||
Equal(B["c"], 3)
|
||||
Equal(B["d"], 4)
|
||||
Equal(B["e"], 5)
|
||||
Equal(B["f"], 6)
|
||||
)
|
||||
"#;
|
||||
|
||||
// Use params where each predicate splits into 2, but total of 4 exceeds batch limit
|
||||
let params = Params {
|
||||
// Allow 3 predicates in batch
|
||||
// Default max_custom_predicate_arity is 5, so each will split into 2 predicates
|
||||
// Total: 2 original predicates -> 4 after splitting (exceeds limit of 3)
|
||||
max_custom_batch_size: 3,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
|
||||
// Should fail with TooManyPredicates error
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
|
||||
if let LoweringError::TooManyPredicates {
|
||||
count,
|
||||
max,
|
||||
original_count,
|
||||
..
|
||||
} = err
|
||||
{
|
||||
assert_eq!(count, 4); // 4 predicates after splitting (2 from each)
|
||||
assert_eq!(max, 3);
|
||||
assert_eq!(original_count, 2); // Started with 2 predicates
|
||||
|
||||
// Error message SHOULD mention splitting since splitting occurred
|
||||
let err_msg = format!("{}", err);
|
||||
assert!(err_msg.contains("before automatic splitting"));
|
||||
assert!(err_msg.contains("started with 2 predicates"));
|
||||
} else {
|
||||
panic!("Expected TooManyPredicates error, got: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
1002
src/lang/frontend_ast_split.rs
Normal file
1002
src/lang/frontend_ast_split.rs
Normal file
File diff suppressed because it is too large
Load diff
775
src/lang/frontend_ast_validate.rs
Normal file
775
src/lang/frontend_ast_validate.rs
Normal file
|
|
@ -0,0 +1,775 @@
|
|||
//! Validation for the frontend AST
|
||||
//!
|
||||
//! This module provides semantic validation for parsed AST documents,
|
||||
//! including name resolution, arity checking, and wildcard validation.
|
||||
|
||||
use std::{collections::HashMap, str::FromStr, sync::Arc};
|
||||
|
||||
use hex::ToHex;
|
||||
|
||||
use crate::{
|
||||
lang::frontend_ast::*,
|
||||
middleware::{CustomPredicateBatch, Hash, NativePredicate},
|
||||
};
|
||||
|
||||
/// A validated AST document with symbol table and diagnostics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ValidatedAST {
|
||||
document: Document,
|
||||
symbols: SymbolTable,
|
||||
diagnostics: Vec<Diagnostic>,
|
||||
}
|
||||
|
||||
impl ValidatedAST {
|
||||
pub fn document(&self) -> &Document {
|
||||
&self.document
|
||||
}
|
||||
|
||||
pub fn symbols(&self) -> &SymbolTable {
|
||||
&self.symbols
|
||||
}
|
||||
|
||||
pub fn diagnostics(&self) -> &[Diagnostic] {
|
||||
&self.diagnostics
|
||||
}
|
||||
|
||||
pub fn into_document(self) -> Document {
|
||||
self.document
|
||||
}
|
||||
}
|
||||
|
||||
/// Symbol table containing all predicates and their metadata
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SymbolTable {
|
||||
/// All predicates available in this scope
|
||||
pub predicates: HashMap<String, PredicateInfo>,
|
||||
/// Wildcard scopes for each custom predicate
|
||||
pub wildcard_scopes: HashMap<String, WildcardScope>,
|
||||
}
|
||||
|
||||
/// Information about a predicate
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PredicateInfo {
|
||||
pub kind: PredicateKind,
|
||||
pub arity: usize,
|
||||
pub public_arity: usize,
|
||||
pub source_span: Option<Span>,
|
||||
}
|
||||
|
||||
/// Kind of predicate
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PredicateKind {
|
||||
Native(NativePredicate),
|
||||
Custom {
|
||||
index: usize,
|
||||
},
|
||||
BatchImported {
|
||||
batch: Arc<CustomPredicateBatch>,
|
||||
index: usize,
|
||||
},
|
||||
IntroImported {
|
||||
name: String,
|
||||
verifier_data_hash: Hash,
|
||||
},
|
||||
}
|
||||
|
||||
/// Wildcard scope for a custom predicate
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WildcardScope {
|
||||
pub wildcards: HashMap<String, WildcardInfo>,
|
||||
}
|
||||
|
||||
/// Information about a wildcard
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WildcardInfo {
|
||||
pub index: usize,
|
||||
pub is_public: bool,
|
||||
pub source_span: Option<Span>,
|
||||
}
|
||||
|
||||
/// Diagnostic message (warning or info)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Diagnostic {
|
||||
pub level: DiagnosticLevel,
|
||||
pub message: String,
|
||||
pub span: Option<Span>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum DiagnosticLevel {
|
||||
Warning,
|
||||
Info,
|
||||
}
|
||||
|
||||
pub use crate::lang::error::ValidationError;
|
||||
|
||||
/// Validate an AST document
|
||||
pub fn validate(
|
||||
document: Document,
|
||||
available_batches: &[Arc<CustomPredicateBatch>],
|
||||
) -> Result<ValidatedAST, ValidationError> {
|
||||
let validator = Validator::new(available_batches);
|
||||
validator.validate(document)
|
||||
}
|
||||
|
||||
struct Validator {
|
||||
available_batches: HashMap<String, Arc<CustomPredicateBatch>>,
|
||||
symbols: SymbolTable,
|
||||
diagnostics: Vec<Diagnostic>,
|
||||
custom_predicate_count: usize,
|
||||
}
|
||||
|
||||
impl Validator {
|
||||
fn new(batches: &[Arc<CustomPredicateBatch>]) -> Self {
|
||||
let mut available_batches = HashMap::new();
|
||||
for batch in batches {
|
||||
// Store by hex ID for lookup
|
||||
let id = format!("0x{}", batch.id().encode_hex::<String>());
|
||||
available_batches.insert(id, batch.clone());
|
||||
}
|
||||
|
||||
Self {
|
||||
available_batches,
|
||||
symbols: SymbolTable {
|
||||
predicates: HashMap::new(),
|
||||
wildcard_scopes: HashMap::new(),
|
||||
},
|
||||
diagnostics: Vec::new(),
|
||||
custom_predicate_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn validate(mut self, document: Document) -> Result<ValidatedAST, ValidationError> {
|
||||
// Pass 1: Build symbol table
|
||||
self.build_symbol_table(&document)?;
|
||||
|
||||
// Pass 2: Validate all references
|
||||
self.validate_references(&document)?;
|
||||
|
||||
Ok(ValidatedAST {
|
||||
document,
|
||||
symbols: self.symbols,
|
||||
diagnostics: self.diagnostics,
|
||||
})
|
||||
}
|
||||
|
||||
fn build_symbol_table(&mut self, document: &Document) -> Result<(), ValidationError> {
|
||||
// First process imports
|
||||
for item in &document.items {
|
||||
if let DocumentItem::UseBatchStatement(use_stmt) = item {
|
||||
self.process_use_batch_statement(use_stmt)?;
|
||||
}
|
||||
if let DocumentItem::UseIntroStatement(use_stmt) = item {
|
||||
self.process_use_intro_statement(use_stmt)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Then process custom predicate definitions
|
||||
for item in &document.items {
|
||||
if let DocumentItem::CustomPredicateDef(pred_def) = item {
|
||||
self.process_custom_predicate_def(pred_def)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for multiple REQUEST definitions (only one allowed)
|
||||
let mut first_request_span = None;
|
||||
for item in &document.items {
|
||||
if let DocumentItem::RequestDef(req) = item {
|
||||
if let Some(first_span) = first_request_span {
|
||||
return Err(ValidationError::MultipleRequestDefinitions {
|
||||
first_span: Some(first_span),
|
||||
second_span: req.span,
|
||||
});
|
||||
}
|
||||
first_request_span = req.span;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_use_batch_statement(
|
||||
&mut self,
|
||||
use_stmt: &UseBatchStatement,
|
||||
) -> Result<(), ValidationError> {
|
||||
let batch_id = format!("0x{}", use_stmt.batch_ref.hash.encode_hex::<String>());
|
||||
|
||||
let batch = self.available_batches.get(&batch_id).ok_or_else(|| {
|
||||
ValidationError::BatchNotFound {
|
||||
id: batch_id.clone(),
|
||||
span: use_stmt.batch_ref.span,
|
||||
}
|
||||
})?;
|
||||
|
||||
if use_stmt.imports.len() != batch.predicates().len() {
|
||||
return Err(ValidationError::ImportArityMismatch {
|
||||
expected: batch.predicates().len(),
|
||||
found: use_stmt.imports.len(),
|
||||
span: use_stmt.span,
|
||||
});
|
||||
}
|
||||
|
||||
for (i, import) in use_stmt.imports.iter().enumerate() {
|
||||
if let ImportName::Named(name) = import {
|
||||
if self.symbols.predicates.contains_key(name) {
|
||||
return Err(ValidationError::DuplicateImport {
|
||||
name: name.clone(),
|
||||
span: use_stmt.span,
|
||||
});
|
||||
}
|
||||
|
||||
let pred = &batch.predicates()[i];
|
||||
// CustomPredicate has args_len (public args) and wildcard_names (total args)
|
||||
let total_arity = pred.wildcard_names.len();
|
||||
let public_arity = pred.args_len;
|
||||
|
||||
self.symbols.predicates.insert(
|
||||
name.clone(),
|
||||
PredicateInfo {
|
||||
kind: PredicateKind::BatchImported {
|
||||
batch: batch.clone(),
|
||||
index: i,
|
||||
},
|
||||
arity: total_arity,
|
||||
public_arity,
|
||||
source_span: use_stmt.span,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_use_intro_statement(
|
||||
&mut self,
|
||||
use_stmt: &UseIntroStatement,
|
||||
) -> Result<(), ValidationError> {
|
||||
let intro_name = &use_stmt.name.name;
|
||||
let args = &use_stmt.args;
|
||||
let intro_predicate_ref = &use_stmt.intro_hash;
|
||||
|
||||
if self.symbols.predicates.contains_key(intro_name) {
|
||||
return Err(ValidationError::DuplicateImport {
|
||||
name: intro_name.clone(),
|
||||
span: use_stmt.span,
|
||||
});
|
||||
}
|
||||
|
||||
self.symbols.predicates.insert(
|
||||
intro_name.clone(),
|
||||
PredicateInfo {
|
||||
kind: PredicateKind::IntroImported {
|
||||
name: intro_name.clone(),
|
||||
// Hash is already parsed in the AST
|
||||
verifier_data_hash: intro_predicate_ref.hash,
|
||||
},
|
||||
arity: args.len(),
|
||||
public_arity: args.len(),
|
||||
source_span: use_stmt.span,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_custom_predicate_def(
|
||||
&mut self,
|
||||
pred_def: &CustomPredicateDef,
|
||||
) -> Result<(), ValidationError> {
|
||||
let name = &pred_def.name.name;
|
||||
|
||||
if self.symbols.predicates.contains_key(name) {
|
||||
let first_span = self.symbols.predicates[name].source_span;
|
||||
return Err(ValidationError::DuplicatePredicate {
|
||||
name: name.clone(),
|
||||
first_span,
|
||||
second_span: pred_def.name.span,
|
||||
});
|
||||
}
|
||||
|
||||
// Check for empty statement list
|
||||
if pred_def.statements.is_empty() {
|
||||
return Err(ValidationError::EmptyStatementList {
|
||||
context: format!("predicate '{}'", name),
|
||||
span: pred_def.span,
|
||||
});
|
||||
}
|
||||
|
||||
// Build wildcard scope
|
||||
let mut wildcards = HashMap::new();
|
||||
let mut wildcard_index = 0;
|
||||
|
||||
// Process public arguments
|
||||
for arg in &pred_def.args.public_args {
|
||||
if wildcards.contains_key(&arg.name) {
|
||||
return Err(ValidationError::DuplicateWildcard {
|
||||
name: arg.name.clone(),
|
||||
span: arg.span,
|
||||
});
|
||||
}
|
||||
wildcards.insert(
|
||||
arg.name.clone(),
|
||||
WildcardInfo {
|
||||
index: wildcard_index,
|
||||
is_public: true,
|
||||
source_span: arg.span,
|
||||
},
|
||||
);
|
||||
wildcard_index += 1;
|
||||
}
|
||||
|
||||
// Process private arguments
|
||||
let mut private_count = 0;
|
||||
if let Some(private_args) = &pred_def.args.private_args {
|
||||
for arg in private_args {
|
||||
if wildcards.contains_key(&arg.name) {
|
||||
return Err(ValidationError::DuplicateWildcard {
|
||||
name: arg.name.clone(),
|
||||
span: arg.span,
|
||||
});
|
||||
}
|
||||
wildcards.insert(
|
||||
arg.name.clone(),
|
||||
WildcardInfo {
|
||||
index: wildcard_index,
|
||||
is_public: false,
|
||||
source_span: arg.span,
|
||||
},
|
||||
);
|
||||
wildcard_index += 1;
|
||||
private_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Add to symbol table
|
||||
self.symbols.predicates.insert(
|
||||
name.clone(),
|
||||
PredicateInfo {
|
||||
kind: PredicateKind::Custom {
|
||||
index: self.custom_predicate_count,
|
||||
},
|
||||
arity: pred_def.args.public_args.len() + private_count,
|
||||
public_arity: pred_def.args.public_args.len(),
|
||||
source_span: pred_def.name.span,
|
||||
},
|
||||
);
|
||||
|
||||
self.symbols
|
||||
.wildcard_scopes
|
||||
.insert(name.clone(), WildcardScope { wildcards });
|
||||
self.custom_predicate_count += 1;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_references(&mut self, document: &Document) -> Result<(), ValidationError> {
|
||||
for item in &document.items {
|
||||
match item {
|
||||
DocumentItem::CustomPredicateDef(pred_def) => {
|
||||
self.validate_custom_predicate_statements(pred_def)?;
|
||||
}
|
||||
DocumentItem::RequestDef(req_def) => {
|
||||
self.validate_request_statements(req_def)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_custom_predicate_statements(
|
||||
&self,
|
||||
pred_def: &CustomPredicateDef,
|
||||
) -> Result<(), ValidationError> {
|
||||
let pred_name = pred_def.name.name.clone();
|
||||
|
||||
for stmt in &pred_def.statements {
|
||||
let wildcard_scope = self
|
||||
.symbols
|
||||
.wildcard_scopes
|
||||
.get(&pred_name)
|
||||
.expect("Wildcard scope should exist after pass 1");
|
||||
self.validate_statement(stmt, Some((&pred_name, wildcard_scope)))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_request_statements(&mut self, req_def: &RequestDef) -> Result<(), ValidationError> {
|
||||
if req_def.statements.is_empty() {
|
||||
self.diagnostics.push(Diagnostic {
|
||||
level: DiagnosticLevel::Warning,
|
||||
message: "Empty REQUEST block".to_string(),
|
||||
span: req_def.span,
|
||||
});
|
||||
}
|
||||
|
||||
for stmt in &req_def.statements {
|
||||
self.validate_statement(stmt, None)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_statement(
|
||||
&self,
|
||||
stmt: &StatementTmpl,
|
||||
wildcard_context: Option<(&str, &WildcardScope)>,
|
||||
) -> Result<(), ValidationError> {
|
||||
let pred_name = &stmt.predicate.name;
|
||||
|
||||
// Check if predicate exists
|
||||
let pred_info = if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||
// Native predicate
|
||||
PredicateInfo {
|
||||
kind: PredicateKind::Native(native),
|
||||
arity: native.arity(),
|
||||
public_arity: native.arity(),
|
||||
source_span: None,
|
||||
}
|
||||
} else if let Some(info) = self.symbols.predicates.get(pred_name) {
|
||||
// Custom or imported predicate
|
||||
info.clone()
|
||||
} else {
|
||||
return Err(ValidationError::UndefinedPredicate {
|
||||
name: pred_name.clone(),
|
||||
span: stmt.predicate.span,
|
||||
});
|
||||
};
|
||||
|
||||
let expected_arity = pred_info.public_arity;
|
||||
|
||||
if stmt.args.len() != expected_arity {
|
||||
return Err(ValidationError::ArgumentCountMismatch {
|
||||
predicate: pred_name.clone(),
|
||||
expected: expected_arity,
|
||||
found: stmt.args.len(),
|
||||
span: stmt.span,
|
||||
});
|
||||
}
|
||||
|
||||
// Validate arguments
|
||||
self.validate_statement_args(stmt, &pred_info, wildcard_context)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_statement_args(
|
||||
&self,
|
||||
stmt: &StatementTmpl,
|
||||
pred_info: &PredicateInfo,
|
||||
wildcard_context: Option<(&str, &WildcardScope)>,
|
||||
) -> Result<(), ValidationError> {
|
||||
// For custom predicates, only wildcards and literals are allowed
|
||||
if matches!(
|
||||
pred_info.kind,
|
||||
PredicateKind::Custom { .. } | PredicateKind::BatchImported { .. }
|
||||
) {
|
||||
for arg in &stmt.args {
|
||||
match arg {
|
||||
StatementTmplArg::AnchoredKey(_) => {
|
||||
return Err(ValidationError::InvalidArgumentType {
|
||||
predicate: stmt.predicate.name.clone(),
|
||||
span: stmt.span,
|
||||
});
|
||||
}
|
||||
StatementTmplArg::Wildcard(id) => {
|
||||
if let Some((pred_name, scope)) = wildcard_context {
|
||||
if !scope.wildcards.contains_key(&id.name) {
|
||||
return Err(ValidationError::UndefinedWildcard {
|
||||
name: id.name.clone(),
|
||||
pred_name: pred_name.to_string(),
|
||||
span: id.span,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
StatementTmplArg::Literal(_) => {}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Native predicates can have anchored keys
|
||||
for arg in &stmt.args {
|
||||
match arg {
|
||||
StatementTmplArg::Wildcard(id) => {
|
||||
if let Some((pred_name, scope)) = wildcard_context {
|
||||
if !scope.wildcards.contains_key(&id.name) {
|
||||
return Err(ValidationError::UndefinedWildcard {
|
||||
name: id.name.clone(),
|
||||
pred_name: pred_name.to_string(),
|
||||
span: id.span,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
StatementTmplArg::AnchoredKey(ak) => {
|
||||
if let Some((pred_name, scope)) = wildcard_context {
|
||||
if !scope.wildcards.contains_key(&ak.root.name) {
|
||||
return Err(ValidationError::UndefinedWildcard {
|
||||
name: ak.root.name.clone(),
|
||||
pred_name: pred_name.to_string(),
|
||||
span: ak.root.span,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
StatementTmplArg::Literal(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
lang::{frontend_ast::parse::parse_document, parser::parse_podlang},
|
||||
middleware::{CustomPredicate, Params, EMPTY_HASH},
|
||||
};
|
||||
|
||||
fn parse_and_validate(
|
||||
input: &str,
|
||||
batches: &[Arc<CustomPredicateBatch>],
|
||||
) -> Result<ValidatedAST, ValidationError> {
|
||||
let parsed = parse_podlang(input).expect("Failed to parse");
|
||||
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
|
||||
validate(document, batches)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_empty() {
|
||||
let result = parse_and_validate("", &[]);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_simple_request() {
|
||||
let input = r#"REQUEST(
|
||||
Equal(A["foo"], B["bar"])
|
||||
)"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_custom_predicate() {
|
||||
let input = r#"
|
||||
my_pred(A, B) = AND (
|
||||
Equal(A["foo"], B["bar"])
|
||||
)
|
||||
"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let validated = result.unwrap();
|
||||
assert!(validated.symbols.predicates.contains_key("my_pred"));
|
||||
assert!(validated.symbols.wildcard_scopes.contains_key("my_pred"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_undefined_predicate() {
|
||||
let input = r#"REQUEST(
|
||||
UndefinedPred(A, B)
|
||||
)"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ValidationError::UndefinedPredicate { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_undefined_wildcard() {
|
||||
let input = r#"
|
||||
my_pred(A) = AND (
|
||||
Equal(A["foo"], B["bar"])
|
||||
)
|
||||
"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(
|
||||
matches!(result, Err(ValidationError::UndefinedWildcard { name, .. }) if name == "B")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arity_mismatch() {
|
||||
let input = r#"REQUEST(
|
||||
Equal(A, B, C)
|
||||
)"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ValidationError::ArgumentCountMismatch { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_predicate() {
|
||||
let input = r#"
|
||||
my_pred(A) = AND (Equal(A["x"], 1))
|
||||
my_pred(B) = AND (Equal(B["y"], 2))
|
||||
"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ValidationError::DuplicatePredicate { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_wildcard() {
|
||||
let input = r#"
|
||||
my_pred(A, A) = AND (Equal(A["x"], 1))
|
||||
"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ValidationError::DuplicateWildcard { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_predicate_with_anchored_key() {
|
||||
let input = r#"
|
||||
my_pred(A, B) = AND (
|
||||
Equal(A["foo"], B["bar"])
|
||||
)
|
||||
|
||||
REQUEST(
|
||||
my_pred(X["key"], Y)
|
||||
)
|
||||
"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ValidationError::InvalidArgumentType { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forward_reference() {
|
||||
let input = r#"
|
||||
pred1(A) = AND (
|
||||
pred2(A)
|
||||
)
|
||||
|
||||
pred2(B) = AND (
|
||||
Equal(B["x"], 1)
|
||||
)
|
||||
"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_private_args() {
|
||||
let input = r#"
|
||||
my_pred(A, private: B, C) = AND (
|
||||
Equal(A["x"], B["y"])
|
||||
Equal(B["z"], C["w"])
|
||||
)
|
||||
"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let validated = result.unwrap();
|
||||
let pred_info = &validated.symbols.predicates["my_pred"];
|
||||
assert_eq!(pred_info.arity, 3);
|
||||
assert_eq!(pred_info.public_arity, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_statement_list() {
|
||||
// Create a custom predicate with empty statements to test validation
|
||||
let document = Document {
|
||||
items: vec![DocumentItem::CustomPredicateDef(CustomPredicateDef {
|
||||
name: Identifier {
|
||||
name: "my_pred".to_string(),
|
||||
span: None,
|
||||
},
|
||||
args: ArgSection {
|
||||
public_args: vec![Identifier {
|
||||
name: "A".to_string(),
|
||||
span: None,
|
||||
}],
|
||||
private_args: None,
|
||||
span: None,
|
||||
},
|
||||
conjunction_type: ConjunctionType::And,
|
||||
statements: vec![], // Empty statements
|
||||
span: None,
|
||||
})],
|
||||
};
|
||||
let result = validate(document, &[]);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ValidationError::EmptyStatementList { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_request_definitions() {
|
||||
let input = r#"
|
||||
REQUEST(Equal(A["x"], 1))
|
||||
REQUEST(Equal(B["y"], 2))
|
||||
"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(matches!(
|
||||
result,
|
||||
Err(ValidationError::MultipleRequestDefinitions { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_use_statement() {
|
||||
let params = Params::default();
|
||||
|
||||
// Create a batch to import
|
||||
let pred = CustomPredicate::and(
|
||||
¶ms,
|
||||
"imported".to_string(),
|
||||
vec![],
|
||||
2,
|
||||
vec!["X".to_string(), "Y".to_string()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let batch = CustomPredicateBatch::new(¶ms, "TestBatch".to_string(), vec![pred]);
|
||||
|
||||
let batch_id = batch.id().encode_hex::<String>();
|
||||
let input = format!(
|
||||
r#"
|
||||
use batch imported_pred from 0x{}
|
||||
use intro intro_pred() from 0x{}
|
||||
|
||||
REQUEST(
|
||||
imported_pred(A, B)
|
||||
intro_pred()
|
||||
)
|
||||
"#,
|
||||
batch_id,
|
||||
EMPTY_HASH.encode_hex::<String>()
|
||||
);
|
||||
|
||||
let result = parse_and_validate(&input, &[batch]);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let validated = result.unwrap();
|
||||
assert!(validated.symbols.predicates.contains_key("imported_pred"));
|
||||
assert!(validated.symbols.predicates.contains_key("intro_pred"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_syntactic_sugar_predicates() {
|
||||
let input = r#"REQUEST(
|
||||
GtEq(A["x"], B["y"])
|
||||
DictContains(D, K, V)
|
||||
SetNotContains(S, E)
|
||||
)"#;
|
||||
let result = parse_and_validate(input, &[]);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
|
@ -33,8 +33,9 @@ use_predicate_list = { import_name ~ ("," ~ import_name)* }
|
|||
import_name = { identifier | "_" }
|
||||
batch_ref = { hash_hex }
|
||||
|
||||
use_intro_statement = { "use" ~ "intro" ~ identifier ~ "(" ~ use_intro_arg_list? ~ ")" ~ "from" ~ batch_ref }
|
||||
use_intro_statement = { "use" ~ "intro" ~ identifier ~ "(" ~ use_intro_arg_list? ~ ")" ~ "from" ~ intro_predicate_ref }
|
||||
use_intro_arg_list = { identifier ~ ("," ~ identifier)* }
|
||||
intro_predicate_ref = { hash_hex }
|
||||
|
||||
request_def = { "REQUEST" ~ "(" ~ statement_list? ~ ")" }
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,27 @@
|
|||
pub mod error;
|
||||
pub mod frontend_ast;
|
||||
pub mod frontend_ast_lower;
|
||||
pub mod frontend_ast_split;
|
||||
pub mod frontend_ast_validate;
|
||||
pub mod parser;
|
||||
pub mod pretty_print;
|
||||
pub mod processor;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use error::LangError;
|
||||
pub use parser::{parse_podlang, Pairs, ParseError, Rule};
|
||||
pub use pretty_print::PrettyPrint;
|
||||
pub use processor::process_pest_tree;
|
||||
use processor::PodlangOutput;
|
||||
|
||||
use crate::middleware::{CustomPredicateBatch, Params};
|
||||
use crate::{
|
||||
frontend::PodRequest,
|
||||
middleware::{CustomPredicateBatch, Params},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct PodlangOutput {
|
||||
pub custom_batch: Arc<CustomPredicateBatch>,
|
||||
pub request: PodRequest,
|
||||
}
|
||||
|
||||
pub fn parse(
|
||||
input: &str,
|
||||
|
|
@ -19,7 +29,28 @@ pub fn parse(
|
|||
available_batches: &[Arc<CustomPredicateBatch>],
|
||||
) -> Result<PodlangOutput, LangError> {
|
||||
let pairs = parse_podlang(input)?;
|
||||
processor::process_pest_tree(pairs, params, available_batches).map_err(LangError::from)
|
||||
let document_pair = pairs
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("parse_podlang should always return at least one pair for a valid document");
|
||||
let document = frontend_ast::parse::parse_document(document_pair)?;
|
||||
let validated = frontend_ast_validate::validate(document, available_batches)?;
|
||||
let lowered = frontend_ast_lower::lower(validated, params, "PodlangBatch".to_string())?;
|
||||
|
||||
let custom_batch = lowered.batch.unwrap_or_else(|| {
|
||||
// If no batch, create an empty one
|
||||
CustomPredicateBatch::new(params, "PodlangBatch".to_string(), vec![])
|
||||
});
|
||||
|
||||
let request = lowered.request.unwrap_or_else(|| {
|
||||
// If no request, create an empty one
|
||||
PodRequest::new(vec![])
|
||||
});
|
||||
|
||||
Ok(PodlangOutput {
|
||||
custom_batch,
|
||||
request,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -30,7 +61,6 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::{
|
||||
backends::plonky2::primitives::ec::schnorr::SecretKey,
|
||||
lang::error::ProcessorError,
|
||||
middleware::{
|
||||
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate,
|
||||
Params, Predicate, RawValue, StatementTmpl, StatementTmplArg, Value, Wildcard,
|
||||
|
|
@ -963,13 +993,13 @@ mod tests {
|
|||
assert!(result.is_err());
|
||||
|
||||
match result.err().unwrap() {
|
||||
LangError::Processor(e) => match *e {
|
||||
ProcessorError::BatchNotFound { id, .. } => {
|
||||
LangError::Validation(e) => match *e {
|
||||
frontend_ast_validate::ValidationError::BatchNotFound { id, .. } => {
|
||||
assert_eq!(id, unknown_batch_id);
|
||||
}
|
||||
_ => panic!("Expected BatchNotFound error, but got {:?}", e),
|
||||
},
|
||||
e => panic!("Expected LangError::Processor, but got {:?}", e),
|
||||
e => panic!("Expected LangError::Validation, but got {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -991,16 +1021,18 @@ mod tests {
|
|||
assert!(result.is_err());
|
||||
|
||||
match result.err().unwrap() {
|
||||
LangError::Processor(e) => match *e {
|
||||
ProcessorError::UndefinedWildcard {
|
||||
name, pred_name, ..
|
||||
LangError::Validation(e) => match *e {
|
||||
frontend_ast_validate::ValidationError::UndefinedWildcard {
|
||||
name,
|
||||
pred_name,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(name, "user_public_key");
|
||||
assert_eq!(pred_name, "identity_verified");
|
||||
}
|
||||
_ => panic!("Expected UndefinedWildcard error, but got {:?}", e),
|
||||
},
|
||||
e => panic!("Expected LangError::Processor, but got {:?}", e),
|
||||
e => panic!("Expected LangError::Validation, but got {:?}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,8 +12,20 @@ pub type Pairs<'a, R> = PestPairs<'a, R>;
|
|||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ParseError {
|
||||
#[error("Invalid integer: {0}")]
|
||||
InvalidInt(String),
|
||||
|
||||
#[error("Pest parsing error: {0}")]
|
||||
Pest(Box<pest::error::Error<Rule>>),
|
||||
|
||||
#[error("Invalid public key: {0}")]
|
||||
InvalidPublicKey(String),
|
||||
|
||||
#[error("Invalid secret key: {0}")]
|
||||
InvalidSecretKey(String),
|
||||
|
||||
#[error("Invalid escape sequence in string: {0}")]
|
||||
InvalidEscapeSequence(String),
|
||||
}
|
||||
|
||||
impl From<pest::error::Error<Rule>> for ParseError {
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,7 @@
|
|||
use std::{
|
||||
fmt::{self, Display},
|
||||
iter,
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
use plonky2::field::types::Field;
|
||||
|
|
@ -51,6 +52,42 @@ pub enum NativePredicate {
|
|||
ArrayUpdate = 1014,
|
||||
}
|
||||
|
||||
impl NativePredicate {
|
||||
pub fn arity(&self) -> usize {
|
||||
match self {
|
||||
NativePredicate::None | NativePredicate::False => 0,
|
||||
NativePredicate::Equal
|
||||
| NativePredicate::NotEqual
|
||||
| NativePredicate::Lt
|
||||
| NativePredicate::Gt
|
||||
| NativePredicate::GtEq
|
||||
| NativePredicate::LtEq
|
||||
| NativePredicate::NotContains
|
||||
| NativePredicate::SetNotContains
|
||||
| NativePredicate::DictNotContains
|
||||
| NativePredicate::PublicKeyOf
|
||||
| NativePredicate::SignedBy
|
||||
| NativePredicate::SetContains => 2,
|
||||
NativePredicate::Contains
|
||||
| NativePredicate::DictContains
|
||||
| NativePredicate::ArrayContains
|
||||
| NativePredicate::SumOf
|
||||
| NativePredicate::ProductOf
|
||||
| NativePredicate::MaxOf
|
||||
| NativePredicate::HashOf
|
||||
| NativePredicate::SetInsert
|
||||
| NativePredicate::SetDelete => 3,
|
||||
NativePredicate::DictInsert
|
||||
| NativePredicate::DictUpdate
|
||||
| NativePredicate::DictDelete
|
||||
| NativePredicate::ArrayUpdate
|
||||
| NativePredicate::ContainerInsert
|
||||
| NativePredicate::ContainerUpdate
|
||||
| NativePredicate::ContainerDelete => 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for NativePredicate {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let s = match self {
|
||||
|
|
@ -95,6 +132,45 @@ impl ToFields for NativePredicate {
|
|||
}
|
||||
}
|
||||
|
||||
impl FromStr for NativePredicate {
|
||||
type Err = Error;
|
||||
fn from_str(s: &str) -> Result<Self> {
|
||||
match s {
|
||||
"Equal" => Ok(NativePredicate::Equal),
|
||||
"NotEqual" => Ok(NativePredicate::NotEqual),
|
||||
"Gt" => Ok(NativePredicate::Gt),
|
||||
"GtEq" => Ok(NativePredicate::GtEq),
|
||||
"Lt" => Ok(NativePredicate::Lt),
|
||||
"LtEq" => Ok(NativePredicate::LtEq),
|
||||
"Contains" => Ok(NativePredicate::Contains),
|
||||
"NotContains" => Ok(NativePredicate::NotContains),
|
||||
"SumOf" => Ok(NativePredicate::SumOf),
|
||||
"ProductOf" => Ok(NativePredicate::ProductOf),
|
||||
"MaxOf" => Ok(NativePredicate::MaxOf),
|
||||
"HashOf" => Ok(NativePredicate::HashOf),
|
||||
"PublicKeyOf" => Ok(NativePredicate::PublicKeyOf),
|
||||
"SignedBy" => Ok(NativePredicate::SignedBy),
|
||||
"ContainerInsert" => Ok(NativePredicate::ContainerInsert),
|
||||
"ContainerUpdate" => Ok(NativePredicate::ContainerUpdate),
|
||||
"ContainerDelete" => Ok(NativePredicate::ContainerDelete),
|
||||
"DictContains" => Ok(NativePredicate::DictContains),
|
||||
"DictNotContains" => Ok(NativePredicate::DictNotContains),
|
||||
"ArrayContains" => Ok(NativePredicate::ArrayContains),
|
||||
"SetContains" => Ok(NativePredicate::SetContains),
|
||||
"SetNotContains" => Ok(NativePredicate::SetNotContains),
|
||||
"DictInsert" => Ok(NativePredicate::DictInsert),
|
||||
"DictUpdate" => Ok(NativePredicate::DictUpdate),
|
||||
"DictDelete" => Ok(NativePredicate::DictDelete),
|
||||
"SetInsert" => Ok(NativePredicate::SetInsert),
|
||||
"SetDelete" => Ok(NativePredicate::SetDelete),
|
||||
"ArrayUpdate" => Ok(NativePredicate::ArrayUpdate),
|
||||
"None" => Ok(NativePredicate::None),
|
||||
"False" => Ok(NativePredicate::False),
|
||||
_ => Err(Error::custom(format!("Invalid native predicate: {}", s))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct IntroPredicateRef {
|
||||
pub name: String,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue