diff --git a/aes256cbc/aes256cbc.go b/aes256cbc/aes256cbc.go new file mode 100644 index 0000000..4f89342 --- /dev/null +++ b/aes256cbc/aes256cbc.go @@ -0,0 +1,169 @@ +// Package aes256cbc is a helper to generate OpenSSL compatible encryption +// with autmatic IV derivation and storage. As long as the key is known all +// data can also get decrypted using OpenSSL CLI. +// Code from http://dequeue.blogspot.de/2014/11/decrypting-something-encrypted-with.html +package aes256cbc + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "crypto/rand" + "encoding/base64" + "fmt" + "io" +) + +// OpenSSL salt is always this string + 8 bytes of actual salt +var openSSLSaltHeader = []byte("Salted__") + +type openSSLCreds [48]byte + +// openSSLEvpBytesToKey follows the OpenSSL (undocumented?) convention for extracting the key and IV from passphrase. +// It uses the EVP_BytesToKey() method which is basically: +// D_i = HASH^count(D_(i-1) || password || salt) where || denotes concatentaion, until there are sufficient bytes available +// 48 bytes since we're expecting to handle AES-256, 32bytes for a key and 16bytes for the IV +func (c *openSSLCreds) Extract(password, salt []byte) (key, iv []byte) { + m := c[:] + buf := make([]byte, 0, 16+len(password)+len(salt)) + var prevSum [16]byte + for i := 0; i < 3; i++ { + n := 0 + if i > 0 { + n = 16 + } + buf = buf[:n+len(password)+len(salt)] + copy(buf, prevSum[:]) + copy(buf[n:], password) + copy(buf[n+len(password):], salt) + prevSum = md5.Sum(buf) + copy(m[i*16:], prevSum[:]) + } + return c[:32], c[32:] +} + +// DecryptString decrypts a base64 encoded string that was encrypted using OpenSSL and AES-256-CBC. +func DecryptString(passphrase, encryptedBase64String string) (string, error) { + text, err := DecryptBase64([]byte(passphrase), []byte(encryptedBase64String)) + return string(text), err +} + +// DecryptBase64 decrypts a base64 encoded []byte that was encrypted using OpenSSL and AES-256-CBC. +func DecryptBase64(passphrase, encryptedBase64 []byte) ([]byte, error) { + encrypted := make([]byte, base64.StdEncoding.DecodedLen(len(encryptedBase64))) + n, err := base64.StdEncoding.Decode(encrypted, encryptedBase64) + if err != nil { + return nil, err + } + return Decrypt(passphrase, encrypted[:n]) +} + +// Decrypt decrypts a []byte that was encrypted using OpenSSL and AES-256-CBC. +func Decrypt(passphrase, encrypted []byte) ([]byte, error) { + if len(encrypted) < aes.BlockSize { + return nil, fmt.Errorf("Cipher data length less than aes block size") + } + saltHeader := encrypted[:aes.BlockSize] + if !bytes.Equal(saltHeader[:8], openSSLSaltHeader) { + return nil, fmt.Errorf("Does not appear to have been encrypted with OpenSSL, salt header missing.") + } + var creds openSSLCreds + key, iv := creds.Extract(passphrase, saltHeader[8:]) + + if len(encrypted) == 0 || len(encrypted)%aes.BlockSize != 0 { + return nil, fmt.Errorf("bad blocksize(%v), aes.BlockSize = %v\n", len(encrypted), aes.BlockSize) + } + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + cbc := cipher.NewCBCDecrypter(c, iv) + cbc.CryptBlocks(encrypted[aes.BlockSize:], encrypted[aes.BlockSize:]) + return pkcs7Unpad(encrypted[aes.BlockSize:]) +} + +// EncryptString encrypts a string in a manner compatible to OpenSSL encryption +// functions using AES-256-CBC as encryption algorithm and encode to base64 format. +func EncryptString(passphrase, plaintextString string) (string, error) { + encryptedBase64, err := EncryptBase64([]byte(passphrase), []byte(plaintextString)) + return string(encryptedBase64), err +} + +// EncryptBase64 encrypts a []byte in a manner compatible to OpenSSL encryption +// functions using AES-256-CBC as encryption algorithm and encode to base64 format. +func EncryptBase64(passphrase, plaintext []byte) ([]byte, error) { + encrypted, err := Encrypt(passphrase, plaintext) + encryptedBase64 := make([]byte, base64.StdEncoding.EncodedLen(len(encrypted))) + base64.StdEncoding.Encode(encryptedBase64, encrypted) + return encryptedBase64, err +} + +// Encrypt encrypts a []byte in a manner compatible to OpenSSL encryption +// functions using AES-256-CBC as encryption algorithm +func Encrypt(passphrase, plaintext []byte) ([]byte, error) { + var salt [8]byte // Generate an 8 byte salt + _, err := io.ReadFull(rand.Reader, salt[:]) + if err != nil { + return nil, err + } + + data := make([]byte, len(plaintext)+aes.BlockSize) + copy(data[0:], openSSLSaltHeader) + copy(data[8:], salt[:]) + copy(data[aes.BlockSize:], plaintext) + + var creds openSSLCreds + key, iv := creds.Extract(passphrase, salt[:]) + encrypted, err := encrypt(key, iv, data) + if err != nil { + return nil, err + } + return encrypted, nil +} + +func encrypt(key, iv, data []byte) ([]byte, error) { + padded, err := pkcs7Pad(data) + if err != nil { + return nil, err + } + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + cbc := cipher.NewCBCEncrypter(c, iv) + cbc.CryptBlocks(padded[aes.BlockSize:], padded[aes.BlockSize:]) + return padded, nil +} + +var padPatterns [aes.BlockSize][]byte + +func init() { + for i := 0; i < len(padPatterns); i++ { + padPatterns[i] = bytes.Repeat([]byte{byte(i)}, i) + } +} + +// pkcs7Pad appends padding. +func pkcs7Pad(data []byte) ([]byte, error) { + padlen := 1 + for ((len(data) + padlen) % aes.BlockSize) != 0 { + padlen = padlen + 1 + } + return append(data, padPatterns[padlen]...), nil +} + +// pkcs7Unpad returns slice of the original data without padding. +func pkcs7Unpad(data []byte) ([]byte, error) { + if len(data)%aes.BlockSize != 0 || len(data) == 0 { + return nil, fmt.Errorf("invalid data len %d", len(data)) + } + padlen := int(data[len(data)-1]) + if padlen > aes.BlockSize || padlen == 0 { + return nil, fmt.Errorf("invalid padding") + } + if !bytes.Equal(padPatterns[padlen], data[len(data)-padlen:]) { + return nil, fmt.Errorf("invalid padding") + } + return data[:len(data)-padlen], nil +} diff --git a/aes256cbc/openssl_test.go b/aes256cbc/aes256cbc_test.go similarity index 67% rename from aes256cbc/openssl_test.go rename to aes256cbc/aes256cbc_test.go index eb3b9a3..a7cae35 100644 --- a/aes256cbc/openssl_test.go +++ b/aes256cbc/aes256cbc_test.go @@ -7,16 +7,22 @@ import ( "testing" ) +func Benchmark_Decrypt(b *testing.B) { + opensslEncrypted := []byte("U2FsdGVkX19ZM5qQJGe/d5A/4pccgH+arBGTp+QnWPU=") + passphrase := []byte("z4yH36a6zerhfE5427ZV") + for i := 0; i < b.N; i++ { + DecryptBase64(passphrase, opensslEncrypted) + } +} + func TestDecryptFromString(t *testing.T) { // > echo -n "hallowelt" | openssl aes-256-cbc -pass pass:z4yH36a6zerhfE5427ZV -a -salt // U2FsdGVkX19ZM5qQJGe/d5A/4pccgH+arBGTp+QnWPU= - opensslEncrypted := []byte("U2FsdGVkX19ZM5qQJGe/d5A/4pccgH+arBGTp+QnWPU=") - passphrase := []byte("z4yH36a6zerhfE5427ZV") - - o := New() + opensslEncrypted := "U2FsdGVkX19ZM5qQJGe/d5A/4pccgH+arBGTp+QnWPU=" + passphrase := "z4yH36a6zerhfE5427ZV" - data, err := o.DecryptString(passphrase, opensslEncrypted) + data, err := DecryptString(passphrase, opensslEncrypted) if err != nil { t.Fatalf("Test errored: %s", err) @@ -28,33 +34,29 @@ func TestDecryptFromString(t *testing.T) { } func TestEncryptToDecrypt(t *testing.T) { - plaintext := []byte("hallowelt") - passphrase := []byte("z4yH36a6zerhfE5427ZV") + plaintext := "hallowelt" + passphrase := "z4yH36a6zerhfE5427ZV" - o := New() - - enc, err := o.EncryptString(passphrase, plaintext) + enc, err := EncryptString(passphrase, plaintext) if err != nil { t.Fatalf("Test errored at encrypt: %s", err) } - dec, err := o.DecryptString(passphrase, enc) + dec, err := DecryptString(passphrase, string(enc)) if err != nil { t.Fatalf("Test errored at decrypt: %s", err) } - if !bytes.Equal(dec, plaintext) { + if string(dec) != plaintext { t.Errorf("Decrypted text did not match input.") } } func TestEncryptToOpenSSL(t *testing.T) { - plaintext := []byte("hallowelt") - passphrase := []byte("z4yH36a6zerhfE5427ZV") - - o := New() + plaintext := "hallowelt" + passphrase := "z4yH36a6zerhfE5427ZV" - enc, err := o.EncryptString(passphrase, plaintext) + enc, err := EncryptString(passphrase, plaintext) if err != nil { t.Fatalf("Test errored at encrypt: %s", err) } @@ -71,7 +73,7 @@ func TestEncryptToOpenSSL(t *testing.T) { t.Errorf("OpenSSL errored: %s", err) } - if !bytes.Equal(out.Bytes(), plaintext) { + if out.String() != plaintext { t.Errorf("OpenSSL output did not match input.\nOutput was: %s", out.String()) } } diff --git a/aes256cbc/examples_test.go b/aes256cbc/examples_test.go new file mode 100644 index 0000000..f6f6910 --- /dev/null +++ b/aes256cbc/examples_test.go @@ -0,0 +1,30 @@ +package aes256cbc + +import "fmt" + +func ExampleEncryptString() { + plaintext := "Hello World!" + passphrase := "z4yH36a6zerhfE5427ZV" + + enc, err := EncryptString(passphrase, plaintext) + if err != nil { + fmt.Printf("An error occurred: %s\n", err) + } + + fmt.Printf("Encrypted text: %s\n", string(enc)) +} + +func ExampleDecryptString() { + opensslEncrypted := "U2FsdGVkX19ZM5qQJGe/d5A/4pccgH+arBGTp+QnWPU=" + passphrase := "z4yH36a6zerhfE5427ZV" + + dec, err := DecryptString(passphrase, opensslEncrypted) + if err != nil { + fmt.Printf("An error occurred: %s\n", err) + } + + fmt.Printf("Decrypted text: %s\n", string(dec)) + + // Output: + // Decrypted text: hallowelt +} diff --git a/aes256cbc/openssl.go b/aes256cbc/openssl.go deleted file mode 100644 index 8bef58b..0000000 --- a/aes256cbc/openssl.go +++ /dev/null @@ -1,197 +0,0 @@ -package aes256cbc - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/md5" - "crypto/rand" - "encoding/base64" - "fmt" - "io" -) - -// OpenSSL is a helper to generate OpenSSL compatible encryption -// with autmatic IV derivation and storage. As long as the key is known all -// data can also get decrypted using OpenSSL CLI. -// Code from http://dequeue.blogspot.de/2014/11/decrypting-something-encrypted-with.html -type OpenSSL struct { - openSSLSaltHeader []byte -} - -type openSSLCreds struct { - key []byte - iv []byte -} - -// New instanciates and initializes a new OpenSSL encrypter -func New() *OpenSSL { - return &OpenSSL{ - openSSLSaltHeader: []byte("Salted__"), // OpenSSL salt is always this string + 8 bytes of actual salt - } -} - -// DecryptString a base64 encoded string that was encrypted -// using OpenSSL and AES-256-CBC -// also compatible with crptojs -func (o *OpenSSL) DecryptString(passphrase, encryptedBase64String []byte) ([]byte, error) { - dbuf := make([]byte, base64.StdEncoding.DecodedLen(len(encryptedBase64String))) - n, err := base64.StdEncoding.Decode(dbuf, encryptedBase64String) - if err != nil { - return nil, err - } - return o.Decrypt(passphrase, dbuf[:n]) -} - -// Decrypt encrypted data that was encrypted using OpenSSL and AES-256-CBC -func (o *OpenSSL) Decrypt(passphrase, encrypted []byte) ([]byte, error) { - if len(encrypted) < aes.BlockSize { - return nil, fmt.Errorf("Cipher data length less than aes block size") - } - saltHeader := encrypted[:aes.BlockSize] - if !bytes.Equal(saltHeader[:8], o.openSSLSaltHeader) { - return nil, fmt.Errorf("Does not appear to have been encrypted with OpenSSL, salt header missing.") - } - salt := saltHeader[8:] - creds, err := o.extractOpenSSLCreds(passphrase, salt) - if err != nil { - return nil, err - } - return o.decrypt(creds.key, creds.iv, encrypted) -} - -func (o *OpenSSL) decrypt(key, iv, data []byte) ([]byte, error) { - if len(data) == 0 || len(data)%aes.BlockSize != 0 { - return nil, fmt.Errorf("bad blocksize(%v), aes.BlockSize = %v\n", len(data), aes.BlockSize) - } - c, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - cbc := cipher.NewCBCDecrypter(c, iv) - cbc.CryptBlocks(data[aes.BlockSize:], data[aes.BlockSize:]) - out, err := o.pkcs7Unpad(data[aes.BlockSize:], aes.BlockSize) - if out == nil { - return nil, err - } - return out, nil -} - -// EncryptString in a manner compatible to OpenSSL encryption -// functions using AES-256-CBC as encryption algorithm -// also compatible with crptojs from https://code.google.com/p/crypto-js/ -func (o *OpenSSL) EncryptString(passphrase, plaintextString []byte) ([]byte, error) { - enc, err := o.Encrypt(passphrase, plaintextString) - if err != nil { - return nil, err - } - - return []byte(base64.StdEncoding.EncodeToString(enc)), nil -} - -// Encrypt in a manner compatible to OpenSSL encryption -// functions using AES-256-CBC as encryption algorithm -// also compatible with crptojs from https://code.google.com/p/crypto-js/ -func (o *OpenSSL) Encrypt(passphrase, plaintextString []byte) ([]byte, error) { - salt := make([]byte, 8) // Generate an 8 byte salt - _, err := io.ReadFull(rand.Reader, salt) - if err != nil { - return nil, err - } - - data := make([]byte, len(plaintextString)+aes.BlockSize) - copy(data[0:], o.openSSLSaltHeader) - copy(data[8:], salt) - copy(data[aes.BlockSize:], plaintextString) - - creds, err := o.extractOpenSSLCreds([]byte(passphrase), salt) - if err != nil { - return nil, err - } - - enc, err := o.encrypt(creds.key, creds.iv, data) - if err != nil { - return nil, err - } - - return enc, nil -} - -func (o *OpenSSL) encrypt(key, iv, data []byte) ([]byte, error) { - padded, err := o.pkcs7Pad(data, aes.BlockSize) - if err != nil { - return nil, err - } - - c, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - cbc := cipher.NewCBCEncrypter(c, iv) - cbc.CryptBlocks(padded[aes.BlockSize:], padded[aes.BlockSize:]) - - return padded, nil -} - -// openSSLEvpBytesToKey follows the OpenSSL (undocumented?) convention for extracting the key and IV from passphrase. -// It uses the EVP_BytesToKey() method which is basically: -// D_i = HASH^count(D_(i-1) || password || salt) where || denotes concatentaion, until there are sufficient bytes available -// 48 bytes since we're expecting to handle AES-256, 32bytes for a key and 16bytes for the IV -func (o *OpenSSL) extractOpenSSLCreds(password, salt []byte) (openSSLCreds, error) { - m := make([]byte, 48) - prev := []byte{} - for i := 0; i < 3; i++ { - prev = o.hash(prev, password, salt) - copy(m[i*16:], prev) - } - return openSSLCreds{key: m[:32], iv: m[32:]}, nil -} - -func (o *OpenSSL) hash(prev, password, salt []byte) []byte { - a := make([]byte, len(prev)+len(password)+len(salt)) - copy(a, prev) - copy(a[len(prev):], password) - copy(a[len(prev)+len(password):], salt) - return o.md5sum(a) -} - -func (o *OpenSSL) md5sum(data []byte) []byte { - h := md5.New() - h.Write(data) - return h.Sum(nil) -} - -// pkcs7Pad appends padding. -func (o *OpenSSL) pkcs7Pad(data []byte, blocklen int) ([]byte, error) { - if blocklen <= 0 { - return nil, fmt.Errorf("invalid blocklen %d", blocklen) - } - padlen := 1 - for ((len(data) + padlen) % blocklen) != 0 { - padlen = padlen + 1 - } - - pad := bytes.Repeat([]byte{byte(padlen)}, padlen) - return append(data, pad...), nil -} - -// pkcs7Unpad returns slice of the original data without padding. -func (o *OpenSSL) pkcs7Unpad(data []byte, blocklen int) ([]byte, error) { - if blocklen <= 0 { - return nil, fmt.Errorf("invalid blocklen %d", blocklen) - } - if len(data)%blocklen != 0 || len(data) == 0 { - return nil, fmt.Errorf("invalid data len %d", len(data)) - } - padlen := int(data[len(data)-1]) - if padlen > blocklen || padlen == 0 { - return nil, fmt.Errorf("invalid padding") - } - pad := data[len(data)-padlen:] - for i := 0; i < padlen; i++ { - if pad[i] != byte(padlen) { - return nil, fmt.Errorf("invalid padding") - } - } - return data[:len(data)-padlen], nil -} diff --git a/main.go b/main.go index f62022a..7cec2b9 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "log" "net" "net/http" + _ "net/http/pprof" "os" "runtime" "runtime/debug" @@ -20,7 +21,7 @@ import ( "syscall" "time" - "github.com/xindong/frontd/aes256cbc" + "github.com/idada/frontd/aes256cbc" ) const ( @@ -40,7 +41,6 @@ var ( var ( _SecretPassphase []byte - _Aes256CBC = aes256cbc.New() ) var ( @@ -136,7 +136,13 @@ func handleConn(c net.Conn) { } else { rdr = bufio.NewReader(c) } - defer _BufioReaderPool.Put(rdr) + bufioReleased := false + defer func() { + if !bufioReleased { + rdr.Reset(nil) + _BufioReaderPool.Put(rdr) + } + }() addr, err := handleBinaryHdr(rdr, c) if err != nil { @@ -189,6 +195,8 @@ func handleConn(c net.Conn) { err = tunneling(string(addr), rdr, c, header) if err != nil { log.Println(err) + } else { + bufioReleased = true } } @@ -309,9 +317,25 @@ func tunneling(addr string, rdr *bufio.Reader, c net.Conn, header *bytes.Buffer) header.WriteTo(backend) } + if n := rdr.Buffered(); n > 0 { + var data []byte + data, err = rdr.Peek(n) + if err != nil { + writeErrCode(c, []byte("4103"), false) + return err + } + _, err = backend.Write(data) + if err != nil { + writeErrCode(c, []byte("4102"), false) + return err + } + } + rdr.Reset(nil) + _BufioReaderPool.Put(rdr) + // Start transfering data go pipe(c, backend) - pipe(backend, rdr) + pipe(backend, c) return nil } @@ -337,7 +361,7 @@ func backendAddrDecrypt(key []byte) ([]byte, error) { } // Try to decrypt it (AES) - addr, err := _Aes256CBC.Decrypt(_SecretPassphase, key) + addr, err := aes256cbc.Decrypt(_SecretPassphase, key) if err != nil { return nil, err } diff --git a/main_test.go b/main_test.go index c3c306c..ab8eedf 100644 --- a/main_test.go +++ b/main_test.go @@ -17,8 +17,8 @@ import ( "testing" "time" - "github.com/xindong/frontd/aes256cbc" - "github.com/xindong/frontd/reuse" + "github.com/idada/frontd/aes256cbc" + "github.com/idada/frontd/reuse" "golang.org/x/net/websocket" ) @@ -114,9 +114,7 @@ func servEcho() { // TestTextDecryptAES --- func TestTextDecryptAES(t *testing.T) { - o := aes256cbc.New() - - dec, err := o.DecryptString(_secret, _expectAESCiphertext) + dec, err := aes256cbc.DecryptBase64(_secret, _expectAESCiphertext) if err != nil { panic(err) } @@ -143,9 +141,7 @@ func TestHTTPServer(t *testing.T) { } func encryptText(plaintext, passphrase []byte) ([]byte, error) { - o := aes256cbc.New() - - return o.EncryptString(passphrase, plaintext) + return aes256cbc.EncryptBase64(passphrase, plaintext) } func testHTTPServer(hdrs map[string]string, expected string) { @@ -311,8 +307,7 @@ func testProtocol(cipherAddr, expected []byte) { // TestBinaryProtocolDecrypt --- func TestBinaryProtocolDecrypt(*testing.T) { - o := aes256cbc.New() - b, err := o.Encrypt(_secret, _echoServerAddr) + b, err := aes256cbc.Encrypt(_secret, _echoServerAddr) if err != nil { panic(err) } @@ -377,8 +372,7 @@ func BenchmarkEncryptText(b *testing.B) { func BenchmarkDecryptText(b *testing.B) { for i := 0; i < b.N; i++ { - o := aes256cbc.New() - _, err := o.DecryptString(_secret, _expectAESCiphertext) + _, err := aes256cbc.DecryptBase64(_secret, _expectAESCiphertext) if err != nil { panic(err) }