Skip to content

Commit

Permalink
add transport test for error code
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 20, 2024
1 parent f7d4753 commit c3cf2a1
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 3 deletions.
4 changes: 3 additions & 1 deletion p2p/net/swarm/swarm_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ func (s *Stream) Reset() error {
}

func (s *Stream) ResetWithError(errCode network.StreamErrorCode) error {
panic("not implemented")
err := s.stream.ResetWithError(errCode)
s.closeAndRemoveStream()
return err
}

func (s *Stream) closeAndRemoveStream() {
Expand Down
72 changes: 72 additions & 0 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"go.uber.org/mock/gomock"

ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -798,3 +799,74 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) {
})
}
}

// TestStreamErrorCode tests correctness for resetting stream with errors
func TestStreamErrorCode(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
if tc.Name != "QUIC" {
t.Skip("only implemented for QUIC")
return
}
server := tc.HostGenerator(t, TransportTestCaseOpts{})
client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
defer server.Close()
defer client.Close()

checkError := func(err error, code network.StreamErrorCode, remote bool) {
t.Helper()
if err == nil {
t.Fatal("expected non nil error")
}
se := &network.StreamError{}
if errors.As(err, &se) {
require.Equal(t, se.ErrorCode, code)
require.Equal(t, se.Remote, remote)
return
}
t.Fatal("expected network.StreamError, got:", err)
}

errCh := make(chan error)
server.SetStreamHandler("/test", func(s network.Stream) {
defer s.Reset()
b := make([]byte, 10)
n, err := s.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = s.Write(b[:n])
if !assert.NoError(t, err) {
return
}
_, err = s.Read(b)
errCh <- err
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
s, err := client.NewStream(ctx, server.ID(), "/test")
require.NoError(t, err)

_, err = s.Write([]byte("hello"))
require.NoError(t, err)

buf := make([]byte, 10)
n, err := s.Read(buf)
require.NoError(t, err)
require.Equal(t, []byte("hello"), buf[:n])

err = s.ResetWithError(42)
require.NoError(t, err)

_, err = s.Read(buf)
checkError(err, 42, false)

_, err = s.Write(buf)
checkError(err, 42, false)

err = <-errCh
checkError(err, 42, true)
})
}
}
42 changes: 42 additions & 0 deletions p2p/transport/quic/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ func TestStreams(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
testStreams(t, tc)
})
t.Run(tc.Name, func(t *testing.T) {
testStreamsErrorCode(t, tc)
})
}
}

Expand Down Expand Up @@ -305,6 +308,45 @@ func testStreams(t *testing.T, tc *connTestCase) {
require.Equal(t, data, []byte("foobar"))
}

func testStreamsErrorCode(t *testing.T, tc *connTestCase) {
serverID, serverKey := createPeer(t)
_, clientKey := createPeer(t)

serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil)
require.NoError(t, err)
defer serverTransport.(io.Closer).Close()
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
defer ln.Close()

clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil)
require.NoError(t, err)
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
serverConn, err := ln.Accept()
require.NoError(t, err)
defer serverConn.Close()

str, err := conn.OpenStream(context.Background())
require.NoError(t, err)
err = str.ResetWithError(42)
require.NoError(t, err)

sstr, err := serverConn.AcceptStream()
require.NoError(t, err)
_, err = io.ReadAll(sstr)
require.Error(t, err)
se := &network.StreamError{}
if errors.As(err, &se) {
require.Equal(t, se.ErrorCode, network.StreamErrorCode(42))
require.True(t, se.Remote)
} else {
t.Fatalf("expected error to be of network.StreamError type, got %T, %v", err, err)
}

}

func TestHandshakeFailPeerIDMismatch(t *testing.T) {
for _, tc := range connTestCases {
t.Run(tc.Name, func(t *testing.T) {
Expand Down
7 changes: 5 additions & 2 deletions p2p/transport/quic/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ type stream struct {
var _ network.MuxedStream = &stream{}

func parseStreamError(err error) error {
if err == nil {
return err
}
se := &quic.StreamError{}
if err != nil && errors.As(err, &se) {
if errors.As(err, &se) {
code := se.ErrorCode
if code > math.MaxUint32 {
code = 0
code = reset
}
err = &network.StreamError{
ErrorCode: network.StreamErrorCode(code),
Expand Down

0 comments on commit c3cf2a1

Please sign in to comment.