From eee279bdb3850c1e9365c94776a3eaf898d19e37 Mon Sep 17 00:00:00 2001
From: hi019 <65871571+hi019@users.noreply.github.com>
Date: Fri, 31 Dec 2021 12:32:39 -0500
Subject: [PATCH] Fix using IP ranges in config.TrustedProxies (#1607) (#1614)

* Fix using IP ranges in config.TrustedProxies (#1607)

* Add tests

* Remove debugging var

* Remove tests

* Update test

Co-authored-by: RW <rene@gofiber.io>
---
 app.go      | 27 +++++++---------
 ctx.go      | 14 ++++++--
 ctx_test.go | 92 +++++++++++++++++++++++++++++++++++++++++++----------
 3 files changed, 100 insertions(+), 33 deletions(-)

diff --git a/app.go b/app.go
index ebb8418c55..f97564733f 100644
--- a/app.go
+++ b/app.go
@@ -350,8 +350,9 @@ type Config struct {
 	// Read EnableTrustedProxyCheck doc.
 	//
 	// Default: []string
-	TrustedProxies    []string `json:"trusted_proxies"`
-	trustedProxiesMap map[string]struct{}
+	TrustedProxies     []string `json:"trusted_proxies"`
+	trustedProxiesMap  map[string]struct{}
+	trustedProxyRanges []*net.IPNet
 
 	//If set to true, will print all routes with their method, path and handler.
 	// Default: false
@@ -501,8 +502,8 @@ func New(config ...Config) *App {
 	}
 
 	app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies))
-	for _, ip := range app.config.TrustedProxies {
-		app.handleTrustedProxy(ip)
+	for _, ipAddress := range app.config.TrustedProxies {
+		app.handleTrustedProxy(ipAddress)
 	}
 
 	// Init app
@@ -512,23 +513,19 @@ func New(config ...Config) *App {
 	return app
 }
 
-// Checks if the given IP address is a range whether or not, adds it to the trustedProxiesMap
+// Adds an ip address to trustedProxyRanges or trustedProxiesMap based on whether it is an IP range or not
 func (app *App) handleTrustedProxy(ipAddress string) {
-	// Detects IP address is range whether or not
 	if strings.Contains(ipAddress, "/") {
-		// Parsing IP address
-		ip, ipnet, err := net.ParseCIDR(ipAddress)
+		_, ipNet, err := net.ParseCIDR(ipAddress)
+
 		if err != nil {
 			fmt.Printf("[Warning] IP range `%s` could not be parsed. \n", ipAddress)
-			return
-		}
-		// Iterates IP address which is between range
-		for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); utils.IncrementIPRange(ip) {
-			app.config.trustedProxiesMap[ip.String()] = struct{}{}
 		}
-		return
+
+		app.config.trustedProxyRanges = append(app.config.trustedProxyRanges, ipNet)
+	} else {
+		app.config.trustedProxiesMap[ipAddress] = struct{}{}
 	}
-	app.config.trustedProxiesMap[ipAddress] = struct{}{}
 }
 
 // Mount attaches another app instance as a sub-router along a routing path.
diff --git a/ctx.go b/ctx.go
index df2b754f86..9f06c3df69 100644
--- a/ctx.go
+++ b/ctx.go
@@ -1310,8 +1310,18 @@ func (c *Ctx) IsProxyTrusted() bool {
 		return true
 	}
 
-	_, trustProxy := c.app.config.trustedProxiesMap[c.fasthttp.RemoteIP().String()]
-	return trustProxy
+	_, trusted := c.app.config.trustedProxiesMap[c.fasthttp.RemoteIP().String()]
+	if trusted {
+		return trusted
+	}
+
+	for _, ipNet := range c.app.config.trustedProxyRanges {
+		if ipNet.Contains(c.fasthttp.RemoteIP()) {
+			return true
+		}
+	}
+
+	return false
 }
 
 // IsLocalHost will return true if address is a localhost address.
diff --git a/ctx_test.go b/ctx_test.go
index 7591372d33..0406ada7d4 100644
--- a/ctx_test.go
+++ b/ctx_test.go
@@ -985,6 +985,30 @@ func Test_Ctx_Hostname_TrustedProxy(t *testing.T) {
 	}
 }
 
+// go test -run Test_Ctx_Hostname_UntrustedProxyRange
+func Test_Ctx_Hostname_TrustedProxyRange(t *testing.T) {
+	t.Parallel()
+
+	app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0/30"}})
+	c := app.AcquireCtx(&fasthttp.RequestCtx{})
+	c.Request().SetRequestURI("http://google.com/test")
+	c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
+	utils.AssertEqual(t, "google1.com", c.Hostname())
+	app.ReleaseCtx(c)
+}
+
+// go test -run Test_Ctx_Hostname_UntrustedProxyRange
+func Test_Ctx_Hostname_UntrustedProxyRange(t *testing.T) {
+	t.Parallel()
+
+	app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"1.0.0.0/30"}})
+	c := app.AcquireCtx(&fasthttp.RequestCtx{})
+	c.Request().SetRequestURI("http://google.com/test")
+	c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
+	utils.AssertEqual(t, "google.com", c.Hostname())
+	app.ReleaseCtx(c)
+}
+
 // go test -run Test_Ctx_Port
 func Test_Ctx_Port(t *testing.T) {
 	t.Parallel()
@@ -1032,22 +1056,6 @@ func Test_Ctx_IP_TrustedProxy(t *testing.T) {
 	utils.AssertEqual(t, "0.0.0.1", c.IP())
 }
 
-// go test -run Test_Ctx_IP_Range_TrustedProxy
-func Test_Ctx_IP_Range_TrustedProxy(t *testing.T) {
-	t.Parallel()
-	app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0", "1.1.1.1/30", "1.1.1.1/100"}, ProxyHeader: HeaderXForwardedFor})
-	c := app.AcquireCtx(&fasthttp.RequestCtx{})
-	defer app.ReleaseCtx(c)
-	expected := map[string]struct{}{
-		"0.0.0.0": {},
-		"1.1.1.0": {},
-		"1.1.1.1": {},
-		"1.1.1.2": {},
-		"1.1.1.3": {},
-	}
-	utils.AssertEqual(t, expected, app.config.trustedProxiesMap)
-}
-
 // go test -run Test_Ctx_IPs  -parallel
 func Test_Ctx_IPs(t *testing.T) {
 	t.Parallel()
@@ -1399,6 +1407,58 @@ func Test_Ctx_Protocol_TrustedProxy(t *testing.T) {
 	utils.AssertEqual(t, "http", c.Protocol())
 }
 
+// go test -run Test_Ctx_Protocol_TrustedProxyRange
+func Test_Ctx_Protocol_TrustedProxyRange(t *testing.T) {
+	t.Parallel()
+	app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0/30"}})
+	c := app.AcquireCtx(&fasthttp.RequestCtx{})
+	defer app.ReleaseCtx(c)
+
+	c.Request().Header.Set(HeaderXForwardedProto, "https")
+	utils.AssertEqual(t, "https", c.Protocol())
+	c.Request().Header.Reset()
+
+	c.Request().Header.Set(HeaderXForwardedProtocol, "https")
+	utils.AssertEqual(t, "https", c.Protocol())
+	c.Request().Header.Reset()
+
+	c.Request().Header.Set(HeaderXForwardedSsl, "on")
+	utils.AssertEqual(t, "https", c.Protocol())
+	c.Request().Header.Reset()
+
+	c.Request().Header.Set(HeaderXUrlScheme, "https")
+	utils.AssertEqual(t, "https", c.Protocol())
+	c.Request().Header.Reset()
+
+	utils.AssertEqual(t, "http", c.Protocol())
+}
+
+// go test -run Test_Ctx_Protocol_UntrustedProxyRange
+func Test_Ctx_Protocol_UntrustedProxyRange(t *testing.T) {
+	t.Parallel()
+	app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"1.1.1.1/30"}})
+	c := app.AcquireCtx(&fasthttp.RequestCtx{})
+	defer app.ReleaseCtx(c)
+
+	c.Request().Header.Set(HeaderXForwardedProto, "https")
+	utils.AssertEqual(t, "http", c.Protocol())
+	c.Request().Header.Reset()
+
+	c.Request().Header.Set(HeaderXForwardedProtocol, "https")
+	utils.AssertEqual(t, "http", c.Protocol())
+	c.Request().Header.Reset()
+
+	c.Request().Header.Set(HeaderXForwardedSsl, "on")
+	utils.AssertEqual(t, "http", c.Protocol())
+	c.Request().Header.Reset()
+
+	c.Request().Header.Set(HeaderXUrlScheme, "https")
+	utils.AssertEqual(t, "http", c.Protocol())
+	c.Request().Header.Reset()
+
+	utils.AssertEqual(t, "http", c.Protocol())
+}
+
 // go test -run Test_Ctx_Protocol_UnTrustedProxy
 func Test_Ctx_Protocol_UnTrustedProxy(t *testing.T) {
 	t.Parallel()