New 'use' syntax with support for intro predicates (#431)

* New 'use' syntax with support for intro predicates

* Use empty statement in test

* Review feedback
This commit is contained in:
Rob Knight 2025-10-17 12:27:11 +02:00 committed by GitHub
parent ffed5b4fbd
commit aa4b531ac7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 118 additions and 13 deletions

View file

@ -26,13 +26,16 @@ arg_section = {
public_arg_list = { identifier ~ ("," ~ identifier)* }
private_arg_list = { identifier ~ ("," ~ identifier)* }
document = { SOI ~ (use_statement | custom_predicate_def | request_def)* ~ EOI }
document = { SOI ~ (use_batch_statement | use_intro_statement | custom_predicate_def | request_def)* ~ EOI }
use_statement = { "use" ~ use_predicate_list ~ "from" ~ batch_ref }
use_batch_statement = { "use" ~ "batch" ~ use_predicate_list ~ "from" ~ batch_ref }
use_predicate_list = { import_name ~ ("," ~ import_name)* }
import_name = { identifier | "_" }
batch_ref = { hash_hex }
use_intro_statement = { "use" ~ "intro" ~ identifier ~ "(" ~ use_intro_arg_list? ~ ")" ~ "from" ~ batch_ref }
use_intro_arg_list = { identifier ~ ("," ~ identifier)* }
request_def = { "REQUEST" ~ "(" ~ statement_list? ~ ")" }
// Define conjunction type explicitly

View file

@ -34,6 +34,7 @@ mod tests {
middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate,
Params, Predicate, RawValue, StatementTmpl, StatementTmplArg, Value, Wildcard,
EMPTY_HASH,
},
};
@ -687,7 +688,7 @@ mod tests {
let batch_id_str = available_batch.id().encode_hex::<String>();
let input = format!(
r#"
use imported_pred from 0x{}
use batch imported_pred from 0x{}
REQUEST(
imported_pred(Pod1, Pod2)
@ -738,7 +739,7 @@ mod tests {
let input = format!(
r#"
use pred_one, _, pred_three from 0x{}
use batch pred_one, _, pred_three from 0x{}
REQUEST(
pred_one(Pod1)
@ -796,7 +797,7 @@ mod tests {
let input = format!(
r#"
use imported_eq from 0x{}
use batch imported_eq from 0x{}
wrapper_pred(X, Y) = AND(
imported_eq(X, Y)
@ -836,6 +837,38 @@ mod tests {
Ok(())
}
#[test]
fn test_e2e_intro_import_parsing() -> Result<(), LangError> {
let params = Params::default();
let intro_hash = EMPTY_HASH.encode_hex::<String>();
let input = format!(
r#"
use intro empty() from 0x{intro_hash}
REQUEST(
empty()
)
"#,
);
let processed = parse(&input, &params, &[])?;
let request_templates = processed.request.templates();
assert_eq!(request_templates.len(), 1);
if let Predicate::Intro(intro_ref) = &request_templates[0].pred {
assert_eq!(intro_ref.name, "empty");
assert_eq!(intro_ref.args_len, 0);
assert_eq!(intro_ref.verifier_data_hash, EMPTY_HASH);
} else {
panic!("Expected Intro predicate");
}
assert!(request_templates[0].args.is_empty());
Ok(())
}
#[test]
fn test_e2e_literals() -> Result<(), LangError> {
let pk = crate::backends::plonky2::primitives::ec::curve::Point::generator();
@ -920,7 +953,7 @@ mod tests {
let input = format!(
r#"
use some_pred from {}
use batch some_pred from {}
"#,
unknown_batch_id
);

View file

@ -15,8 +15,9 @@ use crate::{
frontend::{BuilderArg, CustomPredicateBatchBuilder, PodRequest, StatementTmplBuilder},
lang::parser::Rule,
middleware::{
self, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate, Params, Predicate,
StatementTmpl, StatementTmplArg, Value, Wildcard, F, VALUE_SIZE,
self, CustomPredicateBatch, CustomPredicateRef, Hash, IntroPredicateRef, Key,
NativePredicate, Params, Predicate, StatementTmpl, StatementTmplArg, Value, Wildcard, F,
VALUE_SIZE,
},
};
@ -74,6 +75,8 @@ struct ProcessingContext<'a> {
params: &'a Params,
/// Maps imported predicate names to their full reference (batch and index)
imported_predicates: HashMap<String, CustomPredicateRef>,
/// Maps imported intro predicate names to their intro refs
imported_intro_predicates: HashMap<String, IntroPredicateRef>,
/// Maps predicate names to their batch index and public argument count (from Pass 1)
custom_predicate_signatures: HashMap<String, (usize, usize)>,
/// Stores the original Pest pairs for custom predicate definitions for Pass 2
@ -87,6 +90,7 @@ impl<'a> ProcessingContext<'a> {
ProcessingContext {
params,
imported_predicates: HashMap::new(),
imported_intro_predicates: HashMap::new(),
custom_predicate_signatures: HashMap::new(),
custom_predicate_pairs: Vec::new(),
request_pair: None,
@ -139,8 +143,11 @@ fn first_pass<'a>(
for pair in document_pairs {
match pair.as_rule() {
Rule::use_statement => {
process_use_statement(&pair, ctx, available_batches)?;
Rule::use_batch_statement => {
process_use_batch_statement(&pair, ctx, available_batches)?;
}
Rule::use_intro_statement => {
process_use_intro_statement(&pair, ctx)?;
}
Rule::custom_predicate_def => {
let pred_name_pair = pair
@ -152,6 +159,7 @@ fn first_pass<'a>(
if defined_custom_names.contains(&pred_name)
|| ctx.imported_predicates.contains_key(&pred_name)
|| ctx.imported_intro_predicates.contains_key(&pred_name)
{
return Err(ProcessorError::DuplicateDefinition {
name: pred_name,
@ -205,7 +213,7 @@ fn count_public_args(pred_def_pair: &Pair<Rule>) -> Result<usize, ProcessorError
.count())
}
fn process_use_statement(
fn process_use_batch_statement(
use_pair: &Pair<Rule>,
ctx: &mut ProcessingContext,
available_batches: &[Arc<CustomPredicateBatch>],
@ -258,7 +266,10 @@ fn process_use_statement(
let name = import_name_pair.as_str().to_string();
if ctx.imported_predicates.contains_key(&name) {
if ctx.imported_predicates.contains_key(&name)
|| ctx.imported_intro_predicates.contains_key(&name)
|| ctx.custom_predicate_signatures.contains_key(&name)
{
return Err(ProcessorError::DuplicateImportName {
name,
span: Some(get_span(&import_name_pair)),
@ -272,6 +283,61 @@ fn process_use_statement(
Ok(())
}
fn process_use_intro_statement(
use_pair: &Pair<Rule>,
ctx: &mut ProcessingContext,
) -> Result<(), ProcessorError> {
let mut inner = use_pair.clone().into_inner();
// Structure: identifier, '(', optional arg list, ')', 'from', batch_ref
let name_pair = inner.find(|p| p.as_rule() == Rule::identifier).unwrap();
let pred_name = name_pair.as_str().to_string();
if ctx.imported_predicates.contains_key(&pred_name)
|| ctx.imported_intro_predicates.contains_key(&pred_name)
|| ctx.custom_predicate_signatures.contains_key(&pred_name)
{
return Err(ProcessorError::DuplicateImportName {
name: pred_name,
span: Some(get_span(&name_pair)),
});
}
let args_len = inner
.clone()
.find(|p| p.as_rule() == Rule::use_intro_arg_list)
.map(|arg_list| {
arg_list
.into_inner()
.filter(|p| p.as_rule() == Rule::identifier)
.count()
})
.unwrap_or(0);
let batch_ref_pair = inner.find(|p| p.as_rule() == Rule::batch_ref).unwrap();
let hash_hex_pair = batch_ref_pair.into_inner().next().unwrap();
let hash_str_full = hash_hex_pair.as_str();
let hex_no_prefix = hash_str_full.strip_prefix("0x").unwrap_or(hash_str_full);
let raw_val = parse_hex_str_to_raw_value(hex_no_prefix).map_err(|_| {
ProcessorError::InvalidLiteralFormat {
kind: "intro verifier hash".to_string(),
value: hash_str_full.to_string(),
span: Some(get_span(&hash_hex_pair)),
}
})?;
let verifier_hash: Hash = Hash::from(raw_val);
let intro_ref = IntroPredicateRef {
name: pred_name.clone(),
args_len,
verifier_data_hash: verifier_hash,
};
ctx.imported_intro_predicates.insert(pred_name, intro_ref);
Ok(())
}
enum StatementContext<'a> {
CustomPredicate {
pred_name: &'a str,
@ -714,6 +780,8 @@ fn process_statement_template(
Predicate::Native(native_pred)
} else if let Some(custom_ref) = processing_ctx.imported_predicates.get(stmt_name_str) {
Predicate::Custom(custom_ref.clone())
} else if let Some(intro_ref) = processing_ctx.imported_intro_predicates.get(stmt_name_str) {
Predicate::Intro(intro_ref.clone())
} else if let Some((pred_index, _expected_arity)) = processing_ctx
.custom_predicate_signatures
.get(stmt_name_str)
@ -1145,6 +1213,7 @@ mod processor_tests {
let mut ctx = ProcessingContext {
params: &params,
imported_predicates: HashMap::new(),
imported_intro_predicates: HashMap::new(),
custom_predicate_signatures: HashMap::new(),
custom_predicate_pairs: Vec::new(),
request_pair: None,