Skip to content

Commit

Permalink
Attempt to dial with the localAddr and fall back with a log warning
Browse files Browse the repository at this point in the history
  • Loading branch information
film42 committed Jul 22, 2017
1 parent 76a4c80 commit 031e691
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 15 deletions.
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Config struct {
Credentials []Credential
StripProxyHeaders bool `json:"strip_proxy_headers"`
Port int
FollowLocalAddr bool `json:"follow_local_addr"`
}

func (config *Config) AuthenticationRequired() bool {
Expand Down
35 changes: 24 additions & 11 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,26 @@ type connection struct {
incoming net.Conn
outgoing net.Conn
proxy
localIP string
localAddr net.Addr
}

func (c *connection) Dial(network, address string) (net.Conn, error) {
if config.FollowLocalAddr {
dialer := &net.Dialer{LocalAddr: c.localAddr}

// Try to dial with the incoming LocalAddr to keep the incoming and outgoing IPs the same.
conn, err := dialer.Dial(network, address)
if err == nil {
return conn, nil
}

// If an error occurs, fallback to the default interface. This might happen if you connected
// via a loopback interace, like testing on the same machine. We should be more specifc about
// error handling, but falling back is fine for now.
logger.Warn.Println(c.id, "Ignoring net.Addr for", c.localAddr, "dialing due to error:", err)
}

return net.Dial(network, address)
}

func (c *connection) Handle() {
Expand Down Expand Up @@ -145,7 +164,7 @@ func newConnectionId() string {
return "[" + hex.EncodeToString(bytes) + "]"
}

func localIPString(addr net.Addr) (string, error) {
func localAddrString(addr net.Addr) (string, error) {
switch a := addr.(type) {
case *net.TCPAddr:
return a.IP.String(), nil
Expand All @@ -159,16 +178,10 @@ func localIPString(addr net.Addr) (string, error) {
func NewConnection(incoming net.Conn) (*connection, error) {
newId := fmt.Sprint(newConnectionId(), " [", incoming.RemoteAddr().String(), "]")
localAddr := incoming.LocalAddr()
incomingLocalIP, err := localIPString(localAddr)

if err != nil {
fmt.Println(err)
return nil, err
}

return &connection{
id: newId,
incoming: incoming,
localIP: incomingLocalIP,
id: newId,
incoming: incoming,
localAddr: localAddr,
}, nil
}
3 changes: 1 addition & 2 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"errors"
"fmt"
"net"
"net/http"
"strings"
)
Expand Down Expand Up @@ -39,7 +38,7 @@ func (hp *httpProxy) SetupOutgoing(connection *connection, request *http.Request
}

// Create our outgoing connection.
outgoingConn, err := net.Dial("tcp", host)
outgoingConn, err := connection.Dial("tcp", host)
if err != nil {
return errors.New(fmt.Sprint("Error making outgoing request to", request.Host, err))
}
Expand Down
3 changes: 1 addition & 2 deletions https.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"errors"
"fmt"
"net"
"net/http"
)

Expand All @@ -14,7 +13,7 @@ type httpsProxy struct{}
func (hp *httpsProxy) SetupOutgoing(connection *connection, request *http.Request) error {
// Create our outgoing connection.
outgoingHost := request.URL.Host
outgoingConn, err := net.Dial("tcp", outgoingHost)
outgoingConn, err := connection.Dial("tcp", outgoingHost)
if err != nil {
return errors.New(fmt.Sprint("Error opening outgoing connection to", outgoingHost, err))
}
Expand Down
2 changes: 2 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func main() {
configPtr := flag.String("config", "", "config file")
portPtr := flag.Int("port", 25000, "listen port")
stripProxyHeadersPtr := flag.Bool("strip-proxy-headers", true, "strip proxy headers from http requests")
followLocalAddrPtr := flag.Bool("follow-local-addr", true, "Use the local address of the incoming connection when connecting upstream")
usernamePtr := flag.String("username", "", "username for proxy authentication")
passwordPtr := flag.String("password", "", "password for proxy authentication")
flag.Parse()
Expand All @@ -80,6 +81,7 @@ func main() {
config = &Config{
Port: *portPtr,
StripProxyHeaders: *stripProxyHeadersPtr,
FollowLocalAddr: *followLocalAddrPtr,
}

if *usernamePtr != "" {
Expand Down

0 comments on commit 031e691

Please sign in to comment.