From 8a9b7be74035827f60dcf30d3891f13c0e56c820 Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Thu, 29 Jan 2026 07:49:49 -0700 Subject: [PATCH] ml-kem/module-lattice: add traits for NTT operations Adds `MultiplyNtt` to `module-lattice`, and `Ntt` and `NttInverse` to `ml-kem`. Previously `module-lattice` defined the `Mul` impl on `NttPolynomial` in terms of Algorithm 45 from FIPS 204, and this is used by the `Mul` impls on `NttVector` and `NttMatrix` as well. For ML-KEM we instead need to plug in Algorithm 11 from FIPS 203 (MultiplyNTTs), so we need a trait that lets each crate define the `Mul` impl on `NttPolynomial`. This adds a `MultiplyNtt` trait to `module-lattice`, currently defined on a `Field` (perhaps not great but it's the generic parameter we have), which allows each construction to plug in their own NTT multiplication algorithm. The existing implementation of Algorithm 45 from FIPS 204 will need to move to the `ml-dsa` crate. Also, following the structure of the `ml-dsa` crate, this defines two same-shaped traits, `Ntt` and `NttInverse`, that can be impl'd on types from `module-lattice` to use the same method syntax that was being used previously before #210. Unfortunately we can't share traits with `ml-dsa`, because the only way we can impl these traits for types from `module-lattice` is if we define them. But they're small and there aren't that many of them. --- ml-kem/src/algebra.rs | 215 ++++++++++++++++++++-------------- ml-kem/src/pke.rs | 14 +-- module-lattice/src/algebra.rs | 19 +-- 3 files changed, 141 insertions(+), 107 deletions(-) diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 51535a8..1e409b9 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -1,6 +1,9 @@ use array::{Array, typenum::U256}; use core::ops::{Add, Mul}; -use module_lattice::{algebra::Field, util::Truncate}; +use module_lattice::{ + algebra::{Field, MultiplyNtt}, + util::Truncate, +}; use sha3::digest::XofReader; use subtle::{Choice, ConstantTimeEq}; @@ -28,7 +31,7 @@ pub type PolynomialVector = module_lattice::algebra::Vector; /// An element of the ring `T_q` i.e. a tuple of 128 elements of the direct sum components of `T_q`. pub type NttPolynomial = module_lattice::algebra::NttPolynomial; -// Algorithm 7 SampleNTT(B) +/// Algorithm 7: `SampleNTT(B)` pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { struct FieldElementReader<'a> { xof: &'a mut dyn XofReader, @@ -89,11 +92,11 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial { NttPolynomial::new(Array::from_fn(|_| reader.next())) } -// Algorithm 8. SamplePolyCBD_eta(B) -// -// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in -// ByteDecode. We decode the PRF output into integers with eta bits, then use -// `count_ones` to perform the summation described in the algorithm. +/// Algorithm 8: `SamplePolyCBD_eta(B)` +/// +/// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in +/// `ByteDecode`. We decode the PRF output into integers with eta bits, then use +/// `count_ones` to perform the summation described in the algorithm. pub(crate) fn sample_poly_cbd(B: &PrfOutput) -> Polynomial where Eta: CbdSamplingSize, @@ -102,76 +105,118 @@ where Polynomial::new(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect()) } -// Algorithm 9. NTT -pub(crate) fn ntt(poly: &Polynomial) -> NttPolynomial { - let mut k = 1; +pub(crate) fn sample_poly_vec_cbd(sigma: &B32, start_n: u8) -> PolynomialVector +where + Eta: CbdSamplingSize, + K: ArraySize, +{ + PolynomialVector::new(Array::from_fn(|i| { + let N = start_n + u8::truncate(i); + let prf_output = PRF::(sigma, N); + sample_poly_cbd::(&prf_output) + })) +} + +/// The Number Theoretic Transform (NTT) is a variant of the Discrete Fourier Transform (DFT) +/// defined over a finite field that turns costly polynomial multiplications into simple +/// coefficient-wise multiplications modulo a fixed prime. +pub(crate) trait Ntt { + type Output; + fn ntt(&self) -> Self::Output; +} + +/// Algorithm 9: `NTT` +impl Ntt for Polynomial { + type Output = NttPolynomial; + + fn ntt(&self) -> NttPolynomial { + let mut k = 1; - let mut f = poly.0; - for len in [128, 64, 32, 16, 8, 4, 2] { - for start in (0..256).step_by(2 * len) { - let zeta = ZETA_POW_BITREV[k]; - k += 1; + let mut f = self.0; + for len in [128, 64, 32, 16, 8, 4, 2] { + for start in (0..256).step_by(2 * len) { + let zeta = ZETA_POW_BITREV[k]; + k += 1; - for j in start..(start + len) { - let t = zeta * f[j + len]; - f[j + len] = f[j] - t; - f[j] = f[j] + t; + for j in start..(start + len) { + let t = zeta * f[j + len]; + f[j + len] = f[j] - t; + f[j] = f[j] + t; + } } } + + f.into() } +} + +impl Ntt for PolynomialVector { + type Output = NttVector; - f.into() + fn ntt(&self) -> NttVector { + NttVector(self.0.iter().map(Ntt::ntt).collect()) + } } -pub(crate) fn ntt_vector(poly: &PolynomialVector) -> NttVector { - NttVector(poly.0.iter().map(ntt).collect()) +/// The inverse NTT is the reverse of the Number Theoretic Transform, converting coefficient-wise +/// products back into standard polynomial form while preserving correctness modulo the same prime. +#[allow(clippy::module_name_repetitions)] +pub(crate) trait NttInverse { + type Output; + fn ntt_inverse(&self) -> Self::Output; } -// Algorithm 10. NTT^{-1} -pub(crate) fn ntt_inverse(poly: &NttPolynomial) -> Polynomial { - let mut f: Array = poly.0.clone(); +/// Algorithm 10: `NTT^{-1}` +impl NttInverse for NttPolynomial { + type Output = Polynomial; + + fn ntt_inverse(&self) -> Polynomial { + let mut f: Array = self.0.clone(); - let mut k = 127; - for len in [2, 4, 8, 16, 32, 64, 128] { - for start in (0..256).step_by(2 * len) { - let zeta = ZETA_POW_BITREV[k]; - k -= 1; + let mut k = 127; + for len in [2, 4, 8, 16, 32, 64, 128] { + for start in (0..256).step_by(2 * len) { + let zeta = ZETA_POW_BITREV[k]; + k -= 1; - for j in start..(start + len) { - let t = f[j]; - f[j] = t + f[j + len]; - f[j + len] = zeta * (f[j + len] - t); + for j in start..(start + len) { + let t = f[j]; + f[j] = t + f[j + len]; + f[j + len] = zeta * (f[j + len] - t); + } } } - } - FieldElement::new(3303) * &Polynomial::new(f) + FieldElement::new(3303) * &Polynomial::new(f) + } } -// Algorithm 11. MultiplyNTTs -fn multiply_ntts(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial { - let mut out = NttPolynomial::new(Array::default()); - - for i in 0..128 { - let (c0, c1) = base_case_multiply( - lhs.0[2 * i], - lhs.0[2 * i + 1], - rhs.0[2 * i], - rhs.0[2 * i + 1], - i, - ); - - out.0[2 * i] = c0; - out.0[2 * i + 1] = c1; - } +/// Algorithm 11: `MultiplyNTTs` +impl MultiplyNtt for BaseField { + fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial { + let mut out = NttPolynomial::new(Array::default()); + + for i in 0..128 { + let (c0, c1) = base_case_multiply( + lhs.0[2 * i], + lhs.0[2 * i + 1], + rhs.0[2 * i], + rhs.0[2 * i + 1], + i, + ); + + out.0[2 * i] = c0; + out.0[2 * i + 1] = c1; + } - out + out + } } -// Algorithm 12. BaseCaseMultiply -// -// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of -// modular reductions, since these are the expensive operation. +/// Algorithm 12: `BaseCaseMultiply` +/// +/// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of +/// modular reductions, since these are the expensive operation. #[inline] fn base_case_multiply( a0: FieldElement, @@ -193,30 +238,18 @@ fn base_case_multiply( (FieldElement::new(c0), FieldElement::new(c1)) } -pub(crate) fn sample_poly_vec_cbd(sigma: &B32, start_n: u8) -> PolynomialVector -where - Eta: CbdSamplingSize, - K: ArraySize, -{ - PolynomialVector::new(Array::from_fn(|i| { - let N = start_n + u8::truncate(i); - let prf_output = PRF::(sigma, N); - sample_poly_cbd::(&prf_output) - })) -} - -// Since the powers of zeta used in the NTT and MultiplyNTTs are fixed, we use pre-computed tables -// to avoid the need to compute the exponetiations at runtime. -// -// * ZETA_POW_BITREV[i] = zeta^{BitRev_7(i)} -// * GAMMA[i] = zeta^{2 BitRev_7(i) + 1} -// -// Note that the const environment here imposes some annoying conditions. Because operator -// overloading can't be const, we have to do all the reductions here manually. Because `for` loops -// are forbidden in `const` functions, we do them manually with `while` loops. -// -// The values computed here match those provided in Appendix A of FIPS 203. ZETA_POW_BITREV -// corresponds to the first table, and GAMMA to the second table. +/// Since the powers of zeta used in the `NTT` and `MultiplyNTTs` are fixed, we use pre-computed +/// tables to avoid the need to compute the exponentiations at runtime. +/// +/// * `ZETA_POW_BITREV[i] = zeta^{BitRev_7(i)}` +/// * `GAMMA[i] = zeta^{2 BitRev_7(i) + 1}` +/// +/// Note that the const environment here imposes some annoying conditions. Because operator +/// overloading can't be const, we have to do all the reductions here manually. Because `for` loops +/// are forbidden in `const` functions, we do them manually with `while` loops. +/// +/// The values computed here match those provided in Appendix A of FIPS 203. +/// `ZETA_POW_BITREV` corresponds to the first table, and `GAMMA` to the second table. #[allow(clippy::cast_possible_truncation)] const ZETA_POW_BITREV: [FieldElement; 128] = { const ZETA: u64 = 17; @@ -328,14 +361,14 @@ impl Mul<&NttVector> for &NttVector { self.0 .iter() .zip(rhs.0.iter()) - .map(|(x, y)| multiply_ntts(x, y)) + .map(|(x, y)| x * y) .fold(NttPolynomial::default(), |x, y| &x + &y) } } impl NttVector { pub fn ntt_inverse(&self) -> PolynomialVector { - PolynomialVector::new(self.0.iter().map(ntt_inverse).collect()) + PolynomialVector::new(self.0.iter().map(NttInverse::ntt_inverse).collect()) } } @@ -372,7 +405,7 @@ mod test { use array::typenum::{U2, U3, U8}; use module_lattice::util::Flatten; - // Multiplication in R_q, modulo X^256 + 1 + /// Multiplication in `R_q`, modulo X^256 + 1 fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial { let mut out = Polynomial::default(); for (i, x) in lhs.0.iter().enumerate() { @@ -393,7 +426,7 @@ mod test { fn const_ntt(x: Integer) -> NttPolynomial { let mut p = Polynomial::default(); p.0[0] = FieldElement::new(x); - super::ntt(&p) + p.ntt() } #[test] @@ -412,23 +445,23 @@ mod test { fn ntt() { let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer))); let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer))); - let f_hat = super::ntt(&f); - let g_hat = super::ntt(&g); + let f_hat = f.ntt(); + let g_hat = g.ntt(); // Verify that NTT and NTT^-1 are actually inverses - let f_unhat = ntt_inverse(&f_hat); + let f_unhat = f_hat.ntt_inverse(); assert_eq!(f, f_unhat); // Verify that NTT is a homomorphism with regard to addition let fg = &f + &g; let f_hat_g_hat = &f_hat + &g_hat; - let fg_unhat = ntt_inverse(&f_hat_g_hat); + let fg_unhat = f_hat_g_hat.ntt_inverse(); assert_eq!(fg, fg_unhat); // Verify that NTT is a homomorphism with regard to multiplication let fg = poly_mul(&f, &g); - let f_hat_g_hat = multiply_ntts(&f_hat, &g_hat); - let fg_unhat = ntt_inverse(&f_hat_g_hat); + let f_hat_g_hat = &f_hat * &g_hat; + let fg_unhat = f_hat_g_hat.ntt_inverse(); assert_eq!(fg, fg_unhat); } diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index 0373fbc..f0ee1da 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -1,6 +1,6 @@ use crate::B32; use crate::algebra::{ - NttMatrix, NttVector, Polynomial, PolynomialVector, ntt_inverse, ntt_vector, sample_poly_cbd, + Ntt, NttInverse, NttMatrix, NttVector, Polynomial, PolynomialVector, sample_poly_cbd, sample_poly_vec_cbd, }; use crate::compress::Compress; @@ -71,8 +71,8 @@ where let e: PolynomialVector = sample_poly_vec_cbd::(&sigma, P::K::U8); // NTT the vectors - let s_hat = ntt_vector(&s); - let e_hat = ntt_vector(&e); + let s_hat = s.ntt(); + let e_hat = e.ntt(); // Compute the public value let t_hat = &(&A_hat * &s_hat) + &e_hat; @@ -94,8 +94,8 @@ where let mut v: Polynomial = Encode::::decode(c2); v.decompress::(); - let u_hat = ntt_vector(&u); - let sTu = ntt_inverse(&(&self.s_hat * &u_hat)); + let u_hat = u.ntt(); + let sTu = (&self.s_hat * &u_hat).ntt_inverse(); let mut w = &v - &sTu; Encode::::encode(w.compress::()) } @@ -137,14 +137,14 @@ where let e2: Polynomial = sample_poly_cbd::(&prf_output); let A_hat_t = NttMatrix::::sample_uniform(&self.rho, true); - let r_hat: NttVector = ntt_vector(&r); + let r_hat: NttVector = r.ntt(); let ATr: PolynomialVector = (&A_hat_t * &r_hat).ntt_inverse(); let mut u = ATr + e1; let mut mu: Polynomial = Encode::::decode(message); mu.decompress::(); - let tTr: Polynomial = ntt_inverse(&(&self.t_hat * &r_hat)); + let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse(); let mut v = &(&tTr + &e2) + μ let c1 = Encode::::encode(u.compress::()); diff --git a/module-lattice/src/algebra.rs b/module-lattice/src/algebra.rs index c91d3d1..74e4602 100644 --- a/module-lattice/src/algebra.rs +++ b/module-lattice/src/algebra.rs @@ -327,21 +327,22 @@ impl Mul<&NttPolynomial> for Elem { } } -impl Mul<&NttPolynomial> for &NttPolynomial { +impl Mul<&NttPolynomial> for &NttPolynomial +where + F: Field + MultiplyNtt, +{ type Output = NttPolynomial; - // Algorithm 45 MultiplyNTT fn mul(self, rhs: &NttPolynomial) -> NttPolynomial { - NttPolynomial::new( - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(&x, &y)| x * y) - .collect(), - ) + F::multiply_ntt(self, rhs) } } +/// Perform multiplication in the NTT domain. +pub trait MultiplyNtt: Field { + fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial; +} + impl Neg for &NttPolynomial { type Output = NttPolynomial;