@@ -18,11 +18,15 @@ pub use aead::{
18
18
array:: { Array , AssocArraySize } ,
19
19
} ;
20
20
21
- use aead:: { PostfixTagged , array:: ArraySize , inout:: InOutBuf } ;
21
+ use aead:: {
22
+ TagPosition ,
23
+ array:: ArraySize ,
24
+ inout:: { InOut , InOutBuf } ,
25
+ } ;
22
26
use cipher:: {
23
27
BlockCipherDecrypt , BlockCipherEncrypt , BlockSizeUser ,
24
- consts:: { U12 , U16 } ,
25
- typenum:: Unsigned ,
28
+ consts:: { U2 , U12 , U16 } ,
29
+ typenum:: Prod ,
26
30
} ;
27
31
use core:: marker:: PhantomData ;
28
32
use dbl:: Dbl ;
@@ -55,7 +59,9 @@ pub type Nonce<NonceSize> = Array<u8, NonceSize>;
55
59
/// OCB3 tag
56
60
pub type Tag < TagSize > = Array < u8 , TagSize > ;
57
61
58
- pub ( crate ) type Block = Array < u8 , U16 > ;
62
+ type BlockSize = U16 ;
63
+ pub ( crate ) type Block = Array < u8 , BlockSize > ;
64
+ type DoubleBlock = Array < u8 , Prod < BlockSize , U2 > > ;
59
65
60
66
mod sealed {
61
67
use aead:: array:: {
@@ -168,6 +174,7 @@ where
168
174
{
169
175
type NonceSize = NonceSize ;
170
176
type TagSize = TagSize ;
177
+ const TAG_POSITION : TagPosition = TagPosition :: Postfix ;
171
178
}
172
179
173
180
impl < Cipher , NonceSize , TagSize > From < Cipher > for Ocb3 < Cipher , NonceSize , TagSize >
@@ -190,14 +197,6 @@ where
190
197
}
191
198
}
192
199
193
- impl < Cipher , NonceSize , TagSize > PostfixTagged for Ocb3 < Cipher , NonceSize , TagSize >
194
- where
195
- Cipher : BlockSizeUser < BlockSize = U16 > + BlockCipherEncrypt + BlockCipherDecrypt ,
196
- NonceSize : sealed:: NonceSizes ,
197
- TagSize : sealed:: TagSizes ,
198
- {
199
- }
200
-
201
200
impl < Cipher , NonceSize , TagSize > AeadInOut for Ocb3 < Cipher , NonceSize , TagSize >
202
201
where
203
202
Cipher : BlockSizeUser < BlockSize = U16 > + BlockCipherEncrypt + BlockCipherDecrypt ,
@@ -215,29 +214,30 @@ where
215
214
}
216
215
217
216
// First, try to process many blocks at once.
218
- let ( processed_bytes , mut offset_i, mut checksum_i) = self . wide_encrypt ( nonce, buffer) ;
217
+ let ( tail , index , mut offset_i, mut checksum_i) = self . wide_encrypt ( nonce, buffer) ;
219
218
220
- let mut i = ( processed_bytes / 16 ) + 1 ;
219
+ let mut i = index ;
221
220
222
221
// Then, process the remaining blocks.
223
- for p_i in Block :: slice_as_chunks_mut ( & mut buffer[ processed_bytes..] ) . 0 {
222
+ let ( blocks, mut tail) : ( InOutBuf < ' _ , ' _ , Block > , _ ) = tail. into_chunks ( ) ;
223
+
224
+ for p_i in blocks {
224
225
// offset_i = offset_{i-1} xor L_{ntz(i)}
225
226
inplace_xor ( & mut offset_i, & self . ll [ ntz ( i) ] ) ;
226
227
// checksum_i = checksum_{i-1} xor p_i
227
- inplace_xor ( & mut checksum_i, p_i) ;
228
+ inplace_xor ( & mut checksum_i, p_i. get_in ( ) ) ;
228
229
// c_i = offset_i xor ENCIPHER(K, p_i xor offset_i)
229
- let c_i = p_i;
230
- inplace_xor ( c_i, & offset_i) ;
231
- self . cipher . encrypt_block ( c_i) ;
232
- inplace_xor ( c_i, & offset_i) ;
230
+ let mut c_i = p_i;
231
+ c_i. xor_in2out ( & offset_i) ;
232
+ self . cipher . encrypt_block ( c_i. get_out ( ) ) ;
233
+ inplace_xor ( c_i. get_out ( ) , & offset_i) ;
233
234
234
235
i += 1 ;
235
236
}
236
237
237
238
// Process any partial blocks.
238
- if ( buffer. len ( ) % 16 ) != 0 {
239
- let processed_bytes = ( i - 1 ) * 16 ;
240
- let remaining_bytes = buffer. len ( ) - processed_bytes;
239
+ if !tail. is_empty ( ) {
240
+ let remaining_bytes = tail. len ( ) ;
241
241
242
242
// offset_* = offset_m xor L_*
243
243
inplace_xor ( & mut offset_i, & self . ll_star ) ;
@@ -247,15 +247,13 @@ where
247
247
self . cipher . encrypt_block ( & mut pad) ;
248
248
// checksum_* = checksum_m xor (P_* || 1 || zeros(127-bitlen(P_*)))
249
249
let checksum_rhs = & mut [ 0u8 ; 16 ] ;
250
- checksum_rhs[ ..remaining_bytes] . copy_from_slice ( & buffer [ processed_bytes.. ] ) ;
250
+ checksum_rhs[ ..remaining_bytes] . copy_from_slice ( tail . get_in ( ) ) ;
251
251
checksum_rhs[ remaining_bytes] = 0b1000_0000 ;
252
252
inplace_xor ( & mut checksum_i, checksum_rhs. as_ref ( ) ) ;
253
253
// C_* = P_* xor Pad[1..bitlen(P_*)]
254
- let p_star = & mut buffer [ processed_bytes.. ] ;
254
+ let p_star = tail . get_out ( ) ;
255
255
let pad = & mut pad[ ..p_star. len ( ) ] ;
256
- for ( aa, bb) in p_star. iter_mut ( ) . zip ( pad) {
257
- * aa ^= * bb;
258
- }
256
+ tail. xor_in2out ( pad) ;
259
257
}
260
258
261
259
let tag = self . compute_tag ( associated_data, & mut checksum_i, & offset_i) ;
@@ -270,7 +268,7 @@ where
270
268
buffer : InOutBuf < ' _ , ' _ , u8 > ,
271
269
tag : & aead:: Tag < Self > ,
272
270
) -> aead:: Result < ( ) > {
273
- let expected_tag = self . decrypt_in_place_return_tag ( nonce, associated_data, buffer) ;
271
+ let expected_tag = self . decrypt_inout_return_tag ( nonce, associated_data, buffer) ;
274
272
if expected_tag. ct_eq ( tag) . into ( ) {
275
273
Ok ( ( ) )
276
274
} else {
@@ -286,41 +284,40 @@ where
286
284
TagSize : sealed:: TagSizes ,
287
285
{
288
286
/// Decrypts in place and returns expected tag.
289
- pub ( crate ) fn decrypt_in_place_return_tag (
287
+ pub ( crate ) fn decrypt_inout_return_tag (
290
288
& self ,
291
289
nonce : & Nonce < NonceSize > ,
292
290
associated_data : & [ u8 ] ,
293
- buffer : & mut [ u8 ] ,
291
+ buffer : InOutBuf < ' _ , ' _ , u8 > ,
294
292
) -> aead:: Tag < Self > {
295
293
if ( buffer. len ( ) > C_MAX ) || ( associated_data. len ( ) > A_MAX ) {
296
294
unimplemented ! ( )
297
295
}
298
296
299
297
// First, try to process many blocks at once.
300
- let ( processed_bytes , mut offset_i, mut checksum_i) = self . wide_decrypt ( nonce, buffer) ;
298
+ let ( tail , index , mut offset_i, mut checksum_i) = self . wide_decrypt ( nonce, buffer) ;
301
299
302
- let mut i = ( processed_bytes / 16 ) + 1 ;
300
+ let mut i = index ;
303
301
304
302
// Then, process the remaining blocks.
305
- let ( blocks, _remaining ) = Block :: slice_as_chunks_mut ( & mut buffer [ processed_bytes.. ] ) ;
303
+ let ( blocks, mut tail ) : ( InOutBuf < ' _ , ' _ , Block > , _ ) = tail . into_chunks ( ) ;
306
304
for c_i in blocks {
307
305
// offset_i = offset_{i-1} xor L_{ntz(i)}
308
306
inplace_xor ( & mut offset_i, & self . ll [ ntz ( i) ] ) ;
309
307
// p_i = offset_i xor DECIPHER(K, c_i xor offset_i)
310
- let p_i = c_i;
311
- inplace_xor ( p_i, & offset_i) ;
312
- self . cipher . decrypt_block ( p_i) ;
313
- inplace_xor ( p_i, & offset_i) ;
308
+ let mut p_i = c_i;
309
+ p_i. xor_in2out ( & offset_i) ;
310
+ self . cipher . decrypt_block ( p_i. get_out ( ) ) ;
311
+ inplace_xor ( p_i. get_out ( ) , & offset_i) ;
314
312
// checksum_i = checksum_{i-1} xor p_i
315
- inplace_xor ( & mut checksum_i, p_i) ;
313
+ inplace_xor ( & mut checksum_i, p_i. get_out ( ) ) ;
316
314
317
315
i += 1 ;
318
316
}
319
317
320
318
// Process any partial blocks.
321
- if ( buffer. len ( ) % 16 ) != 0 {
322
- let processed_bytes = ( i - 1 ) * 16 ;
323
- let remaining_bytes = buffer. len ( ) - processed_bytes;
319
+ if !tail. is_empty ( ) {
320
+ let remaining_bytes = tail. len ( ) ;
324
321
325
322
// offset_* = offset_m xor L_*
326
323
inplace_xor ( & mut offset_i, & self . ll_star ) ;
@@ -329,14 +326,12 @@ where
329
326
inplace_xor ( & mut pad, & offset_i) ;
330
327
self . cipher . encrypt_block ( & mut pad) ;
331
328
// P_* = C_* xor Pad[1..bitlen(C_*)]
332
- let c_star = & mut buffer [ processed_bytes.. ] ;
329
+ let c_star = tail . get_in ( ) ;
333
330
let pad = & mut pad[ ..c_star. len ( ) ] ;
334
- for ( aa, bb) in c_star. iter_mut ( ) . zip ( pad) {
335
- * aa ^= * bb;
336
- }
331
+ tail. xor_in2out ( pad) ;
337
332
// checksum_* = checksum_m xor (P_* || 1 || zeros(127-bitlen(P_*)))
338
333
let checksum_rhs = & mut [ 0u8 ; 16 ] ;
339
- checksum_rhs[ ..remaining_bytes] . copy_from_slice ( & buffer [ processed_bytes.. ] ) ;
334
+ checksum_rhs[ ..remaining_bytes] . copy_from_slice ( tail . get_out ( ) ) ;
340
335
checksum_rhs[ remaining_bytes] = 0b1000_0000 ;
341
336
inplace_xor ( & mut checksum_i, checksum_rhs. as_ref ( ) ) ;
342
337
}
@@ -347,81 +342,85 @@ where
347
342
/// Encrypts plaintext in groups of two.
348
343
///
349
344
/// Adapted from https://www.cs.ucdavis.edu/~rogaway/ocb/news/code/ocb.c
350
- fn wide_encrypt ( & self , nonce : & Nonce < NonceSize > , buffer : & mut [ u8 ] ) -> ( usize , Block , Block ) {
345
+ fn wide_encrypt < ' i , ' o > (
346
+ & self ,
347
+ nonce : & Nonce < NonceSize > ,
348
+ buffer : InOutBuf < ' i , ' o , u8 > ,
349
+ ) -> ( InOutBuf < ' i , ' o , u8 > , usize , Block , Block ) {
351
350
const WIDTH : usize = 2 ;
352
351
353
352
let mut i = 1 ;
354
353
355
354
let mut offset_i = [ Block :: default ( ) ; WIDTH ] ;
356
- offset_i[ offset_i . len ( ) - 1 ] = initial_offset ( & self . cipher , nonce, TagSize :: to_u32 ( ) ) ;
355
+ offset_i[ 1 ] = initial_offset ( & self . cipher , nonce, TagSize :: to_u32 ( ) ) ;
357
356
let mut checksum_i = Block :: default ( ) ;
358
- for wide_blocks in buffer. chunks_exact_mut ( <Block as AssocArraySize >:: Size :: USIZE * WIDTH ) {
359
- let p_i = split_into_two_blocks ( wide_blocks) ;
360
357
358
+ let ( wide_blocks, tail) : ( InOutBuf < ' _ , ' _ , DoubleBlock > , _ ) = buffer. into_chunks ( ) ;
359
+ for wide_block in wide_blocks. into_iter ( ) {
360
+ let mut p_i = split_into_two_blocks ( wide_block) ;
361
361
// checksum_i = checksum_{i-1} xor p_i
362
362
for p_ij in & p_i {
363
- inplace_xor ( & mut checksum_i, p_ij) ;
363
+ inplace_xor ( & mut checksum_i, p_ij. get_in ( ) ) ;
364
364
}
365
365
366
366
// offset_i = offset_{i-1} xor L_{ntz(i)}
367
- offset_i[ 0 ] = offset_i[ offset_i . len ( ) - 1 ] ;
367
+ offset_i[ 0 ] = offset_i[ 1 ] ;
368
368
inplace_xor ( & mut offset_i[ 0 ] , & self . ll [ ntz ( i) ] ) ;
369
- for j in 1 ..p_i. len ( ) {
370
- offset_i[ j] = offset_i[ j - 1 ] ;
371
- inplace_xor ( & mut offset_i[ j] , & self . ll [ ntz ( i + j) ] ) ;
372
- }
369
+ offset_i[ 1 ] = offset_i[ 0 ] ;
370
+ inplace_xor ( & mut offset_i[ 1 ] , & self . ll [ ntz ( i + 1 ) ] ) ;
373
371
374
372
// c_i = offset_i xor ENCIPHER(K, p_i xor offset_i)
375
373
for j in 0 ..p_i. len ( ) {
376
- inplace_xor ( p_i[ j] , & offset_i[ j] ) ;
377
- self . cipher . encrypt_block ( p_i[ j] ) ;
378
- inplace_xor ( p_i[ j] , & offset_i[ j] )
374
+ p_i[ j] . xor_in2out ( & offset_i[ j] ) ;
375
+ self . cipher . encrypt_block ( p_i[ j] . get_out ( ) ) ;
376
+ inplace_xor ( p_i[ j] . get_out ( ) , & offset_i[ j] ) ;
379
377
}
380
378
381
379
i += WIDTH ;
382
380
}
383
381
384
- let processed_bytes = ( buffer. len ( ) / ( WIDTH * 16 ) ) * ( WIDTH * 16 ) ;
385
-
386
- ( processed_bytes, offset_i[ offset_i. len ( ) - 1 ] , checksum_i)
382
+ ( tail, i, offset_i[ offset_i. len ( ) - 1 ] , checksum_i)
387
383
}
388
384
389
385
/// Decrypts plaintext in groups of two.
390
386
///
391
387
/// Adapted from https://www.cs.ucdavis.edu/~rogaway/ocb/news/code/ocb.c
392
- fn wide_decrypt ( & self , nonce : & Nonce < NonceSize > , buffer : & mut [ u8 ] ) -> ( usize , Block , Block ) {
388
+ fn wide_decrypt < ' i , ' o > (
389
+ & self ,
390
+ nonce : & Nonce < NonceSize > ,
391
+ buffer : InOutBuf < ' i , ' o , u8 > ,
392
+ ) -> ( InOutBuf < ' i , ' o , u8 > , usize , Block , Block ) {
393
393
const WIDTH : usize = 2 ;
394
394
395
395
let mut i = 1 ;
396
396
397
397
let mut offset_i = [ Block :: default ( ) ; WIDTH ] ;
398
- offset_i[ offset_i . len ( ) - 1 ] = initial_offset ( & self . cipher , nonce, TagSize :: to_u32 ( ) ) ;
398
+ offset_i[ 1 ] = initial_offset ( & self . cipher , nonce, TagSize :: to_u32 ( ) ) ;
399
399
let mut checksum_i = Block :: default ( ) ;
400
- for wide_blocks in buffer. chunks_exact_mut ( 16 * WIDTH ) {
401
- let c_i = split_into_two_blocks ( wide_blocks) ;
400
+
401
+ let ( wide_blocks, tail) : ( InOutBuf < ' _ , ' _ , DoubleBlock > , _ ) = buffer. into_chunks ( ) ;
402
+ for wide_block in wide_blocks. into_iter ( ) {
403
+ let mut c_i = split_into_two_blocks ( wide_block) ;
402
404
403
405
// offset_i = offset_{i-1} xor L_{ntz(i)}
404
- offset_i[ 0 ] = offset_i[ offset_i . len ( ) - 1 ] ;
406
+ offset_i[ 0 ] = offset_i[ 1 ] ;
405
407
inplace_xor ( & mut offset_i[ 0 ] , & self . ll [ ntz ( i) ] ) ;
406
- for j in 1 ..c_i. len ( ) {
407
- offset_i[ j] = offset_i[ j - 1 ] ;
408
- inplace_xor ( & mut offset_i[ j] , & self . ll [ ntz ( i + j) ] ) ;
409
- }
408
+ offset_i[ 1 ] = offset_i[ 0 ] ;
409
+ inplace_xor ( & mut offset_i[ 1 ] , & self . ll [ ntz ( i + 1 ) ] ) ;
410
410
411
411
// p_i = offset_i xor DECIPHER(K, c_i xor offset_i)
412
412
// checksum_i = checksum_{i-1} xor p_i
413
413
for j in 0 ..c_i. len ( ) {
414
- inplace_xor ( c_i[ j] , & offset_i[ j] ) ;
415
- self . cipher . decrypt_block ( c_i[ j] ) ;
416
- inplace_xor ( c_i[ j] , & offset_i[ j] ) ;
417
- inplace_xor ( & mut checksum_i, c_i[ j] ) ;
414
+ c_i[ j] . xor_in2out ( & offset_i[ j] ) ;
415
+ self . cipher . decrypt_block ( c_i[ j] . get_out ( ) ) ;
416
+ inplace_xor ( c_i[ j] . get_out ( ) , & offset_i[ j] ) ;
417
+ inplace_xor ( & mut checksum_i, c_i[ j] . get_out ( ) ) ;
418
418
}
419
419
420
420
i += WIDTH ;
421
421
}
422
422
423
- let processed_bytes = ( buffer. len ( ) / ( WIDTH * 16 ) ) * ( WIDTH * 16 ) ;
424
- ( processed_bytes, offset_i[ offset_i. len ( ) - 1 ] , checksum_i)
423
+ ( tail, i, offset_i[ offset_i. len ( ) - 1 ] , checksum_i)
425
424
}
426
425
427
426
/// Computes HASH function defined in https://www.rfc-editor.org/rfc/rfc7253.html#section-4.1
@@ -580,11 +579,10 @@ pub(crate) fn ntz(n: usize) -> usize {
580
579
}
581
580
582
581
#[ inline]
583
- pub ( crate ) fn split_into_two_blocks ( two_blocks : & mut [ u8 ] ) -> [ & mut Block ; 2 ] {
584
- const BLOCK_SIZE : usize = 16 ;
585
- debug_assert_eq ! ( two_blocks. len( ) , BLOCK_SIZE * 2 ) ;
586
- let ( b0, b1) = two_blocks. split_at_mut ( BLOCK_SIZE ) ;
587
- [ b0. try_into ( ) . unwrap ( ) , b1. try_into ( ) . unwrap ( ) ]
582
+ pub ( crate ) fn split_into_two_blocks < ' i , ' o > (
583
+ two_blocks : InOut < ' i , ' o , DoubleBlock > ,
584
+ ) -> [ InOut < ' i , ' o , Block > ; 2 ] {
585
+ Array :: < InOut < ' i , ' o , Block > , U2 > :: from ( two_blocks) . into ( )
588
586
}
589
587
590
588
#[ cfg( test) ]
0 commit comments