diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index db8c32a..7d25786 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -771,7 +771,8 @@ impl CustomPredicateEntryTarget { pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?; pw.set_target(self.index, F::from_canonical_usize(predicate.index))?; - // Replace statement templates of batch-self with (id,index) + // Replace BatchSelf predicates with Custom(batch, i), and + // SelfPredicateHash args with Literal(hash(Custom(batch, i))) let batch = &predicate.batch; let predicate = predicate.predicate(); let statements = predicate @@ -788,10 +789,22 @@ impl CustomPredicateEntryTarget { } x => x.clone(), }; - StatementTmpl { - pred_or_wc, - args: st_tmpl.args, - } + let args = st_tmpl + .args + .into_iter() + .map(|arg| match arg { + StatementTmplArg::SelfPredicateHash(i) => { + let pred_hash = Predicate::Custom(CustomPredicateRef { + batch: batch.clone(), + index: i, + }) + .hash(); + StatementTmplArg::Literal(Value::from(pred_hash)) + } + other => other, + }) + .collect(); + StatementTmpl { pred_or_wc, args } }) .collect_vec(); let predicate = CustomPredicate { @@ -2012,7 +2025,7 @@ pub(crate) mod tests { // Empty case let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into()); _ = cpb_builder.predicate_and("empty", &[], &[], &[])?; - let custom_predicate_batch = cpb_builder.finish(); + let custom_predicate_batch = cpb_builder.finish()?; helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap(); // Some cases from the examples diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index b0c8f48..68114d2 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -54,8 +54,8 @@ use crate::{ measure_gates_begin, measure_gates_end, middleware::{ CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, - NativePredicate, Params, PredicatePrefix, RawValue, Statement, ToFields, Value, F, - HASH_SIZE, + NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix, + ToFields, Value, F, HASH_SIZE, }, }; // @@ -1534,8 +1534,8 @@ pub fn calculate_statements_hash_circuit( sts_hash } -// Replace predicates of batch-self with the corresponding global custom predicate batch_id and -// index +// Replace BatchSelf predicates with the corresponding Custom(batch_id, index), and +// SelfPredicateHash args with Literal(hash(Custom(batch_id, index))). fn normalize_st_tmpl_circuit( params: &Params, builder: &mut CircuitBuilder, @@ -1564,7 +1564,41 @@ fn normalize_st_tmpl_circuit( ); let pred_hash_or_wc = PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data); - StatementTmplTarget::new(pred_hash_or_wc, st_tmpl.args.clone()) + + // Normalize SelfPredicateHash args: replace prefix 4 with Literal containing the resolved + // predicate hash. Same pattern as the predicate normalization above. + let prefix_sph = builder.constant(F::from(StatementTmplArgPrefix::SelfPredicateHash)); + let prefix_literal = builder.constant(F::from(StatementTmplArgPrefix::Literal)); + let zero = builder.zero(); + let normalized_args = st_tmpl + .args + .iter() + .map(|arg| { + let is_sph = builder.is_equal(arg.elements[0], prefix_sph); + + // The predicate index is in elements[1] (same slot as WildcardLiteral). + let pred_index = arg.elements[1]; + + // Compute hash(Custom(batch_id, pred_index)) + let pred_target = PredicateTarget::new_custom(builder, id, pred_index); + let pred_hash = pred_target.hash(builder); + + // Build a Literal-encoded arg: [1, hash[0..4], 0, 0, 0, 0] + let mut literal_elements = [zero; Params::statement_tmpl_arg_size()]; + literal_elements[0] = prefix_literal; + literal_elements[1] = pred_hash.elements[0]; + literal_elements[2] = pred_hash.elements[1]; + literal_elements[3] = pred_hash.elements[2]; + literal_elements[4] = pred_hash.elements[3]; + let normalized = StatementTmplArgTarget { + elements: literal_elements, + }; + + builder.select_flattenable(params, is_sph, &normalized, arg) + }) + .collect(); + + StatementTmplTarget::new(pred_hash_or_wc, normalized_args) } /// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as @@ -3262,7 +3296,7 @@ mod tests { &[stb0.clone(), stb1.clone()], )?; let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?; - let batch = builder.finish(); + let batch = builder.finish()?; let dict = Hash([F(6), F(7), F(8), F(9)]); @@ -3352,7 +3386,7 @@ mod tests { &[stb0.clone(), stb1.clone()], )?; let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?; - let batch = builder.finish(); + let batch = builder.finish()?; let dict = Hash([F(1), F(2), F(3), F(4)]); let secret_dict = Hash([F(6), F(7), F(8), F(9)]); @@ -3570,4 +3604,100 @@ mod tests { Ok(()) } + + #[test] + fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> { + let params = Params::default(); + + // Build a batch with two predicates: + // pred_A: Equal(x, y) + // pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash + use NativePredicate as NP; + let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal) + .arg("x") + .arg("y"); + cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a]) + .unwrap(); + + // Build pred_B's template manually with SelfPredicateHash(0) + let stb_b_tmpl = StatementTmpl { + pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)), + args: vec![ + StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)), + StatementTmplArg::SelfPredicateHash(0), + ], + }; + let pred_b = CustomPredicate::new( + ¶ms, + "pred_B".into(), + true, + vec![stb_b_tmpl], + 1, + vec!["x".to_string()], + ) + .unwrap(); + cpb.predicates.push(pred_b); + let batch = cpb.finish().unwrap(); + + // Compute the expected resolved hash of pred_A + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_a_hash = Predicate::Custom(pred_a_ref).hash(); + let expected_pred_a_value = Value::from(pred_a_hash); + + // Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to + // Literal(pred_a_hash). Then make_statement_from_template_circuit should produce + // a statement with that literal value. + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_b_tmpl = &pred_b_ref.predicate().statements[0]; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + // Create the template target and batch id target + let st_tmpl_target = builder.add_virtual_statement_tmpl(true); + let batch_id = builder.add_virtual_hash(); + + // Normalize the template (this is what we're testing) + let normalized = + normalize_st_tmpl_circuit(¶ms, &mut builder, &st_tmpl_target, batch_id); + + // Feed normalized template into statement generation + let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards) + .map(|_| builder.add_virtual_value()) + .collect(); + let st_target = + make_statement_from_template_circuit(¶ms, &mut builder, &normalized, &args_target); + + // Connect to expected output + let expected_st_target = builder.add_virtual_statement(false); + builder.connect_flattenable(&expected_st_target, &st_target); + + // Set witness + let mut pw = PartialWitness::::new(); + st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?; + pw.set_target_arr(&batch_id.elements, &batch.id().0)?; + + let some_value = Value::from(42); + // args: first wildcard is "x" = some_value, rest are padding + let mut args_values = vec![some_value.clone()]; + for _ in 1..params.max_custom_predicate_wildcards { + args_values.push(Value::from(EMPTY_VALUE)); + } + for (target, value) in args_target.iter().zip(args_values.iter()) { + target.set_targets(&mut pw, value)?; + } + + // Expected statement: Equal(Literal(some_value), Literal(pred_a_hash)) + let expected_st: crate::backends::plonky2::mainpod::Statement = + Statement::equal(some_value, expected_pred_a_value).into(); + expected_st_target.set_targets(&mut pw, &expected_st)?; + + // Build and verify + let data = builder.build::(); + let proof = data.prove(pw)?; + data.verify(proof)?; + + Ok(()) + } } diff --git a/src/backends/plonky2/mainpod/mod.rs b/src/backends/plonky2/mainpod/mod.rs index 8e6ed46..4968316 100644 --- a/src/backends/plonky2/mainpod/mod.rs +++ b/src/backends/plonky2/mainpod/mod.rs @@ -104,8 +104,9 @@ pub(crate) fn extract_custom_predicate_verifications( if let middleware::Operation::Custom(cpr, sts) = op { if let middleware::Statement::Custom(st_cpr, st_args) = st { assert_eq!(cpr, st_cpr); + let normalized_pred = cpr.normalized_predicate(); let wildcard_values = - wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args) + wildcard_values_from_op_st(params, &normalized_pred, sts, st_args) .expect("resolved wildcards"); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); let custom_predicate_table_index = custom_predicates @@ -1096,7 +1097,7 @@ pub mod tests { &[stb0.clone(), stb1.clone()], )?; let _ = cpb_builder.predicate_or("pred_or", &["dict"], &["secret_dict"], &[stb0, stb1])?; - let cpb = cpb_builder.finish(); + let cpb = cpb_builder.finish()?; let cpb_and = CustomPredicateRef::new(cpb.clone(), 0); let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1); @@ -1130,6 +1131,63 @@ pub mod tests { Ok(pod.verify()?) } + #[test] + fn test_main_self_predicate_hash() -> frontend::Result<()> { + use frontend::BuilderArg; + + let params = Params { + max_signed_by: 0, + max_input_pods: 0, + max_statements: 6, + max_public_statements: 2, + max_operation_args: 5, + max_custom_predicate_wildcards: 4, + max_custom_predicate_verifications: 2, + max_merkle_proofs_containers: 0, + max_merkle_tree_state_transition_proofs_containers: 0, + ..Default::default() + }; + let mut vds = DEFAULT_VD_LIST.clone(); + vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone()); + let vd_set = VDSet::new(&vds); + + // Build a batch: pred_A references pred_B's hash, pred_B references pred_A's hash + let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + let stb_a = STB::new_from_pred(NP::Equal) + .arg("x") + .arg(BuilderArg::SelfPredicateHash("pred_B".into())); + cpb.predicate_and("pred_A", &["x"], &[], &[stb_a])?; + + let stb_b = STB::new_from_pred(NP::Equal) + .arg("x") + .arg(BuilderArg::SelfPredicateHash("pred_A".into())); + cpb.predicate_and("pred_B", &["x"], &[], &[stb_b])?; + + let batch = cpb.finish()?; + + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_b_hash = middleware::Value::from(middleware::Predicate::Custom(pred_b_ref).hash()); + + // Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash) + let mut pod_builder = MainPodBuilder::new(¶ms, &vd_set); + let eq_st = + pod_builder.priv_op(frontend::Operation::eq(pred_b_hash.clone(), pred_b_hash))?; + pod_builder.pub_op(frontend::Operation::custom(pred_a_ref, [eq_st]))?; + + // Mock + let prover = MockProver {}; + let pod = pod_builder.prove(&prover)?; + assert!(pod.pod.verify().is_ok()); + + // Real + let prover = Prover {}; + let pod = pod_builder.prove(&prover)?; + let pod = (pod.pod as Box).downcast::().unwrap(); + + Ok(pod.verify()?) + } + #[test] fn test_set_contains() -> frontend::Result<()> { let params = Params::default(); diff --git a/src/frontend/custom.rs b/src/frontend/custom.rs index 92fdc4f..f3a8115 100644 --- a/src/frontend/custom.rs +++ b/src/frontend/custom.rs @@ -18,6 +18,8 @@ pub enum BuilderArg { /// Key: (origin, key), where origin is Wildcard and key is Key Key(String, String), WildcardLiteral(String), + /// Reference to a same-batch predicate's identity hash (resolved by name in finish()). + SelfPredicateHash(String), } /// When defining a `BuilderArg`, it can be done from 3 different inputs: @@ -130,6 +132,8 @@ pub struct CustomPredicateBatchBuilder { params: Params, pub name: String, pub predicates: Vec, + /// Forward references to resolve in finish(): (predicate_idx, statement_idx, arg_idx, name) + pending_self_pred_hashes: Vec<(usize, usize, usize, String)>, } impl CustomPredicateBatchBuilder { @@ -138,6 +142,7 @@ impl CustomPredicateBatchBuilder { params, name, predicates: Vec::new(), + pending_self_pred_hashes: Vec::new(), } } @@ -194,14 +199,18 @@ impl CustomPredicateBatchBuilder { )); } + let pred_idx = self.predicates.len(); + let mut pending = Vec::new(); let statements = sts .iter() - .map(|sb| { + .enumerate() + .map(|(stmt_idx, sb)| { let stb = sb.clone().desugar(); let st_tmpl_args = stb .args .iter() - .map(|a| { + .enumerate() + .map(|(arg_idx, a)| { Ok::<_, Error>(match a { BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey( @@ -211,6 +220,22 @@ impl CustomPredicateBatchBuilder { BuilderArg::WildcardLiteral(v) => { StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?) } + BuilderArg::SelfPredicateHash(pred_name) => { + // Try backward reference first + match self.predicates.iter().position(|p| p.name == *pred_name) { + Some(index) => StatementTmplArg::SelfPredicateHash(index), + None => { + // Forward reference - placeholder, resolved in finish() + pending.push(( + pred_idx, + stmt_idx, + arg_idx, + pred_name.clone(), + )); + StatementTmplArg::SelfPredicateHash(0) + } + } + } }) }) .collect::>()?; @@ -240,11 +265,27 @@ impl CustomPredicateBatchBuilder { .collect(), )?; self.predicates.push(custom_predicate); + self.pending_self_pred_hashes.extend(pending); Ok(Predicate::BatchSelf(self.predicates.len() - 1)) } - pub fn finish(self) -> Arc { - CustomPredicateBatch::new(self.name, self.predicates) + pub fn finish(mut self) -> Result> { + // Resolve forward references for SelfPredicateHash + for (pred_idx, stmt_idx, arg_idx, ref name) in &self.pending_self_pred_hashes { + let target_idx = self + .predicates + .iter() + .position(|p| p.name == *name) + .ok_or_else(|| { + Error::custom(format!( + "SelfPredicateHash references unknown predicate '{}'", + name + )) + })?; + self.predicates[*pred_idx].statements[*stmt_idx].args[*arg_idx] = + StatementTmplArg::SelfPredicateHash(target_idx); + } + Ok(CustomPredicateBatch::new(self.name, self.predicates)) } } @@ -306,7 +347,7 @@ mod tests { .arg("s2"); builder.predicate_and("gt_custom_pred", &["s1", "s2"], &[], &[gt_stb])?; - let batch = builder.finish(); + let batch = builder.finish()?; let batch_clone = batch.clone(); let gt_custom_pred = CustomPredicateRef::new(batch, 0); @@ -356,7 +397,7 @@ mod tests { &[], &[set_contains_stb], )?; - let batch = builder.finish(); + let batch = builder.finish()?; let batch_clone = batch.clone(); let mut mp_builder = MainPodBuilder::new(¶ms, vd_set); @@ -386,4 +427,83 @@ mod tests { Ok(()) } + + #[test] + fn test_builder_self_predicate_hash_unknown_ref() { + let params = Params::default(); + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + + let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) + .arg("x") + .arg(BuilderArg::SelfPredicateHash("nonexistent".into())); + builder + .predicate_and("pred_A", &["x"], &[], &[stb]) + .unwrap(); + + // finish() should fail because "nonexistent" was never defined + assert!(builder.finish().is_err()); + } + + /// Tests cyclic SelfPredicateHash references end-to-end: + /// pred_A references pred_B's hash (forward ref), pred_B references pred_A's hash (backward + /// ref). Exercises forward reference resolution in finish(), then builds and verifies a POD + /// using pred_A via MockProver. + #[test] + fn test_builder_self_predicate_hash_e2e() -> Result<()> { + let params = Params::default(); + let vd_set = &*MOCK_VD_SET; + + let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into()); + + // pred_A references pred_B's hash (forward ref, pred_B not yet defined) + let stb_a = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) + .arg("x") + .arg(BuilderArg::SelfPredicateHash("pred_B".into())); + builder.predicate_and("pred_A", &["x"], &[], &[stb_a])?; + + // pred_B references pred_A's hash (backward ref, pred_A already defined) + let stb_b = StatementTmplBuilder::new_from_pred(NativePredicate::Equal) + .arg("x") + .arg(BuilderArg::SelfPredicateHash("pred_A".into())); + builder.predicate_and("pred_B", &["x"], &[], &[stb_b])?; + + let batch = builder.finish()?; + + // Verify resolution: pred_A references pred_B (index 1), pred_B references pred_A (index 0) + assert_eq!( + batch.predicates()[0].statements[0].args[1], + StatementTmplArg::SelfPredicateHash(1) + ); + assert_eq!( + batch.predicates()[1].statements[0].args[1], + StatementTmplArg::SelfPredicateHash(0) + ); + + // Compute concrete hashes + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash()); + + // Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash) + let mut mp_builder = MainPodBuilder::new(¶ms, vd_set); + let eq_st = mp_builder.priv_op(Operation::eq(pred_b_hash.clone(), pred_b_hash.clone()))?; + mp_builder.pub_op(Operation::custom(pred_a_ref, [eq_st]))?; + + // Prove and verify + let prover = MockProver {}; + let proof = mp_builder.prove(&prover)?; + proof.pod.verify()?; + + // Verify the public statement contains pred_b_hash as its argument + let pub_sts = proof.pod.pub_self_statements(); + let custom_st = pub_sts + .iter() + .find(|s| matches!(s, middleware::Statement::Custom(_, _))) + .expect("should have a custom statement"); + if let middleware::Statement::Custom(_, args) = custom_st { + assert_eq!(args[0], pred_b_hash); + } + + Ok(()) + } } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 98f280e..1ce2795 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -578,7 +578,7 @@ impl MainPodBuilder { } } OperationType::Custom(cpr) => { - let pred = &cpr.batch.predicates()[cpr.index]; + let pred = cpr.normalized_predicate(); if pred.statements.len() != op.1.len() { return Err(Error::custom(format!( "Custom predicate operation needs {} statements but has {}.", @@ -606,7 +606,7 @@ impl MainPodBuilder { } wildcard_map[index] = Some(value); } - fill_wildcard_values(pred, &args, &mut wildcard_map)?; + fill_wildcard_values(&pred, &args, &mut wildcard_map)?; let v_default = Value::from(0); let st_args: Vec<_> = wildcard_map .into_iter() diff --git a/src/lang/frontend_ast_lower.rs b/src/lang/frontend_ast_lower.rs index b429f4a..fe9b745 100644 --- a/src/lang/frontend_ast_lower.rs +++ b/src/lang/frontend_ast_lower.rs @@ -346,6 +346,9 @@ impl<'a> Lowerer<'a> { let key = Key::from(key_str.as_str()); MWStatementTmplArg::AnchoredKey(wildcard, key) } + BuilderArg::SelfPredicateHash(_) => { + unreachable!("SelfPredicateHash should not appear in request lowering") + } }; mw_args.push(mw_arg); } diff --git a/src/lang/module.rs b/src/lang/module.rs index 3ff3d6b..78fb22e 100644 --- a/src/lang/module.rs +++ b/src/lang/module.rs @@ -345,7 +345,9 @@ fn build_single_batch( })?; } - Ok(builder.finish()) + builder.finish().map_err(|e| BatchingError::Internal { + message: format!("Failed to finalize batch '{}': {}", batch_name, e), + }) } /// Build a statement template with properly resolved predicate references diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs index 13cc387..cf6d9be 100644 --- a/src/middleware/custom.rs +++ b/src/middleware/custom.rs @@ -49,6 +49,9 @@ pub enum StatementTmplArg { // AnchoredKey where the origin is a wildcard AnchoredKey(Wildcard, Key), Wildcard(Wildcard), + /// Reference to a same-batch predicate's identity hash, resolved at verification time. + /// The usize is the predicate index within the batch. + SelfPredicateHash(usize), } #[derive(Clone, Copy)] @@ -57,6 +60,7 @@ pub enum StatementTmplArgPrefix { Literal = 1, AnchoredKey = 2, WildcardLiteral = 3, + SelfPredicateHash = 4, } impl From for F { @@ -68,11 +72,12 @@ impl From for F { impl ToFields for StatementTmplArg { fn to_fields(&self) -> Vec { // Encoding: - // None => (0, 0, 0, 0, 0, 0, 0, 0, 0) - // Literal(v) => (1, [v ], 0, 0, 0, 0) - // Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc]) - // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0) - // In all three cases, we pad to 2 * hash_size + 1 = 9 field elements + // None => (0, 0, 0, 0, 0, 0, 0, 0, 0) + // Literal(v) => (1, [v ], 0, 0, 0, 0) + // Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc]) + // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0) + // SelfPredicateHash(pred_index) => (4, pred_index, 0, 0, 0, 0, 0, 0, 0) + // In all cases, we pad to 2 * hash_size + 1 = 9 field elements match self { StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None)) .chain(iter::repeat(F::ZERO)) @@ -97,6 +102,13 @@ impl ToFields for StatementTmplArg { .take(Params::statement_tmpl_arg_size()) .collect_vec() } + StatementTmplArg::SelfPredicateHash(index) => { + iter::once(F::from(StatementTmplArgPrefix::SelfPredicateHash)) + .chain(iter::once(F::from_canonical_usize(*index))) + .chain(iter::repeat(F::ZERO)) + .take(Params::statement_tmpl_arg_size()) + .collect_vec() + } } } } @@ -113,6 +125,7 @@ impl fmt::Display for StatementTmplArg { write!(f, "]") } Self::Wildcard(v) => v.fmt(f), + Self::SelfPredicateHash(i) => write!(f, "::self.{}", i), } } } @@ -569,6 +582,44 @@ impl CustomPredicateRef { pub fn predicate(&self) -> &CustomPredicate { &self.batch.predicates()[self.index] } + + /// Returns a copy of this predicate with all `SelfPredicateHash(i)` args + /// resolved to `Literal(hash(Custom(batch, i)))`. + pub fn normalized_predicate(&self) -> CustomPredicate { + let pred = self.predicate(); + let normalized_statements = pred + .statements + .iter() + .map(|st_tmpl| { + let args = st_tmpl + .args + .iter() + .map(|arg| match arg { + StatementTmplArg::SelfPredicateHash(i) => { + let pred_hash = Predicate::Custom(CustomPredicateRef { + batch: self.batch.clone(), + index: *i, + }) + .hash(); + StatementTmplArg::Literal(Value::from(pred_hash)) + } + other => other.clone(), + }) + .collect(); + StatementTmpl { + pred_or_wc: st_tmpl.pred_or_wc.clone(), + args, + } + }) + .collect(); + CustomPredicate { + name: pred.name.clone(), + conjunction: pred.conjunction, + statements: normalized_statements, + args_len: pred.args_len, + wildcard_names: pred.wildcard_names.clone(), + } + } } #[cfg(test)] @@ -823,4 +874,164 @@ mod tests { Ok(()) } + + #[test] + fn test_normalized_predicate() -> Result<()> { + let params = Params::default(); + + // Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0)) + let pred_a = CustomPredicate::and( + ¶ms, + "pred_A".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))], + )], + 2, + names(&["x", "y"]), + )?; + let pred_b = CustomPredicate::and( + ¶ms, + "pred_B".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)], + )], + 1, + names(&["x"]), + )?; + let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]); + + // Compute expected pred_A hash + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let expected_hash = Value::from(Predicate::Custom(pred_a_ref).hash()); + + // Normalize pred_B + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let normalized = pred_b_ref.normalized_predicate(); + + // The second arg should be resolved to Literal(pred_A_hash) + assert_eq!( + normalized.statements[0].args[1], + STA::Literal(expected_hash) + ); + + // First arg should be unchanged (still a wildcard) + assert_eq!(normalized.statements[0].args[0], STA::Wildcard(wc(0))); + + Ok(()) + } + + #[test] + fn test_self_predicate_hash_check() -> Result<()> { + let params = Params::default(); + + // Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0)) + let pred_a = CustomPredicate::and( + ¶ms, + "pred_A".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))], + )], + 2, + names(&["x", "y"]), + )?; + let pred_b = CustomPredicate::and( + ¶ms, + "pred_B".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)], + )], + 1, + names(&["x"]), + )?; + let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]); + + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref).hash()); + + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + + // Construct a valid operation: Equal(some_value, pred_a_hash) + let some_value = Value::from(42); + let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())]; + + // The output statement + let output_st = Statement::Custom(pred_b_ref.clone(), vec![some_value.clone()]); + + // This should pass + assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(¶ms, &output_st)?); + + // Now try with wrong hash, should fail + let wrong_hash = Value::from(999); + let bad_op_args = vec![Statement::equal(some_value.clone(), wrong_hash)]; + assert!(Operation::Custom(pred_b_ref, bad_op_args) + .check(¶ms, &output_st) + .is_err()); + + Ok(()) + } + + #[test] + fn test_self_predicate_hash_cyclic() -> Result<()> { + let params = Params::default(); + + // Build a batch where pred_A references pred_B's hash and vice versa + // pred_A = Equal(x, SelfPredicateHash(1)) + // pred_B = Equal(x, SelfPredicateHash(0)) + let pred_a = CustomPredicate::and( + ¶ms, + "pred_A".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(1)], + )], + 1, + names(&["x"]), + )?; + let pred_b = CustomPredicate::and( + ¶ms, + "pred_B".into(), + vec![st( + P::Native(NP::Equal), + vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)], + )], + 1, + names(&["x"]), + )?; + let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]); + + let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0); + let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1); + let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref.clone()).hash()); + let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash()); + + // pred_A's normalized form should reference pred_B's hash + let norm_a = pred_a_ref.normalized_predicate(); + assert_eq!( + norm_a.statements[0].args[1], + STA::Literal(pred_b_hash.clone()) + ); + + // pred_B's normalized form should reference pred_A's hash + let norm_b = pred_b_ref.normalized_predicate(); + assert_eq!( + norm_b.statements[0].args[1], + STA::Literal(pred_a_hash.clone()) + ); + + // Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass + let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())]; + let st_a = Statement::Custom(pred_a_ref.clone(), vec![pred_b_hash.clone()]); + assert!(Operation::Custom(pred_a_ref, op_a).check(¶ms, &st_a)?); + + // Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass + let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())]; + let st_b = Statement::Custom(pred_b_ref.clone(), vec![pred_a_hash.clone()]); + assert!(Operation::Custom(pred_b_ref, op_b).check(¶ms, &st_b)?); + + Ok(()) + } } diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index dfdfcfc..1793e4d 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -595,6 +595,11 @@ pub fn check_st_tmpl( (StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => { wc_check_or_set(v.clone(), wc, wildcard_map) } + (StatementTmplArg::SelfPredicateHash(_), _) => { + unreachable!( + "SelfPredicateHash should be normalized to Literal before template matching" + ) + } _ => Err(Error::mismatched_statement_tmpl_arg( st_tmpl_arg.clone(), st_arg.clone(), @@ -712,7 +717,7 @@ pub(crate) fn check_custom_pred( args: &[Statement], s_args: &[Value], ) -> Result<()> { - let pred = custom_pred_ref.predicate(); + let pred = custom_pred_ref.normalized_predicate(); if pred.statements.len() != args.len() { return Err(Error::diff_amount( "custom predicate operation".to_string(), @@ -731,7 +736,7 @@ pub(crate) fn check_custom_pred( } // Check that the resolved wildcards match the statement arguments. - let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) { + let wc_values = match wildcard_values_from_op_st(params, &pred, args, s_args) { Ok(wc_values) => wc_values, Err(Error::Inner { inner, backtrace }) => match *inner { MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)