always replace SELF when copying statements (#345)

This commit is contained in:
Daniel Gulotta 2025-07-22 14:56:37 -07:00 committed by GitHub
parent 5cdf53576b
commit 89dfc4e214
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 12 additions and 14 deletions

View file

@ -2,17 +2,16 @@
//! is enabled. //! is enabled.
//! See src/middleware/basetypes.rs for more details. //! See src/middleware/basetypes.rs for more details.
/// F is the native field we use everywhere. Currently it's Goldilocks from plonky2
pub use plonky2::field::goldilocks_field::GoldilocksField as F;
use plonky2::{ use plonky2::{
field::{extension::quadratic::QuadraticExtension, goldilocks_field::GoldilocksField}, field::extension::quadratic::QuadraticExtension,
hash::{hash_types, poseidon::PoseidonHash}, hash::{hash_types, poseidon::PoseidonHash},
plonk::{circuit_builder, circuit_data, config::GenericConfig, proof}, plonk::{circuit_builder, circuit_data, config::GenericConfig, proof},
}; };
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
/// F is the native field we use everywhere. Currently it's Goldilocks from plonky2
pub type F = GoldilocksField;
/// D defines the extension degree of the field used in the Plonky2 proofs (quadratic extension). /// D defines the extension degree of the field used in the Plonky2 proofs (quadratic extension).
pub const D: usize = 2; pub const D: usize = 2;

View file

@ -41,7 +41,7 @@ use crate::{
middleware::{ middleware::{
AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation, AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
NativePredicate, Params, PodType, PredicatePrefix, Statement, StatementArg, ToFields, NativePredicate, Params, PodType, PredicatePrefix, Statement, StatementArg, ToFields,
Value, ValueRef, EMPTY_VALUE, F, HASH_SIZE, KEY_TYPE, SELF, VALUE_SIZE, Value, ValueRef, F, HASH_SIZE, KEY_TYPE, SELF, VALUE_SIZE,
}, },
}; };
@ -947,7 +947,6 @@ fn normalize_statement_circuit(
statement: &StatementTarget, statement: &StatementTarget,
self_id: &ValueTarget, self_id: &ValueTarget,
) -> StatementTarget { ) -> StatementTarget {
let zero_value = builder.constant_value(EMPTY_VALUE);
let self_value = builder.constant_value(SELF.0.into()); let self_value = builder.constant_value(SELF.0.into());
let args = statement let args = statement
.args .args
@ -955,11 +954,8 @@ fn normalize_statement_circuit(
.map(|arg| { .map(|arg| {
let first = ValueTarget::from_slice(&arg.elements[..VALUE_SIZE]); let first = ValueTarget::from_slice(&arg.elements[..VALUE_SIZE]);
let second = ValueTarget::from_slice(&arg.elements[VALUE_SIZE..]); let second = ValueTarget::from_slice(&arg.elements[VALUE_SIZE..]);
let is_not_ak = builder.is_equal_flattenable(&zero_value, &second);
let is_ak = builder.not(is_not_ak);
let is_self = builder.is_equal_flattenable(&self_value, &first); let is_self = builder.is_equal_flattenable(&self_value, &first);
let normalize = builder.and(is_ak, is_self); let first_normalized = builder.select_flattenable(params, is_self, self_id, &first);
let first_normalized = builder.select_flattenable(params, normalize, self_id, &first);
StatementArgTarget::new(first_normalized, second) StatementArgTarget::new(first_normalized, second)
}) })
.collect_vec(); .collect_vec();

View file

@ -31,7 +31,7 @@ mod tests {
middleware::{ middleware::{
hash_str, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, hash_str, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key,
NativePredicate, Params, PodId, PodType, Predicate, RawValue, StatementTmpl, NativePredicate, Params, PodId, PodType, Predicate, RawValue, StatementTmpl,
StatementTmplArg, Value, Wildcard, KEY_SIGNER, KEY_TYPE, SELF_ID_HASH, StatementTmplArg, Value, Wildcard, KEY_SIGNER, KEY_TYPE,
}, },
}; };
@ -125,7 +125,7 @@ mod tests {
pred: Predicate::Native(NativePredicate::Equal), pred: Predicate::Native(NativePredicate::Equal),
args: vec![ args: vec![
sta_ak(("ConstPod", 0), "my_val"), // ?ConstPod["my_val"] -> Wildcard(0), Key("my_val") sta_ak(("ConstPod", 0), "my_val"), // ?ConstPod["my_val"] -> Wildcard(0), Key("my_val")
sta_lit(SELF_ID_HASH), sta_lit(RawValue::from(1)),
], ],
}, },
StatementTmpl { StatementTmpl {

View file

@ -55,7 +55,7 @@ pub const HASH_SIZE: usize = 4;
pub const VALUE_SIZE: usize = 4; pub const VALUE_SIZE: usize = 4;
pub const EMPTY_VALUE: RawValue = RawValue([F::ZERO, F::ZERO, F::ZERO, F::ZERO]); pub const EMPTY_VALUE: RawValue = RawValue([F::ZERO, F::ZERO, F::ZERO, F::ZERO]);
pub const SELF_ID_HASH: Hash = Hash([F::ONE, F::ZERO, F::ZERO, F::ZERO]); pub const SELF_ID_HASH: Hash = Hash([F(0x5), F(0xe), F(0x1), F(0xf)]);
pub const EMPTY_HASH: Hash = Hash([F::ZERO, F::ZERO, F::ZERO, F::ZERO]); pub const EMPTY_HASH: Hash = Hash([F::ZERO, F::ZERO, F::ZERO, F::ZERO]);
#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]

View file

@ -766,7 +766,7 @@ impl Params {
} }
} }
/// Replace references to SELF by `self_id` in anchored keys of the statement. /// Replace references to SELF by `self_id`.
pub fn normalize_statement(statement: &Statement, self_id: PodId) -> Statement { pub fn normalize_statement(statement: &Statement, self_id: PodId) -> Statement {
let predicate = statement.predicate(); let predicate = statement.predicate();
let args = statement let args = statement
@ -776,6 +776,9 @@ pub fn normalize_statement(statement: &Statement, self_id: PodId) -> Statement {
StatementArg::Key(AnchoredKey { pod_id, key }) if *pod_id == SELF => { StatementArg::Key(AnchoredKey { pod_id, key }) if *pod_id == SELF => {
StatementArg::Key(AnchoredKey::new(self_id, key.clone())) StatementArg::Key(AnchoredKey::new(self_id, key.clone()))
} }
StatementArg::Literal(value) if value.raw.0 == SELF.0 .0 => {
StatementArg::Literal(self_id.into())
}
_ => sa.clone(), _ => sa.clone(),
}) })
.collect(); .collect();