Skip to content

Commit a7e198e

Browse files
Fanglidingmmmray
andauthored
Fix WS reading X-Forwarded-For & Add tests (#3546)
Fixes #3545 --------- Co-authored-by: mmmray <[email protected]>
1 parent 9e6d7a3 commit a7e198e

File tree

4 files changed

+15
-11
lines changed

4 files changed

+15
-11
lines changed

transport/internet/httpupgrade/httpupgrade_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
151151
return
152152
}
153153

154-
_, err = c.Write([]byte("Response"))
154+
_, err = c.Write([]byte(c.RemoteAddr().String()))
155155
common.Must(err)
156156
}(conn)
157157
})
@@ -169,7 +169,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
169169
var b [1024]byte
170170
n, err := conn.Read(b[:])
171171
common.Must(err)
172-
if string(b[:n]) != "Response" {
172+
if string(b[:n]) != "1.1.1.1:0" {
173173
t.Error("response: ", string(b[:n]))
174174
}
175175

transport/internet/splithttp/splithttp_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
9696
return
9797
}
9898

99-
_, err = c.Write([]byte("Response"))
99+
_, err = c.Write([]byte(c.RemoteAddr().String()))
100100
common.Must(err)
101101
}(conn)
102102
})
@@ -113,7 +113,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
113113

114114
var b [1024]byte
115115
n, _ := conn.Read(b[:])
116-
if string(b[:n]) != "Response" {
116+
if string(b[:n]) != "1.1.1.1:0" {
117117
t.Error("response: ", string(b[:n]))
118118
}
119119

transport/internet/websocket/connection.go

+9-5
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@ import (
1414
var _ buf.Writer = (*connection)(nil)
1515

1616
// connection is a wrapper for net.Conn over WebSocket connection.
17+
// remoteAddr is used to pass "virtual" remote IP addresses in X-Forwarded-For.
18+
// so we shouldn't directly read it form conn.
1719
type connection struct {
18-
conn *websocket.Conn
19-
reader io.Reader
20+
conn *websocket.Conn
21+
reader io.Reader
22+
remoteAddr net.Addr
2023
}
2124

2225
func NewConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection {
2326
return &connection{
24-
conn: conn,
25-
reader: extraReader,
27+
conn: conn,
28+
remoteAddr: remoteAddr,
29+
reader: extraReader,
2630
}
2731
}
2832

@@ -90,7 +94,7 @@ func (c *connection) LocalAddr() net.Addr {
9094
}
9195

9296
func (c *connection) RemoteAddr() net.Addr {
93-
return c.conn.RemoteAddr()
97+
return c.remoteAddr
9498
}
9599

96100
func (c *connection) SetDeadline(t time.Time) error {

transport/internet/websocket/ws_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
9191
return
9292
}
9393

94-
_, err = c.Write([]byte("Response"))
94+
_, err = c.Write([]byte(c.RemoteAddr().String()))
9595
common.Must(err)
9696
}(conn)
9797
})
@@ -109,7 +109,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
109109
var b [1024]byte
110110
n, err := conn.Read(b[:])
111111
common.Must(err)
112-
if string(b[:n]) != "Response" {
112+
if string(b[:n]) != "1.1.1.1:0" {
113113
t.Error("response: ", string(b[:n]))
114114
}
115115

0 commit comments

Comments
 (0)