Skip to content

Commit c6a1750

Browse files
committed
Net listener callback
1 parent 73fc4cb commit c6a1750

File tree

2 files changed

+68
-5
lines changed

2 files changed

+68
-5
lines changed

acceptor.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ type Acceptor struct {
4949
listeners map[string]net.Listener
5050
connectionValidator ConnectionValidator
5151
tlsConfig *tls.Config
52+
newListenerCallback NewListenerCallback
5253
sessionFactory
5354
}
5455

@@ -59,6 +60,9 @@ type ConnectionValidator interface {
5960
Validate(netConn net.Conn, session SessionID) error
6061
}
6162

63+
// NewListenerCallback is a function that returns a net.Listener for the given address and tls.Config struct.
64+
type NewListenerCallback func(address string, tlsConfig *tls.Config) (net.Listener, error)
65+
6266
// Start accepting connections.
6367
func (a *Acceptor) Start() (err error) {
6468
socketAcceptHost := ""
@@ -90,6 +94,15 @@ func (a *Acceptor) Start() (err error) {
9094
a.tlsConfig = tlsConfig
9195
}
9296

97+
if a.newListenerCallback == nil {
98+
a.newListenerCallback = func(address string, tlsConfig *tls.Config) (net.Listener, error) {
99+
if tlsConfig != nil {
100+
return tls.Listen("tcp", address, a.tlsConfig)
101+
}
102+
return net.Listen("tcp", address)
103+
}
104+
}
105+
93106
var useTCPProxy bool
94107
if a.settings.GlobalSettings().HasSetting(config.UseTCPProxy) {
95108
if useTCPProxy, err = a.settings.GlobalSettings().BoolSetting(config.UseTCPProxy); err != nil {
@@ -98,11 +111,7 @@ func (a *Acceptor) Start() (err error) {
98111
}
99112

100113
for address := range a.listeners {
101-
if a.tlsConfig != nil {
102-
if a.listeners[address], err = tls.Listen("tcp", address, a.tlsConfig); err != nil {
103-
return
104-
}
105-
} else if a.listeners[address], err = net.Listen("tcp", address); err != nil {
114+
if a.listeners[address], err = a.newListenerCallback(address, a.tlsConfig); err != nil {
106115
return
107116
} else if useTCPProxy {
108117
a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]}
@@ -435,3 +444,9 @@ func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) {
435444
func (a *Acceptor) SetTLSConfig(tlsConfig *tls.Config) {
436445
a.tlsConfig = tlsConfig
437446
}
447+
448+
// SetNewListenerCallback allows the creator of the Acceptor to specify the callback used to create each net.Listener
449+
// which will be used in the Start() method.
450+
func (a *Acceptor) SetNewListenerCallback(cb NewListenerCallback) {
451+
a.newListenerCallback = cb
452+
}

acceptor_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,51 @@ func TestAcceptor_SetTLSConfig(t *testing.T) {
126126
assert.NotNil(t, conn)
127127
defer conn.Close()
128128
}
129+
130+
func TestAcceptor_SetCallback(t *testing.T) {
131+
sessionSettings := NewSessionSettings()
132+
sessionSettings.Set(config.BeginString, BeginStringFIX42)
133+
sessionSettings.Set(config.SenderCompID, "sender")
134+
sessionSettings.Set(config.TargetCompID, "target")
135+
136+
genericSettings := NewSettings()
137+
138+
genericSettings.GlobalSettings().Set("SocketAcceptPort", "5001")
139+
_, err := genericSettings.AddSession(sessionSettings)
140+
require.NoError(t, err)
141+
142+
logger, err := NewNullLogFactory().Create()
143+
require.NoError(t, err)
144+
acceptor := &Acceptor{settings: genericSettings, globalLog: logger}
145+
defer acceptor.Stop()
146+
// example of a customized tls.Config that loads the certificates dynamically by the `GetCertificate` function
147+
// as opposed to the Certificates slice, that is static in nature, and is only populated once and needs application restart to reload the certs.
148+
customizedTLSConfig := tls.Config{
149+
Certificates: []tls.Certificate{},
150+
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
151+
cert, err := tls.LoadX509KeyPair("_test_data/localhost.crt", "_test_data/localhost.key")
152+
if err != nil {
153+
return nil, err
154+
}
155+
return &cert, nil
156+
},
157+
}
158+
159+
didUseCallback := false
160+
acceptor.SetTLSConfig(&customizedTLSConfig)
161+
acceptor.SetNewListenerCallback(func(address string, tlsConfig *tls.Config) (net.Listener, error) {
162+
didUseCallback = true
163+
assert.Equal(t, &customizedTLSConfig, tlsConfig)
164+
return tls.Listen("tcp", address, tlsConfig)
165+
})
166+
assert.NoError(t, acceptor.Start())
167+
assert.Len(t, acceptor.listeners, 1)
168+
169+
conn, err := tls.Dial("tcp", "localhost:5001", &tls.Config{
170+
InsecureSkipVerify: true,
171+
})
172+
require.NoError(t, err)
173+
assert.NotNil(t, conn)
174+
assert.True(t, didUseCallback)
175+
defer conn.Close()
176+
}

0 commit comments

Comments
 (0)