Skip to content

Commit 426741a

Browse files
committed
testing: add TCPProxy
1 parent 6e944d6 commit 426741a

File tree

5 files changed

+197
-6
lines changed

5 files changed

+197
-6
lines changed

mgo_unix.go

-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,3 @@ import "os"
1313
func (inst *MgoInstance) DestroyWithLog() {
1414
inst.killAndCleanup(os.Interrupt)
1515
}
16-

mgo_windows.go

-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,3 @@ import "os"
1616
func (inst *MgoInstance) DestroyWithLog() {
1717
inst.killAndCleanup(os.Kill)
1818
}
19-

osenv.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,25 @@ var testingVariables = []string{
6767
}
6868

6969
func (s *OsEnvSuite) setEnviron() {
70-
var isWhitelisted func (string) bool
70+
var isWhitelisted func(string) bool
7171
switch runtime.GOOS {
7272
case "windows":
7373
// Lowercase variable names for comparison as they are case
7474
// insenstive on windows. Fancy folding not required for ascii.
7575
lowerEnv := make(map[string]struct{},
76-
len(windowsVariables) + len(testingVariables))
76+
len(windowsVariables)+len(testingVariables))
7777
for _, envVar := range windowsVariables {
7878
lowerEnv[strings.ToLower(envVar)] = struct{}{}
7979
}
8080
for _, envVar := range testingVariables {
8181
lowerEnv[strings.ToLower(envVar)] = struct{}{}
8282
}
83-
isWhitelisted = func (envVar string) bool {
83+
isWhitelisted = func(envVar string) bool {
8484
_, ok := lowerEnv[strings.ToLower(envVar)]
8585
return ok
8686
}
8787
default:
88-
isWhitelisted = func (envVar string) bool {
88+
isWhitelisted = func(envVar string) bool {
8989
for _, testingVar := range testingVariables {
9090
if testingVar == envVar {
9191
return true

tcpproxy.go

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright 2015 Canonical Ltd.
2+
// Licensed under the LGPLv3, see LICENCE file for details.
3+
4+
package testing
5+
6+
import (
7+
"io"
8+
"net"
9+
"sync"
10+
11+
jc "github.com/juju/testing/checkers"
12+
gc "gopkg.in/check.v1"
13+
)
14+
15+
// TCPProxy is a simple TCP proxy that can be used
16+
// to deliberately break TCP connections.
17+
type TCPProxy struct {
18+
listener net.Listener
19+
// mu guards the fields below it.
20+
mu sync.Mutex
21+
// closed holds whether the proxy has been closed.
22+
closed bool
23+
// conns holds all connections that have been made.
24+
conns []io.Closer
25+
}
26+
27+
// NewTCPProxy runs a proxy that copies to and from
28+
// the given remote TCP address. When the proxy
29+
// is closed, its listener and all connections will be closed.
30+
func NewTCPProxy(c *gc.C, remoteAddr string) *TCPProxy {
31+
listener, err := net.Listen("tcp", "127.0.0.1:0")
32+
c.Assert(err, jc.ErrorIsNil)
33+
p := &TCPProxy{
34+
listener: listener,
35+
}
36+
go func() {
37+
for {
38+
client, err := p.listener.Accept()
39+
if err != nil {
40+
if !p.isClosed() {
41+
c.Error("cannot accept: %v", err)
42+
}
43+
return
44+
}
45+
p.addConn(client)
46+
server, err := net.Dial("tcp", remoteAddr)
47+
if err != nil {
48+
if !p.isClosed() {
49+
c.Error("cannot dial remote address: %v", err)
50+
}
51+
return
52+
}
53+
p.addConn(server)
54+
go stream(client, server)
55+
go stream(server, client)
56+
}
57+
}()
58+
return p
59+
}
60+
61+
func (p *TCPProxy) addConn(c net.Conn) {
62+
p.mu.Lock()
63+
defer p.mu.Unlock()
64+
if p.closed {
65+
c.Close()
66+
} else {
67+
p.conns = append(p.conns, c)
68+
}
69+
}
70+
71+
// Close closes the TCPProxy and any connections that
72+
// are currently active.
73+
func (p *TCPProxy) Close() error {
74+
p.mu.Lock()
75+
defer p.mu.Unlock()
76+
p.closed = true
77+
p.listener.Close()
78+
for _, c := range p.conns {
79+
c.Close()
80+
}
81+
return nil
82+
}
83+
84+
// Addr returns the TCP address of the proxy. Dialing
85+
// this address will cause a connection to be made
86+
// to the remote address; any data written will be
87+
// written there, and any data read from the remote
88+
// address will be available to read locally.
89+
func (p *TCPProxy) Addr() string {
90+
// Note: this only works because we explicitly listen on 127.0.0.1 rather
91+
// than the wildcard address.
92+
return p.listener.Addr().String()
93+
}
94+
func (p *TCPProxy) isClosed() bool {
95+
p.mu.Lock()
96+
defer p.mu.Unlock()
97+
return p.closed
98+
}
99+
100+
func stream(dst io.WriteCloser, src io.ReadCloser) {
101+
defer dst.Close()
102+
defer src.Close()
103+
io.Copy(dst, src)
104+
}

tcpproxy_test.go

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright 2015 Canonical Ltd.
2+
// Licensed under the LGPLv3, see LICENCE file for details.
3+
4+
package testing_test
5+
6+
import (
7+
"fmt"
8+
"io"
9+
"net"
10+
"sync"
11+
12+
"github.com/juju/testing"
13+
gc "gopkg.in/check.v1"
14+
)
15+
16+
var _ = gc.Suite(&tcpProxySuite{})
17+
18+
type tcpProxySuite struct{}
19+
20+
func (*tcpProxySuite) TestTCPProxy(c *gc.C) {
21+
var wg sync.WaitGroup
22+
23+
listener, err := net.Listen("tcp", "127.0.0.1:0")
24+
c.Assert(err, gc.IsNil)
25+
defer listener.Close()
26+
wg.Add(1)
27+
go tcpEcho(&wg, listener)
28+
29+
p := testing.NewTCPProxy(c, listener.Addr().String())
30+
c.Assert(p.Addr(), gc.Not(gc.Equals), listener.Addr().String())
31+
32+
// Dial the proxy and check that we see the text echoed correctly.
33+
conn, err := net.Dial("tcp", p.Addr())
34+
c.Assert(err, gc.IsNil)
35+
defer conn.Close()
36+
txt := "hello, world\n"
37+
fmt.Fprint(conn, txt)
38+
39+
buf := make([]byte, len(txt))
40+
n, err := io.ReadFull(conn, buf)
41+
c.Assert(err, gc.IsNil)
42+
c.Assert(string(buf[0:n]), gc.Equals, txt)
43+
44+
// Close the connection and check that we see
45+
// the connection closed for read.
46+
conn.(*net.TCPConn).CloseWrite()
47+
n, err = conn.Read(buf)
48+
c.Assert(err, gc.Equals, io.EOF)
49+
c.Assert(n, gc.Equals, 0)
50+
51+
// Make another connection and close the proxy,
52+
// which should close down the proxy and cause us
53+
// to get an error.
54+
conn, err = net.Dial("tcp", p.Addr())
55+
c.Assert(err, gc.IsNil)
56+
defer conn.Close()
57+
58+
p.Close()
59+
_, err = conn.Read(buf)
60+
c.Assert(err, gc.Equals, io.EOF)
61+
62+
// Make sure that we cannot dial the proxy address either.
63+
conn, err = net.Dial("tcp", p.Addr())
64+
c.Assert(err, gc.ErrorMatches, ".*connection refused")
65+
66+
listener.Close()
67+
// Make sure that all our connections have gone away too.
68+
wg.Wait()
69+
}
70+
71+
// tcpEcho listens on the given listener for TCP connections,
72+
// writes all traffic received back to the sender, and calls
73+
// wg.Done when all its goroutines have completed.
74+
func tcpEcho(wg *sync.WaitGroup, listener net.Listener) {
75+
defer wg.Done()
76+
for {
77+
conn, err := listener.Accept()
78+
if err != nil {
79+
return
80+
}
81+
wg.Add(1)
82+
go func() {
83+
defer wg.Done()
84+
defer conn.Close()
85+
// Echo anything that was written.
86+
io.Copy(conn, conn)
87+
}()
88+
}
89+
}

0 commit comments

Comments
 (0)