Detect invalid wildcards in the language processor (#321)

This commit is contained in:
Rob Knight 2025-07-09 00:31:15 +02:00 committed by GitHub
parent 2c41a6c554
commit 0750dbeaff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 88 additions and 12 deletions

View file

@ -60,9 +60,10 @@ pub enum ProcessorError {
Internal(String),
#[error("Middleware error: {0}")]
Middleware(middleware::Error),
#[error("Undefined wildcard: '?{name}' at {span:?}")]
#[error("Undefined wildcard: '?{name}' in predicate '{pred_name}' at {span:?}")]
UndefinedWildcard {
name: String,
pred_name: String,
span: Option<(usize, usize)>,
},
#[error("Invalid literal format for {kind}: '{value}' at {span:?}")]

View file

@ -31,7 +31,7 @@ mod tests {
middleware::{
hash_str, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key,
NativePredicate, Params, PodId, PodType, Predicate, RawValue, StatementTmpl,
StatementTmplArg, Value, Wildcard, SELF_ID_HASH,
StatementTmplArg, Value, Wildcard, KEY_SIGNER, KEY_TYPE, SELF_ID_HASH,
},
};
@ -954,4 +954,43 @@ mod tests {
e => panic!("Expected LangError::Processor, but got {:?}", e),
}
}
#[test]
fn test_e2e_undefined_wildcard() {
let params = Params::default();
let available_batches = &[];
let input = format!(
r#"
identity_verified(username, private: identity_pod) = AND(
Equal(?identity_pod["{key_type}"], {signed_pod_type})
Equal(?identity_pod["{key_signer}"], {identity_server_pk})
Equal(?identity_pod["username"], ?username)
Equal(?identity_pod["user_public_key"], ?user_public_key)
)
"#,
key_type = KEY_TYPE,
signed_pod_type = PodType::Signed as u32,
key_signer = KEY_SIGNER,
identity_server_pk =
"0x0000000000000000000000000000000000000000000000000000000000000000"
);
let result = parse(&input, &params, available_batches);
assert!(result.is_err());
match result.err().unwrap() {
LangError::Processor(e) => match *e {
ProcessorError::UndefinedWildcard {
name, pred_name, ..
} => {
assert_eq!(name, "user_public_key");
assert_eq!(pred_name, "identity_verified");
}
_ => panic!("Expected UndefinedWildcard error, but got {:?}", e),
},
e => panic!("Expected LangError::Processor, but got {:?}", e),
}
}
}

View file

@ -259,7 +259,10 @@ fn process_use_statement(
}
enum StatementContext<'a> {
CustomPredicate,
CustomPredicate {
pred_name: &'a str,
argument_names: &'a HashSet<String>,
},
Request {
custom_batch: &'a Arc<CustomPredicateBatch>,
wildcard_names: &'a mut Vec<String>,
@ -295,6 +298,7 @@ fn second_pass(
fn pest_pair_to_builder_arg(
params: &Params,
arg_content_pair: &Pair<Rule>,
context: &StatementContext,
) -> Result<BuilderArg, ProcessorError> {
match arg_content_pair.as_rule() {
Rule::literal_value => {
@ -302,14 +306,41 @@ fn pest_pair_to_builder_arg(
Ok(BuilderArg::Literal(value))
}
Rule::wildcard => {
let name = arg_content_pair.as_str().strip_prefix("?").unwrap();
Ok(BuilderArg::WildcardLiteral(name.to_string()))
let wc_str = arg_content_pair.as_str().strip_prefix("?").unwrap();
if let StatementContext::CustomPredicate {
argument_names,
pred_name,
} = context
{
if !argument_names.contains(wc_str) {
return Err(ProcessorError::UndefinedWildcard {
name: wc_str.to_string(),
pred_name: pred_name.to_string(),
span: Some(get_span(arg_content_pair)),
});
}
}
Ok(BuilderArg::WildcardLiteral(wc_str.to_string()))
}
Rule::anchored_key => {
let mut inner_ak_pairs = arg_content_pair.clone().into_inner();
let pod_id_pair = inner_ak_pairs.next().unwrap();
let pod_id_wc_str = pod_id_pair.as_str().strip_prefix("?").unwrap();
if let StatementContext::CustomPredicate {
argument_names,
pred_name,
} = context
{
if !argument_names.contains(pod_id_wc_str) {
return Err(ProcessorError::UndefinedWildcard {
name: pod_id_wc_str.to_string(),
pred_name: pred_name.to_string(),
span: Some(get_span(arg_content_pair)),
});
}
}
let key_part_pair = inner_ak_pairs.next().unwrap();
let key_str = parse_pest_string_literal(&key_part_pair)?;
Ok(BuilderArg::Key(pod_id_wc_str.to_string(), key_str))
@ -516,7 +547,10 @@ fn process_and_add_custom_predicate_to_batch(
params,
&stmt_pair,
processing_ctx,
StatementContext::CustomPredicate,
&mut StatementContext::CustomPredicate {
pred_name: &name,
argument_names: &defined_arg_names,
},
)?;
statement_builders.push(stb);
}
@ -557,7 +591,7 @@ fn process_request_def(
params,
&stmt_pair,
processing_ctx,
StatementContext::Request {
&mut StatementContext::Request {
custom_batch,
wildcard_names: &mut request_wildcard_names,
defined_wildcards: &mut defined_request_wildcards,
@ -582,7 +616,7 @@ fn process_statement_template(
params: &Params,
stmt_pair: &Pair<Rule>,
processing_ctx: &ProcessingContext,
mut context: StatementContext,
context: &mut StatementContext,
) -> Result<StatementTmplBuilder, ProcessorError> {
let mut inner_stmt_pairs = stmt_pair.clone().into_inner();
let name_pair = inner_stmt_pairs
@ -590,13 +624,13 @@ fn process_statement_template(
.unwrap();
let stmt_name_str = name_pair.as_str();
let builder_args = parse_statement_args(params, stmt_pair)?;
let builder_args = parse_statement_args(params, stmt_pair, context)?;
if let StatementContext::Request {
wildcard_names,
defined_wildcards,
..
} = &mut context
} = context
{
let mut temp_stmt_wildcard_names: Vec<String> = Vec::new();
for arg in &builder_args {
@ -626,7 +660,7 @@ fn process_statement_template(
.get(stmt_name_str)
{
match context {
StatementContext::CustomPredicate => Predicate::BatchSelf(*pred_index),
StatementContext::CustomPredicate { .. } => Predicate::BatchSelf(*pred_index),
StatementContext::Request { custom_batch, .. } => {
let custom_pred_ref = CustomPredicateRef::new(custom_batch.clone(), *pred_index);
Predicate::Custom(custom_pred_ref)
@ -861,6 +895,7 @@ fn resolve_wildcard(
.map(|index| Wildcard::new(name_to_resolve.to_string(), index))
.ok_or_else(|| ProcessorError::UndefinedWildcard {
name: name_to_resolve.to_string(),
pred_name: "REQUEST".to_string(),
span: None,
})
}
@ -906,6 +941,7 @@ fn resolve_request_statement_builder(
fn parse_statement_args(
params: &Params,
stmt_pair: &Pair<Rule>,
context: &StatementContext,
) -> Result<Vec<BuilderArg>, ProcessorError> {
let mut builder_args = Vec::new();
let mut inner_stmt_pairs = stmt_pair.clone().into_inner();
@ -917,7 +953,7 @@ fn parse_statement_args(
.filter(|p| p.as_rule() == Rule::statement_arg)
{
let arg_content_pair = arg_pair.into_inner().next().unwrap();
let builder_arg = pest_pair_to_builder_arg(params, &arg_content_pair)?;
let builder_arg = pest_pair_to_builder_arg(params, &arg_content_pair, context)?;
builder_args.push(builder_arg);
}
}