@@ -10,6 +10,7 @@ import (
1010	"net/http" 
1111	"net/http/httptest" 
1212	"strings" 
13+ 	"sync" 
1314	"testing" 
1415
1516	"nhooyr.io/websocket/internal/test/assert" 
@@ -142,6 +143,42 @@ func TestAccept(t *testing.T) {
142143		_ , err  :=  Accept (w , r , nil )
143144		assert .Contains (t , err , `failed to hijack connection` )
144145	})
146+ 	t .Run ("closeRace" , func (t  * testing.T ) {
147+ 		t .Parallel ()
148+ 
149+ 		server , _  :=  net .Pipe ()
150+ 
151+ 		rw  :=  bufio .NewReadWriter (bufio .NewReader (server ), bufio .NewWriter (server ))
152+ 		newResponseWriter  :=  func () http.ResponseWriter  {
153+ 			return  mockHijacker {
154+ 				ResponseWriter : httptest .NewRecorder (),
155+ 				hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
156+ 					return  server , rw , nil 
157+ 				},
158+ 			}
159+ 		}
160+ 		w  :=  newResponseWriter ()
161+ 
162+ 		r  :=  httptest .NewRequest ("GET" , "/" , nil )
163+ 		r .Header .Set ("Connection" , "Upgrade" )
164+ 		r .Header .Set ("Upgrade" , "websocket" )
165+ 		r .Header .Set ("Sec-WebSocket-Version" , "13" )
166+ 		r .Header .Set ("Sec-WebSocket-Key" , xrand .Base64 (16 ))
167+ 
168+ 		c , err  :=  Accept (w , r , nil )
169+ 		wg  :=  & sync.WaitGroup {}
170+ 		wg .Add (2 )
171+ 		go  func () {
172+ 			c .Close (StatusInternalError , "the sky is falling" )
173+ 			wg .Done ()
174+ 		}()
175+ 		go  func () {
176+ 			c .CloseNow ()
177+ 			wg .Done ()
178+ 		}()
179+ 		wg .Wait ()
180+ 		assert .Success (t , err )
181+ 	})
145182}
146183
147184func  Test_verifyClientHandshake (t  * testing.T ) {
0 commit comments