diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 51535a8..1e409b9 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -1,6 +1,9 @@ use array::{Array, typenum::U256}; use core::ops::{Add, Mul}; -use module_lattice::{algebra::Field, util::Truncate}; +use module_lattice::{ + algebra::{Field, MultiplyNtt}, + util::Truncate, +}; use sha3::digest::XofReader; use subtle::{Choice, ConstantTimeEq}; @@ -28,7 +31,7 @@ pub type PolynomialVector = module_lattice::algebra::Vector; /// An element of the ring `T_q` i.e. a tuple of 128 elements of the direct sum components of `T_q`. pub type NttPolynomial = module_lattice::algebra::NttPolynomial; -// Algorithm 7 SampleNTT(B) +/// Algorithm 7: `SampleNTT(B)` pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { struct FieldElementReader<'a> { xof: &'a mut dyn XofReader, @@ -89,11 +92,11 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { NttPolynomial::new(Array::from_fn(|_| reader.next())) } -// Algorithm 8. SamplePolyCBD_eta(B) -// -// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in -// ByteDecode. We decode the PRF output into integers with eta bits, then use -// `count_ones` to perform the summation described in the algorithm. +/// Algorithm 8: `SamplePolyCBD_eta(B)` +/// +/// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in +/// `ByteDecode`. We decode the PRF output into integers with eta bits, then use +/// `count_ones` to perform the summation described in the algorithm. pub(crate) fn sample_poly_cbd(B: &PrfOutput) -> Polynomial where Eta: CbdSamplingSize, @@ -102,76 +105,118 @@ where Polynomial::new(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect()) } -// Algorithm 9. NTT -pub(crate) fn ntt(poly: &Polynomial) -> NttPolynomial { - let mut k = 1; +pub(crate) fn sample_poly_vec_cbd(sigma: &B32, start_n: u8) -> PolynomialVector +where + Eta: CbdSamplingSize, + K: ArraySize, +{ + PolynomialVector::new(Array::from_fn(|i| { + let N = start_n + u8::truncate(i); + let prf_output = PRF::(sigma, N); + sample_poly_cbd::(&prf_output) + })) +} + +/// The Number Theoretic Transform (NTT) is a variant of the Discrete Fourier Transform (DFT) +/// defined over a finite field that turns costly polynomial multiplications into simple +/// coefficient-wise multiplications modulo a fixed prime. +pub(crate) trait Ntt { + type Output; + fn ntt(&self) -> Self::Output; +} + +/// Algorithm 9: `NTT` +impl Ntt for Polynomial { + type Output = NttPolynomial; + + fn ntt(&self) -> NttPolynomial { + let mut k = 1; - let mut f = poly.0; - for len in [128, 64, 32, 16, 8, 4, 2] { - for start in (0..256).step_by(2 * len) { - let zeta = ZETA_POW_BITREV[k]; - k += 1; + let mut f = self.0; + for len in [128, 64, 32, 16, 8, 4, 2] { + for start in (0..256).step_by(2 * len) { + let zeta = ZETA_POW_BITREV[k]; + k += 1; - for j in start..(start + len) { - let t = zeta * f[j + len]; - f[j + len] = f[j] - t; - f[j] = f[j] + t; + for j in start..(start + len) { + let t = zeta * f[j + len]; + f[j + len] = f[j] - t; + f[j] = f[j] + t; + } } } + + f.into() } +} + +impl Ntt for PolynomialVector { + type Output = NttVector; - f.into() + fn ntt(&self) -> NttVector { + NttVector(self.0.iter().map(Ntt::ntt).collect()) + } } -pub(crate) fn ntt_vector(poly: &PolynomialVector) -> NttVector { - NttVector(poly.0.iter().map(ntt).collect()) +/// The inverse NTT is the reverse of the Number Theoretic Transform, converting coefficient-wise +/// products back into standard polynomial form while preserving correctness modulo the same prime. +#[allow(clippy::module_name_repetitions)] +pub(crate) trait NttInverse { + type Output; + fn ntt_inverse(&self) -> Self::Output; } -// Algorithm 10. NTT^{-1} -pub(crate) fn ntt_inverse(poly: &NttPolynomial) -> Polynomial { - let mut f: Array = poly.0.clone(); +/// Algorithm 10: `NTT^{-1}` +impl NttInverse for NttPolynomial { + type Output = Polynomial; + + fn ntt_inverse(&self) -> Polynomial { + let mut f: Array = self.0.clone(); - let mut k = 127; - for len in [2, 4, 8, 16, 32, 64, 128] { - for start in (0..256).step_by(2 * len) { - let zeta = ZETA_POW_BITREV[k]; - k -= 1; + let mut k = 127; + for len in [2, 4, 8, 16, 32, 64, 128] { + for start in (0..256).step_by(2 * len) { + let zeta = ZETA_POW_BITREV[k]; + k -= 1; - for j in start..(start + len) { - let t = f[j]; - f[j] = t + f[j + len]; - f[j + len] = zeta * (f[j + len] - t); + for j in start..(start + len) { + let t = f[j]; + f[j] = t + f[j + len]; + f[j + len] = zeta * (f[j + len] - t); + } } } - } - FieldElement::new(3303) * &Polynomial::new(f) + FieldElement::new(3303) * &Polynomial::new(f) + } } -// Algorithm 11. MultiplyNTTs -fn multiply_ntts(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial { - let mut out = NttPolynomial::new(Array::default()); - - for i in 0..128 { - let (c0, c1) = base_case_multiply( - lhs.0[2 * i], - lhs.0[2 * i + 1], - rhs.0[2 * i], - rhs.0[2 * i + 1], - i, - ); - - out.0[2 * i] = c0; - out.0[2 * i + 1] = c1; - } +/// Algorithm 11: `MultiplyNTTs` +impl MultiplyNtt for BaseField { + fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial { + let mut out = NttPolynomial::new(Array::default()); + + for i in 0..128 { + let (c0, c1) = base_case_multiply( + lhs.0[2 * i], + lhs.0[2 * i + 1], + rhs.0[2 * i], + rhs.0[2 * i + 1], + i, + ); + + out.0[2 * i] = c0; + out.0[2 * i + 1] = c1; + } - out + out + } } -// Algorithm 12. BaseCaseMultiply -// -// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of -// modular reductions, since these are the expensive operation. +/// Algorithm 12: `BaseCaseMultiply` +/// +/// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of +/// modular reductions, since these are the expensive operation. #[inline] fn base_case_multiply( a0: FieldElement, @@ -193,30 +238,18 @@ fn base_case_multiply( (FieldElement::new(c0), FieldElement::new(c1)) } -pub(crate) fn sample_poly_vec_cbd(sigma: &B32, start_n: u8) -> PolynomialVector -where - Eta: CbdSamplingSize, - K: ArraySize, -{ - PolynomialVector::new(Array::from_fn(|i| { - let N = start_n + u8::truncate(i); - let prf_output = PRF::(sigma, N); - sample_poly_cbd::(&prf_output) - })) -} - -// Since the powers of zeta used in the NTT and MultiplyNTTs are fixed, we use pre-computed tables -// to avoid the need to compute the exponetiations at runtime. -// -// * ZETA_POW_BITREV[i] = zeta^{BitRev_7(i)} -// * GAMMA[i] = zeta^{2 BitRev_7(i) + 1} -// -// Note that the const environment here imposes some annoying conditions. Because operator -// overloading can't be const, we have to do all the reductions here manually. Because `for` loops -// are forbidden in `const` functions, we do them manually with `while` loops. -// -// The values computed here match those provided in Appendix A of FIPS 203. ZETA_POW_BITREV -// corresponds to the first table, and GAMMA to the second table. +/// Since the powers of zeta used in the `NTT` and `MultiplyNTTs` are fixed, we use pre-computed +/// tables to avoid the need to compute the exponentiations at runtime. +/// +/// * `ZETA_POW_BITREV[i] = zeta^{BitRev_7(i)}` +/// * `GAMMA[i] = zeta^{2 BitRev_7(i) + 1}` +/// +/// Note that the const environment here imposes some annoying conditions. Because operator +/// overloading can't be const, we have to do all the reductions here manually. Because `for` loops +/// are forbidden in `const` functions, we do them manually with `while` loops. +/// +/// The values computed here match those provided in Appendix A of FIPS 203. +/// `ZETA_POW_BITREV` corresponds to the first table, and `GAMMA` to the second table. #[allow(clippy::cast_possible_truncation)] const ZETA_POW_BITREV: [FieldElement; 128] = { const ZETA: u64 = 17; @@ -328,14 +361,14 @@ impl Mul<&NttVector> for &NttVector { self.0 .iter() .zip(rhs.0.iter()) - .map(|(x, y)| multiply_ntts(x, y)) + .map(|(x, y)| x * y) .fold(NttPolynomial::default(), |x, y| &x + &y) } } impl NttVector { pub fn ntt_inverse(&self) -> PolynomialVector { - PolynomialVector::new(self.0.iter().map(ntt_inverse).collect()) + PolynomialVector::new(self.0.iter().map(NttInverse::ntt_inverse).collect()) } } @@ -372,7 +405,7 @@ mod test { use array::typenum::{U2, U3, U8}; use module_lattice::util::Flatten; - // Multiplication in R_q, modulo X^256 + 1 + /// Multiplication in `R_q`, modulo X^256 + 1 fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial { let mut out = Polynomial::default(); for (i, x) in lhs.0.iter().enumerate() { @@ -393,7 +426,7 @@ mod test { fn const_ntt(x: Integer) -> NttPolynomial { let mut p = Polynomial::default(); p.0[0] = FieldElement::new(x); - super::ntt(&p) + p.ntt() } #[test] @@ -412,23 +445,23 @@ mod test { fn ntt() { let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer))); let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer))); - let f_hat = super::ntt(&f); - let g_hat = super::ntt(&g); + let f_hat = f.ntt(); + let g_hat = g.ntt(); // Verify that NTT and NTT^-1 are actually inverses - let f_unhat = ntt_inverse(&f_hat); + let f_unhat = f_hat.ntt_inverse(); assert_eq!(f, f_unhat); // Verify that NTT is a homomorphism with regard to addition let fg = &f + &g; let f_hat_g_hat = &f_hat + &g_hat; - let fg_unhat = ntt_inverse(&f_hat_g_hat); + let fg_unhat = f_hat_g_hat.ntt_inverse(); assert_eq!(fg, fg_unhat); // Verify that NTT is a homomorphism with regard to multiplication let fg = poly_mul(&f, &g); - let f_hat_g_hat = multiply_ntts(&f_hat, &g_hat); - let fg_unhat = ntt_inverse(&f_hat_g_hat); + let f_hat_g_hat = &f_hat * &g_hat; + let fg_unhat = f_hat_g_hat.ntt_inverse(); assert_eq!(fg, fg_unhat); } diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index 0373fbc..f0ee1da 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -1,6 +1,6 @@ use crate::B32; use crate::algebra::{ - NttMatrix, NttVector, Polynomial, PolynomialVector, ntt_inverse, ntt_vector, sample_poly_cbd, + Ntt, NttInverse, NttMatrix, NttVector, Polynomial, PolynomialVector, sample_poly_cbd, sample_poly_vec_cbd, }; use crate::compress::Compress; @@ -71,8 +71,8 @@ where let e: PolynomialVector = sample_poly_vec_cbd::(&sigma, P::K::U8); // NTT the vectors - let s_hat = ntt_vector(&s); - let e_hat = ntt_vector(&e); + let s_hat = s.ntt(); + let e_hat = e.ntt(); // Compute the public value let t_hat = &(&A_hat * &s_hat) + &e_hat; @@ -94,8 +94,8 @@ where let mut v: Polynomial = Encode::::decode(c2); v.decompress::(); - let u_hat = ntt_vector(&u); - let sTu = ntt_inverse(&(&self.s_hat * &u_hat)); + let u_hat = u.ntt(); + let sTu = (&self.s_hat * &u_hat).ntt_inverse(); let mut w = &v - &sTu; Encode::::encode(w.compress::()) } @@ -137,14 +137,14 @@ where let e2: Polynomial = sample_poly_cbd::(&prf_output); let A_hat_t = NttMatrix::::sample_uniform(&self.rho, true); - let r_hat: NttVector = ntt_vector(&r); + let r_hat: NttVector = r.ntt(); let ATr: PolynomialVector = (&A_hat_t * &r_hat).ntt_inverse(); let mut u = ATr + e1; let mut mu: Polynomial = Encode::::decode(message); mu.decompress::(); - let tTr: Polynomial = ntt_inverse(&(&self.t_hat * &r_hat)); + let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse(); let mut v = &(&tTr + &e2) + μ let c1 = Encode::::encode(u.compress::()); diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index c91d3d1..74e4602 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -327,21 +327,22 @@ impl Mul<&NttPolynomial> for Elem { } } -impl Mul<&NttPolynomial> for &NttPolynomial { +impl Mul<&NttPolynomial> for &NttPolynomial +where + F: Field + MultiplyNtt, +{ type Output = NttPolynomial; - // Algorithm 45 MultiplyNTT fn mul(self, rhs: &NttPolynomial) -> NttPolynomial { - NttPolynomial::new( - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(&x, &y)| x * y) - .collect(), - ) + F::multiply_ntt(self, rhs) } } +/// Perform multiplication in the NTT domain. +pub trait MultiplyNtt: Field { + fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial; +} + impl Neg for &NttPolynomial { type Output = NttPolynomial;