From 76a4c804bb79c2b17220b1542707668e5528b6c0 Mon Sep 17 00:00:00 2001 From: Garrett Thornburg Date: Sat, 22 Jul 2017 00:10:32 -0400 Subject: [PATCH] Switch to checking type of net.Addr to grab IP + update tests --- connection.go | 45 ++++++++++++++++++++++----------------------- connection_test.go | 40 ++++------------------------------------ main.go | 7 ++++++- mock_test.go | 28 ++++++++-------------------- 4 files changed, 40 insertions(+), 80 deletions(-) diff --git a/connection.go b/connection.go index 2ef38b9..39b7605 100644 --- a/connection.go +++ b/connection.go @@ -20,6 +20,7 @@ type connection struct { incoming net.Conn outgoing net.Conn proxy + localIP string } func (c *connection) Handle() { @@ -98,27 +99,6 @@ func (c *connection) Close() { logger.Info.Println(c.id, "Connection closed.") } -func parseAddrFromHostport(hostport string) (string, error) { - if len(hostport) == 0 { - return "", errors.New("Hostport string provided was empty.") - } - - colonIndex := strings.IndexByte(hostport, ':') - if colonIndex == -1 { - return "", errors.New("No colon was provided in the net.Conn local address (hostport string).") - } - - if i := strings.Index(hostport, "]:"); i != -1 { - return hostport[:i+len("]")], nil - } - - if strings.Contains(hostport, "]") { - return "", errors.New("Invalid ipv6 local address provided as hostport string.") - } - - return hostport[:colonIndex], nil -} - // COPIED FROM STD LIB TO USE WITH PROXY-AUTH HEADER // parseBasicAuth parses an HTTP Basic Authentication string. // "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true). @@ -165,11 +145,30 @@ func newConnectionId() string { return "[" + hex.EncodeToString(bytes) + "]" } -func NewConnection(incoming net.Conn) *connection { +func localIPString(addr net.Addr) (string, error) { + switch a := addr.(type) { + case *net.TCPAddr: + return a.IP.String(), nil + case *net.IPAddr: + return a.IP.String(), nil + } + + return "", errors.New("Could not find IP Address in net.Addr: " + addr.String()) +} + +func NewConnection(incoming net.Conn) (*connection, error) { newId := fmt.Sprint(newConnectionId(), " [", incoming.RemoteAddr().String(), "]") + localAddr := incoming.LocalAddr() + incomingLocalIP, err := localIPString(localAddr) + + if err != nil { + fmt.Println(err) + return nil, err + } return &connection{ id: newId, incoming: incoming, - } + localIP: incomingLocalIP, + }, nil } diff --git a/connection_test.go b/connection_test.go index 3ef4949..230434c 100644 --- a/connection_test.go +++ b/connection_test.go @@ -58,7 +58,7 @@ func TestInvalidCredentials(t *testing.T) { incoming := NewMockConn() defer incoming.CloseClient() - conn := NewConnection(incoming) + conn, _ := NewConnection(incoming) go conn.Handle() incoming.ClientWriter.Write([]byte(basicHttpProxyRequest())) @@ -84,7 +84,7 @@ func TestSampleProxy(t *testing.T) { cleanedUp := make(chan bool) incoming := NewMockConn() defer incoming.CloseClient() - conn := NewConnection(incoming) + conn, _ := NewConnection(incoming) go func() { conn.Handle() cleanedUp <- true @@ -114,7 +114,7 @@ func TestSampleProxyWithValidAuthCredentials(t *testing.T) { cleanedUp := make(chan bool) incoming := NewMockConn() - conn := NewConnection(incoming) + conn, _ := NewConnection(incoming) go func() { conn.Handle() cleanedUp <- true @@ -146,7 +146,7 @@ func TestSampleProxyViaConnect(t *testing.T) { cleanedUp := make(chan bool) incoming := NewMockConn() - conn := NewConnection(incoming) + conn, _ := NewConnection(incoming) go func() { conn.Handle() cleanedUp <- true @@ -173,35 +173,3 @@ func TestSampleProxyViaConnect(t *testing.T) { incoming.CloseClient() <-cleanedUp } - -func TestParsingAddrFromHostport(t *testing.T) { - _, err := parseAddrFromHostport("") - if err == nil { - t.Fatal("Expected an error.") - } - - _, err = parseAddrFromHostport("1.1.1.1") - if err == nil { - t.Fatal("Expected an error.") - } - - _, err = parseAddrFromHostport("[2001:db8::1]") - if err == nil { - t.Fatal("Expected an error.") - } - - _, err = parseAddrFromHostport("somerandomstring.com") - if err == nil { - t.Fatal("Expected an error.") - } - - ipv4Addr, _ := parseAddrFromHostport("1.1.1.1:8000") - if ipv4Addr != "1.1.1.1" { - t.Fatalf("Expected 1.1.1.1 but found %s", ipv4Addr) - } - - ipv6Addr, _ := parseAddrFromHostport("[2001:db8::1]:80") - if ipv6Addr != "[2001:db8::1]" { - t.Fatalf("Expected [2001:db8::1] but found %s", ipv6Addr) - } -} diff --git a/main.go b/main.go index d5aa8db..49c34d8 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,12 @@ import ( var config *Config func handleConnection(conn net.Conn) { - connection := NewConnection(conn) + connection, err := NewConnection(conn) + if err != nil { + logger.Fatal.Println(err) + return + } + defer connection.Close() connection.Handle() } diff --git a/mock_test.go b/mock_test.go index c3e8c8a..af75a81 100644 --- a/mock_test.go +++ b/mock_test.go @@ -6,20 +6,6 @@ import ( "time" ) -// Addr is a fake network interface which implements the net.Addr interface -type Addr struct { - NetworkString string - AddrString string -} - -func (a Addr) Network() string { - return a.NetworkString -} - -func (a Addr) String() string { - return a.AddrString -} - type MockConn struct { ServerReader *io.PipeReader ServerWriter *io.PipeWriter @@ -51,16 +37,18 @@ func (c MockConn) Read(data []byte) (n int, err error) { return c.ServerReader. func (c MockConn) Write(data []byte) (n int, err error) { return c.ServerWriter.Write(data) } func (c MockConn) LocalAddr() net.Addr { - return Addr{ - NetworkString: "tcp", - AddrString: "127.0.0.1", + return &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1:2342"), + Port: 2342, + Zone: "", } } func (c MockConn) RemoteAddr() net.Addr { - return Addr{ - NetworkString: "tcp", - AddrString: "127.0.0.1", + return &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1:2342"), + Port: 2342, + Zone: "", } }