Self-referential predicate hashes as statement template args (#494)
* Support quoted predicate hashes, including self-referential predicates * Clippy * Review feedback
This commit is contained in:
parent
13cabdb511
commit
1e592e11cf
9 changed files with 573 additions and 31 deletions
|
|
@ -771,7 +771,8 @@ impl CustomPredicateEntryTarget {
|
||||||
pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?;
|
pw.set_target_arr(&self.id.elements, &predicate.batch.id().0)?;
|
||||||
pw.set_target(self.index, F::from_canonical_usize(predicate.index))?;
|
pw.set_target(self.index, F::from_canonical_usize(predicate.index))?;
|
||||||
|
|
||||||
// Replace statement templates of batch-self with (id,index)
|
// Replace BatchSelf predicates with Custom(batch, i), and
|
||||||
|
// SelfPredicateHash args with Literal(hash(Custom(batch, i)))
|
||||||
let batch = &predicate.batch;
|
let batch = &predicate.batch;
|
||||||
let predicate = predicate.predicate();
|
let predicate = predicate.predicate();
|
||||||
let statements = predicate
|
let statements = predicate
|
||||||
|
|
@ -788,10 +789,22 @@ impl CustomPredicateEntryTarget {
|
||||||
}
|
}
|
||||||
x => x.clone(),
|
x => x.clone(),
|
||||||
};
|
};
|
||||||
StatementTmpl {
|
let args = st_tmpl
|
||||||
pred_or_wc,
|
.args
|
||||||
args: st_tmpl.args,
|
.into_iter()
|
||||||
}
|
.map(|arg| match arg {
|
||||||
|
StatementTmplArg::SelfPredicateHash(i) => {
|
||||||
|
let pred_hash = Predicate::Custom(CustomPredicateRef {
|
||||||
|
batch: batch.clone(),
|
||||||
|
index: i,
|
||||||
|
})
|
||||||
|
.hash();
|
||||||
|
StatementTmplArg::Literal(Value::from(pred_hash))
|
||||||
|
}
|
||||||
|
other => other,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
StatementTmpl { pred_or_wc, args }
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
let predicate = CustomPredicate {
|
let predicate = CustomPredicate {
|
||||||
|
|
@ -2012,7 +2025,7 @@ pub(crate) mod tests {
|
||||||
// Empty case
|
// Empty case
|
||||||
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
|
let mut cpb_builder = CustomPredicateBatchBuilder::new(params.clone(), "empty".into());
|
||||||
_ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
|
_ = cpb_builder.predicate_and("empty", &[], &[], &[])?;
|
||||||
let custom_predicate_batch = cpb_builder.finish();
|
let custom_predicate_batch = cpb_builder.finish()?;
|
||||||
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
|
helper_custom_predicate_in_batch_target(&custom_predicate_batch).unwrap();
|
||||||
|
|
||||||
// Some cases from the examples
|
// Some cases from the examples
|
||||||
|
|
|
||||||
|
|
@ -54,8 +54,8 @@ use crate::{
|
||||||
measure_gates_begin, measure_gates_end,
|
measure_gates_begin, measure_gates_end,
|
||||||
middleware::{
|
middleware::{
|
||||||
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
|
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
|
||||||
NativePredicate, Params, PredicatePrefix, RawValue, Statement, ToFields, Value, F,
|
NativePredicate, Params, PredicatePrefix, RawValue, Statement, StatementTmplArgPrefix,
|
||||||
HASH_SIZE,
|
ToFields, Value, F, HASH_SIZE,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
//
|
//
|
||||||
|
|
@ -1534,8 +1534,8 @@ pub fn calculate_statements_hash_circuit(
|
||||||
sts_hash
|
sts_hash
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace predicates of batch-self with the corresponding global custom predicate batch_id and
|
// Replace BatchSelf predicates with the corresponding Custom(batch_id, index), and
|
||||||
// index
|
// SelfPredicateHash args with Literal(hash(Custom(batch_id, index))).
|
||||||
fn normalize_st_tmpl_circuit(
|
fn normalize_st_tmpl_circuit(
|
||||||
params: &Params,
|
params: &Params,
|
||||||
builder: &mut CircuitBuilder,
|
builder: &mut CircuitBuilder,
|
||||||
|
|
@ -1564,7 +1564,41 @@ fn normalize_st_tmpl_circuit(
|
||||||
);
|
);
|
||||||
let pred_hash_or_wc =
|
let pred_hash_or_wc =
|
||||||
PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data);
|
PredicateHashOrWildcardTarget::new(st_tmpl.pred_hash_or_wc().elements[0], data);
|
||||||
StatementTmplTarget::new(pred_hash_or_wc, st_tmpl.args.clone())
|
|
||||||
|
// Normalize SelfPredicateHash args: replace prefix 4 with Literal containing the resolved
|
||||||
|
// predicate hash. Same pattern as the predicate normalization above.
|
||||||
|
let prefix_sph = builder.constant(F::from(StatementTmplArgPrefix::SelfPredicateHash));
|
||||||
|
let prefix_literal = builder.constant(F::from(StatementTmplArgPrefix::Literal));
|
||||||
|
let zero = builder.zero();
|
||||||
|
let normalized_args = st_tmpl
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|arg| {
|
||||||
|
let is_sph = builder.is_equal(arg.elements[0], prefix_sph);
|
||||||
|
|
||||||
|
// The predicate index is in elements[1] (same slot as WildcardLiteral).
|
||||||
|
let pred_index = arg.elements[1];
|
||||||
|
|
||||||
|
// Compute hash(Custom(batch_id, pred_index))
|
||||||
|
let pred_target = PredicateTarget::new_custom(builder, id, pred_index);
|
||||||
|
let pred_hash = pred_target.hash(builder);
|
||||||
|
|
||||||
|
// Build a Literal-encoded arg: [1, hash[0..4], 0, 0, 0, 0]
|
||||||
|
let mut literal_elements = [zero; Params::statement_tmpl_arg_size()];
|
||||||
|
literal_elements[0] = prefix_literal;
|
||||||
|
literal_elements[1] = pred_hash.elements[0];
|
||||||
|
literal_elements[2] = pred_hash.elements[1];
|
||||||
|
literal_elements[3] = pred_hash.elements[2];
|
||||||
|
literal_elements[4] = pred_hash.elements[3];
|
||||||
|
let normalized = StatementTmplArgTarget {
|
||||||
|
elements: literal_elements,
|
||||||
|
};
|
||||||
|
|
||||||
|
builder.select_flattenable(params, is_sph, &normalized, arg)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
StatementTmplTarget::new(pred_hash_or_wc, normalized_args)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
|
/// Build a table of [batch_id, custom_predicate_index, custom_predicate] with queryable part as
|
||||||
|
|
@ -3262,7 +3296,7 @@ mod tests {
|
||||||
&[stb0.clone(), stb1.clone()],
|
&[stb0.clone(), stb1.clone()],
|
||||||
)?;
|
)?;
|
||||||
let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?;
|
let _ = builder.predicate_or("pred_or", &["id"], &["secret"], &[stb0, stb1])?;
|
||||||
let batch = builder.finish();
|
let batch = builder.finish()?;
|
||||||
|
|
||||||
let dict = Hash([F(6), F(7), F(8), F(9)]);
|
let dict = Hash([F(6), F(7), F(8), F(9)]);
|
||||||
|
|
||||||
|
|
@ -3352,7 +3386,7 @@ mod tests {
|
||||||
&[stb0.clone(), stb1.clone()],
|
&[stb0.clone(), stb1.clone()],
|
||||||
)?;
|
)?;
|
||||||
let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?;
|
let _ = builder.predicate_or("pred_or", &["id"], &["secret_id"], &[stb0, stb1])?;
|
||||||
let batch = builder.finish();
|
let batch = builder.finish()?;
|
||||||
|
|
||||||
let dict = Hash([F(1), F(2), F(3), F(4)]);
|
let dict = Hash([F(1), F(2), F(3), F(4)]);
|
||||||
let secret_dict = Hash([F(6), F(7), F(8), F(9)]);
|
let secret_dict = Hash([F(6), F(7), F(8), F(9)]);
|
||||||
|
|
@ -3570,4 +3604,100 @@ mod tests {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_st_tmpl_self_predicate_hash() -> Result<()> {
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
// Build a batch with two predicates:
|
||||||
|
// pred_A: Equal(x, y)
|
||||||
|
// pred_B: Equal(x, SelfPredicateHash(0)), references pred_A's hash
|
||||||
|
use NativePredicate as NP;
|
||||||
|
let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
|
||||||
|
let stb_a = StatementTmplBuilder::new_from_pred(NP::Equal)
|
||||||
|
.arg("x")
|
||||||
|
.arg("y");
|
||||||
|
cpb.predicate_and("pred_A", &["x", "y"], &[], &[stb_a])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Build pred_B's template manually with SelfPredicateHash(0)
|
||||||
|
let stb_b_tmpl = StatementTmpl {
|
||||||
|
pred_or_wc: PredicateOrWildcard::Predicate(Predicate::Native(NP::Equal)),
|
||||||
|
args: vec![
|
||||||
|
StatementTmplArg::Wildcard(Wildcard::new("x".to_string(), 0)),
|
||||||
|
StatementTmplArg::SelfPredicateHash(0),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
let pred_b = CustomPredicate::new(
|
||||||
|
¶ms,
|
||||||
|
"pred_B".into(),
|
||||||
|
true,
|
||||||
|
vec![stb_b_tmpl],
|
||||||
|
1,
|
||||||
|
vec!["x".to_string()],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
cpb.predicates.push(pred_b);
|
||||||
|
let batch = cpb.finish().unwrap();
|
||||||
|
|
||||||
|
// Compute the expected resolved hash of pred_A
|
||||||
|
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
|
||||||
|
let pred_a_hash = Predicate::Custom(pred_a_ref).hash();
|
||||||
|
let expected_pred_a_value = Value::from(pred_a_hash);
|
||||||
|
|
||||||
|
// Test: normalize_st_tmpl_circuit should convert SelfPredicateHash(0) to
|
||||||
|
// Literal(pred_a_hash). Then make_statement_from_template_circuit should produce
|
||||||
|
// a statement with that literal value.
|
||||||
|
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
|
||||||
|
let pred_b_tmpl = &pred_b_ref.predicate().statements[0];
|
||||||
|
|
||||||
|
let config = CircuitConfig::standard_recursion_config();
|
||||||
|
let mut builder = CircuitBuilder::new(config);
|
||||||
|
|
||||||
|
// Create the template target and batch id target
|
||||||
|
let st_tmpl_target = builder.add_virtual_statement_tmpl(true);
|
||||||
|
let batch_id = builder.add_virtual_hash();
|
||||||
|
|
||||||
|
// Normalize the template (this is what we're testing)
|
||||||
|
let normalized =
|
||||||
|
normalize_st_tmpl_circuit(¶ms, &mut builder, &st_tmpl_target, batch_id);
|
||||||
|
|
||||||
|
// Feed normalized template into statement generation
|
||||||
|
let args_target: Vec<_> = (0..params.max_custom_predicate_wildcards)
|
||||||
|
.map(|_| builder.add_virtual_value())
|
||||||
|
.collect();
|
||||||
|
let st_target =
|
||||||
|
make_statement_from_template_circuit(¶ms, &mut builder, &normalized, &args_target);
|
||||||
|
|
||||||
|
// Connect to expected output
|
||||||
|
let expected_st_target = builder.add_virtual_statement(false);
|
||||||
|
builder.connect_flattenable(&expected_st_target, &st_target);
|
||||||
|
|
||||||
|
// Set witness
|
||||||
|
let mut pw = PartialWitness::<F>::new();
|
||||||
|
st_tmpl_target.set_targets(&mut pw, pred_b_tmpl)?;
|
||||||
|
pw.set_target_arr(&batch_id.elements, &batch.id().0)?;
|
||||||
|
|
||||||
|
let some_value = Value::from(42);
|
||||||
|
// args: first wildcard is "x" = some_value, rest are padding
|
||||||
|
let mut args_values = vec![some_value.clone()];
|
||||||
|
for _ in 1..params.max_custom_predicate_wildcards {
|
||||||
|
args_values.push(Value::from(EMPTY_VALUE));
|
||||||
|
}
|
||||||
|
for (target, value) in args_target.iter().zip(args_values.iter()) {
|
||||||
|
target.set_targets(&mut pw, value)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expected statement: Equal(Literal(some_value), Literal(pred_a_hash))
|
||||||
|
let expected_st: crate::backends::plonky2::mainpod::Statement =
|
||||||
|
Statement::equal(some_value, expected_pred_a_value).into();
|
||||||
|
expected_st_target.set_targets(&mut pw, &expected_st)?;
|
||||||
|
|
||||||
|
// Build and verify
|
||||||
|
let data = builder.build::<C>();
|
||||||
|
let proof = data.prove(pw)?;
|
||||||
|
data.verify(proof)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -104,8 +104,9 @@ pub(crate) fn extract_custom_predicate_verifications(
|
||||||
if let middleware::Operation::Custom(cpr, sts) = op {
|
if let middleware::Operation::Custom(cpr, sts) = op {
|
||||||
if let middleware::Statement::Custom(st_cpr, st_args) = st {
|
if let middleware::Statement::Custom(st_cpr, st_args) = st {
|
||||||
assert_eq!(cpr, st_cpr);
|
assert_eq!(cpr, st_cpr);
|
||||||
|
let normalized_pred = cpr.normalized_predicate();
|
||||||
let wildcard_values =
|
let wildcard_values =
|
||||||
wildcard_values_from_op_st(params, cpr.predicate(), sts, st_args)
|
wildcard_values_from_op_st(params, &normalized_pred, sts, st_args)
|
||||||
.expect("resolved wildcards");
|
.expect("resolved wildcards");
|
||||||
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
|
let sts = sts.iter().map(|s| Statement::from(s.clone())).collect();
|
||||||
let custom_predicate_table_index = custom_predicates
|
let custom_predicate_table_index = custom_predicates
|
||||||
|
|
@ -1096,7 +1097,7 @@ pub mod tests {
|
||||||
&[stb0.clone(), stb1.clone()],
|
&[stb0.clone(), stb1.clone()],
|
||||||
)?;
|
)?;
|
||||||
let _ = cpb_builder.predicate_or("pred_or", &["dict"], &["secret_dict"], &[stb0, stb1])?;
|
let _ = cpb_builder.predicate_or("pred_or", &["dict"], &["secret_dict"], &[stb0, stb1])?;
|
||||||
let cpb = cpb_builder.finish();
|
let cpb = cpb_builder.finish()?;
|
||||||
|
|
||||||
let cpb_and = CustomPredicateRef::new(cpb.clone(), 0);
|
let cpb_and = CustomPredicateRef::new(cpb.clone(), 0);
|
||||||
let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1);
|
let _cpb_or = CustomPredicateRef::new(cpb.clone(), 1);
|
||||||
|
|
@ -1130,6 +1131,63 @@ pub mod tests {
|
||||||
Ok(pod.verify()?)
|
Ok(pod.verify()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_main_self_predicate_hash() -> frontend::Result<()> {
|
||||||
|
use frontend::BuilderArg;
|
||||||
|
|
||||||
|
let params = Params {
|
||||||
|
max_signed_by: 0,
|
||||||
|
max_input_pods: 0,
|
||||||
|
max_statements: 6,
|
||||||
|
max_public_statements: 2,
|
||||||
|
max_operation_args: 5,
|
||||||
|
max_custom_predicate_wildcards: 4,
|
||||||
|
max_custom_predicate_verifications: 2,
|
||||||
|
max_merkle_proofs_containers: 0,
|
||||||
|
max_merkle_tree_state_transition_proofs_containers: 0,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut vds = DEFAULT_VD_LIST.clone();
|
||||||
|
vds.push(rec_main_pod_circuit_data(¶ms).1.verifier_only.clone());
|
||||||
|
let vd_set = VDSet::new(&vds);
|
||||||
|
|
||||||
|
// Build a batch: pred_A references pred_B's hash, pred_B references pred_A's hash
|
||||||
|
let mut cpb = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
|
||||||
|
let stb_a = STB::new_from_pred(NP::Equal)
|
||||||
|
.arg("x")
|
||||||
|
.arg(BuilderArg::SelfPredicateHash("pred_B".into()));
|
||||||
|
cpb.predicate_and("pred_A", &["x"], &[], &[stb_a])?;
|
||||||
|
|
||||||
|
let stb_b = STB::new_from_pred(NP::Equal)
|
||||||
|
.arg("x")
|
||||||
|
.arg(BuilderArg::SelfPredicateHash("pred_A".into()));
|
||||||
|
cpb.predicate_and("pred_B", &["x"], &[], &[stb_b])?;
|
||||||
|
|
||||||
|
let batch = cpb.finish()?;
|
||||||
|
|
||||||
|
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
|
||||||
|
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
|
||||||
|
let pred_b_hash = middleware::Value::from(middleware::Predicate::Custom(pred_b_ref).hash());
|
||||||
|
|
||||||
|
// Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash)
|
||||||
|
let mut pod_builder = MainPodBuilder::new(¶ms, &vd_set);
|
||||||
|
let eq_st =
|
||||||
|
pod_builder.priv_op(frontend::Operation::eq(pred_b_hash.clone(), pred_b_hash))?;
|
||||||
|
pod_builder.pub_op(frontend::Operation::custom(pred_a_ref, [eq_st]))?;
|
||||||
|
|
||||||
|
// Mock
|
||||||
|
let prover = MockProver {};
|
||||||
|
let pod = pod_builder.prove(&prover)?;
|
||||||
|
assert!(pod.pod.verify().is_ok());
|
||||||
|
|
||||||
|
// Real
|
||||||
|
let prover = Prover {};
|
||||||
|
let pod = pod_builder.prove(&prover)?;
|
||||||
|
let pod = (pod.pod as Box<dyn Any>).downcast::<MainPod>().unwrap();
|
||||||
|
|
||||||
|
Ok(pod.verify()?)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_set_contains() -> frontend::Result<()> {
|
fn test_set_contains() -> frontend::Result<()> {
|
||||||
let params = Params::default();
|
let params = Params::default();
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ pub enum BuilderArg {
|
||||||
/// Key: (origin, key), where origin is Wildcard and key is Key
|
/// Key: (origin, key), where origin is Wildcard and key is Key
|
||||||
Key(String, String),
|
Key(String, String),
|
||||||
WildcardLiteral(String),
|
WildcardLiteral(String),
|
||||||
|
/// Reference to a same-batch predicate's identity hash (resolved by name in finish()).
|
||||||
|
SelfPredicateHash(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// When defining a `BuilderArg`, it can be done from 3 different inputs:
|
/// When defining a `BuilderArg`, it can be done from 3 different inputs:
|
||||||
|
|
@ -130,6 +132,8 @@ pub struct CustomPredicateBatchBuilder {
|
||||||
params: Params,
|
params: Params,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub predicates: Vec<CustomPredicate>,
|
pub predicates: Vec<CustomPredicate>,
|
||||||
|
/// Forward references to resolve in finish(): (predicate_idx, statement_idx, arg_idx, name)
|
||||||
|
pending_self_pred_hashes: Vec<(usize, usize, usize, String)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CustomPredicateBatchBuilder {
|
impl CustomPredicateBatchBuilder {
|
||||||
|
|
@ -138,6 +142,7 @@ impl CustomPredicateBatchBuilder {
|
||||||
params,
|
params,
|
||||||
name,
|
name,
|
||||||
predicates: Vec::new(),
|
predicates: Vec::new(),
|
||||||
|
pending_self_pred_hashes: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -194,14 +199,18 @@ impl CustomPredicateBatchBuilder {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let pred_idx = self.predicates.len();
|
||||||
|
let mut pending = Vec::new();
|
||||||
let statements = sts
|
let statements = sts
|
||||||
.iter()
|
.iter()
|
||||||
.map(|sb| {
|
.enumerate()
|
||||||
|
.map(|(stmt_idx, sb)| {
|
||||||
let stb = sb.clone().desugar();
|
let stb = sb.clone().desugar();
|
||||||
let st_tmpl_args = stb
|
let st_tmpl_args = stb
|
||||||
.args
|
.args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|a| {
|
.enumerate()
|
||||||
|
.map(|(arg_idx, a)| {
|
||||||
Ok::<_, Error>(match a {
|
Ok::<_, Error>(match a {
|
||||||
BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()),
|
BuilderArg::Literal(v) => StatementTmplArg::Literal(v.clone()),
|
||||||
BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey(
|
BuilderArg::Key(root_wc, key_str) => StatementTmplArg::AnchoredKey(
|
||||||
|
|
@ -211,6 +220,22 @@ impl CustomPredicateBatchBuilder {
|
||||||
BuilderArg::WildcardLiteral(v) => {
|
BuilderArg::WildcardLiteral(v) => {
|
||||||
StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?)
|
StatementTmplArg::Wildcard(resolve_wildcard(args, priv_args, v)?)
|
||||||
}
|
}
|
||||||
|
BuilderArg::SelfPredicateHash(pred_name) => {
|
||||||
|
// Try backward reference first
|
||||||
|
match self.predicates.iter().position(|p| p.name == *pred_name) {
|
||||||
|
Some(index) => StatementTmplArg::SelfPredicateHash(index),
|
||||||
|
None => {
|
||||||
|
// Forward reference - placeholder, resolved in finish()
|
||||||
|
pending.push((
|
||||||
|
pred_idx,
|
||||||
|
stmt_idx,
|
||||||
|
arg_idx,
|
||||||
|
pred_name.clone(),
|
||||||
|
));
|
||||||
|
StatementTmplArg::SelfPredicateHash(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Result<_>>()?;
|
.collect::<Result<_>>()?;
|
||||||
|
|
@ -240,11 +265,27 @@ impl CustomPredicateBatchBuilder {
|
||||||
.collect(),
|
.collect(),
|
||||||
)?;
|
)?;
|
||||||
self.predicates.push(custom_predicate);
|
self.predicates.push(custom_predicate);
|
||||||
|
self.pending_self_pred_hashes.extend(pending);
|
||||||
Ok(Predicate::BatchSelf(self.predicates.len() - 1))
|
Ok(Predicate::BatchSelf(self.predicates.len() - 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn finish(self) -> Arc<CustomPredicateBatch> {
|
pub fn finish(mut self) -> Result<Arc<CustomPredicateBatch>> {
|
||||||
CustomPredicateBatch::new(self.name, self.predicates)
|
// Resolve forward references for SelfPredicateHash
|
||||||
|
for (pred_idx, stmt_idx, arg_idx, ref name) in &self.pending_self_pred_hashes {
|
||||||
|
let target_idx = self
|
||||||
|
.predicates
|
||||||
|
.iter()
|
||||||
|
.position(|p| p.name == *name)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
Error::custom(format!(
|
||||||
|
"SelfPredicateHash references unknown predicate '{}'",
|
||||||
|
name
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
self.predicates[*pred_idx].statements[*stmt_idx].args[*arg_idx] =
|
||||||
|
StatementTmplArg::SelfPredicateHash(target_idx);
|
||||||
|
}
|
||||||
|
Ok(CustomPredicateBatch::new(self.name, self.predicates))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -306,7 +347,7 @@ mod tests {
|
||||||
.arg("s2");
|
.arg("s2");
|
||||||
|
|
||||||
builder.predicate_and("gt_custom_pred", &["s1", "s2"], &[], &[gt_stb])?;
|
builder.predicate_and("gt_custom_pred", &["s1", "s2"], &[], &[gt_stb])?;
|
||||||
let batch = builder.finish();
|
let batch = builder.finish()?;
|
||||||
let batch_clone = batch.clone();
|
let batch_clone = batch.clone();
|
||||||
let gt_custom_pred = CustomPredicateRef::new(batch, 0);
|
let gt_custom_pred = CustomPredicateRef::new(batch, 0);
|
||||||
|
|
||||||
|
|
@ -356,7 +397,7 @@ mod tests {
|
||||||
&[],
|
&[],
|
||||||
&[set_contains_stb],
|
&[set_contains_stb],
|
||||||
)?;
|
)?;
|
||||||
let batch = builder.finish();
|
let batch = builder.finish()?;
|
||||||
let batch_clone = batch.clone();
|
let batch_clone = batch.clone();
|
||||||
|
|
||||||
let mut mp_builder = MainPodBuilder::new(¶ms, vd_set);
|
let mut mp_builder = MainPodBuilder::new(¶ms, vd_set);
|
||||||
|
|
@ -386,4 +427,83 @@ mod tests {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_builder_self_predicate_hash_unknown_ref() {
|
||||||
|
let params = Params::default();
|
||||||
|
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
|
||||||
|
|
||||||
|
let stb = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
|
||||||
|
.arg("x")
|
||||||
|
.arg(BuilderArg::SelfPredicateHash("nonexistent".into()));
|
||||||
|
builder
|
||||||
|
.predicate_and("pred_A", &["x"], &[], &[stb])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// finish() should fail because "nonexistent" was never defined
|
||||||
|
assert!(builder.finish().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tests cyclic SelfPredicateHash references end-to-end:
|
||||||
|
/// pred_A references pred_B's hash (forward ref), pred_B references pred_A's hash (backward
|
||||||
|
/// ref). Exercises forward reference resolution in finish(), then builds and verifies a POD
|
||||||
|
/// using pred_A via MockProver.
|
||||||
|
#[test]
|
||||||
|
fn test_builder_self_predicate_hash_e2e() -> Result<()> {
|
||||||
|
let params = Params::default();
|
||||||
|
let vd_set = &*MOCK_VD_SET;
|
||||||
|
|
||||||
|
let mut builder = CustomPredicateBatchBuilder::new(params.clone(), "batch".into());
|
||||||
|
|
||||||
|
// pred_A references pred_B's hash (forward ref, pred_B not yet defined)
|
||||||
|
let stb_a = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
|
||||||
|
.arg("x")
|
||||||
|
.arg(BuilderArg::SelfPredicateHash("pred_B".into()));
|
||||||
|
builder.predicate_and("pred_A", &["x"], &[], &[stb_a])?;
|
||||||
|
|
||||||
|
// pred_B references pred_A's hash (backward ref, pred_A already defined)
|
||||||
|
let stb_b = StatementTmplBuilder::new_from_pred(NativePredicate::Equal)
|
||||||
|
.arg("x")
|
||||||
|
.arg(BuilderArg::SelfPredicateHash("pred_A".into()));
|
||||||
|
builder.predicate_and("pred_B", &["x"], &[], &[stb_b])?;
|
||||||
|
|
||||||
|
let batch = builder.finish()?;
|
||||||
|
|
||||||
|
// Verify resolution: pred_A references pred_B (index 1), pred_B references pred_A (index 0)
|
||||||
|
assert_eq!(
|
||||||
|
batch.predicates()[0].statements[0].args[1],
|
||||||
|
StatementTmplArg::SelfPredicateHash(1)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
batch.predicates()[1].statements[0].args[1],
|
||||||
|
StatementTmplArg::SelfPredicateHash(0)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Compute concrete hashes
|
||||||
|
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
|
||||||
|
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
|
||||||
|
let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
|
||||||
|
|
||||||
|
// Build a POD using pred_A: Equal(pred_b_hash, pred_b_hash)
|
||||||
|
let mut mp_builder = MainPodBuilder::new(¶ms, vd_set);
|
||||||
|
let eq_st = mp_builder.priv_op(Operation::eq(pred_b_hash.clone(), pred_b_hash.clone()))?;
|
||||||
|
mp_builder.pub_op(Operation::custom(pred_a_ref, [eq_st]))?;
|
||||||
|
|
||||||
|
// Prove and verify
|
||||||
|
let prover = MockProver {};
|
||||||
|
let proof = mp_builder.prove(&prover)?;
|
||||||
|
proof.pod.verify()?;
|
||||||
|
|
||||||
|
// Verify the public statement contains pred_b_hash as its argument
|
||||||
|
let pub_sts = proof.pod.pub_self_statements();
|
||||||
|
let custom_st = pub_sts
|
||||||
|
.iter()
|
||||||
|
.find(|s| matches!(s, middleware::Statement::Custom(_, _)))
|
||||||
|
.expect("should have a custom statement");
|
||||||
|
if let middleware::Statement::Custom(_, args) = custom_st {
|
||||||
|
assert_eq!(args[0], pred_b_hash);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -578,7 +578,7 @@ impl MainPodBuilder {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
OperationType::Custom(cpr) => {
|
OperationType::Custom(cpr) => {
|
||||||
let pred = &cpr.batch.predicates()[cpr.index];
|
let pred = cpr.normalized_predicate();
|
||||||
if pred.statements.len() != op.1.len() {
|
if pred.statements.len() != op.1.len() {
|
||||||
return Err(Error::custom(format!(
|
return Err(Error::custom(format!(
|
||||||
"Custom predicate operation needs {} statements but has {}.",
|
"Custom predicate operation needs {} statements but has {}.",
|
||||||
|
|
@ -606,7 +606,7 @@ impl MainPodBuilder {
|
||||||
}
|
}
|
||||||
wildcard_map[index] = Some(value);
|
wildcard_map[index] = Some(value);
|
||||||
}
|
}
|
||||||
fill_wildcard_values(pred, &args, &mut wildcard_map)?;
|
fill_wildcard_values(&pred, &args, &mut wildcard_map)?;
|
||||||
let v_default = Value::from(0);
|
let v_default = Value::from(0);
|
||||||
let st_args: Vec<_> = wildcard_map
|
let st_args: Vec<_> = wildcard_map
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
|
||||||
|
|
@ -346,6 +346,9 @@ impl<'a> Lowerer<'a> {
|
||||||
let key = Key::from(key_str.as_str());
|
let key = Key::from(key_str.as_str());
|
||||||
MWStatementTmplArg::AnchoredKey(wildcard, key)
|
MWStatementTmplArg::AnchoredKey(wildcard, key)
|
||||||
}
|
}
|
||||||
|
BuilderArg::SelfPredicateHash(_) => {
|
||||||
|
unreachable!("SelfPredicateHash should not appear in request lowering")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
mw_args.push(mw_arg);
|
mw_args.push(mw_arg);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -345,7 +345,9 @@ fn build_single_batch(
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(builder.finish())
|
builder.finish().map_err(|e| BatchingError::Internal {
|
||||||
|
message: format!("Failed to finalize batch '{}': {}", batch_name, e),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a statement template with properly resolved predicate references
|
/// Build a statement template with properly resolved predicate references
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,9 @@ pub enum StatementTmplArg {
|
||||||
// AnchoredKey where the origin is a wildcard
|
// AnchoredKey where the origin is a wildcard
|
||||||
AnchoredKey(Wildcard, Key),
|
AnchoredKey(Wildcard, Key),
|
||||||
Wildcard(Wildcard),
|
Wildcard(Wildcard),
|
||||||
|
/// Reference to a same-batch predicate's identity hash, resolved at verification time.
|
||||||
|
/// The usize is the predicate index within the batch.
|
||||||
|
SelfPredicateHash(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
|
|
@ -57,6 +60,7 @@ pub enum StatementTmplArgPrefix {
|
||||||
Literal = 1,
|
Literal = 1,
|
||||||
AnchoredKey = 2,
|
AnchoredKey = 2,
|
||||||
WildcardLiteral = 3,
|
WildcardLiteral = 3,
|
||||||
|
SelfPredicateHash = 4,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<StatementTmplArgPrefix> for F {
|
impl From<StatementTmplArgPrefix> for F {
|
||||||
|
|
@ -68,11 +72,12 @@ impl From<StatementTmplArgPrefix> for F {
|
||||||
impl ToFields for StatementTmplArg {
|
impl ToFields for StatementTmplArg {
|
||||||
fn to_fields(&self) -> Vec<F> {
|
fn to_fields(&self) -> Vec<F> {
|
||||||
// Encoding:
|
// Encoding:
|
||||||
// None => (0, 0, 0, 0, 0, 0, 0, 0, 0)
|
// None => (0, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||||
// Literal(v) => (1, [v ], 0, 0, 0, 0)
|
// Literal(v) => (1, [v ], 0, 0, 0, 0)
|
||||||
// Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
|
// Key(wc_index, key_or_wc) => (2, [wc_index], 0, 0, 0, [key_or_wc])
|
||||||
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
|
// WildcardLiteral(wc_index) => (3, [wc_index], 0, 0, 0, 0, 0, 0, 0)
|
||||||
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
|
// SelfPredicateHash(pred_index) => (4, pred_index, 0, 0, 0, 0, 0, 0, 0)
|
||||||
|
// In all cases, we pad to 2 * hash_size + 1 = 9 field elements
|
||||||
match self {
|
match self {
|
||||||
StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None))
|
StatementTmplArg::None => iter::once(F::from(StatementTmplArgPrefix::None))
|
||||||
.chain(iter::repeat(F::ZERO))
|
.chain(iter::repeat(F::ZERO))
|
||||||
|
|
@ -97,6 +102,13 @@ impl ToFields for StatementTmplArg {
|
||||||
.take(Params::statement_tmpl_arg_size())
|
.take(Params::statement_tmpl_arg_size())
|
||||||
.collect_vec()
|
.collect_vec()
|
||||||
}
|
}
|
||||||
|
StatementTmplArg::SelfPredicateHash(index) => {
|
||||||
|
iter::once(F::from(StatementTmplArgPrefix::SelfPredicateHash))
|
||||||
|
.chain(iter::once(F::from_canonical_usize(*index)))
|
||||||
|
.chain(iter::repeat(F::ZERO))
|
||||||
|
.take(Params::statement_tmpl_arg_size())
|
||||||
|
.collect_vec()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -113,6 +125,7 @@ impl fmt::Display for StatementTmplArg {
|
||||||
write!(f, "]")
|
write!(f, "]")
|
||||||
}
|
}
|
||||||
Self::Wildcard(v) => v.fmt(f),
|
Self::Wildcard(v) => v.fmt(f),
|
||||||
|
Self::SelfPredicateHash(i) => write!(f, "::self.{}", i),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -569,6 +582,44 @@ impl CustomPredicateRef {
|
||||||
pub fn predicate(&self) -> &CustomPredicate {
|
pub fn predicate(&self) -> &CustomPredicate {
|
||||||
&self.batch.predicates()[self.index]
|
&self.batch.predicates()[self.index]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a copy of this predicate with all `SelfPredicateHash(i)` args
|
||||||
|
/// resolved to `Literal(hash(Custom(batch, i)))`.
|
||||||
|
pub fn normalized_predicate(&self) -> CustomPredicate {
|
||||||
|
let pred = self.predicate();
|
||||||
|
let normalized_statements = pred
|
||||||
|
.statements
|
||||||
|
.iter()
|
||||||
|
.map(|st_tmpl| {
|
||||||
|
let args = st_tmpl
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|arg| match arg {
|
||||||
|
StatementTmplArg::SelfPredicateHash(i) => {
|
||||||
|
let pred_hash = Predicate::Custom(CustomPredicateRef {
|
||||||
|
batch: self.batch.clone(),
|
||||||
|
index: *i,
|
||||||
|
})
|
||||||
|
.hash();
|
||||||
|
StatementTmplArg::Literal(Value::from(pred_hash))
|
||||||
|
}
|
||||||
|
other => other.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
StatementTmpl {
|
||||||
|
pred_or_wc: st_tmpl.pred_or_wc.clone(),
|
||||||
|
args,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
CustomPredicate {
|
||||||
|
name: pred.name.clone(),
|
||||||
|
conjunction: pred.conjunction,
|
||||||
|
statements: normalized_statements,
|
||||||
|
args_len: pred.args_len,
|
||||||
|
wildcard_names: pred.wildcard_names.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -823,4 +874,164 @@ mod tests {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalized_predicate() -> Result<()> {
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
// Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0))
|
||||||
|
let pred_a = CustomPredicate::and(
|
||||||
|
¶ms,
|
||||||
|
"pred_A".into(),
|
||||||
|
vec![st(
|
||||||
|
P::Native(NP::Equal),
|
||||||
|
vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))],
|
||||||
|
)],
|
||||||
|
2,
|
||||||
|
names(&["x", "y"]),
|
||||||
|
)?;
|
||||||
|
let pred_b = CustomPredicate::and(
|
||||||
|
¶ms,
|
||||||
|
"pred_B".into(),
|
||||||
|
vec![st(
|
||||||
|
P::Native(NP::Equal),
|
||||||
|
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
|
||||||
|
)],
|
||||||
|
1,
|
||||||
|
names(&["x"]),
|
||||||
|
)?;
|
||||||
|
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
|
||||||
|
|
||||||
|
// Compute expected pred_A hash
|
||||||
|
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
|
||||||
|
let expected_hash = Value::from(Predicate::Custom(pred_a_ref).hash());
|
||||||
|
|
||||||
|
// Normalize pred_B
|
||||||
|
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
|
||||||
|
let normalized = pred_b_ref.normalized_predicate();
|
||||||
|
|
||||||
|
// The second arg should be resolved to Literal(pred_A_hash)
|
||||||
|
assert_eq!(
|
||||||
|
normalized.statements[0].args[1],
|
||||||
|
STA::Literal(expected_hash)
|
||||||
|
);
|
||||||
|
|
||||||
|
// First arg should be unchanged (still a wildcard)
|
||||||
|
assert_eq!(normalized.statements[0].args[0], STA::Wildcard(wc(0)));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_self_predicate_hash_check() -> Result<()> {
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
// Build a batch: pred_A = Equal(x, y), pred_B = Equal(x, SelfPredicateHash(0))
|
||||||
|
let pred_a = CustomPredicate::and(
|
||||||
|
¶ms,
|
||||||
|
"pred_A".into(),
|
||||||
|
vec![st(
|
||||||
|
P::Native(NP::Equal),
|
||||||
|
vec![STA::Wildcard(wc(0)), STA::Wildcard(wc(1))],
|
||||||
|
)],
|
||||||
|
2,
|
||||||
|
names(&["x", "y"]),
|
||||||
|
)?;
|
||||||
|
let pred_b = CustomPredicate::and(
|
||||||
|
¶ms,
|
||||||
|
"pred_B".into(),
|
||||||
|
vec![st(
|
||||||
|
P::Native(NP::Equal),
|
||||||
|
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
|
||||||
|
)],
|
||||||
|
1,
|
||||||
|
names(&["x"]),
|
||||||
|
)?;
|
||||||
|
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
|
||||||
|
|
||||||
|
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
|
||||||
|
let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref).hash());
|
||||||
|
|
||||||
|
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
|
||||||
|
|
||||||
|
// Construct a valid operation: Equal(some_value, pred_a_hash)
|
||||||
|
let some_value = Value::from(42);
|
||||||
|
let op_args = vec![Statement::equal(some_value.clone(), pred_a_hash.clone())];
|
||||||
|
|
||||||
|
// The output statement
|
||||||
|
let output_st = Statement::Custom(pred_b_ref.clone(), vec![some_value.clone()]);
|
||||||
|
|
||||||
|
// This should pass
|
||||||
|
assert!(Operation::Custom(pred_b_ref.clone(), op_args).check(¶ms, &output_st)?);
|
||||||
|
|
||||||
|
// Now try with wrong hash, should fail
|
||||||
|
let wrong_hash = Value::from(999);
|
||||||
|
let bad_op_args = vec![Statement::equal(some_value.clone(), wrong_hash)];
|
||||||
|
assert!(Operation::Custom(pred_b_ref, bad_op_args)
|
||||||
|
.check(¶ms, &output_st)
|
||||||
|
.is_err());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_self_predicate_hash_cyclic() -> Result<()> {
|
||||||
|
let params = Params::default();
|
||||||
|
|
||||||
|
// Build a batch where pred_A references pred_B's hash and vice versa
|
||||||
|
// pred_A = Equal(x, SelfPredicateHash(1))
|
||||||
|
// pred_B = Equal(x, SelfPredicateHash(0))
|
||||||
|
let pred_a = CustomPredicate::and(
|
||||||
|
¶ms,
|
||||||
|
"pred_A".into(),
|
||||||
|
vec![st(
|
||||||
|
P::Native(NP::Equal),
|
||||||
|
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(1)],
|
||||||
|
)],
|
||||||
|
1,
|
||||||
|
names(&["x"]),
|
||||||
|
)?;
|
||||||
|
let pred_b = CustomPredicate::and(
|
||||||
|
¶ms,
|
||||||
|
"pred_B".into(),
|
||||||
|
vec![st(
|
||||||
|
P::Native(NP::Equal),
|
||||||
|
vec![STA::Wildcard(wc(0)), STA::SelfPredicateHash(0)],
|
||||||
|
)],
|
||||||
|
1,
|
||||||
|
names(&["x"]),
|
||||||
|
)?;
|
||||||
|
let batch = CustomPredicateBatch::new("batch".into(), vec![pred_a, pred_b]);
|
||||||
|
|
||||||
|
let pred_a_ref = CustomPredicateRef::new(batch.clone(), 0);
|
||||||
|
let pred_b_ref = CustomPredicateRef::new(batch.clone(), 1);
|
||||||
|
let pred_a_hash = Value::from(Predicate::Custom(pred_a_ref.clone()).hash());
|
||||||
|
let pred_b_hash = Value::from(Predicate::Custom(pred_b_ref.clone()).hash());
|
||||||
|
|
||||||
|
// pred_A's normalized form should reference pred_B's hash
|
||||||
|
let norm_a = pred_a_ref.normalized_predicate();
|
||||||
|
assert_eq!(
|
||||||
|
norm_a.statements[0].args[1],
|
||||||
|
STA::Literal(pred_b_hash.clone())
|
||||||
|
);
|
||||||
|
|
||||||
|
// pred_B's normalized form should reference pred_A's hash
|
||||||
|
let norm_b = pred_b_ref.normalized_predicate();
|
||||||
|
assert_eq!(
|
||||||
|
norm_b.statements[0].args[1],
|
||||||
|
STA::Literal(pred_a_hash.clone())
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify pred_A: Equal(pred_b_hash, pred_b_hash) should pass
|
||||||
|
let op_a = vec![Statement::equal(pred_b_hash.clone(), pred_b_hash.clone())];
|
||||||
|
let st_a = Statement::Custom(pred_a_ref.clone(), vec![pred_b_hash.clone()]);
|
||||||
|
assert!(Operation::Custom(pred_a_ref, op_a).check(¶ms, &st_a)?);
|
||||||
|
|
||||||
|
// Verify pred_B: Equal(pred_a_hash, pred_a_hash) should pass
|
||||||
|
let op_b = vec![Statement::equal(pred_a_hash.clone(), pred_a_hash.clone())];
|
||||||
|
let st_b = Statement::Custom(pred_b_ref.clone(), vec![pred_a_hash.clone()]);
|
||||||
|
assert!(Operation::Custom(pred_b_ref, op_b).check(¶ms, &st_b)?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -595,6 +595,11 @@ pub fn check_st_tmpl(
|
||||||
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
|
(StatementTmplArg::Wildcard(wc), StatementArg::Literal(v)) => {
|
||||||
wc_check_or_set(v.clone(), wc, wildcard_map)
|
wc_check_or_set(v.clone(), wc, wildcard_map)
|
||||||
}
|
}
|
||||||
|
(StatementTmplArg::SelfPredicateHash(_), _) => {
|
||||||
|
unreachable!(
|
||||||
|
"SelfPredicateHash should be normalized to Literal before template matching"
|
||||||
|
)
|
||||||
|
}
|
||||||
_ => Err(Error::mismatched_statement_tmpl_arg(
|
_ => Err(Error::mismatched_statement_tmpl_arg(
|
||||||
st_tmpl_arg.clone(),
|
st_tmpl_arg.clone(),
|
||||||
st_arg.clone(),
|
st_arg.clone(),
|
||||||
|
|
@ -712,7 +717,7 @@ pub(crate) fn check_custom_pred(
|
||||||
args: &[Statement],
|
args: &[Statement],
|
||||||
s_args: &[Value],
|
s_args: &[Value],
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let pred = custom_pred_ref.predicate();
|
let pred = custom_pred_ref.normalized_predicate();
|
||||||
if pred.statements.len() != args.len() {
|
if pred.statements.len() != args.len() {
|
||||||
return Err(Error::diff_amount(
|
return Err(Error::diff_amount(
|
||||||
"custom predicate operation".to_string(),
|
"custom predicate operation".to_string(),
|
||||||
|
|
@ -731,7 +736,7 @@ pub(crate) fn check_custom_pred(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the resolved wildcards match the statement arguments.
|
// Check that the resolved wildcards match the statement arguments.
|
||||||
let wc_values = match wildcard_values_from_op_st(params, pred, args, s_args) {
|
let wc_values = match wildcard_values_from_op_st(params, &pred, args, s_args) {
|
||||||
Ok(wc_values) => wc_values,
|
Ok(wc_values) => wc_values,
|
||||||
Err(Error::Inner { inner, backtrace }) => match *inner {
|
Err(Error::Inner { inner, backtrace }) => match *inner {
|
||||||
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
|
MiddlewareInnerError::InvalidWildcardAssignment(wc, v, prev)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue