diff --git a/ml-kem/src/kem.rs b/ml-kem/src/kem.rs index 2644d92..4c0bbc5 100644 --- a/ml-kem/src/kem.rs +++ b/ml-kem/src/kem.rs @@ -5,6 +5,7 @@ pub use ::kem::{ Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, Key, KeyExport, KeyInit, KeySizeUser, TryKeyInit, }; +use sha3::Digest; use crate::{ B32, Encoded, EncodedSizeUser, KemCore, Seed, @@ -169,8 +170,10 @@ where let (dk_pke, ek_pke, h, z) = P::split_dk(enc); let ek_pke = EncryptionKey::from_bytes(ek_pke)?; - // XXX(RLB): The encoding here is redundant, since `h` can be computed from `ek_pke`. - // Should we verify that the provided `h` value is valid? + let test = sha3::Sha3_256::digest(ek_pke.to_bytes()); + if test.as_slice() != h.as_slice() { + return Err(InvalidKey); + } Ok(Self { dk_pke: DecryptionKey::from_bytes(dk_pke), @@ -373,6 +376,7 @@ mod test { use super::*; use crate::{MlKem512Params, MlKem768Params, MlKem1024Params}; use ::kem::{Decapsulate, Encapsulate, Generate}; + use array::typenum::Unsigned; use getrandom::SysRng; use rand_core::UnwrapErr; @@ -421,6 +425,30 @@ mod test { expanded_key_test::(); } + fn invalid_hash_expanded_key_test

() + where + P: KemParams, + { + let mut rng = UnwrapErr(SysRng); + let dk_original = DecapsulationKey::

::generate_from_rng(&mut rng); + + let mut dk_encoded = dk_original.to_encoded_bytes(); + // Corrupt the hash value + let hash_offset = P::NttVectorSize::USIZE + P::EncryptionKeySize::USIZE; + dk_encoded[hash_offset] ^= 0xFF; + + let dk_decoded: Result, InvalidKey> = + DecapsulationKey::from_encoded_bytes(&dk_encoded); + assert!(dk_decoded.is_err()); + } + + #[test] + fn invalid_hash_expanded_key() { + invalid_hash_expanded_key_test::(); + invalid_hash_expanded_key_test::(); + invalid_hash_expanded_key_test::(); + } + fn seed_test

() where P: KemParams,