diff --git a/common/net_table.go b/common/net_table.go index ec1875b..9bdd7d4 100644 --- a/common/net_table.go +++ b/common/net_table.go @@ -1,28 +1,29 @@ package common import ( - "github.com/meshbird/meshbird/log" - "github.com/meshbird/meshbird/network/protocol" - "github.com/meshbird/meshbird/secure" "net" "sync" "time" + + "github.com/meshbird/meshbird/log" + "github.com/meshbird/meshbird/network/protocol" + "github.com/meshbird/meshbird/secure" ) type NetTable struct { BaseService - localNode *LocalNode - waitGroup sync.WaitGroup - dhtInChan chan string + localNode *LocalNode + waitGroup sync.WaitGroup + dhtInChan chan string - lock sync.RWMutex - blackList map[string]time.Time - peers map[string]*RemoteNode + lock sync.RWMutex + blackList map[string]time.Time + peers map[string]*RemoteNode heartbeatTicker <-chan time.Time - logger log.Logger + logger log.Logger } func (nt NetTable) Name() string { @@ -55,7 +56,7 @@ func (nt *NetTable) Stop() { } } -func (nt *NetTable) GetDHTInChannel() chan <- string { +func (nt *NetTable) GetDHTInChannel() chan<- string { return nt.dhtInChan } @@ -164,7 +165,7 @@ func (nt *NetTable) SendPacket(dstIP net.IP, payload []byte) { return } - payloadEnc, err := secure.EncryptIV(payload, nt.localNode.State().Secret.Key, nt.localNode.State().Secret.Key) + payloadEnc, err := secure.EncryptIV(payload, nt.localNode.State().Secret.Key) if err != nil { nt.logger.Error("error on encrypt, %v", err) return diff --git a/common/remotenode.go b/common/remotenode.go index 968fcba..429783d 100644 --- a/common/remotenode.go +++ b/common/remotenode.go @@ -5,6 +5,7 @@ import ( "io" "net" "time" + "github.com/meshbird/meshbird/log" "github.com/meshbird/meshbird/network/protocol" "github.com/meshbird/meshbird/secure" @@ -80,7 +81,7 @@ func (rn *RemoteNode) listen(ln *LocalNode) { case protocol.TypeTransfer: rn.logger.Debug("Writing to interface...") payloadEncrypted := pack.Data.Msg.(protocol.TransferMessage).Bytes() - payload, errDec := secure.DecryptIV(payloadEncrypted, ln.State().Secret.Key, ln.State().Secret.Key) + payload, errDec := secure.DecryptIV(payloadEncrypted, ln.State().Secret.Key) if errDec != nil { rn.logger.Error("error on decrypt, %v", err) break @@ -96,4 +97,4 @@ func (rn *RemoteNode) listen(ln *LocalNode) { rn.lastHeartbeat = time.Now() } } -} \ No newline at end of file +} diff --git a/secure/crypt.go b/secure/crypt.go index f721856..92dfaea 100644 --- a/secure/crypt.go +++ b/secure/crypt.go @@ -1,43 +1,59 @@ package secure import ( - "bytes" "crypto/aes" "crypto/cipher" + "crypto/rand" + "io" + "log" ) -func EncryptIV(decrypted []byte, key []byte, iv []byte) ([]byte, error) { - ac, err := aes.NewCipher(key) +func EncryptIV(decrypted []byte, key []byte) ([]byte, error) { + + c, err := aes.NewCipher(key) if err != nil { + log.Println("[CRYPT][AES][ENC] Problem %s", err.Error()) return nil, err } - c := cipher.NewCBCEncrypter(ac, iv) - decrypted = PKCS5Padding(decrypted, ac.BlockSize()) - encrypted := make([]byte, len(decrypted)) - c.CryptBlocks(encrypted, decrypted) - return encrypted, nil -} -func DecryptIV(encrypted []byte, key []byte, iv []byte) ([]byte, error) { - ac, err := aes.NewCipher(key) + gcm, err := cipher.NewGCM(c) if err != nil { + log.Println("[CRYPT][AES][ENC] Problem %s", err.Error()) + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + log.Println("[CRYPT][AES][NONCE] Problem %s", err.Error()) return nil, err } - c := cipher.NewCBCDecrypter(ac, iv) - decrypted := make([]byte, len(encrypted)) - c.CryptBlocks(decrypted, encrypted) - decrypted = PKCS5UnPadding(decrypted) - return decrypted, nil -} -func PKCS5Padding(src []byte, blockSize int) []byte { - padding := blockSize - len(src)%blockSize - padtext := bytes.Repeat([]byte{byte(padding)}, padding) - return append(src, padtext...) + return gcm.Seal(nonce, nonce, decrypted, nil), nil + } -func PKCS5UnPadding(src []byte) []byte { - length := len(src) - unpadding := int(src[length-1]) - return src[:(length - unpadding)] +func DecryptIV(ciphertext []byte, key []byte) ([]byte, error) { + + c, err := aes.NewCipher(key) + if err != nil { + log.Println("[DECRYPT][AES] Problem %s", err.Error()) + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + log.Println("[DECRYPT][AES] Problem %s", err.Error()) + return nil, err + } + + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + log.Println("[DECRYPT][AES] Problem %s", "Cyphertext too short") + return nil, err + } + + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + + return gcm.Open(nil, nonce, ciphertext, nil) + } diff --git a/secure/crypt_test.go b/secure/crypt_test.go index 491c156..f1b3521 100644 --- a/secure/crypt_test.go +++ b/secure/crypt_test.go @@ -38,7 +38,7 @@ func BenchmarkEncryptAesCbc(b *testing.B) { func BenchmarkDescryptAesCbc(b *testing.B) { key := randomBytes(16) iv := randomBytes(16) - encrypted, err := EncryptIV(original, key, iv) + encrypted, err := EncryptIV(original, key) if err != nil { b.Fatal(err) }