Self-referential predicate hashes as statement template args (#494)

* Support quoted predicate hashes, including self-referential predicates

* Clippy

* Review feedback
This commit is contained in:
Rob Knight 2026-03-24 14:25:11 +00:00 committed by GitHub
parent 13cabdb511
commit 1e592e11cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 573 additions and 31 deletions

View file

@ -49,6 +49,9 @@ pub enum StatementTmplArg {
// AnchoredKey where the origin is a wildcard
AnchoredKey(Wildcard, Key),
Wildcard(Wildcard),
/// Reference to a same-batch predicate's identity hash, resolved at verification time.
/// The usize is the predicate index within the batch.
SelfPredicateHash(usize),
}
#[derive(Clone, Copy)]
@ -57,6 +60,7 @@ pub enum StatementTmplArgPrefix {
Literal = 1,
AnchoredKey = 2,
WildcardLiteral = 3,
SelfPredicateHash = 4,
}
impl From<StatementTmplArgPrefix> for F {
@ -68,11 +72,12 @@ impl From<StatementTmplArgPrefix> for F {
impl ToFields for StatementTmplArg {
fn to_fields(&self) -> Vec<F> {
// Encoding:
// None => (0, 0, 0, 0, 0, 0, 0, 0, 0)
// Literal(v) => (1, [v ], 0, 0, 0, 0)
// Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
// None => (0, 0, 0, 0, 0, 0, 0, 0, 0)
// Literal(v) => (1, [v ], 0, 0, 0, 0)
// Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
// SelfPredicateHash(pred_index) => (4, pred_index, 0, 0, 0, 0, 0, 0, 0)
// In all cases, we pad to 2 * hash_size + 1 = 9 field elements
match self {
StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None))
.chain(iter::repeat(F::ZERO))
@ -97,6 +102,13 @@ impl ToFields for StatementTmplArg {
.take(Params::statement_tmpl_arg_size())
.collect_vec()
}
StatementTmplArg::SelfPredicateHash(index) => {
iter::once(F::from(StatementTmplArgPrefix::SelfPredicateHash))
.chain(iter::once(F::from_canonical_usize(*index)))
.chain(iter::repeat(F::ZERO))
.take(Params::statement_tmpl_arg_size())
.collect_vec()
}
}
}
}
@ -113,6 +125,7 @@ impl fmt::Display for StatementTmplArg {
write!(f, "]")
}
Self::Wildcard(v) => v.fmt(f),
Self::SelfPredicateHash(i) => write!(f, "::self.{}", i),
}
}
}
@ -569,6 +582,44 @@ impl CustomPredicateRef {
pub fn predicate(&self) -> &CustomPredicate {
&self.batch.predicates()[self.index]
}
/// Returns a copy of this predicate with all `SelfPredicateHash(i)` args
/// resolved to `Literal(hash(Custom(batch, i)))`.
pub fn normalized_predicate(&self) -> CustomPredicate {
let pred = self.predicate();
let normalized_statements = pred
.statements
.iter()
.map(|st_tmpl| {
let args = st_tmpl
.args
.iter()
.map(|arg| match arg {
StatementTmplArg::SelfPredicateHash(i) => {
let pred_hash = Predicate::Custom(CustomPredicateRef {
batch: self.batch.clone(),
index: *i,
})
.hash();
StatementTmplArg::Literal(Value::from(pred_hash))
}
other => other.clone(),
})
.collect();
StatementTmpl {
pred_or_wc: st_tmpl.pred_or_wc.clone(),
args,
}
})
.collect();
CustomPredicate {
name: pred.name.clone(),
conjunction: pred.conjunction,
statements: normalized_statements,
args_len: pred.args_len,
wildcard_names: pred.wildcard_names.clone(),
}
}
}
#[cfg(test)]
@ -823,4 +874,164 @@ mod tests {
Ok(())
}
#[test]
fn test_normalized_predicate() -> Result<()> {
let params = Params::default();
// Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))],
)],
2,
names(&["x", "y"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
// Compute expected pred_A hash
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let expected_hash = Value::from(Predicate::Custom(pred_a_ref).hash());
// Normalize pred_B
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let normalized = pred_b_ref.normalized_predicate();
// The second arg should be resolved to Literal(pred_A_hash)
assert_eq!(
normalized.statements[0].args[1],
STA::Literal(expected_hash)
);
// First arg should be unchanged (still a wildcard)
assert_eq!(normalized.statements[0].args[0], STA::Wildcard(wc(0)));
Ok(())
}
#[test]
fn test_self_predicate_hash_check() -> Result<()> {
let params = Params::default();
// Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))],
)],
2,
names(&["x", "y"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref).hash());
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
// Construct a valid operation: Equal(some_value, pred_a_hash)
let some_value = Value::from(42);
let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())];
// The output statement
let output_st = Statement::Custom(pred_b_ref.clone(), vec![some_value.clone()]);
// This should pass
assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(&params, &output_st)?);
// Now try with wrong hash, should fail
let wrong_hash = Value::from(999);
let bad_op_args = vec![Statement::equal(some_value.clone(), wrong_hash)];
assert!(Operation::Custom(pred_b_ref, bad_op_args)
.check(&params, &output_st)
.is_err());
Ok(())
}
#[test]
fn test_self_predicate_hash_cyclic() -> Result<()> {
let params = Params::default();
// Build a batch where pred_A references pred_B's hash and vice versa
// pred_A = Equal(x, SelfPredicateHash(1))
// pred_B = Equal(x, SelfPredicateHash(0))
let pred_a = CustomPredicate::and(
&params,
"pred_A".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(1)],
)],
1,
names(&["x"]),
)?;
let pred_b = CustomPredicate::and(
&params,
"pred_B".into(),
vec![st(
P::Native(NP::Equal),
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
)],
1,
names(&["x"]),
)?;
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref.clone()).hash());
let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
// pred_A's normalized form should reference pred_B's hash
let norm_a = pred_a_ref.normalized_predicate();
assert_eq!(
norm_a.statements[0].args[1],
STA::Literal(pred_b_hash.clone())
);
// pred_B's normalized form should reference pred_A's hash
let norm_b = pred_b_ref.normalized_predicate();
assert_eq!(
norm_b.statements[0].args[1],
STA::Literal(pred_a_hash.clone())
);
// Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass
let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())];
let st_a = Statement::Custom(pred_a_ref.clone(), vec![pred_b_hash.clone()]);
assert!(Operation::Custom(pred_a_ref, op_a).check(&params, &st_a)?);
// Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass
let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())];
let st_b = Statement::Custom(pred_b_ref.clone(), vec![pred_a_hash.clone()]);
assert!(Operation::Custom(pred_b_ref, op_b).check(&params, &st_b)?);
Ok(())
}
}