Self-referential predicate hashes as statement template args (#494)

* Support quoted predicate hashes, including self-referential predicates

* Clippy

* Review feedback
This commit is contained in:
Rob Knight 2026-03-24 14:25:11 +00:00 committed by GitHub
parent 13cabdb511
commit 1e592e11cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 573 additions and 31 deletions

View file

@ -54,8 +54,8 @@ use crate::{
measure_gates_begin, measure_gates_end,
middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
NativePredicate, Params, PredicatePrefix, RawValue, Statement, ToFields, Value, F,
HASH_SIZE,
NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix,
ToFields, Value, F, HASH_SIZE,
},
};
//
@ -1534,8 +1534,8 @@ pub fn calculate_statements_hash_circuit(
sts_hash
}
// Replace predicates of batch-self with the corresponding global custom predicate batch_id and
// index
// Replace BatchSelf predicates with the corresponding Custom(batch_id, index), and
// SelfPredicateHash args with Literal(hash(Custom(batch_id, index))).
fn normalize_st_tmpl_circuit(
params: &Params,
builder: &mut CircuitBuilder,
@ -1564,7 +1564,41 @@ fn normalize_st_tmpl_circuit(
);
let pred_hash_or_wc =
PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data);
StatementTmplTarget::new(pred_hash_or_wc, st_tmpl.args.clone())
// Normalize SelfPredicateHash args: replace prefix 4 with Literal containing the resolved
// predicate hash. Same pattern as the predicate normalization above.
let prefix_sph = builder.constant(F::from(StatementTmplArgPrefix::SelfPredicateHash));
let prefix_literal = builder.constant(F::from(StatementTmplArgPrefix::Literal));
let zero = builder.zero();
let normalized_args = st_tmpl
.args
.iter()
.map(|arg| {
let is_sph = builder.is_equal(arg.elements[0], prefix_sph);
// The predicate index is in elements[1] (same slot as WildcardLiteral).
let pred_index = arg.elements[1];
// Compute hash(Custom(batch_id, pred_index))
let pred_target = PredicateTarget::new_custom(builder, id, pred_index);
let pred_hash = pred_target.hash(builder);
// Build a Literal-encoded arg: [1, hash[0..4], 0, 0, 0, 0]
let mut literal_elements = [zero; Params::statement_tmpl_arg_size()];
literal_elements[0] = prefix_literal;
literal_elements[1] = pred_hash.elements[0];
literal_elements[2] = pred_hash.elements[1];
literal_elements[3] = pred_hash.elements[2];
literal_elements[4] = pred_hash.elements[3];
let normalized = StatementTmplArgTarget {
elements: literal_elements,
};
builder.select_flattenable(params, is_sph, &normalized, arg)
})
.collect();
StatementTmplTarget::new(pred_hash_or_wc, normalized_args)
}
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
@ -3262,7 +3296,7 @@ mod tests {
&[stb0.clone(), stb1.clone()],
)?;
let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?;
let batch = builder.finish();
let batch = builder.finish()?;
let dict = Hash([F(6), F(7), F(8), F(9)]);
@ -3352,7 +3386,7 @@ mod tests {
&[stb0.clone(), stb1.clone()],
)?;
let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?;
let batch = builder.finish();
let batch = builder.finish()?;
let dict = Hash([F(1), F(2), F(3), F(4)]);
let secret_dict = Hash([F(6), F(7), F(8), F(9)]);
@ -3570,4 +3604,100 @@ mod tests {
Ok(())
}
#[test]
fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> {
let params = Params::default();
// Build a batch with two predicates:
// pred_A: Equal(x, y)
// pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash
use NativePredicate as NP;
let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal)
.arg("x")
.arg("y");
cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a])
.unwrap();
// Build pred_B's template manually with SelfPredicateHash(0)
let stb_b_tmpl = StatementTmpl {
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)),
args: vec![
StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)),
StatementTmplArg::SelfPredicateHash(0),
],
};
let pred_b = CustomPredicate::new(
&params,
"pred_B".into(),
true,
vec![stb_b_tmpl],
1,
vec!["x".to_string()],
)
.unwrap();
cpb.predicates.push(pred_b);
let batch = cpb.finish().unwrap();
// Compute the expected resolved hash of pred_A
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
let pred_a_hash = Predicate::Custom(pred_a_ref).hash();
let expected_pred_a_value = Value::from(pred_a_hash);
// Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to
// Literal(pred_a_hash). Then make_statement_from_template_circuit should produce
// a statement with that literal value.
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
let pred_b_tmpl = &pred_b_ref.predicate().statements[0];
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
// Create the template target and batch id target
let st_tmpl_target = builder.add_virtual_statement_tmpl(true);
let batch_id = builder.add_virtual_hash();
// Normalize the template (this is what we're testing)
let normalized =
normalize_st_tmpl_circuit(&params, &mut builder, &st_tmpl_target, batch_id);
// Feed normalized template into statement generation
let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards)
.map(|_| builder.add_virtual_value())
.collect();
let st_target =
make_statement_from_template_circuit(&params, &mut builder, &normalized, &args_target);
// Connect to expected output
let expected_st_target = builder.add_virtual_statement(false);
builder.connect_flattenable(&expected_st_target, &st_target);
// Set witness
let mut pw = PartialWitness::<F>::new();
st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?;
pw.set_target_arr(&batch_id.elements, &batch.id().0)?;
let some_value = Value::from(42);
// args: first wildcard is "x" = some_value, rest are padding
let mut args_values = vec![some_value.clone()];
for _ in 1..params.max_custom_predicate_wildcards {
args_values.push(Value::from(EMPTY_VALUE));
}
for (target, value) in args_target.iter().zip(args_values.iter()) {
target.set_targets(&mut pw, value)?;
}
// Expected statement: Equal(Literal(some_value), Literal(pred_a_hash))
let expected_st: crate::backends::plonky2::mainpod::Statement =
Statement::equal(some_value, expected_pred_a_value).into();
expected_st_target.set_targets(&mut pw, &expected_st)?;
// Build and verify
let data = builder.build::<C>();
let proof = data.prove(pw)?;
data.verify(proof)?;
Ok(())
}
}