63
63
#error "Can't use small memory with assembly optimized code"
64
64
#endif
65
65
#endif
66
+ #if defined(WOLFSSL_MLKEM_CACHE_A )
67
+ #if defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM ) || \
68
+ defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM )
69
+ #error "Can't cache A with small memory code"
70
+ #endif
71
+ #endif
66
72
67
73
#ifdef WOLFSSL_WC_KYBER
68
74
@@ -265,10 +271,14 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
265
271
sword16 * e = NULL ;
266
272
#else
267
273
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
274
+ #ifndef WOLFSSL_MLKEM_CACHE_A
268
275
sword16 e [(KYBER_MAX_K + 1 ) * KYBER_MAX_K * KYBER_N ];
269
276
#else
270
277
sword16 e [KYBER_MAX_K * KYBER_N ];
271
278
#endif
279
+ #else
280
+ sword16 e [KYBER_MAX_K * KYBER_N ];
281
+ #endif
272
282
#endif
273
283
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
274
284
sword16 * a = NULL ;
@@ -285,6 +295,8 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
285
295
}
286
296
287
297
if (ret == 0 ) {
298
+ key -> flags = 0 ;
299
+
288
300
/* Establish parameters based on key type. */
289
301
switch (key -> type ) {
290
302
#ifndef WOLFSSL_NO_ML_KEM
@@ -332,9 +344,17 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
332
344
if (ret == 0 ) {
333
345
/* Allocate dynamic memory for matrix and error vector. */
334
346
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
347
+ #ifndef WOLFSSL_MLKEM_CACHE_A
348
+ /* e (v) | a (m) */
335
349
e = (sword16 * )XMALLOC ((kp + 1 ) * kp * KYBER_N * sizeof (sword16 ),
336
350
key -> heap , DYNAMIC_TYPE_TMP_BUFFER );
337
351
#else
352
+ /* e (v) */
353
+ e = (sword16 * )XMALLOC (kp * KYBER_N * sizeof (sword16 ),
354
+ key -> heap , DYNAMIC_TYPE_TMP_BUFFER );
355
+ #endif
356
+ #else
357
+ /* e (v) */
338
358
e = (sword16 * )XMALLOC (kp * KYBER_N * sizeof (sword16 ),
339
359
key -> heap , DYNAMIC_TYPE_TMP_BUFFER );
340
360
#endif
@@ -346,8 +366,10 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
346
366
if (ret == 0 ) {
347
367
const byte * d = rand ;
348
368
349
- #ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
350
- /* Error vector allocated at end of a. */
369
+ #ifdef WOLFSSL_MLKEM_CACHE_A
370
+ a = key -> a ;
371
+ #elif !defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM )
372
+ /* Matrix A allocated at end of error vector. */
351
373
a = e + (kp * KYBER_N );
352
374
#endif
353
375
@@ -391,6 +413,9 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
391
413
ret = kyber_gen_matrix (& key -> prf , a , kp , pubSeed , 0 );
392
414
}
393
415
if (ret == 0 ) {
416
+ #ifdef WOLFSSL_MLKEM_CACHE_A
417
+ key -> flags |= KYBER_FLAG_A_SET ;
418
+ #endif
394
419
/* Generate key pair from random data. */
395
420
kyber_keygen (key -> priv , key -> pub , e , a , kp );
396
421
#else
@@ -514,7 +539,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
514
539
unsigned char * ct )
515
540
{
516
541
int ret = 0 ;
517
- sword16 * sp = NULL ;
542
+ sword16 * at = NULL ;
518
543
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
519
544
sword16 * k = NULL ;
520
545
sword16 * ep = NULL ;
@@ -523,12 +548,12 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
523
548
unsigned int kp = 0 ;
524
549
unsigned int compVecSz = 0 ;
525
550
#ifndef WOLFSSL_NO_MALLOC
526
- sword16 * at = NULL ;
551
+ sword16 * sp = NULL ;
527
552
#else
528
553
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
529
- sword16 at [((KYBER_MAX_K + 3 ) * KYBER_MAX_K + 3 ) * KYBER_N ];
554
+ sword16 sp [((KYBER_MAX_K + 3 ) * KYBER_MAX_K + 3 ) * KYBER_N ];
530
555
#else
531
- sword16 at [3 * KYBER_MAX_K * KYBER_N ];
556
+ sword16 sp [3 * KYBER_MAX_K * KYBER_N ];
532
557
#endif
533
558
#endif
534
559
#ifdef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
@@ -588,13 +613,13 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
588
613
if (ret == 0 ) {
589
614
/* Allocate dynamic memory for all matrices, vectors and polynomials. */
590
615
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
591
- at = (sword16 * )XMALLOC (((kp + 3 ) * kp + 3 ) * KYBER_N * sizeof (sword16 ),
616
+ sp = (sword16 * )XMALLOC (((kp + 3 ) * kp + 3 ) * KYBER_N * sizeof (sword16 ),
592
617
key -> heap , DYNAMIC_TYPE_TMP_BUFFER );
593
618
#else
594
- at = (sword16 * )XMALLOC (3 * kp * KYBER_N * sizeof (sword16 ), key -> heap ,
619
+ sp = (sword16 * )XMALLOC (3 * kp * KYBER_N * sizeof (sword16 ), key -> heap ,
595
620
DYNAMIC_TYPE_TMP_BUFFER );
596
621
#endif
597
- if (at == NULL ) {
622
+ if (sp == NULL ) {
598
623
ret = MEMORY_E ;
599
624
}
600
625
}
@@ -603,15 +628,15 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
603
628
if (ret == 0 ) {
604
629
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
605
630
/* Assign allocated dynamic memory to pointers.
606
- * at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p) */
631
+ * sp (b) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p) */
632
+ at = sp + KYBER_N * kp ;
607
633
k = at + KYBER_N * kp * kp ;
608
- sp = k + KYBER_N ;
609
- ep = sp + KYBER_N * kp ;
634
+ ep = k + KYBER_N ;
610
635
epp = ep + KYBER_N * kp ;
611
636
#else
612
637
/* Assign allocated dynamic memory to pointers.
613
- * at (v) | sp (v) | bp (v) */
614
- sp = at + KYBER_N * kp ;
638
+ * sp (v) | at (v) | bp (v) */
639
+ at = sp + KYBER_N * kp ;
615
640
#endif
616
641
617
642
/* Initialize the PRF for use in the noise generation. */
@@ -623,6 +648,21 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
623
648
/* Generate noise using PRF. */
624
649
ret = kyber_get_noise (& key -> prf , kp , sp , ep , epp , coins );
625
650
}
651
+ #ifdef WOLFSSL_MLKEM_CACHE_A
652
+ if ((ret == 0 ) && ((key -> flags & KYBER_FLAG_A_SET ) != 0 )) {
653
+ unsigned int i ;
654
+ /* Transpose matrix. */
655
+ for (i = 0 ; i < kp ; i ++ ) {
656
+ unsigned int j ;
657
+ for (j = 0 ; j < kp ; j ++ ) {
658
+ XMEMCPY (& at [(i * kp + j ) * KYBER_N ],
659
+ & key -> a [(j * kp + i ) * KYBER_N ],
660
+ KYBER_N * 2 );
661
+ }
662
+ }
663
+ }
664
+ else
665
+ #endif
626
666
if (ret == 0 ) {
627
667
/* Generate the transposed matrix. */
628
668
ret = kyber_gen_matrix (& key -> prf , at , kp , key -> pubSeed , 1 );
@@ -632,7 +672,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
632
672
sword16 * v ;
633
673
634
674
/* Assign remaining allocated dynamic memory to pointers.
635
- * at (m ) | k (p ) | sp (v ) | ep (p) | epp (v) | bp (v) | v (p)*/
675
+ * sp (v ) | at (m ) | k (p ) | ep (p) | epp (v) | bp (v) | v (p)*/
636
676
bp = epp + KYBER_N ;
637
677
v = bp + KYBER_N * kp ;
638
678
@@ -644,7 +684,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
644
684
}
645
685
if (ret == 0 ) {
646
686
/* Assign remaining allocated dynamic memory to pointers.
647
- * at (v) | sp (v) | bp (v) */
687
+ * sp (v) | at (v) | bp (v) */
648
688
bp = sp + KYBER_N * kp ;
649
689
v = at ;
650
690
@@ -676,7 +716,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
676
716
677
717
#ifndef WOLFSSL_NO_MALLOC
678
718
/* Dispose of dynamic memory allocated in function. */
679
- XFREE (at , key -> heap , DYNAMIC_TYPE_TMP_BUFFER );
719
+ XFREE (sp , key -> heap , DYNAMIC_TYPE_TMP_BUFFER );
680
720
#endif
681
721
682
722
return ret ;
0 commit comments