diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 7ca4d8c..8de6871 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -15,8 +15,8 @@ use crate::{ #[derive(Clone, Debug)] pub enum BuilderArg { Literal(Value), - /// Key: (origin, key), where origin is the wildcard name. - Key(String, Key), + /// Key: (origin, key), where origin is Wildcard and key is Key + Key(String, String), WildcardLiteral(String), /// Reference to a same-batch predicate's identity hash (resolved by name in finish()). SelfPredicateHash(String), @@ -29,7 +29,7 @@ pub enum BuilderArg { /// case i. impl From<(&str, &str)> for BuilderArg { fn from((origin, field): (&str, &str)) -> Self { - Self::Key(origin.to_string(), Key::from(field)) + Self::Key(origin.to_string(), field.to_string()) } } /// case ii. @@ -219,9 +219,9 @@ impl CustomPredicateBatchBuilder { .map(|(arg_idx, a)| { Ok::<_, Error>(match a { BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), - BuilderArg::Key(root_wc, key) => StatementTmplArg::AnchoredKey( + BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey( resolve_wildcard(args, priv_args, root_wc)?, - key.clone(), + Key::from(key_str), ), BuilderArg::WildcardLiteral(v) => { StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 999a3a4..b6e8691 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -15,10 +15,10 @@ pub use serialization::SerializedMainPod; use crate::middleware::{ self, check_custom_pred, containers::{Container, Dictionary}, - fill_wildcard_values, hash_op, max_op, prod_op, root_key_to_ak, sum_op, AnchoredKey, Hash, + fill_wildcard_values, hash_op, max_op, prod_op, root_key_to_ak, sum_op, AnchoredKey, Hash, Key, MainPodInputs, MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey, - RawValue, Signature, Signer, Statement, StatementArg, StrKey, VDSet, Value, ValueRef, - BASE_PARAMS, EMPTY_VALUE, + RawValue, Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, BASE_PARAMS, + EMPTY_VALUE, }; mod custom; @@ -39,7 +39,7 @@ pub use pod_request::*; #[derive(Clone, Debug)] pub struct SignedDictBuilder { pub params: Params, - pub kvs: HashMap, + pub kvs: HashMap, } impl fmt::Display for SignedDictBuilder { @@ -60,7 +60,7 @@ impl SignedDictBuilder { } } - pub fn insert(&mut self, key: impl Into, value: impl Into) { + pub fn insert(&mut self, key: impl Into, value: impl Into) { self.kvs.insert(key.into(), value.into()); } @@ -111,12 +111,12 @@ impl SignedDict { .then_some(()) .ok_or(Error::custom("Invalid signature!")) } - pub fn get(&self, key: impl Into) -> Option { + pub fn get(&self, key: impl Into) -> Option { self.dict.get(&key.into()).unwrap() } // Returns the Contains statement that defines key if it exists. - pub fn get_statement(&self, key: impl Into) -> Option { - let key: StrKey = key.into(); + pub fn get_statement(&self, key: impl Into) -> Option { + let key: Key = key.into(); self.dict.get(&key).unwrap().map(|value| { Statement::Contains( ValueRef::Literal(Value::from(self.dict.clone())), @@ -1112,11 +1112,11 @@ pub mod tests { OperationArg::Statement(st1), OperationArg::Literal(Value::from(1)), ], - OperationAux::MerkleProof(dict.prove(&StrKey::from("a")).unwrap().1), + OperationAux::MerkleProof(dict.prove(&Key::from("a")).unwrap().1), ))?; let mut new_dict = dict.clone(); - new_dict.insert(&StrKey::from("d"), &Value::from(4))?; + new_dict.insert(&Key::from("d"), &Value::from(4))?; builder.pub_op(Operation( OperationType::Native(NativeOperation::DictInsertFromEntries), @@ -1130,7 +1130,7 @@ pub mod tests { ))?; let mut new_old_dict = new_dict.clone(); - new_old_dict.delete(&StrKey::from("d"))?; + new_old_dict.delete(&Key::from("d"))?; assert_eq!(new_old_dict, dict); @@ -1144,7 +1144,7 @@ pub mod tests { OperationAux::None, ))?; - new_old_dict.update(&StrKey::from("c"), &55.into())?; + new_old_dict.update(&Key::from("c"), &55.into())?; builder.pub_op(Operation( OperationType::Native(NativeOperation::DictUpdateFromEntries), diff --git a/src/lang/diagnostics.rs b/src/lang/diagnostics.rs index 8c33e40..7807318 100644 --- a/src/lang/diagnostics.rs +++ b/src/lang/diagnostics.rs @@ -286,123 +286,6 @@ fn render_validation_error( "not allowed here", ) } - - ValidationError::DuplicateRecord { - name, - first_span, - second_span, - } => { - let title = format!("duplicate record definition: {}", name); - render_dual_span( - renderer, - source, - path, - &title, - first_span.as_ref(), - "first definition here", - second_span.as_ref(), - "duplicate definition", - ) - } - - ValidationError::RecordTooManyEntries { - name, - count, - max, - span, - } => { - let title = format!( - "record `{}` has {} entries, exceeding the limit of {}", - name, count, max - ); - render_with_optional_span( - renderer, - source, - path, - &title, - span.as_ref(), - "too many entries", - ) - } - - ValidationError::DuplicateRecordEntry { - record, - entry, - span, - } => { - let title = format!("duplicate entry `{}` in record `{}`", entry, record); - render_with_optional_span( - renderer, - source, - path, - &title, - span.as_ref(), - "already declared", - ) - } - - ValidationError::UnknownRecord { name, span } => { - let title = format!("unknown record type: {}", name); - render_with_optional_span( - renderer, - source, - path, - &title, - span.as_ref(), - "no such record", - ) - } - - ValidationError::UnknownRecordEntry { - record, - entry, - span, - } => { - let title = format!("record `{}` has no entry `{}`", record, entry); - render_with_optional_span( - renderer, - source, - path, - &title, - span.as_ref(), - "unknown entry", - ) - } - - ValidationError::DuplicateLiteralRecordEntry { - record, - entry, - span, - } => { - let title = format!("duplicate entry `{}` in `{}` literal", entry, record); - render_with_optional_span( - renderer, - source, - path, - &title, - span.as_ref(), - "already given", - ) - } - - ValidationError::BracketAccessOnTypedWildcard { - wildcard, - record, - span, - } => { - let title = format!( - "bracket access on `{}` (typed as record `{}`); use `{}.entry` instead", - wildcard, record, wildcard - ); - render_with_optional_span( - renderer, - source, - path, - &title, - span.as_ref(), - "string-key access on integer-keyed record", - ) - } } } diff --git a/src/lang/error.rs b/src/lang/error.rs index ae0d4ca..792d4d8 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -164,52 +164,6 @@ pub enum ValidationError { #[error("Requests must contain a REQUEST block")] NoRequestBlock, - - #[error("Duplicate record definition: {name}")] - DuplicateRecord { - name: String, - first_span: Option, - second_span: Option, - }, - - #[error("Record '{name}' has {count} entries, exceeding the limit of {max}")] - RecordTooManyEntries { - name: String, - count: usize, - max: usize, - span: Option, - }, - - #[error("Duplicate entry name '{entry}' in record '{record}'")] - DuplicateRecordEntry { - record: String, - entry: String, - span: Option, - }, - - #[error("Unknown record type: {name}")] - UnknownRecord { name: String, span: Option }, - - #[error("Record '{record}' has no entry '{entry}'")] - UnknownRecordEntry { - record: String, - entry: String, - span: Option, - }, - - #[error("Duplicate entry '{entry}' in record literal '{record}'")] - DuplicateLiteralRecordEntry { - record: String, - entry: String, - span: Option, - }, - - #[error("Bracket access '{wildcard}[...]' is not allowed on a wildcard typed as record '{record}'; use `{wildcard}.entry` instead")] - BracketAccessOnTypedWildcard { - wildcard: String, - record: String, - span: Option, - }, } /// Lowering errors from frontend AST lowering to middleware @@ -251,6 +205,111 @@ pub enum LoweringError { ValidationErrors, } +/// Context information for split boundary failures +#[derive(Debug, Clone)] +pub struct SplitContext { + /// Index of the split boundary (0-based) + pub split_index: usize, + /// Range of statement indices in the segment before the split + pub statement_range: (usize, usize), + /// Public arguments coming into this segment + pub incoming_public: Vec, + /// Wildcards that cross this boundary (need to be promoted) + pub crossing_wildcards: Vec, + /// Total public arguments needed (incoming + crossing) + pub total_public: usize, +} + +/// Suggestions for refactoring predicates that fail to split +#[derive(Debug, Clone)] +pub enum RefactorSuggestion { + /// A wildcard is used across too many statements + ReduceWildcardSpan { + wildcard: String, + first_use: usize, + last_use: usize, + span: usize, + }, + /// Multiple wildcards should be grouped together + GroupWildcardUsages { wildcards: Vec }, +} + +impl RefactorSuggestion { + pub fn format(&self) -> String { + match self { + RefactorSuggestion::ReduceWildcardSpan { + wildcard, + first_use, + last_use, + span, + } => { + format!( + "Wildcard '{}' is used across {} statements (statements {}-{}).\n\ + Consider grouping all '{}' operations together, or split the wildcard\n\ + into separate early/late variables.", + wildcard, span, first_use, last_use, wildcard + ) + } + RefactorSuggestion::GroupWildcardUsages { wildcards } => { + format!( + "Group operations for wildcards: {}\n\ + These wildcards are used across multiple segments. Try to complete\n\ + all operations for each wildcard before moving to the next.", + wildcards.join(", ") + ) + } + } + } +} + +/// Formats a detailed error message for TooManyPublicArgsAtSplit +fn format_public_args_at_split_error( + predicate: &str, + context: &SplitContext, + max_allowed: usize, + suggestion: &Option>, +) -> String { + let mut msg = format!( + "Too many public arguments at split boundary {} in predicate '{}':\n", + context.split_index, predicate + ); + + msg.push_str(&format!( + " {} incoming public + {} crossing wildcards = {} total (exceeds max of {})\n", + context.incoming_public.len(), + context.crossing_wildcards.len(), + context.total_public, + max_allowed + )); + + msg.push_str(&format!( + " Statements {}-{} in this segment\n", + context.statement_range.0, + context.statement_range.1 - 1 + )); + + if !context.incoming_public.is_empty() { + msg.push_str(&format!( + " Incoming public args: {}\n", + context.incoming_public.join(", ") + )); + } + + if !context.crossing_wildcards.is_empty() { + msg.push_str(&format!( + " Wildcards crossing this boundary: {}\n", + context.crossing_wildcards.join(", ") + )); + } + + if let Some(suggestion) = suggestion { + msg.push_str("\nSuggestion:\n"); + msg.push_str(&suggestion.format()); + } + + msg +} + /// Batching errors from multi-batch packing #[derive(Debug, thiserror::Error)] pub enum BatchingError { @@ -269,14 +328,30 @@ pub enum SplittingError { message: String, }, - #[error("Could not split predicate '{predicate}' into a chain: no feasible partition exists with up to {max_links} links. \ - The predicate's wildcard structure may be too dense for any chain to fit within max_statement_args ({max_statement_args}) \ - and max_custom_predicate_wildcards ({max_wildcards}) per link.")] - Infeasible { + #[error("Too many total arguments in predicate '{predicate}': {count} exceeds max of {max_allowed}. {message}")] + TooManyTotalArgs { predicate: String, - max_links: usize, - max_statement_args: usize, - max_wildcards: usize, + count: usize, + max_allowed: usize, + message: String, + }, + + #[error("Too many total arguments in chain link {link_index} of predicate '{predicate}': {public_count} public + {private_count} private = {total_count} total (exceeds max of {max_allowed})")] + TooManyTotalArgsInChainLink { + predicate: String, + link_index: usize, + public_count: usize, + private_count: usize, + total_count: usize, + max_allowed: usize, + }, + + #[error("{}", format_public_args_at_split_error(.predicate, .context, *.max_allowed, .suggestion))] + TooManyPublicArgsAtSplit { + predicate: String, + context: Box, + max_allowed: usize, + suggestion: Option>, }, } diff --git a/src/lang/frontend_ast.rs b/src/lang/frontend_ast.rs index 14ea33d..dd0052c 100644 --- a/src/lang/frontend_ast.rs +++ b/src/lang/frontend_ast.rs @@ -20,19 +20,10 @@ pub struct Document { pub enum DocumentItem { UseModuleStatement(UseModuleStatement), UseIntroStatement(UseIntroStatement), - RecordDef(RecordDef), CustomPredicateDef(CustomPredicateDef), RequestDef(RequestDef), } -/// Record definition: `record Name = (entry1, entry2, ...)` -#[derive(Debug, Clone, PartialEq)] -pub struct RecordDef { - pub name: Identifier, - pub entries: Vec, - pub span: Option, -} - /// Module import statement: `use module 0xHASH as alias` #[derive(Debug, Clone, PartialEq)] pub struct UseModuleStatement { @@ -77,48 +68,11 @@ pub struct RequestDef { /// Argument section with public and optional private arguments #[derive(Debug, Clone, PartialEq)] pub struct ArgSection { - pub public_args: Vec, - pub private_args: Option>, + pub public_args: Vec, + pub private_args: Option>, pub span: Option, } -/// Predicate argument: `name`, `name TypeName`, or `name module::TypeName`. -/// The optional `type_name` names a record type whose dot-access entries are -/// resolved at lowering time. -#[derive(Debug, Clone, PartialEq)] -pub struct TypedArg { - pub name: String, - pub type_name: Option, - pub span: Option, -} - -/// Reference to a record type — either a local declaration in this module -/// or an import via `use module ... as alias`. -#[derive(Debug, Clone, PartialEq)] -pub enum TypeRef { - Local(Identifier), - Qualified { - module: Identifier, - name: Identifier, - }, -} - -impl TypeRef { - pub fn span(&self) -> Option { - match self { - TypeRef::Local(id) => id.span, - TypeRef::Qualified { name, .. } => name.span, - } - } - - /// Key used to look this reference up in `SymbolTable.records`: the bare - /// name for locals, `alias::Name` for qualified imports. Pairs with - /// `qualified_record_key` (which builds the same shape from raw parts). - pub fn symbol_table_key(&self) -> String { - self.to_string() - } -} - /// Conjunction type for custom predicates #[derive(Debug, Clone, Copy, PartialEq)] pub enum ConjunctionType { @@ -134,18 +88,6 @@ pub struct StatementTmpl { pub span: Option, } -impl StatementTmpl { - /// Names of all wildcards referenced by this statement's arguments, - /// in argument order with duplicates included. - pub fn wildcard_names(&self) -> impl Iterator { - self.args.iter().filter_map(|arg| match arg { - StatementTmplArg::Wildcard(id) => Some(id.name.as_str()), - StatementTmplArg::AnchoredKey(ak) => Some(ak.root.name.as_str()), - StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => None, - }) - } -} - /// Reference to a predicate (local or qualified with module name) #[derive(Debug, Clone, PartialEq)] pub enum PredicateRef { @@ -166,13 +108,6 @@ impl PredicateRef { PredicateRef::Qualified { predicate, .. } => &predicate.name, } } - - pub fn span(&self) -> Option { - match self { - PredicateRef::Local(id) => id.span, - PredicateRef::Qualified { predicate, .. } => predicate.span, - } - } } /// Arguments that can be passed to statements @@ -193,15 +128,20 @@ pub struct AnchoredKey { pub span: Option, } +impl AnchoredKey { + pub fn key_str(&self) -> &str { + match &self.key { + AnchoredKeyPath::Bracket(ls) => &ls.value, + AnchoredKeyPath::Dot(id) => &id.name, + } + } +} + /// Key path in an anchored key #[derive(Debug, Clone, PartialEq)] pub enum AnchoredKeyPath { Bracket(LiteralString), // ["key"] Dot(Identifier), // .key - /// Integer-indexed key. Not produced by the parser; introduced by lowering - /// when a `Dot` access on a record-typed wildcard is resolved to an entry - /// index. - Index(i64), } /// Identifier (variable names, predicate names, etc.) @@ -230,7 +170,6 @@ pub enum LiteralValue { Array(LiteralArray), Set(LiteralSet), Dict(LiteralDict), - Record(LiteralRecord), /// Hash of a native predicate (resolved immediately). NativePredicateHash(Identifier), /// Hash of an external module's predicate (resolved immediately). @@ -238,13 +177,6 @@ pub enum LiteralValue { module: Identifier, predicate: Identifier, }, - /// Compile-time integer literal that resolves to the index of a named - /// entry in a record schema: `R::foo` (local) or `mod::R::foo` (imported). - /// Lowers to `Value::from(idx as i64)` after schema resolution. - RecordEntryIndex { - record: TypeRef, - entry: Identifier, - }, } /// Integer literal @@ -318,23 +250,6 @@ pub struct DictPair { pub span: Option, } -/// Record literal: `Name(Entry: value, ...)` (local) or -/// `module::Name(Entry: value, ...)` (imported). Entries may appear in any -/// order; the schema (resolved in validation) maps each to its index. -#[derive(Debug, Clone, PartialEq)] -pub struct LiteralRecord { - pub name: TypeRef, - pub entries: Vec, - pub span: Option, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct RecordEntryLiteral { - pub name: Identifier, - pub value: LiteralValue, - pub span: Option, -} - /// Source location information for error reporting and formatting #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Span { @@ -361,7 +276,6 @@ impl fmt::Display for DocumentItem { match self { DocumentItem::UseModuleStatement(u) => write!(f, "{}", u), DocumentItem::UseIntroStatement(u) => write!(f, "{}", u), - DocumentItem::RecordDef(r) => write!(f, "{}", r), DocumentItem::CustomPredicateDef(c) => write!(f, "{}", c), DocumentItem::RequestDef(r) => write!(f, "{}", r), } @@ -448,38 +362,6 @@ impl fmt::Display for ArgSection { } } -impl fmt::Display for TypedArg { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.name)?; - if let Some(t) = &self.type_name { - write!(f, " {}", t)?; - } - Ok(()) - } -} - -impl fmt::Display for TypeRef { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TypeRef::Local(id) => write!(f, "{}", id), - TypeRef::Qualified { module, name } => write!(f, "{}::{}", module, name), - } - } -} - -impl fmt::Display for RecordDef { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "record {} = (", self.name)?; - for (i, entry) in self.entries.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", entry)?; - } - write!(f, ")") - } -} - impl fmt::Display for ConjunctionType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -536,7 +418,6 @@ impl fmt::Display for AnchoredKey { match &self.key { AnchoredKeyPath::Bracket(s) => write!(f, "{}[{}]", self.root, s), AnchoredKeyPath::Dot(id) => write!(f, "{}.{}", self.root, id), - AnchoredKeyPath::Index(i) => write!(f, "{}[{}]", self.root, i), } } } @@ -553,39 +434,16 @@ impl fmt::Display for LiteralValue { LiteralValue::Array(a) => write!(f, "{}", a), LiteralValue::Set(s) => write!(f, "{}", s), LiteralValue::Dict(d) => write!(f, "{}", d), - LiteralValue::Record(r) => write!(f, "{}", r), LiteralValue::NativePredicateHash(id) => { write!(f, "@native_predicate({})", id) } LiteralValue::ExternalPredicateHash { module, predicate, .. } => write!(f, "@external_predicate({}, {})", module, predicate), - LiteralValue::RecordEntryIndex { record, entry } => { - write!(f, "{}::{}", record, entry) - } } } } -impl fmt::Display for LiteralRecord { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}(", self.name)?; - for (i, entry) in self.entries.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", entry)?; - } - write!(f, ")") - } -} - -impl fmt::Display for RecordEntryLiteral { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}: {}", self.name, self.value) - } -} - impl fmt::Display for LiteralInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.value) @@ -704,9 +562,6 @@ pub mod parse { inner_pair, ))); } - Rule::record_def => { - items.push(DocumentItem::RecordDef(parse_record_def(inner_pair))); - } Rule::custom_predicate_def => { items.push(DocumentItem::CustomPredicateDef( parse_custom_predicate_def(inner_pair)?, @@ -832,16 +687,16 @@ pub mod parse { Rule::public_arg_list => { public_args = inner_pair .into_inner() - .filter(|p| p.as_rule() == Rule::typed_arg) - .map(parse_typed_arg) + .filter(|p| p.as_rule() == Rule::identifier) + .map(parse_identifier) .collect(); } Rule::private_arg_list => { private_args = Some( inner_pair .into_inner() - .filter(|p| p.as_rule() == Rule::typed_arg) - .map(parse_typed_arg) + .filter(|p| p.as_rule() == Rule::identifier) + .map(parse_identifier) .collect(), ); } @@ -856,50 +711,6 @@ pub mod parse { } } - fn parse_typed_arg(pair: Pair) -> TypedArg { - assert_eq!(pair.as_rule(), Rule::typed_arg); - let span = get_span(&pair); - let mut inner = pair.into_inner(); - let name_pair = inner.next().unwrap(); - let name = name_pair.as_str().to_string(); - let type_name = inner.next().map(parse_type_tag); - TypedArg { - name, - type_name, - span: Some(span), - } - } - - fn parse_type_tag(pair: Pair) -> TypeRef { - assert_eq!(pair.as_rule(), Rule::type_tag); - let inner = pair.into_inner().next().expect("type_tag has one child"); - match inner.as_rule() { - Rule::identifier => TypeRef::Local(parse_identifier(inner)), - Rule::qualified_type_ref => { - let mut idents = inner.into_inner(); - let module = parse_identifier(idents.next().unwrap()); - let name = parse_identifier(idents.next().unwrap()); - TypeRef::Qualified { module, name } - } - other => unreachable!("unexpected type_tag inner rule: {other:?}"), - } - } - - fn parse_record_def(pair: Pair) -> RecordDef { - assert_eq!(pair.as_rule(), Rule::record_def); - let span = get_span(&pair); - let mut idents = pair - .into_inner() - .filter(|p| p.as_rule() == Rule::identifier); - let name = parse_identifier(idents.next().unwrap()); - let entries: Vec<_> = idents.map(parse_identifier).collect(); - RecordDef { - name, - entries, - span: Some(span), - } - } - fn parse_conjunction_type(pair: Pair) -> ConjunctionType { assert_eq!(pair.as_rule(), Rule::conjunction_type); match pair.as_str() { @@ -1034,7 +845,6 @@ pub mod parse { Rule::literal_array => Ok(LiteralValue::Array(parse_literal_array(inner)?)), Rule::literal_set => Ok(LiteralValue::Set(parse_literal_set(inner)?)), Rule::literal_dict => Ok(LiteralValue::Dict(parse_literal_dict(inner)?)), - Rule::literal_record => Ok(LiteralValue::Record(parse_literal_record(inner)?)), Rule::predicate_hash_native => { let id = parse_identifier(inner.into_inner().next().unwrap()); Ok(LiteralValue::NativePredicateHash(id)) @@ -1045,55 +855,10 @@ pub mod parse { let predicate = parse_identifier(parts.next().unwrap()); Ok(LiteralValue::ExternalPredicateHash { module, predicate }) } - Rule::record_entry_index => { - let mut parts = inner.into_inner(); - let first = parse_identifier(parts.next().unwrap()); - let second = parse_identifier(parts.next().unwrap()); - let (record, entry) = match parts.next().map(parse_identifier) { - Some(third) => ( - TypeRef::Qualified { - module: first, - name: second, - }, - third, - ), - None => (TypeRef::Local(first), second), - }; - Ok(LiteralValue::RecordEntryIndex { record, entry }) - } _ => unreachable!("Unexpected literal value rule: {:?}", inner.as_rule()), } } - fn parse_literal_record(pair: Pair) -> Result { - assert_eq!(pair.as_rule(), Rule::literal_record); - let span = get_span(&pair); - let mut inner = pair.into_inner(); - let name = parse_type_tag(inner.next().unwrap()); - let entries: Result, _> = inner - .filter(|p| p.as_rule() == Rule::record_entry) - .map(parse_record_entry) - .collect(); - Ok(LiteralRecord { - name, - entries: entries?, - span: Some(span), - }) - } - - fn parse_record_entry(pair: Pair) -> Result { - assert_eq!(pair.as_rule(), Rule::record_entry); - let span = get_span(&pair); - let mut inner = pair.into_inner(); - let name = parse_identifier(inner.next().unwrap()); - let value = parse_literal_value(inner.next().unwrap())?; - Ok(RecordEntryLiteral { - name, - value, - span: Some(span), - }) - } - fn parse_literal_int(pair: Pair) -> Result { assert_eq!(pair.as_rule(), Rule::literal_int); let value = pair @@ -1320,29 +1085,16 @@ mod tests { u.name.span = None; u.intro_hash.span = None; } - DocumentItem::RecordDef(r) => { - r.span = None; - r.name.span = None; - for e in &mut r.entries { - e.span = None; - } - } DocumentItem::CustomPredicateDef(c) => { c.span = None; c.name.span = None; c.args.span = None; for arg in &mut c.args.public_args { arg.span = None; - if let Some(t) = &mut arg.type_name { - clear_type_ref_spans(t); - } } if let Some(private) = &mut c.args.private_args { for arg in private { arg.span = None; - if let Some(t) = &mut arg.type_name { - clear_type_ref_spans(t); - } } } for stmt in &mut c.statements { @@ -1359,16 +1111,6 @@ mod tests { } } - fn clear_type_ref_spans(t: &mut TypeRef) { - match t { - TypeRef::Local(id) => id.span = None, - TypeRef::Qualified { module, name } => { - module.span = None; - name.span = None; - } - } - } - fn clear_predicate_ref_spans(pred_ref: &mut PredicateRef) { match pred_ref { PredicateRef::Local(id) => id.span = None, @@ -1392,7 +1134,6 @@ mod tests { match &mut ak.key { AnchoredKeyPath::Bracket(s) => s.span = None, AnchoredKeyPath::Dot(id) => id.span = None, - AnchoredKeyPath::Index(_) => {} } } StatementTmplArg::SelfPredicateHash(id) => id.span = None, @@ -1431,15 +1172,6 @@ mod tests { clear_literal_spans(&mut pair.value); } } - LiteralValue::Record(r) => { - r.span = None; - clear_type_ref_spans(&mut r.name); - for entry in &mut r.entries { - entry.span = None; - entry.name.span = None; - clear_literal_spans(&mut entry.value); - } - } LiteralValue::NativePredicateHash(id) => id.span = None, LiteralValue::ExternalPredicateHash { module, predicate, .. @@ -1447,10 +1179,6 @@ mod tests { module.span = None; predicate.span = None; } - LiteralValue::RecordEntryIndex { record, entry } => { - clear_type_ref_spans(record); - entry.span = None; - } } } @@ -1540,139 +1268,6 @@ mod tests { test_roundtrip(input); } - #[test] - fn test_record_decl() { - let input = r#"record ProcInputs = (foo, bar, baz)"#; - test_roundtrip(input); - } - - #[test] - fn test_record_decl_single_entry() { - let input = r#"record Singleton = (only)"#; - test_roundtrip(input); - } - - #[test] - fn test_typed_arg_in_predicate() { - let input = r#"record ProcInputs = (foo, bar, baz) -my_pred(in ProcInputs, other) = AND ( - Equal(in.foo, other) -)"#; - test_roundtrip(input); - } - - #[test] - fn test_typed_arg_mixed_with_untyped() { - let input = r#"record R = (x, y) -mixed(a, b R, c, private: d, e R) = AND ( - Equal(a, c) -)"#; - test_roundtrip(input); - } - - #[test] - fn test_typed_arg_qualified() { - // Qualified type tag references an imported record; the parser - // accepts it without inspecting the import (validation is downstream). - let input = r#"my_pred(in some_module::ProcInputs) = AND ( - Equal(in.foo, in.bar) -)"#; - test_roundtrip(input); - } - - #[test] - fn test_record_literal_full() { - let input = r#"REQUEST( - Equal(A["data"], ProcInputs(foo: 1, bar: 2, baz: 3)) -)"#; - test_roundtrip(input); - } - - #[test] - fn test_record_literal_sparse() { - let input = r#"REQUEST( - Equal(A["data"], ProcInputs(bar: 42)) -)"#; - test_roundtrip(input); - } - - #[test] - fn test_record_literal_empty_rejected() { - // Record literals require at least one entry — an empty literal - // would never validate (no schema has zero entries), so reject at - // parse time for a clearer error. - let input = r#"REQUEST( - Equal(A["data"], Empty()) -)"#; - let parsed = crate::lang::parser::parse_podlang(input); - assert!( - parsed.is_err(), - "expected empty record literal `Empty()` to be rejected" - ); - } - - #[test] - fn test_record_entry_index_local() { - // `R::foo` resolves to the entry's integer index at compile time. - let input = r#"REQUEST( - Contains(A, R::foo, 7) -)"#; - test_roundtrip(input); - } - - #[test] - fn test_record_entry_index_qualified() { - // `mod::R::foo` for an imported record. - let input = r#"REQUEST( - Contains(A, some_mod::R::foo, 7) -)"#; - test_roundtrip(input); - } - - #[test] - fn test_record_literal_nested_value() { - let input = r#"REQUEST( - Equal(A["data"], R(items: [1, 2, 3], other: {"k": "v"})) -)"#; - test_roundtrip(input); - } - - #[test] - fn test_record_literal_qualified() { - // Imported record literal: `module::R(foo: 1)`. Parses with - // `TypeRef::Qualified` as the head; PEG ordering means the - // `module::R` prefix is consumed by `literal_record` rather than - // shadowed by `record_entry_index`. - let input = r#"REQUEST( - Equal(A["data"], some_mod::R(foo: 1, bar: 2)) -)"#; - test_roundtrip(input); - } - - #[test] - fn test_record_keyword_reserved() { - // `record` may not appear as an identifier name. - let input = r#"record record = (foo)"#; - let parsed = crate::lang::parser::parse_podlang(input); - assert!( - parsed.is_err(), - "expected `record` to be rejected as an identifier" - ); - } - - #[test] - fn test_reserved_word_prefix_allowed_as_identifier() { - // The reserved-word check is anchored at a word boundary, so only - // the exact keyword is rejected. Identifiers that merely begin with - // a reserved word (`record_count`, `recorder`, `record_field`, the - // predicate name `record_using_pred`) must parse normally. - let input = r#"record Outer = (record_field, recorder) -record_using_pred(record_count, recordOwner) = AND ( - Equal(record_count, recordOwner) -)"#; - test_roundtrip(input); - } - #[test] fn test_complete_document() { let input = r#"use module 0x0000000000000000000000000000000000000000000000000000000000000000 as imported diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index 154b687..fb00def 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -13,13 +13,13 @@ use crate::{ lang::{ frontend_ast::*, frontend_ast_split, - frontend_ast_validate::{PredicateKind, RecordSource, SymbolTable, ValidatedAST}, + frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST}, module, Module, }, middleware::{ - self, containers, db::mem::MemDB, CustomPredicateRef, IntroPredicateRef, Key, - NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl, - StatementTmplArg as MWStatementTmplArg, StrKey, Value, Wildcard, + self, containers, CustomPredicateRef, IntroPredicateRef, Key, NativePredicate, Params, + Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg, Value, + Wildcard, }, }; @@ -158,10 +158,8 @@ fn resolve_local_predicate( /// Lower a literal value from AST to middleware Value. /// /// This is a pure conversion that cannot fail for context-free literals. -/// Panics on `ExternalPredicateHash`, `Record`, and `RecordEntryIndex` — -/// use `lower_literal_with_context` when any of those may appear (records -/// and entry indices need the symbol table to resolve the record schema; -/// external predicate hashes need the imported-module table). +/// Panics on ExternalPredicateHash — use `lower_literal_with_context` when +/// external predicate references may appear (e.g. inside containers). pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { match lit { LiteralValue::Int(i) => Value::from(i.value), @@ -186,7 +184,7 @@ pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { .pairs .iter() .map(|pair| { - let key = StrKey::from(pair.key.value.as_str()); + let key = Key::from(pair.key.value.as_str()); let value = lower_literal(&pair.value); (key, value) }) @@ -194,11 +192,6 @@ pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { let dict = containers::Dictionary::new(pairs); Value::from(dict) } - LiteralValue::Record(_) => { - unreachable!( - "Record literals must be lowered with context via lower_literal_with_context" - ) - } LiteralValue::NativePredicateHash(id) => { let np = NativePredicate::from_str(&id.name).expect("validated native predicate"); Value::from(Predicate::Native(np).hash()) @@ -208,11 +201,6 @@ pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { "ExternalPredicateHash must be lowered with context via lower_literal_with_context" ) } - LiteralValue::RecordEntryIndex { .. } => { - unreachable!( - "RecordEntryIndex must be lowered with context via lower_literal_with_context" - ) - } } } @@ -264,36 +252,13 @@ pub fn lower_literal_with_context( .pairs .iter() .map(|pair| { - let key = StrKey::from(pair.key.value.as_str()); + let key = Key::from(pair.key.value.as_str()); let value = lower_literal_with_context(&pair.value, symbols, context)?; Ok((key, value)) }) .collect::>()?; Ok(Value::from(containers::Dictionary::new(pairs))) } - LiteralValue::Record(r) => { - // The schema fixes each entry's index, so source order doesn't - // affect the merkle root and missing entries stay missing. - let schema = symbols - .records - .get(&r.name.symbol_table_key()) - .expect("record schema validated"); - let mut arr = containers::Array::empty_with_db(Box::new(MemDB::new())); - for entry_lit in &r.entries { - let idx = schema.entry_index[&entry_lit.name.name]; - let value = lower_literal_with_context(&entry_lit.value, symbols, context)?; - arr.insert(idx, value)?; - } - Ok(Value::from(arr)) - } - LiteralValue::RecordEntryIndex { record, entry } => { - let schema = symbols - .records - .get(&record.symbol_table_key()) - .expect("record schema validated"); - let idx = schema.entry_index[&entry.name]; - Ok(Value::from(idx as i64)) - } // All other variants are context-free other => Ok(lower_literal(other)), } @@ -311,12 +276,11 @@ pub(crate) fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { } StatementTmplArg::Wildcard(id) => BuilderArg::WildcardLiteral(id.name.clone()), StatementTmplArg::AnchoredKey(ak) => { - let key = match &ak.key { - AnchoredKeyPath::Bracket(s) => Key::new(s.value.clone()), - AnchoredKeyPath::Dot(id) => Key::new(id.name.clone()), - AnchoredKeyPath::Index(i) => Key::from(*i), + let key_str = match &ak.key { + AnchoredKeyPath::Bracket(s) => s.value.clone(), + AnchoredKeyPath::Dot(id) => id.name.clone(), }; - BuilderArg::Key(ak.root.name.clone(), key) + BuilderArg::Key(ak.root.name.clone(), key_str) } StatementTmplArg::SelfPredicateHash(id) => BuilderArg::SelfPredicateHash(id.name.clone()), } @@ -386,7 +350,6 @@ impl<'a> Lowerer<'a> { fn lower_module(self, module_name: &str) -> Result { // Extract and split custom predicates from document let custom_predicates = self.extract_and_split_predicates()?; - let local_records = self.collect_local_records(); // Build the module from split predicates let module = module::build_module( @@ -394,24 +357,11 @@ impl<'a> Lowerer<'a> { self.params, module_name, self.validated.symbols(), - local_records, )?; Ok(module) } - /// Collect record declarations from this module's source. No transitive - /// re-export — imported records aren't included. - fn collect_local_records(&self) -> HashMap> { - self.validated - .symbols() - .records - .iter() - .filter(|(_, schema)| matches!(schema.source, RecordSource::Local)) - .map(|(name, schema)| (name.clone(), schema.entries.clone())) - .collect() - } - fn lower_request(self) -> Result { let doc = self.validated.document(); @@ -479,11 +429,12 @@ impl<'a> Lowerer<'a> { let index = wildcard_map.get(&name).expect("Wildcard not found"); MWStatementTmplArg::Wildcard(Wildcard::new(name, *index)) } - BuilderArg::Key(root_name, key) => { + BuilderArg::Key(root_name, key_str) => { let root_index = wildcard_map .get(&root_name) .expect("Root wildcard not found"); let wildcard = Wildcard::new(root_name, *root_index); + let key = Key::from(key_str.as_str()); MWStatementTmplArg::AnchoredKey(wildcard, key) } BuilderArg::SelfPredicateHash(_) => { @@ -528,9 +479,21 @@ impl<'a> Lowerer<'a> { names: &mut Vec, seen: &mut HashSet, ) { - for name in stmt.wildcard_names() { - if seen.insert(name.to_string()) { - names.push(name.to_string()); + for arg in &stmt.args { + match arg { + StatementTmplArg::Wildcard(id) => { + if !seen.contains(&id.name) { + seen.insert(id.name.clone()); + names.push(id.name.clone()); + } + } + StatementTmplArg::AnchoredKey(ak) => { + if !seen.contains(&ak.root.name) { + seen.insert(ak.root.name.clone()); + names.push(ak.root.name.clone()); + } + } + StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {} } } } @@ -548,56 +511,15 @@ impl<'a> Lowerer<'a> { }) .collect(); - // Apply splitting to each predicate as needed. The typed-key rewrite - // happens before splitting so split chain pieces inherit `Index` keys - // unchanged. + // Apply splitting to each predicate as needed let mut split_results = Vec::new(); - for mut pred in predicates { - self.rewrite_typed_dot_access(&mut pred); - let result = frontend_ast_split::split_predicate_if_needed(&pred, self.params)?; + for pred in predicates { + let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?; split_results.push(result); } Ok(split_results) } - - /// Rewrite `r.foo` to `r[i]` when `r` is a typed wildcard, using the - /// record schema's entry-index map. Untyped wildcards keep - /// `Dot`/`Bracket` keys unchanged (POD-string-key semantics). - fn rewrite_typed_dot_access(&self, pred: &mut CustomPredicateDef) { - let symbols = self.validated.symbols(); - let scope = symbols - .wildcard_scopes - .get(&pred.name.name) - .expect("wildcard scope exists for every custom predicate after validation"); - // Skip the per-arg walk for predicates with no typed wildcards — - // the common case before records see widespread use. - if !scope.wildcards.values().any(|wc| wc.record_type.is_some()) { - return; - } - for stmt in &mut pred.statements { - for arg in &mut stmt.args { - let StatementTmplArg::AnchoredKey(ak) = arg else { - continue; - }; - let Some(wc_info) = scope.wildcards.get(&ak.root.name) else { - continue; - }; - let Some(record_name) = &wc_info.record_type else { - continue; - }; - let AnchoredKeyPath::Dot(entry) = &ak.key else { - continue; - }; - let schema = symbols - .records - .get(record_name) - .expect("record_type was resolved at predicate-def time"); - let idx = schema.entry_index[&entry.name]; - ak.key = AnchoredKeyPath::Index(idx as i64); - } - } - } } #[cfg(test)] @@ -617,8 +539,8 @@ mod tests { ) -> Result { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - let validated = validate(document, &HashMap::new(), params, ParseMode::Module) - .expect("Failed to validate"); + let validated = + validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate"); lower_module(validated, params, "test_batch") } @@ -847,8 +769,8 @@ mod tests { let parsed = parse_podlang(&input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document"); - let validated = validate(document, &HashMap::new(), ¶ms, ParseMode::Module) - .expect("Failed to validate"); + let validated = + validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate"); let result = lower_module(validated, ¶ms, "test_batch"); assert!(result.is_ok(), "Lowering failed: {:?}", result.err()); @@ -874,251 +796,4 @@ mod tests { other => panic!("Expected Intro predicate, got {:?}", other), } } - - // ---- Records: predicate-side dot-access lowering ----------------------- - - /// Pull the single `Key` out of statement N, arg N of the first predicate. - fn anchored_key_at( - module: &Module, - pred_idx: usize, - stmt_idx: usize, - arg_idx: usize, - ) -> middleware::Key { - let pred = &module.batch.predicates()[pred_idx]; - let stmt = &pred.statements()[stmt_idx]; - match &stmt.args()[arg_idx] { - middleware::StatementTmplArg::AnchoredKey(_, k) => k.clone(), - other => panic!("expected AnchoredKey at arg {arg_idx}, got {other:?}"), - } - } - - fn anchored_index_at(module: &Module, pred_idx: usize, stmt_idx: usize, arg_idx: usize) -> i64 { - anchored_key_at(module, pred_idx, stmt_idx, arg_idx) - .as_index() - .expect("expected Index key") - .value() - } - - #[test] - fn test_typed_dot_lowers_to_index_key() { - // Single entry on a typed wildcard becomes an integer-keyed - // AnchoredKey at the schema's entry index. - let input = r#" - record R = (foo, bar, baz) - my_pred(in R) = AND(Equal(in.bar, 0)) - "#; - let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); - assert_eq!(anchored_index_at(&module, 0, 0, 0), 1); - } - - #[test] - fn test_dot_on_untyped_wildcard_stays_str_key() { - // No type tag, no schema lookup: dot-access keeps POD-string-key - // semantics. - let input = r#" - my_pred(r) = AND(Equal(r.foo, 1)) - "#; - let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); - match anchored_key_at(&module, 0, 0, 0) { - middleware::Key::Str(sk) => assert_eq!(sk.name(), "foo"), - other => panic!("expected Str key, got {other:?}"), - } - } - - #[test] - fn test_typed_dot_multiple_entries_distinct_indices() { - let input = r#" - record R = (foo, bar, baz) - my_pred(in R) = AND( - Equal(in.foo, in.baz) - Equal(in.bar, 0) - ) - "#; - let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); - assert_eq!(anchored_index_at(&module, 0, 0, 0), 0); - assert_eq!(anchored_index_at(&module, 0, 0, 1), 2); - assert_eq!(anchored_index_at(&module, 0, 1, 0), 1); - } - - #[test] - fn test_typed_dot_in_or_predicate() { - // OR predicates: the lowering produces a single AnchoredKey per - // statement, no cross-statement coupling, so OR works the same as AND. - let input = r#" - record R = (foo, bar) - my_pred(in R) = OR( - Equal(in.foo, 1) - Equal(in.bar, 2) - ) - "#; - let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); - assert!(module.batch.predicates()[0].is_disjunction()); - assert_eq!(anchored_index_at(&module, 0, 0, 0), 0); - assert_eq!(anchored_index_at(&module, 0, 1, 0), 1); - } - - #[test] - fn test_record_predicate_hash_matches_handwritten_index_form() { - // Source-level records are syntactic sugar: the predicate hash for - // `record R = (foo, bar); p(in R) = AND(Equal(in.bar, 7))` must equal - // the hash of the same predicate built directly with an integer-keyed - // anchored key. There is no Podlang surface syntax for `in[1]`, so we - // build the reference batch via the builder API. - use crate::{ - frontend::{CustomPredicateBatchBuilder, StatementTmplBuilder}, - middleware::NativePredicate, - }; - - let with_record = r#" - record R = (foo, bar) - p(in R) = AND(Equal(in.bar, 7)) - "#; - let params = Params::default(); - let m_record = parse_validate_and_lower_module(with_record, ¶ms).unwrap(); - - let mut b = CustomPredicateBatchBuilder::new(params.clone(), "test_batch".into()); - let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) - .arg(BuilderArg::Key("in".into(), Key::from(1i64))) - .arg(BuilderArg::Literal(Value::from(7i64))); - b.predicate_and("p", &["in"], &[], &[stb]).unwrap(); - let plain_batch = b.finish().unwrap(); - - assert_eq!(m_record.batch.id(), plain_batch.id()); - } - - // ---- Records: literal lowering ----------------------------------------- - - fn lower_literal_in_pred(input: &str) -> Value { - let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); - let pred = &module.batch.predicates()[0]; - let stmt = &pred.statements()[0]; - match &stmt.args()[1] { - middleware::StatementTmplArg::Literal(v) => v.clone(), - other => panic!("expected Literal at arg 1, got {other:?}"), - } - } - - #[test] - fn test_record_literal_full_matches_array_root() { - // A fully populated literal must hash identically to the same values - // packed into an `Array::new(...)` (which inserts at indices 0..n in - // order). - let input = r#" - record R = (foo, bar, baz) - my_pred(A) = AND(Equal(A["x"], R(foo: 1, bar: 2, baz: 3))) - "#; - let v = lower_literal_in_pred(input); - let expected = Value::from(containers::Array::new(vec![ - Value::from(1i64), - Value::from(2i64), - Value::from(3i64), - ])); - assert_eq!(v.raw(), expected.raw()); - } - - #[test] - fn test_record_literal_entry_order_doesnt_matter() { - // Schema fixes the index, so source order never affects the root. - let input_a = r#" - record R = (foo, bar) - my_pred(A) = AND(Equal(A["x"], R(foo: 1, bar: 2))) - "#; - let input_b = r#" - record R = (foo, bar) - my_pred(A) = AND(Equal(A["x"], R(bar: 2, foo: 1))) - "#; - assert_eq!( - lower_literal_in_pred(input_a).raw(), - lower_literal_in_pred(input_b).raw() - ); - } - - #[test] - fn test_record_literal_sparse_stays_sparse() { - // Missing entries stay missing (no zero-fill). Compare against an - // explicit sparse Array built the same way. - let input = r#" - record R = (foo, bar, baz) - my_pred(A) = AND(Equal(A["x"], R(bar: 42))) - "#; - let v = lower_literal_in_pred(input); - - let mut sparse = containers::Array::empty_with_db(Box::new(MemDB::new())); - sparse.insert(1, Value::from(42i64)).unwrap(); - let expected = Value::from(sparse); - - assert_eq!(v.raw(), expected.raw()); - } - - #[test] - fn test_record_literal_nested_record_value() { - // A record literal whose entry value is itself a record literal. - // The outer literal commits to whatever root the inner produces. - let input = r#" - record Inner = (x, y) - record Outer = (inner) - my_pred(A) = AND(Equal(A["x"], Outer(inner: Inner(x: 1, y: 2)))) - "#; - let v = lower_literal_in_pred(input); - - let inner = Value::from(containers::Array::new(vec![ - Value::from(1i64), - Value::from(2i64), - ])); - let expected = Value::from(containers::Array::new(vec![inner])); - - assert_eq!(v.raw(), expected.raw()); - } - - #[test] - fn test_typed_dot_survives_predicate_splitting() { - // The rewrite runs before splitting, so chain pieces inherit - // `Index` keys unchanged. Force a split by exceeding the - // per-predicate statement cap. - let input = r#" - record R = (a, b, c, d, e, f) - my_pred(in R) = AND( - Equal(in.a, 1) - Equal(in.b, 2) - Equal(in.c, 3) - Equal(in.d, 4) - Equal(in.e, 5) - Equal(in.f, 6) - ) - "#; - let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); - // Splitter ran (max_custom_predicate_arity = 5). - assert!(module.batch.predicates().len() > 1); - // Every AnchoredKey across all chain pieces is integer-keyed. - for pred in module.batch.predicates() { - for stmt in pred.statements() { - for arg in stmt.args() { - if let middleware::StatementTmplArg::AnchoredKey(_, k) = arg { - assert!( - matches!(k, middleware::Key::Index(_)), - "expected Index key in split chain piece, got {k:?}" - ); - } - } - } - } - } - - #[test] - fn test_record_entry_index_lowers_to_integer() { - // `R::bar` resolves to integer 1 (bar is the second entry of R). - let input = r#" - record R = (foo, bar, baz) - my_pred(A) = AND(Contains(A, R::bar, 7)) - "#; - let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); - let pred = &module.batch.predicates()[0]; - let stmt = &pred.statements()[0]; - match &stmt.args()[1] { - middleware::StatementTmplArg::Literal(v) => { - assert_eq!(v.raw(), Value::from(1i64).raw()); - } - other => panic!("expected Literal at arg 1, got {other:?}"), - } - } } diff --git a/src/lang/frontend_ast_split.rs b/src/lang/frontend_ast_split.rs index cfb16ef..482db7a 100644 --- a/src/lang/frontend_ast_split.rs +++ b/src/lang/frontend_ast_split.rs @@ -1,66 +1,26 @@ -// REVIEW(EDU): Summary -// Overall looks good to me! But before merging please address these two points, explianed below: -// - Briefly document the strict and elastic model in the top level docstring, and in relevant -// functions that generate constraints document whether those constraints are for the strict, -// elastic or both models (I think some functions don't have this information) -// - Use the wildcard order defined in the args of a custom predicate instead of sotring the -// wildcards by name //! Predicate splitting for frontend AST //! -//! Predicates whose statement count exceeds the middleware's -//! `max_custom_predicate_arity` are split into a chain of smaller predicates, -//! each calling the next via a tail-position chain call. Private wildcards -//! that span a split boundary must be promoted to public arguments on the -//! continuation, since they need the same binding on both sides. +//! This module implements automatic predicate splitting when predicates exceed +//! middleware constraints. //! -//! The split is computed by an MILP that, for a given number of links K: +//! When splitting a predicate, we try to group statements that use the same +//! wildcards together. However, if a private wildcard must be used across a +//! split boundary, it must be promoted to a public argument in the latter +//! predicate, to ensure that it is bound to the same value in both predicates. //! -// REVIEW(Edu): what about calling it `statement template` instead of `statement` here? -// I see that in variables and comments below you use the term statement for statement -// template. Perhaps I'm being too pedantic and just saying statement is clear enough, so maybe -// ignore my suggestion. -//! - Assigns each statement to exactly one link. -//! - Tracks which wildcards are used and where, derives "live ranges," and -//! counts each link's declared public/private wildcards. -//! - Caps each link's public-arg count at `max_statement_args` and total -//! declared wildcards at `max_custom_predicate_wildcards`. -//! - Reserves a chain-call slot on every non-last link. +//! A wildcard is "live" at a split boundary if it is used in a statement on both +//! sides of the boundary. We want to minimize the number of live wildcards at +//! split boundaries, to minimize the number of promotions required. //! -//! We try `K = K_min, K_min+1, ...` and stop at the first feasible K. The -//! objective is a tiny `Σ (n-s) · i · assign[s][i]` tiebreaker that biases -//! statements with low original index toward low-index links — so the chain -//! roughly follows source order when nothing else forces a rearrangement. -//! -//! On infeasibility for every K up to `n`, we emit -//! [`SplittingError::Infeasible`]. +//! We use a greedy algorithm to order the statements in a predicate to minimize +//! the number of live wildcards at split boundaries. -#![allow(clippy::needless_range_loop)] - -use std::{ - collections::{HashMap, HashSet}, - fmt, -}; - -use good_lp::{ - constraint, solvers::scip::SCIPProblem, variable, Expression, ProblemVariables, Solution, - SolverModel, Variable, -}; - -/// Solver random-seed shift. Pinning this gives within-version reproducibility -/// against any internal SCIP heuristics that consult the seed (random -/// branching, restart shuffles, etc.). Cross-version determinism still -/// depends on SCIP not changing its algorithms; pin russcip in `Cargo.toml` -/// to control that. -const SCIP_RANDOM_SEED: i32 = 0; +use std::{cmp::Reverse, collections::HashSet}; +// SplittingError is now defined in error.rs pub use crate::lang::error::SplittingError; use crate::{lang::frontend_ast::*, middleware::Params}; -/// Threshold for interpreting MILP solver's floating-point results as binary. -/// The solver returns continuous values in [0, 1] for binary variables; -/// values > 0.5 are interpreted as "true" (1), otherwise "false" (0). -const SOLVER_BINARY_THRESHOLD: f64 = 0.5; - /// A link in the predicate chain #[derive(Debug, Clone)] pub struct ChainLink { @@ -108,70 +68,6 @@ pub struct SplitResult { pub chain_info: Option, } -/// Per-link bottleneck found by [`analyze_infeasibility`]: how far each -/// binding link overshoots the per-link caps, and which wildcards crowd it. -#[derive(Debug, Clone)] -pub struct LinkOvershoot { - pub link_index: usize, - /// Number of public-args slots over `max_statement_args` for this link. - pub public_args_overflow: usize, - /// Number of total declared-wildcard slots over `max_custom_predicate_wildcards`. - pub total_args_overflow: usize, - /// Wildcards passed in to this link as public args (in the elastic solution). - pub public_args_in: Vec, - /// Wildcards declared as private at this link (in the elastic solution). - pub private_args: Vec, -} - -/// Diagnostic report explaining why [`split_predicate_if_needed`] returned -/// [`SplittingError::Infeasible`]. Produced by [`analyze_infeasibility`] on -/// demand — the splitter itself doesn't compute it, since computing it -/// requires a second LP solve. -// REVIEW(Edu): Does this mean that for an infeasible solution we get a result that doesn't pass -// all LP constraints? -#[derive(Debug, Clone)] -pub struct InfeasibilityReport { - pub predicate: String, - /// Number of links the elastic LP was solved at (the minimum K). - pub k: usize, - /// Per-link overshoots in link-index order. Links not over any cap are omitted. - pub overshoots: Vec, -} - -impl fmt::Display for InfeasibilityReport { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "Predicate '{}' cannot be split into {} link(s) without overflowing per-link caps:", - self.predicate, self.k - )?; - let max_args = Params::max_statement_args(); - for o in &self.overshoots { - if o.public_args_overflow > 0 { - writeln!( - f, - " link {}: public_args_in = [{}] ({} args, {} over the {}-arg cap)", - o.link_index, - o.public_args_in.join(", "), - o.public_args_in.len(), - o.public_args_overflow, - max_args - )?; - } - if o.total_args_overflow > 0 { - writeln!( - f, - " link {}: declared {} wildcards (public_args_in + private_args), {} over the cap", - o.link_index, - o.public_args_in.len() + o.private_args.len(), - o.total_args_overflow, - )?; - } - } - Ok(()) - } -} - /// Early validation: Check if predicate is fundamentally splittable pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), SplittingError> { let public_args = pred.args.public_args.len(); @@ -190,20 +86,23 @@ pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), Ok(()) } -/// Split a predicate into a chain if it exceeds statement limit. +/// Split a predicate into a chain if it exceeds statement limit pub fn split_predicate_if_needed( - pred: &CustomPredicateDef, + pred: CustomPredicateDef, params: &Params, ) -> Result { - validate_predicate_is_splittable(pred)?; + // Early validation + validate_predicate_is_splittable(&pred)?; + // If within limits, no splitting needed if pred.statements.len() <= Params::max_custom_predicate_arity() { return Ok(SplitResult { - predicates: vec![pred.clone()], + predicates: vec![pred], chain_info: None, }); } + // Need to split - execute the splitting algorithm let (predicates, chain_info) = split_into_chain(pred, params)?; Ok(SplitResult { @@ -212,379 +111,357 @@ pub fn split_predicate_if_needed( }) } +/// Collect all wildcard names from a statement fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet { - stmt.wildcard_names().map(str::to_string).collect() -} + let mut wildcards = HashSet::new(); -/// Compute the minimum number of chain links needed to fit `n` statements, -/// given that non-last links reserve 1 slot for the chain call (so they hold -/// up to `max_arity - 1` real statements) and the last link uses all of -/// `max_arity`. -fn compute_min_links(n: usize) -> usize { - let max_arity = Params::max_custom_predicate_arity(); - if n <= max_arity { - 1 - } else { - // Smallest K such that (K-1)·(max_arity-1) + max_arity >= n - (n - max_arity).div_ceil(max_arity - 1) + 1 - } -} - -/// MILP outcome for a single K: `links[i]` is the list of original statement -/// indices placed in link i, in original order. -type LinkAssignment = Vec>; - -/// MILP variables shared by the strict feasibility solve and the elastic -/// diagnostic solve. -/// -/// All variables are binary. Constraints (C1..C7 below) make every variable -/// other than `assign` an exact function of the assignment, so the strict and -/// elastic models differ only in how they handle the per-link caps (C8/C9). -// REVIEW(Edu): what's the difference between strict and elastic model? Is strict a model that -// gives a valid solution and elastic a model that gives invalid solutions that showcase the -// bottlenecks? If this implementation has 2 different solving models, could you briefly document -// them in the top level docstring? -struct MilpVars { - n: usize, - k: usize, - num_wildcards: usize, - /// `assign[s][i]`: statement `s` placed in link `i`. - assign: Vec>, - /// `u[w][i]`: wildcard `w` referenced by some statement at link `i`. - u: Vec>, - /// `before[w][i]`: cumulative OR of `u[w][·]` from the left — w is used at link ≤ i. - before: Vec>, - // REVIEW(Edu): At first I was reading these as `public input` and `private input`; but I think - // they mean `public in this link` and `private in this link`. Is this correct? - /// `after[w][i]`: cumulative OR of `u[w][·]` from the right — w is used at link ≥ i. - after: Vec>, - /// `pubin[w][i]`: w appears in link i's `public_args_in`. - pubin: Vec>, - /// `privin[w][i]`: w appears in link i's `private_args` list. - privin: Vec>, -} - -fn mk_binary_grid(vars: &mut ProblemVariables, rows: usize, cols: usize) -> Vec> { - (0..rows) - .map(|_| (0..cols).map(|_| vars.add(variable().binary())).collect()) - .collect() -} - -fn declare_milp_vars( - vars: &mut ProblemVariables, - n: usize, - k: usize, - num_wildcards: usize, -) -> MilpVars { - MilpVars { - n, - k, - num_wildcards, - assign: mk_binary_grid(vars, n, k), - u: mk_binary_grid(vars, num_wildcards, k), - before: mk_binary_grid(vars, num_wildcards, k), - after: mk_binary_grid(vars, num_wildcards, k), - pubin: mk_binary_grid(vars, num_wildcards, k), - privin: mk_binary_grid(vars, num_wildcards, k), - } -} - -/// Source-order tiebreaker: prefers low-original-index statements at low-link -/// indices, so the chain roughly preserves source order when nothing else -/// forces a rearrangement. -/// -/// Coefficient `(n - s)` is strictly positive for every statement, so every -/// pairwise swap of distinct `(s, i)` assignments changes the objective by a -/// non-zero amount. That makes the tiebreaker uniquely-optimising — the -/// solver can't pick between equivalent placements of any single statement. -fn source_order_tiebreaker(v: &MilpVars) -> Expression { - (0..v.n) - .flat_map(|s| (0..v.k).map(move |i| (s, i))) - .map(|(s, i)| ((v.n - s) as f64) * (i as f64) * v.assign[s][i]) - .sum() -} - -/// Build a SCIP model with the splitter's deterministic-build settings: -/// pinned random seed and silent output. -fn build_scip_model(vars: ProblemVariables, objective: Expression) -> SCIPProblem { - vars.minimise(objective) - .using(good_lp::solvers::scip::scip) - .set_option("randomization/randomseedshift", SCIP_RANDOM_SEED) - .set_verbose(false) -} - -// REVIEW(Edu): So this defines constraints that can always find a solution, but the solution may -// have too many wildcards (public or total) in a link. I guess this is used for the elastic -// solution. -/// Add the MILP's structural constraints (C1..C7): assignment, link size, -/// `u`/`before`/`after`/`pubin`/`privin` definitions. Cap constraints (C8/C9) -/// are added by the caller — the strict and elastic versions differ there. -fn add_structural_constraints( - model: &mut M, - v: &MilpVars, - statements_using: &[Vec], - is_original_public: &[bool], -) { - let max_arity = Params::max_custom_predicate_arity(); - let MilpVars { - n, - k, - num_wildcards, - assign, - u, - before, - after, - pubin, - privin, - } = v; - let (n, k, num_wildcards) = (*n, *k, *num_wildcards); - - // C1: Each statement assigned to exactly one link. - for s in 0..n { - let sum: Expression = (0..k).map(|i| assign[s][i]).sum(); - model.add_constraint(constraint!(sum == 1)); - } - - // C2: Per-link statement count. Non-last links reserve a slot for the - // chain call. Also require at least one statement per link. - for i in 0..k { - let cap = if i + 1 < k { max_arity - 1 } else { max_arity }; - let sum: Expression = (0..n).map(|s| assign[s][i]).sum(); - model.add_constraint(constraint!(sum.clone() <= cap as f64)); - model.add_constraint(constraint!(sum >= 1)); - } - - // C3: u[w][i] is exactly the OR over s referencing w of assign[s][i]. - for w in 0..num_wildcards { - for i in 0..k { - for &s in &statements_using[w] { - // If statement s is assigned to link i, then link i uses all wildcards w that - // appear in s. - model.add_constraint(constraint!(u[w][i] >= assign[s][i])); + for arg in &stmt.args { + match arg { + StatementTmplArg::Wildcard(id) => { + wildcards.insert(id.name.clone()); } - // sum of statements in link i that use wildcard w - let upper: Expression = statements_using[w] - .iter() - .map(|&s| Expression::from(assign[s][i])) - .sum(); - // If wildcard w is used in link i, at least one statement requires the wildcard - model.add_constraint(constraint!(u[w][i] <= upper)); + StatementTmplArg::AnchoredKey(ak) => { + wildcards.insert(ak.root.name.clone()); + } + StatementTmplArg::Literal(_) | StatementTmplArg::SelfPredicateHash(_) => {} } } - // C4: before[w][i] = u[w][0] OR u[w][1] OR ... OR u[w][i]. - for w in 0..num_wildcards { - model.add_constraint(constraint!(before[w][0] == u[w][0])); - for i in 1..k { - model.add_constraint(constraint!(before[w][i] >= before[w][i - 1])); - model.add_constraint(constraint!(before[w][i] >= u[w][i])); - model.add_constraint(constraint!(before[w][i] <= before[w][i - 1] + u[w][i])); - } - } - - // C5: after[w][i] = u[w][i] OR u[w][i+1] OR ... OR u[w][k-1]. - for w in 0..num_wildcards { - model.add_constraint(constraint!(after[w][k - 1] == u[w][k - 1])); - for i in (0..k - 1).rev() { - model.add_constraint(constraint!(after[w][i] >= after[w][i + 1])); - model.add_constraint(constraint!(after[w][i] >= u[w][i])); - model.add_constraint(constraint!(after[w][i] <= after[w][i + 1] + u[w][i])); - } - } - - // C6: pubin definitions. - for w in 0..num_wildcards { - if is_original_public[w] { - // Original public args: declared at link 0 (predicate signature) - // and forwarded to link i iff used at some link ≥ i. - model.add_constraint(constraint!(pubin[w][0] == 1)); - for i in 1..k { - model.add_constraint(constraint!(pubin[w][i] == after[w][i])); - } - } else { - // Private wildcards: pubin[w][i] = before[w][i-1] AND after[w][i] - // (used somewhere strictly before AND somewhere at i or later). - model.add_constraint(constraint!(pubin[w][0] == 0)); - for i in 1..k { - model.add_constraint(constraint!(pubin[w][i] <= before[w][i - 1])); - model.add_constraint(constraint!(pubin[w][i] <= after[w][i])); - model.add_constraint(constraint!( - pubin[w][i] >= before[w][i - 1] + after[w][i] - 1 - )); - } - } - } - - // C7: privin definitions. - for w in 0..num_wildcards { - if is_original_public[w] { - for i in 0..k { - model.add_constraint(constraint!(privin[w][i] == 0)); - } - } else { - // privin[w][0] = u[w][0]: at link 0 there is no "before," so a - // private wildcard used at link 0 is necessarily declared private. - model.add_constraint(constraint!(privin[w][0] == u[w][0])); - for i in 1..k { - // privin[w][i] = u[w][i] AND NOT before[w][i-1] - model.add_constraint(constraint!(privin[w][i] <= u[w][i])); - model.add_constraint(constraint!(privin[w][i] <= 1 - before[w][i - 1])); - model.add_constraint(constraint!(privin[w][i] >= u[w][i] - before[w][i - 1])); - } - } - } + wildcards } -// REVIEW(Edu): This is the strict solution constraints -/// Try to partition `n` statements into exactly `k` links using MILP. -/// -/// Returns `Some(assignment)` if a feasible partition exists, `None` if the -/// model is infeasible at this K (caller should try a larger K). -fn solve_milp_for_k( - n: usize, - k: usize, - statements_using: &[Vec], - is_original_public: &[bool], - params: &Params, -) -> Option { - let max_args = Params::max_statement_args(); - let max_wildcards = params.max_custom_predicate_wildcards; - // REVIEW(Edu): I'm confused by the name `is_original_public`. This makes me think about - // public wildcards in the original definition. But you use it as `num_wildcards` which - // includes the private ones. - // OH, `is_original_public` has length=num_wildcards, it's just a "map" - let num_wildcards = is_original_public.len(); - - let mut vars = ProblemVariables::new(); - let v = declare_milp_vars(&mut vars, n, k, num_wildcards); - let objective = source_order_tiebreaker(&v); - let mut model = build_scip_model(vars, objective); - add_structural_constraints(&mut model, &v, statements_using, is_original_public); - - // C8: per-link public-args cap (incoming chain-call args). - for i in 0..k { - let sum: Expression = (0..num_wildcards).map(|w| v.pubin[w][i]).sum(); - model.add_constraint(constraint!(sum <= max_args as f64)); - } - - // C9: per-link total declared wildcards cap. - for i in 0..k { - let sum: Expression = (0..num_wildcards) - .map(|w| Expression::from(v.pubin[w][i]) + v.privin[w][i]) - .sum(); - model.add_constraint(constraint!(sum <= max_wildcards as f64)); - } - - let solution = model.solve().ok()?; - - // Extract per-link statement lists in original-index order. - let mut links: LinkAssignment = vec![Vec::new(); k]; - for s in 0..n { - for i in 0..k { - if solution.value(v.assign[s][i]) > SOLVER_BINARY_THRESHOLD { - links[i].push(s); - break; - } - } - } - Some(links) +/// Order constraints optimally to minimize liveness at boundaries +/// Result of ordering statements optimally for splitting +struct OrderingResult { + /// Reordered statements + statements: Vec, + /// Maps original statement index → reordered index + /// reorder_map[original_idx] = new_idx + reorder_map: Vec, } -/// Convert an MILP link assignment into [`ChainLink`]s, computing each link's -/// public/private/promoted wildcards from the assignment plus the original -/// public-args list. -fn build_chain_links_from_assignment( - links: LinkAssignment, - statements: &[StatementTmpl], - original_public_args: &[String], -) -> Vec { - let k = links.len(); - let stmt_wcs: Vec> = statements - .iter() - .map(collect_wildcards_from_statement) - .collect(); - let link_wcs: Vec> = (0..k) - .map(|i| { - links[i] - .iter() - .flat_map(|&s| stmt_wcs[s].iter().cloned()) - .collect() - }) - .collect(); +fn order_constraints_optimally( + statements: Vec, + public_args: &HashSet, +) -> OrderingResult { + let n = statements.len(); - let mut result = Vec::with_capacity(k); - let mut incoming: Vec = original_public_args.to_vec(); + // If no splitting needed, preserve original order (identity mapping) + if n <= Params::max_custom_predicate_arity() { + return OrderingResult { + statements, + reorder_map: (0..n).collect(), + }; + } - for i in 0..k { - let stmts: Vec = links[i].iter().map(|&s| statements[s].clone()).collect(); + let mut ordered = Vec::new(); + let mut reorder_map = vec![0; n]; + let mut remaining: HashSet = (0..n).collect(); + let mut active_wildcards: HashSet = HashSet::new(); - // Wildcards crossing forward from link i (used here AND later). - let after_wcs: HashSet = (i + 1..k) - .flat_map(|j| link_wcs[j].iter().cloned()) - .collect(); - let crossings: HashSet = link_wcs[i].intersection(&after_wcs).cloned().collect(); + while !remaining.is_empty() { + let best_idx = find_best_next_statement( + &statements, + &remaining, + &active_wildcards, + ordered.len(), + public_args, + ); - let incoming_set: HashSet = incoming.iter().cloned().collect(); + remaining.remove(&best_idx); + let stmt = &statements[best_idx]; - let mut promotions: Vec = crossings + // Record the mapping: original index best_idx → new index ordered.len() + reorder_map[best_idx] = ordered.len(); + ordered.push(stmt.clone()); + + // Only track private wildcards in the active set — public args are always + // available at every boundary so their liveness is irrelevant to split cost. + let stmt_wildcards = collect_wildcards_from_statement(stmt); + active_wildcards.extend( + stmt_wildcards + .into_iter() + .filter(|w| !public_args.contains(w)), + ); + + // Remove private wildcards no longer needed by remaining statements + let needed_later: HashSet<_> = remaining .iter() - .filter(|w| !incoming_set.contains(*w)) - .cloned() + .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) + .filter(|w| !public_args.contains(w)) .collect(); - promotions.sort(); - - let mut private_args: Vec = link_wcs[i] - .difference(&incoming_set) - .filter(|w| !crossings.contains(*w)) - .cloned() - .collect(); - private_args.sort(); - - result.push(ChainLink { - statements: stmts, - public_args_in: incoming.clone(), - private_args, - promoted_wildcards: promotions.clone(), - }); - - incoming.extend(promotions); + active_wildcards.retain(|w| needed_later.contains(w)); } - // REVIEW(Edu): shouldn't this be addressed by the existing constraints? - // - If a wildcard is pub in link i, then it's used in that link or after (C6) - // - If a wildcard is "at_or_after" i, then a statement "at_or_after" i uses it (C5) - // Backward pruning: drop public args from continuations that no link - // (this one or downstream) actually references. Link 0 keeps its full - // user-declared signature. - let num_links = result.len(); - if num_links > 1 { - let last = num_links - 1; - result[last] - .public_args_in - .retain(|a| link_wcs[last].contains(a)); - for i in (1..last).rev() { - let needed_downstream: HashSet = - result[i + 1].public_args_in.iter().cloned().collect(); - result[i] - .public_args_in - .retain(|a| link_wcs[i].contains(a) || needed_downstream.contains(a)); + OrderingResult { + statements: ordered, + reorder_map, + } +} + +/// Compute tie-breaker metrics for deterministic ordering when scores are equal +/// Returns (simplicity, public_closure, negative_fanout) tuple for use in max_by_key +fn compute_tie_breakers( + stmt: &StatementTmpl, + active_wildcards: &HashSet, + statements: &[StatementTmpl], + remaining: &HashSet, + needed_later: &HashSet, + public_args: &HashSet, +) -> (usize, usize, i32) { + let all_wildcards = collect_wildcards_from_statement(stmt); + // Only consider private wildcards for tie-breaking metrics + let stmt_wildcards: HashSet<_> = all_wildcards + .into_iter() + .filter(|w| !public_args.contains(w)) + .collect(); + + // Metric 1: Simplicity - prefer statements with fewer private wildcards + let simplicity = usize::MAX - stmt_wildcards.len(); + + // Metric 2: Closure - prefer statements that close active private wildcards + // (wildcards that won't be needed by any remaining statements) + let closes_count = stmt_wildcards + .intersection(active_wildcards) + .filter(|w| !needed_later.contains(*w)) + .count(); + + // Metric 3: Fanout - prefer statements with lower future usage + // (number of remaining statements sharing private wildcards with this statement) + let fanout = remaining + .iter() + .filter(|&&i| { + let other_wildcards: HashSet<_> = collect_wildcards_from_statement(&statements[i]) + .into_iter() + .filter(|w| !public_args.contains(w)) + .collect(); + !stmt_wildcards.is_disjoint(&other_wildcards) + }) + .count(); + + (simplicity, closes_count, -(fanout as i32)) +} + +fn statement_selection_key( + idx: usize, + statements: &[StatementTmpl], + active_wildcards: &HashSet, + remaining: &HashSet, + approaching_split: bool, + public_args: &HashSet, +) -> (i32, (usize, usize, i32), Reverse) { + // Pre-compute needed_later once and share between primary score and tie-breakers. + // Exclude the candidate itself: we want to know what the *other* remaining statements + // need, so that wildcards used only by this candidate correctly appear as closeable. + let needed_later: HashSet = remaining + .iter() + .filter(|&&i| i != idx) + .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) + .filter(|w| !public_args.contains(w)) + .collect(); + + let primary_score = score_statement( + &statements[idx], + active_wildcards, + approaching_split, + public_args, + &needed_later, + ); + let tie_breakers = compute_tie_breakers( + &statements[idx], + active_wildcards, + statements, + remaining, + &needed_later, + public_args, + ); + + // Final deterministic tie-breaker: prefer smaller original indices. + // This avoids hash-iteration-dependent selection when scores are equal. + (primary_score, tie_breakers, Reverse(idx)) +} + +/// Find the best next statement to add based on scoring heuristic +fn find_best_next_statement( + statements: &[StatementTmpl], + remaining: &HashSet, + active_wildcards: &HashSet, + ordered_count: usize, + public_args: &HashSet, +) -> usize { + // Calculate distance to next split point + let bucket_size = Params::max_custom_predicate_arity() - 1; // Reserve slot for chain call + let distance_to_split = bucket_size - (ordered_count % bucket_size); + let approaching_split = distance_to_split <= 2; + + remaining + .iter() + .max_by_key(|&&idx| { + statement_selection_key( + idx, + statements, + active_wildcards, + remaining, + approaching_split, + public_args, + ) + }) + .copied() + .unwrap() +} + +/// Score a statement based on how well it minimizes private-wildcard liveness at boundaries. +/// `needed_later` is the set of private wildcards used by any remaining statement. +fn score_statement( + stmt: &StatementTmpl, + active_wildcards: &HashSet, + approaching_split: bool, + public_args: &HashSet, + needed_later: &HashSet, +) -> i32 { + let all_wildcards = collect_wildcards_from_statement(stmt); + + // Only score based on private wildcards. Public args are always available at every + // split boundary — they never consume a promotion slot, so their liveness is free. + let stmt_wildcards: HashSet<_> = all_wildcards + .into_iter() + .filter(|w| !public_args.contains(w)) + .collect(); + + // Statements that touch only public args ("cheap" statements) waste a bucket slot + // that could be used to cluster private wildcards. Strongly defer them while any + // private-wildcard statements remain, so they fill leftover space at the end. + // `needed_later` is non-empty iff some remaining statement has a private wildcard. + if stmt_wildcards.is_empty() { + return if needed_later.is_empty() { + 0 + } else { + i32::MIN / 2 + }; + } + + // How many active private wildcards does this reuse? + let reuse_count = stmt_wildcards.intersection(active_wildcards).count(); + + // How many new private wildcards does this introduce? + let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count(); + + // Which of the projected-active wildcards are still needed after this statement? + let mut projected_active = active_wildcards.clone(); + projected_active.extend(stmt_wildcards); + projected_active.retain(|w| needed_later.contains(w)); + let still_active_count = projected_active.len(); + + // Base score: + // +3 per reused wildcard — rewards clustering (wildcard already open, no new cost) + // -4 per new wildcard — penalises opening new live ranges + // -2 per still-live — penalises carrying many wildcards toward the boundary + let base_score = (reuse_count * 3) as i32 + - (new_wildcard_count * 4) as i32 + - (still_active_count * 2) as i32; + + // When close to a split boundary, strongly reward statements that close wildcards + // (active.len() + new - still_active = number of wildcards resolved by this statement). + // Weight 10 >> max base-score magnitude to make closing the dominant factor. + if approaching_split { + let closes_count = active_wildcards.len() + new_wildcard_count - still_active_count; + base_score + (closes_count * 10) as i32 + } else { + base_score + } +} + +/// Calculate which wildcards are live at a split boundary +fn calculate_live_wildcards( + before_split: &[StatementTmpl], + after_split: &[StatementTmpl], +) -> HashSet { + let before: HashSet<_> = before_split + .iter() + .flat_map(collect_wildcards_from_statement) + .collect(); + + let after: HashSet<_> = after_split + .iter() + .flat_map(collect_wildcards_from_statement) + .collect(); + + // Live = in both sets (crosses boundary) + before.intersection(&after).cloned().collect() +} + +/// Generate a refactor suggestion for wildcards crossing a boundary +fn generate_refactor_suggestion( + crossing_wildcards: &[String], + ordered_statements: &[StatementTmpl], +) -> Option { + use crate::lang::error::RefactorSuggestion; + + if crossing_wildcards.is_empty() { + return None; + } + + // Normalize wildcard order so diagnostics are deterministic. + let mut sorted_crossing_wildcards = crossing_wildcards.to_vec(); + sorted_crossing_wildcards.sort(); + + // Analyze the span of each crossing wildcard + let mut wildcard_spans: Vec<(String, usize, usize, usize)> = Vec::new(); + + for wildcard in &sorted_crossing_wildcards { + let mut first_use = None; + let mut last_use = None; + + for (i, stmt) in ordered_statements.iter().enumerate() { + let wildcards = collect_wildcards_from_statement(stmt); + if wildcards.contains(wildcard) { + if first_use.is_none() { + first_use = Some(i); + } + last_use = Some(i); + } + } + + if let (Some(first), Some(last)) = (first_use, last_use) { + let span = last - first; + wildcard_spans.push((wildcard.clone(), first, last, span)); } } - result + // Sort by span (largest first) + wildcard_spans.sort_by(|a, b| b.3.cmp(&a.3)); + + if let Some((wildcard, first, last, span)) = wildcard_spans.first() { + // If a single wildcard has a large span, suggest reducing it + if *span > 3 { + return Some(RefactorSuggestion::ReduceWildcardSpan { + wildcard: wildcard.clone(), + first_use: *first, + last_use: *last, + span: *span, + }); + } + } + + // If multiple wildcards cross the boundary, suggest grouping + if sorted_crossing_wildcards.len() > 1 { + return Some(RefactorSuggestion::GroupWildcardUsages { + wildcards: sorted_crossing_wildcards, + }); + } + + None } -/// Numeric encoding of a predicate's wildcard graph, ready for either the -/// strict MILP or the elastic diagnostic LP. -struct MilpInput { - n: usize, - wildcard_names: Vec, - statements_using: Vec>, - is_original_public: Vec, - original_public_args: Vec, -} +/// Split into chain using bucket-filling approach +/// Returns the split predicates and metadata about the split +fn split_into_chain( + pred: CustomPredicateDef, + params: &Params, +) -> Result<(Vec, SplitChainInfo), SplittingError> { + let original_name = pred.name.name.clone(); + let conjunction = pred.conjunction_type; -fn prepare_milp_input(pred: &CustomPredicateDef) -> MilpInput { let original_public_args: Vec = pred .args .public_args @@ -592,209 +469,154 @@ fn prepare_milp_input(pred: &CustomPredicateDef) -> MilpInput { .map(|id| id.name.clone()) .collect(); - // REVIEW(Edu): Wait a second, in `pred.args` we define an order of wildcards, I think we - // should use that one instead of sorting them by name. - // Stable, sorted index over wildcards referenced by statements OR declared - // as public args (a public arg may be unused in any statement). - let mut wildcard_set: HashSet = pred - .statements - .iter() - .flat_map(collect_wildcards_from_statement) - .collect(); - for name in &original_public_args { - wildcard_set.insert(name.clone()); - } - let mut wildcard_names: Vec = wildcard_set.into_iter().collect(); - wildcard_names.sort(); - let wildcard_index: HashMap = wildcard_names - .iter() - .enumerate() - .map(|(i, name)| (name.clone(), i)) - .collect(); + let public_args_set: HashSet = original_public_args.iter().cloned().collect(); - // Inverse: which statements reference each wildcard (by index). - let mut statements_using: Vec> = vec![Vec::new(); wildcard_names.len()]; - for (s, stmt) in pred.statements.iter().enumerate() { - let mut seen: HashSet = HashSet::new(); - for name in stmt.wildcard_names() { - let w = wildcard_index[name]; - if seen.insert(w) { - statements_using[w].push(s); - } - } - } - - let mut is_original_public = vec![false; wildcard_names.len()]; - for name in &original_public_args { - is_original_public[wildcard_index[name]] = true; - } - - MilpInput { - n: pred.statements.len(), - wildcard_names, - statements_using, - is_original_public, - original_public_args, - } -} - -// REVIEW(Edu): This is very cool, at a very low cost (because it resuses most of the existing -// cost) we get a different model that shows the bottleneck. -/// Solve the elastic LP at the given K, returning per-link slack and -/// wildcard membership for the binding links. Slack variables on each cap -/// turn the otherwise-infeasible model into one that minimises constraint -/// violation, exposing exactly which links are over their caps and by how -/// much. -fn solve_elastic_lp(k: usize, input: &MilpInput, params: &Params) -> Option> { - let max_args = Params::max_statement_args(); - let max_wildcards = params.max_custom_predicate_wildcards; - let num_wildcards = input.wildcard_names.len(); - let n = input.n; - - let mut vars = ProblemVariables::new(); - let v = declare_milp_vars(&mut vars, n, k, num_wildcards); - // REVIEW(Edu): In the circuit, adding more public arguments would be much more expensive than - // adding more private wildcards. Perhaps it makes sense to give more weight to `slack_pub`, - // so that we can learn more from the infiseability results. - let slack_pub: Vec = (0..k).map(|_| vars.add(variable().min(0.0))).collect(); - let slack_total: Vec = (0..k).map(|_| vars.add(variable().min(0.0))).collect(); - - let slack_term: Expression = (0..k) - .map(|i| Expression::from(slack_pub[i]) + slack_total[i]) - .sum(); - // Tiebreaker bound is n²k². Scale so even the worst-case tiebreaker total - // is < 1 — never enough to outweigh a single unit of slack. - let scale = 1.0 / ((n * n * k * k + 1) as f64); - let objective = slack_term + scale * source_order_tiebreaker(&v); - - let mut model = build_scip_model(vars, objective); - add_structural_constraints( - &mut model, - &v, - &input.statements_using, - &input.is_original_public, - ); - - // C8 elastic: Σ pubin[w][i] ≤ max_args + slack_pub[i]. - for i in 0..k { - let sum: Expression = (0..num_wildcards).map(|w| v.pubin[w][i]).sum(); - model.add_constraint(constraint!(sum <= max_args as f64 + slack_pub[i])); - } - - // C9 elastic: Σ (pubin + privin)[w][i] ≤ max_wildcards + slack_total[i]. - for i in 0..k { - let sum: Expression = (0..num_wildcards) - .map(|w| Expression::from(v.pubin[w][i]) + v.privin[w][i]) - .sum(); - model.add_constraint(constraint!(sum <= max_wildcards as f64 + slack_total[i])); - } - - let solution = model.solve().ok()?; - - let mut overshoots = Vec::new(); - for i in 0..k { - let pub_overflow = solution.value(slack_pub[i]).round() as usize; - let total_overflow = solution.value(slack_total[i]).round() as usize; - if pub_overflow == 0 && total_overflow == 0 { - continue; - } - let mut public_args_in = Vec::new(); - let mut private_args = Vec::new(); - for w in 0..num_wildcards { - if solution.value(v.pubin[w][i]) > SOLVER_BINARY_THRESHOLD { - public_args_in.push(input.wildcard_names[w].clone()); - } - if solution.value(v.privin[w][i]) > SOLVER_BINARY_THRESHOLD { - private_args.push(input.wildcard_names[w].clone()); - } - } - public_args_in.sort(); - private_args.sort(); - overshoots.push(LinkOvershoot { - link_index: i, - public_args_overflow: pub_overflow, - total_args_overflow: total_overflow, - public_args_in, - private_args, - }); - } - Some(overshoots) -} - -/// Diagnose why the splitter rejected `pred`. Runs an elastic version of the -/// MILP that allows the per-link caps to be violated by non-negative slack -/// and minimises total slack — the result tells you exactly which links -/// overshoot which caps and by how much. -/// -/// Only meaningful to call on inputs that produced -/// [`SplittingError::Infeasible`]. On feasible inputs the report's -/// `overshoots` will be empty. -pub fn analyze_infeasibility(pred: &CustomPredicateDef, params: &Params) -> InfeasibilityReport { - let input = prepare_milp_input(pred); - let k = compute_min_links(input.n); - let overshoots = solve_elastic_lp(k, &input, params).unwrap_or_default(); - InfeasibilityReport { - predicate: pred.name.name.clone(), - k, - overshoots, - } -} - -/// Split a predicate into a chain via MILP. Tries `K = K_min, K_min+1, ...`, -/// returning the first feasible chain or [`SplittingError::Infeasible`] if -/// no `K` up to `n` works. -fn split_into_chain( - pred: &CustomPredicateDef, - params: &Params, -) -> Result<(Vec, SplitChainInfo), SplittingError> { - let original_name = pred.name.name.clone(); - let conjunction = pred.conjunction_type; let real_statement_count = pred.statements.len(); - let input = prepare_milp_input(pred); - let n = input.n; + let ordering_result = order_constraints_optimally(pred.statements, &public_args_set); + let ordered_statements = ordering_result.statements; + let reorder_map = ordering_result.reorder_map; - let k_min = compute_min_links(n); - let mut found: Option<(usize, LinkAssignment)> = None; - for k in k_min..=n { - if let Some(assignment) = solve_milp_for_k( - n, - k, - &input.statements_using, - &input.is_original_public, - params, - ) { - found = Some((k, assignment)); - break; + let mut chain_links = Vec::new(); + let mut pos = 0; + let mut incoming_public = original_public_args.clone(); + + while pos < ordered_statements.len() { + let remaining = ordered_statements.len() - pos; + let is_last = remaining <= Params::max_custom_predicate_arity(); + + let bucket_size = if is_last { + remaining // Last predicate uses all remaining + } else { + Params::max_custom_predicate_arity() - 1 // Reserve slot for chain call + }; + + let end = pos + bucket_size; + + // Calculate liveness at this split boundary + let live_at_boundary = if is_last { + HashSet::new() + } else { + calculate_live_wildcards(&ordered_statements[pos..end], &ordered_statements[end..]) + }; + + // Check: Can we fit promoted wildcards in public args? + // Need to account for possible overlap between incoming_public and live_at_boundary + let incoming_set: HashSet<_> = incoming_public.iter().cloned().collect(); + let mut new_promotions: Vec<_> = live_at_boundary + .iter() + .filter(|w| !incoming_set.contains(*w)) + .cloned() + .collect(); + new_promotions.sort(); + let total_public = incoming_public.len() + new_promotions.len(); + if total_public > Params::max_statement_args() { + let context = crate::lang::error::SplitContext { + split_index: chain_links.len(), + statement_range: (pos, end), + incoming_public: incoming_public.clone(), + crossing_wildcards: new_promotions.clone(), + total_public, + }; + + let suggestion = generate_refactor_suggestion(&new_promotions, &ordered_statements); + + return Err(SplittingError::TooManyPublicArgsAtSplit { + predicate: original_name.clone(), + context: Box::new(context), + max_allowed: Params::max_statement_args(), + suggestion: suggestion.map(Box::new), + }); } + + // Calculate private args (used in this segment but not incoming and not outgoing) + let segment_wildcards: HashSet<_> = ordered_statements[pos..end] + .iter() + .flat_map(collect_wildcards_from_statement) + .collect(); + + let mut private_args: Vec = segment_wildcards + .difference(&incoming_set) + .filter(|w| !live_at_boundary.contains(*w)) + .cloned() + .collect(); + private_args.sort(); // Deterministic ordering + + // Check: Total args constraint (incoming + new promotions + private) + let public_count = incoming_public.len() + new_promotions.len(); + let private_count = private_args.len(); + let total_args = public_count + private_count; + if total_args > params.max_custom_predicate_wildcards { + return Err(SplittingError::TooManyTotalArgsInChainLink { + predicate: original_name.clone(), + link_index: chain_links.len(), + public_count, + private_count, + total_count: total_args, + max_allowed: params.max_custom_predicate_wildcards, + }); + } + + chain_links.push(ChainLink { + statements: ordered_statements[pos..end].to_vec(), + public_args_in: incoming_public.clone(), + private_args, + // new_promotions are already sorted and already filtered to exclude incoming_public + promoted_wildcards: new_promotions.clone(), + }); + + pos = end; + + // Extend incoming_public for the next link with the newly promoted wildcards. + // new_promotions is already filtered to exclude incoming_set, so no dedup needed. + incoming_public.extend(new_promotions); } - let (_k, assignment) = found.ok_or_else(|| SplittingError::Infeasible { - predicate: original_name.clone(), - max_links: n, - max_statement_args: Params::max_statement_args(), - max_wildcards: params.max_custom_predicate_wildcards, - })?; - - // Reorder map: original index → position in flattened chain. - let mut reorder_map = vec![0usize; n]; + // Backward pass: prune each continuation's public args to the minimal set needed. + // + // The forward pass accumulates incoming_public monotonically, so a continuation may + // inherit original public args that none of its statements (or downstream continuations) + // ever reference. A continuation must declare every public arg it receives, and the + // proof system constrains each declared arg - an arg that goes unused has no constraints + // and will not match the value the caller passes. + // + // Propagating from the last link backward ensures each continuation declares exactly the + // args it uses directly, plus any args its successor still needs. Link 0 (the original + // predicate) is left untouched - its public-arg signature is user-declared. { - let mut flat = 0usize; - for link in &assignment { - for &s in link { - reorder_map[s] = flat; - flat += 1; + let num_links = chain_links.len(); + if num_links > 1 { + // Collect wildcards referenced by each link's statements once. + let link_wildcards: Vec> = chain_links + .iter() + .map(|link| { + link.statements + .iter() + .flat_map(collect_wildcards_from_statement) + .collect() + }) + .collect(); + + let last = num_links - 1; + + // Seed: last link retains only args it directly references. + chain_links[last] + .public_args_in + .retain(|a| link_wildcards[last].contains(a)); + + // Propagate backward through intermediate continuation links (skip link 0). + for i in (1..last).rev() { + let needed_downstream: HashSet = + chain_links[i + 1].public_args_in.iter().cloned().collect(); + chain_links[i] + .public_args_in + .retain(|a| link_wildcards[i].contains(a) || needed_downstream.contains(a)); } } } - let chain_links = build_chain_links_from_assignment( - assignment, - &pred.statements, - &input.original_public_args, - ); - - // Build SplitChainInfo (execution order: innermost continuation first). + // Build SplitChainInfo from chain_links before generating predicates + // Pieces are in execution order: innermost continuation first, original last let num_links = chain_links.len(); let mut chain_pieces = Vec::new(); for i in (0..num_links).rev() { @@ -825,13 +647,14 @@ fn split_into_chain( validate_chain(&chain_predicates, params); // Reverse so continuations come before callers in declaration order. + // This ensures that when batched, continuations are in earlier batches + // and can be referenced by their callers. chain_predicates.reverse(); Ok((chain_predicates, chain_info)) } -/// Build the chain's [`CustomPredicateDef`]s from the per-link metadata, -/// inserting a chain call on every non-last link. +/// Phase 4: Generate synthetic predicates from chain links fn generate_chain_predicates( original_name: &str, chain_links: Vec, @@ -856,12 +679,15 @@ fn generate_chain_predicates( let is_last = i == chain_links.len() - 1; let mut statements = link.statements.clone(); + // Add chain call if not last if !is_last { let next_pred_name = Identifier { name: format!("{}_{}", original_name, i + 1), span: None, }; + // Create arguments for chain call: use next link's public_args_in + // which is current public_args_in extended with current promoted_wildcards let next_link = &chain_links[i + 1]; let chain_call_args: Vec = next_link .public_args_in @@ -883,13 +709,12 @@ fn generate_chain_predicates( statements.push(chain_call); } - // Build public args (incoming). - let public_args: Vec = link + // Build public args (incoming) + let public_args: Vec = link .public_args_in .iter() - .map(|name| TypedArg { + .map(|name| Identifier { name: name.clone(), - type_name: None, span: None, }) .collect(); @@ -908,11 +733,7 @@ fn generate_chain_predicates( Some( private_arg_names .into_iter() - .map(|name| TypedArg { - name, - type_name: None, - span: None, - }) + .map(|name| Identifier { name, span: None }) .collect(), ) }; @@ -1023,7 +844,7 @@ mod tests { let pred = parse_predicate(input); let params = Params::default(); - let result = split_predicate_if_needed(&pred, ¶ms); + let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); let split_result = result.unwrap(); @@ -1047,7 +868,7 @@ mod tests { let pred = parse_predicate(input); let params = Params::default(); // max_custom_predicate_arity = 5 - let result = split_predicate_if_needed(&pred, ¶ms); + let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); let split_result = result.unwrap(); @@ -1091,7 +912,7 @@ mod tests { let pred = parse_predicate(input); let params = Params::default(); // max_custom_predicate_arity = 5 - let result = split_predicate_if_needed(&pred, ¶ms); + let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); let split_result = result.unwrap(); @@ -1127,7 +948,7 @@ mod tests { let pred = parse_predicate(input); let params = Params::default(); // max_custom_predicate_arity = 5 - let result = split_predicate_if_needed(&pred, ¶ms); + let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); let split_result = result.unwrap(); @@ -1176,7 +997,7 @@ mod tests { let pred = parse_predicate(input); let params = Params::default(); - let result = split_predicate_if_needed(&pred, ¶ms); + let result = split_predicate_if_needed(pred, ¶ms); assert!(result.is_ok()); let split_result = result.unwrap(); @@ -1204,6 +1025,207 @@ mod tests { ); } + #[test] + fn test_statement_selection_prefers_lower_index_on_tie() { + // Two structurally symmetric statements produce identical heuristic scores. + // Determinism comes from the final index-based tie breaker. + let input = r#" + tie_break(A, B) = AND ( + Equal(A["x"], B["x"]) + Equal(A["y"], B["y"]) + ) + "#; + + let pred = parse_predicate(input); + let statements = pred.statements; + let remaining: HashSet = [0, 1].into_iter().collect(); + let active_wildcards = HashSet::new(); + + // A and B are the public args of tie_break(A, B) + let public_args: HashSet = ["A".to_string(), "B".to_string()].into_iter().collect(); + let key0 = statement_selection_key( + 0, + &statements, + &active_wildcards, + &remaining, + false, + &public_args, + ); + let key1 = statement_selection_key( + 1, + &statements, + &active_wildcards, + &remaining, + false, + &public_args, + ); + + assert_eq!(key0.0, key1.0, "Primary heuristic score should tie"); + assert_eq!(key0.1, key1.1, "Secondary tie-breaker metrics should tie"); + assert!( + key0 > key1, + "Lower original index should win deterministic final tie-breaker" + ); + + let selected = + find_best_next_statement(&statements, &remaining, &active_wildcards, 0, &public_args); + assert_eq!(selected, 0); + } + + #[test] + fn test_greedy_ordering_reduces_liveness() { + // This test verifies that our greedy ordering algorithm reduces wildcard liveness + // by clustering statements that use the same wildcards together. + // + // The predicate has 8 statements using 3 private wildcards (T1, T2, T3): + // - T1 used in statements 1, 4, 7 + // - T2 used in statements 2, 5, 8 + // - T3 used in statements 3, 6 + // + // NAIVE ORDERING (original order): + // Would interleave T1, T2, T3 usage throughout the predicate. + // When splitting at statement limit (5 statements per predicate): + // Predicate 1: statements 1-5 (introduces T1, T2, T3 - none complete) + // Predicate 2: statements 6-8 (all 3 wildcards still live) + // Result: 2 public args (A, B) + 3 promoted wildcards = 5 total in predicate 2 + // + // GREEDY ORDERING (our algorithm): + // Clusters statements by wildcard to minimize liveness: + // Groups T1 statements together, then T2, then T3 + // Predicate 1: completes some wildcards before the split point + // Predicate 2: fewer wildcards need to cross the boundary + // Result: 2 public args (A, B) + 1-2 promoted wildcards = 3-4 total in predicate 2 + let input = r#" + clustered(A, B, private: T1, T2, T3) = AND ( + Equal(T1["x"], 1) + Equal(T2["y"], 2) + Equal(T3["z"], 3) + Equal(T1["a"], 4) + Equal(T2["b"], 5) + Equal(T3["c"], 6) + Equal(T1["d"], A["result"]) + Equal(T2["e"], B["value"]) + ) + "#; + + let pred = parse_predicate(input); + let params = Params::default(); + + let result = split_predicate_if_needed(pred, ¶ms); + assert!(result.is_ok()); + + let split_result = result.unwrap(); + let chain = &split_result.predicates; + assert_eq!(chain.len(), 2, "Predicate should split into 2 links"); + + let second_pred = &chain[1]; + let second_pred_public_count = second_pred.args.public_args.len(); + + // Verify greedy ordering achieves better results than naive ordering would + // Started with 2 public args (A, B) + // Naive would have: 2 + 3 promoted = 5 public args in second predicate + // Greedy achieves: 2 + 1-2 promoted = 3-4 public args in second predicate + assert!( + second_pred_public_count <= 4, + "Greedy ordering should reduce promotions to ≤4 public args, but got {}", + second_pred_public_count + ); + } + + #[test] + fn test_error_message_formatting() { + // Test that error messages format correctly with detailed context + // We'll manually construct the error to test the formatting + use crate::lang::error::{RefactorSuggestion, SplitContext}; + + let context = SplitContext { + split_index: 0, + statement_range: (0, 4), + incoming_public: vec!["A".to_string(), "B".to_string(), "C".to_string()], + crossing_wildcards: vec!["T1".to_string(), "T2".to_string(), "T3".to_string()], + total_public: 6, + }; + + let suggestion = Some(RefactorSuggestion::GroupWildcardUsages { + wildcards: vec!["T1".to_string(), "T2".to_string(), "T3".to_string()], + }); + + let error = SplittingError::TooManyPublicArgsAtSplit { + predicate: "test_pred".to_string(), + context: Box::new(context), + max_allowed: 5, + suggestion: suggestion.map(Box::new), + }; + + let error_msg = format!("{}", error); + + // Verify the error message contains all the key information + assert!(error_msg.contains("test_pred")); + assert!(error_msg.contains("split boundary 0")); + assert!(error_msg.contains("3 incoming public")); + assert!(error_msg.contains("3 crossing wildcards")); + assert!(error_msg.contains("= 6 total")); + assert!(error_msg.contains("exceeds max of 5")); + assert!(error_msg.contains("Statements 0-3")); + assert!(error_msg.contains("Incoming public args: A, B, C")); + assert!(error_msg.contains("Wildcards crossing this boundary: T1, T2, T3")); + assert!(error_msg.contains("Suggestion:")); + assert!(error_msg.contains("Group operations for wildcards")); + + eprintln!("\n=== Example Error Message ===\n{}\n", error_msg); + } + + #[test] + fn test_error_too_many_total_args_formatting() { + // Test the TooManyTotalArgsInChainLink error message formatting + let error = SplittingError::TooManyTotalArgsInChainLink { + predicate: "huge_pred".to_string(), + link_index: 1, + public_count: 5, + private_count: 6, + total_count: 11, + max_allowed: 10, + }; + + let error_msg = format!("{}", error); + + // Verify the error message includes breakdown + assert!(error_msg.contains("huge_pred")); + assert!(error_msg.contains("chain link 1")); + assert!(error_msg.contains("5 public")); + assert!(error_msg.contains("6 private")); + assert!(error_msg.contains("= 11 total")); + assert!(error_msg.contains("exceeds max of 10")); + + eprintln!("\n=== Example TooManyTotalArgs Error ===\n{}\n", error_msg); + } + + #[test] + fn test_refactor_suggestion_reduce_wildcard_span() { + // Test the "reduce wildcard span" suggestion formatting + use crate::lang::error::RefactorSuggestion; + + let suggestion = RefactorSuggestion::ReduceWildcardSpan { + wildcard: "T".to_string(), + first_use: 0, + last_use: 7, + span: 7, + }; + + let suggestion_text = suggestion.format(); + + // Verify the suggestion formats correctly + assert!(suggestion_text.contains("'T'")); + assert!(suggestion_text.contains("used across 7 statements")); + assert!(suggestion_text.contains("statements 0-7")); + assert!(suggestion_text.contains("grouping all 'T' operations together")); + + eprintln!( + "\n=== Example ReduceWildcardSpan Suggestion ===\n{}\n", + suggestion_text + ); + } + // --- Regression tests --- /// Statements that reference only public args should be deferred until private-wildcard @@ -1232,7 +1254,7 @@ mod tests { let pred = parse_predicate(input); let params = Params::default(); - let result = split_predicate_if_needed(&pred, ¶ms); + let result = split_predicate_if_needed(pred, ¶ms); assert!( result.is_ok(), "Should find a valid split with ≤1 crossing wildcard, got: {:?}", @@ -1261,7 +1283,7 @@ mod tests { let pred = parse_predicate(input); let params = Params::default(); - let result = split_predicate_if_needed(&pred, ¶ms).unwrap(); + let result = split_predicate_if_needed(pred, ¶ms).unwrap(); // chain[0] is the continuation (_1 suffix), chain[1] is the original let continuation = result .predicates @@ -1282,389 +1304,4 @@ mod tests { cont_public ); } - - // =================================================================== - // Completeness probe for the splitter. - // - // `build_pred` constructs a CustomPredicateDef from a "wildcard set per - // statement" specification (cheaper than parsing). `find_any_feasible_ordering` - // brute-forces all permutations and uses the same per-link constraints as - // `split_into_chain` to check whether a feasible chain exists at all. - // =================================================================== - - fn build_pred( - name: &str, - public_args: &[&str], - private_args: &[&str], - stmt_wildcards: &[&[&str]], - ) -> CustomPredicateDef { - let statements: Vec = stmt_wildcards - .iter() - .map(|wcs| { - let args: Vec = wcs - .iter() - .map(|n| { - StatementTmplArg::Wildcard(Identifier { - name: n.to_string(), - span: None, - }) - }) - .collect(); - StatementTmpl { - predicate: PredicateRef::Local(Identifier { - name: "Equal".to_string(), - span: None, - }), - args, - span: None, - } - }) - .collect(); - - let private_args = if private_args.is_empty() { - None - } else { - Some( - private_args - .iter() - .map(|n| TypedArg { - name: n.to_string(), - type_name: None, - span: None, - }) - .collect(), - ) - }; - - CustomPredicateDef { - name: Identifier { - name: name.to_string(), - span: None, - }, - args: ArgSection { - public_args: public_args - .iter() - .map(|n| TypedArg { - name: n.to_string(), - type_name: None, - span: None, - }) - .collect(), - private_args, - span: None, - }, - conjunction_type: ConjunctionType::And, - statements, - span: None, - } - } - - /// Replicates the bucket-fill constraint check from `split_into_chain` for - /// a *fixed* ordering of statements. Returns Ok if the ordering produces a - /// valid chain, Err otherwise. - fn check_ordering_feasible( - ordered: &[StatementTmpl], - original_public_args: &[String], - params: &Params, - ) -> bool { - if ordered.len() <= Params::max_custom_predicate_arity() { - return true; - } - - let mut pos = 0; - let mut incoming_public: Vec = original_public_args.to_vec(); - - while pos < ordered.len() { - let remaining = ordered.len() - pos; - let is_last = remaining <= Params::max_custom_predicate_arity(); - let bucket_size = if is_last { - remaining - } else { - Params::max_custom_predicate_arity() - 1 - }; - let end = pos + bucket_size; - - let live: HashSet = if is_last { - HashSet::new() - } else { - let before: HashSet = ordered[pos..end] - .iter() - .flat_map(collect_wildcards_from_statement) - .collect(); - let after: HashSet = ordered[end..] - .iter() - .flat_map(collect_wildcards_from_statement) - .collect(); - before.intersection(&after).cloned().collect() - }; - - let incoming_set: HashSet = incoming_public.iter().cloned().collect(); - let new_promotions: Vec = live - .iter() - .filter(|w| !incoming_set.contains(*w)) - .cloned() - .collect(); - let total_public = incoming_public.len() + new_promotions.len(); - if total_public > Params::max_statement_args() { - return false; - } - - let segment_wildcards: HashSet = ordered[pos..end] - .iter() - .flat_map(collect_wildcards_from_statement) - .collect(); - let private_args: Vec = segment_wildcards - .difference(&incoming_set) - .filter(|w| !live.contains(*w)) - .cloned() - .collect(); - let total_args = total_public + private_args.len(); - if total_args > params.max_custom_predicate_wildcards { - return false; - } - - pos = end; - incoming_public.extend(new_promotions); - } - - true - } - - /// Brute-force search over all permutations of the predicate's statements - /// for *any* ordering that produces a feasible split. Returns the - /// permutation if found, else None. Caps at 9! to keep tests cheap. - fn find_any_feasible_ordering( - pred: &CustomPredicateDef, - params: &Params, - ) -> Option> { - use itertools::Itertools; - - let n = pred.statements.len(); - assert!(n <= 9, "brute-force capped at 9! permutations"); - - let original_public_args: Vec = pred - .args - .public_args - .iter() - .map(|id| id.name.clone()) - .collect(); - - for perm in (0..n).permutations(n) { - let ordered: Vec = - perm.iter().map(|&i| pred.statements[i].clone()).collect(); - if check_ordering_feasible(&ordered, &original_public_args, params) { - return Some(perm); - } - } - None - } - - /// 6 statements with 2 public args (A0, A1) and 5 private wildcards - /// (T0..T4). A feasible 4+2 chain exists where exactly 3 wildcards cross - /// the boundary (3 promotions + 2 incoming = 5 total public, hitting the - /// cap). The splitter must find one — a partition that puts an extra - /// wildcard across the boundary fails the per-link public-arg cap. - /// - /// Found by random search with seed 0xC0FFEE; inlined for determinism. - #[test] - fn test_splitter_handles_tight_public_arg_cap() { - let pred = build_pred( - "p", - &["A0", "A1"], - &["T0", "T1", "T2", "T3", "T4"], - &[ - &["T0", "T4", "T2"], - &["T1", "T3", "T4"], - &["T2", "T3", "T1"], - &["T4", "A0", "A1"], - &["T3", "T0", "T2"], - &["T0", "A1", "T1"], - ], - ); - let params = Params::default(); - - // Sanity: brute force confirms a feasible ordering exists. - let feasible = find_any_feasible_ordering(&pred, ¶ms); - assert!( - feasible.is_some(), - "expected at least one feasible permutation" - ); - - let result = split_predicate_if_needed(&pred, ¶ms); - assert!( - result.is_ok(), - "splitter rejected an input with a feasible ordering ({:?}): {}", - feasible.unwrap(), - result.err().unwrap() - ); - } - - /// A predicate with one statement that references 9 distinct wildcards - /// is unsplittable: any link containing that statement declares ≥ 9 - /// wildcards, exceeding the per-link cap of 8. `analyze_infeasibility` - /// must surface this as a non-zero `total_args_overflow` and list the - /// crowded link's private args. - #[test] - fn test_analyze_infeasibility_reports_total_args_overflow() { - let pred = build_pred( - "dense", - &["A"], - &["W0", "W1", "W2", "W3", "W4", "W5", "W6", "W7", "W8"], - &[ - &["W0", "W1", "W2", "W3", "W4", "W5", "W6", "W7", "W8"], - &["W0"], - &["W0"], - &["W0"], - &["W0"], - &["W0"], - ], - ); - let params = Params::default(); - - // Sanity: regular splitter rejects this input. - assert!(matches!( - split_predicate_if_needed(&pred, ¶ms), - Err(SplittingError::Infeasible { .. }) - )); - - let report = analyze_infeasibility(&pred, ¶ms); - assert_eq!(report.predicate, "dense"); - assert_eq!(report.k, 2); - - let total_overflow: usize = report - .overshoots - .iter() - .map(|o| o.total_args_overflow) - .sum(); - assert!( - total_overflow >= 1, - "expected ≥1 total-args overflow, got {} (overshoots: {:?})", - total_overflow, - report.overshoots - ); - - // The dense statement forces W1..W8 into one link as private args. - let crowded_link_has_dense_privates = report - .overshoots - .iter() - .any(|o| o.private_args.iter().filter(|w| w.starts_with('W')).count() >= 8); - assert!( - crowded_link_has_dense_privates, - "expected a binding link to declare 8+ W-wildcards as private; got {:?}", - report.overshoots - ); - - // Display impl shouldn't panic and should mention the predicate name. - let formatted = format!("{}", report); - assert!(formatted.contains("dense")); - } - - /// Randomized counterexample search. Run with - /// `cargo test --release search_splitter -- --ignored --nocapture`. - #[test] - #[ignore] - fn search_splitter_counterexample() { - // Tiny LCG so we don't pull rand as a dep. - struct Lcg(u64); - impl Lcg { - fn next(&mut self) -> u64 { - self.0 = self - .0 - .wrapping_mul(6364136223846793005) - .wrapping_add(1442695040888963407); - self.0 - } - fn rand_in(&mut self, n: usize) -> usize { - (self.next() as usize) % n - } - } - - let params = Params::default(); - let mut rng = Lcg(0xC0FFEE); - let mut checked = 0; - let mut found = 0; - - // Sweep over (n_stmts, n_pub, n_priv) combos. - for &n_stmts in &[6usize, 7, 8, 9] { - for &n_pub in &[1usize, 2, 3, 4] { - for &n_priv in &[2usize, 3, 4, 5] { - let pub_names: Vec = (0..n_pub).map(|i| format!("A{}", i)).collect(); - let priv_names: Vec = (0..n_priv).map(|i| format!("T{}", i)).collect(); - let all_names: Vec = - pub_names.iter().chain(priv_names.iter()).cloned().collect(); - - // Generate 200 random predicates per shape. - for trial in 0..200 { - // Each statement gets 2-3 distinct wildcards drawn from all_names. - let stmt_wildcards: Vec> = (0..n_stmts) - .map(|_| { - let arity = 2 + rng.rand_in(2); // 2 or 3 - let mut chosen = Vec::new(); - let mut tries = 0; - while chosen.len() < arity && tries < 20 { - let pick = all_names[rng.rand_in(all_names.len())].clone(); - if !chosen.contains(&pick) { - chosen.push(pick); - } - tries += 1; - } - chosen - }) - .collect(); - - let stmt_refs: Vec> = stmt_wildcards - .iter() - .map(|v| v.iter().map(|s| s.as_str()).collect()) - .collect(); - let stmt_slices: Vec<&[&str]> = - stmt_refs.iter().map(|v| v.as_slice()).collect(); - let pub_refs: Vec<&str> = pub_names.iter().map(|s| s.as_str()).collect(); - let priv_refs: Vec<&str> = priv_names.iter().map(|s| s.as_str()).collect(); - - let pred = build_pred("p", &pub_refs, &priv_refs, &stmt_slices); - - // Skip inputs that fail early validation (e.g. too many public args). - if validate_predicate_is_splittable(&pred).is_err() { - continue; - } - - checked += 1; - let feasible = find_any_feasible_ordering(&pred, ¶ms); - let split = split_predicate_if_needed(&pred, ¶ms); - - if let (Err(err), Some(perm)) = (split, feasible) { - found += 1; - eprintln!( - "\n=== COUNTEREXAMPLE #{} ===\n\ - shape: n={}, n_pub={}, n_priv={}, trial={}\n\ - statements (original order):", - found, n_stmts, n_pub, n_priv, trial - ); - for (i, wcs) in stmt_wildcards.iter().enumerate() { - eprintln!(" s{}: {:?}", i, wcs); - } - eprintln!("feasible permutation: {:?}", perm); - eprintln!("splitter error: {}\n", err); - - if found >= 3 { - eprintln!( - "Found {} counterexamples (out of {} checked); stopping.", - found, checked - ); - return; - } - } - } - } - } - } - - eprintln!( - "Searched {} predicates; found {} counterexamples.", - checked, found - ); - if found == 0 { - eprintln!("No counterexamples found."); - } - } } diff --git a/src/lang/frontend_ast_validate.rs b/src/lang/frontend_ast_validate.rs index 41c0eff..ef3d395 100644 --- a/src/lang/frontend_ast_validate.rs +++ b/src/lang/frontend_ast_validate.rs @@ -13,7 +13,7 @@ use hex::ToHex; use crate::{ lang::{frontend_ast::*, Module}, - middleware::{CustomPredicateBatch, Hash, NativePredicate, Params}, + middleware::{CustomPredicateBatch, Hash, NativePredicate}, }; /// A validated AST document with symbol table and diagnostics @@ -51,55 +51,6 @@ pub struct SymbolTable { pub wildcard_scopes: HashMap, /// Imported modules (bound name → Module reference) pub imported_modules: HashMap>, - /// Records visible in this scope (local declarations + imports). - pub records: HashMap, -} - -/// Resolved record schema: ordered entries plus a name→index lookup, with -/// provenance for diagnostics. Lowering uses `entry_index` to translate -/// dot-access like `r.foo` into the integer key for an `AnchoredKey`. -#[derive(Debug, Clone)] -pub struct RecordSchema { - pub entries: Vec, - pub entry_index: HashMap, - pub source: RecordSource, - pub source_span: Option, -} - -impl RecordSchema { - /// Build a schema from already-deduplicated entries. Callers that need - /// to surface a per-entry span on duplicates (e.g. local declarations) - /// should detect duplicates themselves before calling this. - pub fn from_entries( - entries: Vec, - source: RecordSource, - source_span: Option, - ) -> Self { - let entry_index = entries - .iter() - .enumerate() - .map(|(i, e)| (e.clone(), i)) - .collect(); - Self { - entries, - entry_index, - source, - source_span, - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum RecordSource { - Local, - Imported { module: String }, -} - -/// Build the `SymbolTable.records` key for a record imported via -/// `use module ... as alias`. Mirrors the `alias::Name` form used for -/// `TypeRef::Qualified`. -pub fn qualified_record_key(alias: &str, name: &str) -> String { - format!("{}::{}", alias, name) } /// Information about a predicate @@ -145,9 +96,6 @@ pub struct WildcardInfo { pub index: usize, pub is_public: bool, pub source_span: Option, - /// Record type tag for typed args (`name TypeName` syntax). The name - /// references an entry in `SymbolTable.records`. - pub record_type: Option, } /// Diagnostic message (warning or info) @@ -179,16 +127,14 @@ pub enum ParseMode { pub fn validate( document: Document, available_modules: &HashMap>, - params: &Params, mode: ParseMode, ) -> Result { - let validator = Validator::new(available_modules, params, mode); + let validator = Validator::new(available_modules, mode); validator.validate(document) } struct Validator { available_modules: HashMap>, - params: Params, symbols: SymbolTable, diagnostics: Vec, custom_predicate_count: usize, @@ -196,19 +142,13 @@ struct Validator { } impl Validator { - fn new( - available_modules: &HashMap>, - params: &Params, - mode: ParseMode, - ) -> Self { + fn new(available_modules: &HashMap>, mode: ParseMode) -> Self { Self { available_modules: available_modules.clone(), - params: params.clone(), symbols: SymbolTable { predicates: HashMap::new(), wildcard_scopes: HashMap::new(), imported_modules: HashMap::new(), - records: HashMap::new(), }, diagnostics: Vec::new(), custom_predicate_count: 0, @@ -241,13 +181,6 @@ impl Validator { } } - // Records before predicates so typed-arg resolution can find them. - for item in &document.items { - if let DocumentItem::RecordDef(record_def) = item { - self.process_record_def(record_def)?; - } - } - // Check mode constraints for predicate definitions let mut has_predicates = false; for item in &document.items { @@ -281,7 +214,7 @@ impl Validator { } } - // Enforce that modules have predicates and requests have a REQUEST block. + // Enforce that modules have predicates and requests have a REQUEST block match self.mode { ParseMode::Module if !has_predicates => { return Err(ValidationError::NoPredicatesInModule); @@ -311,22 +244,6 @@ impl Validator { span: use_stmt.span, })?; - // Flatten the imported module's locally-declared records into the - // symbol table under qualified keys (`alias::Name`). No transitive - // re-export — `Module.records` only carries local declarations. - for (record_name, entries) in &module.records { - self.symbols.records.insert( - qualified_record_key(alias, record_name), - RecordSchema::from_entries( - entries.clone(), - RecordSource::Imported { - module: alias.clone(), - }, - use_stmt.span, - ), - ); - } - // Store the module keyed by alias for later qualified name resolution self.symbols .imported_modules @@ -335,24 +252,6 @@ impl Validator { Ok(()) } - /// Returns the resolved `SymbolTable.records` key for a typed arg, or - /// `None` if the arg has no `type_name`. The key is the bare type name - /// for locals and `"alias::Name"` for qualified imports. Errors if the - /// tag doesn't refer to a known record. - fn resolve_typed_arg(&self, arg: &TypedArg) -> Result, ValidationError> { - let Some(type_ref) = &arg.type_name else { - return Ok(None); - }; - let key = type_ref.symbol_table_key(); - if !self.symbols.records.contains_key(&key) { - return Err(ValidationError::UnknownRecord { - name: key, - span: type_ref.span(), - }); - } - Ok(Some(key)) - } - fn process_use_intro_statement( &mut self, use_stmt: &UseIntroStatement, @@ -384,48 +283,6 @@ impl Validator { Ok(()) } - fn process_record_def(&mut self, record_def: &RecordDef) -> Result<(), ValidationError> { - let name = &record_def.name.name; - - if let Some(existing) = self.symbols.records.get(name) { - return Err(ValidationError::DuplicateRecord { - name: name.clone(), - first_span: existing.source_span, - second_span: record_def.name.span, - }); - } - - let max = self.params.max_record_entries(); - if record_def.entries.len() > max { - return Err(ValidationError::RecordTooManyEntries { - name: name.clone(), - count: record_def.entries.len(), - max, - span: record_def.span, - }); - } - - let mut seen = HashSet::with_capacity(record_def.entries.len()); - let mut entries = Vec::with_capacity(record_def.entries.len()); - for entry in &record_def.entries { - if !seen.insert(&entry.name) { - return Err(ValidationError::DuplicateRecordEntry { - record: name.clone(), - entry: entry.name.clone(), - span: entry.span, - }); - } - entries.push(entry.name.clone()); - } - - self.symbols.records.insert( - name.clone(), - RecordSchema::from_entries(entries, RecordSource::Local, record_def.name.span), - ); - - Ok(()) - } - fn process_custom_predicate_def( &mut self, pred_def: &CustomPredicateDef, @@ -461,14 +318,12 @@ impl Validator { span: arg.span, }); } - let record_type = self.resolve_typed_arg(arg)?; wildcards.insert( arg.name.clone(), WildcardInfo { index: wildcard_index, is_public: true, source_span: arg.span, - record_type, }, ); wildcard_index += 1; @@ -484,14 +339,12 @@ impl Validator { span: arg.span, }); } - let record_type = self.resolve_typed_arg(arg)?; wildcards.insert( arg.name.clone(), WildcardInfo { index: wildcard_index, is_public: false, source_span: arg.span, - record_type, }, ); wildcard_index += 1; @@ -590,7 +443,10 @@ impl Validator { wildcard_context: Option<(&str, &WildcardScope)>, ) -> Result<(), ValidationError> { let pred_name = stmt.predicate.predicate_name(); - let pred_span = stmt.predicate.span(); + let pred_span = match &stmt.predicate { + PredicateRef::Local(id) => id.span, + PredicateRef::Qualified { predicate, .. } => predicate.span, + }; let wc_names = match wildcard_context { Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(), @@ -691,44 +547,12 @@ impl Validator { } StatementTmplArg::AnchoredKey(ak) => { if let Some((pred_name, scope)) = wildcard_context { - let Some(wc_info) = scope.wildcards.get(&ak.root.name) else { + if !scope.wildcards.contains_key(&ak.root.name) { return Err(ValidationError::UndefinedWildcard { name: ak.root.name.clone(), pred_name: pred_name.to_string(), span: ak.root.span, }); - }; - // Records are integer-keyed, so string-key access on - // a typed wildcard is dead code at proof time. Reject - // dot access for unknown entries and bracket access - // outright; require `r.entry` for record-shaped data. - if let Some(record_name) = &wc_info.record_type { - match &ak.key { - AnchoredKeyPath::Dot(entry) => { - let schema = - self.symbols.records.get(record_name).expect( - "record_type was resolved at predicate-def time", - ); - if !schema.entry_index.contains_key(&entry.name) { - return Err(ValidationError::UnknownRecordEntry { - record: record_name.clone(), - entry: entry.name.clone(), - span: entry.span, - }); - } - } - AnchoredKeyPath::Bracket(_) => { - return Err(ValidationError::BracketAccessOnTypedWildcard { - wildcard: ak.root.name.clone(), - record: record_name.clone(), - span: ak.span, - }); - } - AnchoredKeyPath::Index(_) => unreachable!( - "AnchoredKeyPath::Index is introduced during lowering; \ - it cannot appear in the parsed AST that validation sees" - ), - } } } } @@ -814,51 +638,6 @@ impl Validator { } Ok(()) } - LiteralValue::Record(r) => { - let key = r.name.symbol_table_key(); - let Some(schema) = self.symbols.records.get(&key) else { - return Err(ValidationError::UnknownRecord { - name: key, - span: r.name.span(), - }); - }; - let mut seen: HashSet<&String> = HashSet::new(); - for entry in &r.entries { - if !schema.entry_index.contains_key(&entry.name.name) { - return Err(ValidationError::UnknownRecordEntry { - record: key.clone(), - entry: entry.name.name.clone(), - span: entry.name.span, - }); - } - if !seen.insert(&entry.name.name) { - return Err(ValidationError::DuplicateLiteralRecordEntry { - record: key.clone(), - entry: entry.name.name.clone(), - span: entry.name.span, - }); - } - self.validate_literal_value(&entry.value)?; - } - Ok(()) - } - LiteralValue::RecordEntryIndex { record, entry } => { - let key = record.symbol_table_key(); - let Some(schema) = self.symbols.records.get(&key) else { - return Err(ValidationError::UnknownRecord { - name: key, - span: record.span(), - }); - }; - if !schema.entry_index.contains_key(&entry.name) { - return Err(ValidationError::UnknownRecordEntry { - record: key, - entry: entry.name.clone(), - span: entry.span, - }); - } - Ok(()) - } _ => Ok(()), } } @@ -880,7 +659,7 @@ mod tests { ) -> Result { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - validate(document, modules, &Params::default(), ParseMode::Module) + validate(document, modules, ParseMode::Module) } fn parse_and_validate_request( @@ -889,7 +668,7 @@ mod tests { ) -> Result { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - validate(document, modules, &Params::default(), ParseMode::Request) + validate(document, modules, ParseMode::Request) } #[test] @@ -1067,9 +846,8 @@ mod tests { span: None, }, args: ArgSection { - public_args: vec![TypedArg { + public_args: vec![Identifier { name: "A".to_string(), - type_name: None, span: None, }], private_args: None, @@ -1080,12 +858,7 @@ mod tests { span: None, })], }; - let result = validate( - document, - &HashMap::new(), - &Params::default(), - ParseMode::Module, - ); + let result = validate(document, &HashMap::new(), ParseMode::Module); assert!(matches!( result, Err(ValidationError::EmptyStatementList { .. }) @@ -1163,247 +936,4 @@ mod tests { let result = parse_and_validate_request(input, &HashMap::new()); assert!(result.is_ok()); } - - // ----- Records ---------------------------------------------------------- - - #[test] - fn test_record_decl_accepted() { - let input = r#" - record ProcInputs = (foo, bar, baz) - my_pred(A) = AND(Equal(A["x"], 1)) - "#; - let validated = parse_and_validate_module(input, &HashMap::new()).unwrap(); - let schema = validated.symbols.records.get("ProcInputs").unwrap(); - assert_eq!(schema.entries, vec!["foo", "bar", "baz"]); - assert_eq!(schema.source, RecordSource::Local); - } - - #[test] - fn test_records_only_module_rejected() { - // A module needs at least one predicate; record-only modules are not - // a valid distribution unit. - let input = r#"record R = (x)"#; - assert!(matches!( - parse_and_validate_module(input, &HashMap::new()), - Err(ValidationError::NoPredicatesInModule) - )); - } - - #[test] - fn test_duplicate_record() { - let input = r#" - record R = (foo) - record R = (bar) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::DuplicateRecord { .. }) - )); - } - - #[test] - fn test_duplicate_entry_in_record() { - let input = r#" - record R = (foo, foo) - my_pred(A) = AND(Equal(A["x"], 1)) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::DuplicateRecordEntry { record, entry, .. }) - if record == "R" && entry == "foo" - )); - } - - #[test] - fn test_record_entry_cap() { - // Use a non-default depth so the cap reflects the parameter (not - // some hard-coded default). This pins three facts in one test: - // the param is wired through, the boundary is inclusive on accept, - // and cap + 1 is rejected. - let mut params = Params::default(); - params.containers.max_depth_small -= 1; - let cap = params.max_record_entries(); - let validate_with_n_entries = |n: usize| { - let entries: Vec = (0..n).map(|i| format!("f{i}")).collect(); - let input = format!( - "record Big = ({})\nmy_pred(A) = AND(Equal(A[\"x\"], 1))", - entries.join(", ") - ); - let parsed = parse_podlang(&input).expect("Failed to parse"); - let document = - parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - validate(document, &HashMap::new(), ¶ms, ParseMode::Module) - }; - assert!(validate_with_n_entries(cap).is_ok()); - let too_many = cap + 1; - assert!(matches!( - validate_with_n_entries(too_many), - Err(ValidationError::RecordTooManyEntries { count, max, .. }) - if count == too_many && max == cap - )); - } - - #[test] - fn test_typed_arg_resolves_known_record() { - let input = r#" - record R = (foo, bar) - my_pred(in R) = AND(Equal(in.foo, in.bar)) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(result.is_ok()); - let validated = result.unwrap(); - let scope = validated.symbols.wildcard_scopes.get("my_pred").unwrap(); - assert_eq!(scope.wildcards["in"].record_type.as_deref(), Some("R")); - } - - #[test] - fn test_typed_arg_unknown_record_rejected() { - let input = r#" - my_pred(in NonExistent) = AND(Equal(in.foo, 1)) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::UnknownRecord { name, .. }) if name == "NonExistent" - )); - } - - #[test] - fn test_dot_access_unknown_entry_rejected() { - let input = r#" - record R = (foo, bar) - my_pred(in R) = AND(Equal(in.quux, 1)) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::UnknownRecordEntry { record, entry, .. }) - if record == "R" && entry == "quux" - )); - } - - #[test] - fn test_dot_access_on_untyped_wildcard_unchecked() { - // r.foo on an untyped wildcard keeps current POD-string-key behavior; - // no record exists named anything that would constrain `foo`. - let input = r#" - my_pred(r) = AND(Equal(r.foo, 1)) - "#; - assert!(parse_and_validate_module(input, &HashMap::new()).is_ok()); - } - - #[test] - fn test_bracket_access_on_typed_wildcard_rejected() { - // Records are integer-keyed; string-key access on a record-typed - // wildcard is incoherent and would never resolve at proof time. - // Force the user to use `.entry` instead. - let input = r#" - record R = (foo) - my_pred(r R) = AND(Equal(r["foo"], 1)) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::BracketAccessOnTypedWildcard { wildcard, record, .. }) - if wildcard == "r" && record == "R" - )); - } - - #[test] - fn test_record_literal_unknown_record() { - let input = r#" - my_pred(A) = AND(Equal(A["x"], NotARecord(f: 1))) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::UnknownRecord { name, .. }) if name == "NotARecord" - )); - } - - #[test] - fn test_record_literal_unknown_entry() { - let input = r#" - record R = (foo, bar) - my_pred(A) = AND(Equal(A["x"], R(foo: 1, quux: 2))) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::UnknownRecordEntry { record, entry, .. }) - if record == "R" && entry == "quux" - )); - } - - #[test] - fn test_record_literal_nested() { - // Nested literals recurse through `validate_literal_value`: an unknown - // entry on the inner literal must still be caught. - let input = r#" - record Outer = (inner) - record Inner = (x, y) - my_pred(A) = AND(Equal(A["x"], Outer(inner: Inner(x: 1, z: 2)))) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::UnknownRecordEntry { record, entry, .. }) - if record == "Inner" && entry == "z" - )); - } - - #[test] - fn test_record_literal_duplicate_entry() { - let input = r#" - record R = (foo, bar) - my_pred(A) = AND(Equal(A["x"], R(foo: 1, foo: 2))) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::DuplicateLiteralRecordEntry { record, entry, .. }) - if record == "R" && entry == "foo" - )); - } - - #[test] - fn test_record_entry_index_resolves() { - // Validation accepts `R::bar` and the schema records bar at index 1 - // — the integer the literal will lower to. - let input = r#" - record R = (foo, bar) - my_pred(A) = AND(Contains(A, R::bar, 7)) - "#; - let validated = parse_and_validate_module(input, &HashMap::new()).unwrap(); - let schema = validated.symbols.records.get("R").unwrap(); - assert_eq!(schema.entry_index["bar"], 1); - } - - #[test] - fn test_record_entry_index_unknown_record() { - let input = r#" - my_pred(A) = AND(Contains(A, NotARecord::foo, 7)) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::UnknownRecord { name, .. }) if name == "NotARecord" - )); - } - - #[test] - fn test_record_entry_index_unknown_entry() { - let input = r#" - record R = (foo, bar) - my_pred(A) = AND(Contains(A, R::quux, 7)) - "#; - let result = parse_and_validate_module(input, &HashMap::new()); - assert!(matches!( - result, - Err(ValidationError::UnknownRecordEntry { record, entry, .. }) - if record == "R" && entry == "quux" - )); - } } diff --git a/src/lang/grammar.pest b/src/lang/grammar.pest index 0446b27..1c11baa 100644 --- a/src/lang/grammar.pest +++ b/src/lang/grammar.pest @@ -11,11 +11,7 @@ WHITESPACE = _{ (" " | "\t" | NEWLINE)+ } // COMMENT matches a line comment (//...\n) or block comment (/*...*/). COMMENT = _{ ("//" ~ (!NEWLINE ~ ANY)* | "/*" ~ (!"*/" ~ ANY)* ~ "*/" ) } -// Word-boundary anchor: the reserved word must not be followed by an -// identifier character, otherwise prefixes like `record_count` or `recorder` -// would be wrongly rejected by the `!reserved_identifier` lookahead in -// `identifier`. -reserved_identifier = @{ ("private" | "true" | "false" | "record") ~ !(ASCII_ALPHANUMERIC | "_") } +reserved_identifier = { "private" | "true" | "false" } // Define rules for identifiers (predicate names, wildcard names) // Must start with alpha or _, followed by alpha, numeric, or _ @@ -27,19 +23,10 @@ arg_section = { public_arg_list ~ ("," ~ private_kw ~ private_arg_list)? } -// `name` or `name TypeName` or `name module::TypeName`. The optional `type_tag` -// is a record type, either local or imported via a `use module ... as alias`. -typed_arg = { identifier ~ type_tag? } -type_tag = { qualified_type_ref | identifier } -qualified_type_ref = { identifier ~ "::" ~ identifier } -public_arg_list = { typed_arg ~ ("," ~ typed_arg)* } -private_arg_list = { typed_arg ~ ("," ~ typed_arg)* } +public_arg_list = { identifier ~ ("," ~ identifier)* } +private_arg_list = { identifier ~ ("," ~ identifier)* } -record_def = { - "record" ~ identifier ~ "=" ~ "(" ~ identifier ~ ("," ~ identifier)* ~ ")" -} - -document = { SOI ~ (use_module_statement | use_intro_statement | record_def | custom_predicate_def | request_def)* ~ EOI } +document = { SOI ~ (use_module_statement | use_intro_statement | custom_predicate_def | request_def)* ~ EOI } use_module_statement = { "use" ~ "module" ~ hash_hex ~ "as" ~ identifier } @@ -96,27 +83,9 @@ literal_value = { literal_string | predicate_hash_native | predicate_hash_external | - literal_record | - record_entry_index | literal_int } -// Record literal: `Name(Field: value, ...)` or `module::Name(...)`. Ordering -// in `literal_value` matters: must come after `Name(...)`-shaped prefix -// literals (PublicKey, SecretKey, Raw) so PEG doesn't shadow them, and BEFORE -// `record_entry_index` so `module::R(...)` isn't consumed as a 2-segment -// entry index — Pest backtracks to `record_entry_index` when no `(` follows. -literal_record = { - type_tag ~ "(" ~ record_entry ~ ("," ~ record_entry)* ~ ")" -} -record_entry = { identifier ~ ":" ~ literal_value } - -// Compile-time entry-index lookup: `R::foo` or `module::R::foo`. Resolves to -// the integer index of the named entry in the (possibly imported) record. -record_entry_index = { - identifier ~ "::" ~ identifier ~ ("::" ~ identifier)? -} - // Primitive literal types literal_int = @{ "-"? ~ ASCII_DIGIT+ } literal_bool = @{ "true" | "false" } diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 729c05b..291f7a6 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -40,10 +40,7 @@ use std::sync::Arc; pub use diagnostics::render_error; pub use error::{LangError, LangErrorKind}; -pub use frontend_ast_split::{ - analyze_infeasibility, InfeasibilityReport, LinkOvershoot, SplitChainInfo, SplitChainPiece, - SplitResult, -}; +pub use frontend_ast_split::{SplitChainInfo, SplitChainPiece, SplitResult}; pub use module::{Module, MultiOperationError}; pub use parser::{parse_podlang, Pairs, ParseError, Rule}; pub use pretty_print::PrettyPrint; @@ -87,7 +84,6 @@ fn load_module_inner( let validated = frontend_ast_validate::validate( document, &available_modules_map, - params, frontend_ast_validate::ParseMode::Module, )?; let module = frontend_ast_lower::lower_module(validated, params, name)?; @@ -128,7 +124,6 @@ fn parse_request_inner( let validated = frontend_ast_validate::validate( document, &available_modules_map, - params, frontend_ast_validate::ParseMode::Request, )?; let request = frontend_ast_lower::lower_request(validated, params)?; @@ -1078,248 +1073,4 @@ mod tests { e => panic!("Expected LangError::Validation, but got {:?}", e), } } - - // ---- Records: cross-module export ------------------------------------- - - #[test] - fn test_e2e_record_imported_predicate_compiles() -> Result<(), LangError> { - // Module A defines a record + a predicate that uses it. - // Module B imports A and writes its own predicate using - // `in a::ProcInputs`. B should compile without errors. - let params = Params::default(); - let module_a_src = r#" - record ProcInputs = (foo, bar, baz) - uses_record(in ProcInputs) = AND( - Equal(in.foo, in.baz) - ) - "#; - let module_a = Arc::new(load_module(module_a_src, "module_a", ¶ms, &[])?); - let a_hash = module_a.id().encode_hex::(); - - let module_b_src = format!( - r#" - use module 0x{} as a - - wraps_a(in a::ProcInputs) = AND( - Equal(in.bar, 7) - ) - "#, - a_hash - ); - let module_b = load_module(&module_b_src, "module_b", ¶ms, &[module_a])?; - assert_eq!(module_b.batch.predicates().len(), 1); - assert_eq!(module_b.batch.predicates()[0].name, "wraps_a"); - Ok(()) - } - - #[test] - fn test_e2e_imported_record_predicate_hash_matches_handwritten() -> Result<(), LangError> { - // Module B's predicate using `in a::ProcInputs` must hash identically - // to the same predicate built directly with an integer-keyed - // anchored key. Schema lives in A; the integer index is what gets - // baked into B's predicate body. - use crate::{ - frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder}, - middleware::Key, - }; - - let params = Params::default(); - let module_a_src = r#" - record ProcInputs = (foo, bar, baz) - stub(A) = AND(Equal(A["x"], 1)) - "#; - let module_a = Arc::new(load_module(module_a_src, "module_a", ¶ms, &[])?); - let a_hash = module_a.id().encode_hex::(); - - let module_b_src = format!( - r#" - use module 0x{} as a - - uses(in a::ProcInputs) = AND( - Equal(in.bar, 7) - ) - "#, - a_hash - ); - let module_b = load_module(&module_b_src, "module_b", ¶ms, &[module_a])?; - - let mut hand = CustomPredicateBatchBuilder::new(params.clone(), "module_b".into()); - let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) - .arg(BuilderArg::Key("in".into(), Key::from(1i64))) - .arg(BuilderArg::Literal(Value::from(7i64))); - hand.predicate_and("uses", &["in"], &[], &[stb]) - .expect("predicate_and"); - let hand_batch = hand.finish().expect("finish"); - - assert_eq!(module_b.batch.id(), hand_batch.id()); - Ok(()) - } - - #[test] - fn test_e2e_imported_record_literal_matches_array_root() -> Result<(), LangError> { - // Qualified record literal: `a::ProcInputs(...)` in the importer's - // body lowers to the same `Array` root as a local record literal - // built from the same schema, with entries placed at their schema - // indices. - use crate::middleware::{containers::Array, StatementTmplArg}; - - let params = Params::default(); - let module_a_src = r#" - record ProcInputs = (foo, bar, baz) - stub(A) = AND(Equal(A["x"], 1)) - "#; - let module_a = Arc::new(load_module(module_a_src, "module_a", ¶ms, &[])?); - let a_hash = module_a.id().encode_hex::(); - - // Source order is intentionally not schema order — the schema lookup - // through the qualified key has to map each entry back to its index. - let module_b_src = format!( - r#" - use module 0x{} as a - - uses(A) = AND( - Equal(A["data"], a::ProcInputs(baz: 30, foo: 10, bar: 20)) - ) - "#, - a_hash - ); - let module_b = load_module(&module_b_src, "module_b", ¶ms, &[module_a])?; - - let pred = &module_b.batch.predicates()[0]; - let stmt = &pred.statements()[0]; - let lowered = match &stmt.args()[1] { - StatementTmplArg::Literal(v) => v.clone(), - other => panic!("expected Literal at arg 1, got {other:?}"), - }; - - let expected = Value::from(Array::new(vec![ - Value::from(10i64), - Value::from(20i64), - Value::from(30i64), - ])); - - assert_eq!(lowered.raw(), expected.raw()); - Ok(()) - } - - #[test] - fn test_e2e_imported_record_literal_unknown_module_rejected() -> Result<(), LangError> { - // A literal that names a module the importer didn't bind must be - // rejected — the qualified key never gets into the symbol table, - // so validation surfaces `UnknownRecord` rather than producing a - // bogus lowered value. - use crate::lang::frontend_ast_validate::ValidationError; - - let params = Params::default(); - let module_a_src = r#" - record ProcInputs = (foo, bar) - stub(A) = AND(Equal(A["x"], 1)) - "#; - let module_a = Arc::new(load_module(module_a_src, "module_a", ¶ms, &[])?); - let a_hash = module_a.id().encode_hex::(); - - // Imported as `a`, but the literal references `b::ProcInputs`. - let module_b_src = format!( - r#" - use module 0x{} as a - - uses(A) = AND( - Equal(A["data"], b::ProcInputs(foo: 1, bar: 2)) - ) - "#, - a_hash - ); - let err = load_module(&module_b_src, "module_b", ¶ms, &[module_a]).unwrap_err(); - match err.kind { - LangErrorKind::Validation(e) => match *e { - ValidationError::UnknownRecord { name, .. } => { - assert_eq!(name, "b::ProcInputs"); - } - other => panic!("expected UnknownRecord, got {other:?}"), - }, - other => panic!("expected Validation, got {other:?}"), - } - Ok(()) - } - - #[test] - fn test_e2e_record_entry_index_proves_via_mock() -> Result<(), LangError> { - // End-to-end: a record-using predicate is satisfied by an Array - // value, with `Inputs::x` resolving the entry's integer index. - // MockProver runs the full proving path. - use crate::{ - backends::plonky2::mock::mainpod::MockProver, - frontend::{MainPodBuilder, Operation}, - middleware::{containers::Array, VDSet}, - }; - - let params = Params::default(); - let module = load_module( - r#" - record Inputs = (x, y) - at_x_is(arr, val) = AND( - Contains(arr, Inputs::x, val) - ) - "#, - "records_e2e", - ¶ms, - &[], - )?; - let at_x_is = module.batch.predicate_ref_by_name("at_x_is").unwrap(); - - // Build a 2-entry Array; arr[0] = 7 (Inputs::x), arr[1] = 13 (Inputs::y). - let arr = Array::new(vec![Value::from(7i64), Value::from(13i64)]); - - let vd_set = VDSet::new(&[]); - let mut builder = MainPodBuilder::new(¶ms, &vd_set); - let contains_st = builder - .priv_op(Operation::array_contains(arr, 0i64, 7i64)) - .unwrap(); - builder - .pub_op(Operation::custom(at_x_is, [contains_st])) - .unwrap(); - let pod = builder.prove(&MockProver {}).unwrap(); - pod.pod.verify().unwrap(); - Ok(()) - } - - #[test] - fn test_e2e_record_typed_dot_proves_via_mock() -> Result<(), LangError> { - // End-to-end: typed dot access is satisfied by an Array entry opened - // with an integer key, then used as an AnchoredKey in an Equal statement. - use crate::{ - backends::plonky2::mock::mainpod::MockProver, - frontend::{MainPodBuilder, Operation}, - middleware::{containers::Array, VDSet}, - }; - - let params = Params::default(); - let module = load_module( - r#" - record Inputs = (x, y) - at_x_is(arr Inputs, val) = AND( - Equal(arr.x, val) - ) - "#, - "records_dot_e2e", - ¶ms, - &[], - )?; - let at_x_is = module.batch.predicate_ref_by_name("at_x_is").unwrap(); - - let arr = Array::new(vec![Value::from(7i64), Value::from(13i64)]); - - let vd_set = VDSet::new(&[]); - let mut builder = MainPodBuilder::new(¶ms, &vd_set); - let contains_st = builder - .priv_op(Operation::array_contains(arr, 0i64, 7i64)) - .unwrap(); - let equal_st = builder.priv_op(Operation::eq(contains_st, 7i64)).unwrap(); - builder - .pub_op(Operation::custom(at_x_is, [equal_st])) - .unwrap(); - let pod = builder.prove(&MockProver {}).unwrap(); - pod.pod.verify().unwrap(); - Ok(()) - } } diff --git a/src/lang/module.rs b/src/lang/module.rs index 3744874..b926871 100644 --- a/src/lang/module.rs +++ b/src/lang/module.rs @@ -53,12 +53,6 @@ pub struct Module { /// Split chain info for predicates that were split pub split_chains: HashMap, - - /// Records declared locally in this module's source: name → ordered entry - /// list. Frontend metadata only — the middleware batch knows nothing - /// about records. No transitive re-export: a downstream importer - /// inherits only the records declared in this module's own source. - pub records: HashMap>, } impl Module { @@ -66,15 +60,6 @@ impl Module { pub fn new( batch: Arc, split_chains: HashMap, - ) -> Self { - Self::with_records(batch, split_chains, HashMap::new()) - } - - /// Like `new`, but seeds the module's locally-declared records. - pub fn with_records( - batch: Arc, - split_chains: HashMap, - records: HashMap>, ) -> Self { let predicate_index = batch .predicates() @@ -86,7 +71,6 @@ impl Module { batch, predicate_index, split_chains, - records, } } @@ -280,7 +264,6 @@ pub fn build_module( params: &Params, module_name: &str, symbols: &SymbolTable, - records: HashMap>, ) -> Result { // Extract predicates and collect split chains let mut predicates = Vec::new(); @@ -298,7 +281,7 @@ pub fn build_module( if predicates.is_empty() { // Return an empty module let empty_batch = CustomPredicateBatch::new(module_name.to_string(), vec![]); - return Ok(Module::with_records(empty_batch, split_chains, records)); + return Ok(Module::new(empty_batch, split_chains)); } // Build reference map: name -> index @@ -311,7 +294,7 @@ pub fn build_module( // Build the batch let batch = build_single_batch(&predicates, &reference_map, symbols, params, module_name)?; - Ok(Module::with_records(batch, split_chains, records)) + Ok(Module::new(batch, split_chains)) } /// Build a batch with properly resolved references @@ -423,14 +406,8 @@ mod tests { fn parse_and_validate(input: &str) -> (Vec, ValidatedAST) { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - let params = Params::default(); - let validated = validate( - document.clone(), - &HashMap::new(), - ¶ms, - ParseMode::Module, - ) - .expect("Failed to validate"); + let validated = validate(document.clone(), &HashMap::new(), ParseMode::Module) + .expect("Failed to validate"); let predicates = document .items @@ -471,7 +448,6 @@ mod tests { ¶ms, "TestModule", validated.symbols(), - HashMap::new(), ); assert!(result.is_ok()); @@ -495,7 +471,6 @@ mod tests { ¶ms, "TestModule", validated.symbols(), - HashMap::new(), ); assert!(result.is_ok()); @@ -520,7 +495,6 @@ mod tests { ¶ms, "TestModule", validated.symbols(), - HashMap::new(), ); assert!(result.is_ok()); @@ -553,7 +527,6 @@ mod tests { ¶ms, "TestModule", validated.symbols(), - HashMap::new(), ); assert!(result.is_ok()); @@ -588,7 +561,6 @@ mod tests { ¶ms, "TestModule", validated.symbols(), - HashMap::new(), ) .unwrap(); @@ -617,7 +589,7 @@ mod tests { // Split the predicate let mut split_results = Vec::new(); - for pred in &predicates { + for pred in predicates { let result = split_predicate_if_needed(pred, ¶ms).expect("Split failed"); split_results.push(result); } @@ -627,14 +599,8 @@ mod tests { assert_eq!(split_results[0].predicates.len(), 2); assert!(split_results[0].chain_info.is_some()); - let module = build_module( - split_results, - ¶ms, - "TestModule", - validated.symbols(), - HashMap::new(), - ) - .unwrap(); + let module = + build_module(split_results, ¶ms, "TestModule", validated.symbols()).unwrap(); // Verify chain info is preserved let chain_info = module.split_chains.get("large_pred").unwrap(); diff --git a/src/middleware/containers.rs b/src/middleware/containers.rs index 0a507ab..7c8e744 100644 --- a/src/middleware/containers.rs +++ b/src/middleware/containers.rs @@ -18,7 +18,7 @@ use crate::{ backends::plonky2::primitives::merkletree::MerkleTreeStateTransitionProof, middleware::{ db::{mem::MemDB, DB}, - Error, Hash, RawValue, Result, StrKey, TypedValue, Value, EMPTY_HASH, + Error, Hash, Key, RawValue, Result, TypedValue, Value, EMPTY_HASH, }, }; @@ -264,22 +264,22 @@ macro_rules! dict { ); ({ $($key:expr => $val:expr),* }) => ({ let mut map = ::std::collections::HashMap::new(); - $( map.insert($crate::middleware::StrKey::from($key), $crate::middleware::Value::from($val)); )* + $( map.insert($crate::middleware::Key::from($key), $crate::middleware::Value::from($val)); )* $crate::middleware::containers::Dictionary::new(map) }); } -// TODO: Replace all methods that receive a `&StrKey` by either `impl Into` for write +// TODO: Replace all methods that receive a `&Key` by either `impl Into` for write // methods and `impl AsRef` for read methods. // TODO: Replace all methods that receive a `&Value` in write methods for `Value`. Consider a // trait? impl Dictionary { - pub fn new(kvs: HashMap) -> Self { + pub fn new(kvs: HashMap) -> Self { Self { inner: Container::new( kvs.into_iter() - .map(|(k, v)| (Value::from(k.into_name()), v)) + .map(|(k, v)| (Value::from(k.name), v)) .collect(), ), } @@ -297,37 +297,29 @@ impl Dictionary { pub fn commitment(&self) -> Hash { self.inner.commitment() } - pub fn get(&self, key: &StrKey) -> Result> { + pub fn get(&self, key: &Key) -> Result> { self.inner.get(key.raw()) } - pub fn prove(&self, key: &StrKey) -> Result<(Value, MerkleProof)> { + pub fn prove(&self, key: &Key) -> Result<(Value, MerkleProof)> { self.inner.prove(key.raw()) } - pub fn prove_nonexistence(&self, key: &StrKey) -> Result { + pub fn prove_nonexistence(&self, key: &Key) -> Result { self.inner.prove_nonexistence(key.raw()) } - pub fn insert( - &mut self, - key: &StrKey, - value: &Value, - ) -> Result { + pub fn insert(&mut self, key: &Key, value: &Value) -> Result { self.inner - .insert(Value::from(key.name().to_string()), value.clone()) + .insert(Value::from(key.name.clone()), value.clone()) } - pub fn update( - &mut self, - key: &StrKey, - value: &Value, - ) -> Result { + pub fn update(&mut self, key: &Key, value: &Value) -> Result { self.inner.update(key.raw(), value.clone()) } - pub fn delete(&mut self, key: &StrKey) -> Result { + pub fn delete(&mut self, key: &Key) -> Result { self.inner.delete(key.raw()) } - pub fn verify(root: Hash, proof: &MerkleProof, key: &StrKey, value: &Value) -> Result<()> { + pub fn verify(root: Hash, proof: &MerkleProof, key: &Key, value: &Value) -> Result<()> { Container::verify(root, proof, key.raw(), value.raw()) } - pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &StrKey) -> Result<()> { + pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Key) -> Result<()> { Container::verify_nonexistence(root, proof, key.raw()) } pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { @@ -532,11 +524,11 @@ mod tests { fn _test_dict(db: Box) { let mut dict0 = Dictionary::empty_with_db(db.clone()); - dict0.insert(&StrKey::from("a"), &Value::from(1)).unwrap(); - dict0.insert(&StrKey::from("b"), &Value::from(2)).unwrap(); - dict0.update(&StrKey::from("a"), &Value::from(3)).unwrap(); - dict0.insert(&StrKey::from("c"), &Value::from(4)).unwrap(); - dict0.delete(&StrKey::from("c")).unwrap(); + dict0.insert(&Key::from("a"), &Value::from(1)).unwrap(); + dict0.insert(&Key::from("b"), &Value::from(2)).unwrap(); + dict0.update(&Key::from("a"), &Value::from(3)).unwrap(); + dict0.insert(&Key::from("c"), &Value::from(4)).unwrap(); + dict0.delete(&Key::from("c")).unwrap(); let kvs0 = dict0.dump().unwrap(); assert_eq!( kvs0, @@ -587,14 +579,14 @@ mod tests { fn _test_nested(db: Box) { let mut nested = Dictionary::empty_with_db(db.clone()); - nested.insert(&StrKey::from("a"), &Value::from(1)).unwrap(); - nested.insert(&StrKey::from("b"), &Value::from(2)).unwrap(); + nested.insert(&Key::from("a"), &Value::from(1)).unwrap(); + nested.insert(&Key::from("b"), &Value::from(2)).unwrap(); let nested_kvs0 = nested.dump().unwrap(); let mut dict0 = Dictionary::empty_with_db(db.clone()); - dict0.insert(&StrKey::from("x"), &Value::from(1)).unwrap(); + dict0.insert(&Key::from("x"), &Value::from(1)).unwrap(); dict0 - .insert(&StrKey::from("y"), &Value::from(nested.clone())) + .insert(&Key::from("y"), &Value::from(nested.clone())) .unwrap(); let kvs0 = dict0.dump().unwrap(); diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index f4f76cb..d212ca8 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -577,24 +577,21 @@ where } } -/// A key identified by a string name. Hash is computed via `hash_str`. #[derive(Clone, Debug, Eq)] -pub struct StrKey { +pub struct Key { name: String, hash: Hash, } -impl StrKey { +impl Key { pub fn new(name: String) -> Self { let hash = hash_str(&name); Self { name, hash } } + pub fn name(&self) -> &str { &self.name } - pub fn into_name(self) -> String { - self.name - } pub fn hash(&self) -> Hash { self.hash } @@ -603,31 +600,20 @@ impl StrKey { } } -impl PartialEq for StrKey { - fn eq(&self, other: &Self) -> bool { - self.hash == other.hash - } -} - -impl hash::Hash for StrKey { +impl hash::Hash for Key { fn hash(&self, state: &mut H) { self.hash.hash(state); } } -impl ToFields for StrKey { - fn to_fields(&self) -> Vec { - self.hash.to_fields() +impl PartialEq for Key { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash } } -impl fmt::Display for StrKey { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "\"{}\"", self.name) - } -} - -impl From for StrKey +// A Key can easily be created from a string-like type +impl From for Key where T: Into, { @@ -636,9 +622,30 @@ where } } -// `StrKey` serializes as a bare string. The cached hash is recomputed on -// deserialize via `StrKey::new`. -impl Serialize for StrKey { +impl ToFields for Key { + fn to_fields(&self) -> Vec { + self.hash.to_fields() + } +} + +impl fmt::Display for Key { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "\"{}\"", self.name)?; + Ok(()) + } +} + +impl From for RawValue { + fn from(key: Key) -> RawValue { + RawValue(key.hash.0) + } +} + +// When serializing a Key, we serialize only the name field, and not the hash. +// We can't directly tell Serde to render the whole struct as a string, so we +// implement our own serialization. It's important that if we change the +// structure of the Key struct, we update this implementation. +impl Serialize for Key { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, @@ -647,206 +654,29 @@ impl Serialize for StrKey { } } -impl<'de> Deserialize<'de> for StrKey { +impl<'de> Deserialize<'de> for Key { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { - String::deserialize(deserializer).map(StrKey::new) + let name = String::deserialize(deserializer)?; + Ok(Key::new(name)) } } -impl JsonSchema for StrKey { +// As per the above, we implement custom serialization for the Key type, and +// Schemars can't automatically generate a schema for it. Instead, we tell it +// to use the standard String schema. +impl JsonSchema for Key { fn schema_name() -> String { - "StrKey".to_string() + "Key".to_string() } + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { ::json_schema(gen) } } -/// A key identified by an integer index. The hash is taken directly from the -/// integer's `RawValue` encoding so that integer-keyed merkle trees (e.g. -/// `Array`, future shallow-MT variants) and integer-keyed `AnchoredKey`s -/// share the same leaf-key encoding. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct IndexKey { - value: i64, -} - -impl IndexKey { - pub fn new(value: i64) -> Self { - Self { value } - } - pub fn value(&self) -> i64 { - self.value - } - pub fn raw(&self) -> RawValue { - RawValue::from(self.value) - } - pub fn hash(&self) -> Hash { - Hash::from(self.raw()) - } -} - -impl ToFields for IndexKey { - fn to_fields(&self) -> Vec { - self.hash().to_fields() - } -} - -impl fmt::Display for IndexKey { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.value) - } -} - -impl From for IndexKey { - fn from(i: i64) -> Self { - Self::new(i) - } -} - -// `IndexKey` serializes as a bare integer. -impl Serialize for IndexKey { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.value.serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for IndexKey { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - i64::deserialize(deserializer).map(IndexKey::new) - } -} - -impl JsonSchema for IndexKey { - fn schema_name() -> String { - "IndexKey".to_string() - } - fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { - ::json_schema(gen) - } -} - -/// A key, either string-named or integer-indexed. -/// -/// APIs that only make sense for one variant (e.g. `Dictionary::insert` for -/// strings) take the inner type directly, lifting the variant check out to -/// the call site where the `match` on `Key` makes the missing arm visible at -/// compile time. -#[derive(Clone, Debug, Eq, Serialize, Deserialize, JsonSchema)] -#[serde(untagged)] -pub enum Key { - Str(StrKey), - Index(IndexKey), -} - -impl Key { - pub fn new(name: String) -> Self { - Key::Str(StrKey::new(name)) - } - pub fn as_str(&self) -> Option<&StrKey> { - match self { - Key::Str(k) => Some(k), - Key::Index(_) => None, - } - } - pub fn as_index(&self) -> Option<&IndexKey> { - match self { - Key::Str(_) => None, - Key::Index(k) => Some(k), - } - } - pub fn hash(&self) -> Hash { - match self { - Key::Str(k) => k.hash(), - Key::Index(k) => k.hash(), - } - } - pub fn raw(&self) -> RawValue { - match self { - Key::Str(k) => k.raw(), - Key::Index(k) => k.raw(), - } - } -} - -impl PartialEq for Key { - fn eq(&self, other: &Self) -> bool { - self.hash() == other.hash() - } -} - -impl hash::Hash for Key { - fn hash(&self, state: &mut H) { - self.hash().hash(state); - } -} - -impl ToFields for Key { - fn to_fields(&self) -> Vec { - self.hash().to_fields() - } -} - -impl fmt::Display for Key { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Key::Str(k) => k.fmt(f), - Key::Index(k) => k.fmt(f), - } - } -} - -impl From for Key { - fn from(k: StrKey) -> Self { - Key::Str(k) - } -} - -impl From for Key { - fn from(k: IndexKey) -> Self { - Key::Index(k) - } -} - -impl From<&str> for Key { - fn from(s: &str) -> Self { - Key::Str(StrKey::from(s)) - } -} - -impl From for Key { - fn from(s: String) -> Self { - Key::Str(StrKey::from(s)) - } -} - -impl From<&String> for Key { - fn from(s: &String) -> Self { - Key::Str(StrKey::from(s)) - } -} - -impl From for Key { - fn from(i: i64) -> Self { - Key::Index(IndexKey::from(i)) - } -} - -impl From for RawValue { - fn from(key: Key) -> RawValue { - key.raw() - } -} - #[derive(Clone, Debug, Eq, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub struct AnchoredKey { @@ -863,13 +693,13 @@ impl AnchoredKey { impl hash::Hash for AnchoredKey { fn hash(&self, state: &mut H) { self.root.hash(state); - self.key.hash().hash(state); + self.key.hash.hash(state); } } impl PartialEq for AnchoredKey { fn eq(&self, other: &Self) -> bool { - self.root == other.root && self.key.hash() == other.key.hash() + self.root == other.root && self.key.hash == other.key.hash } } @@ -1057,12 +887,6 @@ impl Params { self.max_statements - self.max_public_statements } - /// Maximum number of entries permitted in a `record` declaration: the - /// number of leaves in the small container merkle tree variant. - pub fn max_record_entries(&self) -> usize { - 2usize.pow(self.containers.max_depth_small as u32) - } - pub const fn statement_tmpl_arg_size() -> usize { 2 * HASH_SIZE + 1 } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 96f7fc1..8d3316c 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -860,12 +860,8 @@ impl fmt::Display for Operation { pub(crate) fn root_key_to_ak(root: &Value, key: &Value) -> Option { let root_hash = Hash::from(root.raw()); - if let Some(s) = key.as_str() { - Some(AnchoredKey::new(root_hash, Key::from(s))) - } else { - key.as_int() - .map(|i| AnchoredKey::new(root_hash, Key::from(i))) - } + key.as_str() + .map(|s| AnchoredKey::new(root_hash, Key::from(s))) } /// Returns the value associated with `output_ref`. diff --git a/src/middleware/serialization.rs b/src/middleware/serialization.rs index e66ec80..68e6efb 100644 --- a/src/middleware/serialization.rs +++ b/src/middleware/serialization.rs @@ -1,8 +1,12 @@ -use std::fmt::Write; +use std::{ + collections::{HashMap, HashSet}, + fmt::Write, +}; use plonky2::field::types::Field; -use serde::Deserialize; +use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; +use super::{Key, Value}; use crate::middleware::{F, HASH_SIZE, VALUE_SIZE}; fn serialize_field_tuple( @@ -100,3 +104,44 @@ where .parse() .map_err(serde::de::Error::custom) } + +// In order to serialize a Dictionary consistently, we want to order the +// key-value pairs by the key's name field. This has no effect on the hashes +// of the keys and therefore on the Merkle tree, but it makes the serialized +// output deterministic. +pub fn ordered_map( + value: &HashMap, + serializer: S, +) -> Result +where + S: Serializer, +{ + // Convert to Vec and sort by the key's name field + let mut pairs: Vec<_> = value.iter().collect(); + pairs.sort_by(|(k1, _), (k2, _)| k1.name.cmp(&k2.name)); + + // Serialize as a map + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(pairs.len()))?; + for (k, v) in pairs { + map.serialize_entry(k, v)?; + } + map.end() +} + +// Sets are serialized as sequences of elements, which are not ordered by +// default. We want to serialize them in a deterministic way, and we can +// achieve this by sorting the elements. This takes advantage of the fact that +// Value implements Ord. +pub fn ordered_set(value: &HashSet, serializer: S) -> Result +where + S: Serializer, +{ + let mut set = serializer.serialize_seq(Some(value.len()))?; + let mut sorted_values: Vec<&Value> = value.iter().collect(); + sorted_values.sort_by_key(|v| v.raw()); + for v in sorted_values { + set.serialize_element(v)?; + } + set.end() +}