Skip to content

Commit

Permalink
Simply modify
Browse files Browse the repository at this point in the history
  • Loading branch information
ZBCccc committed Sep 19, 2024
1 parent e081e85 commit 5a34186
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 34 deletions.
25 changes: 13 additions & 12 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,13 +576,6 @@ func (c *Ctx) SetTLSExtServernameCallback(sni_cb TLSExtServernameCallback) {

type TLSExtAlpnCallback func(ssl *SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, inlen uint, arg unsafe.Pointer) SSLTLSExtErr

// SetTLSExtAlpnCallback sets callback function for Application Layer Protocol Negotiation
// (ALPN) rfc7301 (https://tools.ietf.org/html/rfc7301).
func (c *Ctx) SetTLSExtAlpnCallback(alpn_cb TLSExtAlpnCallback, arg unsafe.Pointer) {
c.alpn_cb = alpn_cb
C.SSL_CTX_set_alpn_select_cb(c.ctx, (*[0]byte)(C.alpn_cb), arg)
}

func (ctx *Ctx) SetServerALPNProtos(protos []string) {
// Construct the protocol list (format: length byte of each protocol + protocol content)
var protoList []byte
Expand All @@ -591,20 +584,28 @@ func (ctx *Ctx) SetServerALPNProtos(protos []string) {
protoList = append(protoList, []byte(proto)...) // Add the protocol content
}

ctx.SetTLSExtAlpnCallback(func(ssl *SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, inlen uint, arg unsafe.Pointer) SSLTLSExtErr {
ctx.alpn_cb = func(ssl *SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, inlen uint, arg unsafe.Pointer) SSLTLSExtErr {
// Use OpenSSL function to select the protocol
ret := ssl.SslSelectNextProto(out, outlen, unsafe.Pointer(&protoList[0]), uint(len(protoList)), in, inlen)
ret := C.SSL_select_next_proto(
(**C.uchar)(out),
(*C.uchar)(outlen),
(*C.uchar)(unsafe.Pointer(&protoList[0])),
C.uint(len(protoList)),
(*C.uchar)(in),
C.uint(inlen),
)

if ret != OPENSSL_NPN_NEGOTIATED {
return SSLTLSExtErrAlertFatal
}

return SSLTLSExtErrOK
}, nil)
}
C.SSL_CTX_set_alpn_select_cb(ctx.ctx, (*[0]byte)(C.alpn_cb), nil)
}

// SetALPNProtos sets the ALPN protocol list
func (ctx *Ctx) SetALPNProtos(protos []string) error {
// SetClientALPNProtos sets the ALPN protocol list
func (ctx *Ctx) SetClientALPNProtos(protos []string) error {
// Construct the protocol list (format: length byte of each protocol + protocol content)
var protoList []byte
for _, proto := range protos {
Expand Down
2 changes: 1 addition & 1 deletion examples/tlcp_client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func main() {
panic(err)
}

if err := ctx.SetALPNProtos(alpnProtocols); err != nil {
if err := ctx.SetClientALPNProtos(alpnProtocols); err != nil {
panic(err)
}

Expand Down
7 changes: 3 additions & 4 deletions examples/tlcp_server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func handleConn(conn net.Conn) {
log.Println("Close connection")
}

func newNTLSServerWithSNI(acceptAddr string, certKeyPairs map[string]crypto.GMDoubleCertKey, cafile string) (net.Listener, error) {
func newNTLSServer(acceptAddr string, certKeyPairs map[string]crypto.GMDoubleCertKey, cafile string, alpnProtocols []string) (net.Listener, error) {

ctx, err := ts.NewCtxWithVersion(ts.NTLS)
if err != nil {
Expand All @@ -103,8 +103,7 @@ func newNTLSServerWithSNI(acceptAddr string, certKeyPairs map[string]crypto.GMDo
}

// Set ALPN
supportedProtos := []string{"h2", "http/1.1"}
ctx.SetServerALPNProtos(supportedProtos)
ctx.SetServerALPNProtos(alpnProtocols)

// Set SNI callback
ctx.SetTLSExtServernameCallback(func(ssl *ts.SSL) ts.SSLTLSExtErr {
Expand Down Expand Up @@ -304,7 +303,7 @@ func main() {
return
}

server, err := newNTLSServerWithSNI(acceptAddr, certFiles, caFile)
server, err := newNTLSServer(acceptAddr, certFiles, caFile, alpnProtocols)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion ntls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ func TestALPN(t *testing.T) {
}

// Set the ALPN protocols for the context
if err := ctx.SetALPNProtos(alpnProtocols); err != nil {
if err := ctx.SetClientALPNProtos(alpnProtocols); err != nil {
t.Error(err)
return
}
Expand Down
16 changes: 0 additions & 16 deletions ssl.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,3 @@ func alpn_cb_thunk(p unsafe.Pointer, con *C.SSL, out unsafe.Pointer, outlen unsa
// Ensure the out parameter is treated as a pointer to const unsigned char
return C.int(alpn_cb(s, out, outlen, in, inlen, arg))
}

// SslSelectNextProto selects the next protocol from the list of protocols
// provided by the server and the client's list of supported protocols.
// It takes pointers to the output buffer, output length, server buffer,
// server length, input buffer, and input length as parameters. The function
// returns an integer indicating the result of the selection process.
func (s *SSL) SslSelectNextProto(out unsafe.Pointer, outlen unsafe.Pointer, server unsafe.Pointer, serverlen uint, in unsafe.Pointer, inlen uint) C.int {
return C.SSL_select_next_proto(
(**C.uchar)(out),
(*C.uchar)(outlen),
(*C.uchar)(server),
C.uint(serverlen),
(*C.uchar)(in),
C.uint(inlen),
)
}

0 comments on commit 5a34186

Please sign in to comment.