Skip to content

Commit b17e799

Browse files
committed
net: add multi listener impl for net.Listener
This adds an implementation of net.Listener which listens on and accepts connections from multiple addresses. Signed-off-by: Daman Arora <[email protected]>
1 parent fe8a2dd commit b17e799

File tree

2 files changed

+651
-0
lines changed

2 files changed

+651
-0
lines changed

net/multi_listen.go

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
Copyright 2024 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package net
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"net"
23+
"reflect"
24+
"sync"
25+
)
26+
27+
// connErrPair pairs conn and error which is returned by accept on sub-listeners.
28+
type connErrPair struct {
29+
conn net.Conn
30+
err error
31+
}
32+
33+
// multiListener implements net.Listener
34+
type multiListener struct {
35+
listeners []net.Listener
36+
wg sync.WaitGroup
37+
mu sync.Mutex
38+
closed bool
39+
40+
// connErrChs is used as a buffer of accepted connections, one entry per
41+
// sub-listener.
42+
connErrChs []chan connErrPair
43+
}
44+
45+
// compile time check to ensure *multiListener implements net.Listener
46+
var _ net.Listener = &multiListener{}
47+
48+
// MultiListen returns net.Listener which can listen on and accept connections for
49+
// the given network on multiple addresses. Internally it uses stdlib to create
50+
// sub-listener and multiplexes connection requests using go-routines.
51+
// The network must be "tcp", "tcp4" or "tcp6".
52+
// It follows the semantics of net.Listen that primarily means:
53+
// 1. If the host is an unspecified/zero IP address with "tcp" network, MultiListen
54+
// listens on all available unicast and anycast IP addresses of the local system.
55+
// 2. Use "tcp4" or "tcp6" to exclusively listen on IPv4 or IPv6 family, respectively.
56+
// 3. The host can accept names (e.g, localhost) and it will create a listener for at
57+
// most one of the host's IP.
58+
func MultiListen(ctx context.Context, network string, addrs ...string) (net.Listener, error) {
59+
var lc net.ListenConfig
60+
return multiListen(
61+
ctx,
62+
network,
63+
addrs,
64+
func(ctx context.Context, network, address string) (net.Listener, error) {
65+
return lc.Listen(ctx, network, address)
66+
})
67+
}
68+
69+
// multiListen implements MultiListen by consuming stdlib functions as dependency allowing
70+
// mocking for unit-testing.
71+
func multiListen(
72+
ctx context.Context,
73+
network string,
74+
addrs []string,
75+
listenFunc func(ctx context.Context, network, address string) (net.Listener, error),
76+
) (net.Listener, error) {
77+
if !(network == "tcp" || network == "tcp4" || network == "tcp6") {
78+
return nil, fmt.Errorf("network %q not supported", network)
79+
}
80+
if len(addrs) == 0 {
81+
return nil, fmt.Errorf("no address provided to listen on")
82+
}
83+
84+
ml := &multiListener{}
85+
for _, addr := range addrs {
86+
l, err := listenFunc(ctx, network, addr)
87+
if err != nil {
88+
// close all the sub-listeners and exit
89+
_ = ml.Close()
90+
return nil, err
91+
}
92+
ml.listeners = append(ml.listeners, l)
93+
}
94+
95+
for _, l := range ml.listeners {
96+
ml.wg.Add(1)
97+
connErrCh := make(chan connErrPair)
98+
ml.connErrChs = append(ml.connErrChs, connErrCh)
99+
go func(l net.Listener) {
100+
defer ml.wg.Done()
101+
for {
102+
// Accept() is blocking, unless ml.Close() is called, in which
103+
// case it will return immediately with an error.
104+
conn, err := l.Accept()
105+
106+
ml.mu.Lock()
107+
closed := ml.closed
108+
ml.mu.Unlock()
109+
if closed {
110+
return
111+
}
112+
connErrCh <- connErrPair{conn: conn, err: err}
113+
}
114+
}(l)
115+
}
116+
return ml, nil
117+
}
118+
119+
// Accept implements net.Listener. It waits for and returns a connection from
120+
// any of the sub-listener.
121+
func (ml *multiListener) Accept() (net.Conn, error) {
122+
cases := make([]reflect.SelectCase, len(ml.connErrChs))
123+
for i, ch := range ml.connErrChs {
124+
cases[i] = reflect.SelectCase{
125+
Dir: reflect.SelectRecv,
126+
Chan: reflect.ValueOf(ch),
127+
}
128+
}
129+
130+
// wait for any sub-listener to enqueue an accepted connection
131+
_, value, ok := reflect.Select(cases)
132+
if !ok {
133+
// All the "connErrChs" channels will be closed only when Close() is called on the multiListener.
134+
// Closing of the channels implies that all sub-listeners are closed, which causes a
135+
// "use of closed network connection" error on their Accept() calls. We return the same error
136+
// for multiListener.Accept() if multiListener.Close() has already been called.
137+
return nil, fmt.Errorf("use of closed network connection")
138+
}
139+
connErr := value.Interface().(connErrPair)
140+
return connErr.conn, connErr.err
141+
}
142+
143+
// Close implements net.Listener. It will close all sub-listeners and wait for
144+
// the go-routines to exit.
145+
func (ml *multiListener) Close() error {
146+
ml.mu.Lock()
147+
closed := ml.closed
148+
ml.mu.Unlock()
149+
if closed {
150+
return fmt.Errorf("use of closed network connection")
151+
}
152+
153+
ml.mu.Lock()
154+
ml.closed = true
155+
// close the "connErrChs" channels after closing queued connections if any.
156+
for _, connErrCh := range ml.connErrChs {
157+
select {
158+
case connErr := <-connErrCh:
159+
if connErr.conn != nil {
160+
_ = connErr.conn.Close()
161+
}
162+
default:
163+
}
164+
close(connErrCh)
165+
}
166+
ml.mu.Unlock()
167+
168+
// Closing the listeners causes Accept() to immediately return an error,
169+
// which serves as the exit condition for the sub-listener go-routines.
170+
for _, l := range ml.listeners {
171+
_ = l.Close()
172+
}
173+
174+
// Wait for all the sub-listener go-routines to exit.
175+
ml.wg.Wait()
176+
return nil
177+
}
178+
179+
// Addr is an implementation of the net.Listener interface. It always returns
180+
// the address of the first listener. Callers should use conn.LocalAddr() to
181+
// obtain the actual local address of the sub-listener.
182+
func (ml *multiListener) Addr() net.Addr {
183+
return ml.listeners[0].Addr()
184+
}
185+
186+
// Addrs is like Addr, but returns the address for all registered listeners.
187+
func (ml *multiListener) Addrs() []net.Addr {
188+
var ret []net.Addr
189+
for _, l := range ml.listeners {
190+
ret = append(ret, l.Addr())
191+
}
192+
return ret
193+
}

0 commit comments

Comments
 (0)