diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 8de6871..7ca4d8c 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -15,8 +15,8 @@ use crate::{ #[derive(Clone, Debug)] pub enum BuilderArg { Literal(Value), - /// Key: (origin, key), where origin is Wildcard and key is Key - Key(String, String), + /// Key: (origin, key), where origin is the wildcard name. + Key(String, Key), WildcardLiteral(String), /// Reference to a same-batch predicate's identity hash (resolved by name in finish()). SelfPredicateHash(String), @@ -29,7 +29,7 @@ pub enum BuilderArg { /// case i. impl From<(&str, &str)> for BuilderArg { fn from((origin, field): (&str, &str)) -> Self { - Self::Key(origin.to_string(), field.to_string()) + Self::Key(origin.to_string(), Key::from(field)) } } /// case ii. @@ -219,9 +219,9 @@ impl CustomPredicateBatchBuilder { .map(|(arg_idx, a)| { Ok::<_, Error>(match a { BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), - BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey( + BuilderArg::Key(root_wc, key) => StatementTmplArg::AnchoredKey( resolve_wildcard(args, priv_args, root_wc)?, - Key::from(key_str), + key.clone(), ), BuilderArg::WildcardLiteral(v) => { StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?) diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index b6e8691..999a3a4 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -15,10 +15,10 @@ pub use serialization::SerializedMainPod; use crate::middleware::{ self, check_custom_pred, containers::{Container, Dictionary}, - fill_wildcard_values, hash_op, max_op, prod_op, root_key_to_ak, sum_op, AnchoredKey, Hash, Key, + fill_wildcard_values, hash_op, max_op, prod_op, root_key_to_ak, sum_op, AnchoredKey, Hash, MainPodInputs, MainPodProver, NativeOperation, OperationAux, OperationType, Params, PublicKey, - RawValue, Signature, Signer, Statement, StatementArg, VDSet, Value, ValueRef, BASE_PARAMS, - EMPTY_VALUE, + RawValue, Signature, Signer, Statement, StatementArg, StrKey, VDSet, Value, ValueRef, + BASE_PARAMS, EMPTY_VALUE, }; mod custom; @@ -39,7 +39,7 @@ pub use pod_request::*; #[derive(Clone, Debug)] pub struct SignedDictBuilder { pub params: Params, - pub kvs: HashMap, + pub kvs: HashMap, } impl fmt::Display for SignedDictBuilder { @@ -60,7 +60,7 @@ impl SignedDictBuilder { } } - pub fn insert(&mut self, key: impl Into, value: impl Into) { + pub fn insert(&mut self, key: impl Into, value: impl Into) { self.kvs.insert(key.into(), value.into()); } @@ -111,12 +111,12 @@ impl SignedDict { .then_some(()) .ok_or(Error::custom("Invalid signature!")) } - pub fn get(&self, key: impl Into) -> Option { + pub fn get(&self, key: impl Into) -> Option { self.dict.get(&key.into()).unwrap() } // Returns the Contains statement that defines key if it exists. - pub fn get_statement(&self, key: impl Into) -> Option { - let key: Key = key.into(); + pub fn get_statement(&self, key: impl Into) -> Option { + let key: StrKey = key.into(); self.dict.get(&key).unwrap().map(|value| { Statement::Contains( ValueRef::Literal(Value::from(self.dict.clone())), @@ -1112,11 +1112,11 @@ pub mod tests { OperationArg::Statement(st1), OperationArg::Literal(Value::from(1)), ], - OperationAux::MerkleProof(dict.prove(&Key::from("a")).unwrap().1), + OperationAux::MerkleProof(dict.prove(&StrKey::from("a")).unwrap().1), ))?; let mut new_dict = dict.clone(); - new_dict.insert(&Key::from("d"), &Value::from(4))?; + new_dict.insert(&StrKey::from("d"), &Value::from(4))?; builder.pub_op(Operation( OperationType::Native(NativeOperation::DictInsertFromEntries), @@ -1130,7 +1130,7 @@ pub mod tests { ))?; let mut new_old_dict = new_dict.clone(); - new_old_dict.delete(&Key::from("d"))?; + new_old_dict.delete(&StrKey::from("d"))?; assert_eq!(new_old_dict, dict); @@ -1144,7 +1144,7 @@ pub mod tests { OperationAux::None, ))?; - new_old_dict.update(&Key::from("c"), &55.into())?; + new_old_dict.update(&StrKey::from("c"), &55.into())?; builder.pub_op(Operation( OperationType::Native(NativeOperation::DictUpdateFromEntries), diff --git a/src/lang/diagnostics.rs b/src/lang/diagnostics.rs index 7807318..8c33e40 100644 --- a/src/lang/diagnostics.rs +++ b/src/lang/diagnostics.rs @@ -286,6 +286,123 @@ fn render_validation_error( "not allowed here", ) } + + ValidationError::DuplicateRecord { + name, + first_span, + second_span, + } => { + let title = format!("duplicate record definition: {}", name); + render_dual_span( + renderer, + source, + path, + &title, + first_span.as_ref(), + "first definition here", + second_span.as_ref(), + "duplicate definition", + ) + } + + ValidationError::RecordTooManyEntries { + name, + count, + max, + span, + } => { + let title = format!( + "record `{}` has {} entries, exceeding the limit of {}", + name, count, max + ); + render_with_optional_span( + renderer, + source, + path, + &title, + span.as_ref(), + "too many entries", + ) + } + + ValidationError::DuplicateRecordEntry { + record, + entry, + span, + } => { + let title = format!("duplicate entry `{}` in record `{}`", entry, record); + render_with_optional_span( + renderer, + source, + path, + &title, + span.as_ref(), + "already declared", + ) + } + + ValidationError::UnknownRecord { name, span } => { + let title = format!("unknown record type: {}", name); + render_with_optional_span( + renderer, + source, + path, + &title, + span.as_ref(), + "no such record", + ) + } + + ValidationError::UnknownRecordEntry { + record, + entry, + span, + } => { + let title = format!("record `{}` has no entry `{}`", record, entry); + render_with_optional_span( + renderer, + source, + path, + &title, + span.as_ref(), + "unknown entry", + ) + } + + ValidationError::DuplicateLiteralRecordEntry { + record, + entry, + span, + } => { + let title = format!("duplicate entry `{}` in `{}` literal", entry, record); + render_with_optional_span( + renderer, + source, + path, + &title, + span.as_ref(), + "already given", + ) + } + + ValidationError::BracketAccessOnTypedWildcard { + wildcard, + record, + span, + } => { + let title = format!( + "bracket access on `{}` (typed as record `{}`); use `{}.entry` instead", + wildcard, record, wildcard + ); + render_with_optional_span( + renderer, + source, + path, + &title, + span.as_ref(), + "string-key access on integer-keyed record", + ) + } } } diff --git a/src/lang/error.rs b/src/lang/error.rs index 792d4d8..ca2dabc 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -164,6 +164,52 @@ pub enum ValidationError { #[error("Requests must contain a REQUEST block")] NoRequestBlock, + + #[error("Duplicate record definition: {name}")] + DuplicateRecord { + name: String, + first_span: Option, + second_span: Option, + }, + + #[error("Record '{name}' has {count} entries, exceeding the limit of {max}")] + RecordTooManyEntries { + name: String, + count: usize, + max: usize, + span: Option, + }, + + #[error("Duplicate entry name '{entry}' in record '{record}'")] + DuplicateRecordEntry { + record: String, + entry: String, + span: Option, + }, + + #[error("Unknown record type: {name}")] + UnknownRecord { name: String, span: Option }, + + #[error("Record '{record}' has no entry '{entry}'")] + UnknownRecordEntry { + record: String, + entry: String, + span: Option, + }, + + #[error("Duplicate entry '{entry}' in record literal '{record}'")] + DuplicateLiteralRecordEntry { + record: String, + entry: String, + span: Option, + }, + + #[error("Bracket access '{wildcard}[...]' is not allowed on a wildcard typed as record '{record}'; use `{wildcard}.entry` instead")] + BracketAccessOnTypedWildcard { + wildcard: String, + record: String, + span: Option, + }, } /// Lowering errors from frontend AST lowering to middleware diff --git a/src/lang/frontend_ast.rs b/src/lang/frontend_ast.rs index dd0052c..9843fcd 100644 --- a/src/lang/frontend_ast.rs +++ b/src/lang/frontend_ast.rs @@ -20,10 +20,19 @@ pub struct Document { pub enum DocumentItem { UseModuleStatement(UseModuleStatement), UseIntroStatement(UseIntroStatement), + RecordDef(RecordDef), CustomPredicateDef(CustomPredicateDef), RequestDef(RequestDef), } +/// Record definition: `record Name = (entry1, entry2, ...)` +#[derive(Debug, Clone, PartialEq)] +pub struct RecordDef { + pub name: Identifier, + pub entries: Vec, + pub span: Option, +} + /// Module import statement: `use module 0xHASH as alias` #[derive(Debug, Clone, PartialEq)] pub struct UseModuleStatement { @@ -68,11 +77,48 @@ pub struct RequestDef { /// Argument section with public and optional private arguments #[derive(Debug, Clone, PartialEq)] pub struct ArgSection { - pub public_args: Vec, - pub private_args: Option>, + pub public_args: Vec, + pub private_args: Option>, pub span: Option, } +/// Predicate argument: `name`, `name TypeName`, or `name module::TypeName`. +/// The optional `type_name` names a record type whose dot-access entries are +/// resolved at lowering time. +#[derive(Debug, Clone, PartialEq)] +pub struct TypedArg { + pub name: String, + pub type_name: Option, + pub span: Option, +} + +/// Reference to a record type — either a local declaration in this module +/// or an import via `use module ... as alias`. +#[derive(Debug, Clone, PartialEq)] +pub enum TypeRef { + Local(Identifier), + Qualified { + module: Identifier, + name: Identifier, + }, +} + +impl TypeRef { + pub fn span(&self) -> Option { + match self { + TypeRef::Local(id) => id.span, + TypeRef::Qualified { name, .. } => name.span, + } + } + + /// Key used to look this reference up in `SymbolTable.records`: the bare + /// name for locals, `alias::Name` for qualified imports. Pairs with + /// `qualified_record_key` (which builds the same shape from raw parts). + pub fn symbol_table_key(&self) -> String { + self.to_string() + } +} + /// Conjunction type for custom predicates #[derive(Debug, Clone, Copy, PartialEq)] pub enum ConjunctionType { @@ -108,6 +154,13 @@ impl PredicateRef { PredicateRef::Qualified { predicate, .. } => &predicate.name, } } + + pub fn span(&self) -> Option { + match self { + PredicateRef::Local(id) => id.span, + PredicateRef::Qualified { predicate, .. } => predicate.span, + } + } } /// Arguments that can be passed to statements @@ -128,20 +181,15 @@ pub struct AnchoredKey { pub span: Option, } -impl AnchoredKey { - pub fn key_str(&self) -> &str { - match &self.key { - AnchoredKeyPath::Bracket(ls) => &ls.value, - AnchoredKeyPath::Dot(id) => &id.name, - } - } -} - /// Key path in an anchored key #[derive(Debug, Clone, PartialEq)] pub enum AnchoredKeyPath { Bracket(LiteralString), // ["key"] Dot(Identifier), // .key + /// Integer-indexed key. Not produced by the parser; introduced by lowering + /// when a `Dot` access on a record-typed wildcard is resolved to an entry + /// index. + Index(i64), } /// Identifier (variable names, predicate names, etc.) @@ -170,6 +218,7 @@ pub enum LiteralValue { Array(LiteralArray), Set(LiteralSet), Dict(LiteralDict), + Record(LiteralRecord), /// Hash of a native predicate (resolved immediately). NativePredicateHash(Identifier), /// Hash of an external module's predicate (resolved immediately). @@ -177,6 +226,13 @@ pub enum LiteralValue { module: Identifier, predicate: Identifier, }, + /// Compile-time integer literal that resolves to the index of a named + /// entry in a record schema: `R::foo` (local) or `mod::R::foo` (imported). + /// Lowers to `Value::from(idx as i64)` after schema resolution. + RecordEntryIndex { + record: TypeRef, + entry: Identifier, + }, } /// Integer literal @@ -250,6 +306,23 @@ pub struct DictPair { pub span: Option, } +/// Record literal: `Name(Entry: value, ...)` (local) or +/// `module::Name(Entry: value, ...)` (imported). Entries may appear in any +/// order; the schema (resolved in validation) maps each to its index. +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralRecord { + pub name: TypeRef, + pub entries: Vec, + pub span: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct RecordEntryLiteral { + pub name: Identifier, + pub value: LiteralValue, + pub span: Option, +} + /// Source location information for error reporting and formatting #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Span { @@ -276,6 +349,7 @@ impl fmt::Display for DocumentItem { match self { DocumentItem::UseModuleStatement(u) => write!(f, "{}", u), DocumentItem::UseIntroStatement(u) => write!(f, "{}", u), + DocumentItem::RecordDef(r) => write!(f, "{}", r), DocumentItem::CustomPredicateDef(c) => write!(f, "{}", c), DocumentItem::RequestDef(r) => write!(f, "{}", r), } @@ -362,6 +436,38 @@ impl fmt::Display for ArgSection { } } +impl fmt::Display for TypedArg { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name)?; + if let Some(t) = &self.type_name { + write!(f, " {}", t)?; + } + Ok(()) + } +} + +impl fmt::Display for TypeRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TypeRef::Local(id) => write!(f, "{}", id), + TypeRef::Qualified { module, name } => write!(f, "{}::{}", module, name), + } + } +} + +impl fmt::Display for RecordDef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "record {} = (", self.name)?; + for (i, entry) in self.entries.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", entry)?; + } + write!(f, ")") + } +} + impl fmt::Display for ConjunctionType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -418,6 +524,7 @@ impl fmt::Display for AnchoredKey { match &self.key { AnchoredKeyPath::Bracket(s) => write!(f, "{}[{}]", self.root, s), AnchoredKeyPath::Dot(id) => write!(f, "{}.{}", self.root, id), + AnchoredKeyPath::Index(i) => write!(f, "{}[{}]", self.root, i), } } } @@ -434,16 +541,39 @@ impl fmt::Display for LiteralValue { LiteralValue::Array(a) => write!(f, "{}", a), LiteralValue::Set(s) => write!(f, "{}", s), LiteralValue::Dict(d) => write!(f, "{}", d), + LiteralValue::Record(r) => write!(f, "{}", r), LiteralValue::NativePredicateHash(id) => { write!(f, "@native_predicate({})", id) } LiteralValue::ExternalPredicateHash { module, predicate, .. } => write!(f, "@external_predicate({}, {})", module, predicate), + LiteralValue::RecordEntryIndex { record, entry } => { + write!(f, "{}::{}", record, entry) + } } } } +impl fmt::Display for LiteralRecord { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}(", self.name)?; + for (i, entry) in self.entries.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", entry)?; + } + write!(f, ")") + } +} + +impl fmt::Display for RecordEntryLiteral { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}: {}", self.name, self.value) + } +} + impl fmt::Display for LiteralInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.value) @@ -562,6 +692,9 @@ pub mod parse { inner_pair, ))); } + Rule::record_def => { + items.push(DocumentItem::RecordDef(parse_record_def(inner_pair))); + } Rule::custom_predicate_def => { items.push(DocumentItem::CustomPredicateDef( parse_custom_predicate_def(inner_pair)?, @@ -687,16 +820,16 @@ pub mod parse { Rule::public_arg_list => { public_args = inner_pair .into_inner() - .filter(|p| p.as_rule() == Rule::identifier) - .map(parse_identifier) + .filter(|p| p.as_rule() == Rule::typed_arg) + .map(parse_typed_arg) .collect(); } Rule::private_arg_list => { private_args = Some( inner_pair .into_inner() - .filter(|p| p.as_rule() == Rule::identifier) - .map(parse_identifier) + .filter(|p| p.as_rule() == Rule::typed_arg) + .map(parse_typed_arg) .collect(), ); } @@ -711,6 +844,50 @@ pub mod parse { } } + fn parse_typed_arg(pair: Pair) -> TypedArg { + assert_eq!(pair.as_rule(), Rule::typed_arg); + let span = get_span(&pair); + let mut inner = pair.into_inner(); + let name_pair = inner.next().unwrap(); + let name = name_pair.as_str().to_string(); + let type_name = inner.next().map(parse_type_tag); + TypedArg { + name, + type_name, + span: Some(span), + } + } + + fn parse_type_tag(pair: Pair) -> TypeRef { + assert_eq!(pair.as_rule(), Rule::type_tag); + let inner = pair.into_inner().next().expect("type_tag has one child"); + match inner.as_rule() { + Rule::identifier => TypeRef::Local(parse_identifier(inner)), + Rule::qualified_type_ref => { + let mut idents = inner.into_inner(); + let module = parse_identifier(idents.next().unwrap()); + let name = parse_identifier(idents.next().unwrap()); + TypeRef::Qualified { module, name } + } + other => unreachable!("unexpected type_tag inner rule: {other:?}"), + } + } + + fn parse_record_def(pair: Pair) -> RecordDef { + assert_eq!(pair.as_rule(), Rule::record_def); + let span = get_span(&pair); + let mut idents = pair + .into_inner() + .filter(|p| p.as_rule() == Rule::identifier); + let name = parse_identifier(idents.next().unwrap()); + let entries: Vec<_> = idents.map(parse_identifier).collect(); + RecordDef { + name, + entries, + span: Some(span), + } + } + fn parse_conjunction_type(pair: Pair) -> ConjunctionType { assert_eq!(pair.as_rule(), Rule::conjunction_type); match pair.as_str() { @@ -845,6 +1022,7 @@ pub mod parse { Rule::literal_array => Ok(LiteralValue::Array(parse_literal_array(inner)?)), Rule::literal_set => Ok(LiteralValue::Set(parse_literal_set(inner)?)), Rule::literal_dict => Ok(LiteralValue::Dict(parse_literal_dict(inner)?)), + Rule::literal_record => Ok(LiteralValue::Record(parse_literal_record(inner)?)), Rule::predicate_hash_native => { let id = parse_identifier(inner.into_inner().next().unwrap()); Ok(LiteralValue::NativePredicateHash(id)) @@ -855,10 +1033,55 @@ pub mod parse { let predicate = parse_identifier(parts.next().unwrap()); Ok(LiteralValue::ExternalPredicateHash { module, predicate }) } + Rule::record_entry_index => { + let mut parts = inner.into_inner(); + let first = parse_identifier(parts.next().unwrap()); + let second = parse_identifier(parts.next().unwrap()); + let (record, entry) = match parts.next().map(parse_identifier) { + Some(third) => ( + TypeRef::Qualified { + module: first, + name: second, + }, + third, + ), + None => (TypeRef::Local(first), second), + }; + Ok(LiteralValue::RecordEntryIndex { record, entry }) + } _ => unreachable!("Unexpected literal value rule: {:?}", inner.as_rule()), } } + fn parse_literal_record(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::literal_record); + let span = get_span(&pair); + let mut inner = pair.into_inner(); + let name = parse_type_tag(inner.next().unwrap()); + let entries: Result, _> = inner + .filter(|p| p.as_rule() == Rule::record_entry) + .map(parse_record_entry) + .collect(); + Ok(LiteralRecord { + name, + entries: entries?, + span: Some(span), + }) + } + + fn parse_record_entry(pair: Pair) -> Result { + assert_eq!(pair.as_rule(), Rule::record_entry); + let span = get_span(&pair); + let mut inner = pair.into_inner(); + let name = parse_identifier(inner.next().unwrap()); + let value = parse_literal_value(inner.next().unwrap())?; + Ok(RecordEntryLiteral { + name, + value, + span: Some(span), + }) + } + fn parse_literal_int(pair: Pair) -> Result { assert_eq!(pair.as_rule(), Rule::literal_int); let value = pair @@ -1085,16 +1308,29 @@ mod tests { u.name.span = None; u.intro_hash.span = None; } + DocumentItem::RecordDef(r) => { + r.span = None; + r.name.span = None; + for e in &mut r.entries { + e.span = None; + } + } DocumentItem::CustomPredicateDef(c) => { c.span = None; c.name.span = None; c.args.span = None; for arg in &mut c.args.public_args { arg.span = None; + if let Some(t) = &mut arg.type_name { + clear_type_ref_spans(t); + } } if let Some(private) = &mut c.args.private_args { for arg in private { arg.span = None; + if let Some(t) = &mut arg.type_name { + clear_type_ref_spans(t); + } } } for stmt in &mut c.statements { @@ -1111,6 +1347,16 @@ mod tests { } } + fn clear_type_ref_spans(t: &mut TypeRef) { + match t { + TypeRef::Local(id) => id.span = None, + TypeRef::Qualified { module, name } => { + module.span = None; + name.span = None; + } + } + } + fn clear_predicate_ref_spans(pred_ref: &mut PredicateRef) { match pred_ref { PredicateRef::Local(id) => id.span = None, @@ -1134,6 +1380,7 @@ mod tests { match &mut ak.key { AnchoredKeyPath::Bracket(s) => s.span = None, AnchoredKeyPath::Dot(id) => id.span = None, + AnchoredKeyPath::Index(_) => {} } } StatementTmplArg::SelfPredicateHash(id) => id.span = None, @@ -1172,6 +1419,15 @@ mod tests { clear_literal_spans(&mut pair.value); } } + LiteralValue::Record(r) => { + r.span = None; + clear_type_ref_spans(&mut r.name); + for entry in &mut r.entries { + entry.span = None; + entry.name.span = None; + clear_literal_spans(&mut entry.value); + } + } LiteralValue::NativePredicateHash(id) => id.span = None, LiteralValue::ExternalPredicateHash { module, predicate, .. @@ -1179,6 +1435,10 @@ mod tests { module.span = None; predicate.span = None; } + LiteralValue::RecordEntryIndex { record, entry } => { + clear_type_ref_spans(record); + entry.span = None; + } } } @@ -1268,6 +1528,139 @@ mod tests { test_roundtrip(input); } + #[test] + fn test_record_decl() { + let input = r#"record ProcInputs = (foo, bar, baz)"#; + test_roundtrip(input); + } + + #[test] + fn test_record_decl_single_entry() { + let input = r#"record Singleton = (only)"#; + test_roundtrip(input); + } + + #[test] + fn test_typed_arg_in_predicate() { + let input = r#"record ProcInputs = (foo, bar, baz) +my_pred(in ProcInputs, other) = AND ( + Equal(in.foo, other) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_typed_arg_mixed_with_untyped() { + let input = r#"record R = (x, y) +mixed(a, b R, c, private: d, e R) = AND ( + Equal(a, c) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_typed_arg_qualified() { + // Qualified type tag references an imported record; the parser + // accepts it without inspecting the import (validation is downstream). + let input = r#"my_pred(in some_module::ProcInputs) = AND ( + Equal(in.foo, in.bar) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_record_literal_full() { + let input = r#"REQUEST( + Equal(A["data"], ProcInputs(foo: 1, bar: 2, baz: 3)) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_record_literal_sparse() { + let input = r#"REQUEST( + Equal(A["data"], ProcInputs(bar: 42)) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_record_literal_empty_rejected() { + // Record literals require at least one entry — an empty literal + // would never validate (no schema has zero entries), so reject at + // parse time for a clearer error. + let input = r#"REQUEST( + Equal(A["data"], Empty()) +)"#; + let parsed = crate::lang::parser::parse_podlang(input); + assert!( + parsed.is_err(), + "expected empty record literal `Empty()` to be rejected" + ); + } + + #[test] + fn test_record_entry_index_local() { + // `R::foo` resolves to the entry's integer index at compile time. + let input = r#"REQUEST( + Contains(A, R::foo, 7) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_record_entry_index_qualified() { + // `mod::R::foo` for an imported record. + let input = r#"REQUEST( + Contains(A, some_mod::R::foo, 7) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_record_literal_nested_value() { + let input = r#"REQUEST( + Equal(A["data"], R(items: [1, 2, 3], other: {"k": "v"})) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_record_literal_qualified() { + // Imported record literal: `module::R(foo: 1)`. Parses with + // `TypeRef::Qualified` as the head; PEG ordering means the + // `module::R` prefix is consumed by `literal_record` rather than + // shadowed by `record_entry_index`. + let input = r#"REQUEST( + Equal(A["data"], some_mod::R(foo: 1, bar: 2)) +)"#; + test_roundtrip(input); + } + + #[test] + fn test_record_keyword_reserved() { + // `record` may not appear as an identifier name. + let input = r#"record record = (foo)"#; + let parsed = crate::lang::parser::parse_podlang(input); + assert!( + parsed.is_err(), + "expected `record` to be rejected as an identifier" + ); + } + + #[test] + fn test_reserved_word_prefix_allowed_as_identifier() { + // The reserved-word check is anchored at a word boundary, so only + // the exact keyword is rejected. Identifiers that merely begin with + // a reserved word (`record_count`, `recorder`, `record_field`, the + // predicate name `record_using_pred`) must parse normally. + let input = r#"record Outer = (record_field, recorder) +record_using_pred(record_count, recordOwner) = AND ( + Equal(record_count, recordOwner) +)"#; + test_roundtrip(input); + } + #[test] fn test_complete_document() { let input = r#"use module 0x0000000000000000000000000000000000000000000000000000000000000000 as imported diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index fb00def..c89d045 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -13,13 +13,13 @@ use crate::{ lang::{ frontend_ast::*, frontend_ast_split, - frontend_ast_validate::{PredicateKind, SymbolTable, ValidatedAST}, + frontend_ast_validate::{PredicateKind, RecordSource, SymbolTable, ValidatedAST}, module, Module, }, middleware::{ - self, containers, CustomPredicateRef, IntroPredicateRef, Key, NativePredicate, Params, - Predicate, StatementTmpl as MWStatementTmpl, StatementTmplArg as MWStatementTmplArg, Value, - Wildcard, + self, containers, db::mem::MemDB, CustomPredicateRef, IntroPredicateRef, Key, + NativePredicate, Params, Predicate, StatementTmpl as MWStatementTmpl, + StatementTmplArg as MWStatementTmplArg, StrKey, Value, Wildcard, }, }; @@ -158,8 +158,10 @@ fn resolve_local_predicate( /// Lower a literal value from AST to middleware Value. /// /// This is a pure conversion that cannot fail for context-free literals. -/// Panics on ExternalPredicateHash — use `lower_literal_with_context` when -/// external predicate references may appear (e.g. inside containers). +/// Panics on `ExternalPredicateHash`, `Record`, and `RecordEntryIndex` — +/// use `lower_literal_with_context` when any of those may appear (records +/// and entry indices need the symbol table to resolve the record schema; +/// external predicate hashes need the imported-module table). pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { match lit { LiteralValue::Int(i) => Value::from(i.value), @@ -184,7 +186,7 @@ pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { .pairs .iter() .map(|pair| { - let key = Key::from(pair.key.value.as_str()); + let key = StrKey::from(pair.key.value.as_str()); let value = lower_literal(&pair.value); (key, value) }) @@ -192,6 +194,11 @@ pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { let dict = containers::Dictionary::new(pairs); Value::from(dict) } + LiteralValue::Record(_) => { + unreachable!( + "Record literals must be lowered with context via lower_literal_with_context" + ) + } LiteralValue::NativePredicateHash(id) => { let np = NativePredicate::from_str(&id.name).expect("validated native predicate"); Value::from(Predicate::Native(np).hash()) @@ -201,6 +208,11 @@ pub(crate) fn lower_literal(lit: &LiteralValue) -> Value { "ExternalPredicateHash must be lowered with context via lower_literal_with_context" ) } + LiteralValue::RecordEntryIndex { .. } => { + unreachable!( + "RecordEntryIndex must be lowered with context via lower_literal_with_context" + ) + } } } @@ -252,13 +264,36 @@ pub fn lower_literal_with_context( .pairs .iter() .map(|pair| { - let key = Key::from(pair.key.value.as_str()); + let key = StrKey::from(pair.key.value.as_str()); let value = lower_literal_with_context(&pair.value, symbols, context)?; Ok((key, value)) }) .collect::>()?; Ok(Value::from(containers::Dictionary::new(pairs))) } + LiteralValue::Record(r) => { + // The schema fixes each entry's index, so source order doesn't + // affect the merkle root and missing entries stay missing. + let schema = symbols + .records + .get(&r.name.symbol_table_key()) + .expect("record schema validated"); + let mut arr = containers::Array::empty_with_db(Box::new(MemDB::new())); + for entry_lit in &r.entries { + let idx = schema.entry_index[&entry_lit.name.name]; + let value = lower_literal_with_context(&entry_lit.value, symbols, context)?; + arr.insert(idx, value)?; + } + Ok(Value::from(arr)) + } + LiteralValue::RecordEntryIndex { record, entry } => { + let schema = symbols + .records + .get(&record.symbol_table_key()) + .expect("record schema validated"); + let idx = schema.entry_index[&entry.name]; + Ok(Value::from(idx as i64)) + } // All other variants are context-free other => Ok(lower_literal(other)), } @@ -276,11 +311,12 @@ pub(crate) fn lower_statement_arg(arg: &StatementTmplArg) -> BuilderArg { } StatementTmplArg::Wildcard(id) => BuilderArg::WildcardLiteral(id.name.clone()), StatementTmplArg::AnchoredKey(ak) => { - let key_str = match &ak.key { - AnchoredKeyPath::Bracket(s) => s.value.clone(), - AnchoredKeyPath::Dot(id) => id.name.clone(), + let key = match &ak.key { + AnchoredKeyPath::Bracket(s) => Key::new(s.value.clone()), + AnchoredKeyPath::Dot(id) => Key::new(id.name.clone()), + AnchoredKeyPath::Index(i) => Key::from(*i), }; - BuilderArg::Key(ak.root.name.clone(), key_str) + BuilderArg::Key(ak.root.name.clone(), key) } StatementTmplArg::SelfPredicateHash(id) => BuilderArg::SelfPredicateHash(id.name.clone()), } @@ -350,6 +386,7 @@ impl<'a> Lowerer<'a> { fn lower_module(self, module_name: &str) -> Result { // Extract and split custom predicates from document let custom_predicates = self.extract_and_split_predicates()?; + let local_records = self.collect_local_records(); // Build the module from split predicates let module = module::build_module( @@ -357,11 +394,24 @@ impl<'a> Lowerer<'a> { self.params, module_name, self.validated.symbols(), + local_records, )?; Ok(module) } + /// Collect record declarations from this module's source. No transitive + /// re-export — imported records aren't included. + fn collect_local_records(&self) -> HashMap> { + self.validated + .symbols() + .records + .iter() + .filter(|(_, schema)| matches!(schema.source, RecordSource::Local)) + .map(|(name, schema)| (name.clone(), schema.entries.clone())) + .collect() + } + fn lower_request(self) -> Result { let doc = self.validated.document(); @@ -429,12 +479,11 @@ impl<'a> Lowerer<'a> { let index = wildcard_map.get(&name).expect("Wildcard not found"); MWStatementTmplArg::Wildcard(Wildcard::new(name, *index)) } - BuilderArg::Key(root_name, key_str) => { + BuilderArg::Key(root_name, key) => { let root_index = wildcard_map .get(&root_name) .expect("Root wildcard not found"); let wildcard = Wildcard::new(root_name, *root_index); - let key = Key::from(key_str.as_str()); MWStatementTmplArg::AnchoredKey(wildcard, key) } BuilderArg::SelfPredicateHash(_) => { @@ -511,15 +560,56 @@ impl<'a> Lowerer<'a> { }) .collect(); - // Apply splitting to each predicate as needed + // Apply splitting to each predicate as needed. The typed-key rewrite + // happens before splitting so split chain pieces inherit `Index` keys + // unchanged. let mut split_results = Vec::new(); - for pred in predicates { + for mut pred in predicates { + self.rewrite_typed_dot_access(&mut pred); let result = frontend_ast_split::split_predicate_if_needed(pred, self.params)?; split_results.push(result); } Ok(split_results) } + + /// Rewrite `r.foo` to `r[i]` when `r` is a typed wildcard, using the + /// record schema's entry-index map. Untyped wildcards keep + /// `Dot`/`Bracket` keys unchanged (POD-string-key semantics). + fn rewrite_typed_dot_access(&self, pred: &mut CustomPredicateDef) { + let symbols = self.validated.symbols(); + let scope = symbols + .wildcard_scopes + .get(&pred.name.name) + .expect("wildcard scope exists for every custom predicate after validation"); + // Skip the per-arg walk for predicates with no typed wildcards — + // the common case before records see widespread use. + if !scope.wildcards.values().any(|wc| wc.record_type.is_some()) { + return; + } + for stmt in &mut pred.statements { + for arg in &mut stmt.args { + let StatementTmplArg::AnchoredKey(ak) = arg else { + continue; + }; + let Some(wc_info) = scope.wildcards.get(&ak.root.name) else { + continue; + }; + let Some(record_name) = &wc_info.record_type else { + continue; + }; + let AnchoredKeyPath::Dot(entry) = &ak.key else { + continue; + }; + let schema = symbols + .records + .get(record_name) + .expect("record_type was resolved at predicate-def time"); + let idx = schema.entry_index[&entry.name]; + ak.key = AnchoredKeyPath::Index(idx as i64); + } + } + } } #[cfg(test)] @@ -539,8 +629,8 @@ mod tests { ) -> Result { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - let validated = - validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate"); + let validated = validate(document, &HashMap::new(), params, ParseMode::Module) + .expect("Failed to validate"); lower_module(validated, params, "test_batch") } @@ -769,8 +859,8 @@ mod tests { let parsed = parse_podlang(&input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse document"); - let validated = - validate(document, &HashMap::new(), ParseMode::Module).expect("Failed to validate"); + let validated = validate(document, &HashMap::new(), ¶ms, ParseMode::Module) + .expect("Failed to validate"); let result = lower_module(validated, ¶ms, "test_batch"); assert!(result.is_ok(), "Lowering failed: {:?}", result.err()); @@ -796,4 +886,251 @@ mod tests { other => panic!("Expected Intro predicate, got {:?}", other), } } + + // ---- Records: predicate-side dot-access lowering ----------------------- + + /// Pull the single `Key` out of statement N, arg N of the first predicate. + fn anchored_key_at( + module: &Module, + pred_idx: usize, + stmt_idx: usize, + arg_idx: usize, + ) -> middleware::Key { + let pred = &module.batch.predicates()[pred_idx]; + let stmt = &pred.statements()[stmt_idx]; + match &stmt.args()[arg_idx] { + middleware::StatementTmplArg::AnchoredKey(_, k) => k.clone(), + other => panic!("expected AnchoredKey at arg {arg_idx}, got {other:?}"), + } + } + + fn anchored_index_at(module: &Module, pred_idx: usize, stmt_idx: usize, arg_idx: usize) -> i64 { + anchored_key_at(module, pred_idx, stmt_idx, arg_idx) + .as_index() + .expect("expected Index key") + .value() + } + + #[test] + fn test_typed_dot_lowers_to_index_key() { + // Single entry on a typed wildcard becomes an integer-keyed + // AnchoredKey at the schema's entry index. + let input = r#" + record R = (foo, bar, baz) + my_pred(in R) = AND(Equal(in.bar, 0)) + "#; + let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); + assert_eq!(anchored_index_at(&module, 0, 0, 0), 1); + } + + #[test] + fn test_dot_on_untyped_wildcard_stays_str_key() { + // No type tag, no schema lookup: dot-access keeps POD-string-key + // semantics. + let input = r#" + my_pred(r) = AND(Equal(r.foo, 1)) + "#; + let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); + match anchored_key_at(&module, 0, 0, 0) { + middleware::Key::Str(sk) => assert_eq!(sk.name(), "foo"), + other => panic!("expected Str key, got {other:?}"), + } + } + + #[test] + fn test_typed_dot_multiple_entries_distinct_indices() { + let input = r#" + record R = (foo, bar, baz) + my_pred(in R) = AND( + Equal(in.foo, in.baz) + Equal(in.bar, 0) + ) + "#; + let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); + assert_eq!(anchored_index_at(&module, 0, 0, 0), 0); + assert_eq!(anchored_index_at(&module, 0, 0, 1), 2); + assert_eq!(anchored_index_at(&module, 0, 1, 0), 1); + } + + #[test] + fn test_typed_dot_in_or_predicate() { + // OR predicates: the lowering produces a single AnchoredKey per + // statement, no cross-statement coupling, so OR works the same as AND. + let input = r#" + record R = (foo, bar) + my_pred(in R) = OR( + Equal(in.foo, 1) + Equal(in.bar, 2) + ) + "#; + let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); + assert!(module.batch.predicates()[0].is_disjunction()); + assert_eq!(anchored_index_at(&module, 0, 0, 0), 0); + assert_eq!(anchored_index_at(&module, 0, 1, 0), 1); + } + + #[test] + fn test_record_predicate_hash_matches_handwritten_index_form() { + // Source-level records are syntactic sugar: the predicate hash for + // `record R = (foo, bar); p(in R) = AND(Equal(in.bar, 7))` must equal + // the hash of the same predicate built directly with an integer-keyed + // anchored key. There is no Podlang surface syntax for `in[1]`, so we + // build the reference batch via the builder API. + use crate::{ + frontend::{CustomPredicateBatchBuilder, StatementTmplBuilder}, + middleware::NativePredicate, + }; + + let with_record = r#" + record R = (foo, bar) + p(in R) = AND(Equal(in.bar, 7)) + "#; + let params = Params::default(); + let m_record = parse_validate_and_lower_module(with_record, ¶ms).unwrap(); + + let mut b = CustomPredicateBatchBuilder::new(params.clone(), "test_batch".into()); + let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) + .arg(BuilderArg::Key("in".into(), Key::from(1i64))) + .arg(BuilderArg::Literal(Value::from(7i64))); + b.predicate_and("p", &["in"], &[], &[stb]).unwrap(); + let plain_batch = b.finish().unwrap(); + + assert_eq!(m_record.batch.id(), plain_batch.id()); + } + + // ---- Records: literal lowering ----------------------------------------- + + fn lower_literal_in_pred(input: &str) -> Value { + let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); + let pred = &module.batch.predicates()[0]; + let stmt = &pred.statements()[0]; + match &stmt.args()[1] { + middleware::StatementTmplArg::Literal(v) => v.clone(), + other => panic!("expected Literal at arg 1, got {other:?}"), + } + } + + #[test] + fn test_record_literal_full_matches_array_root() { + // A fully populated literal must hash identically to the same values + // packed into an `Array::new(...)` (which inserts at indices 0..n in + // order). + let input = r#" + record R = (foo, bar, baz) + my_pred(A) = AND(Equal(A["x"], R(foo: 1, bar: 2, baz: 3))) + "#; + let v = lower_literal_in_pred(input); + let expected = Value::from(containers::Array::new(vec![ + Value::from(1i64), + Value::from(2i64), + Value::from(3i64), + ])); + assert_eq!(v.raw(), expected.raw()); + } + + #[test] + fn test_record_literal_entry_order_doesnt_matter() { + // Schema fixes the index, so source order never affects the root. + let input_a = r#" + record R = (foo, bar) + my_pred(A) = AND(Equal(A["x"], R(foo: 1, bar: 2))) + "#; + let input_b = r#" + record R = (foo, bar) + my_pred(A) = AND(Equal(A["x"], R(bar: 2, foo: 1))) + "#; + assert_eq!( + lower_literal_in_pred(input_a).raw(), + lower_literal_in_pred(input_b).raw() + ); + } + + #[test] + fn test_record_literal_sparse_stays_sparse() { + // Missing entries stay missing (no zero-fill). Compare against an + // explicit sparse Array built the same way. + let input = r#" + record R = (foo, bar, baz) + my_pred(A) = AND(Equal(A["x"], R(bar: 42))) + "#; + let v = lower_literal_in_pred(input); + + let mut sparse = containers::Array::empty_with_db(Box::new(MemDB::new())); + sparse.insert(1, Value::from(42i64)).unwrap(); + let expected = Value::from(sparse); + + assert_eq!(v.raw(), expected.raw()); + } + + #[test] + fn test_record_literal_nested_record_value() { + // A record literal whose entry value is itself a record literal. + // The outer literal commits to whatever root the inner produces. + let input = r#" + record Inner = (x, y) + record Outer = (inner) + my_pred(A) = AND(Equal(A["x"], Outer(inner: Inner(x: 1, y: 2)))) + "#; + let v = lower_literal_in_pred(input); + + let inner = Value::from(containers::Array::new(vec![ + Value::from(1i64), + Value::from(2i64), + ])); + let expected = Value::from(containers::Array::new(vec![inner])); + + assert_eq!(v.raw(), expected.raw()); + } + + #[test] + fn test_typed_dot_survives_predicate_splitting() { + // The rewrite runs before splitting, so chain pieces inherit + // `Index` keys unchanged. Force a split by exceeding the + // per-predicate statement cap. + let input = r#" + record R = (a, b, c, d, e, f) + my_pred(in R) = AND( + Equal(in.a, 1) + Equal(in.b, 2) + Equal(in.c, 3) + Equal(in.d, 4) + Equal(in.e, 5) + Equal(in.f, 6) + ) + "#; + let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); + // Splitter ran (max_custom_predicate_arity = 5). + assert!(module.batch.predicates().len() > 1); + // Every AnchoredKey across all chain pieces is integer-keyed. + for pred in module.batch.predicates() { + for stmt in pred.statements() { + for arg in stmt.args() { + if let middleware::StatementTmplArg::AnchoredKey(_, k) = arg { + assert!( + matches!(k, middleware::Key::Index(_)), + "expected Index key in split chain piece, got {k:?}" + ); + } + } + } + } + } + + #[test] + fn test_record_entry_index_lowers_to_integer() { + // `R::bar` resolves to integer 1 (bar is the second entry of R). + let input = r#" + record R = (foo, bar, baz) + my_pred(A) = AND(Contains(A, R::bar, 7)) + "#; + let module = parse_validate_and_lower_module(input, &Params::default()).unwrap(); + let pred = &module.batch.predicates()[0]; + let stmt = &pred.statements()[0]; + match &stmt.args()[1] { + middleware::StatementTmplArg::Literal(v) => { + assert_eq!(v.raw(), Value::from(1i64).raw()); + } + other => panic!("expected Literal at arg 1, got {other:?}"), + } + } } diff --git a/src/lang/frontend_ast_split.rs b/src/lang/frontend_ast_split.rs index 482db7a..7889f67 100644 --- a/src/lang/frontend_ast_split.rs +++ b/src/lang/frontend_ast_split.rs @@ -709,12 +709,13 @@ fn generate_chain_predicates( statements.push(chain_call); } - // Build public args (incoming) - let public_args: Vec = link + // Build public args (incoming). + let public_args: Vec = link .public_args_in .iter() - .map(|name| Identifier { + .map(|name| TypedArg { name: name.clone(), + type_name: None, span: None, }) .collect(); @@ -733,7 +734,11 @@ fn generate_chain_predicates( Some( private_arg_names .into_iter() - .map(|name| Identifier { name, span: None }) + .map(|name| TypedArg { + name, + type_name: None, + span: None, + }) .collect(), ) }; diff --git a/src/lang/frontend_ast_validate.rs b/src/lang/frontend_ast_validate.rs index ef3d395..41c0eff 100644 --- a/src/lang/frontend_ast_validate.rs +++ b/src/lang/frontend_ast_validate.rs @@ -13,7 +13,7 @@ use hex::ToHex; use crate::{ lang::{frontend_ast::*, Module}, - middleware::{CustomPredicateBatch, Hash, NativePredicate}, + middleware::{CustomPredicateBatch, Hash, NativePredicate, Params}, }; /// A validated AST document with symbol table and diagnostics @@ -51,6 +51,55 @@ pub struct SymbolTable { pub wildcard_scopes: HashMap, /// Imported modules (bound name → Module reference) pub imported_modules: HashMap>, + /// Records visible in this scope (local declarations + imports). + pub records: HashMap, +} + +/// Resolved record schema: ordered entries plus a name→index lookup, with +/// provenance for diagnostics. Lowering uses `entry_index` to translate +/// dot-access like `r.foo` into the integer key for an `AnchoredKey`. +#[derive(Debug, Clone)] +pub struct RecordSchema { + pub entries: Vec, + pub entry_index: HashMap, + pub source: RecordSource, + pub source_span: Option, +} + +impl RecordSchema { + /// Build a schema from already-deduplicated entries. Callers that need + /// to surface a per-entry span on duplicates (e.g. local declarations) + /// should detect duplicates themselves before calling this. + pub fn from_entries( + entries: Vec, + source: RecordSource, + source_span: Option, + ) -> Self { + let entry_index = entries + .iter() + .enumerate() + .map(|(i, e)| (e.clone(), i)) + .collect(); + Self { + entries, + entry_index, + source, + source_span, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RecordSource { + Local, + Imported { module: String }, +} + +/// Build the `SymbolTable.records` key for a record imported via +/// `use module ... as alias`. Mirrors the `alias::Name` form used for +/// `TypeRef::Qualified`. +pub fn qualified_record_key(alias: &str, name: &str) -> String { + format!("{}::{}", alias, name) } /// Information about a predicate @@ -96,6 +145,9 @@ pub struct WildcardInfo { pub index: usize, pub is_public: bool, pub source_span: Option, + /// Record type tag for typed args (`name TypeName` syntax). The name + /// references an entry in `SymbolTable.records`. + pub record_type: Option, } /// Diagnostic message (warning or info) @@ -127,14 +179,16 @@ pub enum ParseMode { pub fn validate( document: Document, available_modules: &HashMap>, + params: &Params, mode: ParseMode, ) -> Result { - let validator = Validator::new(available_modules, mode); + let validator = Validator::new(available_modules, params, mode); validator.validate(document) } struct Validator { available_modules: HashMap>, + params: Params, symbols: SymbolTable, diagnostics: Vec, custom_predicate_count: usize, @@ -142,13 +196,19 @@ struct Validator { } impl Validator { - fn new(available_modules: &HashMap>, mode: ParseMode) -> Self { + fn new( + available_modules: &HashMap>, + params: &Params, + mode: ParseMode, + ) -> Self { Self { available_modules: available_modules.clone(), + params: params.clone(), symbols: SymbolTable { predicates: HashMap::new(), wildcard_scopes: HashMap::new(), imported_modules: HashMap::new(), + records: HashMap::new(), }, diagnostics: Vec::new(), custom_predicate_count: 0, @@ -181,6 +241,13 @@ impl Validator { } } + // Records before predicates so typed-arg resolution can find them. + for item in &document.items { + if let DocumentItem::RecordDef(record_def) = item { + self.process_record_def(record_def)?; + } + } + // Check mode constraints for predicate definitions let mut has_predicates = false; for item in &document.items { @@ -214,7 +281,7 @@ impl Validator { } } - // Enforce that modules have predicates and requests have a REQUEST block + // Enforce that modules have predicates and requests have a REQUEST block. match self.mode { ParseMode::Module if !has_predicates => { return Err(ValidationError::NoPredicatesInModule); @@ -244,6 +311,22 @@ impl Validator { span: use_stmt.span, })?; + // Flatten the imported module's locally-declared records into the + // symbol table under qualified keys (`alias::Name`). No transitive + // re-export — `Module.records` only carries local declarations. + for (record_name, entries) in &module.records { + self.symbols.records.insert( + qualified_record_key(alias, record_name), + RecordSchema::from_entries( + entries.clone(), + RecordSource::Imported { + module: alias.clone(), + }, + use_stmt.span, + ), + ); + } + // Store the module keyed by alias for later qualified name resolution self.symbols .imported_modules @@ -252,6 +335,24 @@ impl Validator { Ok(()) } + /// Returns the resolved `SymbolTable.records` key for a typed arg, or + /// `None` if the arg has no `type_name`. The key is the bare type name + /// for locals and `"alias::Name"` for qualified imports. Errors if the + /// tag doesn't refer to a known record. + fn resolve_typed_arg(&self, arg: &TypedArg) -> Result, ValidationError> { + let Some(type_ref) = &arg.type_name else { + return Ok(None); + }; + let key = type_ref.symbol_table_key(); + if !self.symbols.records.contains_key(&key) { + return Err(ValidationError::UnknownRecord { + name: key, + span: type_ref.span(), + }); + } + Ok(Some(key)) + } + fn process_use_intro_statement( &mut self, use_stmt: &UseIntroStatement, @@ -283,6 +384,48 @@ impl Validator { Ok(()) } + fn process_record_def(&mut self, record_def: &RecordDef) -> Result<(), ValidationError> { + let name = &record_def.name.name; + + if let Some(existing) = self.symbols.records.get(name) { + return Err(ValidationError::DuplicateRecord { + name: name.clone(), + first_span: existing.source_span, + second_span: record_def.name.span, + }); + } + + let max = self.params.max_record_entries(); + if record_def.entries.len() > max { + return Err(ValidationError::RecordTooManyEntries { + name: name.clone(), + count: record_def.entries.len(), + max, + span: record_def.span, + }); + } + + let mut seen = HashSet::with_capacity(record_def.entries.len()); + let mut entries = Vec::with_capacity(record_def.entries.len()); + for entry in &record_def.entries { + if !seen.insert(&entry.name) { + return Err(ValidationError::DuplicateRecordEntry { + record: name.clone(), + entry: entry.name.clone(), + span: entry.span, + }); + } + entries.push(entry.name.clone()); + } + + self.symbols.records.insert( + name.clone(), + RecordSchema::from_entries(entries, RecordSource::Local, record_def.name.span), + ); + + Ok(()) + } + fn process_custom_predicate_def( &mut self, pred_def: &CustomPredicateDef, @@ -318,12 +461,14 @@ impl Validator { span: arg.span, }); } + let record_type = self.resolve_typed_arg(arg)?; wildcards.insert( arg.name.clone(), WildcardInfo { index: wildcard_index, is_public: true, source_span: arg.span, + record_type, }, ); wildcard_index += 1; @@ -339,12 +484,14 @@ impl Validator { span: arg.span, }); } + let record_type = self.resolve_typed_arg(arg)?; wildcards.insert( arg.name.clone(), WildcardInfo { index: wildcard_index, is_public: false, source_span: arg.span, + record_type, }, ); wildcard_index += 1; @@ -443,10 +590,7 @@ impl Validator { wildcard_context: Option<(&str, &WildcardScope)>, ) -> Result<(), ValidationError> { let pred_name = stmt.predicate.predicate_name(); - let pred_span = match &stmt.predicate { - PredicateRef::Local(id) => id.span, - PredicateRef::Qualified { predicate, .. } => predicate.span, - }; + let pred_span = stmt.predicate.span(); let wc_names = match wildcard_context { Some((_, wc_scope)) => wc_scope.wildcards.keys().collect(), @@ -547,12 +691,44 @@ impl Validator { } StatementTmplArg::AnchoredKey(ak) => { if let Some((pred_name, scope)) = wildcard_context { - if !scope.wildcards.contains_key(&ak.root.name) { + let Some(wc_info) = scope.wildcards.get(&ak.root.name) else { return Err(ValidationError::UndefinedWildcard { name: ak.root.name.clone(), pred_name: pred_name.to_string(), span: ak.root.span, }); + }; + // Records are integer-keyed, so string-key access on + // a typed wildcard is dead code at proof time. Reject + // dot access for unknown entries and bracket access + // outright; require `r.entry` for record-shaped data. + if let Some(record_name) = &wc_info.record_type { + match &ak.key { + AnchoredKeyPath::Dot(entry) => { + let schema = + self.symbols.records.get(record_name).expect( + "record_type was resolved at predicate-def time", + ); + if !schema.entry_index.contains_key(&entry.name) { + return Err(ValidationError::UnknownRecordEntry { + record: record_name.clone(), + entry: entry.name.clone(), + span: entry.span, + }); + } + } + AnchoredKeyPath::Bracket(_) => { + return Err(ValidationError::BracketAccessOnTypedWildcard { + wildcard: ak.root.name.clone(), + record: record_name.clone(), + span: ak.span, + }); + } + AnchoredKeyPath::Index(_) => unreachable!( + "AnchoredKeyPath::Index is introduced during lowering; \ + it cannot appear in the parsed AST that validation sees" + ), + } } } } @@ -638,6 +814,51 @@ impl Validator { } Ok(()) } + LiteralValue::Record(r) => { + let key = r.name.symbol_table_key(); + let Some(schema) = self.symbols.records.get(&key) else { + return Err(ValidationError::UnknownRecord { + name: key, + span: r.name.span(), + }); + }; + let mut seen: HashSet<&String> = HashSet::new(); + for entry in &r.entries { + if !schema.entry_index.contains_key(&entry.name.name) { + return Err(ValidationError::UnknownRecordEntry { + record: key.clone(), + entry: entry.name.name.clone(), + span: entry.name.span, + }); + } + if !seen.insert(&entry.name.name) { + return Err(ValidationError::DuplicateLiteralRecordEntry { + record: key.clone(), + entry: entry.name.name.clone(), + span: entry.name.span, + }); + } + self.validate_literal_value(&entry.value)?; + } + Ok(()) + } + LiteralValue::RecordEntryIndex { record, entry } => { + let key = record.symbol_table_key(); + let Some(schema) = self.symbols.records.get(&key) else { + return Err(ValidationError::UnknownRecord { + name: key, + span: record.span(), + }); + }; + if !schema.entry_index.contains_key(&entry.name) { + return Err(ValidationError::UnknownRecordEntry { + record: key, + entry: entry.name.clone(), + span: entry.span, + }); + } + Ok(()) + } _ => Ok(()), } } @@ -659,7 +880,7 @@ mod tests { ) -> Result { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - validate(document, modules, ParseMode::Module) + validate(document, modules, &Params::default(), ParseMode::Module) } fn parse_and_validate_request( @@ -668,7 +889,7 @@ mod tests { ) -> Result { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - validate(document, modules, ParseMode::Request) + validate(document, modules, &Params::default(), ParseMode::Request) } #[test] @@ -846,8 +1067,9 @@ mod tests { span: None, }, args: ArgSection { - public_args: vec![Identifier { + public_args: vec![TypedArg { name: "A".to_string(), + type_name: None, span: None, }], private_args: None, @@ -858,7 +1080,12 @@ mod tests { span: None, })], }; - let result = validate(document, &HashMap::new(), ParseMode::Module); + let result = validate( + document, + &HashMap::new(), + &Params::default(), + ParseMode::Module, + ); assert!(matches!( result, Err(ValidationError::EmptyStatementList { .. }) @@ -936,4 +1163,247 @@ mod tests { let result = parse_and_validate_request(input, &HashMap::new()); assert!(result.is_ok()); } + + // ----- Records ---------------------------------------------------------- + + #[test] + fn test_record_decl_accepted() { + let input = r#" + record ProcInputs = (foo, bar, baz) + my_pred(A) = AND(Equal(A["x"], 1)) + "#; + let validated = parse_and_validate_module(input, &HashMap::new()).unwrap(); + let schema = validated.symbols.records.get("ProcInputs").unwrap(); + assert_eq!(schema.entries, vec!["foo", "bar", "baz"]); + assert_eq!(schema.source, RecordSource::Local); + } + + #[test] + fn test_records_only_module_rejected() { + // A module needs at least one predicate; record-only modules are not + // a valid distribution unit. + let input = r#"record R = (x)"#; + assert!(matches!( + parse_and_validate_module(input, &HashMap::new()), + Err(ValidationError::NoPredicatesInModule) + )); + } + + #[test] + fn test_duplicate_record() { + let input = r#" + record R = (foo) + record R = (bar) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::DuplicateRecord { .. }) + )); + } + + #[test] + fn test_duplicate_entry_in_record() { + let input = r#" + record R = (foo, foo) + my_pred(A) = AND(Equal(A["x"], 1)) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::DuplicateRecordEntry { record, entry, .. }) + if record == "R" && entry == "foo" + )); + } + + #[test] + fn test_record_entry_cap() { + // Use a non-default depth so the cap reflects the parameter (not + // some hard-coded default). This pins three facts in one test: + // the param is wired through, the boundary is inclusive on accept, + // and cap + 1 is rejected. + let mut params = Params::default(); + params.containers.max_depth_small -= 1; + let cap = params.max_record_entries(); + let validate_with_n_entries = |n: usize| { + let entries: Vec = (0..n).map(|i| format!("f{i}")).collect(); + let input = format!( + "record Big = ({})\nmy_pred(A) = AND(Equal(A[\"x\"], 1))", + entries.join(", ") + ); + let parsed = parse_podlang(&input).expect("Failed to parse"); + let document = + parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); + validate(document, &HashMap::new(), ¶ms, ParseMode::Module) + }; + assert!(validate_with_n_entries(cap).is_ok()); + let too_many = cap + 1; + assert!(matches!( + validate_with_n_entries(too_many), + Err(ValidationError::RecordTooManyEntries { count, max, .. }) + if count == too_many && max == cap + )); + } + + #[test] + fn test_typed_arg_resolves_known_record() { + let input = r#" + record R = (foo, bar) + my_pred(in R) = AND(Equal(in.foo, in.bar)) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(result.is_ok()); + let validated = result.unwrap(); + let scope = validated.symbols.wildcard_scopes.get("my_pred").unwrap(); + assert_eq!(scope.wildcards["in"].record_type.as_deref(), Some("R")); + } + + #[test] + fn test_typed_arg_unknown_record_rejected() { + let input = r#" + my_pred(in NonExistent) = AND(Equal(in.foo, 1)) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::UnknownRecord { name, .. }) if name == "NonExistent" + )); + } + + #[test] + fn test_dot_access_unknown_entry_rejected() { + let input = r#" + record R = (foo, bar) + my_pred(in R) = AND(Equal(in.quux, 1)) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::UnknownRecordEntry { record, entry, .. }) + if record == "R" && entry == "quux" + )); + } + + #[test] + fn test_dot_access_on_untyped_wildcard_unchecked() { + // r.foo on an untyped wildcard keeps current POD-string-key behavior; + // no record exists named anything that would constrain `foo`. + let input = r#" + my_pred(r) = AND(Equal(r.foo, 1)) + "#; + assert!(parse_and_validate_module(input, &HashMap::new()).is_ok()); + } + + #[test] + fn test_bracket_access_on_typed_wildcard_rejected() { + // Records are integer-keyed; string-key access on a record-typed + // wildcard is incoherent and would never resolve at proof time. + // Force the user to use `.entry` instead. + let input = r#" + record R = (foo) + my_pred(r R) = AND(Equal(r["foo"], 1)) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::BracketAccessOnTypedWildcard { wildcard, record, .. }) + if wildcard == "r" && record == "R" + )); + } + + #[test] + fn test_record_literal_unknown_record() { + let input = r#" + my_pred(A) = AND(Equal(A["x"], NotARecord(f: 1))) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::UnknownRecord { name, .. }) if name == "NotARecord" + )); + } + + #[test] + fn test_record_literal_unknown_entry() { + let input = r#" + record R = (foo, bar) + my_pred(A) = AND(Equal(A["x"], R(foo: 1, quux: 2))) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::UnknownRecordEntry { record, entry, .. }) + if record == "R" && entry == "quux" + )); + } + + #[test] + fn test_record_literal_nested() { + // Nested literals recurse through `validate_literal_value`: an unknown + // entry on the inner literal must still be caught. + let input = r#" + record Outer = (inner) + record Inner = (x, y) + my_pred(A) = AND(Equal(A["x"], Outer(inner: Inner(x: 1, z: 2)))) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::UnknownRecordEntry { record, entry, .. }) + if record == "Inner" && entry == "z" + )); + } + + #[test] + fn test_record_literal_duplicate_entry() { + let input = r#" + record R = (foo, bar) + my_pred(A) = AND(Equal(A["x"], R(foo: 1, foo: 2))) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::DuplicateLiteralRecordEntry { record, entry, .. }) + if record == "R" && entry == "foo" + )); + } + + #[test] + fn test_record_entry_index_resolves() { + // Validation accepts `R::bar` and the schema records bar at index 1 + // — the integer the literal will lower to. + let input = r#" + record R = (foo, bar) + my_pred(A) = AND(Contains(A, R::bar, 7)) + "#; + let validated = parse_and_validate_module(input, &HashMap::new()).unwrap(); + let schema = validated.symbols.records.get("R").unwrap(); + assert_eq!(schema.entry_index["bar"], 1); + } + + #[test] + fn test_record_entry_index_unknown_record() { + let input = r#" + my_pred(A) = AND(Contains(A, NotARecord::foo, 7)) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::UnknownRecord { name, .. }) if name == "NotARecord" + )); + } + + #[test] + fn test_record_entry_index_unknown_entry() { + let input = r#" + record R = (foo, bar) + my_pred(A) = AND(Contains(A, R::quux, 7)) + "#; + let result = parse_and_validate_module(input, &HashMap::new()); + assert!(matches!( + result, + Err(ValidationError::UnknownRecordEntry { record, entry, .. }) + if record == "R" && entry == "quux" + )); + } } diff --git a/src/lang/grammar.pest b/src/lang/grammar.pest index 1c11baa..0446b27 100644 --- a/src/lang/grammar.pest +++ b/src/lang/grammar.pest @@ -11,7 +11,11 @@ WHITESPACE = _{ (" " | "\t" | NEWLINE)+ } // COMMENT matches a line comment (//...\n) or block comment (/*...*/). COMMENT = _{ ("//" ~ (!NEWLINE ~ ANY)* | "/*" ~ (!"*/" ~ ANY)* ~ "*/" ) } -reserved_identifier = { "private" | "true" | "false" } +// Word-boundary anchor: the reserved word must not be followed by an +// identifier character, otherwise prefixes like `record_count` or `recorder` +// would be wrongly rejected by the `!reserved_identifier` lookahead in +// `identifier`. +reserved_identifier = @{ ("private" | "true" | "false" | "record") ~ !(ASCII_ALPHANUMERIC | "_") } // Define rules for identifiers (predicate names, wildcard names) // Must start with alpha or _, followed by alpha, numeric, or _ @@ -23,10 +27,19 @@ arg_section = { public_arg_list ~ ("," ~ private_kw ~ private_arg_list)? } -public_arg_list = { identifier ~ ("," ~ identifier)* } -private_arg_list = { identifier ~ ("," ~ identifier)* } +// `name` or `name TypeName` or `name module::TypeName`. The optional `type_tag` +// is a record type, either local or imported via a `use module ... as alias`. +typed_arg = { identifier ~ type_tag? } +type_tag = { qualified_type_ref | identifier } +qualified_type_ref = { identifier ~ "::" ~ identifier } +public_arg_list = { typed_arg ~ ("," ~ typed_arg)* } +private_arg_list = { typed_arg ~ ("," ~ typed_arg)* } -document = { SOI ~ (use_module_statement | use_intro_statement | custom_predicate_def | request_def)* ~ EOI } +record_def = { + "record" ~ identifier ~ "=" ~ "(" ~ identifier ~ ("," ~ identifier)* ~ ")" +} + +document = { SOI ~ (use_module_statement | use_intro_statement | record_def | custom_predicate_def | request_def)* ~ EOI } use_module_statement = { "use" ~ "module" ~ hash_hex ~ "as" ~ identifier } @@ -83,9 +96,27 @@ literal_value = { literal_string | predicate_hash_native | predicate_hash_external | + literal_record | + record_entry_index | literal_int } +// Record literal: `Name(Field: value, ...)` or `module::Name(...)`. Ordering +// in `literal_value` matters: must come after `Name(...)`-shaped prefix +// literals (PublicKey, SecretKey, Raw) so PEG doesn't shadow them, and BEFORE +// `record_entry_index` so `module::R(...)` isn't consumed as a 2-segment +// entry index — Pest backtracks to `record_entry_index` when no `(` follows. +literal_record = { + type_tag ~ "(" ~ record_entry ~ ("," ~ record_entry)* ~ ")" +} +record_entry = { identifier ~ ":" ~ literal_value } + +// Compile-time entry-index lookup: `R::foo` or `module::R::foo`. Resolves to +// the integer index of the named entry in the (possibly imported) record. +record_entry_index = { + identifier ~ "::" ~ identifier ~ ("::" ~ identifier)? +} + // Primitive literal types literal_int = @{ "-"? ~ ASCII_DIGIT+ } literal_bool = @{ "true" | "false" } diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 291f7a6..3ae23d8 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -84,6 +84,7 @@ fn load_module_inner( let validated = frontend_ast_validate::validate( document, &available_modules_map, + params, frontend_ast_validate::ParseMode::Module, )?; let module = frontend_ast_lower::lower_module(validated, params, name)?; @@ -124,6 +125,7 @@ fn parse_request_inner( let validated = frontend_ast_validate::validate( document, &available_modules_map, + params, frontend_ast_validate::ParseMode::Request, )?; let request = frontend_ast_lower::lower_request(validated, params)?; @@ -1073,4 +1075,248 @@ mod tests { e => panic!("Expected LangError::Validation, but got {:?}", e), } } + + // ---- Records: cross-module export ------------------------------------- + + #[test] + fn test_e2e_record_imported_predicate_compiles() -> Result<(), LangError> { + // Module A defines a record + a predicate that uses it. + // Module B imports A and writes its own predicate using + // `in a::ProcInputs`. B should compile without errors. + let params = Params::default(); + let module_a_src = r#" + record ProcInputs = (foo, bar, baz) + uses_record(in ProcInputs) = AND( + Equal(in.foo, in.baz) + ) + "#; + let module_a = Arc::new(load_module(module_a_src, "module_a", ¶ms, &[])?); + let a_hash = module_a.id().encode_hex::(); + + let module_b_src = format!( + r#" + use module 0x{} as a + + wraps_a(in a::ProcInputs) = AND( + Equal(in.bar, 7) + ) + "#, + a_hash + ); + let module_b = load_module(&module_b_src, "module_b", ¶ms, &[module_a])?; + assert_eq!(module_b.batch.predicates().len(), 1); + assert_eq!(module_b.batch.predicates()[0].name, "wraps_a"); + Ok(()) + } + + #[test] + fn test_e2e_imported_record_predicate_hash_matches_handwritten() -> Result<(), LangError> { + // Module B's predicate using `in a::ProcInputs` must hash identically + // to the same predicate built directly with an integer-keyed + // anchored key. Schema lives in A; the integer index is what gets + // baked into B's predicate body. + use crate::{ + frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder}, + middleware::Key, + }; + + let params = Params::default(); + let module_a_src = r#" + record ProcInputs = (foo, bar, baz) + stub(A) = AND(Equal(A["x"], 1)) + "#; + let module_a = Arc::new(load_module(module_a_src, "module_a", ¶ms, &[])?); + let a_hash = module_a.id().encode_hex::(); + + let module_b_src = format!( + r#" + use module 0x{} as a + + uses(in a::ProcInputs) = AND( + Equal(in.bar, 7) + ) + "#, + a_hash + ); + let module_b = load_module(&module_b_src, "module_b", ¶ms, &[module_a])?; + + let mut hand = CustomPredicateBatchBuilder::new(params.clone(), "module_b".into()); + let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) + .arg(BuilderArg::Key("in".into(), Key::from(1i64))) + .arg(BuilderArg::Literal(Value::from(7i64))); + hand.predicate_and("uses", &["in"], &[], &[stb]) + .expect("predicate_and"); + let hand_batch = hand.finish().expect("finish"); + + assert_eq!(module_b.batch.id(), hand_batch.id()); + Ok(()) + } + + #[test] + fn test_e2e_imported_record_literal_matches_array_root() -> Result<(), LangError> { + // Qualified record literal: `a::ProcInputs(...)` in the importer's + // body lowers to the same `Array` root as a local record literal + // built from the same schema, with entries placed at their schema + // indices. + use crate::middleware::{containers::Array, StatementTmplArg}; + + let params = Params::default(); + let module_a_src = r#" + record ProcInputs = (foo, bar, baz) + stub(A) = AND(Equal(A["x"], 1)) + "#; + let module_a = Arc::new(load_module(module_a_src, "module_a", ¶ms, &[])?); + let a_hash = module_a.id().encode_hex::(); + + // Source order is intentionally not schema order — the schema lookup + // through the qualified key has to map each entry back to its index. + let module_b_src = format!( + r#" + use module 0x{} as a + + uses(A) = AND( + Equal(A["data"], a::ProcInputs(baz: 30, foo: 10, bar: 20)) + ) + "#, + a_hash + ); + let module_b = load_module(&module_b_src, "module_b", ¶ms, &[module_a])?; + + let pred = &module_b.batch.predicates()[0]; + let stmt = &pred.statements()[0]; + let lowered = match &stmt.args()[1] { + StatementTmplArg::Literal(v) => v.clone(), + other => panic!("expected Literal at arg 1, got {other:?}"), + }; + + let expected = Value::from(Array::new(vec![ + Value::from(10i64), + Value::from(20i64), + Value::from(30i64), + ])); + + assert_eq!(lowered.raw(), expected.raw()); + Ok(()) + } + + #[test] + fn test_e2e_imported_record_literal_unknown_module_rejected() -> Result<(), LangError> { + // A literal that names a module the importer didn't bind must be + // rejected — the qualified key never gets into the symbol table, + // so validation surfaces `UnknownRecord` rather than producing a + // bogus lowered value. + use crate::lang::frontend_ast_validate::ValidationError; + + let params = Params::default(); + let module_a_src = r#" + record ProcInputs = (foo, bar) + stub(A) = AND(Equal(A["x"], 1)) + "#; + let module_a = Arc::new(load_module(module_a_src, "module_a", ¶ms, &[])?); + let a_hash = module_a.id().encode_hex::(); + + // Imported as `a`, but the literal references `b::ProcInputs`. + let module_b_src = format!( + r#" + use module 0x{} as a + + uses(A) = AND( + Equal(A["data"], b::ProcInputs(foo: 1, bar: 2)) + ) + "#, + a_hash + ); + let err = load_module(&module_b_src, "module_b", ¶ms, &[module_a]).unwrap_err(); + match err.kind { + LangErrorKind::Validation(e) => match *e { + ValidationError::UnknownRecord { name, .. } => { + assert_eq!(name, "b::ProcInputs"); + } + other => panic!("expected UnknownRecord, got {other:?}"), + }, + other => panic!("expected Validation, got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_e2e_record_entry_index_proves_via_mock() -> Result<(), LangError> { + // End-to-end: a record-using predicate is satisfied by an Array + // value, with `Inputs::x` resolving the entry's integer index. + // MockProver runs the full proving path. + use crate::{ + backends::plonky2::mock::mainpod::MockProver, + frontend::{MainPodBuilder, Operation}, + middleware::{containers::Array, VDSet}, + }; + + let params = Params::default(); + let module = load_module( + r#" + record Inputs = (x, y) + at_x_is(arr, val) = AND( + Contains(arr, Inputs::x, val) + ) + "#, + "records_e2e", + ¶ms, + &[], + )?; + let at_x_is = module.batch.predicate_ref_by_name("at_x_is").unwrap(); + + // Build a 2-entry Array; arr[0] = 7 (Inputs::x), arr[1] = 13 (Inputs::y). + let arr = Array::new(vec![Value::from(7i64), Value::from(13i64)]); + + let vd_set = VDSet::new(&[]); + let mut builder = MainPodBuilder::new(¶ms, &vd_set); + let contains_st = builder + .priv_op(Operation::array_contains(arr, 0i64, 7i64)) + .unwrap(); + builder + .pub_op(Operation::custom(at_x_is, [contains_st])) + .unwrap(); + let pod = builder.prove(&MockProver {}).unwrap(); + pod.pod.verify().unwrap(); + Ok(()) + } + + #[test] + fn test_e2e_record_typed_dot_proves_via_mock() -> Result<(), LangError> { + // End-to-end: typed dot access is satisfied by an Array entry opened + // with an integer key, then used as an AnchoredKey in an Equal statement. + use crate::{ + backends::plonky2::mock::mainpod::MockProver, + frontend::{MainPodBuilder, Operation}, + middleware::{containers::Array, VDSet}, + }; + + let params = Params::default(); + let module = load_module( + r#" + record Inputs = (x, y) + at_x_is(arr Inputs, val) = AND( + Equal(arr.x, val) + ) + "#, + "records_dot_e2e", + ¶ms, + &[], + )?; + let at_x_is = module.batch.predicate_ref_by_name("at_x_is").unwrap(); + + let arr = Array::new(vec![Value::from(7i64), Value::from(13i64)]); + + let vd_set = VDSet::new(&[]); + let mut builder = MainPodBuilder::new(¶ms, &vd_set); + let contains_st = builder + .priv_op(Operation::array_contains(arr, 0i64, 7i64)) + .unwrap(); + let equal_st = builder.priv_op(Operation::eq(contains_st, 7i64)).unwrap(); + builder + .pub_op(Operation::custom(at_x_is, [equal_st])) + .unwrap(); + let pod = builder.prove(&MockProver {}).unwrap(); + pod.pod.verify().unwrap(); + Ok(()) + } } diff --git a/src/lang/module.rs b/src/lang/module.rs index b926871..f5b5ecd 100644 --- a/src/lang/module.rs +++ b/src/lang/module.rs @@ -53,6 +53,12 @@ pub struct Module { /// Split chain info for predicates that were split pub split_chains: HashMap, + + /// Records declared locally in this module's source: name → ordered entry + /// list. Frontend metadata only — the middleware batch knows nothing + /// about records. No transitive re-export: a downstream importer + /// inherits only the records declared in this module's own source. + pub records: HashMap>, } impl Module { @@ -60,6 +66,15 @@ impl Module { pub fn new( batch: Arc, split_chains: HashMap, + ) -> Self { + Self::with_records(batch, split_chains, HashMap::new()) + } + + /// Like `new`, but seeds the module's locally-declared records. + pub fn with_records( + batch: Arc, + split_chains: HashMap, + records: HashMap>, ) -> Self { let predicate_index = batch .predicates() @@ -71,6 +86,7 @@ impl Module { batch, predicate_index, split_chains, + records, } } @@ -264,6 +280,7 @@ pub fn build_module( params: &Params, module_name: &str, symbols: &SymbolTable, + records: HashMap>, ) -> Result { // Extract predicates and collect split chains let mut predicates = Vec::new(); @@ -281,7 +298,7 @@ pub fn build_module( if predicates.is_empty() { // Return an empty module let empty_batch = CustomPredicateBatch::new(module_name.to_string(), vec![]); - return Ok(Module::new(empty_batch, split_chains)); + return Ok(Module::with_records(empty_batch, split_chains, records)); } // Build reference map: name -> index @@ -294,7 +311,7 @@ pub fn build_module( // Build the batch let batch = build_single_batch(&predicates, &reference_map, symbols, params, module_name)?; - Ok(Module::new(batch, split_chains)) + Ok(Module::with_records(batch, split_chains, records)) } /// Build a batch with properly resolved references @@ -406,8 +423,14 @@ mod tests { fn parse_and_validate(input: &str) -> (Vec, ValidatedAST) { let parsed = parse_podlang(input).expect("Failed to parse"); let document = parse_document(parsed.into_iter().next().unwrap()).expect("Failed to parse"); - let validated = validate(document.clone(), &HashMap::new(), ParseMode::Module) - .expect("Failed to validate"); + let params = Params::default(); + let validated = validate( + document.clone(), + &HashMap::new(), + ¶ms, + ParseMode::Module, + ) + .expect("Failed to validate"); let predicates = document .items @@ -448,6 +471,7 @@ mod tests { ¶ms, "TestModule", validated.symbols(), + HashMap::new(), ); assert!(result.is_ok()); @@ -471,6 +495,7 @@ mod tests { ¶ms, "TestModule", validated.symbols(), + HashMap::new(), ); assert!(result.is_ok()); @@ -495,6 +520,7 @@ mod tests { ¶ms, "TestModule", validated.symbols(), + HashMap::new(), ); assert!(result.is_ok()); @@ -527,6 +553,7 @@ mod tests { ¶ms, "TestModule", validated.symbols(), + HashMap::new(), ); assert!(result.is_ok()); @@ -561,6 +588,7 @@ mod tests { ¶ms, "TestModule", validated.symbols(), + HashMap::new(), ) .unwrap(); @@ -599,8 +627,14 @@ mod tests { assert_eq!(split_results[0].predicates.len(), 2); assert!(split_results[0].chain_info.is_some()); - let module = - build_module(split_results, ¶ms, "TestModule", validated.symbols()).unwrap(); + let module = build_module( + split_results, + ¶ms, + "TestModule", + validated.symbols(), + HashMap::new(), + ) + .unwrap(); // Verify chain info is preserved let chain_info = module.split_chains.get("large_pred").unwrap(); diff --git a/src/middleware/containers.rs b/src/middleware/containers.rs index 7c8e744..0a507ab 100644 --- a/src/middleware/containers.rs +++ b/src/middleware/containers.rs @@ -18,7 +18,7 @@ use crate::{ backends::plonky2::primitives::merkletree::MerkleTreeStateTransitionProof, middleware::{ db::{mem::MemDB, DB}, - Error, Hash, Key, RawValue, Result, TypedValue, Value, EMPTY_HASH, + Error, Hash, RawValue, Result, StrKey, TypedValue, Value, EMPTY_HASH, }, }; @@ -264,22 +264,22 @@ macro_rules! dict { ); ({ $($key:expr => $val:expr),* }) => ({ let mut map = ::std::collections::HashMap::new(); - $( map.insert($crate::middleware::Key::from($key), $crate::middleware::Value::from($val)); )* + $( map.insert($crate::middleware::StrKey::from($key), $crate::middleware::Value::from($val)); )* $crate::middleware::containers::Dictionary::new(map) }); } -// TODO: Replace all methods that receive a `&Key` by either `impl Into` for write +// TODO: Replace all methods that receive a `&StrKey` by either `impl Into` for write // methods and `impl AsRef` for read methods. // TODO: Replace all methods that receive a `&Value` in write methods for `Value`. Consider a // trait? impl Dictionary { - pub fn new(kvs: HashMap) -> Self { + pub fn new(kvs: HashMap) -> Self { Self { inner: Container::new( kvs.into_iter() - .map(|(k, v)| (Value::from(k.name), v)) + .map(|(k, v)| (Value::from(k.into_name()), v)) .collect(), ), } @@ -297,29 +297,37 @@ impl Dictionary { pub fn commitment(&self) -> Hash { self.inner.commitment() } - pub fn get(&self, key: &Key) -> Result> { + pub fn get(&self, key: &StrKey) -> Result> { self.inner.get(key.raw()) } - pub fn prove(&self, key: &Key) -> Result<(Value, MerkleProof)> { + pub fn prove(&self, key: &StrKey) -> Result<(Value, MerkleProof)> { self.inner.prove(key.raw()) } - pub fn prove_nonexistence(&self, key: &Key) -> Result { + pub fn prove_nonexistence(&self, key: &StrKey) -> Result { self.inner.prove_nonexistence(key.raw()) } - pub fn insert(&mut self, key: &Key, value: &Value) -> Result { + pub fn insert( + &mut self, + key: &StrKey, + value: &Value, + ) -> Result { self.inner - .insert(Value::from(key.name.clone()), value.clone()) + .insert(Value::from(key.name().to_string()), value.clone()) } - pub fn update(&mut self, key: &Key, value: &Value) -> Result { + pub fn update( + &mut self, + key: &StrKey, + value: &Value, + ) -> Result { self.inner.update(key.raw(), value.clone()) } - pub fn delete(&mut self, key: &Key) -> Result { + pub fn delete(&mut self, key: &StrKey) -> Result { self.inner.delete(key.raw()) } - pub fn verify(root: Hash, proof: &MerkleProof, key: &Key, value: &Value) -> Result<()> { + pub fn verify(root: Hash, proof: &MerkleProof, key: &StrKey, value: &Value) -> Result<()> { Container::verify(root, proof, key.raw(), value.raw()) } - pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Key) -> Result<()> { + pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &StrKey) -> Result<()> { Container::verify_nonexistence(root, proof, key.raw()) } pub fn verify_state_transition(proof: &MerkleTreeStateTransitionProof) -> Result<()> { @@ -524,11 +532,11 @@ mod tests { fn _test_dict(db: Box) { let mut dict0 = Dictionary::empty_with_db(db.clone()); - dict0.insert(&Key::from("a"), &Value::from(1)).unwrap(); - dict0.insert(&Key::from("b"), &Value::from(2)).unwrap(); - dict0.update(&Key::from("a"), &Value::from(3)).unwrap(); - dict0.insert(&Key::from("c"), &Value::from(4)).unwrap(); - dict0.delete(&Key::from("c")).unwrap(); + dict0.insert(&StrKey::from("a"), &Value::from(1)).unwrap(); + dict0.insert(&StrKey::from("b"), &Value::from(2)).unwrap(); + dict0.update(&StrKey::from("a"), &Value::from(3)).unwrap(); + dict0.insert(&StrKey::from("c"), &Value::from(4)).unwrap(); + dict0.delete(&StrKey::from("c")).unwrap(); let kvs0 = dict0.dump().unwrap(); assert_eq!( kvs0, @@ -579,14 +587,14 @@ mod tests { fn _test_nested(db: Box) { let mut nested = Dictionary::empty_with_db(db.clone()); - nested.insert(&Key::from("a"), &Value::from(1)).unwrap(); - nested.insert(&Key::from("b"), &Value::from(2)).unwrap(); + nested.insert(&StrKey::from("a"), &Value::from(1)).unwrap(); + nested.insert(&StrKey::from("b"), &Value::from(2)).unwrap(); let nested_kvs0 = nested.dump().unwrap(); let mut dict0 = Dictionary::empty_with_db(db.clone()); - dict0.insert(&Key::from("x"), &Value::from(1)).unwrap(); + dict0.insert(&StrKey::from("x"), &Value::from(1)).unwrap(); dict0 - .insert(&Key::from("y"), &Value::from(nested.clone())) + .insert(&StrKey::from("y"), &Value::from(nested.clone())) .unwrap(); let kvs0 = dict0.dump().unwrap(); diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index d212ca8..f4f76cb 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -577,21 +577,24 @@ where } } +/// A key identified by a string name. Hash is computed via `hash_str`. #[derive(Clone, Debug, Eq)] -pub struct Key { +pub struct StrKey { name: String, hash: Hash, } -impl Key { +impl StrKey { pub fn new(name: String) -> Self { let hash = hash_str(&name); Self { name, hash } } - pub fn name(&self) -> &str { &self.name } + pub fn into_name(self) -> String { + self.name + } pub fn hash(&self) -> Hash { self.hash } @@ -600,20 +603,31 @@ impl Key { } } -impl hash::Hash for Key { - fn hash(&self, state: &mut H) { - self.hash.hash(state); - } -} - -impl PartialEq for Key { +impl PartialEq for StrKey { fn eq(&self, other: &Self) -> bool { self.hash == other.hash } } -// A Key can easily be created from a string-like type -impl From for Key +impl hash::Hash for StrKey { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } +} + +impl ToFields for StrKey { + fn to_fields(&self) -> Vec { + self.hash.to_fields() + } +} + +impl fmt::Display for StrKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "\"{}\"", self.name) + } +} + +impl From for StrKey where T: Into, { @@ -622,30 +636,9 @@ where } } -impl ToFields for Key { - fn to_fields(&self) -> Vec { - self.hash.to_fields() - } -} - -impl fmt::Display for Key { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "\"{}\"", self.name)?; - Ok(()) - } -} - -impl From for RawValue { - fn from(key: Key) -> RawValue { - RawValue(key.hash.0) - } -} - -// When serializing a Key, we serialize only the name field, and not the hash. -// We can't directly tell Serde to render the whole struct as a string, so we -// implement our own serialization. It's important that if we change the -// structure of the Key struct, we update this implementation. -impl Serialize for Key { +// `StrKey` serializes as a bare string. The cached hash is recomputed on +// deserialize via `StrKey::new`. +impl Serialize for StrKey { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, @@ -654,29 +647,206 @@ impl Serialize for Key { } } -impl<'de> Deserialize<'de> for Key { +impl<'de> Deserialize<'de> for StrKey { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { - let name = String::deserialize(deserializer)?; - Ok(Key::new(name)) + String::deserialize(deserializer).map(StrKey::new) } } -// As per the above, we implement custom serialization for the Key type, and -// Schemars can't automatically generate a schema for it. Instead, we tell it -// to use the standard String schema. -impl JsonSchema for Key { +impl JsonSchema for StrKey { fn schema_name() -> String { - "Key".to_string() + "StrKey".to_string() } - fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { ::json_schema(gen) } } +/// A key identified by an integer index. The hash is taken directly from the +/// integer's `RawValue` encoding so that integer-keyed merkle trees (e.g. +/// `Array`, future shallow-MT variants) and integer-keyed `AnchoredKey`s +/// share the same leaf-key encoding. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct IndexKey { + value: i64, +} + +impl IndexKey { + pub fn new(value: i64) -> Self { + Self { value } + } + pub fn value(&self) -> i64 { + self.value + } + pub fn raw(&self) -> RawValue { + RawValue::from(self.value) + } + pub fn hash(&self) -> Hash { + Hash::from(self.raw()) + } +} + +impl ToFields for IndexKey { + fn to_fields(&self) -> Vec { + self.hash().to_fields() + } +} + +impl fmt::Display for IndexKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.value) + } +} + +impl From for IndexKey { + fn from(i: i64) -> Self { + Self::new(i) + } +} + +// `IndexKey` serializes as a bare integer. +impl Serialize for IndexKey { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.value.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for IndexKey { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + i64::deserialize(deserializer).map(IndexKey::new) + } +} + +impl JsonSchema for IndexKey { + fn schema_name() -> String { + "IndexKey".to_string() + } + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema { + ::json_schema(gen) + } +} + +/// A key, either string-named or integer-indexed. +/// +/// APIs that only make sense for one variant (e.g. `Dictionary::insert` for +/// strings) take the inner type directly, lifting the variant check out to +/// the call site where the `match` on `Key` makes the missing arm visible at +/// compile time. +#[derive(Clone, Debug, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(untagged)] +pub enum Key { + Str(StrKey), + Index(IndexKey), +} + +impl Key { + pub fn new(name: String) -> Self { + Key::Str(StrKey::new(name)) + } + pub fn as_str(&self) -> Option<&StrKey> { + match self { + Key::Str(k) => Some(k), + Key::Index(_) => None, + } + } + pub fn as_index(&self) -> Option<&IndexKey> { + match self { + Key::Str(_) => None, + Key::Index(k) => Some(k), + } + } + pub fn hash(&self) -> Hash { + match self { + Key::Str(k) => k.hash(), + Key::Index(k) => k.hash(), + } + } + pub fn raw(&self) -> RawValue { + match self { + Key::Str(k) => k.raw(), + Key::Index(k) => k.raw(), + } + } +} + +impl PartialEq for Key { + fn eq(&self, other: &Self) -> bool { + self.hash() == other.hash() + } +} + +impl hash::Hash for Key { + fn hash(&self, state: &mut H) { + self.hash().hash(state); + } +} + +impl ToFields for Key { + fn to_fields(&self) -> Vec { + self.hash().to_fields() + } +} + +impl fmt::Display for Key { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Key::Str(k) => k.fmt(f), + Key::Index(k) => k.fmt(f), + } + } +} + +impl From for Key { + fn from(k: StrKey) -> Self { + Key::Str(k) + } +} + +impl From for Key { + fn from(k: IndexKey) -> Self { + Key::Index(k) + } +} + +impl From<&str> for Key { + fn from(s: &str) -> Self { + Key::Str(StrKey::from(s)) + } +} + +impl From for Key { + fn from(s: String) -> Self { + Key::Str(StrKey::from(s)) + } +} + +impl From<&String> for Key { + fn from(s: &String) -> Self { + Key::Str(StrKey::from(s)) + } +} + +impl From for Key { + fn from(i: i64) -> Self { + Key::Index(IndexKey::from(i)) + } +} + +impl From for RawValue { + fn from(key: Key) -> RawValue { + key.raw() + } +} + #[derive(Clone, Debug, Eq, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] pub struct AnchoredKey { @@ -693,13 +863,13 @@ impl AnchoredKey { impl hash::Hash for AnchoredKey { fn hash(&self, state: &mut H) { self.root.hash(state); - self.key.hash.hash(state); + self.key.hash().hash(state); } } impl PartialEq for AnchoredKey { fn eq(&self, other: &Self) -> bool { - self.root == other.root && self.key.hash == other.key.hash + self.root == other.root && self.key.hash() == other.key.hash() } } @@ -887,6 +1057,12 @@ impl Params { self.max_statements - self.max_public_statements } + /// Maximum number of entries permitted in a `record` declaration: the + /// number of leaves in the small container merkle tree variant. + pub fn max_record_entries(&self) -> usize { + 2usize.pow(self.containers.max_depth_small as u32) + } + pub const fn statement_tmpl_arg_size() -> usize { 2 * HASH_SIZE + 1 } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 8d3316c..96f7fc1 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -860,8 +860,12 @@ impl fmt::Display for Operation { pub(crate) fn root_key_to_ak(root: &Value, key: &Value) -> Option { let root_hash = Hash::from(root.raw()); - key.as_str() - .map(|s| AnchoredKey::new(root_hash, Key::from(s))) + if let Some(s) = key.as_str() { + Some(AnchoredKey::new(root_hash, Key::from(s))) + } else { + key.as_int() + .map(|i| AnchoredKey::new(root_hash, Key::from(i))) + } } /// Returns the value associated with `output_ref`. diff --git a/src/middleware/serialization.rs b/src/middleware/serialization.rs index 68e6efb..e66ec80 100644 --- a/src/middleware/serialization.rs +++ b/src/middleware/serialization.rs @@ -1,12 +1,8 @@ -use std::{ - collections::{HashMap, HashSet}, - fmt::Write, -}; +use std::fmt::Write; use plonky2::field::types::Field; -use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; +use serde::Deserialize; -use super::{Key, Value}; use crate::middleware::{F, HASH_SIZE, VALUE_SIZE}; fn serialize_field_tuple( @@ -104,44 +100,3 @@ where .parse() .map_err(serde::de::Error::custom) } - -// In order to serialize a Dictionary consistently, we want to order the -// key-value pairs by the key's name field. This has no effect on the hashes -// of the keys and therefore on the Merkle tree, but it makes the serialized -// output deterministic. -pub fn ordered_map( - value: &HashMap, - serializer: S, -) -> Result -where - S: Serializer, -{ - // Convert to Vec and sort by the key's name field - let mut pairs: Vec<_> = value.iter().collect(); - pairs.sort_by(|(k1, _), (k2, _)| k1.name.cmp(&k2.name)); - - // Serialize as a map - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(pairs.len()))?; - for (k, v) in pairs { - map.serialize_entry(k, v)?; - } - map.end() -} - -// Sets are serialized as sequences of elements, which are not ordered by -// default. We want to serialize them in a deterministic way, and we can -// achieve this by sorting the elements. This takes advantage of the fact that -// Value implements Ord. -pub fn ordered_set(value: &HashSet, serializer: S) -> Result -where - S: Serializer, -{ - let mut set = serializer.serialize_seq(Some(value.len()))?; - let mut sorted_values: Vec<&Value> = value.iter().collect(); - sorted_values.sort_by_key(|v| v.raw()); - for v in sorted_values { - set.serialize_element(v)?; - } - set.end() -}