diff --git a/src/lang/error.rs b/src/lang/error.rs index b24d8d7..96161de 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -60,9 +60,10 @@ pub enum ProcessorError { Internal(String), #[error("Middleware error: {0}")] Middleware(middleware::Error), - #[error("Undefined wildcard: '?{name}' at {span:?}")] + #[error("Undefined wildcard: '?{name}' in predicate '{pred_name}' at {span:?}")] UndefinedWildcard { name: String, + pred_name: String, span: Option<(usize, usize)>, }, #[error("Invalid literal format for {kind}: '{value}' at {span:?}")] diff --git a/src/lang/mod.rs b/src/lang/mod.rs index 1f6dc8c..0966656 100644 --- a/src/lang/mod.rs +++ b/src/lang/mod.rs @@ -31,7 +31,7 @@ mod tests { middleware::{ hash_str, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate, Params, PodId, PodType, Predicate, RawValue, StatementTmpl, - StatementTmplArg, Value, Wildcard, SELF_ID_HASH, + StatementTmplArg, Value, Wildcard, KEY_SIGNER, KEY_TYPE, SELF_ID_HASH, }, }; @@ -954,4 +954,43 @@ mod tests { e => panic!("Expected LangError::Processor, but got {:?}", e), } } + + #[test] + fn test_e2e_undefined_wildcard() { + let params = Params::default(); + let available_batches = &[]; + + let input = format!( + r#" + identity_verified(username, private: identity_pod) = AND( + Equal(?identity_pod["{key_type}"], {signed_pod_type}) + Equal(?identity_pod["{key_signer}"], {identity_server_pk}) + Equal(?identity_pod["username"], ?username) + Equal(?identity_pod["user_public_key"], ?user_public_key) + ) + "#, + key_type = KEY_TYPE, + signed_pod_type = PodType::Signed as u32, + key_signer = KEY_SIGNER, + identity_server_pk = + "0x0000000000000000000000000000000000000000000000000000000000000000" + ); + + let result = parse(&input, ¶ms, available_batches); + + assert!(result.is_err()); + + match result.err().unwrap() { + LangError::Processor(e) => match *e { + ProcessorError::UndefinedWildcard { + name, pred_name, .. + } => { + assert_eq!(name, "user_public_key"); + assert_eq!(pred_name, "identity_verified"); + } + _ => panic!("Expected UndefinedWildcard error, but got {:?}", e), + }, + e => panic!("Expected LangError::Processor, but got {:?}", e), + } + } } diff --git a/src/lang/processor.rs b/src/lang/processor.rs index f0b5316..3ac244e 100644 --- a/src/lang/processor.rs +++ b/src/lang/processor.rs @@ -259,7 +259,10 @@ fn process_use_statement( } enum StatementContext<'a> { - CustomPredicate, + CustomPredicate { + pred_name: &'a str, + argument_names: &'a HashSet, + }, Request { custom_batch: &'a Arc, wildcard_names: &'a mut Vec, @@ -295,6 +298,7 @@ fn second_pass( fn pest_pair_to_builder_arg( params: &Params, arg_content_pair: &Pair, + context: &StatementContext, ) -> Result { match arg_content_pair.as_rule() { Rule::literal_value => { @@ -302,14 +306,41 @@ fn pest_pair_to_builder_arg( Ok(BuilderArg::Literal(value)) } Rule::wildcard => { - let name = arg_content_pair.as_str().strip_prefix("?").unwrap(); - Ok(BuilderArg::WildcardLiteral(name.to_string())) + let wc_str = arg_content_pair.as_str().strip_prefix("?").unwrap(); + if let StatementContext::CustomPredicate { + argument_names, + pred_name, + } = context + { + if !argument_names.contains(wc_str) { + return Err(ProcessorError::UndefinedWildcard { + name: wc_str.to_string(), + pred_name: pred_name.to_string(), + span: Some(get_span(arg_content_pair)), + }); + } + } + Ok(BuilderArg::WildcardLiteral(wc_str.to_string())) } Rule::anchored_key => { let mut inner_ak_pairs = arg_content_pair.clone().into_inner(); let pod_id_pair = inner_ak_pairs.next().unwrap(); let pod_id_wc_str = pod_id_pair.as_str().strip_prefix("?").unwrap(); + if let StatementContext::CustomPredicate { + argument_names, + pred_name, + } = context + { + if !argument_names.contains(pod_id_wc_str) { + return Err(ProcessorError::UndefinedWildcard { + name: pod_id_wc_str.to_string(), + pred_name: pred_name.to_string(), + span: Some(get_span(arg_content_pair)), + }); + } + } + let key_part_pair = inner_ak_pairs.next().unwrap(); let key_str = parse_pest_string_literal(&key_part_pair)?; Ok(BuilderArg::Key(pod_id_wc_str.to_string(), key_str)) @@ -516,7 +547,10 @@ fn process_and_add_custom_predicate_to_batch( params, &stmt_pair, processing_ctx, - StatementContext::CustomPredicate, + &mut StatementContext::CustomPredicate { + pred_name: &name, + argument_names: &defined_arg_names, + }, )?; statement_builders.push(stb); } @@ -557,7 +591,7 @@ fn process_request_def( params, &stmt_pair, processing_ctx, - StatementContext::Request { + &mut StatementContext::Request { custom_batch, wildcard_names: &mut request_wildcard_names, defined_wildcards: &mut defined_request_wildcards, @@ -582,7 +616,7 @@ fn process_statement_template( params: &Params, stmt_pair: &Pair, processing_ctx: &ProcessingContext, - mut context: StatementContext, + context: &mut StatementContext, ) -> Result { let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); let name_pair = inner_stmt_pairs @@ -590,13 +624,13 @@ fn process_statement_template( .unwrap(); let stmt_name_str = name_pair.as_str(); - let builder_args = parse_statement_args(params, stmt_pair)?; + let builder_args = parse_statement_args(params, stmt_pair, context)?; if let StatementContext::Request { wildcard_names, defined_wildcards, .. - } = &mut context + } = context { let mut temp_stmt_wildcard_names: Vec = Vec::new(); for arg in &builder_args { @@ -626,7 +660,7 @@ fn process_statement_template( .get(stmt_name_str) { match context { - StatementContext::CustomPredicate => Predicate::BatchSelf(*pred_index), + StatementContext::CustomPredicate { .. } => Predicate::BatchSelf(*pred_index), StatementContext::Request { custom_batch, .. } => { let custom_pred_ref = CustomPredicateRef::new(custom_batch.clone(), *pred_index); Predicate::Custom(custom_pred_ref) @@ -861,6 +895,7 @@ fn resolve_wildcard( .map(|index| Wildcard::new(name_to_resolve.to_string(), index)) .ok_or_else(|| ProcessorError::UndefinedWildcard { name: name_to_resolve.to_string(), + pred_name: "REQUEST".to_string(), span: None, }) } @@ -906,6 +941,7 @@ fn resolve_request_statement_builder( fn parse_statement_args( params: &Params, stmt_pair: &Pair, + context: &StatementContext, ) -> Result, ProcessorError> { let mut builder_args = Vec::new(); let mut inner_stmt_pairs = stmt_pair.clone().into_inner(); @@ -917,7 +953,7 @@ fn parse_statement_args( .filter(|p| p.as_rule() == Rule::statement_arg) { let arg_content_pair = arg_pair.into_inner().next().unwrap(); - let builder_arg = pest_pair_to_builder_arg(params, &arg_content_pair)?; + let builder_arg = pest_pair_to_builder_arg(params, &arg_content_pair, context)?; builder_args.push(builder_arg); } }