Skip to content

Commit

Permalink
Switch to checking type of net.Addr to grab IP + update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
film42 committed Jul 22, 2017
1 parent 5fb86aa commit 76a4c80
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 80 deletions.
45 changes: 22 additions & 23 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type connection struct {
incoming net.Conn
outgoing net.Conn
proxy
localIP string
}

func (c *connection) Handle() {
Expand Down Expand Up @@ -98,27 +99,6 @@ func (c *connection) Close() {
logger.Info.Println(c.id, "Connection closed.")
}

func parseAddrFromHostport(hostport string) (string, error) {
if len(hostport) == 0 {
return "", errors.New("Hostport string provided was empty.")
}

colonIndex := strings.IndexByte(hostport, ':')
if colonIndex == -1 {
return "", errors.New("No colon was provided in the net.Conn local address (hostport string).")
}

if i := strings.Index(hostport, "]:"); i != -1 {
return hostport[:i+len("]")], nil
}

if strings.Contains(hostport, "]") {
return "", errors.New("Invalid ipv6 local address provided as hostport string.")
}

return hostport[:colonIndex], nil
}

// COPIED FROM STD LIB TO USE WITH PROXY-AUTH HEADER
// parseBasicAuth parses an HTTP Basic Authentication string.
// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
Expand Down Expand Up @@ -165,11 +145,30 @@ func newConnectionId() string {
return "[" + hex.EncodeToString(bytes) + "]"
}

func NewConnection(incoming net.Conn) *connection {
func localIPString(addr net.Addr) (string, error) {
switch a := addr.(type) {
case *net.TCPAddr:
return a.IP.String(), nil
case *net.IPAddr:
return a.IP.String(), nil
}

return "", errors.New("Could not find IP Address in net.Addr: " + addr.String())
}

func NewConnection(incoming net.Conn) (*connection, error) {
newId := fmt.Sprint(newConnectionId(), " [", incoming.RemoteAddr().String(), "]")
localAddr := incoming.LocalAddr()
incomingLocalIP, err := localIPString(localAddr)

if err != nil {
fmt.Println(err)
return nil, err
}

return &connection{
id: newId,
incoming: incoming,
}
localIP: incomingLocalIP,
}, nil
}
40 changes: 4 additions & 36 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestInvalidCredentials(t *testing.T) {

incoming := NewMockConn()
defer incoming.CloseClient()
conn := NewConnection(incoming)
conn, _ := NewConnection(incoming)
go conn.Handle()

incoming.ClientWriter.Write([]byte(basicHttpProxyRequest()))
Expand All @@ -84,7 +84,7 @@ func TestSampleProxy(t *testing.T) {
cleanedUp := make(chan bool)
incoming := NewMockConn()
defer incoming.CloseClient()
conn := NewConnection(incoming)
conn, _ := NewConnection(incoming)
go func() {
conn.Handle()
cleanedUp <- true
Expand Down Expand Up @@ -114,7 +114,7 @@ func TestSampleProxyWithValidAuthCredentials(t *testing.T) {

cleanedUp := make(chan bool)
incoming := NewMockConn()
conn := NewConnection(incoming)
conn, _ := NewConnection(incoming)
go func() {
conn.Handle()
cleanedUp <- true
Expand Down Expand Up @@ -146,7 +146,7 @@ func TestSampleProxyViaConnect(t *testing.T) {

cleanedUp := make(chan bool)
incoming := NewMockConn()
conn := NewConnection(incoming)
conn, _ := NewConnection(incoming)
go func() {
conn.Handle()
cleanedUp <- true
Expand All @@ -173,35 +173,3 @@ func TestSampleProxyViaConnect(t *testing.T) {
incoming.CloseClient()
<-cleanedUp
}

func TestParsingAddrFromHostport(t *testing.T) {
_, err := parseAddrFromHostport("")
if err == nil {
t.Fatal("Expected an error.")
}

_, err = parseAddrFromHostport("1.1.1.1")
if err == nil {
t.Fatal("Expected an error.")
}

_, err = parseAddrFromHostport("[2001:db8::1]")
if err == nil {
t.Fatal("Expected an error.")
}

_, err = parseAddrFromHostport("somerandomstring.com")
if err == nil {
t.Fatal("Expected an error.")
}

ipv4Addr, _ := parseAddrFromHostport("1.1.1.1:8000")
if ipv4Addr != "1.1.1.1" {
t.Fatalf("Expected 1.1.1.1 but found %s", ipv4Addr)
}

ipv6Addr, _ := parseAddrFromHostport("[2001:db8::1]:80")
if ipv6Addr != "[2001:db8::1]" {
t.Fatalf("Expected [2001:db8::1] but found %s", ipv6Addr)
}
}
7 changes: 6 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ import (
var config *Config

func handleConnection(conn net.Conn) {
connection := NewConnection(conn)
connection, err := NewConnection(conn)
if err != nil {
logger.Fatal.Println(err)
return
}

defer connection.Close()
connection.Handle()
}
Expand Down
28 changes: 8 additions & 20 deletions mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,6 @@ import (
"time"
)

// Addr is a fake network interface which implements the net.Addr interface
type Addr struct {
NetworkString string
AddrString string
}

func (a Addr) Network() string {
return a.NetworkString
}

func (a Addr) String() string {
return a.AddrString
}

type MockConn struct {
ServerReader *io.PipeReader
ServerWriter *io.PipeWriter
Expand Down Expand Up @@ -51,16 +37,18 @@ func (c MockConn) Read(data []byte) (n int, err error) { return c.ServerReader.
func (c MockConn) Write(data []byte) (n int, err error) { return c.ServerWriter.Write(data) }

func (c MockConn) LocalAddr() net.Addr {
return Addr{
NetworkString: "tcp",
AddrString: "127.0.0.1",
return &net.TCPAddr{
IP: net.ParseIP("127.0.0.1:2342"),
Port: 2342,
Zone: "",
}
}

func (c MockConn) RemoteAddr() net.Addr {
return Addr{
NetworkString: "tcp",
AddrString: "127.0.0.1",
return &net.TCPAddr{
IP: net.ParseIP("127.0.0.1:2342"),
Port: 2342,
Zone: "",
}
}

Expand Down

0 comments on commit 76a4c80

Please sign in to comment.