Skip to content

Commit

Permalink
fix race condition on listeners
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas Chatelain committed Dec 27, 2024
1 parent 625b541 commit c1cead0
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 7 deletions.
16 changes: 13 additions & 3 deletions pkg/agent/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ func HandleConn(conn net.Conn) {
Err: false,
}
}

if err := encoder.Encode(bindResponse); err != nil {
logrus.Error(err)
}
Expand All @@ -296,7 +295,7 @@ func HandleConn(conn net.Conn) {
}
case *protocol.ListenerSockRequestPacket:
sockRequest := e.Payload.(*protocol.ListenerSockRequestPacket)
encoder := protocol.NewEncoder(conn)
socketEncDec := protocol.NewEncoderDecoder(conn)

var sockResponse protocol.ListenerSockResponsePacket
if _, ok := listenerConntrack[sockRequest.SockID]; !ok {
Expand All @@ -305,7 +304,7 @@ func HandleConn(conn net.Conn) {
sockResponse.Err = true
}

if err := encoder.Encode(sockResponse); err != nil {
if err := socketEncDec.Encode(sockResponse); err != nil {
logrus.Error(err)
return
}
Expand All @@ -314,7 +313,18 @@ func HandleConn(conn net.Conn) {
return
}

if err := socketEncDec.Decode(); err != nil {
logrus.Error(err)
return
}
netConn := listenerConntrack[sockRequest.SockID]

if err := socketEncDec.Payload.(*protocol.ListenerSocketConnectionReady).Err; err != false {
logrus.Debug("Socket relay session failed: error from proxy")
netConn.Close()
return
}

relay.StartRelay(netConn, conn)

case *protocol.ListenerCloseResponsePacket:
Expand Down
2 changes: 2 additions & 0 deletions pkg/protocol/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ func interfaceFromPayloadType(payloadType uint8) (interface{}, error) {
return &ListenerCloseRequestPacket{}, nil
case MessageListenerCloseResponse:
return &ListenerCloseResponsePacket{}, nil
case MessageListenerSocketConnectionReady:
return &ListenerSocketConnectionReady{}, nil
default:
return nil, fmt.Errorf("decode called for unknown payload type: %d", payloadType)
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/protocol/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ func payloadTypeFromInterface(payload interface{}) (uint8, error) {
return MessageListenerCloseRequest, nil
case ListenerCloseResponsePacket:
return MessageListenerCloseResponse, nil
case ListenerSocketConnectionReady:
return MessageListenerSocketConnectionReady, nil
default:
return 0, fmt.Errorf("payloadTypeFromInterface called for unknown payload type: %v", payload)
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/protocol/packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const (
MessageListenerCloseRequest
MessageListenerCloseResponse
MessageClose
MessageListenerSocketConnectionReady
)

const (
Expand Down Expand Up @@ -60,6 +61,10 @@ type ListenerSockResponsePacket struct {
Err bool
}

type ListenerSocketConnectionReady struct {
Err bool
}

// ListenerRequestPacket is used when a new listener socket is created by the proxy.
type ListenerRequestPacket struct {
Network string
Expand Down
19 changes: 15 additions & 4 deletions pkg/proxy/listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,23 +150,34 @@ func (l *LigoloListener) relayTCP() error {
logrus.Error(err)
return
}
// Get response back (ListenerSocketResponsePacket)
if err := forwarderProtocolEncDec.Decode(); err != nil {
logrus.Error(err)
return
}

if err := forwarderProtocolEncDec.Payload.(*protocol.ListenerSockResponsePacket).Err; err != false {
logrus.Error(forwarderProtocolEncDec.Payload.(*protocol.ListenerSockResponsePacket).ErrString)
return
}
// Got socket access!

logrus.Debug("Listener relay established!")
// If no error, establish TCP conn!
logrus.Debugf("Listener relay established to %s (%s)!", l.to, l.network)

// Dial the "to" target
connFailed := false
lconn, err := net.Dial(l.network, l.to)
if err != nil {
logrus.Error(err)
connFailed = true
}

// Send connect ack (avoid races)
connectionAckPacket := protocol.ListenerSocketConnectionReady{Err: connFailed}
if err := forwarderProtocolEncDec.Encode(connectionAckPacket); err != nil {
logrus.Error(err)
return
}

if connFailed {
return
}

Expand Down

0 comments on commit c1cead0

Please sign in to comment.