diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index db1cb5af3f7..24ec38172ad 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -110,8 +110,8 @@ const ( tunnelCmdErrorMessage = `You did not specify any valid additional argument to the cloudflared tunnel command. -If you are trying to run a Quick Tunnel then you need to explicitly pass the --url flag. -Eg. cloudflared tunnel --url localhost:8080/. +If you are trying to run a Quick Tunnel then you need to explicitly pass a --url or --unix-socket flag. +Eg. 'cloudflared tunnel --url localhost:8080/' or 'cloudflared tunnel --unix-socket /tmp/socket'. Please note that Quick Tunnels are meant to be ephemeral and should only be used for testing purposes. For production usage, we recommend creating Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/) @@ -215,6 +215,7 @@ var ( "overwrite-dns", "help", } + runQuickTunnel = RunQuickTunnel ) func Flags() []cli.Flag { @@ -313,9 +314,9 @@ func TunnelCommand(c *cli.Context) error { // Run a quick tunnel // A unauthenticated named tunnel hosted on ..com // We don't support running proxy-dns and a quick tunnel at the same time as the same process - shouldRunQuickTunnel := c.IsSet("url") || c.IsSet(ingress.HelloWorldFlag) + shouldRunQuickTunnel := c.IsSet("url") || c.IsSet("unix-socket") || c.IsSet(ingress.HelloWorldFlag) if !c.IsSet("proxy-dns") && c.String("quick-service") != "" && shouldRunQuickTunnel { - return RunQuickTunnel(sc) + return runQuickTunnel(sc) } // If user provides a config, check to see if they meant to use `tunnel run` instead diff --git a/cmd/cloudflared/tunnel/cmd_test.go b/cmd/cloudflared/tunnel/cmd_test.go index b29b396612c..faf1de00ec0 100644 --- a/cmd/cloudflared/tunnel/cmd_test.go +++ b/cmd/cloudflared/tunnel/cmd_test.go @@ -1,9 +1,12 @@ package tunnel import ( + "flag" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" ) func TestHostnameFromURI(t *testing.T) { @@ -15,3 +18,76 @@ func TestHostnameFromURI(t *testing.T) { assert.Equal(t, "", hostnameFromURI("trash")) assert.Equal(t, "", hostnameFromURI("https://awesomesauce.com")) } + +func TestShouldRunQuickTunnel(t *testing.T) { + tests := []struct { + name string + flags map[string]string + expectError bool + }{ + { + name: "Quick tunnel with URL set", + flags: map[string]string{"url": "http://127.0.0.1:8080", "quick-service": "https://fakeapi.trycloudflare.com"}, + expectError: false, + }, + { + name: "Quick tunnel with unix-socket set", + flags: map[string]string{"unix-socket": "/tmp/socket", "quick-service": "https://fakeapi.trycloudflare.com"}, + expectError: false, + }, + { + name: "Quick tunnel with hello-world flag", + flags: map[string]string{"hello-world": "true", "quick-service": "https://fakeapi.trycloudflare.com"}, + expectError: false, + }, + { + name: "Quick tunnel with proxy-dns (invalid combo)", + flags: map[string]string{"url": "http://127.0.0.1:9090", "proxy-dns": "true", "quick-service": "https://fakeapi.trycloudflare.com"}, + expectError: true, + }, + { + name: "No quick-service set", + flags: map[string]string{"url": "http://127.0.0.1:9090"}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock RunQuickTunnel Function + originalRunQuickTunnel := runQuickTunnel + defer func() { runQuickTunnel = originalRunQuickTunnel }() + mockCalled := false + runQuickTunnel = func(sc *subcommandContext) error { + mockCalled = true + return nil + } + + // Mock App Context + app := &cli.App{} + set := flagSetFromMap(tt.flags) + context := cli.NewContext(app, set, nil) + + // Call TunnelCommand + err := TunnelCommand(context) + + // Validate + if tt.expectError { + assert.False(t, mockCalled) + require.Error(t, err) + } else { + assert.True(t, mockCalled) + require.NoError(t, err) + } + }) + } +} + +func flagSetFromMap(flags map[string]string) *flag.FlagSet { + set := flag.NewFlagSet("test", 0) + for key, value := range flags { + set.String(key, "", "") + set.Set(key, value) + } + return set +}