From: Eric Biggers <ebiggers@kernel.org>
To: linux-crypto@vger.kernel.org
Cc: linux-kernel@vger.kernel.org, Ard Biesheuvel <ardb@kernel.org>,
"Jason A . Donenfeld" <Jason@zx2c4.com>,
Herbert Xu <herbert@gondor.apana.org.au>,
Ryan Appel <ryan.appel.333@gmail.com>,
Chris Leech <cleech@redhat.com>,
Eric Biggers <ebiggers@kernel.org>
Subject: [PATCH 1/5] lib/crypto: mlkem: Add ML-KEM-768 and ML-KEM-1024 support
Date: Mon, 25 May 2026 13:43:59 -0500 [thread overview]
Message-ID: <20260525184403.101818-2-ebiggers@kernel.org> (raw)
In-Reply-To: <20260525184403.101818-1-ebiggers@kernel.org>
Add support for ML-KEM-768 and ML-KEM-1024 to lib/crypto/. This is a
proof-of-concept that won't be merged until there's an in-kernel user.
ML-KEM is a "post-quantum" key encapsulation mechanism that is specified
in FIPS 203 and is based on CRYSTALS-Kyber. As a key encapsulation
mechanism (KEM), it consists of a set of algorithms that can be used to
establish a shared secret key over a public channel. As a
"post-quantum" algorithm, it's conjectured to be secure even against
adversaries in possession of a large-scale quantum computer.
Current users of classical key agreement schemes in the kernel include
NVMe in-band authentication, Bluetooth, and WireGuard. It's likely that
at least some of these will eventually be upgraded to a hybrid KEM that
uses ML-KEM, such as X-Wing which uses X25519 and ML-KEM-768.
This commit just gets things started by implementing ML-KEM-768 and
ML-KEM-1024. A later commit adds X-Wing support on top of it.
ML-KEM-512 has been omitted, since it has seen limited adoption and
remains somewhat controversial. It wouldn't be that much work to
support it too, but there's no clear need to do so yet.
Signed-off-by: Eric Biggers <ebiggers@kernel.org>
---
Documentation/crypto/libcrypto-asymmetric.rst | 20 +
Documentation/crypto/libcrypto-signature.rst | 11 -
Documentation/crypto/libcrypto.rst | 2 +-
include/crypto/mlkem.h | 151 +++
lib/crypto/Kconfig | 8 +
lib/crypto/Makefile | 5 +
lib/crypto/mlkem.c | 894 ++++++++++++++++++
7 files changed, 1079 insertions(+), 12 deletions(-)
create mode 100644 Documentation/crypto/libcrypto-asymmetric.rst
delete mode 100644 Documentation/crypto/libcrypto-signature.rst
create mode 100644 include/crypto/mlkem.h
create mode 100644 lib/crypto/mlkem.c
diff --git a/Documentation/crypto/libcrypto-asymmetric.rst b/Documentation/crypto/libcrypto-asymmetric.rst
new file mode 100644
index 000000000000..6e71c5ce823f
--- /dev/null
+++ b/Documentation/crypto/libcrypto-asymmetric.rst
@@ -0,0 +1,20 @@
+.. SPDX-License-Identifier: GPL-2.0-or-later
+
+Asymmetric cryptography
+=======================
+
+ML-DSA
+------
+
+Support for the ML-DSA digital signature algorithm.
+
+.. kernel-doc:: include/crypto/mldsa.h
+
+ML-KEM
+------
+
+Support for the ML-KEM key encapsulation mechanism.
+
+This shall be used as part of a hybrid scheme such as X-Wing, not by itself.
+
+.. kernel-doc:: include/crypto/mlkem.h
diff --git a/Documentation/crypto/libcrypto-signature.rst b/Documentation/crypto/libcrypto-signature.rst
deleted file mode 100644
index e80d59fa51b6..000000000000
--- a/Documentation/crypto/libcrypto-signature.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-.. SPDX-License-Identifier: GPL-2.0-or-later
-
-Digital signature algorithms
-============================
-
-ML-DSA
-------
-
-Support for the ML-DSA digital signature algorithm.
-
-.. kernel-doc:: include/crypto/mldsa.h
diff --git a/Documentation/crypto/libcrypto.rst b/Documentation/crypto/libcrypto.rst
index a1557d45b0e5..8c61b6513e06 100644
--- a/Documentation/crypto/libcrypto.rst
+++ b/Documentation/crypto/libcrypto.rst
@@ -156,10 +156,10 @@ API documentation
=================
.. toctree::
:maxdepth: 2
+ libcrypto-asymmetric
libcrypto-blockcipher
libcrypto-hash
- libcrypto-signature
libcrypto-utils
sha3
diff --git a/include/crypto/mlkem.h b/include/crypto/mlkem.h
new file mode 100644
index 000000000000..e33d065c5442
--- /dev/null
+++ b/include/crypto/mlkem.h
@@ -0,0 +1,151 @@
+/* SPDX-License-Identifier: GPL-2.0-or-later */
+/*
+ * Copyright 2026 Google LLC
+ */
+
+/**
+ * DOC: ML-KEM (Module-Lattice-Based Key Encapsulation Mechanism)
+ *
+ * This is an implementation of ML-KEM, a "post-quantum" key encapsulation
+ * mechanism that is specified in FIPS 203 and is based on CRYSTALS-Kyber.
+ *
+ * Specifically, the ML-KEM-768 and ML-KEM-1024 parameter sets are supported.
+ *
+ * This shall be used as part of a hybrid scheme such as X-Wing, not by itself.
+ *
+ * This implementation is designed to be constant-time, compact, and
+ * memory-efficient, and to reuse the kernel's SHA-3 routines. For simplicity,
+ * it stores integers mod Q as their standard representatives in the interval
+ * [0, Q - 1] across function boundaries. (This makes it more similar to e.g.
+ * BoringSSL than to the Kyber reference code, which uses a slightly more
+ * optimized but harder-to-understand approach.)
+ */
+
+#ifndef _CRYPTO_MLKEM_H
+#define _CRYPTO_MLKEM_H
+
+#include <linux/types.h>
+
+#define MLKEM_SEED_BYTES 64 /* Length of seed for KeyGen */
+#define MLKEM_ESEED_BYTES 32 /* Length of seed for Encaps */
+#define MLKEM_SHARED_SECRET_BYTES 32
+
+#define MLKEM768_PUBLIC_KEY_BYTES 1184
+#define MLKEM768_SECRET_KEY_BYTES 2400
+#define MLKEM768_CIPHERTEXT_BYTES 1088
+
+/**
+ * mlkem768_keygen() - Generate an ML-KEM-768 key pair
+ * @pk: (output) The public key (encapsulation key)
+ * @sk: (output) The secret key (decapsulation key)
+ *
+ * Context: Might sleep
+ *
+ * Return: 0 on success, or -ENOMEM if out of memory.
+ */
+int mlkem768_keygen(u8 pk[MLKEM768_PUBLIC_KEY_BYTES],
+ u8 sk[MLKEM768_SECRET_KEY_BYTES]);
+
+/**
+ * mlkem768_encaps() - Generate and encapsulate shared secret with ML-KEM-768
+ * @ct: (output) The ciphertext
+ * @ss: (output) The generated shared secret
+ * @pk: The public key (encapsulation key)
+ *
+ * Context: Might sleep
+ *
+ * Return:
+ * * 0 on success
+ * * -EBADMSG if the public key is malformed
+ * * -ENOMEM if out of memory
+ */
+int mlkem768_encaps(u8 ct[MLKEM768_CIPHERTEXT_BYTES],
+ u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 pk[MLKEM768_PUBLIC_KEY_BYTES]);
+
+/**
+ * mlkem768_decaps() - Decapsulate shared secret with ML-KEM-768
+ * @ss: (output) The decapsulated shared secret
+ * @ct: The ciphertext
+ * @sk: The secret key (decapsulation key)
+ *
+ * Context: Might sleep
+ *
+ * Return:
+ * * 0 on success, including the "implicit rejection" case where the ciphertext
+ * is invalid and a randomized shared secret is returned
+ * * -EBADMSG if the secret key is malformed
+ * * -ENOMEM if out of memory
+ */
+int mlkem768_decaps(u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 ct[MLKEM768_CIPHERTEXT_BYTES],
+ const u8 sk[MLKEM768_SECRET_KEY_BYTES]);
+
+#define MLKEM1024_PUBLIC_KEY_BYTES 1568
+#define MLKEM1024_SECRET_KEY_BYTES 3168
+#define MLKEM1024_CIPHERTEXT_BYTES 1568
+
+/**
+ * mlkem1024_keygen() - Generate an ML-KEM-1024 key pair
+ * @pk: (output) The public key (encapsulation key)
+ * @sk: (output) The secret key (decapsulation key)
+ *
+ * Context: Might sleep
+ *
+ * Return: 0 on success, or -ENOMEM if out of memory.
+ */
+int mlkem1024_keygen(u8 pk[MLKEM1024_PUBLIC_KEY_BYTES],
+ u8 sk[MLKEM1024_SECRET_KEY_BYTES]);
+
+/**
+ * mlkem1024_encaps() - Generate and encapsulate shared secret with ML-KEM-1024
+ * @ct: (output) The ciphertext
+ * @ss: (output) The generated shared secret
+ * @pk: The public key (encapsulation key)
+ *
+ * Context: Might sleep
+ *
+ * Return:
+ * * 0 on success
+ * * -EBADMSG if the public key is malformed
+ * * -ENOMEM if out of memory
+ */
+int mlkem1024_encaps(u8 ct[MLKEM1024_CIPHERTEXT_BYTES],
+ u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 pk[MLKEM1024_PUBLIC_KEY_BYTES]);
+
+/**
+ * mlkem1024_decaps() - Decapsulate shared secret with ML-KEM-1024
+ * @ss: (output) The decapsulated shared secret
+ * @ct: The ciphertext
+ * @sk: The secret key (decapsulation key)
+ *
+ * Context: Might sleep
+ *
+ * Return:
+ * * 0 on success, including the "implicit rejection" case where the ciphertext
+ * is invalid and a randomized shared secret is returned
+ * * -EBADMSG if the secret key is malformed
+ * * -ENOMEM if out of memory
+ */
+int mlkem1024_decaps(u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 ct[MLKEM1024_CIPHERTEXT_BYTES],
+ const u8 sk[MLKEM1024_SECRET_KEY_BYTES]);
+
+/* Functions taking explicit seeds, only for KUnit testing and hybrid KEMs */
+int mlkem768_keygen_internal(u8 pk[MLKEM768_PUBLIC_KEY_BYTES],
+ u8 sk[MLKEM768_SECRET_KEY_BYTES],
+ const u8 seed[MLKEM_SEED_BYTES]);
+int mlkem768_encaps_internal(u8 ct[MLKEM768_CIPHERTEXT_BYTES],
+ u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 pk[MLKEM768_PUBLIC_KEY_BYTES],
+ const u8 eseed[MLKEM_ESEED_BYTES]);
+int mlkem1024_keygen_internal(u8 pk[MLKEM1024_PUBLIC_KEY_BYTES],
+ u8 sk[MLKEM1024_SECRET_KEY_BYTES],
+ const u8 seed[MLKEM_SEED_BYTES]);
+int mlkem1024_encaps_internal(u8 ct[MLKEM1024_CIPHERTEXT_BYTES],
+ u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 pk[MLKEM1024_PUBLIC_KEY_BYTES],
+ const u8 eseed[MLKEM_ESEED_BYTES]);
+
+#endif /* _CRYPTO_MLKEM_H */
diff --git a/lib/crypto/Kconfig b/lib/crypto/Kconfig
index d3904b72dae7..acaa64af4c6d 100644
--- a/lib/crypto/Kconfig
+++ b/lib/crypto/Kconfig
@@ -141,10 +141,18 @@ config CRYPTO_LIB_MLDSA
select CRYPTO_LIB_SHA3
help
The ML-DSA library functions. Select this if your module uses any of
the functions from <crypto/mldsa.h>.
+config CRYPTO_LIB_MLKEM
+ tristate
+ select CRYPTO_LIB_SHA3
+ select CRYPTO_LIB_UTILS
+ help
+ The ML-KEM library functions. Select this if your module uses any of
+ the functions from <crypto/mlkem.h>.
+
config CRYPTO_LIB_NH
tristate
help
Implementation of the NH almost-universal hash function, specifically
the variant of NH used in Adiantum.
diff --git a/lib/crypto/Makefile b/lib/crypto/Makefile
index 4ad91f390038..94cef4e574cd 100644
--- a/lib/crypto/Makefile
+++ b/lib/crypto/Makefile
@@ -197,10 +197,15 @@ endif # CONFIG_CRYPTO_LIB_MD5_ARCH
obj-$(CONFIG_CRYPTO_LIB_MLDSA) += libmldsa.o
libmldsa-y := mldsa.o
################################################################################
+obj-$(CONFIG_CRYPTO_LIB_MLKEM) += libmlkem.o
+libmlkem-y := mlkem.o
+
+################################################################################
+
obj-$(CONFIG_CRYPTO_LIB_NH) += libnh.o
libnh-y := nh.o
ifeq ($(CONFIG_CRYPTO_LIB_NH_ARCH),y)
CFLAGS_nh.o += -I$(src)/$(SRCARCH)
libnh-$(CONFIG_ARM) += arm/nh-neon-core.o
diff --git a/lib/crypto/mlkem.c b/lib/crypto/mlkem.c
new file mode 100644
index 000000000000..b800ecb49f24
--- /dev/null
+++ b/lib/crypto/mlkem.c
@@ -0,0 +1,894 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
+/*
+ * ML-KEM (Module-Lattice-Based Key Encapsulation Mechanism)
+ *
+ * See include/crypto/mlkem.h for the documentation.
+ *
+ * Copyright 2026 Google LLC
+ */
+
+#include <crypto/mlkem.h>
+#include <crypto/sha3.h>
+#include <crypto/utils.h>
+#include <linux/export.h>
+#include <linux/module.h>
+#include <linux/random.h>
+#include <linux/slab.h>
+#include <linux/string.h>
+#include <linux/unaligned.h>
+
+#define Q 3329 /* The prime q = 2^8 * 13 + 1 */
+#define N 256 /* Number of coefficients per ring element */
+#define MAX_K 4 /* Max matrix dimension (k parameter) for any parameter set */
+#define MAX_CIPHERTEXT_BYTES 1568 /* Max ciphertext length for any param set */
+#define ETA 2 /* Value of eta_1 and eta_2 for ML-KEM-768 and ML-KEM-1024 */
+
+/*
+ * This array contains zeta^BitRev_7(i) for 0 <= i < 128.
+ * Generated by the following Python code:
+ * [pow(17, int(f'{i:07b}'[::-1], 2), 3329) for i in range(128)]
+ * Also matches the first table in FIPS 203 Appendix A.
+ */
+static const u16 zetas[128] = {
+ 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
+ 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
+ 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
+ 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
+ 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
+ 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
+ 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
+ 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
+ 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
+ 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
+ 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
+};
+
+/*
+ * This array contains zeta^(2*BitRev_7(i)+1) for 0 <= i < 128.
+ * Generated by the following Python code:
+ * [pow(17, 2*int(f'{i:07b}'[::-1], 2)+1, 3329) for i in range(128)]
+ * Also matches the second table in FIPS 203 Appendix A, with the values
+ * canonicalized to their standard representatives in [0, Q - 1].
+ */
+static const u16 gammas[128] = {
+ 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
+ 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
+ 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
+ 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
+ 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
+ 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
+ 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
+ 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
+ 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
+ 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
+ 2110, 1219, 2935, 394, 885, 2444, 2154, 1175
+};
+
+struct mlkem_parameter_set {
+ u8 k; /* Rank of the module (number of polynomials in vectors) */
+ u8 du; /* Compression parameter for vector u */
+ u8 dv; /* Compression parameter for polynomial v */
+ u16 pk_len; /* Length of public keys in bytes */
+ u16 sk_len; /* Length of secret keys in bytes */
+ u16 ct_len; /* Length of ciphertexts in bytes */
+};
+
+/* ML-KEM-768 parameters. Reference: FIPS 203 Section 8, "Parameter Sets" */
+static const struct mlkem_parameter_set mlkem768 = {
+ .k = 3,
+ .du = 10,
+ .dv = 4,
+ .pk_len = MLKEM768_PUBLIC_KEY_BYTES,
+ .sk_len = MLKEM768_SECRET_KEY_BYTES,
+ .ct_len = MLKEM768_CIPHERTEXT_BYTES,
+};
+
+/* ML-KEM-1024 parameters. Reference: FIPS 203 Section 8, "Parameter Sets" */
+static const struct mlkem_parameter_set mlkem1024 = {
+ .k = 4,
+ .du = 11,
+ .dv = 5,
+ .pk_len = MLKEM1024_PUBLIC_KEY_BYTES,
+ .sk_len = MLKEM1024_SECRET_KEY_BYTES,
+ .ct_len = MLKEM1024_CIPHERTEXT_BYTES,
+};
+
+/*
+ * An element of the ring R_q (normal form) or the ring T_q (NTT form). In both
+ * cases it consists of N integers in the interval [0, Q - 1]. In R_q, these
+ * are the coefficients of one polynomial of maximum degree N - 1. In T_q,
+ * these are the coefficients of N / 2 polynomials of maximum degree 1.
+ */
+struct mlkem_ring_elem {
+ u16 x[N];
+};
+
+struct k_pke_keygen_workspace {
+ union { /* These aren't used at the same time, so they can overlap. */
+ struct shake_ctx shake;
+ struct sha3_ctx sha3;
+ };
+ union {
+ struct {
+ u8 rho[32];
+ u8 sigma[32];
+ };
+ u8 rho_and_sigma[64];
+ };
+ u8 block[MAX(64 * ETA, SHAKE128_BLOCK_SIZE)];
+ struct mlkem_ring_elem tmp, t;
+ struct mlkem_ring_elem s[MAX_K];
+};
+
+struct k_pke_encrypt_workspace {
+ struct shake_ctx shake;
+ u8 block[MAX(64 * ETA, SHAKE128_BLOCK_SIZE)];
+ struct mlkem_ring_elem tmp, y[MAX_K];
+ union { /* These aren't used at the same time, so they can overlap. */
+ struct mlkem_ring_elem u, v;
+ };
+};
+
+struct k_pke_decrypt_workspace {
+ struct mlkem_ring_elem su, s, tmp;
+};
+
+struct mlkem_encap_workspace {
+ union { /* These aren't used at the same time, so they can overlap. */
+ struct sha3_ctx sha3;
+ struct k_pke_encrypt_workspace k_pke_enc;
+ };
+ union {
+ struct {
+ u8 K[32];
+ u8 r[32];
+ };
+ u8 K_and_r[64];
+ };
+ u8 pk_hash[SHA3_256_DIGEST_SIZE];
+};
+
+struct mlkem_decap_workspace {
+ union { /* These aren't used at the same time, so they can overlap. */
+ struct shake_ctx shake;
+ struct sha3_ctx sha3;
+ struct k_pke_encrypt_workspace k_pke_enc;
+ struct k_pke_decrypt_workspace k_pke_dec;
+ };
+ union {
+ struct {
+ u8 K_prime[32];
+ u8 r_prime[32];
+ };
+ u8 K_prime_and_r_prime[64];
+ };
+ u8 h[SHA3_256_DIGEST_SIZE];
+ u8 m_prime[32];
+ u8 K_bar[32];
+ u8 ct_prime[MAX_CIPHERTEXT_BYTES];
+};
+
+/* Reduce @x to its standard representative mod Q, where 0 <= @x < 2*Q */
+static u16 reduce_once(u16 x)
+{
+ u16 x_minus_q, mask;
+
+ x_minus_q = x - Q; /* This wraps around, setting bit 15, if @x < Q */
+ OPTIMIZER_HIDE_VAR(x_minus_q);
+ mask = (s16)x_minus_q >> 15; /* mask = 0xffff if @x < Q, else 0 */
+ OPTIMIZER_HIDE_VAR(mask);
+
+ return (mask & x) | (~mask & x_minus_q);
+}
+
+/* Reduce @x to its standard representative mod Q, where 0 <= @x < Q + 2*Q*Q */
+static u16 reduce(u32 x)
+{
+ /* Barrett reduction: x - floor((x * floor(2^24 / Q)) / 2^24) * Q */
+ u64 product = (u64)x * 5039;
+ u32 quotient = (u32)(product >> 24);
+ u32 remainder = x - quotient * Q; /* 0 <= remainder < 2*Q */
+
+ return reduce_once(remainder);
+}
+
+/*
+ * Convert @f to its number-theoretically-transformed representation in-place.
+ * Reference: FIPS 203 Algorithm 9, NTT
+ */
+static void ntt(struct mlkem_ring_elem *f)
+{
+ int i = 1; /* index in zetas */
+
+ for (int len = 128; len >= 2; len /= 2) {
+ for (int start = 0; start < 256; start += 2 * len) {
+ u32 zeta = zetas[i++];
+
+ for (int j = start; j < start + len; j++) {
+ u16 t = reduce(zeta * f->x[j + len]);
+
+ f->x[j + len] = reduce_once(Q + f->x[j] - t);
+ f->x[j] = reduce_once(f->x[j] + t);
+ }
+ }
+ }
+}
+
+/*
+ * Convert @f from its number-theoretically-transformed representation in-place.
+ *
+ * Reference: FIPS 203 Algorithm 10, NTT^-1
+ */
+static void invntt(struct mlkem_ring_elem *f)
+{
+ int i = 127; /* index in zetas */
+
+ for (int len = 2; len <= 128; len *= 2) {
+ for (int start = 0; start < 256; start += 2 * len) {
+ u32 zeta = zetas[i--];
+
+ for (int j = start; j < start + len; j++) {
+ u16 t = f->x[j];
+
+ f->x[j] = reduce_once(t + f->x[j + len]);
+ f->x[j + len] =
+ reduce(zeta * (Q + f->x[j + len] - t));
+ }
+ }
+ }
+
+ /* Multiply by 128^-1 = 3303 mod Q. */
+ for (int j = 0; j < 256; j++)
+ f->x[j] = reduce((u32)f->x[j] * 3303);
+}
+
+/* Compute @a += @b, where @a and @b are both elements of either R_q or T_q. */
+static void poly_add(struct mlkem_ring_elem *a, const struct mlkem_ring_elem *b)
+{
+ for (int i = 0; i < N; i++)
+ a->x[i] = reduce_once(a->x[i] + b->x[i]);
+}
+
+/* Compute @a -= @b, where @a and @b are both elements of either R_q or T_q. */
+static void poly_sub(struct mlkem_ring_elem *a, const struct mlkem_ring_elem *b)
+{
+ for (int i = 0; i < N; i++)
+ a->x[i] = reduce_once(Q + a->x[i] - b->x[i]);
+}
+
+/*
+ * Compute @h += @f * @g in the ring T_q.
+ *
+ * Reference: FIPS 203 Algorithm 11, MultiplyNTTs
+ */
+static void poly_mul_acc(struct mlkem_ring_elem *h,
+ const struct mlkem_ring_elem *f,
+ const struct mlkem_ring_elem *g)
+{
+ for (int i = 0; i < N / 2; i++) {
+ u32 a0 = f->x[2 * i];
+ u32 a1 = f->x[2 * i + 1];
+ u32 b0 = g->x[2 * i];
+ u32 b1 = g->x[2 * i + 1];
+ u16 c0 = reduce(a0 * b0 + reduce(a1 * b1) * (u32)gammas[i]);
+ u16 c1 = reduce(a0 * b1 + a1 * b0);
+
+ h->x[2 * i] = reduce_once(h->x[2 * i] + c0);
+ h->x[2 * i + 1] = reduce_once(h->x[2 * i + 1] + c1);
+ }
+}
+
+/*
+ * "Compress" @x in the interval [0, Q - 1] into the smaller interval [0, 2^d
+ * - 1] for the given 1 <= @d <= 11 by computing:
+ *
+ * round((2^d / Q) * x) mod 2^d
+ * = floor((x * 2^d + floor(Q / 2)) / Q) mod 2^d
+ *
+ * Reference: FIPS 203 Section 4.2.1, "Conversion and Compression Algorithms"
+ */
+static u16 compress_d(u16 x, int d)
+{
+ u32 t = ((u32)x << d) + (Q / 2);
+
+ /*
+ * t = floor(t / Q). Avoid potentially variable-time division by using
+ * the equivalent (for the input ranges) reciprocal multiplication
+ * floor((t * ceil(2^32 / Q)) / 2^32). 2^32 is chosen because it's
+ * efficient and provides enough precision for all allowed inputs.
+ */
+ OPTIMIZER_HIDE_VAR(t);
+ t = ((u64)t * 1290168) >> 32;
+ OPTIMIZER_HIDE_VAR(t);
+
+ return t & ((1 << d) - 1);
+}
+
+/*
+ * "Decompress" @y in the interval [0, 2^d - 1] into the larger interval [0, Q
+ * - 1] for the given 1 <= @d <= 11 by computing:
+ *
+ * round((Q / 2^d) * y)
+ * = floor((y * Q + 2^(d-1)) / 2^d)
+ *
+ * Reference: FIPS 203 Section 4.2.1, "Conversion and Compression Algorithms"
+ */
+static u16 decompress_d(u16 y, int d)
+{
+ u32 t;
+
+ OPTIMIZER_HIDE_VAR(y);
+ t = (u32)y * Q;
+ OPTIMIZER_HIDE_VAR(t);
+ return (t + (1 << (d - 1))) >> d;
+}
+
+/*
+ * Encode the ring element @in into 32 * @d bytes at @out.
+ *
+ * Reference: FIPS 203 Algorithm 5, ByteEncode_d
+ */
+static void byte_encode(u8 *out, const struct mlkem_ring_elem *in, int d)
+{
+ u32 acc = 0;
+ int bits = 0;
+
+ for (int i = 0; i < 256; i++) {
+ acc |= (u32)in->x[i] << bits;
+ bits += d;
+ for (; bits >= 8; bits -= 8, acc >>= 8)
+ *out++ = (u8)acc;
+ }
+}
+
+/*
+ * Decode the 32 * @d bytes at @in into a ring element @out with coefficients in
+ * the interval [0, 2^d - 1]. Note that when @d >= 12, decoded coefficients can
+ * be out of range (i.e. >= Q) and need to be validated afterwards.
+ *
+ * Reference: FIPS 203 Algorithm 6, ByteDecode_d
+ */
+static void byte_decode(struct mlkem_ring_elem *out, const u8 *in, int d)
+{
+ const u32 mask = (1 << d) - 1;
+ u32 acc = 0;
+ int bits = 0;
+
+ for (int i = 0; i < 256; i++) {
+ for (; bits < d; bits += 8)
+ acc |= (u32)(*in++) << bits;
+ out->x[i] = acc & mask;
+ acc >>= d;
+ bits -= d;
+ }
+}
+
+static void compress_and_encode(u8 *out, struct mlkem_ring_elem *f, int d)
+{
+ for (int i = 0; i < N; i++)
+ f->x[i] = compress_d(f->x[i], d);
+ byte_encode(out, f, d);
+}
+
+static void decode_and_decompress(struct mlkem_ring_elem *out, const u8 *in,
+ int d)
+{
+ byte_decode(out, in, d);
+ for (int i = 0; i < N; i++)
+ out->x[i] = decompress_d(out->x[i], d);
+}
+
+/*
+ * Check in constant time whether any coefficients in @f are outside the
+ * interval [0, Q - 1]. This is needed after calling byte_decode() with d=12.
+ */
+static bool is_out_of_range(const struct mlkem_ring_elem *f)
+{
+ u32 bad = 0;
+
+ for (int i = 0; i < N; i++) {
+ OPTIMIZER_HIDE_VAR(bad);
+ bad |= Q - 1 - f->x[i]; /* Set bit 31 if f->x[i] >= Q. */
+ }
+ OPTIMIZER_HIDE_VAR(bad);
+ return bad >> 31;
+}
+
+/*
+ * Expand the 33 input bytes (@s, @b) into a random polynomial @out with
+ * coefficients in [-ETA, ETA] using a Centered Binomial
+ * Distribution (CBD). @shake and @block are used as temporary space.
+ *
+ * This is FIPS 203 Algorithm 8 "SamplePolyCBD_eta", composed with PRF_eta.
+ * (SamplePolyCBD_eta is invoked only with the output of PRF_eta.)
+ */
+static void sample_poly_cbd(struct mlkem_ring_elem *out, const u8 s[32], u8 b,
+ struct shake_ctx *shake,
+ u8 block[at_least 64 * ETA])
+{
+ /* Expand (s, b) into '64 * ETA' bytes. */
+ shake256_init(shake);
+ shake_update(shake, s, 32);
+ shake_update(shake, &b, 1);
+ shake_squeeze(shake, block, 64 * ETA);
+
+ /*
+ * Generate the coefficients by counting the number of 1 bits in each
+ * ETA-bit group, then subtracting the counts of adjacent pairs.
+ */
+ static_assert(ETA == 2);
+ for (int i = 0; i < 256; i += 16) {
+ const u64 mask = 0x5555555555555555;
+ u64 t = get_unaligned_le64(&block[i / 2]);
+ u64 popcounts = (t & mask) + ((t >> 1) & mask);
+
+ for (int j = 0; j < 16; j++) {
+ u16 x = (popcounts >> (4 * j)) & 0x3;
+ u16 y = (popcounts >> (4 * j + 2)) & 0x3;
+
+ out->x[i + j] = reduce_once(Q + x - y);
+ }
+ }
+}
+
+/*
+ * Generate the element of the matrix NTT(A) at @row and @column from the seed
+ * @rho. The output, which is an element of the ring T_q, is written to @out.
+ * @shake and @block are used as temporary space.
+ *
+ * Reference: FIPS 203 Algorithm 7, SampleNTT
+ */
+static void sample_ntt(struct mlkem_ring_elem *out, const u8 rho[32], u8 row,
+ u8 column, struct shake_ctx *shake,
+ u8 block[at_least SHAKE128_BLOCK_SIZE])
+{
+ shake128_init(shake);
+ shake_update(shake, rho, 32);
+ shake_update(shake, &column, 1);
+ shake_update(shake, &row, 1);
+
+ for (int i = 0; i < 256;) {
+ /*
+ * Squeeze the next block, then use rejection sampling to
+ * generate up to two coefficients from each 3-byte group.
+ */
+ static_assert(SHAKE128_BLOCK_SIZE % 3 == 0);
+ shake_squeeze(shake, block, SHAKE128_BLOCK_SIZE);
+
+ for (int j = 0; j < SHAKE128_BLOCK_SIZE && i < 256; j += 3) {
+ u16 d1 = block[j] | ((block[j + 1] & 0xf) << 8);
+ u16 d2 = (block[j + 1] >> 4) | (block[j + 2] << 4);
+
+ if (d1 < Q)
+ out->x[i++] = d1;
+ if (i < N && d2 < Q)
+ out->x[i++] = d2;
+ }
+ }
+}
+
+/*
+ * Generate a K-PKE key pair (@pk, @sk) from the random seed @d. @pk is 384*k +
+ * 32 bytes, and @sk is 384*k bytes.
+ *
+ * Reference: FIPS 203 Algorithm 13, K-PKE.KeyGen
+ */
+static int k_pke_keygen(u8 *pk, u8 *sk, const u8 d[32],
+ const struct mlkem_parameter_set *params)
+{
+ struct k_pke_keygen_workspace *ws __free(kfree_sensitive) =
+ kmalloc_obj(*ws);
+ const u8 k = params->k;
+
+ if (!ws)
+ return -ENOMEM;
+
+ /* rho || sigma = G(d || k) */
+ sha3_512_init(&ws->sha3);
+ sha3_update(&ws->sha3, d, 32);
+ sha3_update(&ws->sha3, &k, 1);
+ sha3_final(&ws->sha3, ws->rho_and_sigma);
+
+ /* Generate and encode the secret key NTT(s). */
+ for (int i = 0; i < k; i++) {
+ sample_poly_cbd(&ws->s[i], ws->sigma, i, &ws->shake, ws->block);
+ ntt(&ws->s[i]);
+ byte_encode(&sk[384 * i], &ws->s[i], 12);
+ }
+
+ /*
+ * Generate and encode the public key pk = NTT(t) || rho.
+ * To save memory, do it one row at a time.
+ */
+ for (int i = 0; i < k; i++) { /* For each row in A */
+ /* NTT(t) = NTT(A) * NTT(s) + NTT(e) */
+ ws->t = (struct mlkem_ring_elem){};
+ for (int j = 0; j < k; j++) { /* For each column in A */
+ sample_ntt(&ws->tmp, ws->rho, i, j, &ws->shake,
+ ws->block);
+ poly_mul_acc(&ws->t, &ws->tmp, &ws->s[j]);
+ }
+ sample_poly_cbd(&ws->tmp, ws->sigma, k + i, &ws->shake,
+ ws->block);
+ ntt(&ws->tmp);
+ poly_add(&ws->t, &ws->tmp);
+ byte_encode(&pk[384 * i], &ws->t, 12);
+ }
+ memcpy(&pk[384 * k], ws->rho, 32);
+ return 0;
+}
+
+/*
+ * Encrypt the message @m using the public key @pk and randomness @r.
+ *
+ * Reference: FIPS 203 Algorithm 14, K-PKE.Encrypt
+ */
+static int k_pke_encrypt(u8 *ct, const u8 *pk, const u8 m[32], const u8 r[32],
+ struct k_pke_encrypt_workspace *ws,
+ const struct mlkem_parameter_set *params)
+{
+ const u8 k = params->k;
+ const u8 *rho = &pk[384 * k];
+
+ /* Generate the vector NTT(y) from r. */
+ for (int i = 0; i < k; i++) {
+ sample_poly_cbd(&ws->y[i], r, i, &ws->shake, ws->block);
+ ntt(&ws->y[i]);
+ }
+
+ /*
+ * Compute, compress, and encode u = NTT^-1( NTT(A^T) * NTT(y) ) + e_1.
+ * To save memory, do it one row at a time.
+ */
+ for (int i = 0; i < k; i++) { /* For each row in A^T (column in A) */
+ /* u = NTT(A^T) * NTT(y) */
+ ws->u = (struct mlkem_ring_elem){};
+ for (int j = 0; j < k;
+ j++) { /* For each column in A^T (row in A) */
+ sample_ntt(&ws->tmp, rho, j, i, &ws->shake, ws->block);
+ poly_mul_acc(&ws->u, &ws->tmp, &ws->y[j]);
+ }
+ /* u = NTT^-1(u) */
+ invntt(&ws->u);
+
+ /* u += e_1, where e_1 is generated from r */
+ sample_poly_cbd(&ws->tmp, r, k + i, &ws->shake, ws->block);
+ poly_add(&ws->u, &ws->tmp);
+
+ /* Compress and encode u into the ciphertext ct. */
+ compress_and_encode(&ct[32 * params->du * i], &ws->u,
+ params->du);
+ }
+
+ /* v = NTT^-1( NTT(t^T) * NTT(y) ) */
+ ws->v = (struct mlkem_ring_elem){};
+ for (int i = 0; i < k; i++) {
+ /* Decode next element of NTT(t) from pk. */
+ byte_decode(&ws->tmp, &pk[384 * i], 12);
+ if (is_out_of_range(&ws->tmp)) {
+ memzero_explicit(ct, params->ct_len);
+ return -EBADMSG; /* pk is malformed. */
+ }
+ poly_mul_acc(&ws->v, &ws->tmp, &ws->y[i]);
+ }
+ invntt(&ws->v);
+
+ /* v += e_2, where e_2 is generated from r */
+ sample_poly_cbd(&ws->tmp, r, 2 * k, &ws->shake, ws->block);
+ poly_add(&ws->v, &ws->tmp);
+
+ /* Add the message m, after decompressing from 1 bit to [0, Q - 1]. */
+ for (int i = 0; i < N; i++)
+ ws->v.x[i] =
+ reduce_once(ws->v.x[i] +
+ decompress_d((m[i / 8] >> (i & 7)) & 1, 1));
+
+ /* Compress and encode v into the ciphertext ct. */
+ compress_and_encode(&ct[32 * params->du * k], &ws->v, params->dv);
+ return 0;
+}
+
+/*
+ * Decrypt the ciphertext @ct using the secret key @sk.
+ *
+ * Reference: FIPS 203 Algorithm 15, K-PKE.Decrypt
+ */
+static int k_pke_decrypt(u8 m[32], const u8 *sk, const u8 *ct,
+ struct k_pke_decrypt_workspace *ws,
+ const struct mlkem_parameter_set *params)
+{
+ /* Compute s^T * u' */
+ ws->su = (struct mlkem_ring_elem){};
+ for (int i = 0; i < params->k; i++) {
+ /* Decode next element of NTT(s) from sk. */
+ byte_decode(&ws->s, &sk[384 * i], 12);
+ if (is_out_of_range(&ws->s))
+ return -EBADMSG; /* sk is malformed. */
+
+ /* Decode and decompress next element of u' from ct. */
+ decode_and_decompress(&ws->tmp, &ct[32 * params->du * i],
+ params->du);
+ /* NTT(u') [one element] */
+ ntt(&ws->tmp);
+
+ /* NTT(s^T) * NTT(u') [one element] */
+ poly_mul_acc(&ws->su, &ws->s, &ws->tmp);
+ }
+ invntt(&ws->su);
+
+ /* Decode and decompress v' from the ciphertext ct. */
+ decode_and_decompress(&ws->tmp, &ct[32 * params->du * params->k],
+ params->dv);
+
+ /* w = v' - s^T * u' */
+ poly_sub(&ws->tmp, &ws->su);
+
+ /* Decode the plaintext m from the polynomial w. */
+ compress_and_encode(m, &ws->tmp, 1);
+ return 0;
+}
+
+/*
+ * Generate an ML-KEM key pair (@pk, @sk) from the random seed @seed.
+ * The lengths of @pk and @sk are determined by the chosen @params.
+ *
+ * Reference: FIPS 203 Algorithm 16, ML-KEM.KeyGen_internal. Formally it takes
+ * two 32-byte seeds 'd' and 'z'; here @seed is d || z.
+ */
+static int mlkem_keygen_internal(u8 *pk, u8 *sk,
+ const u8 seed[MLKEM_SEED_BYTES],
+ const struct mlkem_parameter_set *params)
+{
+ u8 *sk_ptr;
+ int err;
+
+ /* Generate a key pair for the public key encryption scheme K-PKE. */
+ err = k_pke_keygen(pk, sk, &seed[0], params);
+ if (err)
+ return err;
+
+ /*
+ * Generate the ML-KEM key pair as
+ * Public key (encapsulation key): pk = pk_pke
+ * Secret key (decapsulation key): sk = sk_pke || pk || H(pk) || z
+ *
+ * For pk there's nothing to add. This just handles the sk.
+ */
+ sk_ptr = &sk[384 * params->k]; /* end of sk_pke */
+ memcpy(sk_ptr, pk, params->pk_len);
+ sk_ptr += params->pk_len;
+ sha3_256(pk, params->pk_len, sk_ptr);
+ sk_ptr += SHA3_256_DIGEST_SIZE;
+ memcpy(sk_ptr, &seed[32], 32);
+
+ return 0;
+}
+
+int mlkem768_keygen_internal(u8 pk[MLKEM768_PUBLIC_KEY_BYTES],
+ u8 sk[MLKEM768_SECRET_KEY_BYTES],
+ const u8 seed[MLKEM_SEED_BYTES])
+{
+ return mlkem_keygen_internal(pk, sk, seed, &mlkem768);
+}
+EXPORT_SYMBOL_NS_GPL(mlkem768_keygen_internal, "CRYPTO_INTERNAL");
+
+int mlkem768_keygen(u8 pk[MLKEM768_PUBLIC_KEY_BYTES],
+ u8 sk[MLKEM768_SECRET_KEY_BYTES])
+{
+ u8 seed[MLKEM_SEED_BYTES];
+ int err;
+
+ get_random_bytes(seed, sizeof(seed));
+ err = mlkem768_keygen_internal(pk, sk, seed);
+ memzero_explicit(seed, sizeof(seed));
+ return err;
+}
+EXPORT_SYMBOL_NS_GPL(mlkem768_keygen, "CRYPTO_INTERNAL");
+
+int mlkem1024_keygen_internal(u8 pk[MLKEM1024_PUBLIC_KEY_BYTES],
+ u8 sk[MLKEM1024_SECRET_KEY_BYTES],
+ const u8 seed[MLKEM_SEED_BYTES])
+{
+ return mlkem_keygen_internal(pk, sk, seed, &mlkem1024);
+}
+EXPORT_SYMBOL_NS_GPL(mlkem1024_keygen_internal, "CRYPTO_INTERNAL");
+
+int mlkem1024_keygen(u8 pk[MLKEM1024_PUBLIC_KEY_BYTES],
+ u8 sk[MLKEM1024_SECRET_KEY_BYTES])
+{
+ u8 seed[MLKEM_SEED_BYTES];
+ int err;
+
+ get_random_bytes(seed, sizeof(seed));
+ err = mlkem1024_keygen_internal(pk, sk, seed);
+ memzero_explicit(seed, sizeof(seed));
+ return err;
+}
+EXPORT_SYMBOL_NS_GPL(mlkem1024_keygen, "CRYPTO_INTERNAL");
+
+/*
+ * Generate and encapsulate a shared secret, using the random seed @eseed.
+ *
+ * Reference: FIPS 203 Algorithm 17, ML-KEM.Encaps_internal
+ */
+static int mlkem_encaps_internal(u8 *ct, u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 *pk,
+ const u8 eseed[MLKEM_ESEED_BYTES],
+ const struct mlkem_parameter_set *params)
+{
+ struct mlkem_encap_workspace *ws __free(kfree_sensitive) =
+ kmalloc_obj(*ws);
+ int err;
+
+ if (!ws)
+ return -ENOMEM;
+
+ /* Derived shared secret key K and randomness r. */
+ sha3_256(pk, params->pk_len, ws->pk_hash);
+ sha3_512_init(&ws->sha3);
+ sha3_update(&ws->sha3, eseed, MLKEM_ESEED_BYTES);
+ sha3_update(&ws->sha3, ws->pk_hash, SHA3_256_DIGEST_SIZE);
+ sha3_final(&ws->sha3, ws->K_and_r);
+
+ /* Encrypt eseed using K-PKE with randomness r. */
+ err = k_pke_encrypt(ct, pk, eseed, ws->r, &ws->k_pke_enc, params);
+ if (err == 0)
+ memcpy(ss, ws->K, MLKEM_SHARED_SECRET_BYTES);
+ return err;
+}
+
+int mlkem768_encaps_internal(u8 ct[MLKEM768_CIPHERTEXT_BYTES],
+ u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 pk[MLKEM768_PUBLIC_KEY_BYTES],
+ const u8 eseed[MLKEM_ESEED_BYTES])
+{
+ return mlkem_encaps_internal(ct, ss, pk, eseed, &mlkem768);
+}
+EXPORT_SYMBOL_NS_GPL(mlkem768_encaps_internal, "CRYPTO_INTERNAL");
+
+int mlkem768_encaps(u8 ct[MLKEM768_CIPHERTEXT_BYTES],
+ u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 pk[MLKEM768_PUBLIC_KEY_BYTES])
+{
+ u8 eseed[MLKEM_ESEED_BYTES];
+ int err;
+
+ get_random_bytes(eseed, sizeof(eseed));
+ err = mlkem768_encaps_internal(ct, ss, pk, eseed);
+ memzero_explicit(eseed, sizeof(eseed));
+ return err;
+}
+EXPORT_SYMBOL_NS_GPL(mlkem768_encaps, "CRYPTO_INTERNAL");
+
+int mlkem1024_encaps_internal(u8 ct[MLKEM1024_CIPHERTEXT_BYTES],
+ u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 pk[MLKEM1024_PUBLIC_KEY_BYTES],
+ const u8 eseed[MLKEM_ESEED_BYTES])
+{
+ return mlkem_encaps_internal(ct, ss, pk, eseed, &mlkem1024);
+}
+EXPORT_SYMBOL_NS_GPL(mlkem1024_encaps_internal, "CRYPTO_INTERNAL");
+
+int mlkem1024_encaps(u8 ct[MLKEM1024_CIPHERTEXT_BYTES],
+ u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 pk[MLKEM1024_PUBLIC_KEY_BYTES])
+{
+ u8 eseed[MLKEM_ESEED_BYTES];
+ int err;
+
+ get_random_bytes(eseed, sizeof(eseed));
+ err = mlkem1024_encaps_internal(ct, ss, pk, eseed);
+ memzero_explicit(eseed, sizeof(eseed));
+ return err;
+}
+EXPORT_SYMBOL_NS_GPL(mlkem1024_encaps, "CRYPTO_INTERNAL");
+
+/* Reference: FIPS 203 Algorithm 21, ML-KEM.Decaps */
+static int mlkem_decaps(u8 ss[MLKEM_SHARED_SECRET_BYTES], const u8 *ct,
+ const u8 *sk, const struct mlkem_parameter_set *params)
+{
+ struct mlkem_decap_workspace *ws __free(kfree_sensitive) =
+ kmalloc_obj(*ws);
+ const u8 *sk_pke, *pk_pke, *h, *z;
+ u32 mask;
+ int err;
+
+ if (!ws) {
+ err = -ENOMEM;
+ goto decaps_failed;
+ }
+
+ /* Extract the parts of the secret key. */
+ sk_pke = sk; /* PKE secret key */
+ pk_pke = &sk_pke[384 * params->k]; /* PKE public key */
+ h = &pk_pke[params->pk_len]; /* Hash of PKE public key */
+ z = &h[SHA3_256_DIGEST_SIZE]; /* Implicit rejection value */
+
+ /* Validate h, as specified in the "Hash check" in FIPS 203. */
+ sha3_256(pk_pke, params->pk_len, ws->h);
+ if (crypto_memneq(h, ws->h, SHA3_256_DIGEST_SIZE)) {
+ /*
+ * This is technically a branch on the contents of the secret
+ * key, but it only indicates whether it's malformed or not.
+ */
+ err = -EBADMSG;
+ goto decaps_failed;
+ }
+
+ /* m' = K-PKE.Decrypt(sk_pke, ct) */
+ err = k_pke_decrypt(ws->m_prime, sk_pke, ct, &ws->k_pke_dec, params);
+ if (err) {
+ /*
+ * This is technically a branch on the contents of the secret
+ * key, but it only indicates whether it's malformed or not.
+ */
+ goto decaps_failed;
+ }
+
+ /* (K', r') = G(m' || h) */
+ sha3_512_init(&ws->sha3);
+ sha3_update(&ws->sha3, ws->m_prime, 32);
+ sha3_update(&ws->sha3, h, SHA3_256_DIGEST_SIZE);
+ sha3_final(&ws->sha3, ws->K_prime_and_r_prime);
+
+ /* K_bar = J(z || ct) */
+ shake256_init(&ws->shake);
+ shake_update(&ws->shake, z, 32);
+ shake_update(&ws->shake, ct, params->ct_len);
+ shake_squeeze(&ws->shake, ws->K_bar, 32);
+
+ /* ct' = K-PKE.Encrypt(pk_pke, m', r') */
+ err = k_pke_encrypt(ws->ct_prime, pk_pke, ws->m_prime, ws->r_prime,
+ &ws->k_pke_enc, params);
+ if (err)
+ goto decaps_failed;
+
+ /*
+ * Return the shared secret: K_bar if ct != ct', else K'.
+ *
+ * This entire section must be constant-time with respect to not only
+ * the contents of ct, ct', K', and K_bar, but also whether K' or K_bar
+ * is chosen. crypto_memneq() isn't necessarily safe to use here, as it
+ * checks its result with a ternary expression.
+ */
+ mask = 0;
+ for (int i = 0; i < params->ct_len; i++) {
+ OPTIMIZER_HIDE_VAR(mask);
+ mask |= ct[i] ^ ws->ct_prime[i];
+ }
+ OPTIMIZER_HIDE_VAR(mask);
+ mask = -mask; /* bit 31 is 1 if ct != ct', else 0 */
+ OPTIMIZER_HIDE_VAR(mask);
+ mask = (s32)mask >> 31; /* mask = 0xffffffff if ct != ct', else 0 */
+ OPTIMIZER_HIDE_VAR(mask);
+ for (int i = 0; i < 32; i++)
+ ss[i] = (ws->K_bar[i] & mask) | (ws->K_prime[i] & ~mask);
+ return 0;
+
+decaps_failed:
+ /*
+ * On error, return a random shared secret. This is a safer default in
+ * the case where the caller forgets to check the return value.
+ */
+ get_random_bytes(ss, MLKEM_SHARED_SECRET_BYTES);
+ return err;
+}
+
+int mlkem768_decaps(u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 ct[MLKEM768_CIPHERTEXT_BYTES],
+ const u8 sk[MLKEM768_SECRET_KEY_BYTES])
+{
+ return mlkem_decaps(ss, ct, sk, &mlkem768);
+}
+EXPORT_SYMBOL_NS_GPL(mlkem768_decaps, "CRYPTO_INTERNAL");
+
+int mlkem1024_decaps(u8 ss[MLKEM_SHARED_SECRET_BYTES],
+ const u8 ct[MLKEM1024_CIPHERTEXT_BYTES],
+ const u8 sk[MLKEM1024_SECRET_KEY_BYTES])
+{
+ return mlkem_decaps(ss, ct, sk, &mlkem1024);
+}
+EXPORT_SYMBOL_NS_GPL(mlkem1024_decaps, "CRYPTO_INTERNAL");
+
+MODULE_DESCRIPTION("ML-KEM (Module-Lattice-Based Key Encapsulation Mechanism)");
+MODULE_LICENSE("GPL");
--
2.54.0
next prev parent reply other threads:[~2026-05-25 18:46 UTC|newest]
Thread overview: 6+ messages / expand[flat|nested] mbox.gz Atom feed top
2026-05-25 18:43 [PATCH 0/5] ML-KEM and X-Wing support Eric Biggers
2026-05-25 18:43 ` Eric Biggers [this message]
2026-05-25 18:44 ` [PATCH 2/5] lib/crypto: mlkem: Add KUnit tests for ML-KEM Eric Biggers
2026-05-25 18:44 ` [PATCH 3/5] lib/crypto: mlkem: Add FIPS 140-3 tests Eric Biggers
2026-05-25 18:44 ` [PATCH 4/5] lib/crypto: xwing: Add support for X-Wing KEM Eric Biggers
2026-05-25 18:44 ` [PATCH 5/5] lib/crypto: xwing: Add KUnit tests " Eric Biggers
Reply instructions:
You may reply publicly to this message via plain-text email
using any one of the following methods:
* Save the following mbox file, import it into your mail client,
and reply-to-all from there: mbox
Avoid top-posting and favor interleaved quoting:
https://en.wikipedia.org/wiki/Posting_style#Interleaved_style
* Reply using the --to, --cc, and --in-reply-to
switches of git-send-email(1):
git send-email \
--in-reply-to=20260525184403.101818-2-ebiggers@kernel.org \
--to=ebiggers@kernel.org \
--cc=Jason@zx2c4.com \
--cc=ardb@kernel.org \
--cc=cleech@redhat.com \
--cc=herbert@gondor.apana.org.au \
--cc=linux-crypto@vger.kernel.org \
--cc=linux-kernel@vger.kernel.org \
--cc=ryan.appel.333@gmail.com \
/path/to/YOUR_REPLY
https://kernel.org/pub/software/scm/git/docs/git-send-email.html
* If your mail client supports setting the In-Reply-To header
via mailto: links, try the mailto: link
Be sure your reply has a Subject: header at the top and a blank line
before the message body.
This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox