diff --git a/conn.go b/conn.go index b007f7d..224522d 100644 --- a/conn.go +++ b/conn.go @@ -198,6 +198,15 @@ func (c *Conn) CurrentCipher() (string, error) { return C.GoString(p), nil } +func (c *Conn) GetVersion() (string, error) { + p := C.X_SSL_get_version(c.ssl) + if p == nil { + return "", errors.New("Failed to get version") + } + + return C.GoString(p), nil +} + func (c *Conn) fillInputBuffer() error { for { n, err := c.into_ssl.ReadFromOnce(c.conn) diff --git a/examples/client/ntls_client.go b/examples/client/ntls_client.go new file mode 100644 index 0000000..7167383 --- /dev/null +++ b/examples/client/ntls_client.go @@ -0,0 +1,151 @@ +// Copyright 2023 The Tongsuo Project Authors. All Rights Reserved. +// +// Licensed under the Apache License 2.0 (the "License"). You may not use +// this file except in compliance with the License. You can obtain a copy +// in the file LICENSE in the source distribution or at +// https://github.com/Tongsuo-Project/tongsuo-go-sdk/blob/main/LICENSE + +package main + +import ( + "bufio" + "flag" + "fmt" + "os" + + ts "github.com/tongsuo-project/tongsuo-go-sdk" +) + +func main() { + cipherSuite := "" + signCertFile := "" + signKeyFile := "" + encCertFile := "" + encKeyFile := "" + caFile := "" + connAddr := "" + + flag.StringVar(&connAddr, "conn", "127.0.0.1:443", "host:port") + flag.StringVar(&cipherSuite, "cipher", "ECC-SM2-SM4-CBC-SM3", "cipher suite") + flag.StringVar(&signCertFile, "sign_cert", "", "sign certificate file") + flag.StringVar(&signKeyFile, "sign_key", "", "sign private key file") + flag.StringVar(&encCertFile, "enc_cert", "", "encrypt certificate file") + flag.StringVar(&encKeyFile, "enc_key", "", "encrypt private key file") + flag.StringVar(&caFile, "CAfile", "", "CA certificate file") + + flag.Parse() + + ctx, err := ts.NewCtxWithVersion(ts.NTLS) + if err != nil { + panic(err) + } + + if err := ctx.SetCipherList(cipherSuite); err != nil { + panic(err) + } + + if signCertFile != "" { + signCertPEM, err := os.ReadFile(signCertFile) + if err != nil { + panic(err) + } + signCert, err := ts.LoadCertificateFromPEM(signCertPEM) + if err != nil { + panic(err) + } + + if err := ctx.UseSignCertificate(signCert); err != nil { + panic(err) + } + } + + if signKeyFile != "" { + signKeyPEM, err := os.ReadFile(signKeyFile) + if err != nil { + panic(err) + } + signKey, err := ts.LoadPrivateKeyFromPEM(signKeyPEM) + if err != nil { + panic(err) + } + + if err := ctx.UseSignPrivateKey(signKey); err != nil { + panic(err) + } + } + + if encCertFile != "" { + encCertPEM, err := os.ReadFile(encCertFile) + if err != nil { + panic(err) + } + encCert, err := ts.LoadCertificateFromPEM(encCertPEM) + if err != nil { + panic(err) + } + + if err := ctx.UseEncryptCertificate(encCert); err != nil { + panic(err) + } + } + + if encKeyFile != "" { + encKeyPEM, err := os.ReadFile(encKeyFile) + if err != nil { + panic(err) + } + + encKey, err := ts.LoadPrivateKeyFromPEM(encKeyPEM) + if err != nil { + panic(err) + } + + if err := ctx.UseEncryptPrivateKey(encKey); err != nil { + panic(err) + } + } + + if caFile != "" { + if err := ctx.LoadVerifyLocations(caFile, ""); err != nil { + panic(err) + } + } + + conn, err := ts.Dial("tcp", connAddr, ctx, ts.InsecureSkipHostVerification) + if err != nil { + panic(err) + } + defer conn.Close() + + cipher, err := conn.CurrentCipher() + if err != nil { + panic(err) + } + + ver, err := conn.GetVersion() + if err != nil { + panic(err) + } + + fmt.Println("New connection: " + ver + ", cipher=" + cipher) + + reader := bufio.NewReader(os.Stdin) + text, _ := reader.ReadString('\n') + + request := text + "\n" + fmt.Println(">>>\n" + request) + if _, err := conn.Write([]byte(request)); err != nil { + panic(err) + } + + buffer := make([]byte, 4096) + n, err := conn.Read(buffer) + if err != nil { + fmt.Println("read error:", err) + return + } + + fmt.Println("<<<\n" + string(buffer[:n])) + + return +} diff --git a/examples/server/ntls_server.go b/examples/server/ntls_server.go new file mode 100644 index 0000000..f218e2f --- /dev/null +++ b/examples/server/ntls_server.go @@ -0,0 +1,178 @@ +// Copyright 2023 The Tongsuo Project Authors. All Rights Reserved. +// +// Licensed under the Apache License 2.0 (the "License"). You may not use +// this file except in compliance with the License. You can obtain a copy +// in the file LICENSE in the source distribution or at +// https://github.com/Tongsuo-Project/tongsuo-go-sdk/blob/main/LICENSE + +package main + +import ( + "bufio" + "flag" + "log" + "net" + "os" + + ts "github.com/tongsuo-project/tongsuo-go-sdk" +) + +func handleConn(conn net.Conn) { + defer conn.Close() + + // Read incoming data into buffer + req, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + log.Printf("Error reading incoming data: %v", err) + return + } + + ntls := conn.(*ts.Conn) + ver, err := ntls.GetVersion() + if err != nil { + log.Println("failed get version: ", err) + return + } + + cipher, err := ntls.CurrentCipher() + if err != nil { + log.Println("failed get cipher: ", err) + return + } + + log.Println("New connection: " + ver + ", cipher=" + cipher) + log.Println("Recv:\n" + req) + + // Send a response back to the client + if _, err = conn.Write([]byte(req + "\n")); err != nil { + log.Printf("Unable to send response: %v", err) + return + } + + log.Println("Sent:\n" + req) + log.Println("Close connection") +} + +func newNTLSServer(acceptAddr string, signCertFile string, signKeyFile string, encCertFile string, encKeyFile string, cafile string) (net.Listener, error) { + ctx, err := ts.NewCtxWithVersion(ts.NTLS) + if err != nil { + log.Println(err) + return nil, err + } + + if err := ctx.LoadVerifyLocations(cafile, ""); err != nil { + log.Println(err) + return nil, err + } + + encCertPEM, err := os.ReadFile(encCertFile) + if err != nil { + log.Println(err) + return nil, err + } + + signCertPEM, err := os.ReadFile(signCertFile) + if err != nil { + log.Println(err) + return nil, err + } + + encCert, err := ts.LoadCertificateFromPEM(encCertPEM) + if err != nil { + log.Println(err) + return nil, err + } + + signCert, err := ts.LoadCertificateFromPEM(signCertPEM) + if err != nil { + log.Println(err) + return nil, err + } + + if err := ctx.UseEncryptCertificate(encCert); err != nil { + log.Println(err) + return nil, err + } + + if err := ctx.UseSignCertificate(signCert); err != nil { + log.Println(err) + return nil, err + } + + encKeyPEM, err := os.ReadFile(encKeyFile) + if err != nil { + log.Println(err) + return nil, err + } + + signKeyPEM, err := os.ReadFile(signKeyFile) + if err != nil { + log.Println(err) + return nil, err + } + + encKey, err := ts.LoadPrivateKeyFromPEM(encKeyPEM) + if err != nil { + log.Println(err) + return nil, err + } + + signKey, err := ts.LoadPrivateKeyFromPEM(signKeyPEM) + if err != nil { + log.Println(err) + return nil, err + } + + if err := ctx.UseEncryptPrivateKey(encKey); err != nil { + log.Println(err) + return nil, err + } + + if err := ctx.UseSignPrivateKey(signKey); err != nil { + log.Println(err) + return nil, err + } + + lis, err := ts.Listen("tcp", acceptAddr, ctx) + if err != nil { + log.Println(err) + return nil, err + } + + return lis, nil +} + +func main() { + signCertFile := "" + signKeyFile := "" + encCertFile := "" + encKeyFile := "" + caFile := "" + acceptAddr := "" + + flag.StringVar(&acceptAddr, "accept", "127.0.0.1:443", "host:port") + flag.StringVar(&signCertFile, "sign_cert", "", "sign certificate file") + flag.StringVar(&signKeyFile, "sign_key", "", "sign private key file") + flag.StringVar(&encCertFile, "enc_cert", "", "encrypt certificate file") + flag.StringVar(&encKeyFile, "enc_key", "", "encrypt private key file") + flag.StringVar(&caFile, "CAfile", "", "CA certificate file") + + flag.Parse() + + server, err := newNTLSServer(acceptAddr, signCertFile, signKeyFile, encCertFile, encKeyFile, caFile) + + if err != nil { + return + } + defer server.Close() + + for { + conn, err := server.Accept() + if err != nil { + log.Println("failed accept: ", err) + continue + } + + go handleConn(conn) + } +} diff --git a/shim.c b/shim.c index 02a9c82..61ca40a 100644 --- a/shim.c +++ b/shim.c @@ -445,9 +445,15 @@ long X_SSL_clear_options(SSL* ssl, long options) { long X_SSL_set_tlsext_host_name(SSL *ssl, const char *name) { return SSL_set_tlsext_host_name(ssl, name); } -const char * X_SSL_get_cipher_name(const SSL *ssl) { + +const char *X_SSL_get_cipher_name(const SSL *ssl) { return SSL_get_cipher_name(ssl); } + +const char *X_SSL_get_version(const SSL *ssl) { + return SSL_get_version(ssl); +} + int X_SSL_session_reused(SSL *ssl) { return SSL_session_reused(ssl); } diff --git a/shim.h b/shim.h index a6c9b32..ea41c72 100644 --- a/shim.h +++ b/shim.h @@ -50,7 +50,8 @@ extern long X_SSL_set_options(SSL* ssl, long options); extern long X_SSL_get_options(SSL* ssl); extern long X_SSL_clear_options(SSL* ssl, long options); extern long X_SSL_set_tlsext_host_name(SSL *ssl, const char *name); -extern const char * X_SSL_get_cipher_name(const SSL *ssl); +extern const char *X_SSL_get_cipher_name(const SSL *ssl); +extern const char *X_SSL_get_version(const SSL *ssl); extern int X_SSL_session_reused(SSL *ssl); extern int X_SSL_new_index();