Skip to content

Commit db0fa30

Browse files
authored
Merge pull request #8436 from SparkiDev/mlkem_cache_a
ML-KEM/Kyber: cache A from key generation for decapsulation
2 parents 896ec23 + 9253d1d commit db0fa30

File tree

4 files changed

+96
-26
lines changed

4 files changed

+96
-26
lines changed

configure.ac

+3
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,9 @@ do
13991399
small)
14001400
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_KYBER_SMALL"
14011401
;;
1402+
cache-a)
1403+
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_MLKEM_CACHE_A"
1404+
;;
14021405
512)
14031406
ENABLED_KYBER512=yes
14041407
;;

wolfcrypt/benchmark/benchmark.c

+31-9
Original file line numberDiff line numberDiff line change
@@ -9630,17 +9630,37 @@ static void bench_kyber_keygen(int type, const char* name, int keySize,
96309630
#endif
96319631
}
96329632

9633-
static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
9633+
static void bench_kyber_encap(int type, const char* name, int keySize,
9634+
KyberKey* key1, KyberKey* key2)
96349635
{
96359636
int ret = 0, times, count, pending = 0;
96369637
double start;
96379638
const char**desc = bench_desc_words[lng_index];
96389639
byte ct[KYBER_MAX_CIPHER_TEXT_SIZE];
96399640
byte ss[KYBER_SS_SZ];
9641+
byte pub[KYBER_MAX_PUBLIC_KEY_SIZE];
9642+
word32 pubLen;
96409643
word32 ctSz;
96419644
DECLARE_MULTI_VALUE_STATS_VARS()
96429645

9643-
ret = wc_KyberKey_CipherTextSize(key, &ctSz);
9646+
ret = wc_KyberKey_PublicKeySize(key1, &pubLen);
9647+
if (ret != 0) {
9648+
return;
9649+
}
9650+
ret = wc_KyberKey_EncodePublicKey(key1, pub, pubLen);
9651+
if (ret != 0) {
9652+
return;
9653+
}
9654+
ret = wc_KyberKey_Init(type, key2, HEAP_HINT, INVALID_DEVID);
9655+
if (ret != 0) {
9656+
return;
9657+
}
9658+
ret = wc_KyberKey_DecodePublicKey(key2, pub, pubLen);
9659+
if (ret != 0) {
9660+
return;
9661+
}
9662+
9663+
ret = wc_KyberKey_CipherTextSize(key2, &ctSz);
96449664
if (ret != 0) {
96459665
return;
96469666
}
@@ -9651,10 +9671,10 @@ static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
96519671
/* while free pending slots in queue, submit ops */
96529672
for (times = 0; times < agreeTimes || pending > 0; times++) {
96539673
#ifdef KYBER_NONDETERMINISTIC
9654-
ret = wc_KyberKey_Encapsulate(key, ct, ss, &gRng);
9674+
ret = wc_KyberKey_Encapsulate(key2, ct, ss, &gRng);
96559675
#else
96569676
unsigned char rand[KYBER_ENC_RAND_SZ] = {0,};
9657-
ret = wc_KyberKey_EncapsulateWithRandom(key, ct, ss, rand,
9677+
ret = wc_KyberKey_EncapsulateWithRandom(key2, ct, ss, rand,
96589678
sizeof(rand));
96599679
#endif
96609680
if (ret != 0)
@@ -9681,7 +9701,7 @@ static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
96819701
do {
96829702
/* while free pending slots in queue, submit ops */
96839703
for (times = 0; times < agreeTimes || pending > 0; times++) {
9684-
ret = wc_KyberKey_Decapsulate(key, ss, ct, ctSz);
9704+
ret = wc_KyberKey_Decapsulate(key1, ss, ct, ctSz);
96859705
if (ret != 0)
96869706
goto exit_decap;
96879707
RECORD_MULTI_VALUE_STATS();
@@ -9702,7 +9722,8 @@ static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
97029722

97039723
void bench_kyber(int type)
97049724
{
9705-
KyberKey key;
9725+
KyberKey key1;
9726+
KyberKey key2;
97069727
const char* name = NULL;
97079728
int keySize = 0;
97089729

@@ -9749,10 +9770,11 @@ void bench_kyber(int type)
97499770
#endif
97509771
}
97519772

9752-
bench_kyber_keygen(type, name, keySize, &key);
9753-
bench_kyber_encap(name, keySize, &key);
9773+
bench_kyber_keygen(type, name, keySize, &key1);
9774+
bench_kyber_encap(type, name, keySize, &key1, &key2);
97549775

9755-
wc_KyberKey_Free(&key);
9776+
wc_KyberKey_Free(&key2);
9777+
wc_KyberKey_Free(&key1);
97569778
}
97579779
#endif
97589780

wolfcrypt/src/wc_kyber.c

+57-17
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@
6363
#error "Can't use small memory with assembly optimized code"
6464
#endif
6565
#endif
66+
#if defined(WOLFSSL_MLKEM_CACHE_A)
67+
#if defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) || \
68+
defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM)
69+
#error "Can't cache A with small memory code"
70+
#endif
71+
#endif
6672

6773
#ifdef WOLFSSL_WC_KYBER
6874

@@ -265,10 +271,14 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
265271
sword16* e = NULL;
266272
#else
267273
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
274+
#ifndef WOLFSSL_MLKEM_CACHE_A
268275
sword16 e[(KYBER_MAX_K + 1) * KYBER_MAX_K * KYBER_N];
269276
#else
270277
sword16 e[KYBER_MAX_K * KYBER_N];
271278
#endif
279+
#else
280+
sword16 e[KYBER_MAX_K * KYBER_N];
281+
#endif
272282
#endif
273283
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
274284
sword16* a = NULL;
@@ -285,6 +295,8 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
285295
}
286296

287297
if (ret == 0) {
298+
key->flags = 0;
299+
288300
/* Establish parameters based on key type. */
289301
switch (key->type) {
290302
#ifndef WOLFSSL_NO_ML_KEM
@@ -332,9 +344,17 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
332344
if (ret == 0) {
333345
/* Allocate dynamic memory for matrix and error vector. */
334346
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
347+
#ifndef WOLFSSL_MLKEM_CACHE_A
348+
/* e (v) | a (m) */
335349
e = (sword16*)XMALLOC((kp + 1) * kp * KYBER_N * sizeof(sword16),
336350
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
337351
#else
352+
/* e (v) */
353+
e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16),
354+
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
355+
#endif
356+
#else
357+
/* e (v) */
338358
e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16),
339359
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
340360
#endif
@@ -346,8 +366,10 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
346366
if (ret == 0) {
347367
const byte* d = rand;
348368

349-
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
350-
/* Error vector allocated at end of a. */
369+
#ifdef WOLFSSL_MLKEM_CACHE_A
370+
a = key->a;
371+
#elif !defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM)
372+
/* Matrix A allocated at end of error vector. */
351373
a = e + (kp * KYBER_N);
352374
#endif
353375

@@ -391,6 +413,9 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
391413
ret = kyber_gen_matrix(&key->prf, a, kp, pubSeed, 0);
392414
}
393415
if (ret == 0) {
416+
#ifdef WOLFSSL_MLKEM_CACHE_A
417+
key->flags |= KYBER_FLAG_A_SET;
418+
#endif
394419
/* Generate key pair from random data. */
395420
kyber_keygen(key->priv, key->pub, e, a, kp);
396421
#else
@@ -514,7 +539,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
514539
unsigned char* ct)
515540
{
516541
int ret = 0;
517-
sword16* sp = NULL;
542+
sword16* at = NULL;
518543
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
519544
sword16* k = NULL;
520545
sword16* ep = NULL;
@@ -523,12 +548,12 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
523548
unsigned int kp = 0;
524549
unsigned int compVecSz = 0;
525550
#ifndef WOLFSSL_NO_MALLOC
526-
sword16* at = NULL;
551+
sword16* sp = NULL;
527552
#else
528553
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
529-
sword16 at[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N];
554+
sword16 sp[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N];
530555
#else
531-
sword16 at[3 * KYBER_MAX_K * KYBER_N];
556+
sword16 sp[3 * KYBER_MAX_K * KYBER_N];
532557
#endif
533558
#endif
534559
#ifdef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
@@ -588,13 +613,13 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
588613
if (ret == 0) {
589614
/* Allocate dynamic memory for all matrices, vectors and polynomials. */
590615
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
591-
at = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16),
616+
sp = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16),
592617
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
593618
#else
594-
at = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap,
619+
sp = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap,
595620
DYNAMIC_TYPE_TMP_BUFFER);
596621
#endif
597-
if (at == NULL) {
622+
if (sp == NULL) {
598623
ret = MEMORY_E;
599624
}
600625
}
@@ -603,15 +628,15 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
603628
if (ret == 0) {
604629
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
605630
/* Assign allocated dynamic memory to pointers.
606-
* at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p) */
631+
* sp (b) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p) */
632+
at = sp + KYBER_N * kp;
607633
k = at + KYBER_N * kp * kp;
608-
sp = k + KYBER_N;
609-
ep = sp + KYBER_N * kp;
634+
ep = k + KYBER_N;
610635
epp = ep + KYBER_N * kp;
611636
#else
612637
/* Assign allocated dynamic memory to pointers.
613-
* at (v) | sp (v) | bp (v) */
614-
sp = at + KYBER_N * kp;
638+
* sp (v) | at (v) | bp (v) */
639+
at = sp + KYBER_N * kp;
615640
#endif
616641

617642
/* Initialize the PRF for use in the noise generation. */
@@ -623,6 +648,21 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
623648
/* Generate noise using PRF. */
624649
ret = kyber_get_noise(&key->prf, kp, sp, ep, epp, coins);
625650
}
651+
#ifdef WOLFSSL_MLKEM_CACHE_A
652+
if ((ret == 0) && ((key->flags & KYBER_FLAG_A_SET) != 0)) {
653+
unsigned int i;
654+
/* Transpose matrix. */
655+
for (i = 0; i < kp; i++) {
656+
unsigned int j;
657+
for (j = 0; j < kp; j++) {
658+
XMEMCPY(&at[(i * kp + j) * KYBER_N],
659+
&key->a[(j * kp + i) * KYBER_N],
660+
KYBER_N * 2);
661+
}
662+
}
663+
}
664+
else
665+
#endif
626666
if (ret == 0) {
627667
/* Generate the transposed matrix. */
628668
ret = kyber_gen_matrix(&key->prf, at, kp, key->pubSeed, 1);
@@ -632,7 +672,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
632672
sword16* v;
633673

634674
/* Assign remaining allocated dynamic memory to pointers.
635-
* at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p)*/
675+
* sp (v) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p)*/
636676
bp = epp + KYBER_N;
637677
v = bp + KYBER_N * kp;
638678

@@ -644,7 +684,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
644684
}
645685
if (ret == 0) {
646686
/* Assign remaining allocated dynamic memory to pointers.
647-
* at (v) | sp (v) | bp (v) */
687+
* sp (v) | at (v) | bp (v) */
648688
bp = sp + KYBER_N * kp;
649689
v = at;
650690

@@ -676,7 +716,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
676716

677717
#ifndef WOLFSSL_NO_MALLOC
678718
/* Dispose of dynamic memory allocated in function. */
679-
XFREE(at, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
719+
XFREE(sp, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
680720
#endif
681721

682722
return ret;

wolfssl/wolfcrypt/wc_kyber.h

+5
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ enum {
6262
KYBER_FLAG_PUB_SET = 0x0002,
6363
KYBER_FLAG_BOTH_SET = 0x0003,
6464
KYBER_FLAG_H_SET = 0x0004,
65+
KYBER_FLAG_A_SET = 0x0008,
6566

6667
/* 2 bits of random used to create noise value. */
6768
KYBER_CBD_ETA2 = 2,
@@ -137,6 +138,10 @@ struct KyberKey {
137138
byte h[KYBER_SYM_SZ];
138139
/* Randomizer for decapsulation. */
139140
byte z[KYBER_SYM_SZ];
141+
#ifdef WOLFSSL_MLKEM_CACHE_A
142+
/* A matrix from key generation. */
143+
sword16 a[KYBER_MAX_K * KYBER_MAX_K * KYBER_N];
144+
#endif
140145
};
141146

142147
#ifdef __cplusplus

0 commit comments

Comments
 (0)