Skip to content

Commit

Permalink
Merge pull request #14 from dongbeiouba/examples/ntls
Browse files Browse the repository at this point in the history
Add ntls client and server example; Add API conn.GetVersion()
  • Loading branch information
itomsawyer authored Dec 1, 2023
2 parents 8c4fc8f + 381c145 commit 6d0268c
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 2 deletions.
9 changes: 9 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ func (c *Conn) CurrentCipher() (string, error) {
return C.GoString(p), nil
}

func (c *Conn) GetVersion() (string, error) {
p := C.X_SSL_get_version(c.ssl)
if p == nil {
return "", errors.New("Failed to get version")
}

return C.GoString(p), nil
}

func (c *Conn) fillInputBuffer() error {
for {
n, err := c.into_ssl.ReadFromOnce(c.conn)
Expand Down
151 changes: 151 additions & 0 deletions examples/client/ntls_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright 2023 The Tongsuo Project Authors. All Rights Reserved.
//
// Licensed under the Apache License 2.0 (the "License"). You may not use
// this file except in compliance with the License. You can obtain a copy
// in the file LICENSE in the source distribution or at
// https://github.com/Tongsuo-Project/tongsuo-go-sdk/blob/main/LICENSE

package main

import (
"bufio"
"flag"
"fmt"
"os"

ts "github.com/tongsuo-project/tongsuo-go-sdk"
)

func main() {
cipherSuite := ""
signCertFile := ""
signKeyFile := ""
encCertFile := ""
encKeyFile := ""
caFile := ""
connAddr := ""

flag.StringVar(&connAddr, "conn", "127.0.0.1:443", "host:port")
flag.StringVar(&cipherSuite, "cipher", "ECC-SM2-SM4-CBC-SM3", "cipher suite")
flag.StringVar(&signCertFile, "sign_cert", "", "sign certificate file")
flag.StringVar(&signKeyFile, "sign_key", "", "sign private key file")
flag.StringVar(&encCertFile, "enc_cert", "", "encrypt certificate file")
flag.StringVar(&encKeyFile, "enc_key", "", "encrypt private key file")
flag.StringVar(&caFile, "CAfile", "", "CA certificate file")

flag.Parse()

ctx, err := ts.NewCtxWithVersion(ts.NTLS)
if err != nil {
panic(err)
}

if err := ctx.SetCipherList(cipherSuite); err != nil {
panic(err)
}

if signCertFile != "" {
signCertPEM, err := os.ReadFile(signCertFile)
if err != nil {
panic(err)
}
signCert, err := ts.LoadCertificateFromPEM(signCertPEM)
if err != nil {
panic(err)
}

if err := ctx.UseSignCertificate(signCert); err != nil {
panic(err)
}
}

if signKeyFile != "" {
signKeyPEM, err := os.ReadFile(signKeyFile)
if err != nil {
panic(err)
}
signKey, err := ts.LoadPrivateKeyFromPEM(signKeyPEM)
if err != nil {
panic(err)
}

if err := ctx.UseSignPrivateKey(signKey); err != nil {
panic(err)
}
}

if encCertFile != "" {
encCertPEM, err := os.ReadFile(encCertFile)
if err != nil {
panic(err)
}
encCert, err := ts.LoadCertificateFromPEM(encCertPEM)
if err != nil {
panic(err)
}

if err := ctx.UseEncryptCertificate(encCert); err != nil {
panic(err)
}
}

if encKeyFile != "" {
encKeyPEM, err := os.ReadFile(encKeyFile)
if err != nil {
panic(err)
}

encKey, err := ts.LoadPrivateKeyFromPEM(encKeyPEM)
if err != nil {
panic(err)
}

if err := ctx.UseEncryptPrivateKey(encKey); err != nil {
panic(err)
}
}

if caFile != "" {
if err := ctx.LoadVerifyLocations(caFile, ""); err != nil {
panic(err)
}
}

conn, err := ts.Dial("tcp", connAddr, ctx, ts.InsecureSkipHostVerification)
if err != nil {
panic(err)
}
defer conn.Close()

cipher, err := conn.CurrentCipher()
if err != nil {
panic(err)
}

ver, err := conn.GetVersion()
if err != nil {
panic(err)
}

fmt.Println("New connection: " + ver + ", cipher=" + cipher)

reader := bufio.NewReader(os.Stdin)
text, _ := reader.ReadString('\n')

request := text + "\n"
fmt.Println(">>>\n" + request)
if _, err := conn.Write([]byte(request)); err != nil {
panic(err)
}

buffer := make([]byte, 4096)
n, err := conn.Read(buffer)
if err != nil {
fmt.Println("read error:", err)
return
}

fmt.Println("<<<\n" + string(buffer[:n]))

return
}
178 changes: 178 additions & 0 deletions examples/server/ntls_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Copyright 2023 The Tongsuo Project Authors. All Rights Reserved.
//
// Licensed under the Apache License 2.0 (the "License"). You may not use
// this file except in compliance with the License. You can obtain a copy
// in the file LICENSE in the source distribution or at
// https://github.com/Tongsuo-Project/tongsuo-go-sdk/blob/main/LICENSE

package main

import (
"bufio"
"flag"
"log"
"net"
"os"

ts "github.com/tongsuo-project/tongsuo-go-sdk"
)

func handleConn(conn net.Conn) {
defer conn.Close()

// Read incoming data into buffer
req, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
log.Printf("Error reading incoming data: %v", err)
return
}

ntls := conn.(*ts.Conn)
ver, err := ntls.GetVersion()
if err != nil {
log.Println("failed get version: ", err)
return
}

cipher, err := ntls.CurrentCipher()
if err != nil {
log.Println("failed get cipher: ", err)
return
}

log.Println("New connection: " + ver + ", cipher=" + cipher)
log.Println("Recv:\n" + req)

// Send a response back to the client
if _, err = conn.Write([]byte(req + "\n")); err != nil {
log.Printf("Unable to send response: %v", err)
return
}

log.Println("Sent:\n" + req)
log.Println("Close connection")
}

func newNTLSServer(acceptAddr string, signCertFile string, signKeyFile string, encCertFile string, encKeyFile string, cafile string) (net.Listener, error) {
ctx, err := ts.NewCtxWithVersion(ts.NTLS)
if err != nil {
log.Println(err)
return nil, err
}

if err := ctx.LoadVerifyLocations(cafile, ""); err != nil {
log.Println(err)
return nil, err
}

encCertPEM, err := os.ReadFile(encCertFile)
if err != nil {
log.Println(err)
return nil, err
}

signCertPEM, err := os.ReadFile(signCertFile)
if err != nil {
log.Println(err)
return nil, err
}

encCert, err := ts.LoadCertificateFromPEM(encCertPEM)
if err != nil {
log.Println(err)
return nil, err
}

signCert, err := ts.LoadCertificateFromPEM(signCertPEM)
if err != nil {
log.Println(err)
return nil, err
}

if err := ctx.UseEncryptCertificate(encCert); err != nil {
log.Println(err)
return nil, err
}

if err := ctx.UseSignCertificate(signCert); err != nil {
log.Println(err)
return nil, err
}

encKeyPEM, err := os.ReadFile(encKeyFile)
if err != nil {
log.Println(err)
return nil, err
}

signKeyPEM, err := os.ReadFile(signKeyFile)
if err != nil {
log.Println(err)
return nil, err
}

encKey, err := ts.LoadPrivateKeyFromPEM(encKeyPEM)
if err != nil {
log.Println(err)
return nil, err
}

signKey, err := ts.LoadPrivateKeyFromPEM(signKeyPEM)
if err != nil {
log.Println(err)
return nil, err
}

if err := ctx.UseEncryptPrivateKey(encKey); err != nil {
log.Println(err)
return nil, err
}

if err := ctx.UseSignPrivateKey(signKey); err != nil {
log.Println(err)
return nil, err
}

lis, err := ts.Listen("tcp", acceptAddr, ctx)
if err != nil {
log.Println(err)
return nil, err
}

return lis, nil
}

func main() {
signCertFile := ""
signKeyFile := ""
encCertFile := ""
encKeyFile := ""
caFile := ""
acceptAddr := ""

flag.StringVar(&acceptAddr, "accept", "127.0.0.1:443", "host:port")
flag.StringVar(&signCertFile, "sign_cert", "", "sign certificate file")
flag.StringVar(&signKeyFile, "sign_key", "", "sign private key file")
flag.StringVar(&encCertFile, "enc_cert", "", "encrypt certificate file")
flag.StringVar(&encKeyFile, "enc_key", "", "encrypt private key file")
flag.StringVar(&caFile, "CAfile", "", "CA certificate file")

flag.Parse()

server, err := newNTLSServer(acceptAddr, signCertFile, signKeyFile, encCertFile, encKeyFile, caFile)

if err != nil {
return
}
defer server.Close()

for {
conn, err := server.Accept()
if err != nil {
log.Println("failed accept: ", err)
continue
}

go handleConn(conn)
}
}
8 changes: 7 additions & 1 deletion shim.c
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,15 @@ long X_SSL_clear_options(SSL* ssl, long options) {
long X_SSL_set_tlsext_host_name(SSL *ssl, const char *name) {
return SSL_set_tlsext_host_name(ssl, name);
}
const char * X_SSL_get_cipher_name(const SSL *ssl) {

const char *X_SSL_get_cipher_name(const SSL *ssl) {
return SSL_get_cipher_name(ssl);
}

const char *X_SSL_get_version(const SSL *ssl) {
return SSL_get_version(ssl);
}

int X_SSL_session_reused(SSL *ssl) {
return SSL_session_reused(ssl);
}
Expand Down
3 changes: 2 additions & 1 deletion shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ extern long X_SSL_set_options(SSL* ssl, long options);
extern long X_SSL_get_options(SSL* ssl);
extern long X_SSL_clear_options(SSL* ssl, long options);
extern long X_SSL_set_tlsext_host_name(SSL *ssl, const char *name);
extern const char * X_SSL_get_cipher_name(const SSL *ssl);
extern const char *X_SSL_get_cipher_name(const SSL *ssl);
extern const char *X_SSL_get_version(const SSL *ssl);
extern int X_SSL_session_reused(SSL *ssl);
extern int X_SSL_new_index();

Expand Down

0 comments on commit 6d0268c

Please sign in to comment.