From 2c053d816832954134c2a14b9641a6e2fa3ebd52 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 30 Oct 2024 11:03:07 -0700 Subject: [PATCH] tcp transport: Parameterize metrics collector in TCP --- p2p/transport/tcp/metrics.go | 29 ++++++++++++------- p2p/transport/tcp/metrics_none.go | 8 +++-- p2p/transport/tcp/tcp.go | 6 ++-- .../sampledconn/sampledconn_common.go | 25 +++++++++++----- 4 files changed, 46 insertions(+), 22 deletions(-) diff --git a/p2p/transport/tcp/metrics.go b/p2p/transport/tcp/metrics.go index 213ee2200a..50820d870c 100644 --- a/p2p/transport/tcp/metrics.go +++ b/p2p/transport/tcp/metrics.go @@ -24,7 +24,7 @@ var ( const collectFrequency = 10 * time.Second -var collector *aggregatingCollector +var defaultCollector *aggregatingCollector var initMetricsOnce sync.Once @@ -34,8 +34,8 @@ func initMetrics() { bytesSentDesc = prometheus.NewDesc("tcp_sent_bytes", "TCP bytes sent", nil, nil) bytesRcvdDesc = prometheus.NewDesc("tcp_rcvd_bytes", "TCP bytes received", nil, nil) - collector = newAggregatingCollector() - prometheus.MustRegister(collector) + defaultCollector = newAggregatingCollector() + prometheus.MustRegister(defaultCollector) const direction = "direction" @@ -196,7 +196,7 @@ func (c *aggregatingCollector) Collect(metrics chan<- prometheus.Metric) { func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { c.mutex.Lock() - collector.removeConn(conn.id) + c.removeConn(conn.id) c.mutex.Unlock() closedConns.WithLabelValues(direction).Inc() } @@ -204,6 +204,8 @@ func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { type tracingConn struct { id uint64 + collector *aggregatingCollector + startTime time.Time isClient bool @@ -213,7 +215,8 @@ type tracingConn struct { closeErr error } -func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) { +// newTracingConn wraps a manet.Conn with a tracingConn. A nil collector will use the default collector. +func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (*tracingConn, error) { initMetricsOnce.Do(func() { initMetrics() }) conn, err := tcp.NewConn(c) if err != nil { @@ -224,8 +227,12 @@ func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) { isClient: isClient, Conn: c, tcpConn: conn, + collector: collector, + } + if tc.collector == nil { + tc.collector = defaultCollector } - tc.id = collector.AddConn(tc) + tc.id = tc.collector.AddConn(tc) newConns.WithLabelValues(tc.getDirection()).Inc() return tc, nil } @@ -239,7 +246,7 @@ func (c *tracingConn) getDirection() string { func (c *tracingConn) Close() error { c.closeOnce.Do(func() { - collector.ClosedConn(c, c.getDirection()) + c.collector.ClosedConn(c, c.getDirection()) c.closeErr = c.Conn.Close() }) return c.closeErr @@ -258,10 +265,12 @@ func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) { type tracingListener struct { manet.Listener + collector *aggregatingCollector } -func newTracingListener(l manet.Listener) *tracingListener { - return &tracingListener{Listener: l} +// newTracingListener wraps a manet.Listener with a tracingListener. A nil collector will use the default collector. +func newTracingListener(l manet.Listener, collector *aggregatingCollector) *tracingListener { + return &tracingListener{Listener: l, collector: collector} } func (l *tracingListener) Accept() (manet.Conn, error) { @@ -269,5 +278,5 @@ func (l *tracingListener) Accept() (manet.Conn, error) { if err != nil { return nil, err } - return newTracingConn(conn, false) + return newTracingConn(conn, l.collector, false) } diff --git a/p2p/transport/tcp/metrics_none.go b/p2p/transport/tcp/metrics_none.go index 8538b30c89..cbee982070 100644 --- a/p2p/transport/tcp/metrics_none.go +++ b/p2p/transport/tcp/metrics_none.go @@ -6,5 +6,9 @@ package tcp import manet "github.com/multiformats/go-multiaddr/net" -func newTracingConn(c manet.Conn, _ bool) (manet.Conn, error) { return c, nil } -func newTracingListener(l manet.Listener) manet.Listener { return l } +type aggregatingCollector struct{} + +func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (manet.Conn, error) { + return c, nil +} +func newTracingListener(l manet.Listener, collector *aggregatingCollector) manet.Listener { return l } diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index 1b145c2b45..e197b26660 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -142,6 +142,8 @@ type TcpTransport struct { rcmgr network.ResourceManager reuse reuseport.Transport + + metricsCollector *aggregatingCollector } var _ transport.Transport = &TcpTransport{} @@ -231,7 +233,7 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p c := conn if t.enableMetrics { var err error - c, err = newTracingConn(conn, true) + c, err = newTracingConn(conn, t.metricsCollector, true) if err != nil { return nil, err } @@ -277,7 +279,7 @@ func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { } if t.enableMetrics { - list = newTracingListener(&tcpListener{list, 0}) + list = newTracingListener(&tcpListener{list, 0}, t.metricsCollector) } return t.upgrader.UpgradeListener(t, list), nil } diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go index eb71f7b44d..7324b45849 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go @@ -6,6 +6,8 @@ import ( "net" "syscall" "time" + + manet "github.com/multiformats/go-multiaddr/net" ) const peekSize = 3 @@ -16,7 +18,7 @@ var errNotSupported = errors.New("not supported on this platform") var ErrNotTCPConn = errors.New("passed conn is not a TCPConn") -func PeekBytes(conn net.Conn) (PeekedBytes, net.Conn, error) { +func PeekBytes(conn manet.Conn) (PeekedBytes, manet.Conn, error) { if c, ok := conn.(syscall.Conn); ok { b, err := OSPeekConn(c) if err == nil { @@ -28,7 +30,7 @@ func PeekBytes(conn net.Conn) (PeekedBytes, net.Conn, error) { // Fallback to wrapping the coonn } - if c, ok := conn.(tcpConnInterface); ok { + if c, ok := conn.(ManetTCPConnInterface); ok { return newFallbackSampledConn(c) } @@ -36,16 +38,18 @@ func PeekBytes(conn net.Conn) (PeekedBytes, net.Conn, error) { } type fallbackPeekingConn struct { - tcpConnInterface + ManetTCPConnInterface peekedBytes PeekedBytes bytesPeeked uint8 } // tcpConnInterface is the interface for TCPConn's functions -// NOTE: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be -// misused given we've read a few bytes from the connection. +// NOTE: `SyscallConn() (syscall.RawConn, error)` is here to make using this as +// a TCP Conn easier, but it's a potential footgun as you could skipped the +// peeked bytes if using the fallback type tcpConnInterface interface { net.Conn + syscall.Conn CloseRead() error CloseWrite() error @@ -60,8 +64,13 @@ type tcpConnInterface interface { io.WriterTo } -func newFallbackSampledConn(conn tcpConnInterface) (PeekedBytes, *fallbackPeekingConn, error) { - s := &fallbackPeekingConn{tcpConnInterface: conn} +type ManetTCPConnInterface interface { + manet.Conn + tcpConnInterface +} + +func newFallbackSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *fallbackPeekingConn, error) { + s := &fallbackPeekingConn{ManetTCPConnInterface: conn} _, err := io.ReadFull(conn, s.peekedBytes[:]) if err != nil { return s.peekedBytes, nil, err @@ -76,5 +85,5 @@ func (sc *fallbackPeekingConn) Read(b []byte) (int, error) { return red, nil } - return sc.tcpConnInterface.Read(b) + return sc.ManetTCPConnInterface.Read(b) }