From 9247fedd9953765af4544ae940ac319f7f9c388c Mon Sep 17 00:00:00 2001 From: K1 Date: Fri, 29 Nov 2024 21:01:34 +0800 Subject: [PATCH] SM4 supports cipher.Block interface By integrating the SM4 API from Tongsuo to support the cipher.Block interface in the golang standard library, enhancing the usability of the SM4 interface. Note, we should add enable-export-sm4 to config options while building Tongsuo. --- .github/workflows/main.yml | 8 +- README.md | 2 +- crypto/init.go | 1 + crypto/shim.h | 1 + crypto/sm4/sm4.go | 51 +++++++++++ crypto/sm4/sm4_test.go | 174 +++++++++++++++++++++++++++++++++++++ examples/sm4/main.go | 73 +++++++--------- 7 files changed, 263 insertions(+), 47 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a54e710..2787110 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -35,7 +35,7 @@ jobs: - name: Build Tongsuo run: | cd Tongsuo - ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls + ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4 make -j4 make install @@ -69,7 +69,7 @@ jobs: - name: Build Tongsuo run: | cd Tongsuo - ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls + ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4 make -j4 make install @@ -108,7 +108,7 @@ jobs: - name: Build Tongsuo Static run: | cd tongsuo - ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls no-shared + ./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4 no-shared make -j4 make install @@ -147,7 +147,7 @@ jobs: run: | mkdir _build cd _build - perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls + perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls enable-export-sm4 nmake /S nmake install working-directory: Tongsuo diff --git a/README.md b/README.md index eae4ca7..02d51e9 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ cd Tongsuo git checkout 8.3-stable -./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls +./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls enable-export-sm4 make -j make install ``` diff --git a/crypto/init.go b/crypto/init.go index 5be290c..6c0312c 100644 --- a/crypto/init.go +++ b/crypto/init.go @@ -48,6 +48,7 @@ var ( ErrInternalError = errors.New("internal error") ErrEmptyKey = errors.New("empty key") ErrNoData = errors.New("no data") + ErrInvalidKeySize = errors.New("invalid key size") ) func init() { diff --git a/crypto/shim.h b/crypto/shim.h index fb3b73f..70f3851 100644 --- a/crypto/shim.h +++ b/crypto/shim.h @@ -25,6 +25,7 @@ #include #include #include +#include /* shim methods */ extern int X_tscrypto_init(); diff --git a/crypto/sm4/sm4.go b/crypto/sm4/sm4.go index 5583fe0..74904aa 100644 --- a/crypto/sm4/sm4.go +++ b/crypto/sm4/sm4.go @@ -12,11 +12,18 @@ import "C" import ( "bytes" + "crypto/cipher" "fmt" + "unsafe" "github.com/tongsuo-project/tongsuo-go-sdk/crypto" ) +const ( + BlockSize = 16 + KeySize = 16 +) + type Encrypter interface { // crypto.EncryptionCipherCtx SetPadding(pad bool) @@ -50,6 +57,50 @@ type sm4Decrypter struct { tag []byte } +type sm4Cipher struct { + rk [32]uint32 +} + +func (c *sm4Cipher) BlockSize() int { + return BlockSize +} + +func NewCipher(key []byte) (cipher.Block, error) { + if len(key) != KeySize { + return nil, fmt.Errorf("invalid key size: %w", crypto.ErrInvalidKeySize) + } + + cipher := &sm4Cipher{} + ret := C.SM4_set_key((*C.uchar)(&key[0]), (*C.SM4_KEY)(unsafe.Pointer(&cipher.rk))) + if ret != 1 { + return nil, fmt.Errorf("failed to set key: %w", crypto.ErrInternalError) + } + + return cipher, nil +} + +func (c *sm4Cipher) Encrypt(dst, src []byte) { + if len(src) < BlockSize { + panic("sm4: input not full block") + } + if len(dst) < BlockSize { + panic("sm4: output not full block") + } + + C.SM4_encrypt((*C.uchar)(&src[0]), (*C.uchar)(&dst[0]), (*C.SM4_KEY)(unsafe.Pointer(&c.rk))) +} + +func (c *sm4Cipher) Decrypt(dst, src []byte) { + if len(src) < BlockSize { + panic("sm4: input not full block") + } + if len(dst) < BlockSize { + panic("sm4: output not full block") + } + + C.SM4_decrypt((*C.uchar)(&src[0]), (*C.uchar)(&dst[0]), (*C.SM4_KEY)(unsafe.Pointer(&c.rk))) +} + func getSM4Cipher(mode int) (*crypto.Cipher, error) { var cipher *crypto.Cipher var err error diff --git a/crypto/sm4/sm4_test.go b/crypto/sm4/sm4_test.go index a81b71c..f25de20 100644 --- a/crypto/sm4/sm4_test.go +++ b/crypto/sm4/sm4_test.go @@ -9,6 +9,7 @@ package sm4_test import ( "bytes" + "crypto/cipher" "encoding/hex" "strings" "testing" @@ -20,6 +21,179 @@ import ( const hexPlainText1 = `AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFE EEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA` +func TestSM4ECBWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + cipherText, _ := hex.DecodeString("681EDF34D206965E86B3E94F536E4246") + + block, err := sm4.NewCipher(key) + if err != nil { + t.Fatal("failed to create SM4 cipher: ", err) + } + + cipherText1 := make([]byte, len(plainText)) + + for i := 0; i < len(plainText); i += block.BlockSize() { + block.Encrypt(cipherText1[i:i+block.BlockSize()], plainText[i:i+block.BlockSize()]) + } + + if !bytes.Equal(cipherText1, cipherText) { + t.Fatalf("exp:%x got:%x", cipherText, cipherText1) + } + + plainText1 := make([]byte, len(cipherText1)) + + for i := 0; i < len(cipherText1); i += block.BlockSize() { + block.Decrypt(plainText1[i:i+block.BlockSize()], cipherText1[i:i+block.BlockSize()]) + } + + if !bytes.Equal(plainText, plainText1) { + t.Fatalf("exp:%x got:%x", plainText, plainText1) + } +} + +func testCryptWithCipherBlock(t *testing.T, mode string, key, iv, plainText, cipherText []byte) { + t.Helper() + + block, err := sm4.NewCipher(key) + if err != nil { + t.Fatal("failed to create SM4 cipher: ", err) + } + + cipherText1 := make([]byte, len(plainText)) + + switch mode { + case "CBC": + stream := cipher.NewCBCEncrypter(block, iv) + stream.CryptBlocks(cipherText1, plainText) + case "CFB": + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(cipherText1, plainText) + case "OFB": + stream := cipher.NewOFB(block, iv) + stream.XORKeyStream(cipherText1, plainText) + case "CTR": + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(cipherText1, plainText) + } + + if !bytes.Equal(cipherText1, cipherText) { + t.Fatalf("exp:%x got:%x", cipherText, cipherText1) + } + + plainText1 := make([]byte, len(plainText)) + + switch mode { + case "CBC": + stream := cipher.NewCBCDecrypter(block, iv) + stream.CryptBlocks(plainText1, cipherText1) + case "CFB": + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(plainText1, cipherText1) + case "OFB": + stream := cipher.NewOFB(block, iv) + stream.XORKeyStream(plainText1, cipherText1) + case "CTR": + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(plainText1, cipherText1) + } + + if !bytes.Equal(plainText, plainText1) { + t.Fatalf("exp:%x got:%x", plainText, plainText1) + } +} + +func TestSM4CBCWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") + cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B") + + testCryptWithCipherBlock(t, "CBC", key, iv, plainText, cipherText) +} + +func TestSM4CFBWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") + cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A70569ED258A85A0467CC92AAB393DD978995") + + testCryptWithCipherBlock(t, "CFB", key, iv, plainText, cipherText) +} + +func TestSM4OFBWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") + cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A7056F2075D28B5235F58D50027E4177D2BCE") + + testCryptWithCipherBlock(t, "OFB", key, iv, plainText, cipherText) +} + +func TestSM4CTRWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + hexCipherText := `C2B4759E78AC3CF43D0852F4E8D5F9FD7256E8A5FCB65A350EE00630912E44492A0B17E1B85B060D0FBA612D8A95831638 +B361FD5FFACD942F081485A83CA35D` + plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", "")) + cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", "")) + + testCryptWithCipherBlock(t, "CTR", key, iv, plainText, cipherText) +} + +func TestSM4GCMWithCipherBlock(t *testing.T) { + t.Parallel() + + key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210") + iv, _ := hex.DecodeString("00001234567800000000ABCD") + aad, _ := hex.DecodeString("FEEDFACEDEADBEEFFEEDFACEDEADBEEFABADDAD2") + tag, _ := hex.DecodeString("83DE3541E4C2B58177E065A9BF7B62EC") + hexCipherText := `17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A5 +6834CBCF98C397B4024A2691233B8D` + plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", "")) + cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", "")) + + block, err := sm4.NewCipher(key) + if err != nil { + t.Fatal("failed to create SM4 cipher: ", err) + } + + stream, err := cipher.NewGCM(block) + if err != nil { + t.Fatal("failed to create GCM: ", err) + } + + cipherText1 := stream.Seal(nil, iv, plainText, aad) + + if !bytes.Equal(cipherText1, append(cipherText, tag...)) { + t.Fatalf("exp:%x got:%x", cipherText1, append(cipherText, tag...)) + } + + stream2, err := cipher.NewGCM(block) + if err != nil { + t.Fatal("failed to create GCM: ", err) + } + + plainText1, err := stream2.Open(nil, iv, cipherText1, aad) + if err != nil { + t.Fatal("failed to decrypt: ", err) + } + + if !bytes.Equal(plainText1, plainText) { + t.Fatalf("exp:%x got:%x", plainText1, plainText) + } +} + func doEncrypt(t *testing.T, mode int, key, iv, plainText, cipherText []byte) { t.Helper() diff --git a/examples/sm4/main.go b/examples/sm4/main.go index d7b5c54..bd30437 100644 --- a/examples/sm4/main.go +++ b/examples/sm4/main.go @@ -9,11 +9,11 @@ package main import ( "bytes" + "crypto/cipher" "encoding/hex" "fmt" "log" - "github.com/tongsuo-project/tongsuo-go-sdk/crypto" "github.com/tongsuo-project/tongsuo-go-sdk/crypto/sm4" ) @@ -23,20 +23,18 @@ func sm4CBCEncrypt() { plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B") - enc, err := sm4.NewEncrypter(crypto.CipherModeCBC, key, iv) + block, err := sm4.NewCipher(key) if err != nil { - log.Fatal("failed to create encrypter: ", err) + log.Fatal("failed to create SM4 cipher: ", err) } - enc.SetPadding(false) + cipherText1 := make([]byte, len(plainText)) - actualCipherText, err := enc.EncryptAll(plainText) - if err != nil { - log.Fatal("failed to encrypt: ", err) - } + stream := cipher.NewCBCEncrypter(block, iv) + stream.CryptBlocks(cipherText1, plainText) - if !bytes.Equal(cipherText, actualCipherText) { - log.Fatalf("exp:%x got:%x", cipherText, actualCipherText) + if !bytes.Equal(cipherText1, cipherText) { + log.Fatalf("exp:%x got:%x", cipherText, cipherText1) } fmt.Println("[sm4CBCEncrypt]") @@ -52,20 +50,18 @@ func sm4CBCDecrypt() { plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210") cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B") - enc, err := sm4.NewDecrypter(crypto.CipherModeCBC, key, iv) + block, err := sm4.NewCipher(key) if err != nil { - log.Fatal("failed to create decrypter: ", err) + log.Fatal("failed to create SM4 cipher: ", err) } - enc.SetPadding(false) + plainText1 := make([]byte, len(cipherText)) - actualPlainText, err := enc.DecryptAll(cipherText) - if err != nil { - log.Fatal("failed to decrypt: ", err) - } + stream := cipher.NewCBCDecrypter(block, iv) + stream.CryptBlocks(plainText1, cipherText) - if !bytes.Equal(plainText, actualPlainText) { - log.Fatalf("exp:%x got:%x", plainText, actualPlainText) + if !bytes.Equal(plainText, plainText1) { + log.Fatalf("exp:%x got:%x", plainText, plainText1) } fmt.Println("[sm4CBCDecrypt]") @@ -83,29 +79,20 @@ func sm4GCMEncrypt() { plainText, _ := hex.DecodeString("AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFEEEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA") cipherText, _ := hex.DecodeString("17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A56834CBCF98C397B4024A2691233B8D") - enc, err := sm4.NewEncrypter(crypto.CipherModeGCM, key, iv) + block, err := sm4.NewCipher(key) if err != nil { - log.Fatal("failed to create encrypter: ", err) + log.Fatal("failed to create SM4 cipher: ", err) } - enc.SetAAD(aad) - - actualCipherText, err := enc.EncryptAll(plainText) + stream, err := cipher.NewGCM(block) if err != nil { - log.Fatal("failed to encrypt: ", err) + log.Fatal("failed to create GCM: ", err) } - if !bytes.Equal(cipherText, actualCipherText) { - log.Fatalf("exp:%x got:%x", cipherText, actualCipherText) - } - - actualTag, err := enc.GetTag() - if err != nil { - log.Fatal("failed to get tag: ", err) - } + cipherText1 := stream.Seal(nil, iv, plainText, aad) - if !bytes.Equal(tag, actualTag) { - log.Fatalf("exp:%x got:%x", tag, actualTag) + if !bytes.Equal(cipherText1, append(cipherText, tag...)) { + log.Fatalf("exp:%x got:%x", cipherText1, append(cipherText, tag...)) } fmt.Println("[sm4GCMEncrypt]") @@ -125,21 +112,23 @@ func sm4GCMDecrypt() { plainText, _ := hex.DecodeString("AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFEEEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA") cipherText, _ := hex.DecodeString("17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A56834CBCF98C397B4024A2691233B8D") - dec, err := sm4.NewDecrypter(crypto.CipherModeGCM, key, iv) + block, err := sm4.NewCipher(key) if err != nil { - log.Fatal("failed to create decrypter: ", err) + log.Fatal("failed to create SM4 cipher: ", err) } - dec.SetTag(tag) - dec.SetAAD(aad) + stream, err := cipher.NewGCM(block) + if err != nil { + log.Fatal("failed to create GCM: ", err) + } - actualPlainText, err := dec.DecryptAll(cipherText) + plainText1, err := stream.Open(nil, iv, append(cipherText, tag...), aad) if err != nil { log.Fatal("failed to decrypt: ", err) } - if !bytes.Equal(plainText, actualPlainText) { - log.Fatalf("exp:%x got:%x", plainText, actualPlainText) + if !bytes.Equal(plainText1, plainText) { + log.Fatalf("exp:%x got:%x", plainText1, plainText) } fmt.Println("[sm4GCMDecrypt]")