@@ -22,10 +22,9 @@ import (
22
22
"crypto/cipher"
23
23
"crypto/rand"
24
24
"fmt"
25
- "io"
26
-
27
25
"github.com/pkg/errors"
28
26
"github.com/tjfoc/gmsm/sm4"
27
+ "io"
29
28
)
30
29
31
30
type CryptoType int
@@ -314,43 +313,54 @@ func Sm4DecryptECB(encrypted, key []byte) (decrypted []byte, err error) {
314
313
}
315
314
316
315
func Sm4EncryptCBC (origData , key , iv []byte ) (encrypted []byte , err error ) {
317
- if err = sm4 .SetIV (iv ); err != nil {
316
+ if err = sm4 .SetIV (EnsureByteArrayLength16 ( iv ) ); err != nil {
318
317
return nil , err
319
318
}
320
319
return sm4 .Sm4Cbc (key , origData , true )
321
320
}
322
321
323
322
func Sm4DecryptCBC (encrypted , key , iv []byte ) (decrypted []byte , err error ) {
324
- if err = sm4 .SetIV (iv ); err != nil {
323
+ if err = sm4 .SetIV (EnsureByteArrayLength16 ( iv ) ); err != nil {
325
324
return nil , err
326
325
}
327
326
return sm4 .Sm4Cbc (key , encrypted , false )
328
327
}
329
328
330
329
func Sm4EncryptCFB (origData , key , iv []byte ) (encrypted []byte , err error ) {
331
- if err = sm4 .SetIV (iv ); err != nil {
330
+ if err = sm4 .SetIV (EnsureByteArrayLength16 ( iv ) ); err != nil {
332
331
return nil , err
333
332
}
334
333
return sm4 .Sm4CFB (key , origData , true )
335
334
}
336
335
337
336
func Sm4DecryptCFB (encrypted , key , iv []byte ) (decrypted []byte , err error ) {
338
- if err = sm4 .SetIV (iv ); err != nil {
337
+ if err = sm4 .SetIV (EnsureByteArrayLength16 ( iv ) ); err != nil {
339
338
return nil , err
340
339
}
341
340
return sm4 .Sm4CFB (key , encrypted , false )
342
341
}
343
342
344
343
func Sm4EncryptOFB (origData , key , iv []byte ) (encrypted []byte , err error ) {
345
- if err = sm4 .SetIV (iv ); err != nil {
344
+ if err = sm4 .SetIV (EnsureByteArrayLength16 ( iv ) ); err != nil {
346
345
return nil , err
347
346
}
348
347
return sm4 .Sm4OFB (key , origData , true )
349
348
}
350
349
351
350
func Sm4DecryptOFB (encrypted , key , iv []byte ) (decrypted []byte , err error ) {
352
- if err = sm4 .SetIV (iv ); err != nil {
351
+ if err = sm4 .SetIV (EnsureByteArrayLength16 ( iv ) ); err != nil {
353
352
return nil , err
354
353
}
355
354
return sm4 .Sm4OFB (key , encrypted , false )
356
355
}
356
+
357
+ func EnsureByteArrayLength16 (input []byte ) []byte {
358
+ if len (input ) == 16 {
359
+ return input
360
+ }
361
+ repeated := append (input , input ... )
362
+ for len (repeated ) < 16 {
363
+ repeated = append (repeated , input ... )
364
+ }
365
+ return repeated [:16 ]
366
+ }
0 commit comments