Improved predicate splitting (#445)
* Multi-batch splitting * Invoke split predicates by name, passing in full argument list * Reorder batches to prevent failure of forward references where possible * Rename APIs for clarity * Simplify example * Add more docs * Review updates * Remove duplicate code * Comment topological sort algorithm
This commit is contained in:
parent
9c9a2c454c
commit
d1b7b4d37e
12 changed files with 2090 additions and 466 deletions
|
|
@ -1,38 +1,103 @@
|
|||
//! Lowering from frontend AST to middleware structures
|
||||
//!
|
||||
//! This module converts validated frontend AST to middleware data structures.
|
||||
//! Currently implements basic 1:1 conversion without automatic predicate splitting.
|
||||
//! Supports automatic predicate splitting and multi-batch packing.
|
||||
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder},
|
||||
frontend::{BuilderArg, StatementTmplBuilder},
|
||||
lang::{
|
||||
frontend_ast::*,
|
||||
frontend_ast_batch::{self, PredicateBatches},
|
||||
frontend_ast_split,
|
||||
frontend_ast_validate::{PredicateKind, ValidatedAST},
|
||||
},
|
||||
middleware::{
|
||||
self, containers, CustomPredicateBatch, IntroPredicateRef, NativePredicate, Params,
|
||||
Predicate, PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
|
||||
self, containers, IntroPredicateRef, NativePredicate, Params, Predicate,
|
||||
PredicateOrWildcard, StatementTmpl as MWStatementTmpl,
|
||||
StatementTmplArg as MWStatementTmplArg, Wildcard,
|
||||
},
|
||||
};
|
||||
|
||||
/// Result of lowering: optional custom predicate batch and optional request
|
||||
// ============================================================================
|
||||
// Shared lowering utilities
|
||||
// ============================================================================
|
||||
// These functions convert AST types to middleware/builder types and are used
|
||||
// by both the request lowering (in this module) and predicate batching
|
||||
// (in frontend_ast_batch).
|
||||
|
||||
/// Lower a literal value from AST to middleware Value.
|
||||
///
|
||||
/// This is a pure conversion that cannot fail.
|
||||
pub fn lower_literal(lit: &LiteralValue) -> middleware::Value {
|
||||
match lit {
|
||||
LiteralValue::Int(i) => middleware::Value::from(i.value),
|
||||
LiteralValue::Bool(b) => middleware::Value::from(b.value),
|
||||
LiteralValue::String(s) => middleware::Value::from(s.value.clone()),
|
||||
LiteralValue::Raw(r) => middleware::Value::from(r.hash.hash),
|
||||
LiteralValue::PublicKey(pk) => middleware::Value::from(pk.point),
|
||||
LiteralValue::SecretKey(sk) => middleware::Value::from(sk.secret_key.clone()),
|
||||
LiteralValue::Array(a) => {
|
||||
let elements: Vec<_> = a.elements.iter().map(lower_literal).collect();
|
||||
let array = containers::Array::new(elements);
|
||||
middleware::Value::from(array)
|
||||
}
|
||||
LiteralValue::Set(s) => {
|
||||
let elements: std::collections::HashSet<_> =
|
||||
s.elements.iter().map(lower_literal).collect();
|
||||
let set = containers::Set::new(elements);
|
||||
middleware::Value::from(set)
|
||||
}
|
||||
LiteralValue::Dict(d) => {
|
||||
let pairs: std::collections::HashMap<_, _> = d
|
||||
.pairs
|
||||
.iter()
|
||||
.map(|pair| {
|
||||
let key = middleware::Key::from(pair.key.value.as_str());
|
||||
let value = lower_literal(&pair.value);
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
let dict = containers::Dictionary::new(pairs);
|
||||
middleware::Value::from(dict)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Lower a statement argument from AST to BuilderArg.
|
||||
///
|
||||
/// This is a pure conversion that cannot fail.
|
||||
pub fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg {
|
||||
match arg {
|
||||
StatementTmplArg::Literal(lit) => {
|
||||
let value = lower_literal(lit);
|
||||
BuilderArg::Literal(value)
|
||||
}
|
||||
StatementTmplArg::Wildcard(id) => BuilderArg::WildcardLiteral(id.name.clone()),
|
||||
StatementTmplArg::AnchoredKey(ak) => {
|
||||
let key_str = match &ak.key {
|
||||
AnchoredKeyPath::Bracket(s) => s.value.clone(),
|
||||
AnchoredKeyPath::Dot(id) => id.name.clone(),
|
||||
};
|
||||
BuilderArg::Key(ak.root.name.clone(), key_str)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of lowering: optional custom predicate batches and optional request
|
||||
///
|
||||
/// A Podlang file can contain:
|
||||
/// - Just custom predicates (batch: Some, request: None)
|
||||
/// - Just a request (batch: None, request: Some)
|
||||
/// - Both (batch: Some, request: Some)
|
||||
/// - Neither (batch: None, request: None) - just imports
|
||||
/// - Just custom predicates (batches: Some, request: None)
|
||||
/// - Just a request (batches: None, request: Some)
|
||||
/// - Both (batches: Some, request: Some)
|
||||
/// - Neither (batches: None, request: None) - just imports
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoweredOutput {
|
||||
pub batch: Option<Arc<CustomPredicateBatch>>,
|
||||
pub batches: Option<PredicateBatches>,
|
||||
pub request: Option<crate::frontend::PodRequest>,
|
||||
}
|
||||
|
||||
|
|
@ -60,71 +125,70 @@ pub fn lower(
|
|||
struct Lowerer<'a> {
|
||||
validated: ValidatedAST,
|
||||
params: &'a Params,
|
||||
/// Map of predicate names to their index in the current batch (for split predicates)
|
||||
batch_predicate_index: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
impl<'a> Lowerer<'a> {
|
||||
fn new(validated: ValidatedAST, params: &'a Params) -> Self {
|
||||
Self {
|
||||
validated,
|
||||
params,
|
||||
batch_predicate_index: HashMap::new(),
|
||||
}
|
||||
Self { validated, params }
|
||||
}
|
||||
|
||||
fn lower(mut self, batch_name: String) -> Result<LoweredOutput, LoweringError> {
|
||||
// Lower custom predicates (if any)
|
||||
let batch = self.lower_batch(batch_name)?;
|
||||
fn lower(self, batch_name: String) -> Result<LoweredOutput, LoweringError> {
|
||||
// Lower custom predicates (if any) - now supports multiple batches
|
||||
let batches = self.lower_batches(batch_name)?;
|
||||
|
||||
// Lower request (if any) - pass batch so BatchSelf refs can be converted to Custom refs
|
||||
let request = self.lower_request(batch.as_ref())?;
|
||||
// Lower request (if any) - pass batches so refs can be resolved
|
||||
let request = self.lower_request(batches.as_ref())?;
|
||||
|
||||
Ok(LoweredOutput { batch, request })
|
||||
Ok(LoweredOutput { batches, request })
|
||||
}
|
||||
|
||||
fn lower_batch(
|
||||
&mut self,
|
||||
batch_name: String,
|
||||
) -> Result<Option<Arc<CustomPredicateBatch>>, LoweringError> {
|
||||
fn lower_batches(&self, batch_name: String) -> Result<Option<PredicateBatches>, LoweringError> {
|
||||
// Extract and split custom predicates from document
|
||||
let (custom_predicates, original_count) = self.extract_and_split_predicates()?;
|
||||
let custom_predicates = self.extract_and_split_predicates()?;
|
||||
|
||||
// If no custom predicates, return None
|
||||
if custom_predicates.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Check batch size constraint
|
||||
if custom_predicates.len() > self.params.max_custom_batch_size {
|
||||
return Err(LoweringError::TooManyPredicates {
|
||||
batch_name: batch_name.clone(),
|
||||
count: custom_predicates.len(),
|
||||
max: self.params.max_custom_batch_size,
|
||||
original_count,
|
||||
});
|
||||
// Build map of imported predicates for batching
|
||||
let imported_predicates = self.build_imported_predicates_map();
|
||||
|
||||
// Use the new batching module to pack predicates into batches
|
||||
let batches = frontend_ast_batch::batch_predicates(
|
||||
custom_predicates,
|
||||
self.params,
|
||||
&batch_name,
|
||||
&imported_predicates,
|
||||
)?;
|
||||
|
||||
Ok(Some(batches))
|
||||
}
|
||||
|
||||
fn build_imported_predicates_map(
|
||||
&self,
|
||||
) -> HashMap<String, frontend_ast_batch::ImportedPredicateInfo> {
|
||||
let symbols = self.validated.symbols();
|
||||
let mut imported = HashMap::new();
|
||||
|
||||
for (name, info) in &symbols.predicates {
|
||||
if let PredicateKind::BatchImported { batch, index } = &info.kind {
|
||||
imported.insert(
|
||||
name.clone(),
|
||||
frontend_ast_batch::ImportedPredicateInfo {
|
||||
batch: batch.clone(),
|
||||
index: *index,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Build index of all predicates in the batch
|
||||
for (idx, pred) in custom_predicates.iter().enumerate() {
|
||||
self.batch_predicate_index
|
||||
.insert(pred.name.name.clone(), idx);
|
||||
}
|
||||
|
||||
// Create custom predicate batch using builder
|
||||
let mut cpb_builder =
|
||||
CustomPredicateBatchBuilder::new(self.params.clone(), batch_name.clone());
|
||||
|
||||
for pred_def in &custom_predicates {
|
||||
self.lower_custom_predicate(pred_def, &mut cpb_builder)?;
|
||||
}
|
||||
|
||||
Ok(Some(cpb_builder.finish()))
|
||||
imported
|
||||
}
|
||||
|
||||
fn lower_request(
|
||||
&self,
|
||||
batch: Option<&Arc<CustomPredicateBatch>>,
|
||||
batches: Option<&PredicateBatches>,
|
||||
) -> Result<Option<crate::frontend::PodRequest>, LoweringError> {
|
||||
let doc = self.validated.document();
|
||||
|
||||
|
|
@ -141,44 +205,78 @@ impl<'a> Lowerer<'a> {
|
|||
// Build wildcard map from all wildcards used in the request statements
|
||||
let wildcard_map = self.build_request_wildcard_map(request_def);
|
||||
|
||||
// Lower each statement to a builder first
|
||||
let mut statement_builders = Vec::new();
|
||||
for stmt in &request_def.statements {
|
||||
let stmt_builder = self.lower_statement_to_builder(stmt)?;
|
||||
statement_builders.push(stmt_builder);
|
||||
}
|
||||
|
||||
// Resolve builders to middleware statement templates
|
||||
// Lower each statement to middleware templates, resolving predicates
|
||||
let mut request_templates = Vec::new();
|
||||
for stmt_builder in statement_builders {
|
||||
let mw_stmt =
|
||||
self.resolve_request_statement_builder(stmt_builder, &wildcard_map, batch)?;
|
||||
for stmt in &request_def.statements {
|
||||
let mw_stmt = self.lower_request_statement(stmt, &wildcard_map, batches)?;
|
||||
request_templates.push(mw_stmt);
|
||||
}
|
||||
|
||||
Ok(Some(crate::frontend::PodRequest::new(request_templates)))
|
||||
}
|
||||
|
||||
fn resolve_request_statement_builder(
|
||||
fn lower_request_statement(
|
||||
&self,
|
||||
stmt_builder: StatementTmplBuilder,
|
||||
stmt: &StatementTmpl,
|
||||
wildcard_map: &HashMap<String, usize>,
|
||||
batch: Option<&Arc<CustomPredicateBatch>>,
|
||||
batches: Option<&PredicateBatches>,
|
||||
) -> Result<MWStatementTmpl, LoweringError> {
|
||||
// First desugar the builder
|
||||
let desugared = stmt_builder.desugar();
|
||||
|
||||
// Convert BatchSelf predicate to Custom if we have a batch
|
||||
let mut predicate = desugared.predicate;
|
||||
if let Some(batch_ref) = batch {
|
||||
if let Predicate::BatchSelf(index) = predicate {
|
||||
predicate = Predicate::Custom(middleware::CustomPredicateRef::new(
|
||||
batch_ref.clone(),
|
||||
index,
|
||||
));
|
||||
}
|
||||
// Enforce argument count limit for request statements
|
||||
if stmt.args.len() > self.params.max_statement_args {
|
||||
return Err(LoweringError::TooManyStatementArgs {
|
||||
count: stmt.args.len(),
|
||||
max: self.params.max_statement_args,
|
||||
});
|
||||
}
|
||||
|
||||
let pred_name = &stmt.predicate.name;
|
||||
let symbols = self.validated.symbols();
|
||||
|
||||
// Resolve predicate - for request statements, local custom predicates
|
||||
// must be resolved to CustomPredicateRef (not BatchSelf)
|
||||
let predicate = if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||
Predicate::Native(native)
|
||||
} else if let Some(info) = symbols.predicates.get(pred_name) {
|
||||
match &info.kind {
|
||||
PredicateKind::Native(np) => Predicate::Native(*np),
|
||||
PredicateKind::Custom { .. } => {
|
||||
// Local custom predicates - resolve to CustomPredicateRef
|
||||
let batches = batches.ok_or_else(|| LoweringError::PredicateNotFound {
|
||||
name: pred_name.clone(),
|
||||
})?;
|
||||
let pred_ref = batches.predicate_ref_by_name(pred_name).ok_or_else(|| {
|
||||
LoweringError::PredicateNotFound {
|
||||
name: pred_name.clone(),
|
||||
}
|
||||
})?;
|
||||
Predicate::Custom(pred_ref)
|
||||
}
|
||||
PredicateKind::BatchImported { batch, index } => {
|
||||
Predicate::Custom(middleware::CustomPredicateRef::new(batch.clone(), *index))
|
||||
}
|
||||
PredicateKind::IntroImported {
|
||||
name,
|
||||
verifier_data_hash,
|
||||
} => Predicate::Intro(IntroPredicateRef {
|
||||
name: name.clone(),
|
||||
args_len: info.public_arity,
|
||||
verifier_data_hash: *verifier_data_hash,
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
return Err(LoweringError::PredicateNotFound {
|
||||
name: pred_name.clone(),
|
||||
});
|
||||
};
|
||||
|
||||
// Create a builder with the resolved predicate and desugar
|
||||
let mut builder = StatementTmplBuilder::new(predicate);
|
||||
for arg in &stmt.args {
|
||||
let builder_arg = lower_statement_arg(arg);
|
||||
builder = builder.arg(builder_arg);
|
||||
}
|
||||
let desugared = builder.desugar();
|
||||
|
||||
// Convert BuilderArgs to StatementTmplArgs
|
||||
let mut mw_args = Vec::new();
|
||||
for builder_arg in desugared.args {
|
||||
|
|
@ -202,7 +300,7 @@ impl<'a> Lowerer<'a> {
|
|||
|
||||
Ok(MWStatementTmpl {
|
||||
// TODO: Support wildcard
|
||||
pred_or_wc: PredicateOrWildcard::Predicate(predicate),
|
||||
pred_or_wc: PredicateOrWildcard::Predicate(desugared.predicate),
|
||||
args: mw_args,
|
||||
})
|
||||
}
|
||||
|
|
@ -251,7 +349,7 @@ impl<'a> Lowerer<'a> {
|
|||
|
||||
fn extract_and_split_predicates(
|
||||
&self,
|
||||
) -> Result<(Vec<CustomPredicateDef>, usize), LoweringError> {
|
||||
) -> Result<Vec<frontend_ast_split::SplitResult>, LoweringError> {
|
||||
let doc = self.validated.document();
|
||||
let predicates: Vec<CustomPredicateDef> = doc
|
||||
.items
|
||||
|
|
@ -262,182 +360,14 @@ impl<'a> Lowerer<'a> {
|
|||
})
|
||||
.collect();
|
||||
|
||||
let original_count = predicates.len();
|
||||
|
||||
// Apply splitting to each predicate as needed
|
||||
let mut split_predicates = Vec::new();
|
||||
let mut split_results = Vec::new();
|
||||
for pred in predicates {
|
||||
let chain = frontend_ast_split::split_predicate_if_needed(pred, self.params)?;
|
||||
split_predicates.extend(chain);
|
||||
let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?;
|
||||
split_results.push(result);
|
||||
}
|
||||
|
||||
Ok((split_predicates, original_count))
|
||||
}
|
||||
|
||||
fn lower_custom_predicate(
|
||||
&self,
|
||||
pred_def: &CustomPredicateDef,
|
||||
cpb_builder: &mut CustomPredicateBatchBuilder,
|
||||
) -> Result<(), LoweringError> {
|
||||
let name = pred_def.name.name.clone();
|
||||
|
||||
// Note: Constraint checking is handled by the splitting phase
|
||||
// Predicates passed here should already be within limits
|
||||
|
||||
// Collect public and private argument names
|
||||
let mut public_arg_names = Vec::new();
|
||||
let mut private_arg_names = Vec::new();
|
||||
|
||||
for arg in &pred_def.args.public_args {
|
||||
public_arg_names.push(arg.name.clone());
|
||||
}
|
||||
|
||||
if let Some(private_args) = &pred_def.args.private_args {
|
||||
for arg in private_args {
|
||||
private_arg_names.push(arg.name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Lower statements to builders
|
||||
let mut statement_builders = Vec::new();
|
||||
for stmt in &pred_def.statements {
|
||||
let stmt_builder = self.lower_statement_to_builder(stmt)?;
|
||||
statement_builders.push(stmt_builder);
|
||||
}
|
||||
|
||||
// Convert to &str slices for builder API
|
||||
let public_args_str: Vec<&str> = public_arg_names.iter().map(|s| s.as_str()).collect();
|
||||
let private_args_str: Vec<&str> = private_arg_names.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
// Add predicate to batch using builder
|
||||
let conjunction = pred_def.conjunction_type == ConjunctionType::And;
|
||||
|
||||
cpb_builder
|
||||
.predicate(
|
||||
&name,
|
||||
conjunction,
|
||||
&public_args_str,
|
||||
&private_args_str,
|
||||
&statement_builders,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
crate::frontend::Error::Middleware(mw_err) => LoweringError::Middleware(mw_err),
|
||||
_ => LoweringError::InvalidArgumentType,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn lower_statement_to_builder(
|
||||
&self,
|
||||
stmt: &StatementTmpl,
|
||||
) -> Result<StatementTmplBuilder, LoweringError> {
|
||||
// Get predicate
|
||||
let pred_name = &stmt.predicate.name;
|
||||
let symbols = self.validated.symbols();
|
||||
|
||||
// Check for native predicates first
|
||||
let predicate = if let Ok(native) = NativePredicate::from_str(pred_name) {
|
||||
Predicate::Native(native)
|
||||
} else if let Some(&index) = self.batch_predicate_index.get(pred_name) {
|
||||
// References to other predicates in the same batch (including split chains)
|
||||
Predicate::BatchSelf(index)
|
||||
} else if let Some(info) = symbols.predicates.get(pred_name) {
|
||||
match &info.kind {
|
||||
PredicateKind::Native(np) => Predicate::Native(*np),
|
||||
PredicateKind::Custom { index } => Predicate::BatchSelf(*index),
|
||||
PredicateKind::BatchImported { batch, index } => {
|
||||
Predicate::Custom(middleware::CustomPredicateRef::new(batch.clone(), *index))
|
||||
}
|
||||
PredicateKind::IntroImported {
|
||||
name,
|
||||
verifier_data_hash,
|
||||
} => Predicate::Intro(IntroPredicateRef {
|
||||
name: name.clone(),
|
||||
args_len: info.public_arity,
|
||||
verifier_data_hash: *verifier_data_hash,
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
unreachable!("Predicate {} not found", pred_name);
|
||||
};
|
||||
|
||||
// Check args count
|
||||
if stmt.args.len() > self.params.max_statement_args {
|
||||
return Err(LoweringError::TooManyStatementArgs {
|
||||
count: stmt.args.len(),
|
||||
max: self.params.max_statement_args,
|
||||
});
|
||||
}
|
||||
|
||||
// Convert AST args to BuilderArgs
|
||||
let mut builder = StatementTmplBuilder::new(predicate);
|
||||
for arg in &stmt.args {
|
||||
let builder_arg = Self::lower_statement_arg_to_builder(arg)?;
|
||||
builder = builder.arg(builder_arg);
|
||||
}
|
||||
|
||||
// Return builder without calling .desugar() - that will happen later
|
||||
Ok(builder)
|
||||
}
|
||||
|
||||
fn lower_statement_arg_to_builder(arg: &StatementTmplArg) -> Result<BuilderArg, LoweringError> {
|
||||
match arg {
|
||||
StatementTmplArg::Literal(lit) => {
|
||||
let value = Self::lower_literal(lit)?;
|
||||
Ok(BuilderArg::Literal(value))
|
||||
}
|
||||
StatementTmplArg::Wildcard(id) => {
|
||||
// For builder, we just need the wildcard name
|
||||
Ok(BuilderArg::WildcardLiteral(id.name.clone()))
|
||||
}
|
||||
StatementTmplArg::AnchoredKey(ak) => {
|
||||
let key_str = match &ak.key {
|
||||
AnchoredKeyPath::Bracket(s) => s.value.clone(),
|
||||
AnchoredKeyPath::Dot(id) => id.name.clone(),
|
||||
};
|
||||
Ok(BuilderArg::Key(ak.root.name.clone(), key_str))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lower_literal(lit: &LiteralValue) -> Result<middleware::Value, LoweringError> {
|
||||
let value = match lit {
|
||||
LiteralValue::Int(i) => middleware::Value::from(i.value),
|
||||
LiteralValue::Bool(b) => middleware::Value::from(b.value),
|
||||
LiteralValue::String(s) => middleware::Value::from(s.value.clone()),
|
||||
LiteralValue::Raw(r) => middleware::Value::from(r.hash.hash),
|
||||
LiteralValue::PublicKey(pk) => middleware::Value::from(pk.point),
|
||||
LiteralValue::SecretKey(sk) => middleware::Value::from(sk.secret_key.clone()),
|
||||
LiteralValue::Array(a) => {
|
||||
let elements: Result<Vec<_>, _> =
|
||||
a.elements.iter().map(Self::lower_literal).collect();
|
||||
let array = containers::Array::new(elements?);
|
||||
middleware::Value::from(array)
|
||||
}
|
||||
LiteralValue::Set(s) => {
|
||||
let elements: Result<Vec<_>, _> =
|
||||
s.elements.iter().map(Self::lower_literal).collect();
|
||||
let set_values: std::collections::HashSet<_> = elements?.into_iter().collect();
|
||||
let set = containers::Set::new(set_values);
|
||||
middleware::Value::from(set)
|
||||
}
|
||||
LiteralValue::Dict(d) => {
|
||||
let pairs: Result<Vec<(middleware::Key, middleware::Value)>, LoweringError> = d
|
||||
.pairs
|
||||
.iter()
|
||||
.map(|pair| {
|
||||
let key = middleware::Key::from(pair.key.value.as_str());
|
||||
let value = Self::lower_literal(&pair.value)?;
|
||||
Ok((key, value))
|
||||
})
|
||||
.collect();
|
||||
let dict_map: std::collections::HashMap<_, _> = pairs?.into_iter().collect();
|
||||
let dict = containers::Dictionary::new(dict_map);
|
||||
middleware::Value::from(dict)
|
||||
}
|
||||
};
|
||||
Ok(value)
|
||||
Ok(split_results)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -458,9 +388,16 @@ mod tests {
|
|||
lower(validated, params, "test_batch".to_string())
|
||||
}
|
||||
|
||||
// Helper to get the batch from the output (expecting it to exist)
|
||||
fn expect_batch(output: &LoweredOutput) -> &Arc<CustomPredicateBatch> {
|
||||
output.batch.as_ref().expect("Expected batch to be present")
|
||||
// Helper to get the first batch from the output (expecting it to exist)
|
||||
fn expect_batch(
|
||||
output: &LoweredOutput,
|
||||
) -> &std::sync::Arc<crate::middleware::CustomPredicateBatch> {
|
||||
output
|
||||
.batches
|
||||
.as_ref()
|
||||
.expect("Expected batches to be present")
|
||||
.first_batch()
|
||||
.expect("Expected at least one batch")
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -547,13 +484,20 @@ mod tests {
|
|||
|
||||
let lowered = result.unwrap();
|
||||
// Should be automatically split into 2 predicates (my_pred and my_pred_1)
|
||||
assert_eq!(expect_batch(&lowered).predicates().len(), 2);
|
||||
let batches = lowered.batches.as_ref().expect("Expected batches");
|
||||
assert_eq!(batches.total_predicate_count(), 2);
|
||||
|
||||
// First predicate should have 5 statements (4 + chain call)
|
||||
assert_eq!(expect_batch(&lowered).predicates()[0].statements().len(), 5);
|
||||
|
||||
// Second predicate should have 2 statements
|
||||
assert_eq!(expect_batch(&lowered).predicates()[1].statements().len(), 2);
|
||||
// With topological sorting, my_pred_1 comes first (since my_pred depends on it)
|
||||
// my_pred_1 has 2 statements
|
||||
// my_pred has 5 statements (4 + chain call)
|
||||
// Just verify we have the right total statement counts
|
||||
let batch = batches.first_batch().unwrap();
|
||||
let total_statements: usize = batch
|
||||
.predicates()
|
||||
.iter()
|
||||
.map(|p| p.statements().len())
|
||||
.sum();
|
||||
assert_eq!(total_statements, 7); // 5 + 2 = 7 total statements
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -642,108 +586,64 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_message_with_splitting() {
|
||||
// Create a document with predicates that will exceed the batch limit after splitting
|
||||
// We'll create 2 predicates with 4 statements each (max arity = 5)
|
||||
// Each will NOT split individually, but together they exceed a small batch limit
|
||||
fn test_multi_batch_packing() {
|
||||
// Create more predicates than fit in a single batch
|
||||
// With max_custom_batch_size = 4, 5 predicates should span 2 batches
|
||||
let input = r#"
|
||||
pred1(A) = AND (
|
||||
Equal(A["a"], 1)
|
||||
Equal(A["b"], 2)
|
||||
)
|
||||
pred2(B) = AND (
|
||||
Equal(B["c"], 3)
|
||||
Equal(B["d"], 4)
|
||||
)
|
||||
pred1(A) = AND(Equal(A["a"], 1))
|
||||
pred2(B) = AND(Equal(B["b"], 2))
|
||||
pred3(C) = AND(Equal(C["c"], 3))
|
||||
pred4(D) = AND(Equal(D["d"], 4))
|
||||
pred5(E) = AND(Equal(E["e"], 5))
|
||||
"#;
|
||||
|
||||
// Use very restrictive params to force the error
|
||||
let params = Params {
|
||||
max_custom_batch_size: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let params = Params::default(); // max_custom_batch_size = 4
|
||||
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Should fail with TooManyPredicates error
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
let lowered = result.unwrap();
|
||||
let batches = lowered.batches.as_ref().expect("Expected batches");
|
||||
|
||||
if let LoweringError::TooManyPredicates {
|
||||
count,
|
||||
max,
|
||||
original_count,
|
||||
..
|
||||
} = err
|
||||
{
|
||||
assert_eq!(count, 2); // 2 predicates after splitting (no splitting occurred)
|
||||
assert_eq!(max, 1);
|
||||
assert_eq!(original_count, 2); // Started with 2 predicates
|
||||
// Should have 2 batches
|
||||
assert_eq!(batches.batch_count(), 2);
|
||||
assert_eq!(batches.total_predicate_count(), 5);
|
||||
|
||||
// Error message should NOT mention splitting since no splitting occurred
|
||||
let err_msg = format!("{}", err);
|
||||
assert!(!err_msg.contains("before automatic splitting"));
|
||||
} else {
|
||||
panic!("Expected TooManyPredicates error, got: {:?}", err);
|
||||
}
|
||||
// First batch should have 4 predicates
|
||||
assert_eq!(batches.batches()[0].predicates().len(), 4);
|
||||
// Second batch should have 1 predicate
|
||||
assert_eq!(batches.batches()[1].predicates().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_message_after_splitting() {
|
||||
// Create TWO predicates that will EACH split into 2 predicates
|
||||
// This tests the case where splitting causes the batch to be too large
|
||||
// but no individual predicate chain exceeds the limit
|
||||
fn test_split_chains_span_batches() {
|
||||
// Create predicates that will split, plus additional predicates
|
||||
// to force the split chains across batch boundaries
|
||||
let input = r#"
|
||||
pred1(A) = AND (
|
||||
Equal(A["a"], 1)
|
||||
Equal(A["b"], 2)
|
||||
Equal(A["c"], 3)
|
||||
Equal(A["d"], 4)
|
||||
Equal(A["e"], 5)
|
||||
Equal(A["f"], 6)
|
||||
)
|
||||
pred2(B) = AND (
|
||||
Equal(B["a"], 1)
|
||||
Equal(B["b"], 2)
|
||||
Equal(B["c"], 3)
|
||||
Equal(B["d"], 4)
|
||||
Equal(B["e"], 5)
|
||||
Equal(B["f"], 6)
|
||||
pred1(A) = AND(Equal(A["a"], 1))
|
||||
pred2(B) = AND(Equal(B["b"], 2))
|
||||
pred3(C) = AND(Equal(C["c"], 3))
|
||||
large_pred(D) = AND(
|
||||
Equal(D["a"], 1)
|
||||
Equal(D["b"], 2)
|
||||
Equal(D["c"], 3)
|
||||
Equal(D["d"], 4)
|
||||
Equal(D["e"], 5)
|
||||
Equal(D["f"], 6)
|
||||
)
|
||||
"#;
|
||||
|
||||
// Use params where each predicate splits into 2, but total of 4 exceeds batch limit
|
||||
let params = Params {
|
||||
// Allow 3 predicates in batch
|
||||
// Default max_custom_predicate_arity is 5, so each will split into 2 predicates
|
||||
// Total: 2 original predicates -> 4 after splitting (exceeds limit of 3)
|
||||
max_custom_batch_size: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let params = Params::default();
|
||||
|
||||
let result = parse_validate_and_lower(input, ¶ms);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Should fail with TooManyPredicates error
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
let lowered = result.unwrap();
|
||||
let batches = lowered.batches.as_ref().expect("Expected batches");
|
||||
|
||||
if let LoweringError::TooManyPredicates {
|
||||
count,
|
||||
max,
|
||||
original_count,
|
||||
..
|
||||
} = err
|
||||
{
|
||||
assert_eq!(count, 4); // 4 predicates after splitting (2 from each)
|
||||
assert_eq!(max, 3);
|
||||
assert_eq!(original_count, 2); // Started with 2 predicates
|
||||
|
||||
// Error message SHOULD mention splitting since splitting occurred
|
||||
let err_msg = format!("{}", err);
|
||||
assert!(err_msg.contains("before automatic splitting"));
|
||||
assert!(err_msg.contains("started with 2 predicates"));
|
||||
} else {
|
||||
panic!("Expected TooManyPredicates error, got: {:?}", err);
|
||||
}
|
||||
// pred1, pred2, pred3 + large_pred split into 2 = 5 total predicates
|
||||
// Should span 2 batches
|
||||
assert_eq!(batches.total_predicate_count(), 5);
|
||||
assert_eq!(batches.batch_count(), 2);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue