Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 124 additions & 91 deletions ml-kem/src/algebra.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use array::{Array, typenum::U256};
use core::ops::{Add, Mul};
use module_lattice::{algebra::Field, util::Truncate};
use module_lattice::{
algebra::{Field, MultiplyNtt},
util::Truncate,
};
use sha3::digest::XofReader;
use subtle::{Choice, ConstantTimeEq};

Expand Down Expand Up @@ -28,7 +31,7 @@ pub type PolynomialVector<K> = module_lattice::algebra::Vector<BaseField, K>;
/// An element of the ring `T_q` i.e. a tuple of 128 elements of the direct sum components of `T_q`.
pub type NttPolynomial = module_lattice::algebra::NttPolynomial<BaseField>;

// Algorithm 7 SampleNTT(B)
/// Algorithm 7: `SampleNTT(B)`
pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial {
struct FieldElementReader<'a> {
xof: &'a mut dyn XofReader,
Expand Down Expand Up @@ -89,11 +92,11 @@ pub fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial {
NttPolynomial::new(Array::from_fn(|_| reader.next()))
}

// Algorithm 8. SamplePolyCBD_eta(B)
//
// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in
// ByteDecode. We decode the PRF output into integers with eta bits, then use
// `count_ones` to perform the summation described in the algorithm.
/// Algorithm 8: `SamplePolyCBD_eta(B)`
///
/// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in
/// `ByteDecode`. We decode the PRF output into integers with eta bits, then use
/// `count_ones` to perform the summation described in the algorithm.
pub(crate) fn sample_poly_cbd<Eta>(B: &PrfOutput<Eta>) -> Polynomial
where
Eta: CbdSamplingSize,
Expand All @@ -102,76 +105,118 @@ where
Polynomial::new(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect())
}

// Algorithm 9. NTT
pub(crate) fn ntt(poly: &Polynomial) -> NttPolynomial {
let mut k = 1;
pub(crate) fn sample_poly_vec_cbd<Eta, K>(sigma: &B32, start_n: u8) -> PolynomialVector<K>
where
Eta: CbdSamplingSize,
K: ArraySize,
{
PolynomialVector::new(Array::from_fn(|i| {
let N = start_n + u8::truncate(i);
let prf_output = PRF::<Eta>(sigma, N);
sample_poly_cbd::<Eta>(&prf_output)
}))
}

/// The Number Theoretic Transform (NTT) is a variant of the Discrete Fourier Transform (DFT)
/// defined over a finite field that turns costly polynomial multiplications into simple
/// coefficient-wise multiplications modulo a fixed prime.
pub(crate) trait Ntt {
type Output;
fn ntt(&self) -> Self::Output;
}

/// Algorithm 9: `NTT`
impl Ntt for Polynomial {
type Output = NttPolynomial;

fn ntt(&self) -> NttPolynomial {
let mut k = 1;

let mut f = poly.0;
for len in [128, 64, 32, 16, 8, 4, 2] {
for start in (0..256).step_by(2 * len) {
let zeta = ZETA_POW_BITREV[k];
k += 1;
let mut f = self.0;
for len in [128, 64, 32, 16, 8, 4, 2] {
for start in (0..256).step_by(2 * len) {
let zeta = ZETA_POW_BITREV[k];
k += 1;

for j in start..(start + len) {
let t = zeta * f[j + len];
f[j + len] = f[j] - t;
f[j] = f[j] + t;
for j in start..(start + len) {
let t = zeta * f[j + len];
f[j + len] = f[j] - t;
f[j] = f[j] + t;
}
}
}

f.into()
}
}

impl<K: ArraySize> Ntt for PolynomialVector<K> {
type Output = NttVector<K>;

f.into()
fn ntt(&self) -> NttVector<K> {
NttVector(self.0.iter().map(Ntt::ntt).collect())
}
}

pub(crate) fn ntt_vector<K: ArraySize>(poly: &PolynomialVector<K>) -> NttVector<K> {
NttVector(poly.0.iter().map(ntt).collect())
/// The inverse NTT is the reverse of the Number Theoretic Transform, converting coefficient-wise
/// products back into standard polynomial form while preserving correctness modulo the same prime.
#[allow(clippy::module_name_repetitions)]
pub(crate) trait NttInverse {
type Output;
fn ntt_inverse(&self) -> Self::Output;
}

// Algorithm 10. NTT^{-1}
pub(crate) fn ntt_inverse(poly: &NttPolynomial) -> Polynomial {
let mut f: Array<FieldElement, U256> = poly.0.clone();
/// Algorithm 10: `NTT^{-1}`
impl NttInverse for NttPolynomial {
type Output = Polynomial;

fn ntt_inverse(&self) -> Polynomial {
let mut f: Array<FieldElement, U256> = self.0.clone();

let mut k = 127;
for len in [2, 4, 8, 16, 32, 64, 128] {
for start in (0..256).step_by(2 * len) {
let zeta = ZETA_POW_BITREV[k];
k -= 1;
let mut k = 127;
for len in [2, 4, 8, 16, 32, 64, 128] {
for start in (0..256).step_by(2 * len) {
let zeta = ZETA_POW_BITREV[k];
k -= 1;

for j in start..(start + len) {
let t = f[j];
f[j] = t + f[j + len];
f[j + len] = zeta * (f[j + len] - t);
for j in start..(start + len) {
let t = f[j];
f[j] = t + f[j + len];
f[j + len] = zeta * (f[j + len] - t);
}
}
}
}

FieldElement::new(3303) * &Polynomial::new(f)
FieldElement::new(3303) * &Polynomial::new(f)
}
}

// Algorithm 11. MultiplyNTTs
fn multiply_ntts(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial {
let mut out = NttPolynomial::new(Array::default());

for i in 0..128 {
let (c0, c1) = base_case_multiply(
lhs.0[2 * i],
lhs.0[2 * i + 1],
rhs.0[2 * i],
rhs.0[2 * i + 1],
i,
);

out.0[2 * i] = c0;
out.0[2 * i + 1] = c1;
}
/// Algorithm 11: `MultiplyNTTs`
impl MultiplyNtt for BaseField {
fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial {
let mut out = NttPolynomial::new(Array::default());

for i in 0..128 {
let (c0, c1) = base_case_multiply(
lhs.0[2 * i],
lhs.0[2 * i + 1],
rhs.0[2 * i],
rhs.0[2 * i + 1],
i,
);

out.0[2 * i] = c0;
out.0[2 * i + 1] = c1;
}

out
out
}
}

// Algorithm 12. BaseCaseMultiply
//
// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of
// modular reductions, since these are the expensive operation.
/// Algorithm 12: `BaseCaseMultiply`
///
/// This is a hot loop. We promote to u64 so that we can do the absolute minimum number of
/// modular reductions, since these are the expensive operation.
#[inline]
fn base_case_multiply(
a0: FieldElement,
Expand All @@ -193,30 +238,18 @@ fn base_case_multiply(
(FieldElement::new(c0), FieldElement::new(c1))
}

pub(crate) fn sample_poly_vec_cbd<Eta, K>(sigma: &B32, start_n: u8) -> PolynomialVector<K>
where
Eta: CbdSamplingSize,
K: ArraySize,
{
PolynomialVector::new(Array::from_fn(|i| {
let N = start_n + u8::truncate(i);
let prf_output = PRF::<Eta>(sigma, N);
sample_poly_cbd::<Eta>(&prf_output)
}))
}

// Since the powers of zeta used in the NTT and MultiplyNTTs are fixed, we use pre-computed tables
// to avoid the need to compute the exponetiations at runtime.
//
// * ZETA_POW_BITREV[i] = zeta^{BitRev_7(i)}
// * GAMMA[i] = zeta^{2 BitRev_7(i) + 1}
//
// Note that the const environment here imposes some annoying conditions. Because operator
// overloading can't be const, we have to do all the reductions here manually. Because `for` loops
// are forbidden in `const` functions, we do them manually with `while` loops.
//
// The values computed here match those provided in Appendix A of FIPS 203. ZETA_POW_BITREV
// corresponds to the first table, and GAMMA to the second table.
/// Since the powers of zeta used in the `NTT` and `MultiplyNTTs` are fixed, we use pre-computed
/// tables to avoid the need to compute the exponentiations at runtime.
///
/// * `ZETA_POW_BITREV[i] = zeta^{BitRev_7(i)}`
/// * `GAMMA[i] = zeta^{2 BitRev_7(i) + 1}`
///
/// Note that the const environment here imposes some annoying conditions. Because operator
/// overloading can't be const, we have to do all the reductions here manually. Because `for` loops
/// are forbidden in `const` functions, we do them manually with `while` loops.
///
/// The values computed here match those provided in Appendix A of FIPS 203.
/// `ZETA_POW_BITREV` corresponds to the first table, and `GAMMA` to the second table.
#[allow(clippy::cast_possible_truncation)]
const ZETA_POW_BITREV: [FieldElement; 128] = {
const ZETA: u64 = 17;
Expand Down Expand Up @@ -328,14 +361,14 @@ impl<K: ArraySize> Mul<&NttVector<K>> for &NttVector<K> {
self.0
.iter()
.zip(rhs.0.iter())
.map(|(x, y)| multiply_ntts(x, y))
.map(|(x, y)| x * y)
.fold(NttPolynomial::default(), |x, y| &x + &y)
}
}

impl<K: ArraySize> NttVector<K> {
pub fn ntt_inverse(&self) -> PolynomialVector<K> {
PolynomialVector::new(self.0.iter().map(ntt_inverse).collect())
PolynomialVector::new(self.0.iter().map(NttInverse::ntt_inverse).collect())
}
}

Expand Down Expand Up @@ -372,7 +405,7 @@ mod test {
use array::typenum::{U2, U3, U8};
use module_lattice::util::Flatten;

// Multiplication in R_q, modulo X^256 + 1
/// Multiplication in `R_q`, modulo X^256 + 1
fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial {
let mut out = Polynomial::default();
for (i, x) in lhs.0.iter().enumerate() {
Expand All @@ -393,7 +426,7 @@ mod test {
fn const_ntt(x: Integer) -> NttPolynomial {
let mut p = Polynomial::default();
p.0[0] = FieldElement::new(x);
super::ntt(&p)
p.ntt()
}

#[test]
Expand All @@ -412,23 +445,23 @@ mod test {
fn ntt() {
let f = Polynomial::new(Array::from_fn(|i| FieldElement::new(i as Integer)));
let g = Polynomial::new(Array::from_fn(|i| FieldElement::new(2 * i as Integer)));
let f_hat = super::ntt(&f);
let g_hat = super::ntt(&g);
let f_hat = f.ntt();
let g_hat = g.ntt();

// Verify that NTT and NTT^-1 are actually inverses
let f_unhat = ntt_inverse(&f_hat);
let f_unhat = f_hat.ntt_inverse();
assert_eq!(f, f_unhat);

// Verify that NTT is a homomorphism with regard to addition
let fg = &f + &g;
let f_hat_g_hat = &f_hat + &g_hat;
let fg_unhat = ntt_inverse(&f_hat_g_hat);
let fg_unhat = f_hat_g_hat.ntt_inverse();
assert_eq!(fg, fg_unhat);

// Verify that NTT is a homomorphism with regard to multiplication
let fg = poly_mul(&f, &g);
let f_hat_g_hat = multiply_ntts(&f_hat, &g_hat);
let fg_unhat = ntt_inverse(&f_hat_g_hat);
let f_hat_g_hat = &f_hat * &g_hat;
let fg_unhat = f_hat_g_hat.ntt_inverse();
assert_eq!(fg, fg_unhat);
}

Expand Down
14 changes: 7 additions & 7 deletions ml-kem/src/pke.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::B32;
use crate::algebra::{
NttMatrix, NttVector, Polynomial, PolynomialVector, ntt_inverse, ntt_vector, sample_poly_cbd,
Ntt, NttInverse, NttMatrix, NttVector, Polynomial, PolynomialVector, sample_poly_cbd,
sample_poly_vec_cbd,
};
use crate::compress::Compress;
Expand Down Expand Up @@ -71,8 +71,8 @@ where
let e: PolynomialVector<P::K> = sample_poly_vec_cbd::<P::Eta1, P::K>(&sigma, P::K::U8);

// NTT the vectors
let s_hat = ntt_vector(&s);
let e_hat = ntt_vector(&e);
let s_hat = s.ntt();
let e_hat = e.ntt();

// Compute the public value
let t_hat = &(&A_hat * &s_hat) + &e_hat;
Expand All @@ -94,8 +94,8 @@ where
let mut v: Polynomial = Encode::<P::Dv>::decode(c2);
v.decompress::<P::Dv>();

let u_hat = ntt_vector(&u);
let sTu = ntt_inverse(&(&self.s_hat * &u_hat));
let u_hat = u.ntt();
let sTu = (&self.s_hat * &u_hat).ntt_inverse();
let mut w = &v - &sTu;
Encode::<U1>::encode(w.compress::<U1>())
}
Expand Down Expand Up @@ -137,14 +137,14 @@ where
let e2: Polynomial = sample_poly_cbd::<P::Eta2>(&prf_output);

let A_hat_t = NttMatrix::<P::K>::sample_uniform(&self.rho, true);
let r_hat: NttVector<P::K> = ntt_vector(&r);
let r_hat: NttVector<P::K> = r.ntt();
let ATr: PolynomialVector<P::K> = (&A_hat_t * &r_hat).ntt_inverse();
let mut u = ATr + e1;

let mut mu: Polynomial = Encode::<U1>::decode(message);
mu.decompress::<U1>();

let tTr: Polynomial = ntt_inverse(&(&self.t_hat * &r_hat));
let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse();
let mut v = &(&tTr + &e2) + &mu;

let c1 = Encode::<P::Du>::encode(u.compress::<P::Du>());
Expand Down
19 changes: 10 additions & 9 deletions module-lattice/src/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,21 +327,22 @@ impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
}
}

impl<F: Field> Mul<&NttPolynomial<F>> for &NttPolynomial<F> {
impl<F> Mul<&NttPolynomial<F>> for &NttPolynomial<F>
where
F: Field + MultiplyNtt,
{
type Output = NttPolynomial<F>;

// Algorithm 45 MultiplyNTT
fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
NttPolynomial::new(
self.0
.iter()
.zip(rhs.0.iter())
.map(|(&x, &y)| x * y)
.collect(),
)
F::multiply_ntt(self, rhs)
}
}

/// Perform multiplication in the NTT domain.
pub trait MultiplyNtt: Field {
fn multiply_ntt(lhs: &NttPolynomial<Self>, rhs: &NttPolynomial<Self>) -> NttPolynomial<Self>;
}

impl<F: Field> Neg for &NttPolynomial<F> {
type Output = NttPolynomial<F>;

Expand Down