Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quicreuse: make it possible to use an application-constructed quic.Transport #3122

Merged
merged 10 commits into from
Jan 10, 2025
65 changes: 59 additions & 6 deletions p2p/transport/quicreuse/connmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"sync"

Expand All @@ -15,6 +16,22 @@ import (
quicmetrics "github.com/quic-go/quic-go/metrics"
)

type QUICListener interface {
Accept(ctx context.Context) (quic.Connection, error)
Close() error
Addr() net.Addr
}

var _ QUICListener = &quic.Listener{}

type QUICTransport interface {
Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error)
Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error)
WriteTo(b []byte, addr net.Addr) (int, error)
ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error)
io.Closer
}

type ConnManager struct {
reuseUDP4 *reuse
reuseUDP6 *reuse
Expand Down Expand Up @@ -101,6 +118,32 @@ func (c *ConnManager) getReuse(network string) (*reuse, error) {
}
}

// LendTransport is an advanced method used to lend an existing QUICTransport
// to the ConnManager. The ConnManager will close the returned channel when it
// is done with the transport, so that the owner may safely close the transport.
func (c *ConnManager) LendTransport(network string, tr QUICTransport, conn net.PacketConn) (<-chan struct{}, error) {
c.quicListenersMu.Lock()
defer c.quicListenersMu.Unlock()

localAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
return nil, errors.New("expected a conn.LocalAddr() to return a *net.UDPAddr")
}

refCountedTr := &refcountedTransport{
QUICTransport: tr,
packetConn: conn,
borrowDoneSignal: make(chan struct{}),
}

var reuse *reuse
reuse, err := c.getReuse(network)
if err != nil {
return nil, err
}
return refCountedTr.borrowDoneSignal, reuse.AddTransport(refCountedTr, localAddr)
}

func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) {
return c.ListenQUICAndAssociate(nil, addr, tlsConf, allowWindowIncrease)
}
Expand Down Expand Up @@ -175,7 +218,7 @@ func (c *ConnManager) SharedNonQUICPacketConn(network string, laddr *net.UDPAddr
ctx: ctx,
ctxCancel: cancel,
owningTransport: t,
tr: &t.Transport,
tr: t.QUICTransport,
}, nil
}
return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set")
Expand All @@ -201,10 +244,12 @@ func (c *ConnManager) transportForListen(association any, network string, laddr
}
return &singleOwnerTransport{
packetConn: conn,
Transport: quic.Transport{
Conn: conn,
StatelessResetKey: &c.srk,
TokenGeneratorKey: &c.tokenKey,
Transport: &wrappedQUICTransport{
&quic.Transport{
Conn: conn,
StatelessResetKey: &c.srk,
TokenGeneratorKey: &c.tokenKey,
},
},
}, nil
}
Expand Down Expand Up @@ -279,7 +324,7 @@ func (c *ConnManager) TransportWithAssociationForDial(association any, network s
if err != nil {
return nil, err
}
return &singleOwnerTransport{Transport: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn}, nil
return &singleOwnerTransport{Transport: &wrappedQUICTransport{&quic.Transport{Conn: conn, StatelessResetKey: &c.srk}}, packetConn: conn}, nil
}

func (c *ConnManager) Protocols() []int {
Expand All @@ -299,3 +344,11 @@ func (c *ConnManager) Close() error {
func (c *ConnManager) ClientConfig() *quic.Config {
return c.clientConfig
}

type wrappedQUICTransport struct {
*quic.Transport
}

func (t *wrappedQUICTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) {
return t.Transport.Listen(tlsConf, conf)
}
62 changes: 60 additions & 2 deletions p2p/transport/quicreuse/connmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestConnectionPassedToQUICForListening(t *testing.T) {
quicTr, err := cm.transportForListen(nil, netw, naddr)
require.NoError(t, err)
defer quicTr.Close()
if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok {
if _, ok := quicTr.(*singleOwnerTransport).packetConn.(quic.OOBCapablePacketConn); !ok {
t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn")
}
}
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestConnectionPassedToQUICForDialing(t *testing.T) {

require.NoError(t, err, "dial error")
defer quicTr.Close()
if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok {
if _, ok := quicTr.(*singleOwnerTransport).packetConn.(quic.OOBCapablePacketConn); !ok {
t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn")
}
}
Expand Down Expand Up @@ -257,3 +257,61 @@ func testListener(t *testing.T, enableReuseport bool) {

checkClosed(t, cm)
}

func TestExternalTransport(t *testing.T) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero})
require.NoError(t, err)
defer conn.Close()
port := conn.LocalAddr().(*net.UDPAddr).Port
tr := &quic.Transport{Conn: conn}
defer tr.Close()

cm, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})
require.NoError(t, err)
doneWithTr, err := cm.LendTransport("udp4", &wrappedQUICTransport{tr}, conn)
require.NoError(t, err)

// make sure this transport is used when listening on the same port
ln, err := cm.ListenQUICAndAssociate(
"quic",
ma.StringCast(fmt.Sprintf("/ip4/0.0.0.0/udp/%d", port)),
&tls.Config{NextProtos: []string{"libp2p"}},
func(quic.Connection, uint64) bool { return false },
)
require.NoError(t, err)
defer ln.Close()
require.Equal(t, port, ln.Addr().(*net.UDPAddr).Port)

// make sure this transport is used when dialing out
udpLn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
defer udpLn.Close()
addrChan := make(chan net.Addr, 1)
go func() {
_, addr, _ := udpLn.ReadFrom(make([]byte, 2000))
addrChan <- addr
}()
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
_, err = cm.DialQUIC(
ctx,
ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic-v1", udpLn.LocalAddr().(*net.UDPAddr).Port)),
&tls.Config{NextProtos: []string{"libp2p"}},
func(quic.Connection, uint64) bool { return false },
)
require.ErrorIs(t, err, context.DeadlineExceeded)

select {
case addr := <-addrChan:
require.Equal(t, port, addr.(*net.UDPAddr).Port)
case <-time.After(time.Second):
t.Fatal("timeout")
}

cm.Close()
select {
case <-doneWithTr:
default:
t.Fatal("doneWithTr not closed")
}
}
2 changes: 1 addition & 1 deletion p2p/transport/quicreuse/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type protoConf struct {
}

type quicListener struct {
l *quic.Listener
l QUICListener
transport refCountedQuicTransport
running chan struct{}
addrs []ma.Multiaddr
Expand Down
6 changes: 2 additions & 4 deletions p2p/transport/quicreuse/nonquic_packetconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@ import (
"context"
"net"
"time"

"github.com/quic-go/quic-go"
)

// nonQUICPacketConn is a net.PacketConn that can be used to read and write
// non-QUIC packets on a quic.Transport. This lets us reuse this UDP port for
// other transports like WebRTC.
type nonQUICPacketConn struct {
owningTransport refCountedQuicTransport
tr *quic.Transport
tr QUICTransport
ctx context.Context
ctxCancel context.CancelFunc
readCtx context.Context
Expand All @@ -32,7 +30,7 @@ func (n *nonQUICPacketConn) Close() error {

// LocalAddr implements net.PacketConn.
func (n *nonQUICPacketConn) LocalAddr() net.Addr {
return n.tr.Conn.LocalAddr()
return n.owningTransport.LocalAddr()
}

// ReadFrom implements net.PacketConn.
Expand Down
86 changes: 66 additions & 20 deletions p2p/transport/quicreuse/reuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package quicreuse
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"time"
Expand All @@ -25,23 +27,30 @@ type refCountedQuicTransport interface {
IncreaseCount()

Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error)
Listen(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, error)
Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error)
}

type singleOwnerTransport struct {
quic.Transport
Transport QUICTransport

// Used to write packets directly around QUIC.
packetConn net.PacketConn
}

var _ QUICTransport = &singleOwnerTransport{}

func (c *singleOwnerTransport) IncreaseCount() {}
func (c *singleOwnerTransport) DecreaseCount() {
c.Transport.Close()
func (c *singleOwnerTransport) DecreaseCount() { c.Transport.Close() }
func (c *singleOwnerTransport) LocalAddr() net.Addr {
return c.packetConn.LocalAddr()
}

func (c *singleOwnerTransport) LocalAddr() net.Addr {
return c.Transport.Conn.LocalAddr()
func (c *singleOwnerTransport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error) {
return c.Transport.Dial(ctx, addr, tlsConf, conf)
}

func (c *singleOwnerTransport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) {
return c.Transport.ReadNonQUICPacket(ctx, b)
}

func (c *singleOwnerTransport) Close() error {
Expand All @@ -54,14 +63,18 @@ func (c *singleOwnerTransport) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.Transport.WriteTo(b, addr)
}

func (c *singleOwnerTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) {
return c.Transport.Listen(tlsConf, conf)
}

// Constant. Defined as variables to simplify testing.
var (
garbageCollectInterval = 30 * time.Second
maxUnusedDuration = 10 * time.Second
)

type refcountedTransport struct {
quic.Transport
QUICTransport

// Used to write packets directly around QUIC.
packetConn net.PacketConn
Expand All @@ -70,6 +83,11 @@ type refcountedTransport struct {
refCount int
unusedSince time.Time

// Only set for transports we are borrowing.
// If set, we will _never_ close the underlying transport. We only close this
// channel to signal to the owner that we are done with it.
borrowDoneSignal chan struct{}

assocations map[any]struct{}
}

Expand Down Expand Up @@ -109,17 +127,24 @@ func (c *refcountedTransport) IncreaseCount() {
}

func (c *refcountedTransport) Close() error {
// TODO(when we drop support for go 1.19) use errors.Join
c.Transport.Close()
return c.packetConn.Close()
if c.borrowDoneSignal != nil {
close(c.borrowDoneSignal)
return nil
}

return errors.Join(c.QUICTransport.Close(), c.packetConn.Close())
}

func (c *refcountedTransport) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.Transport.WriteTo(b, addr)
return c.QUICTransport.WriteTo(b, addr)
}

func (c *refcountedTransport) LocalAddr() net.Addr {
return c.Transport.Conn.LocalAddr()
return c.packetConn.LocalAddr()
}

func (c *refcountedTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) {
return c.QUICTransport.Listen(tlsConf, conf)
}

func (c *refcountedTransport) DecreaseCount() {
Expand Down Expand Up @@ -302,15 +327,34 @@ func (r *reuse) transportForDialLocked(association any, network string, source *
if err != nil {
return nil, err
}
tr := &refcountedTransport{Transport: quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
TokenGeneratorKey: r.tokenGeneratorKey,
}, packetConn: conn}
tr := &refcountedTransport{
QUICTransport: &wrappedQUICTransport{
Transport: &quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
TokenGeneratorKey: r.tokenGeneratorKey,
},
},
packetConn: conn,
}
r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = tr
return tr, nil
}

func (r *reuse) AddTransport(tr *refcountedTransport, laddr *net.UDPAddr) error {
r.mutex.Lock()
defer r.mutex.Unlock()

if !laddr.IP.IsUnspecified() {
return errors.New("adding transport for specific IP not supported")
}
if _, ok := r.globalDialers[laddr.Port]; ok {
return fmt.Errorf("already have global dialer for port %d", laddr.Port)
}
r.globalDialers[laddr.Port] = tr
return nil
}

func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
Expand Down Expand Up @@ -351,9 +395,11 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun
}
localAddr := conn.LocalAddr().(*net.UDPAddr)
tr := &refcountedTransport{
Transport: quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
QUICTransport: &wrappedQUICTransport{
Transport: &quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
},
},
packetConn: conn,
}
Expand Down
Loading