Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions cmd/cloudflared/tunnel/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,19 @@ func configureProxyFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide,
Value: false,
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "header",
Aliases: []string{"H"},
Usage: "Add custom header when forwarding to origin (format: 'Name: Value')",
EnvVars: []string{"TUNNEL_HEADERS"},
Hidden: shouldHide,
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "remove-header",
Usage: "Remove header when forwarding to origin",
EnvVars: []string{"TUNNEL_REMOVE_HEADERS"},
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.ManagementHostname,
Usage: "Management hostname to signify incoming management requests",
Expand Down
4 changes: 4 additions & 0 deletions config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ type OriginRequestConfig struct {
Http2Origin *bool `yaml:"http2Origin" json:"http2Origin,omitempty"`
// Access holds all access related configs
Access *AccessConfig `yaml:"access" json:"access,omitempty"`
// Custom headers to add/modify when forwarding to origin
Headers map[string]string `yaml:"headers" json:"headers,omitempty"`
// Headers to remove when forwarding to origin
RemoveHeaders []string `yaml:"removeHeaders" json:"removeHeaders,omitempty"`
}

type AccessConfig struct {
Expand Down
98 changes: 98 additions & 0 deletions ingress/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ingress

import (
"encoding/json"
"strings"
"time"

"github.com/urfave/cli/v2"
Expand Down Expand Up @@ -137,6 +138,16 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
var proxyPort uint
var proxyType string
var http2Origin bool

var cliHeaders map[string]string
var cliRemoveHeaders []string
if c.IsSet("header") {
cliHeaders = parseHeadersFromCLI(c)
}
if c.IsSet("remove-header") {
cliRemoveHeaders = parseRemoveHeadersFromCLI(c)
}

if flag := ProxyConnectTimeoutFlag; c.IsSet(flag) {
connectTimeout = config.CustomDuration{Duration: c.Duration(flag)}
}
Expand Down Expand Up @@ -209,6 +220,8 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
ProxyPort: proxyPort,
ProxyType: proxyType,
Http2Origin: http2Origin,
Headers: cliHeaders,
RemoveHeaders: cliRemoveHeaders,
}
}

Expand Down Expand Up @@ -333,6 +346,12 @@ type OriginRequestConfig struct {

// Access holds all access related configs
Access config.AccessConfig `yaml:"access" json:"access,omitempty"`

// Custom headers to add/modify when forwarding to origin
Headers map[string]string `yaml:"headers" json:"headers,omitempty"`

// Headers to remove when forwarding to origin
RemoveHeaders []string `yaml:"removeHeaders" json:"removeHeaders,omitempty"`
}

func (defaults *OriginRequestConfig) setConnectTimeout(overrides config.OriginRequestConfig) {
Expand Down Expand Up @@ -456,6 +475,26 @@ func (defaults *OriginRequestConfig) setAccess(overrides config.OriginRequestCon
}
}

func (defaults *OriginRequestConfig) setHeaders(overrides config.OriginRequestConfig) {
if overrides.Headers != nil {
if defaults.Headers == nil {
defaults.Headers = make(map[string]string)
}
for k, v := range overrides.Headers {
defaults.Headers[k] = v
}
}
}

func (defaults *OriginRequestConfig) setRemoveHeaders(overrides config.OriginRequestConfig) {
if overrides.RemoveHeaders != nil {
if defaults.RemoveHeaders == nil {
defaults.RemoveHeaders = make([]string, 0)
}
defaults.RemoveHeaders = append(defaults.RemoveHeaders, overrides.RemoveHeaders...)
}
}

// SetConfig gets config for the requests that cloudflared sends to origins.
// Each field has a setter method which sets a value for the field by trying to find:
// 1. The user config for this rule
Expand Down Expand Up @@ -485,6 +524,8 @@ func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfi
cfg.setIPRules(overrides)
cfg.setHttp2Origin(overrides)
cfg.setAccess(overrides)
cfg.setHeaders(overrides)
cfg.setRemoveHeaders(overrides)

return cfg
}
Expand Down Expand Up @@ -540,6 +581,8 @@ func ConvertToRawOriginConfig(c OriginRequestConfig) config.OriginRequestConfig
IPRules: convertToRawIPRules(c.IPRules),
Http2Origin: defaultBoolToNil(c.Http2Origin),
Access: access,
Headers: c.Headers,
RemoveHeaders: c.RemoveHeaders,
}
}

Expand Down Expand Up @@ -583,3 +626,58 @@ func zeroUIntToNil(v uint) *uint {

return &v
}

func parseHeadersFromCLI(c *cli.Context) map[string]string {
headers := make(map[string]string)

if c.IsSet("header") {
headerFlags := c.StringSlice("header")
for _, headerFlag := range headerFlags {
if name, value, valid := parseHeaderFlag(headerFlag); valid {
headers[name] = value
}
}
}

return headers
}

func parseRemoveHeadersFromCLI(c *cli.Context) []string {
if c.IsSet("remove-header") {
return c.StringSlice("remove-header")
}
return nil
}

func parseHeaderFlag(headerFlag string) (name, value string, valid bool) {
parts := strings.SplitN(headerFlag, ":", 2)
if len(parts) != 2 {
return "", "", false
}

name = strings.TrimSpace(parts[0])
value = strings.TrimSpace(parts[1])

if name == "" || value == "" || !isValidHeaderName(name) {
return "", "", false
}

return name, value, true
}

func isValidHeaderName(name string) bool {
if name == "" || strings.Contains(name, ":") {
return false
}
if strings.ContainsAny(name, " \t\r\n") {
return false
}
if strings.TrimSpace(name) == "" {
return false
}
if len(name) > 256 {
return false
}

return true
}
183 changes: 183 additions & 0 deletions ingress/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ingress
import (
"encoding/json"
"flag"
"strings"
"testing"
"time"

Expand All @@ -12,6 +13,7 @@ import (

"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/ipaccess"
"github.com/stretchr/testify/assert"
)

// Ensure that the nullable config from `config` package and the
Expand Down Expand Up @@ -415,6 +417,187 @@ func TestDefaultConfigFromCLI(t *testing.T) {
require.Equal(t, expected, actual)
}

func TestOriginRequestConfigHeaders(t *testing.T) {
config := OriginRequestConfig{
Headers: map[string]string{
"X-Custom-Header": "custom-value",
"Authorization": "Bearer token123",
},
RemoveHeaders: []string{"X-Unwanted", "Server"},
}

jsonData, err := json.Marshal(config)
assert.NoError(t, err)
assert.Contains(t, string(jsonData), "X-Custom-Header")
assert.Contains(t, string(jsonData), "custom-value")
assert.Contains(t, string(jsonData), "X-Unwanted")

var unmarshaled OriginRequestConfig
err = json.Unmarshal(jsonData, &unmarshaled)
assert.NoError(t, err)
assert.Equal(t, "custom-value", unmarshaled.Headers["X-Custom-Header"])
assert.Equal(t, "Bearer token123", unmarshaled.Headers["Authorization"])
assert.Contains(t, unmarshaled.RemoveHeaders, "X-Unwanted")
assert.Contains(t, unmarshaled.RemoveHeaders, "Server")
}

func TestParseHeaderFlag(t *testing.T) {
name, value, valid := parseHeaderFlag("X-Custom-Header: custom-value")
assert.True(t, valid)
assert.Equal(t, "X-Custom-Header", name)
assert.Equal(t, "custom-value", value)

name, value, valid = parseHeaderFlag(" Authorization : Bearer token ")
assert.True(t, valid)
assert.Equal(t, "Authorization", name)
assert.Equal(t, "Bearer token", value)

name, value, valid = parseHeaderFlag("X-Header: ")
assert.False(t, valid)

name, value, valid = parseHeaderFlag(" : value")
assert.False(t, valid)

_, _, valid = parseHeaderFlag("invalid-format")
assert.False(t, valid)

_, _, valid = parseHeaderFlag(": value-only")
assert.False(t, valid)

_, _, valid = parseHeaderFlag("name-only:")
assert.False(t, valid)

_, _, valid = parseHeaderFlag("")
assert.False(t, valid)

name, value, valid = parseHeaderFlag("X-Special: value with @#$%^&*()")
assert.True(t, valid)
assert.Equal(t, "X-Special", name)
assert.Equal(t, "value with @#$%^&*()", value)

name, value, valid = parseHeaderFlag("X-URL: https://example.com:8080/path")
assert.True(t, valid)
assert.Equal(t, "X-URL", name)
assert.Equal(t, "https://example.com:8080/path", value)
}

func TestIsValidHeaderName(t *testing.T) {
assert.True(t, isValidHeaderName("X-Custom-Header"))
assert.True(t, isValidHeaderName("Authorization"))
assert.True(t, isValidHeaderName("Content-Type"))
assert.True(t, isValidHeaderName("X-API-Key"))
assert.True(t, isValidHeaderName("User-Agent"))

assert.False(t, isValidHeaderName(""))
assert.False(t, isValidHeaderName(" "))
assert.False(t, isValidHeaderName("\t"))
assert.False(t, isValidHeaderName("\n"))
assert.False(t, isValidHeaderName("\r"))

assert.False(t, isValidHeaderName("Header With Space"))
assert.False(t, isValidHeaderName("Header\tWith\tTab"))
assert.False(t, isValidHeaderName("Header\nWith\nNewline"))
assert.False(t, isValidHeaderName("Header\rWith\rCarriageReturn"))

assert.False(t, isValidHeaderName(":Header"))
assert.False(t, isValidHeaderName("Header:"))
assert.False(t, isValidHeaderName("Header::Value"))

longHeader := strings.Repeat("A", 257)
assert.False(t, isValidHeaderName(longHeader))

boundaryHeader := strings.Repeat("A", 256)
assert.True(t, isValidHeaderName(boundaryHeader))

assert.True(t, isValidHeaderName("X"))
assert.True(t, isValidHeaderName("a"))
assert.True(t, isValidHeaderName("1"))

assert.True(t, isValidHeaderName("X-Header"))
assert.True(t, isValidHeaderName("X_Header"))
assert.True(t, isValidHeaderName("X.Header"))
}

func TestParseHeadersFromCLI(t *testing.T) {
app := cli.NewApp()
app.Flags = []cli.Flag{
&cli.StringSliceFlag{
Name: "header",
},
}

app.Action = func(c *cli.Context) error {
headers := parseHeadersFromCLI(c)

assert.Equal(t, 3, len(headers))
assert.Equal(t, "test-value", headers["X-Test-Header"])
assert.Equal(t, "static-key-123", headers["X-API-Key"])
assert.Equal(t, "Bearer token", headers["Authorization"])

assert.NotContains(t, headers, "Invalid-Header")
assert.NotContains(t, headers, "X-Empty")

return nil
}

err := app.Run([]string{"app", "--header", "X-Test-Header: test-value", "--header", "X-API-Key: static-key-123", "--header", "Authorization: Bearer token", "--header", "Invalid-Header", "--header", "X-Empty: "})
assert.NoError(t, err)
}

func TestParseRemoveHeadersFromCLI(t *testing.T) {
app := cli.NewApp()
app.Flags = []cli.Flag{
&cli.StringSliceFlag{
Name: "remove-header",
},
}

app.Action = func(c *cli.Context) error {
removeHeaders := parseRemoveHeadersFromCLI(c)

assert.Equal(t, 3, len(removeHeaders))
assert.Contains(t, removeHeaders, "X-Unwanted")
assert.Contains(t, removeHeaders, "Server")
assert.Contains(t, removeHeaders, "User-Agent")

return nil
}

err := app.Run([]string{"app", "--remove-header", "X-Unwanted", "--remove-header", "Server", "--remove-header", "User-Agent"})
assert.NoError(t, err)
}

func TestParseHeadersFromCLINotSet(t *testing.T) {
app := cli.NewApp()

app.Action = func(c *cli.Context) error {
headers := parseHeadersFromCLI(c)

assert.Equal(t, 0, len(headers))
assert.NotNil(t, headers)

return nil
}

err := app.Run([]string{"app"})
assert.NoError(t, err)
}

func TestParseRemoveHeadersFromCLINotSet(t *testing.T) {
app := cli.NewApp()

app.Action = func(c *cli.Context) error {
removeHeaders := parseRemoveHeadersFromCLI(c)

assert.Nil(t, removeHeaders)

return nil
}

err := app.Run([]string{"app"})
assert.NoError(t, err)
}

func newIPRule(t *testing.T, prefix string, ports []int, allow bool) ipaccess.Rule {
rule, err := ipaccess.NewRuleByCIDR(&prefix, ports, allow)
require.NoError(t, err)
Expand Down
Loading