Detect invalid wildcards in the language processor (#321)

This commit is contained in:
Rob Knight 2025-07-09 00:31:15 +02:00 committed by GitHub
parent 2c41a6c554
commit 0750dbeaff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 88 additions and 12 deletions

View file

@ -60,9 +60,10 @@ pub enum ProcessorError {
Internal(String), Internal(String),
#[error("Middleware error: {0}")] #[error("Middleware error: {0}")]
Middleware(middleware::Error), Middleware(middleware::Error),
#[error("Undefined wildcard: '?{name}' at {span:?}")] #[error("Undefined wildcard: '?{name}' in predicate '{pred_name}' at {span:?}")]
UndefinedWildcard { UndefinedWildcard {
name: String, name: String,
pred_name: String,
span: Option<(usize, usize)>, span: Option<(usize, usize)>,
}, },
#[error("Invalid literal format for {kind}: '{value}' at {span:?}")] #[error("Invalid literal format for {kind}: '{value}' at {span:?}")]

View file

@ -31,7 +31,7 @@ mod tests {
middleware::{ middleware::{
hash_str, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, hash_str, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key,
NativePredicate, Params, PodId, PodType, Predicate, RawValue, StatementTmpl, NativePredicate, Params, PodId, PodType, Predicate, RawValue, StatementTmpl,
StatementTmplArg, Value, Wildcard, SELF_ID_HASH, StatementTmplArg, Value, Wildcard, KEY_SIGNER, KEY_TYPE, SELF_ID_HASH,
}, },
}; };
@ -954,4 +954,43 @@ mod tests {
e => panic!("Expected LangError::Processor, but got {:?}", e), e => panic!("Expected LangError::Processor, but got {:?}", e),
} }
} }
#[test]
fn test_e2e_undefined_wildcard() {
let params = Params::default();
let available_batches = &[];
let input = format!(
r#"
identity_verified(username, private: identity_pod) = AND(
Equal(?identity_pod["{key_type}"], {signed_pod_type})
Equal(?identity_pod["{key_signer}"], {identity_server_pk})
Equal(?identity_pod["username"], ?username)
Equal(?identity_pod["user_public_key"], ?user_public_key)
)
"#,
key_type = KEY_TYPE,
signed_pod_type = PodType::Signed as u32,
key_signer = KEY_SIGNER,
identity_server_pk =
"0x0000000000000000000000000000000000000000000000000000000000000000"
);
let result = parse(&input, &params, available_batches);
assert!(result.is_err());
match result.err().unwrap() {
LangError::Processor(e) => match *e {
ProcessorError::UndefinedWildcard {
name, pred_name, ..
} => {
assert_eq!(name, "user_public_key");
assert_eq!(pred_name, "identity_verified");
}
_ => panic!("Expected UndefinedWildcard error, but got {:?}", e),
},
e => panic!("Expected LangError::Processor, but got {:?}", e),
}
}
} }

View file

@ -259,7 +259,10 @@ fn process_use_statement(
} }
enum StatementContext<'a> { enum StatementContext<'a> {
CustomPredicate, CustomPredicate {
pred_name: &'a str,
argument_names: &'a HashSet<String>,
},
Request { Request {
custom_batch: &'a Arc<CustomPredicateBatch>, custom_batch: &'a Arc<CustomPredicateBatch>,
wildcard_names: &'a mut Vec<String>, wildcard_names: &'a mut Vec<String>,
@ -295,6 +298,7 @@ fn second_pass(
fn pest_pair_to_builder_arg( fn pest_pair_to_builder_arg(
params: &Params, params: &Params,
arg_content_pair: &Pair<Rule>, arg_content_pair: &Pair<Rule>,
context: &StatementContext,
) -> Result<BuilderArg, ProcessorError> { ) -> Result<BuilderArg, ProcessorError> {
match arg_content_pair.as_rule() { match arg_content_pair.as_rule() {
Rule::literal_value => { Rule::literal_value => {
@ -302,14 +306,41 @@ fn pest_pair_to_builder_arg(
Ok(BuilderArg::Literal(value)) Ok(BuilderArg::Literal(value))
} }
Rule::wildcard => { Rule::wildcard => {
let name = arg_content_pair.as_str().strip_prefix("?").unwrap(); let wc_str = arg_content_pair.as_str().strip_prefix("?").unwrap();
Ok(BuilderArg::WildcardLiteral(name.to_string())) if let StatementContext::CustomPredicate {
argument_names,
pred_name,
} = context
{
if !argument_names.contains(wc_str) {
return Err(ProcessorError::UndefinedWildcard {
name: wc_str.to_string(),
pred_name: pred_name.to_string(),
span: Some(get_span(arg_content_pair)),
});
}
}
Ok(BuilderArg::WildcardLiteral(wc_str.to_string()))
} }
Rule::anchored_key => { Rule::anchored_key => {
let mut inner_ak_pairs = arg_content_pair.clone().into_inner(); let mut inner_ak_pairs = arg_content_pair.clone().into_inner();
let pod_id_pair = inner_ak_pairs.next().unwrap(); let pod_id_pair = inner_ak_pairs.next().unwrap();
let pod_id_wc_str = pod_id_pair.as_str().strip_prefix("?").unwrap(); let pod_id_wc_str = pod_id_pair.as_str().strip_prefix("?").unwrap();
if let StatementContext::CustomPredicate {
argument_names,
pred_name,
} = context
{
if !argument_names.contains(pod_id_wc_str) {
return Err(ProcessorError::UndefinedWildcard {
name: pod_id_wc_str.to_string(),
pred_name: pred_name.to_string(),
span: Some(get_span(arg_content_pair)),
});
}
}
let key_part_pair = inner_ak_pairs.next().unwrap(); let key_part_pair = inner_ak_pairs.next().unwrap();
let key_str = parse_pest_string_literal(&key_part_pair)?; let key_str = parse_pest_string_literal(&key_part_pair)?;
Ok(BuilderArg::Key(pod_id_wc_str.to_string(), key_str)) Ok(BuilderArg::Key(pod_id_wc_str.to_string(), key_str))
@ -516,7 +547,10 @@ fn process_and_add_custom_predicate_to_batch(
params, params,
&stmt_pair, &stmt_pair,
processing_ctx, processing_ctx,
StatementContext::CustomPredicate, &mut StatementContext::CustomPredicate {
pred_name: &name,
argument_names: &defined_arg_names,
},
)?; )?;
statement_builders.push(stb); statement_builders.push(stb);
} }
@ -557,7 +591,7 @@ fn process_request_def(
params, params,
&stmt_pair, &stmt_pair,
processing_ctx, processing_ctx,
StatementContext::Request { &mut StatementContext::Request {
custom_batch, custom_batch,
wildcard_names: &mut request_wildcard_names, wildcard_names: &mut request_wildcard_names,
defined_wildcards: &mut defined_request_wildcards, defined_wildcards: &mut defined_request_wildcards,
@ -582,7 +616,7 @@ fn process_statement_template(
params: &Params, params: &Params,
stmt_pair: &Pair<Rule>, stmt_pair: &Pair<Rule>,
processing_ctx: &ProcessingContext, processing_ctx: &ProcessingContext,
mut context: StatementContext, context: &mut StatementContext,
) -> Result<StatementTmplBuilder, ProcessorError> { ) -> Result<StatementTmplBuilder, ProcessorError> {
let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); let mut inner_stmt_pairs = stmt_pair.clone().into_inner();
let name_pair = inner_stmt_pairs let name_pair = inner_stmt_pairs
@ -590,13 +624,13 @@ fn process_statement_template(
.unwrap(); .unwrap();
let stmt_name_str = name_pair.as_str(); let stmt_name_str = name_pair.as_str();
let builder_args = parse_statement_args(params, stmt_pair)?; let builder_args = parse_statement_args(params, stmt_pair, context)?;
if let StatementContext::Request { if let StatementContext::Request {
wildcard_names, wildcard_names,
defined_wildcards, defined_wildcards,
.. ..
} = &mut context } = context
{ {
let mut temp_stmt_wildcard_names: Vec<String> = Vec::new(); let mut temp_stmt_wildcard_names: Vec<String> = Vec::new();
for arg in &builder_args { for arg in &builder_args {
@ -626,7 +660,7 @@ fn process_statement_template(
.get(stmt_name_str) .get(stmt_name_str)
{ {
match context { match context {
StatementContext::CustomPredicate => Predicate::BatchSelf(*pred_index), StatementContext::CustomPredicate { .. } => Predicate::BatchSelf(*pred_index),
StatementContext::Request { custom_batch, .. } => { StatementContext::Request { custom_batch, .. } => {
let custom_pred_ref = CustomPredicateRef::new(custom_batch.clone(), *pred_index); let custom_pred_ref = CustomPredicateRef::new(custom_batch.clone(), *pred_index);
Predicate::Custom(custom_pred_ref) Predicate::Custom(custom_pred_ref)
@ -861,6 +895,7 @@ fn resolve_wildcard(
.map(|index| Wildcard::new(name_to_resolve.to_string(), index)) .map(|index| Wildcard::new(name_to_resolve.to_string(), index))
.ok_or_else(|| ProcessorError::UndefinedWildcard { .ok_or_else(|| ProcessorError::UndefinedWildcard {
name: name_to_resolve.to_string(), name: name_to_resolve.to_string(),
pred_name: "REQUEST".to_string(),
span: None, span: None,
}) })
} }
@ -906,6 +941,7 @@ fn resolve_request_statement_builder(
fn parse_statement_args( fn parse_statement_args(
params: &Params, params: &Params,
stmt_pair: &Pair<Rule>, stmt_pair: &Pair<Rule>,
context: &StatementContext,
) -> Result<Vec<BuilderArg>, ProcessorError> { ) -> Result<Vec<BuilderArg>, ProcessorError> {
let mut builder_args = Vec::new(); let mut builder_args = Vec::new();
let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); let mut inner_stmt_pairs = stmt_pair.clone().into_inner();
@ -917,7 +953,7 @@ fn parse_statement_args(
.filter(|p| p.as_rule() == Rule::statement_arg) .filter(|p| p.as_rule() == Rule::statement_arg)
{ {
let arg_content_pair = arg_pair.into_inner().next().unwrap(); let arg_content_pair = arg_pair.into_inner().next().unwrap();
let builder_arg = pest_pair_to_builder_arg(params, &arg_content_pair)?; let builder_arg = pest_pair_to_builder_arg(params, &arg_content_pair, context)?;
builder_args.push(builder_arg); builder_args.push(builder_arg);
} }
} }