Self-referential predicate hashes as statement template args (#494)

* Support quoted predicate hashes, including self-referential predicates

* Clippy

* Review feedback
This commit is contained in:
Rob Knight 2026-03-24 14:25:11 +00:00 committed by GitHub
parent 13cabdb511
commit 1e592e11cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 573 additions and 31 deletions

View file

@ -771,7 +771,8 @@ impl CustomPredicateEntryTarget {
pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?; pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?;
pw.set_target(self.index, F::from_canonical_usize(predicate.index))?; pw.set_target(self.index, F::from_canonical_usize(predicate.index))?;
// Replace statement templates of batch-self with (id,index) // Replace BatchSelf predicates with Custom(batch, i), and
// SelfPredicateHash args with Literal(hash(Custom(batch, i)))
let batch = &predicate.batch; let batch = &predicate.batch;
let predicate = predicate.predicate(); let predicate = predicate.predicate();
let statements = predicate let statements = predicate
@ -788,10 +789,22 @@ impl CustomPredicateEntryTarget {
} }
x => x.clone(), x => x.clone(),
}; };
StatementTmpl { let args = st_tmpl
pred_or_wc, .args
args: st_tmpl.args, .into_iter()
.map(|arg| match arg {
StatementTmplArg::SelfPredicateHash(i) => {
let pred_hash = Predicate::Custom(CustomPredicateRef {
batch: batch.clone(),
index: i,
})
.hash();
StatementTmplArg::Literal(Value::from(pred_hash))
} }
other => other,
})
.collect();
StatementTmpl { pred_or_wc, args }
}) })
.collect_vec(); .collect_vec();
let predicate = CustomPredicate { let predicate = CustomPredicate {
@ -2012,7 +2025,7 @@ pub(crate) mod tests {
// Empty case // Empty case
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into()); let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
_ = cpb_builder.predicate_and("empty", &[], &[], &[])?; _ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
let custom_predicate_batch = cpb_builder.finish(); let custom_predicate_batch = cpb_builder.finish()?;
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap(); helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
// Some cases from the examples // Some cases from the examples

View file

@ -54,8 +54,8 @@ use crate::{
measure_gates_begin, measure_gates_end, measure_gates_begin, measure_gates_end,
middleware::{ middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
NativePredicate, Params, PredicatePrefix, RawValue, Statement, ToFields, Value, F, NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix,
HASH_SIZE, ToFields, Value, F, HASH_SIZE,
}, },
}; };
// //
@ -1534,8 +1534,8 @@ pub fn calculate_statements_hash_circuit(
sts_hash sts_hash
} }
// Replace predicates of batch-self with the corresponding global custom predicate batch_id and // Replace BatchSelf predicates with the corresponding Custom(batch_id, index), and
// index // SelfPredicateHash args with Literal(hash(Custom(batch_id, index))).
fn normalize_st_tmpl_circuit( fn normalize_st_tmpl_circuit(
params: &Params, params: &Params,
builder: &mut CircuitBuilder, builder: &mut CircuitBuilder,
@ -1564,7 +1564,41 @@ fn normalize_st_tmpl_circuit(
); );
let pred_hash_or_wc = let pred_hash_or_wc =
PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data); PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data);
StatementTmplTarget::new(pred_hash_or_wc, st_tmpl.args.clone())
// Normalize SelfPredicateHash args: replace prefix 4 with Literal containing the resolved
// predicate hash. Same pattern as the predicate normalization above.
let prefix_sph = builder.constant(F::from(StatementTmplArgPrefix::SelfPredicateHash));
let prefix_literal = builder.constant(F::from(StatementTmplArgPrefix::Literal));
let zero = builder.zero();
let normalized_args = st_tmpl
.args
.iter()
.map(|arg| {
let is_sph = builder.is_equal(arg.elements[0], prefix_sph);
// The predicate index is in elements[1] (same slot as WildcardLiteral).
let pred_index = arg.elements[1];
// Compute hash(Custom(batch_id, pred_index))
let pred_target = PredicateTarget::new_custom(builder, id, pred_index);
let pred_hash = pred_target.hash(builder);
// Build a Literal-encoded arg: [1, hash[0..4], 0, 0, 0, 0]
let mut literal_elements = [zero; Params::statement_tmpl_arg_size()];
literal_elements[0] = prefix_literal;
literal_elements[1] = pred_hash.elements[0];
literal_elements[2] = pred_hash.elements[1];
literal_elements[3] = pred_hash.elements[2];
literal_elements[4] = pred_hash.elements[3];
let normalized = StatementTmplArgTarget {
elements: literal_elements,
};
builder.select_flattenable(params, is_sph, &normalized, arg)
})
.collect();
StatementTmplTarget::new(pred_hash_or_wc, normalized_args)
} }
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as /// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
@ -3262,7 +3296,7 @@ mod tests {
&[stb0.clone(), stb1.clone()], &[stb0.clone(), stb1.clone()],
)?; )?;
let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?; let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?;
let batch = builder.finish(); let batch = builder.finish()?;
let dict = Hash([F(6), F(7), F(8), F(9)]); let dict = Hash([F(6), F(7), F(8), F(9)]);
@ -3352,7 +3386,7 @@ mod tests {
&[stb0.clone(), stb1.clone()], &[stb0.clone(), stb1.clone()],
)?; )?;
let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?; let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?;
let batch = builder.finish(); let batch = builder.finish()?;
let dict = Hash([F(1), F(2), F(3), F(4)]); let dict = Hash([F(1), F(2), F(3), F(4)]);
let secret_dict = Hash([F(6), F(7), F(8), F(9)]); let secret_dict = Hash([F(6), F(7), F(8), F(9)]);
@ -3570,4 +3604,100 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> {
let params = Params::default();
// Build a batch with two predicates:
// pred_A: Equal(x, y)
// pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash
use NativePredicate as NP;
let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal)
.arg("x")
.arg("y");
cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a])
.unwrap();
// Build pred_B's template manually with SelfPredicateHash(0)
let stb_b_tmpl = StatementTmpl {
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)),
args: vec![
StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)),
StatementTmplArg::SelfPredicateHash(0),
],
};
let pred_b = CustomPredicate::new(
&params,
"pred_B".into(),
true,
vec![stb_b_tmpl],
1,
vec!["x".to_string()],
)
.unwrap();
cpb.predicates.push(pred_b);
let batch = cpb.finish().unwrap();
// Compute the expected resolved hash of pred_A
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_a_hash = Predicate::Custom(pred_a_ref).hash();
let expected_pred_a_value = Value::from(pred_a_hash);
// Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to
// Literal(pred_a_hash). Then make_statement_from_template_circuit should produce
// a statement with that literal value.
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_b_tmpl = &pred_b_ref.predicate().statements[0];
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
// Create the template target and batch id target
let st_tmpl_target = builder.add_virtual_statement_tmpl(true);
let batch_id = builder.add_virtual_hash();
// Normalize the template (this is what we're testing)
let normalized =
normalize_st_tmpl_circuit(&params, &mut builder, &st_tmpl_target, batch_id);
// Feed normalized template into statement generation
let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value())
.collect();
let st_target =
make_statement_from_template_circuit(&params, &mut builder, &normalized, &args_target);
// Connect to expected output
let expected_st_target = builder.add_virtual_statement(false);
builder.connect_flattenable(&expected_st_target, &st_target);
// Set witness
let mut pw = PartialWitness::<F>::new();
st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?;
pw.set_target_arr(&batch_id.elements, &batch.id().0)?;
let some_value = Value::from(42);
// args: first wildcard is "x" = some_value, rest are padding
let mut args_values = vec![some_value.clone()];
for _ in 1..params.max_custom_predicate_wildcards {
args_values.push(Value::from(EMPTY_VALUE));
}
for (target, value) in args_target.iter().zip(args_values.iter()) {
target.set_targets(&mut pw, value)?;
}
// Expected statement: Equal(Literal(some_value), Literal(pred_a_hash))
let expected_st: crate::backends::plonky2::mainpod::Statement =
Statement::equal(some_value, expected_pred_a_value).into();
expected_st_target.set_targets(&mut pw, &expected_st)?;
// Build and verify
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
} }

View file

@ -104,8 +104,9 @@ pub(crate) fn extract_custom_predicate_verifications(
if let middleware::Operation::Custom(cpr, sts) = op { if let middleware::Operation::Custom(cpr, sts) = op {
if let middleware::Statement::Custom(st_cpr, st_args) = st { if let middleware::Statement::Custom(st_cpr, st_args) = st {
assert_eq!(cpr, st_cpr); assert_eq!(cpr, st_cpr);
let normalized_pred = cpr.normalized_predicate();
let wildcard_values = let wildcard_values =
wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args) wildcard_values_from_op_st(params, &normalized_pred, sts, st_args)
.expect("resolved wildcards"); .expect("resolved wildcards");
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect(); let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
let custom_predicate_table_index = custom_predicates let custom_predicate_table_index = custom_predicates
@ -1096,7 +1097,7 @@ pub mod tests {
&[stb0.clone(), stb1.clone()], &[stb0.clone(), stb1.clone()],
)?; )?;
let _ = cpb_builder.predicate_or("pred_or", &["dict"], &["secret_dict"], &[stb0, stb1])?; let _ = cpb_builder.predicate_or("pred_or", &["dict"], &["secret_dict"], &[stb0, stb1])?;
let cpb = cpb_builder.finish(); let cpb = cpb_builder.finish()?;
let cpb_and = CustomPredicateRef::new(cpb.clone(), 0); let cpb_and = CustomPredicateRef::new(cpb.clone(), 0);
let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1); let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1);
@ -1130,6 +1131,63 @@ pub mod tests {
Ok(pod.verify()?) Ok(pod.verify()?)
} }
#[test]
fn test_main_self_predicate_hash() -> frontend::Result<()> {
use frontend::BuilderArg;
let params = Params {
max_signed_by: 0,
max_input_pods: 0,
max_statements: 6,
max_public_statements: 2,
max_operation_args: 5,
max_custom_predicate_wildcards: 4,
max_custom_predicate_verifications: 2,
max_merkle_proofs_containers: 0,
max_merkle_tree_state_transition_proofs_containers: 0,
..Default::default()
};
let mut vds = DEFAULT_VD_LIST.clone();
vds.push(rec_main_pod_circuit_data(&params).1.verifier_only.clone());
let vd_set = VDSet::new(&vds);
// Build a batch: pred_A references pred_B's hash, pred_B references pred_A's hash
let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb_a = STB::new_from_pred(NP::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_B".into()));
cpb.predicate_and("pred_A", &["x"], &[], &[stb_a])?;
let stb_b = STB::new_from_pred(NP::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_A".into()));
cpb.predicate_and("pred_B", &["x"], &[], &[stb_b])?;
let batch = cpb.finish()?;
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_b_hash = middleware::Value::from(middleware::Predicate::Custom(pred_b_ref).hash());
// Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash)
let mut pod_builder = MainPodBuilder::new(&params, &vd_set);
let eq_st =
pod_builder.priv_op(frontend::Operation::eq(pred_b_hash.clone(), pred_b_hash))?;
pod_builder.pub_op(frontend::Operation::custom(pred_a_ref, [eq_st]))?;
// Mock
let prover = MockProver {};
let pod = pod_builder.prove(&prover)?;
assert!(pod.pod.verify().is_ok());
// Real
let prover = Prover {};
let pod = pod_builder.prove(&prover)?;
let pod = (pod.pod as Box<dyn Any>).downcast::<MainPod>().unwrap();
Ok(pod.verify()?)
}
#[test] #[test]
fn test_set_contains() -> frontend::Result<()> { fn test_set_contains() -> frontend::Result<()> {
let params = Params::default(); let params = Params::default();

View file

@ -18,6 +18,8 @@ pub enum BuilderArg {
/// Key: (origin, key), where origin is Wildcard and key is Key /// Key: (origin, key), where origin is Wildcard and key is Key
Key(String, String), Key(String, String),
WildcardLiteral(String), WildcardLiteral(String),
/// Reference to a same-batch predicate's identity hash (resolved by name in finish()).
SelfPredicateHash(String),
} }
/// When defining a `BuilderArg`, it can be done from 3 different inputs: /// When defining a `BuilderArg`, it can be done from 3 different inputs:
@ -130,6 +132,8 @@ pub struct CustomPredicateBatchBuilder {
params: Params, params: Params,
pub name: String, pub name: String,
pub predicates: Vec<CustomPredicate>, pub predicates: Vec<CustomPredicate>,
/// Forward references to resolve in finish(): (predicate_idx, statement_idx, arg_idx, name)
pending_self_pred_hashes: Vec<(usize, usize, usize, String)>,
} }
impl CustomPredicateBatchBuilder { impl CustomPredicateBatchBuilder {
@ -138,6 +142,7 @@ impl CustomPredicateBatchBuilder {
params, params,
name, name,
predicates: Vec::new(), predicates: Vec::new(),
pending_self_pred_hashes: Vec::new(),
} }
} }
@ -194,14 +199,18 @@ impl CustomPredicateBatchBuilder {
)); ));
} }
let pred_idx = self.predicates.len();
let mut pending = Vec::new();
let statements = sts let statements = sts
.iter() .iter()
.map(|sb| { .enumerate()
.map(|(stmt_idx, sb)| {
let stb = sb.clone().desugar(); let stb = sb.clone().desugar();
let st_tmpl_args = stb let st_tmpl_args = stb
.args .args
.iter() .iter()
.map(|a| { .enumerate()
.map(|(arg_idx, a)| {
Ok::<_, Error>(match a { Ok::<_, Error>(match a {
BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()), BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()),
BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey( BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey(
@ -211,6 +220,22 @@ impl CustomPredicateBatchBuilder {
BuilderArg::WildcardLiteral(v) => { BuilderArg::WildcardLiteral(v) => {
StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?) StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?)
} }
BuilderArg::SelfPredicateHash(pred_name) => {
// Try backward reference first
match self.predicates.iter().position(|p| p.name == *pred_name) {
Some(index) => StatementTmplArg::SelfPredicateHash(index),
None => {
// Forward reference - placeholder, resolved in finish()
pending.push((
pred_idx,
stmt_idx,
arg_idx,
pred_name.clone(),
));
StatementTmplArg::SelfPredicateHash(0)
}
}
}
}) })
}) })
.collect::<Result<_>>()?; .collect::<Result<_>>()?;
@ -240,11 +265,27 @@ impl CustomPredicateBatchBuilder {
.collect(), .collect(),
)?; )?;
self.predicates.push(custom_predicate); self.predicates.push(custom_predicate);
self.pending_self_pred_hashes.extend(pending);
Ok(Predicate::BatchSelf(self.predicates.len() - 1)) Ok(Predicate::BatchSelf(self.predicates.len() - 1))
} }
pub fn finish(self) -> Arc<CustomPredicateBatch> { pub fn finish(mut self) -> Result<Arc<CustomPredicateBatch>> {
CustomPredicateBatch::new(self.name, self.predicates) // Resolve forward references for SelfPredicateHash
for (pred_idx, stmt_idx, arg_idx, ref name) in &self.pending_self_pred_hashes {
let target_idx = self
.predicates
.iter()
.position(|p| p.name == *name)
.ok_or_else(|| {
Error::custom(format!(
"SelfPredicateHash references unknown predicate '{}'",
name
))
})?;
self.predicates[*pred_idx].statements[*stmt_idx].args[*arg_idx] =
StatementTmplArg::SelfPredicateHash(target_idx);
}
Ok(CustomPredicateBatch::new(self.name, self.predicates))
} }
} }
@ -306,7 +347,7 @@ mod tests {
.arg("s2"); .arg("s2");
builder.predicate_and("gt_custom_pred", &["s1", "s2"], &[], &[gt_stb])?; builder.predicate_and("gt_custom_pred", &["s1", "s2"], &[], &[gt_stb])?;
let batch = builder.finish(); let batch = builder.finish()?;
let batch_clone = batch.clone(); let batch_clone = batch.clone();
let gt_custom_pred = CustomPredicateRef::new(batch, 0); let gt_custom_pred = CustomPredicateRef::new(batch, 0);
@ -356,7 +397,7 @@ mod tests {
&[], &[],
&[set_contains_stb], &[set_contains_stb],
)?; )?;
let batch = builder.finish(); let batch = builder.finish()?;
let batch_clone = batch.clone(); let batch_clone = batch.clone();
let mut mp_builder = MainPodBuilder::new(&params, vd_set); let mut mp_builder = MainPodBuilder::new(&params, vd_set);
@ -386,4 +427,83 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_builder_self_predicate_hash_unknown_ref() {
let params = Params::default();
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("nonexistent".into()));
builder
.predicate_and("pred_A", &["x"], &[], &[stb])
.unwrap();
// finish() should fail because "nonexistent" was never defined
assert!(builder.finish().is_err());
}
/// Tests cyclic SelfPredicateHash references end-to-end:
/// pred_A references pred_B's hash (forward ref), pred_B references pred_A's hash (backward
/// ref). Exercises forward reference resolution in finish(), then builds and verifies a POD
/// using pred_A via MockProver.
#[test]
fn test_builder_self_predicate_hash_e2e() -> Result<()> {
let params = Params::default();
let vd_set = &*MOCK_VD_SET;
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
// pred_A references pred_B's hash (forward ref, pred_B not yet defined)
let stb_a = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_B".into()));
builder.predicate_and("pred_A", &["x"], &[], &[stb_a])?;
// pred_B references pred_A's hash (backward ref, pred_A already defined)
let stb_b = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
.arg("x")
.arg(BuilderArg::SelfPredicateHash("pred_A".into()));
builder.predicate_and("pred_B", &["x"], &[], &[stb_b])?;
let batch = builder.finish()?;
// Verify resolution: pred_A references pred_B (index 1), pred_B references pred_A (index 0)
assert_eq!(
batch.predicates()[0].statements[0].args[1],
StatementTmplArg::SelfPredicateHash(1)
);
assert_eq!(
batch.predicates()[1].statements[0].args[1],
StatementTmplArg::SelfPredicateHash(0)
);
// Compute concrete hashes
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
// Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash)
let mut mp_builder = MainPodBuilder::new(&params, vd_set);
let eq_st = mp_builder.priv_op(Operation::eq(pred_b_hash.clone(), pred_b_hash.clone()))?;
mp_builder.pub_op(Operation::custom(pred_a_ref, [eq_st]))?;
// Prove and verify
let prover = MockProver {};
let proof = mp_builder.prove(&prover)?;
proof.pod.verify()?;
// Verify the public statement contains pred_b_hash as its argument
let pub_sts = proof.pod.pub_self_statements();
let custom_st = pub_sts
.iter()
.find(|s| matches!(s, middleware::Statement::Custom(_, _)))
.expect("should have a custom statement");
if let middleware::Statement::Custom(_, args) = custom_st {
assert_eq!(args[0], pred_b_hash);
}
Ok(())
}
} }

View file

@ -578,7 +578,7 @@ impl MainPodBuilder {
} }
} }
OperationType::Custom(cpr) => { OperationType::Custom(cpr) => {
let pred = &cpr.batch.predicates()[cpr.index]; let pred = cpr.normalized_predicate();
if pred.statements.len() != op.1.len() { if pred.statements.len() != op.1.len() {
return Err(Error::custom(format!( return Err(Error::custom(format!(
"Custom predicate operation needs {} statements but has {}.", "Custom predicate operation needs {} statements but has {}.",
@ -606,7 +606,7 @@ impl MainPodBuilder {
} }
wildcard_map[index] = Some(value); wildcard_map[index] = Some(value);
} }
fill_wildcard_values(pred, &args, &mut wildcard_map)?; fill_wildcard_values(&pred, &args, &mut wildcard_map)?;
let v_default = Value::from(0); let v_default = Value::from(0);
let st_args: Vec<_> = wildcard_map let st_args: Vec<_> = wildcard_map
.into_iter() .into_iter()

View file

@ -346,6 +346,9 @@ impl<'a> Lowerer<'a> {
let key = Key::from(key_str.as_str()); let key = Key::from(key_str.as_str());
MWStatementTmplArg::AnchoredKey(wildcard, key) MWStatementTmplArg::AnchoredKey(wildcard, key)
} }
BuilderArg::SelfPredicateHash(_) => {
unreachable!("SelfPredicateHash should not appear in request lowering")
}
}; };
mw_args.push(mw_arg); mw_args.push(mw_arg);
} }

View file

@ -345,7 +345,9 @@ fn build_single_batch(
})?; })?;
} }
Ok(builder.finish()) builder.finish().map_err(|e| BatchingError::Internal {
message: format!("Failed to finalize batch '{}': {}", batch_name, e),
})
} }
/// Build a statement template with properly resolved predicate references /// Build a statement template with properly resolved predicate references

View file

@ -49,6 +49,9 @@ pub enum StatementTmplArg {
// AnchoredKey where the origin is a wildcard // AnchoredKey where the origin is a wildcard
AnchoredKey(Wildcard, Key), AnchoredKey(Wildcard, Key),
Wildcard(Wildcard), Wildcard(Wildcard),
/// Reference to a same-batch predicate's identity hash, resolved at verification time.
/// The usize is the predicate index within the batch.
SelfPredicateHash(usize),
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
@ -57,6 +60,7 @@ pub enum StatementTmplArgPrefix {
Literal = 1, Literal = 1,
AnchoredKey = 2, AnchoredKey = 2,
WildcardLiteral = 3, WildcardLiteral = 3,
SelfPredicateHash = 4,
} }
impl From<StatementTmplArgPrefix> for F { impl From<StatementTmplArgPrefix> for F {
@ -72,7 +76,8 @@ impl ToFields for StatementTmplArg {
// Literal(v) => (1, [v ], 0, 0, 0, 0) // Literal(v) => (1, [v ], 0, 0, 0, 0)
// Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc]) // Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0) // WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements // SelfPredicateHash(pred_index) => (4, pred_index, 0, 0, 0, 0, 0, 0, 0)
// In all cases, we pad to 2 * hash_size + 1 = 9 field elements
match self { match self {
StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None)) StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None))
.chain(iter::repeat(F::ZERO)) .chain(iter::repeat(F::ZERO))
@ -97,6 +102,13 @@ impl ToFields for StatementTmplArg {
.take(Params::statement_tmpl_arg_size()) .take(Params::statement_tmpl_arg_size())
.collect_vec() .collect_vec()
} }
StatementTmplArg::SelfPredicateHash(index) => {
iter::once(F::from(StatementTmplArgPrefix::SelfPredicateHash))
.chain(iter::once(F::from_canonical_usize(*index)))
.chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect_vec()
}
} }
} }
} }
@ -113,6 +125,7 @@ impl fmt::Display for StatementTmplArg {
write!(f, "]") write!(f, "]")
} }
Self::Wildcard(v) => v.fmt(f), Self::Wildcard(v) => v.fmt(f),
Self::SelfPredicateHash(i) => write!(f, "::self.{}", i),
} }
} }
} }
@ -569,6 +582,44 @@ impl CustomPredicateRef {
pub fn predicate(&self) -> &CustomPredicate { pub fn predicate(&self) -> &CustomPredicate {
&self.batch.predicates()[self.index] &self.batch.predicates()[self.index]
} }
/// Returns a copy of this predicate with all `SelfPredicateHash(i)` args
/// resolved to `Literal(hash(Custom(batch, i)))`.
pub fn normalized_predicate(&self) -> CustomPredicate {
let pred = self.predicate();
let normalized_statements = pred
.statements
.iter()
.map(|st_tmpl| {
let args = st_tmpl
.args
.iter()
.map(|arg| match arg {
StatementTmplArg::SelfPredicateHash(i) => {
let pred_hash = Predicate::Custom(CustomPredicateRef {
batch: self.batch.clone(),
index: *i,
})
.hash();
StatementTmplArg::Literal(Value::from(pred_hash))
}
other => other.clone(),
})
.collect();
StatementTmpl {
pred_or_wc: st_tmpl.pred_or_wc.clone(),
args,
}
})
.collect();
CustomPredicate {
name: pred.name.clone(),
conjunction: pred.conjunction,
statements: normalized_statements,
args_len: pred.args_len,
wildcard_names: pred.wildcard_names.clone(),
}
}
} }
#[cfg(test)] #[cfg(test)]
@ -823,4 +874,164 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_normalized_predicate() -> Result<()> {
let params = Params::default();
// Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))],
)],
2,
names(&["x", "y"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
// Compute expected pred_A hash
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let expected_hash = Value::from(Predicate::Custom(pred_a_ref).hash());
// Normalize pred_B
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let normalized = pred_b_ref.normalized_predicate();
// The second arg should be resolved to Literal(pred_A_hash)
assert_eq!(
normalized.statements[0].args[1],
STA::Literal(expected_hash)
);
// First arg should be unchanged (still a wildcard)
assert_eq!(normalized.statements[0].args[0], STA::Wildcard(wc(0)));
Ok(())
}
#[test]
fn test_self_predicate_hash_check() -> Result<()> {
let params = Params::default();
// Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))],
)],
2,
names(&["x", "y"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref).hash());
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
// Construct a valid operation: Equal(some_value, pred_a_hash)
let some_value = Value::from(42);
let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())];
// The output statement
let output_st = Statement::Custom(pred_b_ref.clone(), vec![some_value.clone()]);
// This should pass
assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(&params, &output_st)?);
// Now try with wrong hash, should fail
let wrong_hash = Value::from(999);
let bad_op_args = vec![Statement::equal(some_value.clone(), wrong_hash)];
assert!(Operation::Custom(pred_b_ref, bad_op_args)
.check(&params, &output_st)
.is_err());
Ok(())
}
#[test]
fn test_self_predicate_hash_cyclic() -> Result<()> {
let params = Params::default();
// Build a batch where pred_A references pred_B's hash and vice versa
// pred_A = Equal(x, SelfPredicateHash(1))
// pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(1)],
)],
1,
names(&["x"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref.clone()).hash());
let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
// pred_A's normalized form should reference pred_B's hash
let norm_a = pred_a_ref.normalized_predicate();
assert_eq!(
norm_a.statements[0].args[1],
STA::Literal(pred_b_hash.clone())
);
// pred_B's normalized form should reference pred_A's hash
let norm_b = pred_b_ref.normalized_predicate();
assert_eq!(
norm_b.statements[0].args[1],
STA::Literal(pred_a_hash.clone())
);
// Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass
let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())];
let st_a = Statement::Custom(pred_a_ref.clone(), vec![pred_b_hash.clone()]);
assert!(Operation::Custom(pred_a_ref, op_a).check(&params, &st_a)?);
// Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass
let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())];
let st_b = Statement::Custom(pred_b_ref.clone(), vec![pred_a_hash.clone()]);
assert!(Operation::Custom(pred_b_ref, op_b).check(&params, &st_b)?);
Ok(())
}
} }

View file

@ -595,6 +595,11 @@ pub fn check_st_tmpl(
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => { (StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
wc_check_or_set(v.clone(), wc, wildcard_map) wc_check_or_set(v.clone(), wc, wildcard_map)
} }
(StatementTmplArg::SelfPredicateHash(_), _) => {
unreachable!(
"SelfPredicateHash should be normalized to Literal before template matching"
)
}
_ => Err(Error::mismatched_statement_tmpl_arg( _ => Err(Error::mismatched_statement_tmpl_arg(
st_tmpl_arg.clone(), st_tmpl_arg.clone(),
st_arg.clone(), st_arg.clone(),
@ -712,7 +717,7 @@ pub(crate) fn check_custom_pred(
args: &[Statement], args: &[Statement],
s_args: &[Value], s_args: &[Value],
) -> Result<()> { ) -> Result<()> {
let pred = custom_pred_ref.predicate(); let pred = custom_pred_ref.normalized_predicate();
if pred.statements.len() != args.len() { if pred.statements.len() != args.len() {
return Err(Error::diff_amount( return Err(Error::diff_amount(
"custom predicate operation".to_string(), "custom predicate operation".to_string(),
@ -731,7 +736,7 @@ pub(crate) fn check_custom_pred(
} }
// Check that the resolved wildcards match the statement arguments. // Check that the resolved wildcards match the statement arguments.
let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) { let wc_values = match wildcard_values_from_op_st(params, &pred, args, s_args) {
Ok(wc_values) => wc_values, Ok(wc_values) => wc_values,
Err(Error::Inner { inner, backtrace }) => match *inner { Err(Error::Inner { inner, backtrace }) => match *inner {
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev) MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)