760 lines
25 KiB
Rust
760 lines
25 KiB
Rust
//! Lowering from frontend AST to middleware structures
|
|
//!
|
|
//! This module converts validated frontend AST to middleware data structures.
|
|
//! Supports automatic predicate splitting and multi-batch packing.
|
|
|
|
use std::{
|
|
collections::{HashMap, HashSet},
|
|
str::FromStr,
|
|
sync::Arc,
|
|
};
|
|
|
|
use crate::{
|
|
frontend::{BuilderArg, StatementTmplBuilder},
|
|
lang::{
|
|
frontend_ast::*,
|
|
frontend_ast_batch::{self, PredicateBatches},
|
|
frontend_ast_split,
|
|
frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST},
|
|
},
|
|
middleware::{
|
|
containers, CustomPredicateBatch, CustomPredicateRef, IntroPredicateRef, Key,
|
|
NativePredicate, Params, Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
|
|
StatementTmplArg as MWStatementTmplArg, Value, Wildcard,
|
|
},
|
|
};
|
|
|
|
/// Context for predicate resolution - determines how local custom predicates are resolved
|
|
pub enum ResolutionContext<'a> {
|
|
/// Request context: local custom predicates resolve to Intro/CustomPredicateRef via batches
|
|
Request {
|
|
batches: Option<&'a PredicateBatches>,
|
|
},
|
|
/// Batch context: local custom predicates may resolve to BatchSelf or Intro/CustomPredicateRef
|
|
Batch {
|
|
current_batch_idx: usize,
|
|
reference_map: &'a HashMap<String, (usize, usize)>,
|
|
existing_batches: &'a [Arc<CustomPredicateBatch>],
|
|
},
|
|
}
|
|
|
|
/// Resolve a predicate name to a Predicate using the symbol table
|
|
pub fn resolve_predicate(
|
|
pred_name: &str,
|
|
symbols: &SymbolTable,
|
|
context: &ResolutionContext,
|
|
) -> Option<Predicate> {
|
|
// 1. Try native predicate first
|
|
if let Ok(native) = NativePredicate::from_str(pred_name) {
|
|
return Some(Predicate::Native(native));
|
|
}
|
|
|
|
// 2. Look up in symbol table
|
|
if let Some(info) = symbols.predicates.get(pred_name) {
|
|
let predicate = match &info.kind {
|
|
PredicateKind::Native(np) => Predicate::Native(*np),
|
|
|
|
PredicateKind::Custom { .. } => match context {
|
|
ResolutionContext::Request { batches } => {
|
|
let batches = batches.as_ref()?;
|
|
let pred_ref = batches.predicate_ref_by_name(pred_name)?;
|
|
Predicate::Custom(pred_ref)
|
|
}
|
|
ResolutionContext::Batch {
|
|
current_batch_idx,
|
|
reference_map,
|
|
existing_batches,
|
|
} => resolve_local_predicate(
|
|
pred_name,
|
|
*current_batch_idx,
|
|
reference_map,
|
|
existing_batches,
|
|
)?,
|
|
},
|
|
|
|
PredicateKind::BatchImported { batch, index } => {
|
|
Predicate::Custom(CustomPredicateRef::new(batch.clone(), *index))
|
|
}
|
|
|
|
PredicateKind::IntroImported {
|
|
name,
|
|
verifier_data_hash,
|
|
} => Predicate::Intro(IntroPredicateRef {
|
|
name: name.clone(),
|
|
args_len: info.public_arity,
|
|
verifier_data_hash: *verifier_data_hash,
|
|
}),
|
|
};
|
|
return Some(predicate);
|
|
}
|
|
|
|
// 3. In batch context, also check reference_map for split chain pieces
|
|
// (predicates created by splitting that aren't in the original symbol table)
|
|
if let ResolutionContext::Batch {
|
|
current_batch_idx,
|
|
reference_map,
|
|
existing_batches,
|
|
} = context
|
|
{
|
|
if reference_map.contains_key(pred_name) {
|
|
return resolve_local_predicate(
|
|
pred_name,
|
|
*current_batch_idx,
|
|
reference_map,
|
|
existing_batches,
|
|
);
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
/// Resolve a local predicate (one in this document or a split chain piece) using the reference_map
|
|
fn resolve_local_predicate(
|
|
pred_name: &str,
|
|
current_batch_idx: usize,
|
|
reference_map: &HashMap<String, (usize, usize)>,
|
|
existing_batches: &[Arc<CustomPredicateBatch>],
|
|
) -> Option<Predicate> {
|
|
let &(target_batch, target_idx) = reference_map.get(pred_name)?;
|
|
if target_batch == current_batch_idx {
|
|
Some(Predicate::BatchSelf(target_idx))
|
|
} else if target_batch < current_batch_idx {
|
|
let batch = &existing_batches[target_batch];
|
|
Some(Predicate::Custom(CustomPredicateRef::new(
|
|
batch.clone(),
|
|
target_idx,
|
|
)))
|
|
} else {
|
|
unreachable!(
|
|
"Forward cross-batch reference should be impossible: {} -> {}",
|
|
current_batch_idx, target_batch
|
|
);
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Shared lowering utilities
|
|
// ============================================================================
|
|
// These functions convert AST types to middleware/builder types and are used
|
|
// by both the request lowering (in this module) and predicate batching
|
|
// (in frontend_ast_batch).
|
|
|
|
/// Lower a literal value from AST to middleware Value.
|
|
///
|
|
/// This is a pure conversion that cannot fail.
|
|
pub fn lower_literal(lit: &LiteralValue) -> Value {
|
|
match lit {
|
|
LiteralValue::Int(i) => Value::from(i.value),
|
|
LiteralValue::Bool(b) => Value::from(b.value),
|
|
LiteralValue::String(s) => Value::from(s.value.clone()),
|
|
LiteralValue::Raw(r) => Value::from(r.hash.hash),
|
|
LiteralValue::PublicKey(pk) => Value::from(pk.point),
|
|
LiteralValue::SecretKey(sk) => Value::from(sk.secret_key.clone()),
|
|
LiteralValue::Array(a) => {
|
|
let elements: Vec<_> = a.elements.iter().map(lower_literal).collect();
|
|
let array = containers::Array::new(elements);
|
|
Value::from(array)
|
|
}
|
|
LiteralValue::Set(s) => {
|
|
let elements: std::collections::HashSet<_> =
|
|
s.elements.iter().map(lower_literal).collect();
|
|
let set = containers::Set::new(elements);
|
|
Value::from(set)
|
|
}
|
|
LiteralValue::Dict(d) => {
|
|
let pairs: HashMap<_, _> = d
|
|
.pairs
|
|
.iter()
|
|
.map(|pair| {
|
|
let key = Key::from(pair.key.value.as_str());
|
|
let value = lower_literal(&pair.value);
|
|
(key, value)
|
|
})
|
|
.collect();
|
|
let dict = containers::Dictionary::new(pairs);
|
|
Value::from(dict)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Lower a statement argument from AST to BuilderArg.
|
|
///
|
|
/// This is a pure conversion that cannot fail.
|
|
pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
|
|
match arg {
|
|
StatementTmplArg::Literal(lit) => {
|
|
let value = lower_literal(lit);
|
|
BuilderArg::Literal(value)
|
|
}
|
|
StatementTmplArg::Wildcard(id) => BuilderArg::WildcardLiteral(id.name.clone()),
|
|
StatementTmplArg::AnchoredKey(ak) => {
|
|
let key_str = match &ak.key {
|
|
AnchoredKeyPath::Bracket(s) => s.value.clone(),
|
|
AnchoredKeyPath::Dot(id) => id.name.clone(),
|
|
};
|
|
BuilderArg::Key(ak.root.name.clone(), key_str)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Result of lowering: optional custom predicate batches and optional request
|
|
///
|
|
/// A Podlang file can contain:
|
|
/// - Just custom predicates (batches: Some, request: None)
|
|
/// - Just a request (batches: None, request: Some)
|
|
/// - Both (batches: Some, request: Some)
|
|
/// - Neither (batches: None, request: None) - just imports
|
|
#[derive(Debug, Clone)]
|
|
pub struct LoweredOutput {
|
|
pub batches: Option<PredicateBatches>,
|
|
pub request: Option<crate::frontend::PodRequest>,
|
|
}
|
|
|
|
pub use crate::lang::error::LoweringError;
|
|
|
|
/// Lower a validated AST to middleware structures
|
|
///
|
|
/// Returns both the custom predicate batch (if any) and the request (if any).
|
|
/// At least one will be Some if the document contains custom predicates or a request.
|
|
pub fn lower(
|
|
validated: ValidatedAST,
|
|
params: &Params,
|
|
batch_name: String,
|
|
) -> Result<LoweredOutput, LoweringError> {
|
|
if !validated.diagnostics().is_empty() {
|
|
// For now, treat any diagnostics as errors
|
|
// In future we could allow warnings
|
|
return Err(LoweringError::ValidationErrors);
|
|
}
|
|
|
|
let lowerer = Lowerer::new(validated, params);
|
|
lowerer.lower(batch_name)
|
|
}
|
|
|
|
struct Lowerer<'a> {
|
|
validated: ValidatedAST,
|
|
params: &'a Params,
|
|
}
|
|
|
|
impl<'a> Lowerer<'a> {
|
|
fn new(validated: ValidatedAST, params: &'a Params) -> Self {
|
|
Self { validated, params }
|
|
}
|
|
|
|
fn lower(self, batch_name: String) -> Result<LoweredOutput, LoweringError> {
|
|
// Lower custom predicates (if any) - now supports multiple batches
|
|
let batches = self.lower_batches(batch_name)?;
|
|
|
|
// Lower request (if any) - pass batches so refs can be resolved
|
|
let request = self.lower_request(batches.as_ref())?;
|
|
|
|
Ok(LoweredOutput { batches, request })
|
|
}
|
|
|
|
fn lower_batches(&self, batch_name: String) -> Result<Option<PredicateBatches>, LoweringError> {
|
|
// Extract and split custom predicates from document
|
|
let custom_predicates = self.extract_and_split_predicates()?;
|
|
|
|
// If no custom predicates, return None
|
|
if custom_predicates.is_empty() {
|
|
return Ok(None);
|
|
}
|
|
|
|
// Use the new batching module to pack predicates into batches
|
|
// Pass the symbol table for unified predicate resolution
|
|
let batches = frontend_ast_batch::batch_predicates(
|
|
custom_predicates,
|
|
self.params,
|
|
&batch_name,
|
|
self.validated.symbols(),
|
|
)?;
|
|
|
|
Ok(Some(batches))
|
|
}
|
|
|
|
fn lower_request(
|
|
&self,
|
|
batches: Option<&PredicateBatches>,
|
|
) -> Result<Option<crate::frontend::PodRequest>, LoweringError> {
|
|
let doc = self.validated.document();
|
|
|
|
// Find request definition (if any)
|
|
let request_def = doc.items.iter().find_map(|item| match item {
|
|
DocumentItem::RequestDef(req) => Some(req),
|
|
_ => None,
|
|
});
|
|
|
|
let Some(request_def) = request_def else {
|
|
return Ok(None);
|
|
};
|
|
|
|
// Build wildcard map from all wildcards used in the request statements
|
|
let wildcard_map = self.build_request_wildcard_map(request_def);
|
|
|
|
// Lower each statement to middleware templates, resolving predicates
|
|
let mut request_templates = Vec::new();
|
|
for stmt in &request_def.statements {
|
|
let mw_stmt = self.lower_request_statement(stmt, &wildcard_map, batches)?;
|
|
request_templates.push(mw_stmt);
|
|
}
|
|
|
|
Ok(Some(crate::frontend::PodRequest::new(request_templates)))
|
|
}
|
|
|
|
fn lower_request_statement(
|
|
&self,
|
|
stmt: &StatementTmpl,
|
|
wildcard_map: &HashMap<String, usize>,
|
|
batches: Option<&PredicateBatches>,
|
|
) -> Result<MWStatementTmpl, LoweringError> {
|
|
// Enforce argument count limit for request statements
|
|
if stmt.args.len() > self.params.max_statement_args {
|
|
return Err(LoweringError::TooManyStatementArgs {
|
|
count: stmt.args.len(),
|
|
max: self.params.max_statement_args,
|
|
});
|
|
}
|
|
|
|
let pred_name = &stmt.predicate.name;
|
|
let symbols = self.validated.symbols();
|
|
|
|
// Resolve predicate using the unified resolution function
|
|
let context = ResolutionContext::Request { batches };
|
|
let predicate = resolve_predicate(pred_name, symbols, &context).ok_or_else(|| {
|
|
LoweringError::PredicateNotFound {
|
|
name: pred_name.clone(),
|
|
}
|
|
})?;
|
|
|
|
// Create a builder with the resolved predicate and desugar
|
|
let mut builder = StatementTmplBuilder::new(predicate);
|
|
for arg in &stmt.args {
|
|
let builder_arg = lower_statement_arg(arg);
|
|
builder = builder.arg(builder_arg);
|
|
}
|
|
let desugared = builder.desugar();
|
|
|
|
// Convert BuilderArgs to StatementTmplArgs
|
|
let mut mw_args = Vec::new();
|
|
for builder_arg in desugared.args {
|
|
let mw_arg = match builder_arg {
|
|
BuilderArg::Literal(value) => MWStatementTmplArg::Literal(value),
|
|
BuilderArg::WildcardLiteral(name) => {
|
|
let index = wildcard_map.get(&name).expect("Wildcard not found");
|
|
MWStatementTmplArg::Wildcard(Wildcard::new(name, *index))
|
|
}
|
|
BuilderArg::Key(root_name, key_str) => {
|
|
let root_index = wildcard_map
|
|
.get(&root_name)
|
|
.expect("Root wildcard not found");
|
|
let wildcard = Wildcard::new(root_name, *root_index);
|
|
let key = Key::from(key_str.as_str());
|
|
MWStatementTmplArg::AnchoredKey(wildcard, key)
|
|
}
|
|
};
|
|
mw_args.push(mw_arg);
|
|
}
|
|
|
|
Ok(MWStatementTmpl {
|
|
// TODO: Support wildcard
|
|
pred_or_wc: PredicateOrWildcard::Predicate(desugared.predicate),
|
|
args: mw_args,
|
|
})
|
|
}
|
|
|
|
fn build_request_wildcard_map(&self, request_def: &RequestDef) -> HashMap<String, usize> {
|
|
// Collect all unique wildcards from all statements
|
|
let mut wildcard_names = Vec::new();
|
|
let mut seen = HashSet::new();
|
|
|
|
for stmt in &request_def.statements {
|
|
self.collect_statement_wildcards(stmt, &mut wildcard_names, &mut seen);
|
|
}
|
|
|
|
// Build map from name to index
|
|
wildcard_names
|
|
.into_iter()
|
|
.enumerate()
|
|
.map(|(idx, name)| (name, idx))
|
|
.collect()
|
|
}
|
|
|
|
fn collect_statement_wildcards(
|
|
&self,
|
|
stmt: &StatementTmpl,
|
|
names: &mut Vec<String>,
|
|
seen: &mut HashSet<String>,
|
|
) {
|
|
for arg in &stmt.args {
|
|
match arg {
|
|
StatementTmplArg::Wildcard(id) => {
|
|
if !seen.contains(&id.name) {
|
|
seen.insert(id.name.clone());
|
|
names.push(id.name.clone());
|
|
}
|
|
}
|
|
StatementTmplArg::AnchoredKey(ak) => {
|
|
if !seen.contains(&ak.root.name) {
|
|
seen.insert(ak.root.name.clone());
|
|
names.push(ak.root.name.clone());
|
|
}
|
|
}
|
|
StatementTmplArg::Literal(_) => {}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn extract_and_split_predicates(
|
|
&self,
|
|
) -> Result<Vec<frontend_ast_split::SplitResult>, LoweringError> {
|
|
let doc = self.validated.document();
|
|
let predicates: Vec<CustomPredicateDef> = doc
|
|
.items
|
|
.iter()
|
|
.filter_map(|item| match item {
|
|
DocumentItem::CustomPredicateDef(pred) => Some(pred.clone()),
|
|
_ => None,
|
|
})
|
|
.collect();
|
|
|
|
// Apply splitting to each predicate as needed
|
|
let mut split_results = Vec::new();
|
|
for pred in predicates {
|
|
let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?;
|
|
split_results.push(result);
|
|
}
|
|
|
|
Ok(split_results)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::lang::{
|
|
frontend_ast::parse::parse_document, frontend_ast_validate::validate, parser::parse_podlang,
|
|
};
|
|
|
|
fn parse_validate_and_lower(
|
|
input: &str,
|
|
params: &Params,
|
|
) -> Result<LoweredOutput, LoweringError> {
|
|
let parsed = parse_podlang(input).expect("Failed to parse");
|
|
let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse");
|
|
let validated = validate(document, &[]).expect("Failed to validate");
|
|
lower(validated, params, "test_batch".to_string())
|
|
}
|
|
|
|
// Helper to get the first batch from the output (expecting it to exist)
|
|
fn expect_batch(
|
|
output: &LoweredOutput,
|
|
) -> &std::sync::Arc<crate::middleware::CustomPredicateBatch> {
|
|
output
|
|
.batches
|
|
.as_ref()
|
|
.expect("Expected batches to be present")
|
|
.first_batch()
|
|
.expect("Expected at least one batch")
|
|
}
|
|
|
|
#[test]
|
|
fn test_simple_predicate() {
|
|
let input = r#"
|
|
my_pred(A, B) = AND (
|
|
Equal(A["foo"], B["bar"])
|
|
)
|
|
"#;
|
|
|
|
let params = Params::default();
|
|
let result = parse_validate_and_lower(input, ¶ms);
|
|
if let Err(e) = &result {
|
|
eprintln!("Error: {:?}", e);
|
|
}
|
|
assert!(result.is_ok());
|
|
|
|
let lowered = result.unwrap();
|
|
assert_eq!(expect_batch(&lowered).predicates().len(), 1);
|
|
|
|
let pred = &expect_batch(&lowered).predicates()[0];
|
|
assert_eq!(pred.name, "my_pred");
|
|
assert_eq!(pred.args_len(), 2);
|
|
assert_eq!(pred.wildcard_names().len(), 2);
|
|
assert_eq!(pred.statements().len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_private_args() {
|
|
let input = r#"
|
|
my_pred(A, private: B, C) = AND (
|
|
Equal(A["x"], B["y"])
|
|
Equal(B["z"], C["w"])
|
|
)
|
|
"#;
|
|
|
|
let params = Params::default();
|
|
let result = parse_validate_and_lower(input, ¶ms);
|
|
assert!(result.is_ok());
|
|
|
|
let lowered = result.unwrap();
|
|
let pred = &expect_batch(&lowered).predicates()[0];
|
|
assert_eq!(pred.args_len(), 1); // Only A is public
|
|
assert_eq!(pred.wildcard_names().len(), 3); // A, B, C total
|
|
}
|
|
|
|
#[test]
|
|
fn test_or_predicate() {
|
|
let input = r#"
|
|
my_pred(A, B) = OR (
|
|
Equal(A["x"], 1)
|
|
Equal(B["y"], 2)
|
|
)
|
|
"#;
|
|
|
|
let params = Params::default();
|
|
let result = parse_validate_and_lower(input, ¶ms);
|
|
assert!(result.is_ok());
|
|
|
|
let lowered = result.unwrap();
|
|
let pred = &expect_batch(&lowered).predicates()[0];
|
|
assert!(pred.is_disjunction());
|
|
}
|
|
|
|
#[test]
|
|
fn test_automatic_splitting() {
|
|
let input = r#"
|
|
my_pred(A) = AND (
|
|
Equal(A["a"], 1)
|
|
Equal(A["b"], 2)
|
|
Equal(A["c"], 3)
|
|
Equal(A["d"], 4)
|
|
Equal(A["e"], 5)
|
|
Equal(A["f"], 6)
|
|
)
|
|
"#;
|
|
|
|
let params = Params::default(); // max_custom_predicate_arity = 5
|
|
let result = parse_validate_and_lower(input, ¶ms);
|
|
if let Err(e) = &result {
|
|
eprintln!("Splitting error: {:?}", e);
|
|
}
|
|
assert!(result.is_ok());
|
|
|
|
let lowered = result.unwrap();
|
|
// Should be automatically split into 2 predicates (my_pred and my_pred_1)
|
|
let batches = lowered.batches.as_ref().expect("Expected batches");
|
|
assert_eq!(batches.total_predicate_count(), 2);
|
|
|
|
// With topological sorting, my_pred_1 comes first (since my_pred depends on it)
|
|
// my_pred_1 has 2 statements
|
|
// my_pred has 5 statements (4 + chain call)
|
|
// Just verify we have the right total statement counts
|
|
let batch = batches.first_batch().unwrap();
|
|
let total_statements: usize = batch
|
|
.predicates()
|
|
.iter()
|
|
.map(|p| p.statements().len())
|
|
.sum();
|
|
assert_eq!(total_statements, 7); // 5 + 2 = 7 total statements
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiple_predicates() {
|
|
let input = r#"
|
|
pred1(A) = AND (
|
|
Equal(A["x"], 1)
|
|
)
|
|
|
|
pred2(B) = AND (
|
|
Equal(B["y"], 2)
|
|
)
|
|
"#;
|
|
|
|
let params = Params::default();
|
|
let result = parse_validate_and_lower(input, ¶ms);
|
|
assert!(result.is_ok());
|
|
|
|
let lowered = result.unwrap();
|
|
assert_eq!(expect_batch(&lowered).predicates().len(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_batch_self_reference() {
|
|
let input = r#"
|
|
pred1(A) = AND (
|
|
Equal(A["x"], 1)
|
|
)
|
|
|
|
pred2(B) = AND (
|
|
pred1(B)
|
|
)
|
|
"#;
|
|
|
|
let params = Params::default();
|
|
let result = parse_validate_and_lower(input, ¶ms);
|
|
assert!(result.is_ok());
|
|
|
|
let lowered = result.unwrap();
|
|
let pred2 = &expect_batch(&lowered).predicates()[1];
|
|
let stmt = &pred2.statements()[0];
|
|
|
|
// Should be BatchSelf(0) referring to pred1
|
|
assert!(matches!(
|
|
stmt.pred_or_wc,
|
|
PredicateOrWildcard::Predicate(Predicate::BatchSelf(0))
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_literals() {
|
|
let input = r#"
|
|
my_pred(X) = AND (
|
|
Equal(X["int"], 42)
|
|
Equal(X["bool"], true)
|
|
Equal(X["string"], "hello")
|
|
)
|
|
"#;
|
|
|
|
let params = Params::default();
|
|
let result = parse_validate_and_lower(input, ¶ms);
|
|
assert!(result.is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_syntactic_sugar_desugaring() {
|
|
let input = r#"
|
|
my_pred(D) = AND (
|
|
DictContains(D, "key", "value")
|
|
)
|
|
"#;
|
|
|
|
let params = Params::default();
|
|
let result = parse_validate_and_lower(input, ¶ms);
|
|
assert!(result.is_ok());
|
|
|
|
let lowered = result.unwrap();
|
|
let pred = &expect_batch(&lowered).predicates()[0];
|
|
let stmt = &pred.statements()[0];
|
|
|
|
// Should desugar to the Contains predicate
|
|
assert!(matches!(
|
|
stmt.pred_or_wc,
|
|
PredicateOrWildcard::Predicate(Predicate::Native(NativePredicate::Contains))
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_multi_batch_packing() {
|
|
// Create more predicates than fit in a single batch
|
|
// With max_custom_batch_size = 4, 5 predicates should span 2 batches
|
|
let input = r#"
|
|
pred1(A) = AND(Equal(A["a"], 1))
|
|
pred2(B) = AND(Equal(B["b"], 2))
|
|
pred3(C) = AND(Equal(C["c"], 3))
|
|
pred4(D) = AND(Equal(D["d"], 4))
|
|
pred5(E) = AND(Equal(E["e"], 5))
|
|
"#;
|
|
|
|
let params = Params::default(); // max_custom_batch_size = 4
|
|
|
|
let result = parse_validate_and_lower(input, ¶ms);
|
|
assert!(result.is_ok());
|
|
|
|
let lowered = result.unwrap();
|
|
let batches = lowered.batches.as_ref().expect("Expected batches");
|
|
|
|
// Should have 2 batches
|
|
assert_eq!(batches.batch_count(), 2);
|
|
assert_eq!(batches.total_predicate_count(), 5);
|
|
|
|
// First batch should have 4 predicates
|
|
assert_eq!(batches.batches()[0].predicates().len(), 4);
|
|
// Second batch should have 1 predicate
|
|
assert_eq!(batches.batches()[1].predicates().len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_split_chains_span_batches() {
|
|
// Create predicates that will split, plus additional predicates
|
|
// to force the split chains across batch boundaries
|
|
let input = r#"
|
|
pred1(A) = AND(Equal(A["a"], 1))
|
|
pred2(B) = AND(Equal(B["b"], 2))
|
|
pred3(C) = AND(Equal(C["c"], 3))
|
|
large_pred(D) = AND(
|
|
Equal(D["a"], 1)
|
|
Equal(D["b"], 2)
|
|
Equal(D["c"], 3)
|
|
Equal(D["d"], 4)
|
|
Equal(D["e"], 5)
|
|
Equal(D["f"], 6)
|
|
)
|
|
"#;
|
|
|
|
let params = Params::default();
|
|
|
|
let result = parse_validate_and_lower(input, ¶ms);
|
|
assert!(result.is_ok());
|
|
|
|
let lowered = result.unwrap();
|
|
let batches = lowered.batches.as_ref().expect("Expected batches");
|
|
|
|
// pred1, pred2, pred3 + large_pred split into 2 = 5 total predicates
|
|
// Should span 2 batches
|
|
assert_eq!(batches.total_predicate_count(), 5);
|
|
assert_eq!(batches.batch_count(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_intro_predicate_in_custom_predicate() {
|
|
use hex::ToHex;
|
|
|
|
use crate::middleware::EMPTY_HASH;
|
|
|
|
// Import an intro predicate and use it inside a custom predicate definition
|
|
let intro_hash = EMPTY_HASH.encode_hex::<String>();
|
|
let input = format!(
|
|
r#"
|
|
use intro external_check(X) from 0x{intro_hash}
|
|
|
|
my_pred(A) = AND (
|
|
Equal(A["foo"], 42)
|
|
external_check(A)
|
|
)
|
|
"#
|
|
);
|
|
|
|
let params = Params::default();
|
|
|
|
// Parse, validate, and lower
|
|
let parsed = parse_podlang(&input).expect("Failed to parse");
|
|
let document =
|
|
parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document");
|
|
let validated = validate(document, &[]).expect("Failed to validate");
|
|
let result = lower(validated, ¶ms, "test_batch".to_string());
|
|
|
|
assert!(result.is_ok(), "Lowering failed: {:?}", result.err());
|
|
|
|
let lowered = result.unwrap();
|
|
let batch = expect_batch(&lowered);
|
|
|
|
// Should have one custom predicate
|
|
assert_eq!(batch.predicates().len(), 1);
|
|
|
|
let pred = &batch.predicates()[0];
|
|
assert_eq!(pred.name, "my_pred");
|
|
// 2 statements: Equal and external_check
|
|
assert_eq!(pred.statements().len(), 2);
|
|
|
|
// Verify the second statement is an intro predicate reference
|
|
let intro_stmt = &pred.statements()[1];
|
|
match intro_stmt.pred_or_wc() {
|
|
PredicateOrWildcard::Predicate(Predicate::Intro(intro_ref)) => {
|
|
assert_eq!(intro_ref.name, "external_check");
|
|
assert_eq!(intro_ref.args_len, 1);
|
|
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
|
|
}
|
|
other => panic!("Expected Intro predicate, got {:?}", other),
|
|
}
|
|
}
|
|
}
|