Skip to content

Commit

Permalink
feat: support selection of encryption and decryption methods, add gm …
Browse files Browse the repository at this point in the history
…sm4 crypto type
  • Loading branch information
destinyoooo authored and dk-lockdown committed Dec 30, 2024
1 parent 35fc701 commit 75e86e9
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 10 deletions.
1 change: 1 addition & 0 deletions docker/conf/config_sdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,4 @@ app_config:
- table: departments
columns: [ "dept_name" ]
aeskey: 123456789abcdefg
cryptoType: aesgcm
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ require (
github.com/spf13/cobra v1.1.1
github.com/stretchr/testify v1.7.1
github.com/testcontainers/testcontainers-go v0.13.0
github.com/tjfoc/gmsm v1.4.1
github.com/uber-go/atomic v1.4.0
github.com/valyala/fasthttp v1.34.0
go.etcd.io/etcd/api/v3 v3.5.0-alpha.0
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,8 @@ github.com/tikv/client-go/v2 v2.0.0-alpha.0.20210831090540-391fcd842dc8/go.mod h
github.com/tikv/pd v1.1.0-beta.0.20210323121136-78679e5e209d/go.mod h1:Jw9KG11C/23Rr7DW4XWQ7H5xOgGZo6DFL1OKAF4+Igw=
github.com/tikv/pd v1.1.0-beta.0.20210818112400-0c5667766690 h1:qGn7fDqj7IZ5dozy7QVkoj+0bama92ruVGHqoCBg7W4=
github.com/tikv/pd v1.1.0-beta.0.20210818112400-0c5667766690/go.mod h1:rammPjeZgpvfrQRPkijcx8tlxF1XM5+m6kRXrkDzCAA=
github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho=
github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE=
github.com/tklauser/go-sysconf v0.3.4/go.mod h1:Cl2c8ZRWfHD5IrfHo9VN+FX9kCFjIOyVklgXycLB6ek=
github.com/tklauser/numcpus v0.2.1/go.mod h1:9aU+wOc6WjUIZEwWMP62PL/41d65P+iks1gBkr4QyP8=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
Expand Down Expand Up @@ -1412,6 +1414,7 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
Expand Down Expand Up @@ -1505,6 +1508,7 @@ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
Expand Down
19 changes: 10 additions & 9 deletions pkg/filter/crypto/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ type _filter struct {
}

type ColumnCrypto struct {
Table string
Columns []string
AesKey string
Table string
Columns []string
AesKey string
CryptoType misc.CryptoType
}

type columnIndex struct {
Expand Down Expand Up @@ -304,7 +305,7 @@ func encryptInsertValues(columns []*columnIndex, config *ColumnCrypto, valueList
if param, ok := arg.(*driver.ValueExpr); ok {
value := param.GetBytes()
if len(value) != 0 {
encoded, err := misc.AesEncryptGCM(value, []byte(config.AesKey), []byte(aesIV))
encoded, err := misc.CryptoEncrypt(value, []byte(config.AesKey), []byte(aesIV), config.CryptoType)
if err != nil {
return errors.Wrapf(err, "Encryption of %s failed", column.Column)
}
Expand All @@ -326,7 +327,7 @@ func encryptUpdateValues(updateStmt *ast.UpdateStmt, config *ColumnCrypto) error
if param, ok := arg.(*driver.ValueExpr); ok {
value := param.GetBytes()
if len(value) != 0 {
encoded, err := misc.AesEncryptGCM(value, []byte(config.AesKey), []byte(aesIV))
encoded, err := misc.CryptoEncrypt(value, []byte(config.AesKey), []byte(aesIV), config.CryptoType)
if err != nil {
return errors.Wrapf(err, "Encryption of %s failed", column.Column)
}
Expand All @@ -345,14 +346,14 @@ func encryptBindVars(columns []*columnIndex, config *ColumnCrypto, args *map[str
parameterID := fmt.Sprintf("v%d", column.Index+1)
param := (*args)[parameterID]
if arg, ok := param.(string); ok {
encoded, err := misc.AesEncryptGCM([]byte(arg), []byte(config.AesKey), []byte(aesIV))
encoded, err := misc.CryptoEncrypt([]byte(arg), []byte(config.AesKey), []byte(aesIV), config.CryptoType)
if err != nil {
return errors.Errorf("Encryption of %s failed: %v", column.Column, err)
}
val := hex.EncodeToString(encoded)
(*args)[parameterID] = val
} else if arg, ok := param.([]byte); ok {
encoded, err := misc.AesEncryptGCM(arg, []byte(config.AesKey), []byte(aesIV))
encoded, err := misc.CryptoEncrypt(arg, []byte(config.AesKey), []byte(aesIV), config.CryptoType)
if err != nil {
return errors.Errorf("Encryption of %s failed: %v", column.Column, err)
}
Expand All @@ -372,7 +373,7 @@ func decryptDecodedResult(decodedResult *mysql.Result, config *ColumnCrypto, col
if protoValue != nil {
if originalVal, ok := protoValue.Val.([]byte); ok {
if n, err := hex.Decode(originalVal, originalVal); err == nil {
if decodedVal, err := misc.AesDecryptGCM(originalVal[:n], []byte(config.AesKey), []byte(aesIV)); err == nil {
if decodedVal, err := misc.CryptoDecrypt(originalVal[:n], []byte(config.AesKey), []byte(aesIV), config.CryptoType); err == nil {
r.Values[column.Index].Val = decodedVal
}
}
Expand All @@ -385,7 +386,7 @@ func decryptDecodedResult(decodedResult *mysql.Result, config *ColumnCrypto, col
if protoValue != nil {
if originalVal, ok := protoValue.Val.([]byte); ok {
if n, err := hex.Decode(originalVal, originalVal); err == nil {
if decodedVal, err := misc.AesDecryptGCM(originalVal[:n], []byte(config.AesKey), []byte(aesIV)); err == nil {
if decodedVal, err := misc.CryptoDecrypt(originalVal[:n], []byte(config.AesKey), []byte(aesIV), config.CryptoType); err == nil {
r.Values[column.Index].Val = decodedVal
}
}
Expand Down
188 changes: 187 additions & 1 deletion pkg/misc/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,112 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"github.com/pkg/errors"
"github.com/tjfoc/gmsm/sm4"
"io"
)

"github.com/pkg/errors"
type CryptoType int

const (
CryptoAESGCM CryptoType = iota
CryptoAESCBC
CryptoAESECB
CryptoAESCFB
CryptoSM4GCM
CryptoSM4ECB
CryptoSM4CBC
CryptoSM4CFB
CryptoSM4OFB
)

func (c *CryptoType) UnmarshalText(text []byte) error {
if c == nil {
return errors.New("can't unmarshal a nil *CryptoType")
}
if !c.unmarshalText(bytes.ToLower(text)) {
return fmt.Errorf("unrecognized protocol type: %q", text)
}
return nil
}

func (c *CryptoType) unmarshalText(text []byte) bool {
switch string(text) {
case "aesgcm":
*c = CryptoAESGCM
case "aescbc":
*c = CryptoAESCBC
case "aesecb":
*c = CryptoAESECB
case "aescfb":
*c = CryptoAESCFB
case "sm4gcm":
*c = CryptoSM4GCM
case "sm4ecb":
*c = CryptoSM4ECB
case "sm4cbc":
*c = CryptoSM4CBC
case "sm4cfb":
*c = CryptoSM4CFB
case "sm4ofb":
*c = CryptoSM4OFB
default:
return false
}
return true
}

func CryptoEncrypt(data []byte, key []byte, iv []byte, cryptoType CryptoType) ([]byte, error) {
switch cryptoType {
case CryptoAESGCM:
return AesEncryptGCM(data, key, iv)
case CryptoAESCBC:
return AesEncryptCBC(data, key, iv)
case CryptoAESECB:
return AesEncryptECB(data, key)
case CryptoAESCFB:
return AesEncryptCFB(data, key)
case CryptoSM4GCM:
return Sm4EncryptGCM(data, key, iv)
case CryptoSM4ECB:
return Sm4EncryptECB(data, key)
case CryptoSM4CBC:
return Sm4EncryptCBC(data, key, iv)
case CryptoSM4CFB:
return Sm4EncryptCFB(data, key, iv)
case CryptoSM4OFB:
return Sm4EncryptOFB(data, key, iv)
default:
return AesEncryptGCM(data, key, iv)
}
}

func CryptoDecrypt(encrypted []byte, key []byte, iv []byte, cryptoType CryptoType) ([]byte, error) {
switch cryptoType {
case CryptoAESGCM:
return AesDecryptGCM(encrypted, key, iv)
case CryptoAESCBC:
return AesDecryptCBC(encrypted, key, iv)
case CryptoAESECB:
return AesDecryptECB(encrypted, key)
case CryptoAESCFB:
return AesDecryptCFB(encrypted, key)
case CryptoSM4GCM:
return Sm4DecryptGCM(encrypted, key, iv)
case CryptoSM4ECB:
return Sm4DecryptECB(encrypted, key)
case CryptoSM4CBC:
return Sm4DecryptCBC(encrypted, key, iv)
case CryptoSM4CFB:
return Sm4DecryptCFB(encrypted, key, iv)
case CryptoSM4OFB:
return Sm4DecryptOFB(encrypted, key, iv)
default:
return AesDecryptGCM(encrypted, key, iv)
}
}

func AesEncryptGCM(origData []byte, key []byte, iv []byte) (encrypted []byte, err error) {
var block cipher.Block
block, err = aes.NewCipher(key)
Expand Down Expand Up @@ -178,3 +279,88 @@ func AesDecryptCFB(encrypted []byte, key []byte) (decrypted []byte, err error) {
stream.XORKeyStream(encrypted, encrypted)
return encrypted, err
}

func Sm4EncryptGCM(origData, key []byte, iv []byte) (encrypted []byte, err error) {
// Sm4GCM /**
// key: 对称加密密钥
// IV: IV向量
// in:
// A: 附加的可鉴别数据(ADD)
// mode: true - 加密; false - 解密验证
//
// return: 密文C, 鉴别标签T, 错误
encrypted, _, err = sm4.Sm4GCM(key, iv, origData, []byte{}, true)
if err != nil {
return nil, err
}
return encrypted, nil
}

func Sm4DecryptGCM(encrypted, key []byte, iv []byte) (decrypted []byte, err error) {
decrypted, _, err = sm4.Sm4GCM(key, iv, encrypted, []byte{}, true)
if err != nil {
return nil, err
}
return decrypted, nil
}

func Sm4EncryptECB(origData, key []byte) (encrypted []byte, err error) {
return sm4.Sm4Ecb(key, origData, true)
}

func Sm4DecryptECB(encrypted, key []byte) (decrypted []byte, err error) {
return sm4.Sm4Ecb(key, encrypted, false)
}

func Sm4EncryptCBC(origData, key, iv []byte) (encrypted []byte, err error) {
if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil {
return nil, err
}
return sm4.Sm4Cbc(key, origData, true)
}

func Sm4DecryptCBC(encrypted, key, iv []byte) (decrypted []byte, err error) {
if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil {
return nil, err
}
return sm4.Sm4Cbc(key, encrypted, false)
}

func Sm4EncryptCFB(origData, key, iv []byte) (encrypted []byte, err error) {
if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil {
return nil, err
}
return sm4.Sm4CFB(key, origData, true)
}

func Sm4DecryptCFB(encrypted, key, iv []byte) (decrypted []byte, err error) {
if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil {
return nil, err
}
return sm4.Sm4CFB(key, encrypted, false)
}

func Sm4EncryptOFB(origData, key, iv []byte) (encrypted []byte, err error) {
if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil {
return nil, err
}
return sm4.Sm4OFB(key, origData, true)
}

func Sm4DecryptOFB(encrypted, key, iv []byte) (decrypted []byte, err error) {
if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil {
return nil, err
}
return sm4.Sm4OFB(key, encrypted, false)
}

func EnsureByteArrayLength16(input []byte) []byte {
if len(input) == 16 {
return input
}
repeated := append(input, input...)
for len(repeated) < 16 {
repeated = append(repeated, input...)
}
return repeated[:16]
}
81 changes: 81 additions & 0 deletions pkg/misc/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,84 @@ func TestAesDecryptCFB(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []byte("exampleplaintext"), decrypted)
}

func TestSm4EncryptGCM(t *testing.T) {
key, _ := hex.DecodeString("31323334353637383961626364656667")
plaintext := []byte("sunset4")
encrypted, err := Sm4EncryptGCM(plaintext, key, []byte("greatdbpack!"))
assert.Nil(t, err)
t.Logf("%x", encrypted)
}

func TestSm4DecryptGCM(t *testing.T) {
key, _ := hex.DecodeString("31323334353637383961626364656667")
encrypted, _ := hex.DecodeString("4b3dd6cb3e0145")
decrypted, err := Sm4DecryptGCM(encrypted, key, []byte("greatdbpack!"))
assert.Nil(t, err)
t.Logf("%s", decrypted)
assert.Equal(t, []byte("sunset4"), decrypted)
}

func TestSm4EncryptECB(t *testing.T) {
key, _ := hex.DecodeString("6368616e676520746869732070617373")
plaintext := []byte("exampleplaintext")
encrypted, err := Sm4EncryptECB(plaintext, key)
assert.Nil(t, err)
t.Logf("%x", encrypted)
}

func TestSm4DecryptECB(t *testing.T) {
key, _ := hex.DecodeString("6368616e676520746869732070617373")
encrypted, _ := hex.DecodeString("1cadd74166afbe5f4bdaf6ebb49d4c46ce96714d2c0839338f995f4854c61b58")
decrypted, err := Sm4DecryptECB(encrypted, key)
assert.Nil(t, err)
assert.Equal(t, []byte("exampleplaintext"), decrypted)
}

func TestSm4EncryptCBC(t *testing.T) {
key, _ := hex.DecodeString("6368616e676520746869732070617373")
plaintext := []byte("exampleplaintext")
encrypted, err := Sm4EncryptCBC(plaintext, key, []byte("impressivedbpack"))
assert.Nil(t, err)
t.Logf("%x", encrypted)
}

func TestSm4DecryptCBC(t *testing.T) {
key, _ := hex.DecodeString("6368616e676520746869732070617373")
encrypted, _ := hex.DecodeString("2e88063cb32a13ce8fbfb60512c23d78d257734049682849d7c82a19f00e131a")
decrypted, err := Sm4DecryptCBC(encrypted, key, []byte("impressivedbpack"))
assert.Nil(t, err)
assert.Equal(t, []byte("exampleplaintext"), decrypted)
}

func TestSm4EncryptCFB(t *testing.T) {
key, _ := hex.DecodeString("6368616e676520746869732070617373")
plaintext := []byte("exampleplaintext")
encrypted, err := Sm4EncryptCFB(plaintext, key, []byte("impressivedbpack"))
assert.Nil(t, err)
t.Logf("%x", encrypted)
}

func TestSm4DecryptCFB(t *testing.T) {
key, _ := hex.DecodeString("6368616e676520746869732070617373")
encrypted, _ := hex.DecodeString("5ce63f4fac3744073aa91ac44bdc4ab44a19895a9fcb106947eae2cecfd99e62")
decrypted, err := Sm4DecryptCFB(encrypted, key, []byte("impressivedbpack"))
assert.Nil(t, err)
assert.Equal(t, []byte("exampleplaintext"), decrypted)
}

func TestSm4EncryptOFB(t *testing.T) {
key, _ := hex.DecodeString("6368616e676520746869732070617373")
plaintext := []byte("exampleplaintext")
encrypted, err := Sm4EncryptOFB(plaintext, key, []byte("impressivedbpack"))
assert.Nil(t, err)
t.Logf("%x", encrypted)
}

func TestSm4DecryptOFB(t *testing.T) {
key, _ := hex.DecodeString("6368616e676520746869732070617373")
encrypted, _ := hex.DecodeString("5ce63f4fac3744073aa91ac44bdc4ab4f83abab6ff8e4fd91da0740e339f9b2d")
decrypted, err := Sm4DecryptOFB(encrypted, key, []byte("impressivedbpack"))
assert.Nil(t, err)
assert.Equal(t, []byte("exampleplaintext"), decrypted)
}

0 comments on commit 75e86e9

Please sign in to comment.