Skip to content

Commit

Permalink
SM4 supports cipher.Block interface
Browse files Browse the repository at this point in the history
By integrating the SM4 API from Tongsuo to support the cipher.Block
interface in the golang standard library, enhancing the usability of
the SM4 interface.

Note, we should add enable-export-sm4 to config options while building
Tongsuo.
  • Loading branch information
dongbeiouba committed Nov 29, 2024
1 parent 7c0ad4e commit 9247fed
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 47 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Build Tongsuo
run: |
cd Tongsuo
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4
make -j4
make install
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
- name: Build Tongsuo
run: |
cd Tongsuo
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4
make -j4
make install
Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
- name: Build Tongsuo Static
run: |
cd tongsuo
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls no-shared
./config --prefix=${RUNNER_TEMP}/tongsuo --libdir=${RUNNER_TEMP}/tongsuo/lib enable-ntls enable-export-sm4 no-shared
make -j4
make install
Expand Down Expand Up @@ -147,7 +147,7 @@ jobs:
run: |
mkdir _build
cd _build
perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls
perl ..\Configure VC-WIN64A no-makedepend --prefix=%RUNNER_TEMP%\tongsuo enable-ntls enable-export-sm4
nmake /S
nmake install
working-directory: Tongsuo
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ cd Tongsuo

git checkout 8.3-stable

./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls
./config --prefix=/opt/tongsuo --libdir=/opt/tongsuo/lib enable-ntls enable-export-sm4
make -j
make install
```
Expand Down
1 change: 1 addition & 0 deletions crypto/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ var (
ErrInternalError = errors.New("internal error")
ErrEmptyKey = errors.New("empty key")
ErrNoData = errors.New("no data")
ErrInvalidKeySize = errors.New("invalid key size")
)

func init() {
Expand Down
1 change: 1 addition & 0 deletions crypto/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <openssl/x509v3.h>
#include <openssl/ec.h>
#include <openssl/opensslv.h>
#include <openssl/sm4.h>

/* shim methods */
extern int X_tscrypto_init();
Expand Down
51 changes: 51 additions & 0 deletions crypto/sm4/sm4.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@ import "C"

import (
"bytes"
"crypto/cipher"
"fmt"
"unsafe"

"github.com/tongsuo-project/tongsuo-go-sdk/crypto"
)

const (
BlockSize = 16
KeySize = 16
)

type Encrypter interface {
// crypto.EncryptionCipherCtx
SetPadding(pad bool)
Expand Down Expand Up @@ -50,6 +57,50 @@ type sm4Decrypter struct {
tag []byte
}

type sm4Cipher struct {
rk [32]uint32
}

func (c *sm4Cipher) BlockSize() int {
return BlockSize
}

func NewCipher(key []byte) (cipher.Block, error) {
if len(key) != KeySize {
return nil, fmt.Errorf("invalid key size: %w", crypto.ErrInvalidKeySize)
}

cipher := &sm4Cipher{}
ret := C.SM4_set_key((*C.uchar)(&key[0]), (*C.SM4_KEY)(unsafe.Pointer(&cipher.rk)))
if ret != 1 {
return nil, fmt.Errorf("failed to set key: %w", crypto.ErrInternalError)
}

return cipher, nil
}

func (c *sm4Cipher) Encrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("sm4: input not full block")
}
if len(dst) < BlockSize {
panic("sm4: output not full block")
}

C.SM4_encrypt((*C.uchar)(&src[0]), (*C.uchar)(&dst[0]), (*C.SM4_KEY)(unsafe.Pointer(&c.rk)))
}

func (c *sm4Cipher) Decrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("sm4: input not full block")
}
if len(dst) < BlockSize {
panic("sm4: output not full block")
}

C.SM4_decrypt((*C.uchar)(&src[0]), (*C.uchar)(&dst[0]), (*C.SM4_KEY)(unsafe.Pointer(&c.rk)))
}

func getSM4Cipher(mode int) (*crypto.Cipher, error) {
var cipher *crypto.Cipher
var err error
Expand Down
174 changes: 174 additions & 0 deletions crypto/sm4/sm4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package sm4_test

import (
"bytes"
"crypto/cipher"
"encoding/hex"
"strings"
"testing"
Expand All @@ -20,6 +21,179 @@ import (
const hexPlainText1 = `AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCCDDDDDDDDDDDDDDDDEEEEEEEEEEEEEEEEFFFFFFFFFFFFFFFFE
EEEEEEEEEEEEEEEAAAAAAAAAAAAAAAA`

func TestSM4ECBWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("681EDF34D206965E86B3E94F536E4246")

block, err := sm4.NewCipher(key)
if err != nil {
t.Fatal("failed to create SM4 cipher: ", err)
}

cipherText1 := make([]byte, len(plainText))

for i := 0; i < len(plainText); i += block.BlockSize() {
block.Encrypt(cipherText1[i:i+block.BlockSize()], plainText[i:i+block.BlockSize()])
}

if !bytes.Equal(cipherText1, cipherText) {
t.Fatalf("exp:%x got:%x", cipherText, cipherText1)
}

plainText1 := make([]byte, len(cipherText1))

for i := 0; i < len(cipherText1); i += block.BlockSize() {
block.Decrypt(plainText1[i:i+block.BlockSize()], cipherText1[i:i+block.BlockSize()])
}

if !bytes.Equal(plainText, plainText1) {
t.Fatalf("exp:%x got:%x", plainText, plainText1)
}
}

func testCryptWithCipherBlock(t *testing.T, mode string, key, iv, plainText, cipherText []byte) {
t.Helper()

block, err := sm4.NewCipher(key)
if err != nil {
t.Fatal("failed to create SM4 cipher: ", err)
}

cipherText1 := make([]byte, len(plainText))

switch mode {
case "CBC":
stream := cipher.NewCBCEncrypter(block, iv)
stream.CryptBlocks(cipherText1, plainText)
case "CFB":
stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(cipherText1, plainText)
case "OFB":
stream := cipher.NewOFB(block, iv)
stream.XORKeyStream(cipherText1, plainText)
case "CTR":
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(cipherText1, plainText)
}

if !bytes.Equal(cipherText1, cipherText) {
t.Fatalf("exp:%x got:%x", cipherText, cipherText1)
}

plainText1 := make([]byte, len(plainText))

switch mode {
case "CBC":
stream := cipher.NewCBCDecrypter(block, iv)
stream.CryptBlocks(plainText1, cipherText1)
case "CFB":
stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(plainText1, cipherText1)
case "OFB":
stream := cipher.NewOFB(block, iv)
stream.XORKeyStream(plainText1, cipherText1)
case "CTR":
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(plainText1, cipherText1)
}

if !bytes.Equal(plainText, plainText1) {
t.Fatalf("exp:%x got:%x", plainText, plainText1)
}
}

func TestSM4CBCWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("2677F46B09C122CC975533105BD4A22AF6125F7275CE552C3A2BBCF533DE8A3B")

testCryptWithCipherBlock(t, "CBC", key, iv, plainText, cipherText)
}

func TestSM4CFBWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A70569ED258A85A0467CC92AAB393DD978995")

testCryptWithCipherBlock(t, "CFB", key, iv, plainText, cipherText)
}

func TestSM4OFBWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
plainText, _ := hex.DecodeString("0123456789ABCDEFFEDCBA98765432100123456789ABCDEFFEDCBA9876543210")
cipherText, _ := hex.DecodeString("693D9A535BAD5BB1786F53D7253A7056F2075D28B5235F58D50027E4177D2BCE")

testCryptWithCipherBlock(t, "OFB", key, iv, plainText, cipherText)
}

func TestSM4CTRWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
hexCipherText := `C2B4759E78AC3CF43D0852F4E8D5F9FD7256E8A5FCB65A350EE00630912E44492A0B17E1B85B060D0FBA612D8A95831638
B361FD5FFACD942F081485A83CA35D`
plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", ""))
cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", ""))

testCryptWithCipherBlock(t, "CTR", key, iv, plainText, cipherText)
}

func TestSM4GCMWithCipherBlock(t *testing.T) {
t.Parallel()

key, _ := hex.DecodeString("0123456789ABCDEFFEDCBA9876543210")
iv, _ := hex.DecodeString("00001234567800000000ABCD")
aad, _ := hex.DecodeString("FEEDFACEDEADBEEFFEEDFACEDEADBEEFABADDAD2")
tag, _ := hex.DecodeString("83DE3541E4C2B58177E065A9BF7B62EC")
hexCipherText := `17F399F08C67D5EE19D0DC9969C4BB7D5FD46FD3756489069157B282BB200735D82710CA5C22F0CCFA7CBF93D496AC15A5
6834CBCF98C397B4024A2691233B8D`
plainText, _ := hex.DecodeString(strings.ReplaceAll(hexPlainText1, "\n", ""))
cipherText, _ := hex.DecodeString(strings.ReplaceAll(hexCipherText, "\n", ""))

block, err := sm4.NewCipher(key)
if err != nil {
t.Fatal("failed to create SM4 cipher: ", err)
}

stream, err := cipher.NewGCM(block)
if err != nil {
t.Fatal("failed to create GCM: ", err)
}

cipherText1 := stream.Seal(nil, iv, plainText, aad)

if !bytes.Equal(cipherText1, append(cipherText, tag...)) {
t.Fatalf("exp:%x got:%x", cipherText1, append(cipherText, tag...))
}

stream2, err := cipher.NewGCM(block)
if err != nil {
t.Fatal("failed to create GCM: ", err)
}

plainText1, err := stream2.Open(nil, iv, cipherText1, aad)
if err != nil {
t.Fatal("failed to decrypt: ", err)
}

if !bytes.Equal(plainText1, plainText) {
t.Fatalf("exp:%x got:%x", plainText1, plainText)
}
}

func doEncrypt(t *testing.T, mode int, key, iv, plainText, cipherText []byte) {
t.Helper()

Expand Down
Loading

0 comments on commit 9247fed

Please sign in to comment.