From 8cc090c5e048995d00207b4f0e4e4210204e3ce5 Mon Sep 17 00:00:00 2001 From: Ahmad Afuni Date: Tue, 6 May 2025 19:14:53 +1000 Subject: [PATCH] Implement HashOf statement and op (#217) --- src/backends/plonky2/circuits/common.rs | 16 ++++- src/backends/plonky2/circuits/mainpod.rs | 84 +++++++++++++++++++++++- src/frontend/mod.rs | 27 +++++++- src/middleware/basetypes.rs | 6 +- src/middleware/operation.rs | 8 +++ src/middleware/statement.rs | 4 ++ 6 files changed, 139 insertions(+), 6 deletions(-) diff --git a/src/backends/plonky2/circuits/common.rs b/src/backends/plonky2/circuits/common.rs index 615ade7..c7ea061 100644 --- a/src/backends/plonky2/circuits/common.rs +++ b/src/backends/plonky2/circuits/common.rs @@ -7,7 +7,10 @@ use plonky2::{ extension::Extendable, types::{Field, PrimeField64}, }, - hash::hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS}, + hash::{ + hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS}, + poseidon::PoseidonHash, + }, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, @@ -320,6 +323,9 @@ pub trait CircuitBuilderPod, const D: usize> { /// and `y` each consist of two `u32` limbs. fn assert_i64_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget); + /// Creates value target that is a hash of two given values. + fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget; + // Convenience methods for accessing and connecting elements of // (vectors of) flattenables. fn vec_ref(&mut self, ts: &[T], i: Target) -> T; @@ -455,6 +461,14 @@ impl CircuitBuilderPod for CircuitBuilder { assert_limb_lt(self, lhs, rhs); } + fn hash_values(&mut self, x: ValueTarget, y: ValueTarget) -> ValueTarget { + ValueTarget::from_slice( + &self + .hash_n_to_hash_no_pad::([x.elements, y.elements].concat()) + .elements, + ) + } + fn vec_ref(&mut self, ts: &[T], i: Target) -> T { // TODO: Revisit this when we need more than 64 statements. let vector_ref = |builder: &mut CircuitBuilder, v: &[Target], i| { diff --git a/src/backends/plonky2/circuits/mainpod.rs b/src/backends/plonky2/circuits/mainpod.rs index 2db0adc..dec3281 100644 --- a/src/backends/plonky2/circuits/mainpod.rs +++ b/src/backends/plonky2/circuits/mainpod.rs @@ -108,6 +108,7 @@ impl OperationVerifyGadget { self.eval_copy(builder, st, op, &resolved_op_args)?, self.eval_eq_from_entries(builder, st, op, &resolved_op_args), self.eval_lt_lteq_from_entries(builder, st, op, &resolved_op_args), + self.eval_hash_of(builder, st, op, &resolved_op_args), ] }, // Skip these if there are no resolved Merkle claims @@ -275,6 +276,40 @@ impl OperationVerifyGadget { builder.all([op_st_code_ok, arg_types_ok, st_args_ok]) } + fn eval_hash_of( + &self, + builder: &mut CircuitBuilder, + st: &StatementTarget, + op: &OperationTarget, + resolved_op_args: &[StatementTarget], + ) -> BoolTarget { + let op_code_ok = op.has_native_type(builder, NativeOperation::HashOf); + + let arg_types_ok = self.first_n_args_are_valueofs(builder, 3, resolved_op_args); + + let arg1_value = resolved_op_args[0].args[1].as_value(); + let arg2_value = resolved_op_args[1].args[1].as_value(); + let arg3_value = resolved_op_args[2].args[1].as_value(); + + let expected_hash_value = builder.hash_values(arg2_value, arg3_value); + + let hash_value_ok = + builder.is_equal_slice(&arg1_value.elements, &expected_hash_value.elements); + + let arg1_key = resolved_op_args[0].args[0].clone(); + let arg2_key = resolved_op_args[1].args[0].clone(); + let arg3_key = resolved_op_args[2].args[0].clone(); + let expected_statement = StatementTarget::new_native( + builder, + &self.params, + NativePredicate::HashOf, + &[arg1_key, arg2_key, arg3_key], + ); + let st_ok = builder.is_equal_flattenable(st, &expected_statement); + + builder.all([op_code_ok, arg_types_ok, hash_value_ok, st_ok]) + } + fn eval_none( &self, builder: &mut CircuitBuilder, @@ -533,7 +568,7 @@ mod tests { mainpod::{OperationArg, OperationAux}, primitives::merkletree::{MerkleClaimAndProof, MerkleTree}, }, - middleware::{Hash, OperationType, PodId, RawValue}, + middleware::{hash_values, Hash, OperationType, PodId, RawValue}, }; fn operation_verify( @@ -939,6 +974,53 @@ mod tests { ); operation_verify(st, op, prev_statements, merkle_proofs.clone())?; + // HashOf + let input_values = [ + Value::from(RawValue([ + GoldilocksField(1), + GoldilocksField(2), + GoldilocksField(3), + GoldilocksField(4), + ])), + Value::from(512), + ]; + let v1 = hash_values(&input_values); + let [v2, v3] = input_values; + + let st1: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), + v1.into(), + ) + .into(); + let st2: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), + v2, + ) + .into(); + let st3: mainpod::Statement = Statement::ValueOf( + AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), + v3, + ) + .into(); + + let st: mainpod::Statement = Statement::HashOf( + AnchoredKey::from((PodId(RawValue::from(88).into()), "hola")), + AnchoredKey::from((PodId(RawValue::from(128).into()), "mundo")), + AnchoredKey::from((PodId(RawValue::from(256).into()), "!")), + ) + .into(); + let op = mainpod::Operation( + OperationType::Native(NativeOperation::HashOf), + vec![ + OperationArg::Index(0), + OperationArg::Index(1), + OperationArg::Index(2), + ], + OperationAux::None, + ); + let prev_statements = vec![st1, st2, st3]; + operation_verify(st, op, prev_statements, merkle_proofs.clone())?; + // NotContainsFromEntries let kvs = [ (1.into(), 55.into()), diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 56d33ab..7113b72 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -7,9 +7,10 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; use crate::middleware::{ - self, check_st_tmpl, hash_str, AnchoredKey, Key, MainPodInputs, NativeOperation, - NativePredicate, OperationAux, OperationType, Params, PodId, PodProver, PodSigner, Predicate, - Statement, StatementArg, Value, WildcardValue, EMPTY_VALUE, KEY_TYPE, SELF, + self, check_st_tmpl, hash_str, hash_values, AnchoredKey, Hash, Key, MainPodInputs, + NativeOperation, NativePredicate, OperationAux, OperationType, Params, PodId, PodProver, + PodSigner, Predicate, Statement, StatementArg, Value, WildcardValue, EMPTY_VALUE, KEY_TYPE, + SELF, }; mod custom; @@ -435,6 +436,26 @@ impl MainPodBuilder { return Err(Error::op_invalid_args("max-of".to_string())); } }, + HashOf => match (args[0].clone(), args[1].clone(), args[2].clone()) { + ( + OperationArg::Statement(Statement::ValueOf(ak0, v0)), + OperationArg::Statement(Statement::ValueOf(ak1, v1)), + OperationArg::Statement(Statement::ValueOf(ak2, v2)), + ) => { + if Hash::from(v0.raw()) == hash_values(&[v1, v2]) { + vec![ + StatementArg::Key(ak0), + StatementArg::Key(ak1), + StatementArg::Key(ak2), + ] + } else { + return Err(Error::op_invalid_args("hash-of".to_string())); + } + } + _ => { + return Err(Error::op_invalid_args("hash-of".to_string())); + } + }, ContainsFromEntries => self.op_args_entries(public, args)?, NotContainsFromEntries => self.op_args_entries(public, args)?, _ => Err(Error::custom(format!( diff --git a/src/middleware/basetypes.rs b/src/middleware/basetypes.rs index 9dbc513..f2c4827 100644 --- a/src/middleware/basetypes.rs +++ b/src/middleware/basetypes.rs @@ -58,7 +58,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use super::serialization::*; -use crate::middleware::{Params, ToFields}; +use crate::middleware::{Params, ToFields, Value}; /// F is the native field we use everywhere. Currently it's Goldilocks from plonky2 pub type F = GoldilocksField; @@ -164,6 +164,10 @@ pub fn hash_fields(input: &[F]) -> Hash { Hash(PoseidonHash::hash_no_pad(input).elements) } +pub fn hash_values(input: &[Value]) -> Hash { + hash_fields(&input.iter().flat_map(|v| v.raw().0).collect::>()) +} + impl From for Hash { fn from(v: RawValue) -> Self { Hash(v.0) diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 405a32c..bdb182e 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -69,6 +69,7 @@ pub enum NativeOperation { SumOf = 11, ProductOf = 12, MaxOf = 13, + HashOf = 14, // Syntactic sugar operations. These operations are not supported by the backend. The // frontend compiler is responsible of translating these operations into the operations above. @@ -119,6 +120,7 @@ impl OperationType { NativeOperation::SumOf => Some(Predicate::Native(NativePredicate::SumOf)), NativeOperation::ProductOf => Some(Predicate::Native(NativePredicate::ProductOf)), NativeOperation::MaxOf => Some(Predicate::Native(NativePredicate::MaxOf)), + NativeOperation::HashOf => Some(Predicate::Native(NativePredicate::HashOf)), no => unreachable!("Unexpected syntactic sugar op {:?}", no), }, OperationType::Custom(cpr) => Some(Predicate::Custom(cpr.clone())), @@ -152,6 +154,7 @@ pub enum Operation { SumOf(Statement, Statement, Statement), ProductOf(Statement, Statement, Statement), MaxOf(Statement, Statement, Statement), + HashOf(Statement, Statement, Statement), Custom(CustomPredicateRef, Vec), } @@ -174,6 +177,7 @@ impl Operation { Self::SumOf(_, _, _) => OT::Native(SumOf), Self::ProductOf(_, _, _) => OT::Native(ProductOf), Self::MaxOf(_, _, _) => OT::Native(MaxOf), + Self::HashOf(_, _, _) => OT::Native(HashOf), Self::Custom(cpr, _) => OT::Custom(cpr.clone()), } } @@ -194,6 +198,7 @@ impl Operation { Self::SumOf(s1, s2, s3) => vec![s1, s2, s3], Self::ProductOf(s1, s2, s3) => vec![s1, s2, s3], Self::MaxOf(s1, s2, s3) => vec![s1, s2, s3], + Self::HashOf(s1, s2, s3) => vec![s1, s2, s3], Self::Custom(_, args) => args, } } @@ -250,6 +255,9 @@ impl Operation { Self::ProductOf(s1, s2, s3) } (NO::MaxOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => Self::MaxOf(s1, s2, s3), + (NO::HashOf, (Some(s1), Some(s2), Some(s3)), OA::None, 3) => { + Self::HashOf(s1, s2, s3) + } _ => Err(Error::custom(format!( "Ill-formed operation {:?} with arguments {:?}.", op_code, args diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 2327335..2f6f278 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -32,6 +32,7 @@ pub enum NativePredicate { SumOf = 8, ProductOf = 9, MaxOf = 10, + HashOf = 11, // Syntactic sugar predicates. These predicates are not supported by the backend. The // frontend compiler is responsible of translating these predicates into the predicates above. @@ -102,6 +103,7 @@ pub enum Statement { SumOf(AnchoredKey, AnchoredKey, AnchoredKey), ProductOf(AnchoredKey, AnchoredKey, AnchoredKey), MaxOf(AnchoredKey, AnchoredKey, AnchoredKey), + HashOf(AnchoredKey, AnchoredKey, AnchoredKey), Custom(CustomPredicateRef, Vec), } @@ -123,6 +125,7 @@ impl Statement { Self::SumOf(_, _, _) => Native(NativePredicate::SumOf), Self::ProductOf(_, _, _) => Native(NativePredicate::ProductOf), Self::MaxOf(_, _, _) => Native(NativePredicate::MaxOf), + Self::HashOf(_, _, _) => Native(NativePredicate::HashOf), Self::Custom(cpr, _) => Custom(cpr.clone()), } } @@ -140,6 +143,7 @@ impl Statement { Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], + Self::HashOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)], Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(WildcardLiteral)), } }