Skip to content

Commit d2e48d4

Browse files
authored
[relay] Use instanceURL instead of Exposed address. (#4905)
Replaces string-based exposed address handling with URL-based InstanceURL() (type url.URL) across relay/server and relay/healthcheck; adds SchemeREL/SchemeRELS constants; updates getInstanceURL to return *url.URL with scheme and TLS validation; adjusts WS dialing and health-check logic to use URL fields.
1 parent 27dd97c commit d2e48d4

File tree

7 files changed

+47
-45
lines changed

7 files changed

+47
-45
lines changed

relay/cmd/root.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ func execute(cmd *cobra.Command, args []string) error {
160160
log.Debugf("failed to create relay server: %v", err)
161161
return fmt.Errorf("failed to create relay server: %v", err)
162162
}
163-
log.Infof("server will be available on: %s", srv.InstanceURL())
163+
instanceURL := srv.InstanceURL()
164+
log.Infof("server will be available on: %s", instanceURL.String())
164165
wg.Add(1)
165166
go func() {
166167
defer wg.Done()

relay/healthcheck/healthcheck.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ import (
66
"errors"
77
"net"
88
"net/http"
9-
"strings"
9+
"net/url"
1010
"sync"
1111
"time"
1212

1313
log "github.com/sirupsen/logrus"
1414

1515
"github.com/netbirdio/netbird/relay/protocol"
16+
"github.com/netbirdio/netbird/relay/server"
1617
)
1718

1819
const (
@@ -26,7 +27,7 @@ const (
2627

2728
type ServiceChecker interface {
2829
ListenerProtocols() []protocol.Protocol
29-
ExposedAddress() string
30+
InstanceURL() url.URL
3031
}
3132

3233
type HealthStatus struct {
@@ -134,7 +135,7 @@ func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) {
134135
}
135136
status.Listeners = listeners
136137

137-
if !strings.HasPrefix(s.config.ServiceChecker.ExposedAddress(), "rels") {
138+
if s.config.ServiceChecker.InstanceURL().Scheme != server.SchemeRELS {
138139
status.CertificateValid = false
139140
}
140141

@@ -156,14 +157,9 @@ func (s *Server) validateListeners() ([]protocol.Protocol, bool) {
156157
}
157158

158159
func (s *Server) validateConnection(ctx context.Context) bool {
159-
exposedAddress := s.config.ServiceChecker.ExposedAddress()
160-
if exposedAddress == "" {
161-
log.Error("exposed address is empty, cannot validate certificate")
162-
return false
163-
}
164-
165-
if err := dialWS(ctx, exposedAddress); err != nil {
166-
log.Errorf("failed to dial WebSocket listener at %s: %v", exposedAddress, err)
160+
addr := s.config.ServiceChecker.InstanceURL()
161+
if err := dialWS(ctx, addr); err != nil {
162+
log.Errorf("failed to dial WebSocket listener at %s: %v", addr.String(), err)
167163
return false
168164
}
169165

relay/healthcheck/ws.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@ package healthcheck
33
import (
44
"context"
55
"fmt"
6-
"strings"
6+
"net/url"
77

88
"github.com/coder/websocket"
99

10+
"github.com/netbirdio/netbird/relay/server"
1011
"github.com/netbirdio/netbird/shared/relay"
1112
)
1213

13-
func dialWS(ctx context.Context, address string) error {
14-
addressSplit := strings.Split(address, "/")
14+
func dialWS(ctx context.Context, address url.URL) error {
1515
scheme := "ws"
16-
if addressSplit[0] == "rels:" {
16+
if address.Scheme == server.SchemeRELS {
1717
scheme = "wss"
1818
}
19-
url := fmt.Sprintf("%s://%s%s", scheme, addressSplit[2], relay.WebSocketURLPath)
19+
wsURL := fmt.Sprintf("%s://%s%s", scheme, address.Host, relay.WebSocketURLPath)
2020

21-
conn, resp, err := websocket.Dial(ctx, url, nil)
21+
conn, resp, err := websocket.Dial(ctx, wsURL, nil)
2222
if resp != nil {
2323
defer func() {
2424
if resp.Body != nil {

relay/server/relay.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"net"
7+
"net/url"
78
"sync"
89
"time"
910

@@ -22,7 +23,7 @@ type Config struct {
2223
TLSSupport bool
2324
AuthValidator Validator
2425

25-
instanceURL string
26+
instanceURL url.URL
2627
}
2728

2829
func (c *Config) validate() error {
@@ -37,7 +38,7 @@ func (c *Config) validate() error {
3738
if err != nil {
3839
return fmt.Errorf("invalid url: %v", err)
3940
}
40-
c.instanceURL = instanceURL
41+
c.instanceURL = *instanceURL
4142

4243
if c.AuthValidator == nil {
4344
return fmt.Errorf("auth validator is required")
@@ -53,7 +54,7 @@ type Relay struct {
5354

5455
store *store.Store
5556
notifier *store.PeerNotifier
56-
instanceURL string
57+
instanceURL url.URL
5758
exposedAddress string
5859
preparedMsg *preparedMsg
5960

@@ -97,7 +98,7 @@ func NewRelay(config Config) (*Relay, error) {
9798
notifier: store.NewPeerNotifier(),
9899
}
99100

100-
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
101+
r.preparedMsg, err = newPreparedMsg(r.instanceURL.String())
101102
if err != nil {
102103
metricsCancel()
103104
return nil, fmt.Errorf("prepare message: %v", err)
@@ -177,11 +178,6 @@ func (r *Relay) Shutdown(ctx context.Context) {
177178
}
178179

179180
// InstanceURL returns the instance URL of the relay server
180-
func (r *Relay) InstanceURL() string {
181+
func (r *Relay) InstanceURL() url.URL {
181182
return r.instanceURL
182183
}
183-
184-
// ExposedAddress returns the exposed address (domain:port) where clients connect
185-
func (r *Relay) ExposedAddress() string {
186-
return r.exposedAddress
187-
}

relay/server/server.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package server
33
import (
44
"context"
55
"crypto/tls"
6+
"net/url"
67
"sync"
78

89
"github.com/hashicorp/go-multierror"
@@ -39,7 +40,7 @@ type Server struct {
3940
//
4041
// config: A Config struct containing the necessary configuration:
4142
// - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used.
42-
// - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required.
43+
// - InstanceURL: The public address (in domain:port format) used as the server's instance URL. Required.
4344
// - TLSSupport: A boolean indicating whether TLS is enabled for the server.
4445
// - AuthValidator: A Validator used to authenticate peers. Required.
4546
//
@@ -119,11 +120,6 @@ func (r *Server) Shutdown(ctx context.Context) error {
119120
return nberrors.FormatErrorOrNil(multiErr)
120121
}
121122

122-
// InstanceURL returns the instance URL of the relay server.
123-
func (r *Server) InstanceURL() string {
124-
return r.relay.instanceURL
125-
}
126-
127123
func (r *Server) ListenerProtocols() []protocol.Protocol {
128124
result := make([]protocol.Protocol, 0)
129125

@@ -135,6 +131,6 @@ func (r *Server) ListenerProtocols() []protocol.Protocol {
135131
return result
136132
}
137133

138-
func (r *Server) ExposedAddress() string {
139-
return r.relay.ExposedAddress()
134+
func (r *Server) InstanceURL() url.URL {
135+
return r.relay.InstanceURL()
140136
}

relay/server/url.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@ import (
66
"strings"
77
)
88

9+
const (
10+
SchemeREL = "rel"
11+
SchemeRELS = "rels"
12+
)
13+
914
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
1015
// provided address according to TLS definition and parses the address before returning it
11-
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
16+
func getInstanceURL(exposedAddress string, tlsSupported bool) (*url.URL, error) {
1217
addr := exposedAddress
1318
split := strings.Split(exposedAddress, "://")
1419
switch {
@@ -17,17 +22,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
1722
case len(split) == 1 && !tlsSupported:
1823
addr = "rel://" + exposedAddress
1924
case len(split) > 2:
20-
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
25+
return nil, fmt.Errorf("invalid exposed address: %s", exposedAddress)
2126
}
2227

2328
parsedURL, err := url.ParseRequestURI(addr)
2429
if err != nil {
25-
return "", fmt.Errorf("invalid exposed address: %v", err)
30+
return nil, fmt.Errorf("invalid exposed address: %v", err)
31+
}
32+
33+
if parsedURL.Scheme != SchemeREL && parsedURL.Scheme != SchemeRELS {
34+
return nil, fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
2635
}
2736

28-
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
29-
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
37+
// Validate scheme matches TLS configuration
38+
if tlsSupported && parsedURL.Scheme == SchemeREL {
39+
return nil, fmt.Errorf("non-TLS scheme '%s' provided but TLS is supported", SchemeREL)
3040
}
3141

32-
return parsedURL.String(), nil
42+
return parsedURL, nil
3343
}

relay/server/relay_test.go renamed to relay/server/url_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func TestGetInstanceURL(t *testing.T) {
1313
{"Valid address with TLS", "example.com", true, "rels://example.com", false},
1414
{"Valid address without TLS", "example.com", false, "rel://example.com", false},
1515
{"Valid address with scheme", "rel://example.com", false, "rel://example.com", false},
16-
{"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false},
16+
{"Invalid address with non TLS scheme and TLS true", "rel://example.com", true, "", true},
1717
{"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false},
1818
{"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false},
1919
{"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false},
@@ -28,8 +28,11 @@ func TestGetInstanceURL(t *testing.T) {
2828
if (err != nil) != tt.expectError {
2929
t.Errorf("expected error: %v, got: %v", tt.expectError, err)
3030
}
31-
if url != tt.expectedURL {
32-
t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url)
31+
if !tt.expectError && url != nil && url.String() != tt.expectedURL {
32+
t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url.String())
33+
}
34+
if tt.expectError && url != nil {
35+
t.Errorf("expected nil URL on error, got: %s", url.String())
3336
}
3437
})
3538
}

0 commit comments

Comments
 (0)