From 6078a49f3b67833dff3fccd10e4ff313e889f1af Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Tue, 29 Jul 2025 11:26:45 -0700 Subject: [PATCH 01/38] add debug endpoint; add test configs --- cmd/proxy/main.go | 5 ++ .../config/cluster-a-mux-client-proxy-1.yaml | 27 ++++++ .../config/cluster-a-mux-client-proxy-2.yaml | 27 ++++++ .../config/cluster-b-mux-server-proxy-1.yaml | 27 ++++++ .../config/cluster-b-mux-server-proxy-2.yaml | 27 ++++++ develop/config/dynamic-config.yaml | 4 +- proxy/debug.go | 52 ++++++++++++ proxy/stream_tracker.go | 83 +++++++++++++++++++ 8 files changed, 251 insertions(+), 1 deletion(-) create mode 100644 develop/config/cluster-a-mux-client-proxy-1.yaml create mode 100644 develop/config/cluster-a-mux-client-proxy-2.yaml create mode 100644 develop/config/cluster-b-mux-server-proxy-1.yaml create mode 100644 develop/config/cluster-b-mux-server-proxy-2.yaml create mode 100644 proxy/debug.go create mode 100644 proxy/stream_tracker.go diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 41b31daf..6085343b 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -69,6 +69,11 @@ func startPProfHTTPServer(logger log.Logger, c config.ProfilingConfig) { return } + // Add debug endpoint handler + http.HandleFunc("/debug/connections", func(w http.ResponseWriter, r *http.Request) { + proxy.HandleDebugInfo(w, r, logger) + }) + go func() { logger.Info("Start pprof http server", tag.NewStringTag("address", addr)) if err := http.ListenAndServe(addr, nil); err != nil { diff --git a/develop/config/cluster-a-mux-client-proxy-1.yaml b/develop/config/cluster-a-mux-client-proxy-1.yaml new file mode 100644 index 00000000..d0a3c526 --- /dev/null +++ b/develop/config/cluster-a-mux-client-proxy-1.yaml @@ -0,0 +1,27 @@ +inbound: + name: "a-inbound-server" + server: + type: "mux" + mux: "muxed" + client: + tcp: + serverAddress: "localhost:7233" +outbound: + name: "a-outbound-server" + server: + tcp: + listenAddress: "0.0.0.0:6133" + client: + type: "mux" + mux: "muxed" +mux: + - name: "muxed" + mode: "client" + client: + serverAddress: "localhost:7003" +shardCount: + mode: "lcm" + localShardCount: 2 + remoteShardCount: 3 +profiling: + pprofAddress: "localhost:6060" diff --git a/develop/config/cluster-a-mux-client-proxy-2.yaml b/develop/config/cluster-a-mux-client-proxy-2.yaml new file mode 100644 index 00000000..bc5f43dd --- /dev/null +++ b/develop/config/cluster-a-mux-client-proxy-2.yaml @@ -0,0 +1,27 @@ +inbound: + name: "a-inbound-server" + server: + type: "mux" + mux: "muxed" + client: + tcp: + serverAddress: "localhost:7233" +outbound: + name: "a-outbound-server" + server: + tcp: + listenAddress: "0.0.0.0:6233" + client: + type: "mux" + mux: "muxed" +mux: + - name: "muxed" + mode: "client" + client: + serverAddress: "localhost:7003" +shardCount: + mode: "lcm" + localShardCount: 2 + remoteShardCount: 3 +profiling: + pprofAddress: "localhost:6061" \ No newline at end of file diff --git a/develop/config/cluster-b-mux-server-proxy-1.yaml b/develop/config/cluster-b-mux-server-proxy-1.yaml new file mode 100644 index 00000000..70f01281 --- /dev/null +++ b/develop/config/cluster-b-mux-server-proxy-1.yaml @@ -0,0 +1,27 @@ +inbound: + name: "b-inbound-server" + server: + type: "mux" + mux: "muxed" + client: + tcp: + serverAddress: "localhost:8233" +outbound: + name: "b-outbound-server" + server: + tcp: + listenAddress: "0.0.0.0:6333" + client: + type: "mux" + mux: "muxed" +mux: + - name: "muxed" + mode: "server" + server: + listenAddress: "0.0.0.0:6334" +shardCount: + mode: "lcm" + localShardCount: 3 + remoteShardCount: 2 +profiling: + pprofAddress: "localhost:6070" \ No newline at end of file diff --git a/develop/config/cluster-b-mux-server-proxy-2.yaml b/develop/config/cluster-b-mux-server-proxy-2.yaml new file mode 100644 index 00000000..db3384c4 --- /dev/null +++ b/develop/config/cluster-b-mux-server-proxy-2.yaml @@ -0,0 +1,27 @@ +inbound: + name: "b-inbound-server" + server: + type: "mux" + mux: "muxed" + client: + tcp: + serverAddress: "localhost:8233" +outbound: + name: "b-outbound-server" + server: + tcp: + listenAddress: "0.0.0.0:6433" + client: + type: "mux" + mux: "muxed" +mux: + - name: "muxed" + mode: "server" + server: + listenAddress: "0.0.0.0:6434" +shardCount: + mode: "lcm" + localShardCount: 3 + remoteShardCount: 2 +profiling: + pprofAddress: "localhost:6071" \ No newline at end of file diff --git a/develop/config/dynamic-config.yaml b/develop/config/dynamic-config.yaml index 95de71bc..b5a46d29 100644 --- a/develop/config/dynamic-config.yaml +++ b/develop/config/dynamic-config.yaml @@ -26,4 +26,6 @@ history.persistenceMaxQPS: constraints: {} frontend.persistenceMaxQPS: - value: 100000 - constraints: {} \ No newline at end of file + constraints: {} +history.shardUpdateMinInterval: + - value: 1s \ No newline at end of file diff --git a/proxy/debug.go b/proxy/debug.go new file mode 100644 index 00000000..40520a25 --- /dev/null +++ b/proxy/debug.go @@ -0,0 +1,52 @@ +package proxy + +import ( + "encoding/json" + "net/http" + "time" + + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" +) + +type ( + // StreamInfo represents information about an active gRPC stream + StreamInfo struct { + ID string `json:"id"` + Method string `json:"method"` + Direction string `json:"direction"` + ClientShard string `json:"client_shard"` + ServerShard string `json:"server_shard"` + StartTime time.Time `json:"start_time"` + LastSeen time.Time `json:"last_seen"` + } + + DebugResponse struct { + Timestamp time.Time `json:"timestamp"` + ActiveStreams []StreamInfo `json:"active_streams"` + StreamCount int `json:"stream_count"` + } +) + +func HandleDebugInfo(w http.ResponseWriter, r *http.Request, logger log.Logger) { + w.Header().Set("Content-Type", "application/json") + + var activeStreams []StreamInfo + var streamCount int + + // Get active streams information + streamTracker := GetGlobalStreamTracker() + activeStreams = streamTracker.GetActiveStreams() + streamCount = streamTracker.GetStreamCount() + + response := DebugResponse{ + Timestamp: time.Now(), + ActiveStreams: activeStreams, + StreamCount: streamCount, + } + + if err := json.NewEncoder(w).Encode(response); err != nil { + logger.Error("Failed to encode debug response", tag.Error(err)) + http.Error(w, "Internal server error", http.StatusInternalServerError) + } +} diff --git a/proxy/stream_tracker.go b/proxy/stream_tracker.go new file mode 100644 index 00000000..6dfe948c --- /dev/null +++ b/proxy/stream_tracker.go @@ -0,0 +1,83 @@ +package proxy + +import ( + "sync" + "time" +) + +// StreamTracker tracks active gRPC streams for debugging +type StreamTracker struct { + mu sync.RWMutex + streams map[string]*StreamInfo +} + +// NewStreamTracker creates a new stream tracker +func NewStreamTracker() *StreamTracker { + return &StreamTracker{ + streams: make(map[string]*StreamInfo), + } +} + +// RegisterStream adds a new active stream +func (st *StreamTracker) RegisterStream(id, method, direction, clientShard, serverShard string) { + st.mu.Lock() + defer st.mu.Unlock() + + now := time.Now() + st.streams[id] = &StreamInfo{ + ID: id, + Method: method, + Direction: direction, + ClientShard: clientShard, + ServerShard: serverShard, + StartTime: now, + LastSeen: now, + } +} + +// UpdateStream updates the last seen time for a stream +func (st *StreamTracker) UpdateStream(id string) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.LastSeen = time.Now() + } +} + +// UnregisterStream removes a stream from tracking +func (st *StreamTracker) UnregisterStream(id string) { + st.mu.Lock() + defer st.mu.Unlock() + + delete(st.streams, id) +} + +// GetActiveStreams returns a copy of all active streams +func (st *StreamTracker) GetActiveStreams() []StreamInfo { + st.mu.RLock() + defer st.mu.RUnlock() + + streams := make([]StreamInfo, 0, len(st.streams)) + for _, stream := range st.streams { + streams = append(streams, *stream) + } + + return streams +} + +// GetStreamCount returns the number of active streams +func (st *StreamTracker) GetStreamCount() int { + st.mu.RLock() + defer st.mu.RUnlock() + + return len(st.streams) +} + +// Global stream tracker instance +var globalStreamTracker = NewStreamTracker() + +// GetGlobalStreamTracker returns the global stream tracker instance +func GetGlobalStreamTracker() *StreamTracker { + return globalStreamTracker +} From 4b1eb9b3cdb23aca65194031a56a553ca02e38c2 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Wed, 30 Jul 2025 17:41:53 -0700 Subject: [PATCH 02/38] add shard manager --- cmd/proxy/main.go | 6 +- config/config.go | 29 + .../config/cluster-a-mux-client-proxy-1.yaml | 23 +- .../config/cluster-a-mux-client-proxy-2.yaml | 25 +- .../config/cluster-b-mux-server-proxy-1.yaml | 25 +- .../config/cluster-b-mux-server-proxy-2.yaml | 25 +- endtoendtest/echo_server.go | 1 + go.mod | 11 + go.sum | 32 ++ proxy/cluster_connection.go | 8 +- proxy/cluster_connection_test.go | 10 +- proxy/debug.go | 26 +- proxy/fx.go | 1 + proxy/proxy.go | 17 +- proxy/shard_manager.go | 499 ++++++++++++++++++ proxy/test/replication_failover_test.go | 2 +- 16 files changed, 715 insertions(+), 25 deletions(-) create mode 100644 proxy/shard_manager.go diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 6085343b..1212854e 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -63,7 +63,7 @@ func buildCLIOptions() *cli.App { return app } -func startPProfHTTPServer(logger log.Logger, c config.ProfilingConfig) { +func startPProfHTTPServer(logger log.Logger, c config.ProfilingConfig, proxyInstance *proxy.Proxy) { addr := c.PProfHTTPAddress if len(addr) == 0 { return @@ -71,7 +71,7 @@ func startPProfHTTPServer(logger log.Logger, c config.ProfilingConfig) { // Add debug endpoint handler http.HandleFunc("/debug/connections", func(w http.ResponseWriter, r *http.Request) { - proxy.HandleDebugInfo(w, r, logger) + proxy.HandleDebugInfo(w, r, proxyInstance, logger) }) go func() { @@ -106,7 +106,7 @@ func startProxy(c *cli.Context) error { } cfg := proxyParams.ConfigProvider.GetS2SProxyConfig() - startPProfHTTPServer(proxyParams.Logger, cfg.ProfilingConfig) + startPProfHTTPServer(proxyParams.Logger, cfg.ProfilingConfig, proxyParams.Proxy) if err := proxyParams.Proxy.Start(); err != nil { return err diff --git a/config/config.go b/config/config.go index 53e85fbc..ae1958f0 100644 --- a/config/config.go +++ b/config/config.go @@ -40,6 +40,7 @@ type ShardCountMode string const ( ShardCountDefault ShardCountMode = "" ShardCountLCM ShardCountMode = "lcm" + ShardCountFixed ShardCountMode = "fixed" ) type HealthCheckProtocol string @@ -156,6 +157,7 @@ type ( OutboundHealthCheck *HealthCheckConfig `yaml:"outboundHealthCheck"` NamespaceNameTranslation NameTranslationConfig `yaml:"namespaceNameTranslation"` SearchAttributeTranslation SATranslationConfig `yaml:"searchAttributeTranslation"` + MemberlistConfig *MemberlistConfig `yaml:"memberlist"` Metrics *MetricsConfig `yaml:"metrics"` ProfilingConfig ProfilingConfig `yaml:"profiling"` Logging LoggingConfig `yaml:"logging"` @@ -217,6 +219,33 @@ type ( LoggingConfig struct { ThrottleMaxRPS float64 `yaml:"throttleMaxRPS"` } + + MemberlistConfig struct { + // Enable distributed shard management using memberlist + Enabled bool `yaml:"enabled"` + // Enable proxy-to-proxy forwarding (requires Enabled=true) + EnableForwarding bool `yaml:"enableForwarding"` + // Node name for this proxy instance in the cluster + NodeName string `yaml:"nodeName"` + // Bind address for memberlist cluster communication + BindAddr string `yaml:"bindAddr"` + // Bind port for memberlist cluster communication + BindPort int `yaml:"bindPort"` + // List of existing cluster members to join + JoinAddrs []string `yaml:"joinAddrs"` + // Shard assignment strategy (deprecated - now uses actual ownership tracking) + ShardStrategy string `yaml:"shardStrategy"` + // Map of node names to their proxy service addresses for forwarding + ProxyAddresses map[string]string `yaml:"proxyAddresses"` + // Use TCP-only transport (disables UDP) for restricted networks + TCPOnly bool `yaml:"tcpOnly"` + // Disable TCP pings when using TCP-only mode + DisableTCPPings bool `yaml:"disableTCPPings"` + // Probe timeout for memberlist health checks + ProbeTimeoutMs int `yaml:"probeTimeoutMs"` + // Probe interval for memberlist health checks + ProbeIntervalMs int `yaml:"probeIntervalMs"` + } ) func FromServerTLSConfig(cfg ServerTLSConfig) encryption.TLSConfig { diff --git a/develop/config/cluster-a-mux-client-proxy-1.yaml b/develop/config/cluster-a-mux-client-proxy-1.yaml index d0a3c526..426eede2 100644 --- a/develop/config/cluster-a-mux-client-proxy-1.yaml +++ b/develop/config/cluster-a-mux-client-proxy-1.yaml @@ -19,9 +19,30 @@ mux: mode: "client" client: serverAddress: "localhost:7003" +# shardCount: +# mode: "lcm" +# localShardCount: 2 +# remoteShardCount: 3 shardCount: - mode: "lcm" + mode: "fixed" localShardCount: 2 remoteShardCount: 3 profiling: pprofAddress: "localhost:6060" +memberlist: + enabled: true + enableForwarding: false + nodeName: "proxy-node-a-1" + bindAddr: "0.0.0.0" + bindPort: 6135 + joinAddrs: + - "localhost:6235" + proxyAddresses: + "proxy-node-1": "localhost:7001" + "proxy-node-2": "proxy-node-2:7001" + "proxy-node-3": "proxy-node-3:7001" + # TCP-only configuration for restricted networks + tcpOnly: true # Use TCP transport only, disable UDP + disableTCPPings: true # Disable TCP pings for faster convergence + probeTimeoutMs: 1000 # Longer timeout for network latency + probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file diff --git a/develop/config/cluster-a-mux-client-proxy-2.yaml b/develop/config/cluster-a-mux-client-proxy-2.yaml index bc5f43dd..4027cb22 100644 --- a/develop/config/cluster-a-mux-client-proxy-2.yaml +++ b/develop/config/cluster-a-mux-client-proxy-2.yaml @@ -19,9 +19,30 @@ mux: mode: "client" client: serverAddress: "localhost:7003" +# shardCount: +# mode: "lcm" +# localShardCount: 2 +# remoteShardCount: 3 shardCount: - mode: "lcm" + mode: "fixed" localShardCount: 2 remoteShardCount: 3 profiling: - pprofAddress: "localhost:6061" \ No newline at end of file + pprofAddress: "localhost:6061" +memberlist: + enabled: true + enableForwarding: false + nodeName: "proxy-node-a-2" + bindAddr: "0.0.0.0" + bindPort: 6235 + joinAddrs: + - "localhost:6135" + proxyAddresses: + "proxy-node-1": "localhost:7001" + "proxy-node-2": "proxy-node-2:7001" + "proxy-node-3": "proxy-node-3:7001" + # TCP-only configuration for restricted networks + tcpOnly: true # Use TCP transport only, disable UDP + disableTCPPings: true # Disable TCP pings for faster convergence + probeTimeoutMs: 1000 # Longer timeout for network latency + probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file diff --git a/develop/config/cluster-b-mux-server-proxy-1.yaml b/develop/config/cluster-b-mux-server-proxy-1.yaml index 70f01281..5155b29f 100644 --- a/develop/config/cluster-b-mux-server-proxy-1.yaml +++ b/develop/config/cluster-b-mux-server-proxy-1.yaml @@ -19,9 +19,30 @@ mux: mode: "server" server: listenAddress: "0.0.0.0:6334" +# shardCount: +# mode: "lcm" +# localShardCount: 3 +# remoteShardCount: 2 shardCount: - mode: "lcm" + mode: "fixed" localShardCount: 3 remoteShardCount: 2 profiling: - pprofAddress: "localhost:6070" \ No newline at end of file + pprofAddress: "localhost:6070" +memberlist: + enabled: true + enableForwarding: false + nodeName: "proxy-node-b-1" + bindAddr: "0.0.0.0" + bindPort: 6335 + joinAddrs: + - "localhost:6435" + proxyAddresses: + "proxy-node-1": "localhost:7001" + "proxy-node-2": "proxy-node-2:7001" + "proxy-node-3": "proxy-node-3:7001" + # TCP-only configuration for restricted networks + tcpOnly: true # Use TCP transport only, disable UDP + disableTCPPings: true # Disable TCP pings for faster convergence + probeTimeoutMs: 1000 # Longer timeout for network latency + probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file diff --git a/develop/config/cluster-b-mux-server-proxy-2.yaml b/develop/config/cluster-b-mux-server-proxy-2.yaml index db3384c4..6cefe758 100644 --- a/develop/config/cluster-b-mux-server-proxy-2.yaml +++ b/develop/config/cluster-b-mux-server-proxy-2.yaml @@ -19,9 +19,30 @@ mux: mode: "server" server: listenAddress: "0.0.0.0:6434" +# shardCount: +# mode: "lcm" +# localShardCount: 3 +# remoteShardCount: 2 shardCount: - mode: "lcm" + mode: "fixed" localShardCount: 3 remoteShardCount: 2 profiling: - pprofAddress: "localhost:6071" \ No newline at end of file + pprofAddress: "localhost:6071" +memberlist: + enabled: true + enableForwarding: false + nodeName: "proxy-node-b-2" + bindAddr: "0.0.0.0" + bindPort: 6435 + joinAddrs: + - "localhost:6335" + proxyAddresses: + "proxy-node-1": "localhost:7001" + "proxy-node-2": "proxy-node-2:7001" + "proxy-node-3": "proxy-node-3:7001" + # TCP-only configuration for restricted networks + tcpOnly: true # Use TCP transport only, disable UDP + disableTCPPings: true # Disable TCP pings for faster convergence + probeTimeoutMs: 1000 # Longer timeout for network latency + probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file diff --git a/endtoendtest/echo_server.go b/endtoendtest/echo_server.go index 22e9504f..83288ad8 100644 --- a/endtoendtest/echo_server.go +++ b/endtoendtest/echo_server.go @@ -115,6 +115,7 @@ func NewEchoServer( configProvider := config.NewMockConfigProvider(*localClusterInfo.S2sProxyConfig) proxy = s2sproxy.NewProxy( configProvider, + nil, logger, ) diff --git a/go.mod b/go.mod index 95719fa8..25be7132 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gogo/status v1.1.1 github.com/golang/mock v1.7.0-rc.1 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.1.0 + github.com/hashicorp/memberlist v0.5.1 github.com/hashicorp/yamux v0.1.2 github.com/keilerkonzept/visit v1.1.1 github.com/pkg/errors v0.9.1 @@ -40,6 +41,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.51.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0 // indirect github.com/apache/thrift v0.21.0 // indirect + github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da // indirect github.com/aws/aws-sdk-go v1.55.6 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -66,6 +68,7 @@ require ( github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v1.0.0 // indirect + github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect @@ -76,6 +79,12 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect + github.com/hashicorp/errwrap v1.0.0 // indirect + github.com/hashicorp/go-immutable-radix v1.0.0 // indirect + github.com/hashicorp/go-msgpack/v2 v2.1.1 // indirect + github.com/hashicorp/go-multierror v1.0.0 // indirect + github.com/hashicorp/go-sockaddr v1.0.0 // indirect + github.com/hashicorp/golang-lru v0.5.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/iancoleman/strcase v0.3.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -89,6 +98,7 @@ require ( github.com/lib/pq v1.10.9 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/miekg/dns v1.1.26 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect @@ -106,6 +116,7 @@ require ( github.com/robfig/cron v1.2.0 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/sony/gobreaker v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect diff --git a/go.sum b/go.sum index 602451ec..487fdb62 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,8 @@ github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3 github.com/apache/thrift v0.16.0/go.mod h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU= github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE= github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da h1:8GUt8eRujhVEGZFFEjBj46YV4rDjvGrNxb0KMWYkL2I= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/aws/aws-sdk-go v1.55.6 h1:cSg4pvZ3m8dgYcgqB97MrcdjUmZ1BeMYKUxMMB89IPk= github.com/aws/aws-sdk-go v1.55.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/benbjohnson/clock v0.0.0-20160125162948-a620c1cc9866/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE= @@ -141,6 +143,8 @@ github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6 github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c h1:964Od4U6p2jUkFxvCydnIczKteheJEzHRToSGK3Bnlw= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -171,8 +175,24 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5uk github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-immutable-radix v1.0.0 h1:AKDB1HM5PWEA7i4nhcpwOrO2byshxBjXVn/J/3+z5/0= +github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-msgpack/v2 v2.1.1 h1:xQEY9yB2wnHitoSzk/B9UjXWRQ67QKu5AOm8aFp8N3I= +github.com/hashicorp/go-msgpack/v2 v2.1.1/go.mod h1:upybraOAblm4S7rx0+jeNy+CWWhzywQsSRV5033mMu4= +github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-sockaddr v1.0.0 h1:GeH6tui99pF4NJgfnhp+L6+FfobzVW3Ah46sLo0ICXs= +github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= +github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= +github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/golang-lru v0.5.0 h1:CL2msUPvZTLb5O648aiLNJw3hnBxN2+1Jq8rCOH9wdo= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/hashicorp/memberlist v0.5.1 h1:mk5dRuzeDNis2bi6LLoQIXfMH7JQvAzt3mQD0vNZZUo= +github.com/hashicorp/memberlist v0.5.1/go.mod h1:zGDXV6AqbDTKTM6yxW0I4+JtFzZAJVoIPvss4hV8F24= github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= @@ -221,6 +241,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/miekg/dns v1.1.26 h1:gPxPSwALAeHJSjarOs00QjVdV9QoBvc1D2ujQUr5BzU= +github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -235,6 +257,8 @@ github.com/olivere/elastic/v7 v7.0.32/go.mod h1:c7PVmLe3Fxq77PIfY/bZmxY/TAamBhCz github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c h1:Lgl0gzECD8GnQ5QCWA8o6BtfL6mDH5rQgM4/fX3avOs= +github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw= github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -270,6 +294,8 @@ github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWN github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/samuel/go-thrift v0.0.0-20190219015601-e8b6b52668fe/go.mod h1:Vrkh1pnjV9Bl8c3P9zH0/D4NlOHWP5d4/hF4YTULaec= +github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUtVbo7ada43DJhG55ua/hjS5I= +github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sirupsen/logrus v1.0.2-0.20170726183946-abee6f9b0679/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -377,6 +403,7 @@ go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -414,6 +441,7 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -442,6 +470,8 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -465,6 +495,7 @@ golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -483,6 +514,7 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3 golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190907020128-2ca718005c18/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index 5ee388e3..90004a5d 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -59,6 +59,7 @@ type ( inboundClient closableClientConn inboundObserver *ReplicationStreamObserver outboundObserver *ReplicationStreamObserver + shardManager ShardManager logger log.Logger } // contextAwareServer represents a startable gRPC server used to provide the Temporal interface on some connection. @@ -103,12 +104,13 @@ func sanitizeConnectionName(name string) string { } // NewClusterConnection unpacks the connConfig and creates the inbound and outbound clients and servers. -func NewClusterConnection(lifetime context.Context, connConfig config.ClusterConnConfig, logger log.Logger) (*ClusterConnection, error) { +func NewClusterConnection(lifetime context.Context, connConfig config.ClusterConnConfig, shardManager ShardManager, logger log.Logger) (*ClusterConnection, error) { // The name is used in metrics and in the protocol for identifying the multi-client-conn. Sanitize it or else grpc.Dial will be very unhappy. sanitizedConnectionName := sanitizeConnectionName(connConfig.Name) cc := &ClusterConnection{ - lifetime: lifetime, - logger: log.With(logger, tag.NewStringTag("clusterConn", sanitizedConnectionName)), + lifetime: lifetime, + logger: log.With(logger, tag.NewStringTag("clusterConn", sanitizedConnectionName)), + shardManager: shardManager, } var err error cc.inboundClient, err = createClient(lifetime, sanitizedConnectionName, connConfig.LocalServer.Connection, "inbound") diff --git a/proxy/cluster_connection_test.go b/proxy/cluster_connection_test.go index cf45a56d..27db4275 100644 --- a/proxy/cluster_connection_test.go +++ b/proxy/cluster_connection_test.go @@ -172,25 +172,25 @@ func newPairedLocalClusterConnection(t *testing.T, isMux bool, logger log.Logger var localCtx context.Context localCtx, cancelLocalCC = context.WithCancel(t.Context()) localCC, err = NewClusterConnection(localCtx, makeTCPClusterConfig("TCP-only Connection Local Proxy", - a.localTemporalAddr, a.localProxyInbound, a.localProxyOutbound, a.remoteProxyInbound), logger) + a.localTemporalAddr, a.localProxyInbound, a.localProxyOutbound, a.remoteProxyInbound), nil, logger) require.NoError(t, err) var remoteCtx context.Context remoteCtx, cancelRemoteCC = context.WithCancel(t.Context()) remoteCC, err = NewClusterConnection(remoteCtx, makeTCPClusterConfig("TCP-only Connection Remote Proxy", - a.remoteTemporalAddr, a.remoteProxyInbound, a.remoteProxyOutbound, a.localProxyInbound), logger) + a.remoteTemporalAddr, a.remoteProxyInbound, a.remoteProxyOutbound, a.localProxyInbound), nil, logger) require.NoError(t, err) } else { var localCtx context.Context localCtx, cancelLocalCC = context.WithCancel(t.Context()) localCC, err = NewClusterConnection(localCtx, makeMuxClusterConfig("Mux Connection Local Establishing Proxy", - config.ConnTypeMuxClient, a.localTemporalAddr, a.localProxyOutbound, a.remoteProxyInbound), logger) + config.ConnTypeMuxClient, a.localTemporalAddr, a.localProxyOutbound, a.remoteProxyInbound), nil, logger) require.NoError(t, err) var remoteCtx context.Context remoteCtx, cancelRemoteCC = context.WithCancel(t.Context()) remoteCC, err = NewClusterConnection(remoteCtx, makeMuxClusterConfig("Mux Connection Remote Receiving Proxy", - config.ConnTypeMuxServer, a.remoteTemporalAddr, a.remoteProxyOutbound, a.remoteProxyInbound), logger) + config.ConnTypeMuxServer, a.remoteTemporalAddr, a.remoteProxyOutbound, a.remoteProxyInbound), nil, logger) require.NoError(t, err) } clientFromLocal, err := grpc.NewClient(a.localProxyOutbound, grpcutil.MakeDialOptions(nil, metrics.GetStandardGRPCClientInterceptor("outbound-local"))...) @@ -259,7 +259,7 @@ func TestMuxCCFailover(t *testing.T) { cancel() newConnection, err := NewClusterConnection(t.Context(), makeMuxClusterConfig("newRemoteMux", config.ConnTypeMuxServer, plcc.addresses.remoteTemporalAddr, plcc.addresses.remoteProxyOutbound, plcc.addresses.remoteProxyInbound, - func(cc *config.ClusterConnConfig) { cc.RemoteServer.Connection.MuxCount = 5 }), logger) + func(cc *config.ClusterConnConfig) { cc.RemoteServer.Connection.MuxCount = 5 }), nil, logger) require.NoError(t, err) newConnection.Start() // Wait for localCC's client retry... diff --git a/proxy/debug.go b/proxy/debug.go index 40520a25..2410b33e 100644 --- a/proxy/debug.go +++ b/proxy/debug.go @@ -5,6 +5,7 @@ import ( "net/http" "time" + "go.temporal.io/server/client/history" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" ) @@ -21,28 +22,45 @@ type ( LastSeen time.Time `json:"last_seen"` } + // ShardDebugInfo contains debug information about shard distribution + ShardDebugInfo struct { + Enabled bool `json:"enabled"` + ForwardingEnabled bool `json:"forwarding_enabled"` + NodeName string `json:"node_name"` + LocalShards []history.ClusterShardID `json:"local_shards"` + LocalShardCount int `json:"local_shard_count"` + ClusterNodes []string `json:"cluster_nodes"` + ClusterSize int `json:"cluster_size"` + RemoteShards map[string]string `json:"remote_shards"` // shard_id -> node_name + RemoteShardCounts map[string]int `json:"remote_shard_counts"` // node_name -> shard_count + } + DebugResponse struct { - Timestamp time.Time `json:"timestamp"` - ActiveStreams []StreamInfo `json:"active_streams"` - StreamCount int `json:"stream_count"` + Timestamp time.Time `json:"timestamp"` + ActiveStreams []StreamInfo `json:"active_streams"` + StreamCount int `json:"stream_count"` + ShardInfo ShardDebugInfo `json:"shard_info"` } ) -func HandleDebugInfo(w http.ResponseWriter, r *http.Request, logger log.Logger) { +func HandleDebugInfo(w http.ResponseWriter, r *http.Request, proxyInstance *Proxy, logger log.Logger) { w.Header().Set("Content-Type", "application/json") var activeStreams []StreamInfo var streamCount int + var shardInfo ShardDebugInfo // Get active streams information streamTracker := GetGlobalStreamTracker() activeStreams = streamTracker.GetActiveStreams() streamCount = streamTracker.GetStreamCount() + shardInfo = proxyInstance.GetShardInfo() response := DebugResponse{ Timestamp: time.Now(), ActiveStreams: activeStreams, StreamCount: streamCount, + ShardInfo: shardInfo, } if err := json.NewEncoder(w).Encode(response); err != nil { diff --git a/proxy/fx.go b/proxy/fx.go index aba1ded7..16afc573 100644 --- a/proxy/fx.go +++ b/proxy/fx.go @@ -6,4 +6,5 @@ import ( var Module = fx.Options( fx.Provide(NewProxy), + fx.Provide(NewShardManager), ) diff --git a/proxy/proxy.go b/proxy/proxy.go index e0e2e6ed..5456b0da 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -30,17 +30,19 @@ type ( inboundHealthCheckServer *http.Server outboundHealthCheckServer *http.Server metricsServer *http.Server + shardManager ShardManager logger log.Logger } ) -func NewProxy(configProvider config.ConfigProvider, logger log.Logger) *Proxy { +func NewProxy(configProvider config.ConfigProvider, shardManager ShardManager, logger log.Logger) *Proxy { s2sConfig := config.ToClusterConnConfig(configProvider.GetS2SProxyConfig()) ctx, cancel := context.WithCancel(context.Background()) proxy := &Proxy{ lifetime: ctx, cancel: cancel, clusterConnections: make(map[migrationId]*ClusterConnection, len(s2sConfig.MuxTransports)), + shardManager: shardManager, logger: log.NewThrottledLogger( logger, func() float64 { @@ -55,7 +57,7 @@ func NewProxy(configProvider config.ConfigProvider, logger log.Logger) *Proxy { proxy.metricsConfig = s2sConfig.Metrics } for _, clusterCfg := range s2sConfig.ClusterConnections { - cc, err := NewClusterConnection(ctx, clusterCfg, logger) + cc, err := NewClusterConnection(ctx, clusterCfg, shardManager, logger) if err != nil { logger.Fatal("Incorrectly configured Mux cluster connection", tag.Error(err), tag.NewStringTag("name", clusterCfg.Name)) continue @@ -169,6 +171,12 @@ func (s *Proxy) Start() error { ` it needs at least the following path: metrics.prometheus.listenAddress`) } + if s.shardManager != nil { + if err := s.shardManager.Start(s.lifetime); err != nil { + return err + } + } + for _, v := range s.clusterConnections { v.Start() } @@ -199,3 +207,8 @@ func (s *Proxy) Describe() string { sb.WriteString("]") return sb.String() } + +// GetShardInfo returns debug information about shard distribution +func (s *Proxy) GetShardInfo() ShardDebugInfo { + return s.shardManager.GetShardInfo() +} diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go new file mode 100644 index 00000000..0aff2b76 --- /dev/null +++ b/proxy/shard_manager.go @@ -0,0 +1,499 @@ +package proxy + +import ( + "context" + "encoding/json" + "fmt" + "hash/fnv" + "sort" + "strconv" + "sync" + "time" + + "github.com/hashicorp/memberlist" + "go.temporal.io/server/client/history" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + + "github.com/temporalio/s2s-proxy/config" +) + +type ( + // ShardManager manages distributed shard ownership across proxy instances + ShardManager interface { + // Start initializes the memberlist cluster and starts the manager + Start(lifetime context.Context) error + // Stop shuts down the manager and leaves the cluster + Stop() + // RegisterShard registers a clientShardID as owned by this proxy instance + RegisterShard(clientShardID history.ClusterShardID) + // UnregisterShard removes a clientShardID from this proxy's ownership + UnregisterShard(clientShardID history.ClusterShardID) + // GetShardOwner returns the proxy node name that owns the given shard + GetShardOwner(clientShardID history.ClusterShardID) (string, bool) + // GetProxyAddress returns the proxy service address for the given node name + GetProxyAddress(nodeName string) (string, bool) + // IsLocalShard checks if this proxy instance owns the given shard + IsLocalShard(clientShardID history.ClusterShardID) bool + // GetMemberNodes returns all active proxy nodes in the cluster + GetMemberNodes() []string + // GetLocalShards returns all shards currently handled by this proxy instance + GetLocalShards() []history.ClusterShardID + // GetShardInfo returns debug information about shard distribution + GetShardInfo() ShardDebugInfo + } + + shardManagerImpl struct { + config *config.MemberlistConfig + logger log.Logger + ml *memberlist.Memberlist + delegate *shardDelegate + mutex sync.RWMutex + localAddr string + started bool + } + + // shardDelegate implements memberlist.Delegate for shard state management + shardDelegate struct { + manager *shardManagerImpl + logger log.Logger + localShards map[string]history.ClusterShardID // key: "clusterID:shardID" + mutex sync.RWMutex + } + + // ShardMessage represents shard ownership changes broadcast to cluster + ShardMessage struct { + Type string `json:"type"` // "register" or "unregister" + NodeName string `json:"node"` + ClientShard history.ClusterShardID `json:"shard"` + Timestamp time.Time `json:"timestamp"` + } + + // NodeShardState represents all shards owned by a node + NodeShardState struct { + NodeName string `json:"node"` + Shards map[string]history.ClusterShardID `json:"shards"` + Updated time.Time `json:"updated"` + } +) + +// NewShardManager creates a new shard manager instance +func NewShardManager(configProvider config.ConfigProvider, logger log.Logger) (ShardManager, error) { + cfg := configProvider.GetS2SProxyConfig().MemberlistConfig + if cfg == nil || !cfg.Enabled { + return &noopShardManager{}, nil + } + + delegate := &shardDelegate{ + logger: logger, + localShards: make(map[string]history.ClusterShardID), + } + + sm := &shardManagerImpl{ + config: cfg, + logger: logger, + delegate: delegate, + } + + delegate.manager = sm + + return sm, nil +} + +func (sm *shardManagerImpl) Start(lifetime context.Context) error { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + if sm.started { + return nil + } + + // Configure memberlist + var mlConfig *memberlist.Config + if sm.config.TCPOnly { + mlConfig = memberlist.DefaultWANConfig() + // Disable UDP for restricted networks + mlConfig.DisableTcpPings = sm.config.DisableTCPPings + } else { + mlConfig = memberlist.DefaultLocalConfig() + } + + mlConfig.Name = sm.config.NodeName + mlConfig.BindAddr = sm.config.BindAddr + mlConfig.BindPort = sm.config.BindPort + mlConfig.Delegate = sm.delegate + mlConfig.Events = &shardEventDelegate{manager: sm, logger: sm.logger} + + // Configure timeouts if specified + if sm.config.ProbeTimeoutMs > 0 { + mlConfig.ProbeTimeout = time.Duration(sm.config.ProbeTimeoutMs) * time.Millisecond + } + if sm.config.ProbeIntervalMs > 0 { + mlConfig.ProbeInterval = time.Duration(sm.config.ProbeIntervalMs) * time.Millisecond + } + + // Create memberlist + ml, err := memberlist.Create(mlConfig) + if err != nil { + return fmt.Errorf("failed to create memberlist: %w", err) + } + + sm.ml = ml + sm.localAddr = fmt.Sprintf("%s:%d", sm.config.BindAddr, sm.config.BindPort) + + // Join existing cluster if configured + if len(sm.config.JoinAddrs) > 0 { + num, err := ml.Join(sm.config.JoinAddrs) + if err != nil { + sm.logger.Warn("Failed to join some cluster members", tag.Error(err)) + } + sm.logger.Info("Joined memberlist cluster", tag.NewStringTag("members", strconv.Itoa(num))) + } + + sm.started = true + sm.logger.Info("Shard manager started", + tag.NewStringTag("node", sm.config.NodeName), + tag.NewStringTag("addr", sm.localAddr)) + + context.AfterFunc(lifetime, func() { + sm.Stop() + }) + return nil +} + +func (sm *shardManagerImpl) Stop() { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + if !sm.started || sm.ml == nil { + return + } + + // Leave the cluster gracefully + err := sm.ml.Leave(5 * time.Second) + if err != nil { + sm.logger.Error("Error leaving memberlist cluster", tag.Error(err)) + } + + err = sm.ml.Shutdown() + if err != nil { + sm.logger.Error("Error shutting down memberlist", tag.Error(err)) + } + + sm.started = false + sm.logger.Info("Shard manager stopped") +} + +func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) { + sm.delegate.addLocalShard(clientShardID) + sm.broadcastShardChange("register", clientShardID) + + // Trigger memberlist metadata update to propagate NodeMeta to other nodes + if sm.ml != nil { + if err := sm.ml.UpdateNode(0); err != nil { // 0 timeout means immediate update + sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) + } + } +} + +func (sm *shardManagerImpl) UnregisterShard(clientShardID history.ClusterShardID) { + sm.delegate.removeLocalShard(clientShardID) + sm.broadcastShardChange("unregister", clientShardID) + + // Trigger memberlist metadata update to propagate NodeMeta to other nodes + if sm.ml != nil { + if err := sm.ml.UpdateNode(0); err != nil { // 0 timeout means immediate update + sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) + } + } +} + +func (sm *shardManagerImpl) GetShardOwner(clientShardID history.ClusterShardID) (string, bool) { + if !sm.started { + return "", false + } + + // Use consistent hashing to determine shard owner + return sm.consistentHashOwner(clientShardID), true +} + +func (sm *shardManagerImpl) IsLocalShard(clientShardID history.ClusterShardID) bool { + if !sm.started { + return true // If not using memberlist, handle locally + } + + owner, found := sm.GetShardOwner(clientShardID) + return found && owner == sm.config.NodeName +} + +func (sm *shardManagerImpl) GetProxyAddress(nodeName string) (string, bool) { + if sm.config.ProxyAddresses == nil { + return "", false + } + addr, found := sm.config.ProxyAddresses[nodeName] + return addr, found +} + +func (sm *shardManagerImpl) GetMemberNodes() []string { + if !sm.started || sm.ml == nil { + return []string{sm.config.NodeName} + } + + // Use a timeout to prevent deadlocks when memberlist is busy + membersChan := make(chan []*memberlist.Node, 1) + go func() { + defer func() { + if r := recover(); r != nil { + sm.logger.Error("Panic in GetMemberNodes", tag.NewStringTag("error", fmt.Sprintf("%v", r))) + } + }() + membersChan <- sm.ml.Members() + }() + + select { + case members := <-membersChan: + nodes := make([]string, len(members)) + for i, member := range members { + nodes[i] = member.Name + } + return nodes + case <-time.After(100 * time.Millisecond): + // Timeout: return cached node name to prevent hanging + sm.logger.Warn("GetMemberNodes timeout, returning self node", + tag.NewStringTag("node", sm.config.NodeName)) + return []string{sm.config.NodeName} + } +} + +func (sm *shardManagerImpl) GetLocalShards() []history.ClusterShardID { + sm.delegate.mutex.RLock() + defer sm.delegate.mutex.RUnlock() + + shards := make([]history.ClusterShardID, 0, len(sm.delegate.localShards)) + for _, shard := range sm.delegate.localShards { + shards = append(shards, shard) + } + return shards +} + +func (sm *shardManagerImpl) GetShardInfo() ShardDebugInfo { + localShards := sm.GetLocalShards() + clusterNodes := sm.GetMemberNodes() + + // Build remote shard maps by querying memberlist metadata directly + remoteShards := make(map[string]string) + remoteShardCounts := make(map[string]int) + + // Initialize counts for all nodes + for _, node := range clusterNodes { + remoteShardCounts[node] = 0 + } + + // Count local shards for this node + remoteShardCounts[sm.config.NodeName] = len(localShards) + + // Collect shard ownership information from all cluster members + if sm.ml != nil { + for _, member := range sm.ml.Members() { + if len(member.Meta) > 0 { + var nodeState NodeShardState + if err := json.Unmarshal(member.Meta, &nodeState); err == nil { + nodeName := nodeState.NodeName + if nodeName != "" { + remoteShardCounts[nodeName] = len(nodeState.Shards) + + // Add remote shards (exclude local node) + if nodeName != sm.config.NodeName { + for _, shard := range nodeState.Shards { + shardKey := fmt.Sprintf("%d:%d", shard.ClusterID, shard.ShardID) + remoteShards[shardKey] = nodeName + } + } + } + } + } + } + } + + return ShardDebugInfo{ + Enabled: true, + ForwardingEnabled: sm.config.EnableForwarding, + NodeName: sm.config.NodeName, + LocalShards: localShards, + LocalShardCount: len(localShards), + ClusterNodes: clusterNodes, + ClusterSize: len(clusterNodes), + RemoteShards: remoteShards, + RemoteShardCounts: remoteShardCounts, + } +} + +func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.ClusterShardID) { + if !sm.started || sm.ml == nil { + return + } + + msg := ShardMessage{ + Type: msgType, + NodeName: sm.config.NodeName, + ClientShard: shard, + Timestamp: time.Now(), + } + + data, err := json.Marshal(msg) + if err != nil { + sm.logger.Error("Failed to marshal shard message", tag.Error(err)) + return + } + + err = sm.ml.SendReliable(sm.ml.Members()[0], data) + if err != nil { + sm.logger.Error("Failed to broadcast shard change", tag.Error(err)) + } +} + +func (sm *shardManagerImpl) consistentHashOwner(shard history.ClusterShardID) string { + nodes := sm.GetMemberNodes() + if len(nodes) == 0 { + return sm.config.NodeName + } + + // Sort nodes for consistent ordering + sort.Strings(nodes) + + // Hash the shard ID + h := fnv.New32a() + shardKey := fmt.Sprintf("%d:%d", shard.ClusterID, shard.ShardID) + h.Write([]byte(shardKey)) + hash := h.Sum32() + + // Use consistent hashing to determine owner + return nodes[hash%uint32(len(nodes))] +} + +// shardDelegate implements memberlist.Delegate +func (sd *shardDelegate) NodeMeta(limit int) []byte { + sd.mutex.RLock() + defer sd.mutex.RUnlock() + + state := NodeShardState{ + NodeName: sd.manager.config.NodeName, + Shards: sd.localShards, + Updated: time.Now(), + } + + data, err := json.Marshal(state) + if err != nil { + sd.logger.Error("Failed to marshal node meta", tag.Error(err)) + return nil + } + + if len(data) > limit { + // If metadata is too large, just send node name + return []byte(sd.manager.config.NodeName) + } + + return data +} + +func (sd *shardDelegate) NotifyMsg(data []byte) { + var msg ShardMessage + if err := json.Unmarshal(data, &msg); err != nil { + sd.logger.Error("Failed to unmarshal shard message", tag.Error(err)) + return + } + + sd.logger.Debug("Received shard message", + tag.NewStringTag("type", msg.Type), + tag.NewStringTag("node", msg.NodeName), + tag.NewStringTag("shard", ClusterShardIDtoString(msg.ClientShard))) +} + +func (sd *shardDelegate) GetBroadcasts(overhead, limit int) [][]byte { + // Not implementing broadcasts for now + return nil +} + +func (sd *shardDelegate) LocalState(join bool) []byte { + return sd.NodeMeta(512) +} + +func (sd *shardDelegate) MergeRemoteState(buf []byte, join bool) { + var state NodeShardState + if err := json.Unmarshal(buf, &state); err != nil { + sd.logger.Error("Failed to unmarshal remote state", tag.Error(err)) + return + } + + sd.logger.Debug("Merged remote shard state", + tag.NewStringTag("node", state.NodeName), + tag.NewStringTag("shards", strconv.Itoa(len(state.Shards)))) +} + +func (sd *shardDelegate) addLocalShard(shard history.ClusterShardID) { + sd.mutex.Lock() + defer sd.mutex.Unlock() + + key := fmt.Sprintf("%d:%d", shard.ClusterID, shard.ShardID) + sd.localShards[key] = shard +} + +func (sd *shardDelegate) removeLocalShard(shard history.ClusterShardID) { + sd.mutex.Lock() + defer sd.mutex.Unlock() + + key := fmt.Sprintf("%d:%d", shard.ClusterID, shard.ShardID) + delete(sd.localShards, key) +} + +// shardEventDelegate handles memberlist cluster events +type shardEventDelegate struct { + manager *shardManagerImpl + logger log.Logger +} + +func (sed *shardEventDelegate) NotifyJoin(node *memberlist.Node) { + sed.logger.Info("Node joined cluster", + tag.NewStringTag("node", node.Name), + tag.NewStringTag("addr", node.Addr.String())) +} + +func (sed *shardEventDelegate) NotifyLeave(node *memberlist.Node) { + sed.logger.Info("Node left cluster", + tag.NewStringTag("node", node.Name), + tag.NewStringTag("addr", node.Addr.String())) +} + +func (sed *shardEventDelegate) NotifyUpdate(node *memberlist.Node) { + sed.logger.Debug("Node updated", + tag.NewStringTag("node", node.Name), + tag.NewStringTag("addr", node.Addr.String())) +} + +// noopShardManager provides a no-op implementation when memberlist is disabled +type noopShardManager struct{} + +func (nsm *noopShardManager) Start(_ context.Context) error { return nil } +func (nsm *noopShardManager) Stop() {} +func (nsm *noopShardManager) RegisterShard(history.ClusterShardID) {} +func (nsm *noopShardManager) UnregisterShard(history.ClusterShardID) {} +func (nsm *noopShardManager) GetShardOwner(history.ClusterShardID) (string, bool) { return "", false } +func (nsm *noopShardManager) GetProxyAddress(string) (string, bool) { return "", false } +func (nsm *noopShardManager) IsLocalShard(history.ClusterShardID) bool { return true } +func (nsm *noopShardManager) GetMemberNodes() []string { return []string{} } +func (nsm *noopShardManager) GetLocalShards() []history.ClusterShardID { + return []history.ClusterShardID{} +} +func (nsm *noopShardManager) GetShardInfo() ShardDebugInfo { + return ShardDebugInfo{ + Enabled: false, + ForwardingEnabled: false, + NodeName: "", + LocalShards: []history.ClusterShardID{}, + LocalShardCount: 0, + ClusterNodes: []string{}, + ClusterSize: 0, + RemoteShards: make(map[string]string), + RemoteShardCounts: make(map[string]int), + } +} diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index c2918520..6f3232fd 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -287,7 +287,7 @@ func (s *ReplicationTestSuite) createProxy( } configProvider := &simpleConfigProvider{cfg: *cfg} - proxy := s2sproxy.NewProxy(configProvider, s.logger) + proxy := s2sproxy.NewProxy(configProvider, nil, s.logger) s.NotNil(proxy) err := proxy.Start() From de6633bc01d8879d4d81306b3d78c13d3057094b Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Thu, 7 Aug 2025 14:58:51 -0700 Subject: [PATCH 03/38] proxy routing --- MEMBERLIST_TROUBLESHOOTING.md | 229 ++++++++++++++++++ PROXY_FORWARDING.md | 96 ++++++++ .../config/cluster-a-mux-client-proxy-1.yaml | 2 +- .../config/cluster-a-mux-client-proxy-2.yaml | 2 +- .../config/cluster-b-mux-server-proxy-1.yaml | 2 +- .../config/cluster-b-mux-server-proxy-2.yaml | 2 +- develop/config/dynamic-config.yaml | 6 +- proxy/debug.go | 39 ++- proxy/proxy.go | 149 ++++++++++++ proxy/shard_manager.go | 6 +- proxy/stream_tracker.go | 63 ++++- 11 files changed, 575 insertions(+), 21 deletions(-) create mode 100644 MEMBERLIST_TROUBLESHOOTING.md create mode 100644 PROXY_FORWARDING.md diff --git a/MEMBERLIST_TROUBLESHOOTING.md b/MEMBERLIST_TROUBLESHOOTING.md new file mode 100644 index 00000000..e22fed3b --- /dev/null +++ b/MEMBERLIST_TROUBLESHOOTING.md @@ -0,0 +1,229 @@ +# Memberlist Network Troubleshooting + +This guide helps resolve network connectivity issues with memberlist in the s2s-proxy. + +## Common Issues + +### UDP Ping Failures + +**Symptoms:** +``` +[DEBUG] memberlist: Failed UDP ping: proxy-node-a-2 (timeout reached) +[WARN] memberlist: Was able to connect to proxy-node-a-2 over TCP but UDP probes failed, network may be misconfigured +``` + +**Causes:** +- UDP traffic blocked by firewalls +- Running in containers without UDP port mapping +- Network security policies blocking UDP +- NAT/proxy configurations + +**Solutions:** + +#### 1. Use TCP-Only Mode (Recommended) + +Update your configuration to use TCP-only transport: + +```yaml +memberlist: + enabled: true + enableForwarding: true + nodeName: "proxy-node-1" + bindAddr: "0.0.0.0" + bindPort: 7946 + joinAddrs: + - "proxy-node-2:7946" + - "proxy-node-3:7946" + # TCP-only configuration + tcpOnly: true # Disable UDP entirely + disableTCPPings: true # Improve performance in TCP-only mode + probeTimeoutMs: 1000 # Adjust for network latency + probeIntervalMs: 2000 # Reduce probe frequency +``` + +#### 2. Open UDP Ports + +If you want to keep UDP enabled: + +**Docker/Kubernetes:** +```bash +# Expose UDP port in Docker +docker run -p 7946:7946/udp -p 7946:7946/tcp ... + +# Kubernetes service +apiVersion: v1 +kind: Service +spec: + ports: + - name: memberlist-tcp + port: 7946 + protocol: TCP + - name: memberlist-udp + port: 7946 + protocol: UDP +``` + +**Firewall:** +```bash +# Linux iptables +iptables -A INPUT -p udp --dport 7946 -j ACCEPT +iptables -A INPUT -p tcp --dport 7946 -j ACCEPT + +# AWS Security Groups - allow UDP/TCP 7946 +``` + +#### 3. Adjust Bind Address + +For container environments, use specific bind addresses: + +```yaml +memberlist: + bindAddr: "0.0.0.0" # Listen on all interfaces + # OR + bindAddr: "10.0.0.1" # Specific container IP +``` + +## Configuration Options + +### Network Timing + +```yaml +memberlist: + probeTimeoutMs: 500 # Time to wait for ping response (default: 500ms) + probeIntervalMs: 1000 # Time between health probes (default: 1s) +``` + +**Adjust based on network conditions:** +- **Fast networks**: Lower values (500ms timeout, 1s interval) +- **Slow/high-latency networks**: Higher values (1000ms timeout, 2s interval) +- **Unreliable networks**: Much higher values (2000ms timeout, 5s interval) + +### Transport Modes + +#### Local Network Mode (Default) +```yaml +memberlist: + tcpOnly: false # Uses both UDP and TCP +``` +- Best for local networks +- Fastest failure detection +- Requires UDP connectivity + +#### TCP-Only Mode +```yaml +memberlist: + tcpOnly: true # TCP transport only + disableTCPPings: true # Optimize for TCP-only +``` +- Works in restricted networks +- Slightly slower failure detection +- More reliable in containerized environments + +## Testing Connectivity + +### 1. Test TCP Connectivity +```bash +# Test if TCP port is reachable +telnet proxy-node-2 7946 +nc -zv proxy-node-2 7946 +``` + +### 2. Test UDP Connectivity +```bash +# Test UDP port (if not using tcpOnly) +nc -u -zv proxy-node-2 7946 +``` + +### 3. Monitor Memberlist Logs +Enable debug logging to see detailed memberlist behavior: +```bash +# Set log level to debug +export LOG_LEVEL=debug +./s2s-proxy start --config your-config.yaml +``` + +### 4. Check Debug Endpoint +Query the debug endpoint to see cluster status: +```bash +curl http://localhost:6060/debug/connections | jq .shard_info +``` + +## Example Configurations + +### Docker Compose +```yaml +version: '3.8' +services: + proxy1: + image: s2s-proxy + ports: + - "7946:7946/tcp" + - "7946:7946/udp" # Only if not using tcpOnly + environment: + - CONFIG_PATH=/config/proxy.yaml +``` + +### Kubernetes +```yaml +apiVersion: apps/v1 +kind: Deployment +spec: + template: + spec: + containers: + - name: s2s-proxy + ports: + - containerPort: 7946 + protocol: TCP + - containerPort: 7946 + protocol: UDP # Only if not using tcpOnly +``` + +## Performance Impact + +**UDP + TCP Mode:** +- Fastest failure detection (~1-2 seconds) +- Best for stable networks +- Requires UDP connectivity + +**TCP-Only Mode:** +- Slightly slower failure detection (~2-5 seconds) +- More reliable in restricted environments +- Works everywhere TCP works + +## Recommended Settings by Environment + +### Local Development +```yaml +memberlist: + tcpOnly: false + probeTimeoutMs: 500 + probeIntervalMs: 1000 +``` + +### Docker/Containers +```yaml +memberlist: + tcpOnly: true + disableTCPPings: true + probeTimeoutMs: 1000 + probeIntervalMs: 2000 +``` + +### Kubernetes +```yaml +memberlist: + tcpOnly: true + disableTCPPings: true + probeTimeoutMs: 1500 + probeIntervalMs: 3000 +``` + +### High-Latency/Unreliable Networks +```yaml +memberlist: + tcpOnly: true + disableTCPPings: true + probeTimeoutMs: 2000 + probeIntervalMs: 5000 +``` diff --git a/PROXY_FORWARDING.md b/PROXY_FORWARDING.md new file mode 100644 index 00000000..0ee8639f --- /dev/null +++ b/PROXY_FORWARDING.md @@ -0,0 +1,96 @@ +# Proxy-to-Proxy Forwarding + +This document describes the proxy-to-proxy forwarding functionality that enables distributed shard management across multiple s2s-proxy instances. + +## Overview + +The proxy-to-proxy forwarding mechanism allows multiple proxy instances to work together as a cluster, where each proxy instance owns a subset of shards. When a replication stream request comes to a proxy that doesn't own the target shard, it automatically forwards the request to the proxy instance that does own that shard. + +## Architecture + +``` +Client → Proxy A (Inbound) → Proxy B (Inbound) → Target Server + (Forward) (Owner) +``` + +## How It Works + +1. **Shard Ownership**: Using consistent hashing via HashiCorp memberlist, each proxy instance is assigned ownership of specific shards +2. **Ownership Check**: When a `StreamWorkflowReplicationMessages` request arrives on an **inbound connection** with **forwarding enabled**, the proxy checks if it owns the required shard +3. **Forwarding**: If another proxy owns the shard, the request is forwarded to that proxy (only for inbound connections with forwarding enabled) +4. **Bidirectional Streaming**: The forwarding proxy acts as a transparent relay, forwarding both requests and responses + +## Key Components + +### Shard Manager +- **Interface**: `ShardManager` with methods for shard ownership and proxy address resolution +- **Implementation**: Uses memberlist for cluster membership and consistent hashing for shard distribution +- **Methods**: + - `IsLocalShard(shardID)` - Check if this proxy owns a shard + - `GetShardOwner(shardID)` - Get the node name that owns a shard + - `GetProxyAddress(nodeName)` - Get the service address for a proxy node + +### Forwarding Logic +- **Location**: `StreamWorkflowReplicationMessages` in `adminservice.go` +- **Conditions**: Forwards only when: + - **Inbound connection** (`s.IsInbound == true`) + - **Memberlist enabled** (`memberlist.enabled == true`) + - **Forwarding enabled** (`memberlist.enableForwarding == true`) +- **Checks**: Two shard ownership checks (only for inbound): + 1. `clientShardID` - the incoming shard from the client + 2. `serverShardID` - the target shard (after LCM remapping if applicable) +- **Forwarding Function**: `forwardToProxy()` handles the bidirectional streaming + +### Configuration + +```yaml +memberlist: + enabled: true + # Enable proxy-to-proxy forwarding + enableForwarding: true + nodeName: "proxy-node-1" + bindAddr: "0.0.0.0" + bindPort: 7946 + joinAddrs: + - "proxy-node-2:7946" + - "proxy-node-3:7946" + shardStrategy: "consistent" + proxyAddresses: + "proxy-node-1": "localhost:7001" + "proxy-node-2": "proxy-node-2:7001" + "proxy-node-3": "proxy-node-3:7001" +``` + +## Metrics + +The following Prometheus metrics track forwarding operations: + +- `shard_distribution` - Number of shards handled by each proxy instance +- `shard_forwarding_total` - Total forwarding operations (labels: from_node, to_node, result) +- `memberlist_cluster_size` - Number of nodes in the memberlist cluster +- `memberlist_events_total` - Memberlist events (join/leave) + +## Benefits + +1. **Horizontal Scaling**: Add more proxy instances to handle more shards +2. **High Availability**: Automatic shard redistribution when proxies fail +3. **Load Distribution**: Shards are evenly distributed across proxy instances +4. **Transparent**: Clients don't need to know about shard ownership +5. **Configurable**: Can enable cluster coordination without forwarding via `enableForwarding: false` +6. **Backward Compatible**: Works with existing setups when memberlist is disabled + +## Limitations + +- Forwarding adds one additional network hop for non-local shards +- Requires careful configuration of proxy addresses for inter-proxy communication +- Uses insecure gRPC connections for proxy-to-proxy communication (can be enhanced with TLS) + +## Example Deployment + +For a 3-proxy cluster handling temporal replication: + +1. **proxy-node-1**: Handles shards 0, 3, 6, 9, ... +2. **proxy-node-2**: Handles shards 1, 4, 7, 10, ... +3. **proxy-node-3**: Handles shards 2, 5, 8, 11, ... + +When a replication stream for shard 7 comes to proxy-node-1, it will automatically forward to proxy-node-2. \ No newline at end of file diff --git a/develop/config/cluster-a-mux-client-proxy-1.yaml b/develop/config/cluster-a-mux-client-proxy-1.yaml index 426eede2..27b5a1a3 100644 --- a/develop/config/cluster-a-mux-client-proxy-1.yaml +++ b/develop/config/cluster-a-mux-client-proxy-1.yaml @@ -31,7 +31,7 @@ profiling: pprofAddress: "localhost:6060" memberlist: enabled: true - enableForwarding: false + enableForwarding: true nodeName: "proxy-node-a-1" bindAddr: "0.0.0.0" bindPort: 6135 diff --git a/develop/config/cluster-a-mux-client-proxy-2.yaml b/develop/config/cluster-a-mux-client-proxy-2.yaml index 4027cb22..7f34f13a 100644 --- a/develop/config/cluster-a-mux-client-proxy-2.yaml +++ b/develop/config/cluster-a-mux-client-proxy-2.yaml @@ -31,7 +31,7 @@ profiling: pprofAddress: "localhost:6061" memberlist: enabled: true - enableForwarding: false + enableForwarding: true nodeName: "proxy-node-a-2" bindAddr: "0.0.0.0" bindPort: 6235 diff --git a/develop/config/cluster-b-mux-server-proxy-1.yaml b/develop/config/cluster-b-mux-server-proxy-1.yaml index 5155b29f..33467c58 100644 --- a/develop/config/cluster-b-mux-server-proxy-1.yaml +++ b/develop/config/cluster-b-mux-server-proxy-1.yaml @@ -31,7 +31,7 @@ profiling: pprofAddress: "localhost:6070" memberlist: enabled: true - enableForwarding: false + enableForwarding: true nodeName: "proxy-node-b-1" bindAddr: "0.0.0.0" bindPort: 6335 diff --git a/develop/config/cluster-b-mux-server-proxy-2.yaml b/develop/config/cluster-b-mux-server-proxy-2.yaml index 6cefe758..bf7749b0 100644 --- a/develop/config/cluster-b-mux-server-proxy-2.yaml +++ b/develop/config/cluster-b-mux-server-proxy-2.yaml @@ -31,7 +31,7 @@ profiling: pprofAddress: "localhost:6071" memberlist: enabled: true - enableForwarding: false + enableForwarding: true nodeName: "proxy-node-b-2" bindAddr: "0.0.0.0" bindPort: 6435 diff --git a/develop/config/dynamic-config.yaml b/develop/config/dynamic-config.yaml index b5a46d29..2d75917d 100644 --- a/develop/config/dynamic-config.yaml +++ b/develop/config/dynamic-config.yaml @@ -20,7 +20,7 @@ history.ReplicationEnableUpdateWithNewTaskMerge: history.enableWorkflowExecutionTimeoutTimer: - value: true history.EnableReplicationTaskTieredProcessing: - - value: true + - value: false history.persistenceMaxQPS: - value: 100000 constraints: {} @@ -28,4 +28,6 @@ frontend.persistenceMaxQPS: - value: 100000 constraints: {} history.shardUpdateMinInterval: - - value: 1s \ No newline at end of file + - value: 1s +history.ReplicationAllowMultiSourceShard: + - value: true \ No newline at end of file diff --git a/proxy/debug.go b/proxy/debug.go index 2410b33e..9f7d24e4 100644 --- a/proxy/debug.go +++ b/proxy/debug.go @@ -13,13 +13,18 @@ import ( type ( // StreamInfo represents information about an active gRPC stream StreamInfo struct { - ID string `json:"id"` - Method string `json:"method"` - Direction string `json:"direction"` - ClientShard string `json:"client_shard"` - ServerShard string `json:"server_shard"` - StartTime time.Time `json:"start_time"` - LastSeen time.Time `json:"last_seen"` + ID string `json:"id"` + Method string `json:"method"` + Direction string `json:"direction"` + ClientShard string `json:"client_shard"` + ServerShard string `json:"server_shard"` + StartTime time.Time `json:"start_time"` + LastSeen time.Time `json:"last_seen"` + TotalDuration string `json:"total_duration"` + IdleDuration string `json:"idle_duration"` + LastSyncWatermark *int64 `json:"last_sync_watermark,omitempty"` + LastSyncWatermarkTime *time.Time `json:"last_sync_watermark_time,omitempty"` + LastExclusiveHighWatermark *int64 `json:"last_exclusive_high_watermark,omitempty"` } // ShardDebugInfo contains debug information about shard distribution @@ -35,11 +40,20 @@ type ( RemoteShardCounts map[string]int `json:"remote_shard_counts"` // node_name -> shard_count } + // ChannelDebugInfo holds debug information about channels + ChannelDebugInfo struct { + RemoteSendChannels map[string]int `json:"remote_send_channels"` // shard ID -> buffer size + LocalAckChannels map[string]int `json:"local_ack_channels"` // shard ID -> buffer size + TotalSendChannels int `json:"total_send_channels"` + TotalAckChannels int `json:"total_ack_channels"` + } + DebugResponse struct { - Timestamp time.Time `json:"timestamp"` - ActiveStreams []StreamInfo `json:"active_streams"` - StreamCount int `json:"stream_count"` - ShardInfo ShardDebugInfo `json:"shard_info"` + Timestamp time.Time `json:"timestamp"` + ActiveStreams []StreamInfo `json:"active_streams"` + StreamCount int `json:"stream_count"` + ShardInfo ShardDebugInfo `json:"shard_info"` + ChannelInfo ChannelDebugInfo `json:"channel_info"` } ) @@ -49,18 +63,21 @@ func HandleDebugInfo(w http.ResponseWriter, r *http.Request, proxyInstance *Prox var activeStreams []StreamInfo var streamCount int var shardInfo ShardDebugInfo + var channelInfo ChannelDebugInfo // Get active streams information streamTracker := GetGlobalStreamTracker() activeStreams = streamTracker.GetActiveStreams() streamCount = streamTracker.GetStreamCount() shardInfo = proxyInstance.GetShardInfo() + channelInfo = proxyInstance.GetChannelInfo() response := DebugResponse{ Timestamp: time.Now(), ActiveStreams: activeStreams, StreamCount: streamCount, ShardInfo: shardInfo, + ChannelInfo: channelInfo, } if err := json.NewEncoder(w).Encode(response); err != nil { diff --git a/proxy/proxy.go b/proxy/proxy.go index 5456b0da..61beeef4 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -6,7 +6,10 @@ import ( "fmt" "net/http" "strings" + "sync" + "go.temporal.io/server/api/adminservice/v1" + "go.temporal.io/server/client/history" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" @@ -32,6 +35,13 @@ type ( metricsServer *http.Server shardManager ShardManager logger log.Logger + + remoteSendChannelsMu sync.RWMutex + remoteSendChannels map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesResponse + localAckChannelsMu sync.RWMutex + localAckChannels map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesRequest + localReceiverCancelFuncsMu sync.RWMutex + localReceiverCancelFuncs map[history.ClusterShardID]context.CancelFunc } ) @@ -49,6 +59,9 @@ func NewProxy(configProvider config.ConfigProvider, shardManager ShardManager, l return s2sConfig.Logging.GetThrottleMaxRPS() }, ), + remoteSendChannels: make(map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesResponse), + localAckChannels: make(map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesRequest), + localReceiverCancelFuncs: make(map[history.ClusterShardID]context.CancelFunc), } if len(s2sConfig.ClusterConnections) == 0 { panic(errors.New("cannot create proxy without inbound and outbound config")) @@ -212,3 +225,139 @@ func (s *Proxy) Describe() string { func (s *Proxy) GetShardInfo() ShardDebugInfo { return s.shardManager.GetShardInfo() } + +// GetChannelInfo returns debug information about active channels +func (s *Proxy) GetChannelInfo() ChannelDebugInfo { + remoteSendChannels := make(map[string]int) + var totalSendChannels int + + // Collect remote send channel info first + s.remoteSendChannelsMu.RLock() + for shardID, ch := range s.remoteSendChannels { + shardKey := ClusterShardIDtoString(shardID) + remoteSendChannels[shardKey] = len(ch) + } + totalSendChannels = len(s.remoteSendChannels) + s.remoteSendChannelsMu.RUnlock() + + localAckChannels := make(map[string]int) + var totalAckChannels int + + // Collect local ack channel info separately + s.localAckChannelsMu.RLock() + for shardID, ch := range s.localAckChannels { + shardKey := ClusterShardIDtoString(shardID) + localAckChannels[shardKey] = len(ch) + } + totalAckChannels = len(s.localAckChannels) + s.localAckChannelsMu.RUnlock() + + return ChannelDebugInfo{ + RemoteSendChannels: remoteSendChannels, + LocalAckChannels: localAckChannels, + TotalSendChannels: totalSendChannels, + TotalAckChannels: totalAckChannels, + } +} + +// SetRemoteSendChan registers a send channel for a specific shard ID +func (s *Proxy) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan *adminservice.StreamWorkflowReplicationMessagesResponse) { + s.remoteSendChannelsMu.Lock() + defer s.remoteSendChannelsMu.Unlock() + s.remoteSendChannels[shardID] = sendChan + s.logger.Info("Registered remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) +} + +// GetRemoteSendChan retrieves the send channel for a specific shard ID +func (s *Proxy) GetRemoteSendChan(shardID history.ClusterShardID) (chan *adminservice.StreamWorkflowReplicationMessagesResponse, bool) { + s.remoteSendChannelsMu.RLock() + defer s.remoteSendChannelsMu.RUnlock() + ch, exists := s.remoteSendChannels[shardID] + return ch, exists +} + +// GetAllRemoteSendChans returns a map of all remote send channels +func (s *Proxy) GetAllRemoteSendChans() map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesResponse { + s.remoteSendChannelsMu.RLock() + defer s.remoteSendChannelsMu.RUnlock() + + // Create a copy of the map + result := make(map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesResponse, len(s.remoteSendChannels)) + for k, v := range s.remoteSendChannels { + result[k] = v + } + return result +} + +// RemoveRemoteSendChan removes the send channel for a specific shard ID +func (s *Proxy) RemoveRemoteSendChan(shardID history.ClusterShardID) { + s.remoteSendChannelsMu.Lock() + defer s.remoteSendChannelsMu.Unlock() + delete(s.remoteSendChannels, shardID) + s.logger.Info("Removed remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) +} + +// SetLocalAckChan registers an ack channel for a specific shard ID +func (s *Proxy) SetLocalAckChan(shardID history.ClusterShardID, ackChan chan *adminservice.StreamWorkflowReplicationMessagesRequest) { + s.localAckChannelsMu.Lock() + defer s.localAckChannelsMu.Unlock() + s.localAckChannels[shardID] = ackChan + s.logger.Info("Registered local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) +} + +// GetLocalAckChan retrieves the ack channel for a specific shard ID +func (s *Proxy) GetLocalAckChan(shardID history.ClusterShardID) (chan *adminservice.StreamWorkflowReplicationMessagesRequest, bool) { + s.localAckChannelsMu.RLock() + defer s.localAckChannelsMu.RUnlock() + ch, exists := s.localAckChannels[shardID] + return ch, exists +} + +// RemoveLocalAckChan removes the ack channel for a specific shard ID +func (s *Proxy) RemoveLocalAckChan(shardID history.ClusterShardID) { + s.localAckChannelsMu.Lock() + defer s.localAckChannelsMu.Unlock() + delete(s.localAckChannels, shardID) + s.logger.Info("Removed local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) +} + +// SetLocalReceiverCancelFunc registers a cancel function for a local receiver for a specific shard ID +func (s *Proxy) SetLocalReceiverCancelFunc(shardID history.ClusterShardID, cancelFunc context.CancelFunc) { + s.localReceiverCancelFuncsMu.Lock() + defer s.localReceiverCancelFuncsMu.Unlock() + s.localReceiverCancelFuncs[shardID] = cancelFunc + s.logger.Info("Registered local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) +} + +// GetLocalReceiverCancelFunc retrieves the cancel function for a local receiver for a specific shard ID +func (s *Proxy) GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (context.CancelFunc, bool) { + s.localReceiverCancelFuncsMu.RLock() + defer s.localReceiverCancelFuncsMu.RUnlock() + cancelFunc, exists := s.localReceiverCancelFuncs[shardID] + return cancelFunc, exists +} + +// RemoveLocalReceiverCancelFunc removes the cancel function for a local receiver for a specific shard ID +func (s *Proxy) RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) { + s.localReceiverCancelFuncsMu.Lock() + defer s.localReceiverCancelFuncsMu.Unlock() + delete(s.localReceiverCancelFuncs, shardID) + s.logger.Info("Removed local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) +} + +// TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed +func (s *Proxy) TerminatePreviousLocalReceiver(clientShardID history.ClusterShardID) { + // Check if there's a previous cancel function for this shard + if prevCancelFunc, exists := s.GetLocalReceiverCancelFunc(clientShardID); exists { + s.logger.Info("Terminating previous local receiver for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(clientShardID))) + + // Cancel the previous receiver's context + prevCancelFunc() + + // Remove the cancel function from tracking + s.RemoveLocalReceiverCancelFunc(clientShardID) + + // Also clean up the associated ack channel if it exists + s.RemoveLocalAckChan(clientShardID) + } +} diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 0aff2b76..64a1c7ff 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -403,7 +403,7 @@ func (sd *shardDelegate) NotifyMsg(data []byte) { return } - sd.logger.Debug("Received shard message", + sd.logger.Info("Received shard message", tag.NewStringTag("type", msg.Type), tag.NewStringTag("node", msg.NodeName), tag.NewStringTag("shard", ClusterShardIDtoString(msg.ClientShard))) @@ -425,7 +425,7 @@ func (sd *shardDelegate) MergeRemoteState(buf []byte, join bool) { return } - sd.logger.Debug("Merged remote shard state", + sd.logger.Info("Merged remote shard state", tag.NewStringTag("node", state.NodeName), tag.NewStringTag("shards", strconv.Itoa(len(state.Shards)))) } @@ -465,7 +465,7 @@ func (sed *shardEventDelegate) NotifyLeave(node *memberlist.Node) { } func (sed *shardEventDelegate) NotifyUpdate(node *memberlist.Node) { - sed.logger.Debug("Node updated", + sed.logger.Info("Node updated", tag.NewStringTag("node", node.Name), tag.NewStringTag("addr", node.Addr.String())) } diff --git a/proxy/stream_tracker.go b/proxy/stream_tracker.go index 6dfe948c..66b368c4 100644 --- a/proxy/stream_tracker.go +++ b/proxy/stream_tracker.go @@ -1,6 +1,7 @@ package proxy import ( + "fmt" "sync" "time" ) @@ -45,6 +46,29 @@ func (st *StreamTracker) UpdateStream(id string) { } } +// UpdateStreamSyncReplicationState updates the sync replication state information for a stream +func (st *StreamTracker) UpdateStreamSyncReplicationState(id string, inclusiveLowWatermark int64, watermarkTime *time.Time) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.LastSeen = time.Now() + stream.LastSyncWatermark = &inclusiveLowWatermark + stream.LastSyncWatermarkTime = watermarkTime + } +} + +// UpdateStreamReplicationMessages updates the replication messages information for a stream +func (st *StreamTracker) UpdateStreamReplicationMessages(id string, exclusiveHighWatermark int64) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.LastSeen = time.Now() + stream.LastExclusiveHighWatermark = &exclusiveHighWatermark + } +} + // UnregisterStream removes a stream from tracking func (st *StreamTracker) UnregisterStream(id string) { st.mu.Lock() @@ -58,9 +82,16 @@ func (st *StreamTracker) GetActiveStreams() []StreamInfo { st.mu.RLock() defer st.mu.RUnlock() + now := time.Now() streams := make([]StreamInfo, 0, len(st.streams)) for _, stream := range st.streams { - streams = append(streams, *stream) + // Create a copy and calculate both durations in seconds + streamCopy := *stream + totalSeconds := int(now.Sub(stream.StartTime).Seconds()) + idleSeconds := int(now.Sub(stream.LastSeen).Seconds()) + streamCopy.TotalDuration = formatDurationSeconds(totalSeconds) + streamCopy.IdleDuration = formatDurationSeconds(idleSeconds) + streams = append(streams, streamCopy) } return streams @@ -81,3 +112,33 @@ var globalStreamTracker = NewStreamTracker() func GetGlobalStreamTracker() *StreamTracker { return globalStreamTracker } + +// formatDurationSeconds formats a duration in seconds to a readable string +func formatDurationSeconds(totalSeconds int) string { + if totalSeconds < 60 { + return fmt.Sprintf("%ds", totalSeconds) + } + + minutes := totalSeconds / 60 + seconds := totalSeconds % 60 + + if minutes < 60 { + if seconds == 0 { + return fmt.Sprintf("%dm", minutes) + } + return fmt.Sprintf("%dm%ds", minutes, seconds) + } + + hours := minutes / 60 + minutes = minutes % 60 + + if minutes == 0 && seconds == 0 { + return fmt.Sprintf("%dh", hours) + } else if seconds == 0 { + return fmt.Sprintf("%dh%dm", hours, minutes) + } else if minutes == 0 { + return fmt.Sprintf("%dh%ds", hours, seconds) + } else { + return fmt.Sprintf("%dh%dm%ds", hours, minutes, seconds) + } +} From d797c523e11112f4b95dbb1b5c750cbf9520b78a Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Wed, 3 Sep 2025 16:18:26 -0700 Subject: [PATCH 04/38] route only in target proxy; remap taskID and track ack in proxy. --- config/config.go | 1 + .../config/cluster-a-mux-client-proxy-1.yaml | 13 +- .../config/cluster-a-mux-client-proxy-2.yaml | 13 +- .../config/cluster-b-mux-server-proxy-1.yaml | 7 +- .../config/cluster-b-mux-server-proxy-2.yaml | 7 +- develop/config/dynamic-config.yaml | 4 +- proxy/adminservice.go | 11 + proxy/debug.go | 29 +- proxy/proxy.go | 68 +- proxy/proxy_streams.go | 746 ++++++++++++++++++ proxy/shard_manager.go | 64 +- proxy/stream_tracker.go | 85 +- 12 files changed, 982 insertions(+), 66 deletions(-) create mode 100644 proxy/proxy_streams.go diff --git a/config/config.go b/config/config.go index ae1958f0..f33fa42b 100644 --- a/config/config.go +++ b/config/config.go @@ -41,6 +41,7 @@ const ( ShardCountDefault ShardCountMode = "" ShardCountLCM ShardCountMode = "lcm" ShardCountFixed ShardCountMode = "fixed" + ShardCountRouting ShardCountMode = "routing" ) type HealthCheckProtocol string diff --git a/develop/config/cluster-a-mux-client-proxy-1.yaml b/develop/config/cluster-a-mux-client-proxy-1.yaml index 27b5a1a3..bc79e067 100644 --- a/develop/config/cluster-a-mux-client-proxy-1.yaml +++ b/develop/config/cluster-a-mux-client-proxy-1.yaml @@ -23,10 +23,10 @@ mux: # mode: "lcm" # localShardCount: 2 # remoteShardCount: 3 -shardCount: - mode: "fixed" - localShardCount: 2 - remoteShardCount: 3 +# shardCount: +# mode: "fixed" +# localShardCount: 2 +# remoteShardCount: 3 profiling: pprofAddress: "localhost:6060" memberlist: @@ -38,9 +38,8 @@ memberlist: joinAddrs: - "localhost:6235" proxyAddresses: - "proxy-node-1": "localhost:7001" - "proxy-node-2": "proxy-node-2:7001" - "proxy-node-3": "proxy-node-3:7001" + "proxy-node-a-1": "localhost:6133" + "proxy-node-a-2": "localhost:6233" # TCP-only configuration for restricted networks tcpOnly: true # Use TCP transport only, disable UDP disableTCPPings: true # Disable TCP pings for faster convergence diff --git a/develop/config/cluster-a-mux-client-proxy-2.yaml b/develop/config/cluster-a-mux-client-proxy-2.yaml index 7f34f13a..b9cf33a7 100644 --- a/develop/config/cluster-a-mux-client-proxy-2.yaml +++ b/develop/config/cluster-a-mux-client-proxy-2.yaml @@ -23,10 +23,10 @@ mux: # mode: "lcm" # localShardCount: 2 # remoteShardCount: 3 -shardCount: - mode: "fixed" - localShardCount: 2 - remoteShardCount: 3 +# shardCount: +# mode: "fixed" +# localShardCount: 2 +# remoteShardCount: 3 profiling: pprofAddress: "localhost:6061" memberlist: @@ -38,9 +38,8 @@ memberlist: joinAddrs: - "localhost:6135" proxyAddresses: - "proxy-node-1": "localhost:7001" - "proxy-node-2": "proxy-node-2:7001" - "proxy-node-3": "proxy-node-3:7001" + "proxy-node-a-1": "localhost:6133" + "proxy-node-a-2": "localhost:6233" # TCP-only configuration for restricted networks tcpOnly: true # Use TCP transport only, disable UDP disableTCPPings: true # Disable TCP pings for faster convergence diff --git a/develop/config/cluster-b-mux-server-proxy-1.yaml b/develop/config/cluster-b-mux-server-proxy-1.yaml index 33467c58..a9912a22 100644 --- a/develop/config/cluster-b-mux-server-proxy-1.yaml +++ b/develop/config/cluster-b-mux-server-proxy-1.yaml @@ -24,7 +24,7 @@ mux: # localShardCount: 3 # remoteShardCount: 2 shardCount: - mode: "fixed" + mode: "routing" localShardCount: 3 remoteShardCount: 2 profiling: @@ -38,9 +38,8 @@ memberlist: joinAddrs: - "localhost:6435" proxyAddresses: - "proxy-node-1": "localhost:7001" - "proxy-node-2": "proxy-node-2:7001" - "proxy-node-3": "proxy-node-3:7001" + "proxy-node-b-1": "localhost:6333" + "proxy-node-b-2": "localhost:6433" # TCP-only configuration for restricted networks tcpOnly: true # Use TCP transport only, disable UDP disableTCPPings: true # Disable TCP pings for faster convergence diff --git a/develop/config/cluster-b-mux-server-proxy-2.yaml b/develop/config/cluster-b-mux-server-proxy-2.yaml index bf7749b0..96689bb4 100644 --- a/develop/config/cluster-b-mux-server-proxy-2.yaml +++ b/develop/config/cluster-b-mux-server-proxy-2.yaml @@ -24,7 +24,7 @@ mux: # localShardCount: 3 # remoteShardCount: 2 shardCount: - mode: "fixed" + mode: "routing" localShardCount: 3 remoteShardCount: 2 profiling: @@ -38,9 +38,8 @@ memberlist: joinAddrs: - "localhost:6335" proxyAddresses: - "proxy-node-1": "localhost:7001" - "proxy-node-2": "proxy-node-2:7001" - "proxy-node-3": "proxy-node-3:7001" + "proxy-node-b-1": "localhost:6333" + "proxy-node-b-2": "localhost:6433" # TCP-only configuration for restricted networks tcpOnly: true # Use TCP transport only, disable UDP disableTCPPings: true # Disable TCP pings for faster convergence diff --git a/develop/config/dynamic-config.yaml b/develop/config/dynamic-config.yaml index 2d75917d..dbe95f8b 100644 --- a/develop/config/dynamic-config.yaml +++ b/develop/config/dynamic-config.yaml @@ -28,6 +28,4 @@ frontend.persistenceMaxQPS: - value: 100000 constraints: {} history.shardUpdateMinInterval: - - value: 1s -history.ReplicationAllowMultiSourceShard: - - value: true \ No newline at end of file + - value: 1s \ No newline at end of file diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 29b27a4f..41bf42be 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -303,6 +303,10 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( targetMetadata.Set(history.MetadataKeyServerShardID, strconv.Itoa(int(newSourceShardID.ShardID))) } + if s.shardCountConfig.Mode == config.ShardCountRouting { + return s.streamRouting() + } + forwarder := newStreamForwarder( s.adminClient, targetStreamServer, @@ -321,6 +325,13 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( return nil } +// streamRouting: placeholder for future stream routing implementation +func (s *adminServiceProxyServer) streamRouting() error { + _ = &proxyStreamSender{} + _ = &proxyStreamReceiver{} + return nil +} + func mapShardIDUnique(sourceShardCount, targetShardCount, sourceShardID int32) int32 { targetShardID := servercommon.MapShardID(sourceShardCount, targetShardCount, sourceShardID) if len(targetShardID) != 1 { diff --git a/proxy/debug.go b/proxy/debug.go index 9f7d24e4..d5fcb6f6 100644 --- a/proxy/debug.go +++ b/proxy/debug.go @@ -11,20 +11,25 @@ import ( ) type ( + // StreamInfo represents information about an active gRPC stream StreamInfo struct { - ID string `json:"id"` - Method string `json:"method"` - Direction string `json:"direction"` - ClientShard string `json:"client_shard"` - ServerShard string `json:"server_shard"` - StartTime time.Time `json:"start_time"` - LastSeen time.Time `json:"last_seen"` - TotalDuration string `json:"total_duration"` - IdleDuration string `json:"idle_duration"` - LastSyncWatermark *int64 `json:"last_sync_watermark,omitempty"` - LastSyncWatermarkTime *time.Time `json:"last_sync_watermark_time,omitempty"` - LastExclusiveHighWatermark *int64 `json:"last_exclusive_high_watermark,omitempty"` + ID string `json:"id"` + Method string `json:"method"` + Direction string `json:"direction"` + Role string `json:"role,omitempty"` + ClientShard string `json:"client_shard"` + ServerShard string `json:"server_shard"` + StartTime time.Time `json:"start_time"` + LastSeen time.Time `json:"last_seen"` + TotalDuration string `json:"total_duration"` + IdleDuration string `json:"idle_duration"` + LastSyncWatermark *int64 `json:"last_sync_watermark,omitempty"` + LastSyncWatermarkTime *time.Time `json:"last_sync_watermark_time,omitempty"` + LastExclusiveHighWatermark *int64 `json:"last_exclusive_high_watermark,omitempty"` + LastTaskIDs []int64 `json:"last_task_ids"` + SenderDebug *SenderDebugInfo `json:"sender_debug,omitempty"` + ReceiverDebug *ReceiverDebugInfo `json:"receiver_debug,omitempty"` } // ShardDebugInfo contains debug information about shard distribution diff --git a/proxy/proxy.go b/proxy/proxy.go index 61beeef4..da678824 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -23,6 +23,19 @@ type ( // Needs some config revision before uncommenting: //accountId string } + + // RoutedAck wraps an ACK with the target shard it originated from + RoutedAck struct { + TargetShard history.ClusterShardID + Req *adminservice.StreamWorkflowReplicationMessagesRequest + } + + // RoutedMessage wraps a replication response with originating client shard info + RoutedMessage struct { + SourceShard history.ClusterShardID + Resp *adminservice.StreamWorkflowReplicationMessagesResponse + } + Proxy struct { lifetime context.Context cancel context.CancelFunc @@ -36,12 +49,17 @@ type ( shardManager ShardManager logger log.Logger - remoteSendChannelsMu sync.RWMutex - remoteSendChannels map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesResponse - localAckChannelsMu sync.RWMutex - localAckChannels map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesRequest - localReceiverCancelFuncsMu sync.RWMutex + // remoteSendChannels maps shard IDs to send channels for replication message routing + remoteSendChannels map[history.ClusterShardID]chan RoutedMessage + remoteSendChannelsMu sync.RWMutex + + // localAckChannels maps shard IDs to ack channels for local acknowledgment handling + localAckChannels map[history.ClusterShardID]chan RoutedAck + localAckChannelsMu sync.RWMutex + + // localReceiverCancelFuncs maps shard IDs to context cancel functions for local receiver termination localReceiverCancelFuncs map[history.ClusterShardID]context.CancelFunc + localReceiverCancelFuncsMu sync.RWMutex } ) @@ -59,8 +77,8 @@ func NewProxy(configProvider config.ConfigProvider, shardManager ShardManager, l return s2sConfig.Logging.GetThrottleMaxRPS() }, ), - remoteSendChannels: make(map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesResponse), - localAckChannels: make(map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesRequest), + remoteSendChannels: make(map[history.ClusterShardID]chan RoutedMessage), + localAckChannels: make(map[history.ClusterShardID]chan RoutedAck), localReceiverCancelFuncs: make(map[history.ClusterShardID]context.CancelFunc), } if len(s2sConfig.ClusterConnections) == 0 { @@ -261,7 +279,7 @@ func (s *Proxy) GetChannelInfo() ChannelDebugInfo { } // SetRemoteSendChan registers a send channel for a specific shard ID -func (s *Proxy) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan *adminservice.StreamWorkflowReplicationMessagesResponse) { +func (s *Proxy) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) { s.remoteSendChannelsMu.Lock() defer s.remoteSendChannelsMu.Unlock() s.remoteSendChannels[shardID] = sendChan @@ -269,7 +287,7 @@ func (s *Proxy) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan } // GetRemoteSendChan retrieves the send channel for a specific shard ID -func (s *Proxy) GetRemoteSendChan(shardID history.ClusterShardID) (chan *adminservice.StreamWorkflowReplicationMessagesResponse, bool) { +func (s *Proxy) GetRemoteSendChan(shardID history.ClusterShardID) (chan RoutedMessage, bool) { s.remoteSendChannelsMu.RLock() defer s.remoteSendChannelsMu.RUnlock() ch, exists := s.remoteSendChannels[shardID] @@ -277,18 +295,32 @@ func (s *Proxy) GetRemoteSendChan(shardID history.ClusterShardID) (chan *adminse } // GetAllRemoteSendChans returns a map of all remote send channels -func (s *Proxy) GetAllRemoteSendChans() map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesResponse { +func (s *Proxy) GetAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage { s.remoteSendChannelsMu.RLock() defer s.remoteSendChannelsMu.RUnlock() // Create a copy of the map - result := make(map[history.ClusterShardID]chan *adminservice.StreamWorkflowReplicationMessagesResponse, len(s.remoteSendChannels)) + result := make(map[history.ClusterShardID]chan RoutedMessage, len(s.remoteSendChannels)) for k, v := range s.remoteSendChannels { result[k] = v } return result } +// GetRemoteSendChansByCluster returns a copy of remote send channels filtered by clusterID +func (s *Proxy) GetRemoteSendChansByCluster(clusterID int32) map[history.ClusterShardID]chan RoutedMessage { + s.remoteSendChannelsMu.RLock() + defer s.remoteSendChannelsMu.RUnlock() + + result := make(map[history.ClusterShardID]chan RoutedMessage) + for k, v := range s.remoteSendChannels { + if k.ClusterID == clusterID { + result[k] = v + } + } + return result +} + // RemoveRemoteSendChan removes the send channel for a specific shard ID func (s *Proxy) RemoveRemoteSendChan(shardID history.ClusterShardID) { s.remoteSendChannelsMu.Lock() @@ -298,7 +330,7 @@ func (s *Proxy) RemoveRemoteSendChan(shardID history.ClusterShardID) { } // SetLocalAckChan registers an ack channel for a specific shard ID -func (s *Proxy) SetLocalAckChan(shardID history.ClusterShardID, ackChan chan *adminservice.StreamWorkflowReplicationMessagesRequest) { +func (s *Proxy) SetLocalAckChan(shardID history.ClusterShardID, ackChan chan RoutedAck) { s.localAckChannelsMu.Lock() defer s.localAckChannelsMu.Unlock() s.localAckChannels[shardID] = ackChan @@ -306,7 +338,7 @@ func (s *Proxy) SetLocalAckChan(shardID history.ClusterShardID, ackChan chan *ad } // GetLocalAckChan retrieves the ack channel for a specific shard ID -func (s *Proxy) GetLocalAckChan(shardID history.ClusterShardID) (chan *adminservice.StreamWorkflowReplicationMessagesRequest, bool) { +func (s *Proxy) GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, bool) { s.localAckChannelsMu.RLock() defer s.localAckChannelsMu.RUnlock() ch, exists := s.localAckChannels[shardID] @@ -346,18 +378,18 @@ func (s *Proxy) RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) { } // TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed -func (s *Proxy) TerminatePreviousLocalReceiver(clientShardID history.ClusterShardID) { +func (s *Proxy) TerminatePreviousLocalReceiver(serverShardID history.ClusterShardID) { // Check if there's a previous cancel function for this shard - if prevCancelFunc, exists := s.GetLocalReceiverCancelFunc(clientShardID); exists { - s.logger.Info("Terminating previous local receiver for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(clientShardID))) + if prevCancelFunc, exists := s.GetLocalReceiverCancelFunc(serverShardID); exists { + s.logger.Info("Terminating previous local receiver for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(serverShardID))) // Cancel the previous receiver's context prevCancelFunc() // Remove the cancel function from tracking - s.RemoveLocalReceiverCancelFunc(clientShardID) + s.RemoveLocalReceiverCancelFunc(serverShardID) // Also clean up the associated ack channel if it exists - s.RemoveLocalAckChan(clientShardID) + s.RemoveLocalAckChan(serverShardID) } } diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go new file mode 100644 index 00000000..f05ba829 --- /dev/null +++ b/proxy/proxy_streams.go @@ -0,0 +1,746 @@ +package proxy + +import ( + "context" + "fmt" + "io" + "sync" + + "go.temporal.io/server/api/adminservice/v1" + replicationv1 "go.temporal.io/server/api/replication/v1" + "go.temporal.io/server/client/history" + servercommon "go.temporal.io/server/common" + "go.temporal.io/server/common/channel" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "google.golang.org/grpc/metadata" +) + +// proxyIDMapping stores the original source shard and task for a given proxy task ID +// Entries are kept in strictly increasing proxyID order. +type proxyIDMapping struct { + sourceShard history.ClusterShardID + sourceTask int64 +} + +// proxyIDRingBuffer is a dynamically growing ring buffer keyed by monotonically increasing proxy IDs. +// It supports O(1) append and O(k) pop up to a given watermark, while preserving insertion order. +type proxyIDRingBuffer struct { + entries []proxyIDMapping + head int + size int + startProxyID int64 // proxyID of the current head element when size > 0 +} + +func newProxyIDRingBuffer(capacity int) *proxyIDRingBuffer { + if capacity < 1 { + capacity = 1 + } + return &proxyIDRingBuffer{entries: make([]proxyIDMapping, capacity)} +} + +// ensureCapacity grows the buffer if it is full, preserving order. +func (b *proxyIDRingBuffer) ensureCapacity() { + if b.size < len(b.entries) { + return + } + newCap := len(b.entries) * 2 + if newCap == 0 { + newCap = 1 + } + newEntries := make([]proxyIDMapping, newCap) + // copy existing elements in order starting from head + for i := 0; i < b.size; i++ { + idx := (b.head + i) % len(b.entries) + newEntries[i] = b.entries[idx] + } + b.entries = newEntries + b.head = 0 +} + +// Append appends a mapping for the given proxyID. ProxyIDs must be strictly increasing and contiguous. +func (b *proxyIDRingBuffer) Append(proxyID int64, sourceShard history.ClusterShardID, sourceTask int64) { + b.ensureCapacity() + if b.size == 0 { + b.startProxyID = proxyID + } else { + // Maintain contiguity: next proxyID must be startProxyID + size + expected := b.startProxyID + int64(b.size) + if proxyID != expected { + // If contiguity is violated, grow holes by inserting empty mappings until aligned. + // In practice proxyID is always increasing by 1, so this branch should not trigger. + for expected < proxyID { + b.ensureCapacity() + pos := (b.head + b.size) % len(b.entries) + b.entries[pos] = proxyIDMapping{sourceShard: history.ClusterShardID{}, sourceTask: 0} + b.size++ + expected++ + } + } + } + pos := (b.head + b.size) % len(b.entries) + b.entries[pos] = proxyIDMapping{sourceShard: sourceShard, sourceTask: sourceTask} + b.size++ +} + +// PopUpTo pops and aggregates mappings up to and including the given watermark (proxy ID). +// Returns per-source-shard the maximal original source task acknowledged. +func (b *proxyIDRingBuffer) PopUpTo(watermark int64) map[history.ClusterShardID]int64 { + result := make(map[history.ClusterShardID]int64) + if b.size == 0 { + return result + } + // if watermark is before head, nothing to pop + if watermark < b.startProxyID { + return result + } + count64 := watermark - b.startProxyID + 1 + if count64 <= 0 { + return result + } + count := int(count64) + if count > b.size { + count = b.size + } + for i := 0; i < count; i++ { + idx := (b.head + i) % len(b.entries) + m := b.entries[idx] + // Skip zero entries (shouldn't happen unless contiguity fix inserted holes) + if m.sourceShard.ClusterID == 0 && m.sourceShard.ShardID == 0 { + continue + } + if current, ok := result[m.sourceShard]; !ok || m.sourceTask > current { + result[m.sourceShard] = m.sourceTask + } + } + // advance head + b.head = (b.head + count) % len(b.entries) + b.size -= count + b.startProxyID += int64(count) + return result +} + +// AggregateUpTo computes the per-shard aggregation up to watermark without removing entries. +// Returns (aggregation, count) where count is the number of entries covered. +func (b *proxyIDRingBuffer) AggregateUpTo(watermark int64) (map[history.ClusterShardID]int64, int) { + result := make(map[history.ClusterShardID]int64) + if b.size == 0 { + return result, 0 + } + if watermark < b.startProxyID { + return result, 0 + } + count64 := watermark - b.startProxyID + 1 + if count64 <= 0 { + return result, 0 + } + count := int(count64) + if count > b.size { + count = b.size + } + for i := 0; i < count; i++ { + idx := (b.head + i) % len(b.entries) + m := b.entries[idx] + if m.sourceShard.ClusterID == 0 && m.sourceShard.ShardID == 0 { + continue + } + if current, ok := result[m.sourceShard]; !ok || m.sourceTask > current { + result[m.sourceShard] = m.sourceTask + } + } + return result, count +} + +// Discard advances the head by count entries, effectively removing them. +func (b *proxyIDRingBuffer) Discard(count int) { + if count <= 0 { + return + } + if count > b.size { + count = b.size + } + b.head = (b.head + count) % len(b.entries) + b.size -= count + b.startProxyID += int64(count) +} + +// proxyStreamSender is responsible for sending replication messages to the next hop +// (another proxy or a target server) and receiving ACKs back. +// This is scaffolding only – the concrete behavior will be wired in later. +type proxyStreamSender struct { + logger log.Logger + // shardID history.ClusterShardID + shardManager ShardManager + proxy *Proxy + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + directionLabel string + streamID string + streamTracker *StreamTracker + // sendMsgChan carries replication messages to be sent to the remote side. + sendMsgChan chan RoutedMessage + + mu sync.Mutex + nextProxyTaskID int64 + idRing *proxyIDRingBuffer + // prevAckBySource tracks the last ack level sent per original source shard + prevAckBySource map[history.ClusterShardID]int64 +} + +// buildSenderDebugSnapshot returns a snapshot of the sender's ring buffer and related state +func (s *proxyStreamSender) buildSenderDebugSnapshot(maxEntries int) *SenderDebugInfo { + s.mu.Lock() + defer s.mu.Unlock() + + info := &SenderDebugInfo{ + PrevAckBySource: make(map[string]int64), + } + + info.NextProxyTaskID = s.nextProxyTaskID + + for k, v := range s.prevAckBySource { + info.PrevAckBySource[ClusterShardIDtoString(k)] = v + } + + if s.idRing != nil { + info.RingStartProxyID = s.idRing.startProxyID + info.RingSize = s.idRing.size + info.RingCapacity = len(s.idRing.entries) + info.RingHead = s.idRing.head + + // Build entries preview + if maxEntries <= 0 { + maxEntries = 20 + } + limit := s.idRing.size + if limit > maxEntries { + limit = maxEntries + } + info.EntriesPreview = make([]ProxyIDEntry, 0, limit) + for i := 0; i < limit; i++ { + idx := (s.idRing.head + i) % len(s.idRing.entries) + e := s.idRing.entries[idx] + info.EntriesPreview = append(info.EntriesPreview, ProxyIDEntry{ + ProxyID: s.idRing.startProxyID + int64(i), + SourceShard: ClusterShardIDtoString(e.sourceShard), + SourceTask: e.sourceTask, + }) + } + } + + return info +} + +func (s *proxyStreamSender) Run( + targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + shutdownChan channel.ShutdownOnce, +) { + + // Register this sender as the owner of the shard for the duration of the stream + if s.shardManager != nil { + s.shardManager.RegisterShard(s.targetShardID) + defer s.shardManager.UnregisterShard(s.targetShardID) + } + + // Register local stream tracking for sender (short id, include role) + s.streamTracker = GetGlobalStreamTracker() + s.streamID = fmt.Sprintf("snd-%s-%s", + ClusterShardIDtoString(s.sourceShardID), + ClusterShardIDtoString(s.targetShardID), + ) + s.streamTracker.RegisterStream( + s.streamID, + "StreamWorkflowReplicationMessages", + s.directionLabel, + ClusterShardIDtoString(s.sourceShardID), + ClusterShardIDtoString(s.targetShardID), + StreamRoleSender, + ) + defer s.streamTracker.UnregisterStream(s.streamID) + + wg := sync.WaitGroup{} + // lazy init maps + s.mu.Lock() + if s.idRing == nil { + s.idRing = newProxyIDRingBuffer(1024) + } + if s.prevAckBySource == nil { + s.prevAckBySource = make(map[history.ClusterShardID]int64) + } + s.mu.Unlock() + + // Register remote send channel for this shard so receiver can forward tasks locally + s.sendMsgChan = make(chan RoutedMessage, 100) + + s.proxy.SetRemoteSendChan(s.targetShardID, s.sendMsgChan) + defer s.proxy.RemoveRemoteSendChan(s.targetShardID) + + wg.Add(2) + go func() { + defer wg.Done() + _ = s.sendReplicationMessages(targetStreamServer, shutdownChan) + }() + go func() { + defer wg.Done() + _ = s.recvAck(targetStreamServer, shutdownChan) + }() + // Wait for shutdown signal (triggered by receiver or stream errors) + <-shutdownChan.Channel() + // Ensure send loop exits promptly + close(s.sendMsgChan) + // Do not block waiting for ack goroutine; it will terminate when stream ends +} + +// recvAck receives ACKs from the remote side and forwards them to the provided +// channel for aggregation/routing. Non-blocking shutdown is coordinated via +// shutdownChan. This is a placeholder implementation. +func (s *proxyStreamSender) recvAck( + targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + shutdownChan channel.ShutdownOnce, +) error { + defer func() { + s.logger.Info("Shutdown targetStreamServer.Recv loop.") + shutdownChan.Shutdown() + }() + for !shutdownChan.IsShutdown() { + req, err := targetStreamServer.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + + // Unmap proxy task IDs back to original source shard/task and ACK by source shard + if attr, ok := req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { + proxyAckWatermark := attr.SyncReplicationState.InclusiveLowWatermark + + // Log incoming upstream ACK watermark + s.logger.Info("Sender received upstream ACK", tag.NewInt64("inclusive_low", proxyAckWatermark)) + // track sync watermark + s.streamTracker.UpdateStreamSyncReplicationState(s.streamID, proxyAckWatermark, nil) + s.streamTracker.UpdateStream(s.streamID) + + s.mu.Lock() + shardToAck, pendingDiscard := s.idRing.AggregateUpTo(proxyAckWatermark) + s.mu.Unlock() + + if len(shardToAck) > 0 { + for srcShard, originalAck := range shardToAck { + // If proxy watermark has passed an empty-batch proxy-high, translate it to original-high + s.mu.Lock() + // record last ack per source shard + s.prevAckBySource[srcShard] = originalAck + s.mu.Unlock() + + routedAck := &RoutedAck{ + TargetShard: s.targetShardID, + Req: &adminservice.StreamWorkflowReplicationMessagesRequest{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ + SyncReplicationState: &replicationv1.SyncReplicationState{ + InclusiveLowWatermark: originalAck, + InclusiveLowWatermarkTime: attr.SyncReplicationState.InclusiveLowWatermarkTime, + }, + }, + }, + } + + // Log outgoing ACK for this source shard + s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", originalAck)) + + s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger) + } + + // TODO: ack to idle shards using prevAckBySource + + } else { + // No new shards to ACK: send previous ack levels per source shard (if known) + s.mu.Lock() + for srcShard, prev := range s.prevAckBySource { + routedAck := &RoutedAck{ + TargetShard: s.targetShardID, + Req: &adminservice.StreamWorkflowReplicationMessagesRequest{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ + SyncReplicationState: &replicationv1.SyncReplicationState{ + InclusiveLowWatermark: prev, + InclusiveLowWatermarkTime: attr.SyncReplicationState.InclusiveLowWatermarkTime, + }, + }, + }, + } + // Log fallback ACK for this source shard + s.logger.Info("Sender forwarding fallback ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", prev)) + s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger) + } + s.mu.Unlock() + } + + // Only after forwarding ACKs, discard the entries from the ring buffer + if pendingDiscard > 0 { + s.mu.Lock() + s.idRing.Discard(pendingDiscard) + s.mu.Unlock() + } + + // Update debug snapshot after ack processing + s.streamTracker.UpdateStreamSenderDebug(s.streamID, s.buildSenderDebugSnapshot(20)) + } + } + return nil +} + +// sendReplicationMessages sends replication messages read from sendMsgChan to +// the remote side. This is a placeholder implementation. +func (s *proxyStreamSender) sendReplicationMessages( + targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + shutdownChan channel.ShutdownOnce, +) error { + defer func() { + s.logger.Info("Shutdown sendMsgChan loop.") + shutdownChan.Shutdown() + }() + for !shutdownChan.IsShutdown() { + if s.sendMsgChan == nil { + return nil + } + select { + case routed, ok := <-s.sendMsgChan: + if !ok { + return nil + } + resp := routed.Resp + if m, ok := resp.Attributes.(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages); ok && m.Messages != nil { + // rewrite task ids + s.mu.Lock() + var originalIDs []int64 + var proxyIDs []int64 + // capture original exclusive high watermark before rewriting + originalHigh := m.Messages.ExclusiveHighWatermark + for _, t := range m.Messages.ReplicationTasks { + // allocate proxy task id + s.nextProxyTaskID++ + proxyID := s.nextProxyTaskID + // remember original + original := t.SourceTaskId + originalIDs = append(originalIDs, original) + s.idRing.Append(proxyID, routed.SourceShard, original) + // rewrite id + t.SourceTaskId = proxyID + t.RawTaskInfo.TaskId = proxyID + proxyIDs = append(proxyIDs, proxyID) + } + s.mu.Unlock() + // Log mapping from original -> proxy IDs + s.logger.Info(fmt.Sprintf("Sender forwarding ReplicationTasks from shard %s: original=%v proxy=%v", ClusterShardIDtoString(routed.SourceShard), originalIDs, proxyIDs)) + + // Ensure exclusive high watermark is in proxy task ID space + if len(m.Messages.ReplicationTasks) > 0 { + m.Messages.ExclusiveHighWatermark = m.Messages.ReplicationTasks[len(m.Messages.ReplicationTasks)-1].RawTaskInfo.TaskId + 1 + } else { + // No tasks in this batch: allocate a synthetic proxy task id mapping + s.mu.Lock() + s.nextProxyTaskID++ + proxyHigh := s.nextProxyTaskID + s.idRing.Append(proxyHigh, routed.SourceShard, originalHigh) + m.Messages.ExclusiveHighWatermark = proxyHigh + s.mu.Unlock() + } + // track sent tasks ids and high watermark + ids := make([]int64, 0, len(m.Messages.ReplicationTasks)) + for _, t := range m.Messages.ReplicationTasks { + ids = append(ids, t.SourceTaskId) + } + s.streamTracker.UpdateStreamLastTaskIDs(s.streamID, ids) + s.streamTracker.UpdateStreamReplicationMessages(s.streamID, m.Messages.ExclusiveHighWatermark) + s.streamTracker.UpdateStreamSenderDebug(s.streamID, s.buildSenderDebugSnapshot(20)) + s.streamTracker.UpdateStream(s.streamID) + } + if err := targetStreamServer.Send(resp); err != nil { + return err + } + case <-shutdownChan.Channel(): + return nil + } + } + return nil +} + +// proxyStreamReceiver receives replication messages from a local/remote server and +// produces ACKs destined for the original sender. +type proxyStreamReceiver struct { + logger log.Logger + // shardID history.ClusterShardID + shardManager ShardManager + proxy *Proxy + adminClient adminservice.AdminServiceClient + localShardCount int32 + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + directionLabel string + ackChan chan RoutedAck + // ack aggregation across target shards + ackByTarget map[history.ClusterShardID]int64 + lastSentMin int64 + // lastExclusiveHighOriginal tracks last exclusive high watermark seen from source (original id space) + lastExclusiveHighOriginal int64 + streamID string + streamTracker *StreamTracker +} + +// buildReceiverDebugSnapshot builds receiver ACK aggregation state for debugging +func (r *proxyStreamReceiver) buildReceiverDebugSnapshot() *ReceiverDebugInfo { + info := &ReceiverDebugInfo{ + AckByTarget: make(map[string]int64), + } + for k, v := range r.ackByTarget { + info.AckByTarget[ClusterShardIDtoString(k)] = v + } + info.LastAggregatedMin = r.lastSentMin + info.LastExclusiveHighOriginal = r.lastExclusiveHighOriginal + return info +} + +func (r *proxyStreamReceiver) Run( + shutdownChan channel.ShutdownOnce, +) { + // Terminate any previous local receiver for this shard + r.proxy.TerminatePreviousLocalReceiver(r.sourceShardID) + + r.logger = log.With(r.logger, + tag.NewStringTag("client", ClusterShardIDtoString(r.targetShardID)), + tag.NewStringTag("server", ClusterShardIDtoString(r.sourceShardID)), + ) + + // Build metadata for local server stream + md := metadata.New(map[string]string{}) + md.Set(history.MetadataKeyClientClusterID, fmt.Sprintf("%d", r.targetShardID.ClusterID)) + md.Set(history.MetadataKeyClientShardID, fmt.Sprintf("%d", r.targetShardID.ShardID)) + md.Set(history.MetadataKeyServerClusterID, fmt.Sprintf("%d", r.sourceShardID.ClusterID)) + md.Set(history.MetadataKeyServerShardID, fmt.Sprintf("%d", r.sourceShardID.ShardID)) + + outgoingContext := metadata.NewOutgoingContext(context.Background(), md) + outgoingContext, cancel := context.WithCancel(outgoingContext) + defer cancel() + + // Open stream receiver -> local server's stream sender for clientShardID + var sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient + var err error + sourceStreamClient, err = r.adminClient.StreamWorkflowReplicationMessages(outgoingContext) + if err != nil { + r.logger.Error("adminClient.StreamWorkflowReplicationMessages error", tag.Error(err)) + return + } + + // Setup ack channel and cancel func bookkeeping + r.ackChan = make(chan RoutedAck, 100) + r.proxy.SetLocalAckChan(r.sourceShardID, r.ackChan) + r.proxy.SetLocalReceiverCancelFunc(r.sourceShardID, cancel) + defer func() { + r.proxy.RemoveLocalAckChan(r.sourceShardID) + r.proxy.RemoveLocalReceiverCancelFunc(r.sourceShardID) + }() + + // init aggregation state + r.ackByTarget = make(map[history.ClusterShardID]int64) + r.lastSentMin = 0 + + // Register a new local stream for tracking (short id, include role) + r.streamID = fmt.Sprintf("rcv-%s-%s", + ClusterShardIDtoString(r.sourceShardID), + ClusterShardIDtoString(r.targetShardID), + ) + r.streamTracker = GetGlobalStreamTracker() + r.streamTracker.RegisterStream( + r.streamID, + "StreamWorkflowReplicationMessages", + r.directionLabel, + ClusterShardIDtoString(r.sourceShardID), + ClusterShardIDtoString(r.targetShardID), + StreamRoleReceiver, + ) + defer r.streamTracker.UnregisterStream(r.streamID) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer func() { + shutdownChan.Shutdown() + wg.Done() + }() + _ = r.recvReplicationMessages(sourceStreamClient, shutdownChan) + }() + + go func() { + defer func() { + shutdownChan.Shutdown() + _ = sourceStreamClient.CloseSend() + wg.Done() + }() + _ = r.sendAck(sourceStreamClient, shutdownChan) + }() + + wg.Wait() +} + +// recvReplicationMessages receives from local server and routes to target shard owners. +func (r *proxyStreamReceiver) recvReplicationMessages( + sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient, + shutdownChan channel.ShutdownOnce, +) error { + for !shutdownChan.IsShutdown() { + resp, err := sourceStreamClient.Recv() + if err == io.EOF { + r.logger.Info("sourceStreamClient.Recv encountered EOF", tag.Error(err)) + return nil + } + if err != nil { + r.logger.Error("sourceStreamClient.Recv encountered error", tag.Error(err)) + return err + } + + if attr, ok := resp.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages); ok && attr.Messages != nil { + // Group by recalculated target shard using namespace/workflow hash + tasksByTargetShard := make(map[history.ClusterShardID][]*replicationv1.ReplicationTask) + ids := make([]int64, 0, len(attr.Messages.ReplicationTasks)) + for _, task := range attr.Messages.ReplicationTasks { + if task.RawTaskInfo != nil && task.RawTaskInfo.NamespaceId != "" && task.RawTaskInfo.WorkflowId != "" { + targetShard := servercommon.WorkflowIDToHistoryShard(task.RawTaskInfo.NamespaceId, task.RawTaskInfo.WorkflowId, r.localShardCount) + targetClusterShard := history.ClusterShardID{ClusterID: r.targetShardID.ClusterID, ShardID: targetShard} + tasksByTargetShard[targetClusterShard] = append(tasksByTargetShard[targetClusterShard], task) + ids = append(ids, task.SourceTaskId) + } + } + + // Log every replication task id received at receiver + r.logger.Info(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", attr.Messages.ExclusiveHighWatermark, ids)) + + // record last source exclusive high watermark (original id space) + r.lastExclusiveHighOriginal = attr.Messages.ExclusiveHighWatermark + + // update tracker for incoming messages + if r.streamTracker != nil && r.streamID != "" { + r.streamTracker.UpdateStreamLastTaskIDs(r.streamID, ids) + r.streamTracker.UpdateStreamReplicationMessages(r.streamID, attr.Messages.ExclusiveHighWatermark) + r.streamTracker.UpdateStreamReceiverDebug(r.streamID, r.buildReceiverDebugSnapshot()) + r.streamTracker.UpdateStream(r.streamID) + } + + // If replication tasks are empty, still log the empty batch and send watermark + if len(attr.Messages.ReplicationTasks) == 0 { + r.logger.Info("Receiver received empty replication batch", tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) + for targetShardID, sendChan := range r.proxy.GetRemoteSendChansByCluster(r.targetShardID.ClusterID) { + r.logger.Info("Sending high watermark to target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) + sendChan <- RoutedMessage{ + SourceShard: r.sourceShardID, + Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: attr.Messages.ExclusiveHighWatermark, + Priority: attr.Messages.Priority, + }, + }, + }, + } + } + continue + } + + for targetShardID, tasks := range tasksByTargetShard { + forwardResp := &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ReplicationTasks: tasks, + ExclusiveHighWatermark: tasks[len(tasks)-1].RawTaskInfo.TaskId + 1, + Priority: attr.Messages.Priority, + }, + }, + } + + if r.shardManager != nil && r.shardManager.IsLocalShard(targetShardID) { + if r.proxy != nil { + if sendCh, ok := r.proxy.GetRemoteSendChan(targetShardID); ok { + select { + case sendCh <- RoutedMessage{SourceShard: r.sourceShardID, Resp: forwardResp}: + case <-shutdownChan.Channel(): + return nil + } + } else { + r.logger.Error("No send channel found for target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } + } + } else if r.shardManager != nil { + if owner, ok := r.shardManager.GetShardOwner(targetShardID); ok { + r.logger.Info("Target shard owned by remote node", tag.NewStringTag("owner", owner), tag.NewStringTag("shard", ClusterShardIDtoString(targetShardID))) + // TODO: forward via inter-proxy transport + } else { + r.logger.Warn("Unable to determine owner for target shard", tag.NewStringTag("shard", ClusterShardIDtoString(targetShardID))) + } + } + } + } + } + return nil +} + +// sendAck forwards ACKs from local ack channel upstream to the local server. +func (r *proxyStreamReceiver) sendAck( + sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient, + shutdownChan channel.ShutdownOnce, +) error { + for !shutdownChan.IsShutdown() { + select { + case routed := <-r.ackChan: + // Update per-target watermark + if attr, ok := routed.Req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { + r.ackByTarget[routed.TargetShard] = attr.SyncReplicationState.InclusiveLowWatermark + // Compute minimal watermark across targets + min := int64(0) + first := true + for _, wm := range r.ackByTarget { + if first || wm < min { + min = wm + first = false + } + } + if !first && min >= r.lastSentMin { + // Clamp ACK to last known exclusive high watermark from source + if r.lastExclusiveHighOriginal > 0 && min > r.lastExclusiveHighOriginal { + r.logger.Warn("Aggregated ACK exceeds last source high watermark; clamping", + tag.NewInt64("ack_min", min), + tag.NewInt64("source_exclusive_high", r.lastExclusiveHighOriginal)) + min = r.lastExclusiveHighOriginal + } + // Send aggregated minimal ack upstream + aggregated := &adminservice.StreamWorkflowReplicationMessagesRequest{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ + SyncReplicationState: &replicationv1.SyncReplicationState{ + InclusiveLowWatermark: min, + }, + }, + } + r.logger.Info("Receiver sending aggregated ACK upstream", tag.NewInt64("inclusive_low", min)) + if err := sourceStreamClient.Send(aggregated); err != nil { + if err != io.EOF { + r.logger.Error("sourceStreamClient.Send encountered error", tag.Error(err)) + } else { + r.logger.Info("sourceStreamClient.Send encountered EOF", tag.Error(err)) + } + return err + } + // Track sync watermark for receiver stream + if r.streamTracker != nil && r.streamID != "" { + r.streamTracker.UpdateStreamSyncReplicationState(r.streamID, min, nil) + r.streamTracker.UpdateStream(r.streamID) + // Update receiver debug snapshot when we send an aggregated ACK + r.streamTracker.UpdateStreamReceiverDebug(r.streamID, r.buildReceiverDebugSnapshot()) + } + r.lastSentMin = min + } + } + case <-shutdownChan.Channel(): + return nil + } + } + return nil +} diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 64a1c7ff..70758270 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/memberlist" "go.temporal.io/server/client/history" + "go.temporal.io/server/common/channel" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" @@ -41,6 +42,8 @@ type ( GetLocalShards() []history.ClusterShardID // GetShardInfo returns debug information about shard distribution GetShardInfo() ShardDebugInfo + // DeliverAckToShardOwner routes an ACK request to the appropriate shard owner (local or remote) + DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) } shardManagerImpl struct { @@ -328,6 +331,37 @@ func (sm *shardManagerImpl) GetShardInfo() ShardDebugInfo { } } +// DeliverAckToShardOwner routes an ACK to the local shard owner or records intent for remote forwarding. +func (sm *shardManagerImpl) DeliverAckToShardOwner( + sourceShard history.ClusterShardID, + routedAck *RoutedAck, + proxy *Proxy, + shutdownChan channel.ShutdownOnce, + logger log.Logger, +) { + if sm.IsLocalShard(sourceShard) { + if proxy != nil { + if ackCh, ok := proxy.GetLocalAckChan(sourceShard); ok { + select { + case ackCh <- *routedAck: + case <-shutdownChan.Channel(): + return + } + } else { + logger.Warn("No local ack channel for source shard", tag.NewStringTag("shard", ClusterShardIDtoString(sourceShard))) + } + } + return + } + + // TODO: forward to remote owner via inter-proxy transport. + if owner, ok := sm.GetShardOwner(sourceShard); ok { + logger.Info("ACK belongs to remote source shard owner", tag.NewStringTag("owner", owner), tag.NewStringTag("shard", ClusterShardIDtoString(sourceShard))) + } else { + logger.Warn("Unable to determine source shard owner for ACK", tag.NewStringTag("shard", ClusterShardIDtoString(sourceShard))) + } +} + func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.ClusterShardID) { if !sm.started || sm.ml == nil { return @@ -346,9 +380,21 @@ func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.C return } - err = sm.ml.SendReliable(sm.ml.Members()[0], data) - if err != nil { - sm.logger.Error("Failed to broadcast shard change", tag.Error(err)) + for _, member := range sm.ml.Members() { + // Skip sending to self node + if member.Name == sm.config.NodeName { + continue + } + + // Send in goroutine to make it non-blocking + go func(m *memberlist.Node) { + err := sm.ml.SendReliable(m, data) + if err != nil { + sm.logger.Error("Failed to broadcast shard change", + tag.Error(err), + tag.NewStringTag("target_node", m.Name)) + } + }(member) } } @@ -497,3 +543,15 @@ func (nsm *noopShardManager) GetShardInfo() ShardDebugInfo { RemoteShardCounts: make(map[string]int), } } + +func (nsm *noopShardManager) DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) { + if proxy != nil { + if ackCh, ok := proxy.GetLocalAckChan(srcShard); ok { + select { + case ackCh <- *routedAck: + case <-shutdownChan.Channel(): + return + } + } + } +} diff --git a/proxy/stream_tracker.go b/proxy/stream_tracker.go index 66b368c4..bbd4d977 100644 --- a/proxy/stream_tracker.go +++ b/proxy/stream_tracker.go @@ -6,6 +6,12 @@ import ( "time" ) +const ( + StreamRoleSender = "Sender" + StreamRoleReceiver = "Receiver" + StreamRoleForwarder = "Forwarder" +) + // StreamTracker tracks active gRPC streams for debugging type StreamTracker struct { mu sync.RWMutex @@ -20,19 +26,22 @@ func NewStreamTracker() *StreamTracker { } // RegisterStream adds a new active stream -func (st *StreamTracker) RegisterStream(id, method, direction, clientShard, serverShard string) { +func (st *StreamTracker) RegisterStream(id, method, direction, clientShard, serverShard, role string) { st.mu.Lock() defer st.mu.Unlock() now := time.Now() st.streams[id] = &StreamInfo{ - ID: id, - Method: method, - Direction: direction, - ClientShard: clientShard, - ServerShard: serverShard, - StartTime: now, - LastSeen: now, + ID: id, + Method: method, + Direction: direction, + ClientShard: clientShard, + ServerShard: serverShard, + Role: role, + StartTime: now, + LastSeen: now, + SenderDebug: &SenderDebugInfo{}, + ReceiverDebug: &ReceiverDebugInfo{}, } } @@ -69,6 +78,17 @@ func (st *StreamTracker) UpdateStreamReplicationMessages(id string, exclusiveHig } } +// UpdateStreamLastTaskIDs updates the last seen task ids for a stream +func (st *StreamTracker) UpdateStreamLastTaskIDs(id string, taskIDs []int64) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.LastSeen = time.Now() + stream.LastTaskIDs = taskIDs + } +} + // UnregisterStream removes a stream from tracking func (st *StreamTracker) UnregisterStream(id string) { st.mu.Lock() @@ -77,6 +97,55 @@ func (st *StreamTracker) UnregisterStream(id string) { delete(st.streams, id) } +// SenderDebugInfo captures proxy-stream-sender internals for debugging +type SenderDebugInfo struct { + RingStartProxyID int64 `json:"ring_start_proxy_id,omitempty"` + RingSize int `json:"ring_size,omitempty"` + RingCapacity int `json:"ring_capacity,omitempty"` + RingHead int `json:"ring_head,omitempty"` + NextProxyTaskID int64 `json:"next_proxy_task_id,omitempty"` + PrevAckBySource map[string]int64 `json:"prev_ack_by_source,omitempty"` + LastHighBySource map[string]int64 `json:"last_high_by_source,omitempty"` + LastProxyHighBySource map[string]int64 `json:"last_proxy_high_by_source,omitempty"` + EntriesPreview []ProxyIDEntry `json:"entries_preview,omitempty"` +} + +// ProxyIDEntry is a preview of a ring buffer entry +type ProxyIDEntry struct { + ProxyID int64 `json:"proxy_id"` + SourceShard string `json:"source_shard"` + SourceTask int64 `json:"source_task"` +} + +// ReceiverDebugInfo captures proxy-stream-receiver ack aggregation state +type ReceiverDebugInfo struct { + AckByTarget map[string]int64 `json:"ack_by_target,omitempty"` + LastAggregatedMin int64 `json:"last_aggregated_min,omitempty"` + LastExclusiveHighOriginal int64 `json:"last_exclusive_high_original,omitempty"` +} + +// UpdateStreamSenderDebug sets the sender debug snapshot for a stream +func (st *StreamTracker) UpdateStreamSenderDebug(id string, info *SenderDebugInfo) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.SenderDebug = info + stream.LastSeen = time.Now() + } +} + +// UpdateStreamReceiverDebug sets the receiver debug snapshot for a stream +func (st *StreamTracker) UpdateStreamReceiverDebug(id string, info *ReceiverDebugInfo) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.ReceiverDebug = info + stream.LastSeen = time.Now() + } +} + // GetActiveStreams returns a copy of all active streams func (st *StreamTracker) GetActiveStreams() []StreamInfo { st.mu.RLock() From 1b20ad0b178153f474d0624ed484623324d707c1 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Thu, 4 Sep 2025 13:58:00 -0700 Subject: [PATCH 05/38] update --- proxy/debug.go | 39 +++++++++++++++++++++++++++++++++------ proxy/proxy_streams.go | 6 ++++-- proxy/stream_tracker.go | 27 --------------------------- 3 files changed, 37 insertions(+), 35 deletions(-) diff --git a/proxy/debug.go b/proxy/debug.go index d5fcb6f6..586bbdd7 100644 --- a/proxy/debug.go +++ b/proxy/debug.go @@ -12,24 +12,51 @@ import ( type ( + // ProxyIDEntry is a preview of a ring buffer entry + ProxyIDEntry struct { + ProxyID int64 `json:"proxy_id"` + SourceShard string `json:"source_shard"` + SourceTask int64 `json:"source_task"` + } + + // SenderDebugInfo captures proxy-stream-sender internals for debugging + SenderDebugInfo struct { + RingStartProxyID int64 `json:"ring_start_proxy_id"` + RingSize int `json:"ring_size"` + RingCapacity int `json:"ring_capacity"` + RingHead int `json:"ring_head"` + NextProxyTaskID int64 `json:"next_proxy_task_id"` + PrevAckBySource map[string]int64 `json:"prev_ack_by_source"` + LastHighBySource map[string]int64 `json:"last_high_by_source"` + LastProxyHighBySource map[string]int64 `json:"last_proxy_high_by_source"` + EntriesPreview []ProxyIDEntry `json:"entries_preview"` + } + + // ReceiverDebugInfo captures proxy-stream-receiver ack aggregation state + ReceiverDebugInfo struct { + AckByTarget map[string]int64 `json:"ack_by_target"` + LastAggregatedMin int64 `json:"last_aggregated_min"` + LastExclusiveHighOriginal int64 `json:"last_exclusive_high_original"` + } + // StreamInfo represents information about an active gRPC stream StreamInfo struct { ID string `json:"id"` Method string `json:"method"` Direction string `json:"direction"` - Role string `json:"role,omitempty"` + Role string `json:"role"` ClientShard string `json:"client_shard"` ServerShard string `json:"server_shard"` StartTime time.Time `json:"start_time"` LastSeen time.Time `json:"last_seen"` TotalDuration string `json:"total_duration"` IdleDuration string `json:"idle_duration"` - LastSyncWatermark *int64 `json:"last_sync_watermark,omitempty"` - LastSyncWatermarkTime *time.Time `json:"last_sync_watermark_time,omitempty"` - LastExclusiveHighWatermark *int64 `json:"last_exclusive_high_watermark,omitempty"` + LastSyncWatermark *int64 `json:"last_sync_watermark"` + LastSyncWatermarkTime *time.Time `json:"last_sync_watermark_time"` + LastExclusiveHighWatermark *int64 `json:"last_exclusive_high_watermark"` LastTaskIDs []int64 `json:"last_task_ids"` - SenderDebug *SenderDebugInfo `json:"sender_debug,omitempty"` - ReceiverDebug *ReceiverDebugInfo `json:"receiver_debug,omitempty"` + SenderDebug *SenderDebugInfo `json:"sender_debug"` + ReceiverDebug *ReceiverDebugInfo `json:"receiver_debug"` } // ShardDebugInfo contains debug information about shard distribution diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index f05ba829..5084a72b 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -426,7 +426,9 @@ func (s *proxyStreamSender) sendReplicationMessages( s.idRing.Append(proxyID, routed.SourceShard, original) // rewrite id t.SourceTaskId = proxyID - t.RawTaskInfo.TaskId = proxyID + if t.RawTaskInfo != nil { + t.RawTaskInfo.TaskId = proxyID + } proxyIDs = append(proxyIDs, proxyID) } s.mu.Unlock() @@ -435,7 +437,7 @@ func (s *proxyStreamSender) sendReplicationMessages( // Ensure exclusive high watermark is in proxy task ID space if len(m.Messages.ReplicationTasks) > 0 { - m.Messages.ExclusiveHighWatermark = m.Messages.ReplicationTasks[len(m.Messages.ReplicationTasks)-1].RawTaskInfo.TaskId + 1 + m.Messages.ExclusiveHighWatermark = m.Messages.ReplicationTasks[len(m.Messages.ReplicationTasks)-1].SourceTaskId + 1 } else { // No tasks in this batch: allocate a synthetic proxy task id mapping s.mu.Lock() diff --git a/proxy/stream_tracker.go b/proxy/stream_tracker.go index bbd4d977..ede493f0 100644 --- a/proxy/stream_tracker.go +++ b/proxy/stream_tracker.go @@ -97,33 +97,6 @@ func (st *StreamTracker) UnregisterStream(id string) { delete(st.streams, id) } -// SenderDebugInfo captures proxy-stream-sender internals for debugging -type SenderDebugInfo struct { - RingStartProxyID int64 `json:"ring_start_proxy_id,omitempty"` - RingSize int `json:"ring_size,omitempty"` - RingCapacity int `json:"ring_capacity,omitempty"` - RingHead int `json:"ring_head,omitempty"` - NextProxyTaskID int64 `json:"next_proxy_task_id,omitempty"` - PrevAckBySource map[string]int64 `json:"prev_ack_by_source,omitempty"` - LastHighBySource map[string]int64 `json:"last_high_by_source,omitempty"` - LastProxyHighBySource map[string]int64 `json:"last_proxy_high_by_source,omitempty"` - EntriesPreview []ProxyIDEntry `json:"entries_preview,omitempty"` -} - -// ProxyIDEntry is a preview of a ring buffer entry -type ProxyIDEntry struct { - ProxyID int64 `json:"proxy_id"` - SourceShard string `json:"source_shard"` - SourceTask int64 `json:"source_task"` -} - -// ReceiverDebugInfo captures proxy-stream-receiver ack aggregation state -type ReceiverDebugInfo struct { - AckByTarget map[string]int64 `json:"ack_by_target,omitempty"` - LastAggregatedMin int64 `json:"last_aggregated_min,omitempty"` - LastExclusiveHighOriginal int64 `json:"last_exclusive_high_original,omitempty"` -} - // UpdateStreamSenderDebug sets the sender debug snapshot for a stream func (st *StreamTracker) UpdateStreamSenderDebug(id string, info *SenderDebugInfo) { st.mu.Lock() From fbc4af4b132798bce5ddb67cbe15f787c24db201 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 5 Sep 2025 15:41:04 -0700 Subject: [PATCH 06/38] retry when shard not available --- proxy/proxy_streams.go | 202 +++++++++++++++++++++++++++++------------ proxy/shard_manager.go | 36 +++----- 2 files changed, 158 insertions(+), 80 deletions(-) diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 5084a72b..86dfe118 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -326,29 +326,56 @@ func (s *proxyStreamSender) recvAck( s.mu.Unlock() if len(shardToAck) > 0 { - for srcShard, originalAck := range shardToAck { - // If proxy watermark has passed an empty-batch proxy-high, translate it to original-high - s.mu.Lock() - // record last ack per source shard - s.prevAckBySource[srcShard] = originalAck - s.mu.Unlock() - - routedAck := &RoutedAck{ - TargetShard: s.targetShardID, - Req: &adminservice.StreamWorkflowReplicationMessagesRequest{ - Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ - SyncReplicationState: &replicationv1.SyncReplicationState{ - InclusiveLowWatermark: originalAck, - InclusiveLowWatermarkTime: attr.SyncReplicationState.InclusiveLowWatermarkTime, + sent := make(map[history.ClusterShardID]bool, len(shardToAck)) + logged := make(map[history.ClusterShardID]bool, len(shardToAck)) + numRemaining := len(shardToAck) + backoff := 10 * time.Millisecond + for numRemaining > 0 { + select { + case <-shutdownChan.Channel(): + return nil + default: + } + progress := false + for srcShard, originalAck := range shardToAck { + if sent[srcShard] { + continue + } + routedAck := &RoutedAck{ + TargetShard: s.targetShardID, + Req: &adminservice.StreamWorkflowReplicationMessagesRequest{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ + SyncReplicationState: &replicationv1.SyncReplicationState{ + InclusiveLowWatermark: originalAck, + InclusiveLowWatermarkTime: attr.SyncReplicationState.InclusiveLowWatermarkTime, + }, }, }, - }, - } - - // Log outgoing ACK for this source shard - s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", originalAck)) + } - s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger) + s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", originalAck)) + + if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger) { + sent[srcShard] = true + numRemaining-- + progress = true + // record last ack per source shard after forwarding + s.mu.Lock() + s.prevAckBySource[srcShard] = originalAck + s.mu.Unlock() + } else if !logged[srcShard] { + s.logger.Warn("No local ack channel for source shard; retrying until available", tag.NewStringTag("shard", ClusterShardIDtoString(srcShard))) + logged[srcShard] = true + } + } + if !progress { + time.Sleep(backoff) + if backoff < time.Second { + backoff *= 2 + } + } else if backoff > 10*time.Millisecond { + backoff = 10 * time.Millisecond + } } // TODO: ack to idle shards using prevAckBySource @@ -356,23 +383,58 @@ func (s *proxyStreamSender) recvAck( } else { // No new shards to ACK: send previous ack levels per source shard (if known) s.mu.Lock() + pendingPrev := make(map[history.ClusterShardID]int64, len(s.prevAckBySource)) for srcShard, prev := range s.prevAckBySource { - routedAck := &RoutedAck{ - TargetShard: s.targetShardID, - Req: &adminservice.StreamWorkflowReplicationMessagesRequest{ - Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ - SyncReplicationState: &replicationv1.SyncReplicationState{ - InclusiveLowWatermark: prev, - InclusiveLowWatermarkTime: attr.SyncReplicationState.InclusiveLowWatermarkTime, + pendingPrev[srcShard] = prev + } + s.mu.Unlock() + + sent := make(map[history.ClusterShardID]bool, len(pendingPrev)) + logged := make(map[history.ClusterShardID]bool, len(pendingPrev)) + numRemaining := len(pendingPrev) + backoff := 10 * time.Millisecond + for numRemaining > 0 { + select { + case <-shutdownChan.Channel(): + return nil + default: + } + progress := false + for srcShard, prev := range pendingPrev { + if sent[srcShard] { + continue + } + routedAck := &RoutedAck{ + TargetShard: s.targetShardID, + Req: &adminservice.StreamWorkflowReplicationMessagesRequest{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ + SyncReplicationState: &replicationv1.SyncReplicationState{ + InclusiveLowWatermark: prev, + InclusiveLowWatermarkTime: attr.SyncReplicationState.InclusiveLowWatermarkTime, + }, }, }, - }, + } + // Log fallback ACK for this source shard + s.logger.Info("Sender forwarding fallback ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", prev)) + if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger) { + sent[srcShard] = true + numRemaining-- + progress = true + } else if !logged[srcShard] { + s.logger.Warn("No local ack channel for source shard; retrying until available", tag.NewStringTag("shard", ClusterShardIDtoString(srcShard))) + logged[srcShard] = true + } + } + if !progress { + time.Sleep(backoff) + if backoff < time.Second { + backoff *= 2 + } + } else if backoff > 10*time.Millisecond { + backoff = 10 * time.Millisecond } - // Log fallback ACK for this source shard - s.logger.Info("Sender forwarding fallback ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", prev)) - s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger) } - s.mu.Unlock() } // Only after forwarding ACKs, discard the entries from the ring buffer @@ -648,36 +710,60 @@ func (r *proxyStreamReceiver) recvReplicationMessages( continue } - for targetShardID, tasks := range tasksByTargetShard { - forwardResp := &adminservice.StreamWorkflowReplicationMessagesResponse{ - Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ - Messages: &replicationv1.WorkflowReplicationMessages{ - ReplicationTasks: tasks, - ExclusiveHighWatermark: tasks[len(tasks)-1].RawTaskInfo.TaskId + 1, - Priority: attr.Messages.Priority, - }, - }, + // Retry across the whole target set until all sends succeed (or shutdown) + sentByTarget := make(map[history.ClusterShardID]bool, len(tasksByTargetShard)) + loggedByTarget := make(map[history.ClusterShardID]bool, len(tasksByTargetShard)) + for targetShardID := range tasksByTargetShard { + sentByTarget[targetShardID] = false + } + numRemaining := len(tasksByTargetShard) + backoff := 10 * time.Millisecond + for numRemaining > 0 { + select { + case <-shutdownChan.Channel(): + return nil + default: } - - if r.shardManager != nil && r.shardManager.IsLocalShard(targetShardID) { - if r.proxy != nil { - if sendCh, ok := r.proxy.GetRemoteSendChan(targetShardID); ok { - select { - case sendCh <- RoutedMessage{SourceShard: r.sourceShardID, Resp: forwardResp}: - case <-shutdownChan.Channel(): - return nil - } - } else { - r.logger.Error("No send channel found for target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) - } + progress := false + for targetShardID, tasks := range tasksByTargetShard { + if sentByTarget[targetShardID] { + continue } - } else if r.shardManager != nil { - if owner, ok := r.shardManager.GetShardOwner(targetShardID); ok { - r.logger.Info("Target shard owned by remote node", tag.NewStringTag("owner", owner), tag.NewStringTag("shard", ClusterShardIDtoString(targetShardID))) - // TODO: forward via inter-proxy transport + if ch, ok := r.proxy.GetRemoteSendChan(targetShardID); ok { + msg := RoutedMessage{ + SourceShard: r.sourceShardID, + Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ReplicationTasks: tasks, + ExclusiveHighWatermark: tasks[len(tasks)-1].RawTaskInfo.TaskId + 1, + Priority: attr.Messages.Priority, + }, + }, + }, + } + select { + case ch <- msg: + sentByTarget[targetShardID] = true + numRemaining-- + progress = true + case <-shutdownChan.Channel(): + return nil + } } else { - r.logger.Warn("Unable to determine owner for target shard", tag.NewStringTag("shard", ClusterShardIDtoString(targetShardID))) + if !loggedByTarget[targetShardID] { + r.logger.Warn("No send channel found for target shard; retrying until available", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + loggedByTarget[targetShardID] = true + } + } + } + if !progress { + time.Sleep(backoff) + if backoff < time.Second { + backoff *= 2 } + } else if backoff > 10*time.Millisecond { + backoff = 10 * time.Millisecond } } } diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 70758270..804c194b 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -43,7 +43,7 @@ type ( // GetShardInfo returns debug information about shard distribution GetShardInfo() ShardDebugInfo // DeliverAckToShardOwner routes an ACK request to the appropriate shard owner (local or remote) - DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) + DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) bool } shardManagerImpl struct { @@ -338,28 +338,18 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger, -) { - if sm.IsLocalShard(sourceShard) { - if proxy != nil { - if ackCh, ok := proxy.GetLocalAckChan(sourceShard); ok { - select { - case ackCh <- *routedAck: - case <-shutdownChan.Channel(): - return - } - } else { - logger.Warn("No local ack channel for source shard", tag.NewStringTag("shard", ClusterShardIDtoString(sourceShard))) - } +) bool { + if ackCh, ok := proxy.GetLocalAckChan(sourceShard); ok { + select { + case ackCh <- *routedAck: + return true + case <-shutdownChan.Channel(): + return false } - return - } - - // TODO: forward to remote owner via inter-proxy transport. - if owner, ok := sm.GetShardOwner(sourceShard); ok { - logger.Info("ACK belongs to remote source shard owner", tag.NewStringTag("owner", owner), tag.NewStringTag("shard", ClusterShardIDtoString(sourceShard))) } else { - logger.Warn("Unable to determine source shard owner for ACK", tag.NewStringTag("shard", ClusterShardIDtoString(sourceShard))) + logger.Warn("No local ack channel for source shard", tag.NewStringTag("shard", ClusterShardIDtoString(sourceShard))) } + return false } func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.ClusterShardID) { @@ -544,14 +534,16 @@ func (nsm *noopShardManager) GetShardInfo() ShardDebugInfo { } } -func (nsm *noopShardManager) DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) { +func (nsm *noopShardManager) DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) bool { if proxy != nil { if ackCh, ok := proxy.GetLocalAckChan(srcShard); ok { select { case ackCh <- *routedAck: + return true case <-shutdownChan.Channel(): - return + return false } } } + return false } From 9f5cd62bfdcdfe0e0805ab10c68e7c9f17d69c8a Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 5 Sep 2025 16:53:24 -0700 Subject: [PATCH 07/38] add tags to log --- proxy/proxy_streams.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 86dfe118..e9754c81 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -168,8 +168,7 @@ func (b *proxyIDRingBuffer) Discard(count int) { // (another proxy or a target server) and receiving ACKs back. // This is scaffolding only – the concrete behavior will be wired in later. type proxyStreamSender struct { - logger log.Logger - // shardID history.ClusterShardID + logger log.Logger shardManager ShardManager proxy *Proxy targetShardID history.ClusterShardID @@ -235,6 +234,9 @@ func (s *proxyStreamSender) Run( targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, shutdownChan channel.ShutdownOnce, ) { + s.logger = log.With(s.logger, + tag.NewStringTag("role", "sender"), + ) // Register this sender as the owner of the shard for the duration of the stream if s.shardManager != nil { @@ -532,8 +534,7 @@ func (s *proxyStreamSender) sendReplicationMessages( // proxyStreamReceiver receives replication messages from a local/remote server and // produces ACKs destined for the original sender. type proxyStreamReceiver struct { - logger log.Logger - // shardID history.ClusterShardID + logger log.Logger shardManager ShardManager proxy *Proxy adminClient adminservice.AdminServiceClient @@ -573,6 +574,9 @@ func (r *proxyStreamReceiver) Run( r.logger = log.With(r.logger, tag.NewStringTag("client", ClusterShardIDtoString(r.targetShardID)), tag.NewStringTag("server", ClusterShardIDtoString(r.sourceShardID)), + tag.NewStringTag("stream-source-shard", ClusterShardIDtoString(r.sourceShardID)), + tag.NewStringTag("stream-target-shard", ClusterShardIDtoString(r.targetShardID)), + tag.NewStringTag("role", "receiver"), ) // Build metadata for local server stream From b51520a8ee69c301fdac848b284d204f01d59fc6 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Tue, 23 Sep 2025 15:40:31 -0700 Subject: [PATCH 08/38] add intra proxy streams --- common/intra_headers.go | 37 + config/config.go | 2 - .../config/cluster-a-mux-client-proxy-1.yaml | 18 +- .../config/cluster-a-mux-client-proxy-2.yaml | 18 +- .../config/cluster-b-mux-server-proxy-1.yaml | 1 - .../config/cluster-b-mux-server-proxy-2.yaml | 1 - endtoendtest/echo_server.go | 1 - interceptor/translation_interceptor.go | 11 + proxy/adminservice.go | 4 + proxy/debug.go | 25 +- proxy/intra_proxy_router.go | 731 ++++++++++++++++++ proxy/proxy.go | 66 +- proxy/proxy_streams.go | 50 +- proxy/shard_manager.go | 456 ++++++++--- proxy/stream_tracker.go | 28 + proxy/test/replication_failover_test.go | 2 +- 16 files changed, 1236 insertions(+), 215 deletions(-) create mode 100644 common/intra_headers.go create mode 100644 proxy/intra_proxy_router.go diff --git a/common/intra_headers.go b/common/intra_headers.go new file mode 100644 index 00000000..7bc2731f --- /dev/null +++ b/common/intra_headers.go @@ -0,0 +1,37 @@ +package common + +import ( + "context" + + "google.golang.org/grpc/metadata" +) + +const ( + // Intra-proxy identification and tracing headers + IntraProxyHeaderKey = "x-s2s-intra-proxy" + IntraProxyHeaderValue = "1" + IntraProxyOriginProxyIDHeader = "x-s2s-origin-proxy-id" + IntraProxyHopCountHeader = "x-s2s-hop-count" + IntraProxyTraceIDHeader = "x-s2s-trace-id" +) + +// IsIntraProxy checks incoming context metadata for intra-proxy marker. +func IsIntraProxy(ctx context.Context) bool { + if md, ok := metadata.FromIncomingContext(ctx); ok { + if vals := md.Get(IntraProxyHeaderKey); len(vals) > 0 && vals[0] == IntraProxyHeaderValue { + return true + } + } + return false +} + +// WithIntraProxyHeaders returns a new outgoing context with intra-proxy headers set. +func WithIntraProxyHeaders(ctx context.Context, headers map[string]string) context.Context { + md, _ := metadata.FromOutgoingContext(ctx) + md = md.Copy() + md.Set(IntraProxyHeaderKey, IntraProxyHeaderValue) + for k, v := range headers { + md.Set(k, v) + } + return metadata.NewOutgoingContext(ctx, md) +} diff --git a/config/config.go b/config/config.go index f33fa42b..bf1fee43 100644 --- a/config/config.go +++ b/config/config.go @@ -224,8 +224,6 @@ type ( MemberlistConfig struct { // Enable distributed shard management using memberlist Enabled bool `yaml:"enabled"` - // Enable proxy-to-proxy forwarding (requires Enabled=true) - EnableForwarding bool `yaml:"enableForwarding"` // Node name for this proxy instance in the cluster NodeName string `yaml:"nodeName"` // Bind address for memberlist cluster communication diff --git a/develop/config/cluster-a-mux-client-proxy-1.yaml b/develop/config/cluster-a-mux-client-proxy-1.yaml index bc79e067..3c54388b 100644 --- a/develop/config/cluster-a-mux-client-proxy-1.yaml +++ b/develop/config/cluster-a-mux-client-proxy-1.yaml @@ -28,20 +28,4 @@ mux: # localShardCount: 2 # remoteShardCount: 3 profiling: - pprofAddress: "localhost:6060" -memberlist: - enabled: true - enableForwarding: true - nodeName: "proxy-node-a-1" - bindAddr: "0.0.0.0" - bindPort: 6135 - joinAddrs: - - "localhost:6235" - proxyAddresses: - "proxy-node-a-1": "localhost:6133" - "proxy-node-a-2": "localhost:6233" - # TCP-only configuration for restricted networks - tcpOnly: true # Use TCP transport only, disable UDP - disableTCPPings: true # Disable TCP pings for faster convergence - probeTimeoutMs: 1000 # Longer timeout for network latency - probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file + pprofAddress: "localhost:6060" \ No newline at end of file diff --git a/develop/config/cluster-a-mux-client-proxy-2.yaml b/develop/config/cluster-a-mux-client-proxy-2.yaml index b9cf33a7..8bdbfbb1 100644 --- a/develop/config/cluster-a-mux-client-proxy-2.yaml +++ b/develop/config/cluster-a-mux-client-proxy-2.yaml @@ -28,20 +28,4 @@ mux: # localShardCount: 2 # remoteShardCount: 3 profiling: - pprofAddress: "localhost:6061" -memberlist: - enabled: true - enableForwarding: true - nodeName: "proxy-node-a-2" - bindAddr: "0.0.0.0" - bindPort: 6235 - joinAddrs: - - "localhost:6135" - proxyAddresses: - "proxy-node-a-1": "localhost:6133" - "proxy-node-a-2": "localhost:6233" - # TCP-only configuration for restricted networks - tcpOnly: true # Use TCP transport only, disable UDP - disableTCPPings: true # Disable TCP pings for faster convergence - probeTimeoutMs: 1000 # Longer timeout for network latency - probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file + pprofAddress: "localhost:6061" \ No newline at end of file diff --git a/develop/config/cluster-b-mux-server-proxy-1.yaml b/develop/config/cluster-b-mux-server-proxy-1.yaml index a9912a22..4e2b6da9 100644 --- a/develop/config/cluster-b-mux-server-proxy-1.yaml +++ b/develop/config/cluster-b-mux-server-proxy-1.yaml @@ -31,7 +31,6 @@ profiling: pprofAddress: "localhost:6070" memberlist: enabled: true - enableForwarding: true nodeName: "proxy-node-b-1" bindAddr: "0.0.0.0" bindPort: 6335 diff --git a/develop/config/cluster-b-mux-server-proxy-2.yaml b/develop/config/cluster-b-mux-server-proxy-2.yaml index 96689bb4..2d7b111a 100644 --- a/develop/config/cluster-b-mux-server-proxy-2.yaml +++ b/develop/config/cluster-b-mux-server-proxy-2.yaml @@ -31,7 +31,6 @@ profiling: pprofAddress: "localhost:6071" memberlist: enabled: true - enableForwarding: true nodeName: "proxy-node-b-2" bindAddr: "0.0.0.0" bindPort: 6435 diff --git a/endtoendtest/echo_server.go b/endtoendtest/echo_server.go index 83288ad8..22e9504f 100644 --- a/endtoendtest/echo_server.go +++ b/endtoendtest/echo_server.go @@ -115,7 +115,6 @@ func NewEchoServer( configProvider := config.NewMockConfigProvider(*localClusterInfo.S2sProxyConfig) proxy = s2sproxy.NewProxy( configProvider, - nil, logger, ) diff --git a/interceptor/translation_interceptor.go b/interceptor/translation_interceptor.go index 56cf9622..3fc2b836 100644 --- a/interceptor/translation_interceptor.go +++ b/interceptor/translation_interceptor.go @@ -10,6 +10,7 @@ import ( "go.temporal.io/server/common/log/tag" "google.golang.org/grpc" + "github.com/temporalio/s2s-proxy/common" "github.com/temporalio/s2s-proxy/metrics" ) @@ -75,6 +76,16 @@ func (i *TranslationInterceptor) InterceptStream( info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { + + i.logger.Debug("InterceptStream", tag.NewAnyTag("method", info.FullMethod)) + // Skip translation for intra-proxy streams + if common.IsIntraProxy(ss.Context()) { + err := handler(srv, ss) + if err != nil { + i.logger.Error("grpc handler with error: %v", tag.Error(err)) + } + return err + } return handler(srv, newStreamTranslator(ss, i.logger, i.translators)) } diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 41bf42be..2490f6b8 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -243,6 +243,10 @@ func ClusterShardIDtoString(sd history.ClusterShardID) string { return fmt.Sprintf("(id: %d, shard: %d)", sd.ClusterID, sd.ShardID) } +func ClusterShardIDtoShortString(sd history.ClusterShardID) string { + return fmt.Sprintf("%d:%d", sd.ClusterID, sd.ShardID) +} + // StreamWorkflowReplicationMessages establishes an HTTP/2 stream. gRPC passes us a stream that represents the initiating server, // and we can freely Send and Recv on that "server". Because this is a proxy, we also establish a bidirectional // stream using our configured adminClient. When we Recv on the initiator, we Send to the client. diff --git a/proxy/debug.go b/proxy/debug.go index 586bbdd7..60a5abc9 100644 --- a/proxy/debug.go +++ b/proxy/debug.go @@ -61,15 +61,14 @@ type ( // ShardDebugInfo contains debug information about shard distribution ShardDebugInfo struct { - Enabled bool `json:"enabled"` - ForwardingEnabled bool `json:"forwarding_enabled"` - NodeName string `json:"node_name"` - LocalShards []history.ClusterShardID `json:"local_shards"` - LocalShardCount int `json:"local_shard_count"` - ClusterNodes []string `json:"cluster_nodes"` - ClusterSize int `json:"cluster_size"` - RemoteShards map[string]string `json:"remote_shards"` // shard_id -> node_name - RemoteShardCounts map[string]int `json:"remote_shard_counts"` // node_name -> shard_count + Enabled bool `json:"enabled"` + NodeName string `json:"node_name"` + LocalShards map[string]history.ClusterShardID `json:"local_shards"` // key: "clusterID:shardID" + LocalShardCount int `json:"local_shard_count"` + ClusterNodes []string `json:"cluster_nodes"` + ClusterSize int `json:"cluster_size"` + RemoteShards map[string]string `json:"remote_shards"` // shard_id -> node_name + RemoteShardCounts map[string]int `json:"remote_shard_counts"` // node_name -> shard_count } // ChannelDebugInfo holds debug information about channels @@ -84,7 +83,7 @@ type ( Timestamp time.Time `json:"timestamp"` ActiveStreams []StreamInfo `json:"active_streams"` StreamCount int `json:"stream_count"` - ShardInfo ShardDebugInfo `json:"shard_info"` + ShardInfos []ShardDebugInfo `json:"shard_infos"` ChannelInfo ChannelDebugInfo `json:"channel_info"` } ) @@ -94,21 +93,21 @@ func HandleDebugInfo(w http.ResponseWriter, r *http.Request, proxyInstance *Prox var activeStreams []StreamInfo var streamCount int - var shardInfo ShardDebugInfo + var shardInfos []ShardDebugInfo var channelInfo ChannelDebugInfo // Get active streams information streamTracker := GetGlobalStreamTracker() activeStreams = streamTracker.GetActiveStreams() streamCount = streamTracker.GetStreamCount() - shardInfo = proxyInstance.GetShardInfo() + shardInfos = proxyInstance.GetShardInfos() channelInfo = proxyInstance.GetChannelInfo() response := DebugResponse{ Timestamp: time.Now(), ActiveStreams: activeStreams, StreamCount: streamCount, - ShardInfo: shardInfo, + ShardInfos: shardInfos, ChannelInfo: channelInfo, } diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go new file mode 100644 index 00000000..4ccf058b --- /dev/null +++ b/proxy/intra_proxy_router.go @@ -0,0 +1,731 @@ +package proxy + +import ( + "context" + "fmt" + "io" + "sync" + "time" + + "go.temporal.io/server/api/adminservice/v1" + replicationv1 "go.temporal.io/server/api/replication/v1" + "go.temporal.io/server/client/history" + "go.temporal.io/server/common/channel" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/temporalio/s2s-proxy/common" + "github.com/temporalio/s2s-proxy/config" +) + +// intraProxyManager maintains long-lived intra-proxy streams to peer proxies and +// provides simple send helpers (e.g., forwarding ACKs). +type intraProxyManager struct { + logger log.Logger + streamsMu sync.RWMutex + shardManager ShardManager + shardCountConfig config.ShardCountConfig + // Group state by remote peer for unified lifecycle ops + peers map[string]*peerState +} + +type peerState struct { + conn *grpc.ClientConn + receivers map[peerStreamKey]*intraProxyStreamReceiver + senders map[peerStreamKey]*intraProxyStreamSender + recvShutdown map[peerStreamKey]channel.ShutdownOnce +} + +type peerStreamKey struct { + clientShard history.ClusterShardID + serverShard history.ClusterShardID +} + +func newIntraProxyManager(logger log.Logger, shardManager ShardManager, shardCountConfig config.ShardCountConfig) *intraProxyManager { + return &intraProxyManager{ + logger: logger, + shardManager: shardManager, + shardCountConfig: shardCountConfig, + peers: make(map[string]*peerState), + } +} + +// intraProxyStreamSender registers server stream and forwards upstream ACKs to shard owners (local or remote). +// Replication messages are sent by intraProxyManager.sendMessages using the registered server stream. +type intraProxyStreamSender struct { + logger log.Logger + shardManager ShardManager + proxy *Proxy + intraMgr *intraProxyManager + peerNodeName string + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + streamID string + server adminservice.AdminService_StreamWorkflowReplicationMessagesServer +} + +func (s *intraProxyStreamSender) Run( + targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + shutdownChan channel.ShutdownOnce, +) error { + s.logger.Info("intraProxyStreamSender Run") + defer s.logger.Info("intraProxyStreamSender Run finished") + + // Register server-side intra-proxy stream in tracker + s.streamID = BuildIntraProxySenderStreamID(s.peerNodeName, s.sourceShardID, s.targetShardID) + s.logger = log.With(s.logger, tag.NewStringTag("streamID", s.streamID)) + st := GetGlobalStreamTracker() + st.RegisterStream(s.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(s.targetShardID), ClusterShardIDtoString(s.sourceShardID), StreamRoleForwarder) + defer st.UnregisterStream(s.streamID) + + s.server = targetStreamServer + + // register this sender so sendMessages can use it + s.intraMgr.RegisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID, s) + defer s.intraMgr.UnregisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID) + + // recv ACKs from peer and route to original source shard owner + return s.recvAck(shutdownChan) +} + +// recvAck reads ACKs from the peer and routes them to the source shard owner. +func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) error { + s.logger.Info("intraProxyStreamSender recvAck") + defer s.logger.Info("intraProxyStreamSender recvAck finished") + + for !shutdownChan.IsShutdown() { + req, err := s.server.Recv() + if err == io.EOF { + s.logger.Info("intraProxyStreamSender recvAck encountered EOF") + return nil + } + if err != nil { + shutdownChan.Shutdown() + s.logger.Error("intraProxyStreamSender recvAck encountered error", tag.Error(err)) + return err + } + if attr, ok := req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { + ack := attr.SyncReplicationState.InclusiveLowWatermark + + s.logger.Info("Sender received upstream ACK", tag.NewInt64("inclusive_low", ack)) + + // Update server-side intra-proxy stream tracker with sync watermark + st := GetGlobalStreamTracker() + st.UpdateStreamSyncReplicationState(s.streamID, ack, nil) + st.UpdateStream(s.streamID) + + routedAck := &RoutedAck{ + TargetShard: s.targetShardID, + Req: &adminservice.StreamWorkflowReplicationMessagesRequest{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ + SyncReplicationState: &replicationv1.SyncReplicationState{InclusiveLowWatermark: ack}, + }, + }, + } + + s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(s.sourceShardID)), tag.NewInt64("ack", ack)) + // FIXME: should retry. If not succeed, return and shutdown the stream + sent := s.shardManager.DeliverAckToShardOwner(s.sourceShardID, routedAck, s.proxy, shutdownChan, s.logger, ack, false) + if !sent { + s.logger.Error("Sender failed to forward ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(s.sourceShardID)), tag.NewInt64("ack", ack)) + return fmt.Errorf("failed to forward ACK to source shard") + } + } + } + return nil +} + +// sendReplicationMessages sends replication messages to the peer via the server stream. +func (s *intraProxyStreamSender) sendReplicationMessages(resp *adminservice.StreamWorkflowReplicationMessagesResponse) error { + // Update server-side intra-proxy tracker for outgoing messages + if msgs, ok := resp.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages); ok && msgs.Messages != nil { + st := GetGlobalStreamTracker() + ids := make([]int64, 0, len(msgs.Messages.ReplicationTasks)) + for _, t := range msgs.Messages.ReplicationTasks { + ids = append(ids, t.SourceTaskId) + } + st.UpdateStreamLastTaskIDs(s.streamID, ids) + st.UpdateStreamReplicationMessages(s.streamID, msgs.Messages.ExclusiveHighWatermark) + st.UpdateStream(s.streamID) + } + if err := s.server.Send(resp); err != nil { + return err + } + return nil +} + +// intraProxyStreamReceiver ensures a client stream to peer exists and sends aggregated ACKs upstream. +type intraProxyStreamReceiver struct { + logger log.Logger + shardManager ShardManager + proxy *Proxy + intraMgr *intraProxyManager + peerNodeName string + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + streamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient + streamID string + shutdown channel.ShutdownOnce + cancel context.CancelFunc +} + +// Run opens the client stream with metadata, registers tracking, and starts receiver goroutines. +func (r *intraProxyStreamReceiver) Run(ctx context.Context, self *Proxy, conn *grpc.ClientConn) error { + r.logger.Info("intraProxyStreamReceiver Run") + // Build metadata according to receiver pattern: client=targetShard, server=sourceShard + md := metadata.New(map[string]string{}) + md.Set(history.MetadataKeyClientClusterID, fmt.Sprintf("%d", r.targetShardID.ClusterID)) + md.Set(history.MetadataKeyClientShardID, fmt.Sprintf("%d", r.targetShardID.ShardID)) + md.Set(history.MetadataKeyServerClusterID, fmt.Sprintf("%d", r.sourceShardID.ClusterID)) + md.Set(history.MetadataKeyServerShardID, fmt.Sprintf("%d", r.sourceShardID.ShardID)) + ctx = metadata.NewOutgoingContext(ctx, md) + ctx = common.WithIntraProxyHeaders(ctx, map[string]string{ + common.IntraProxyOriginProxyIDHeader: r.shardManager.GetShardInfo().NodeName, + }) + + // Ensure we can cancel Recv() by canceling the context when tearing down + ctx, cancel := context.WithCancel(ctx) + r.cancel = cancel + + client := adminservice.NewAdminServiceClient(conn) + stream, err := client.StreamWorkflowReplicationMessages(ctx) + if err != nil { + if r.cancel != nil { + r.cancel() + } + return err + } + r.streamClient = stream + + // Register client-side intra-proxy stream in tracker + r.streamID = BuildIntraProxyReceiverStreamID(r.peerNodeName, r.targetShardID, r.sourceShardID) + r.logger = log.With(r.logger, tag.NewStringTag("streamID", r.streamID)) + st := GetGlobalStreamTracker() + st.RegisterStream(r.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(r.targetShardID), ClusterShardIDtoString(r.sourceShardID), StreamRoleForwarder) + defer st.UnregisterStream(r.streamID) + + // Start replication receiver loop + return r.recvReplicationMessages(self) +} + +// recvReplicationMessages receives replication messages and forwards to local shard owner. +func (r *intraProxyStreamReceiver) recvReplicationMessages(self *Proxy) error { + r.logger.Info("recvReplicationMessages started") + defer r.logger.Info("recvReplicationMessages finished") + + shutdown := r.shutdown + defer shutdown.Shutdown() + backoff := 10 * time.Millisecond + for !shutdown.IsShutdown() { + resp, err := r.streamClient.Recv() + if err == io.EOF { + r.logger.Info("recvReplicationMessages encountered EOF") + return nil + } + if err != nil { + r.logger.Error("intra-proxy stream Recv error", tag.Error(err)) + return err + } + if msgs, ok := resp.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages); ok && msgs.Messages != nil { + // Update client-side intra-proxy tracker for received messages + st := GetGlobalStreamTracker() + ids := make([]int64, 0, len(msgs.Messages.ReplicationTasks)) + for _, t := range msgs.Messages.ReplicationTasks { + ids = append(ids, t.SourceTaskId) + } + st.UpdateStreamLastTaskIDs(r.streamID, ids) + st.UpdateStreamReplicationMessages(r.streamID, msgs.Messages.ExclusiveHighWatermark) + st.UpdateStream(r.streamID) + + r.logger.Info(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", msgs.Messages.ExclusiveHighWatermark, ids)) + + msg := RoutedMessage{SourceShard: r.sourceShardID, Resp: resp} + sent := false + logged := false + for !sent { + if ch, ok := self.GetRemoteSendChan(r.targetShardID); ok { + select { + case ch <- msg: + sent = true + r.logger.Info("Receiver sent ReplicationTasks to local target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID)), tag.NewInt64("exclusive_high", msgs.Messages.ExclusiveHighWatermark)) + case <-shutdown.Channel(): + return nil + } + } else { + if !logged { + r.logger.Warn("No local send channel yet for target shard; waiting", + tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID))) + logged = true + } + time.Sleep(backoff) + if backoff < time.Second { + backoff *= 2 + } + } + } + backoff = 10 * time.Millisecond + } + } + return nil +} + +// sendAck sends an ACK upstream via the client stream and updates tracker. +func (r *intraProxyStreamReceiver) sendAck(req *adminservice.StreamWorkflowReplicationMessagesRequest) error { + if err := r.streamClient.Send(req); err != nil { + return err + } + if attr, ok := req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { + st := GetGlobalStreamTracker() + st.UpdateStreamSyncReplicationState(r.streamID, attr.SyncReplicationState.InclusiveLowWatermark, nil) + st.UpdateStream(r.streamID) + } + return nil +} + +func (m *intraProxyManager) RegisterSender( + peerNodeName string, + clientShard history.ClusterShardID, + serverShard history.ClusterShardID, + sender *intraProxyStreamSender, +) { + // Cross-cluster only + if clientShard.ClusterID == serverShard.ClusterID { + return + } + key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + m.streamsMu.Lock() + ps := m.peers[peerNodeName] + if ps == nil { + ps = &peerState{receivers: make(map[peerStreamKey]*intraProxyStreamReceiver), senders: make(map[peerStreamKey]*intraProxyStreamSender), recvShutdown: make(map[peerStreamKey]channel.ShutdownOnce)} + m.peers[peerNodeName] = ps + } + if ps.senders == nil { + ps.senders = make(map[peerStreamKey]*intraProxyStreamSender) + } + ps.senders[key] = sender + m.streamsMu.Unlock() +} + +func (m *intraProxyManager) UnregisterSender( + peerNodeName string, + clientShard history.ClusterShardID, + serverShard history.ClusterShardID, +) { + key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + m.streamsMu.Lock() + if ps := m.peers[peerNodeName]; ps != nil && ps.senders != nil { + delete(ps.senders, key) + } + m.streamsMu.Unlock() +} + +// EnsureReceiverForPeerShard ensures a client stream and an ACK aggregator exist for the given peer/shard pair. +func (m *intraProxyManager) EnsureReceiverForPeerShard(p *Proxy, peerNodeName string, clientShard history.ClusterShardID, serverShard history.ClusterShardID) { + logger := log.With(m.logger, + tag.NewStringTag("peerNodeName", peerNodeName), + tag.NewStringTag("clientShard", ClusterShardIDtoString(clientShard)), + tag.NewStringTag("serverShard", ClusterShardIDtoString(serverShard))) + logger.Info("EnsureReceiverForPeerShard") + + // Cross-cluster only + if clientShard.ClusterID == serverShard.ClusterID { + return + } + // Do not create intra-proxy streams to self instance + if peerNodeName == m.shardManager.GetNodeName() { + return + } + // Require at least one shard to be local to this instance + if !m.shardManager.IsLocalShard(clientShard) && !m.shardManager.IsLocalShard(serverShard) { + return + } + // Consolidated path: ensure stream and background loops + err := m.ensureStream(context.Background(), logger, peerNodeName, clientShard, serverShard, p) + if err != nil { + logger.Error("failed to ensureStream", tag.Error(err)) + } +} + +// ensurePeer ensures a per-peer state with a shared gRPC connection exists. +func (m *intraProxyManager) ensurePeer( + ctx context.Context, + peerNodeName string, + p *Proxy, +) (*peerState, error) { + m.streamsMu.RLock() + if ps, ok := m.peers[peerNodeName]; ok && ps != nil && ps.conn != nil { + m.streamsMu.RUnlock() + return ps, nil + } + m.streamsMu.RUnlock() + + // Build TLS from this proxy's outbound client TLS config if available + var dialOpts []grpc.DialOption + + // TODO: FIX this for new config format + // var tlsCfg *config.ClientTLSConfig + // if p.outboundServer != nil { + // t := p.outboundServer.config.Client.TLS + // tlsCfg = &t + // } else if p.inboundServer != nil { + // t := p.inboundServer.config.Client.TLS + // tlsCfg = &t + // } + // if tlsCfg != nil && tlsCfg.IsEnabled() { + // cfg, e := encryption.GetClientTLSConfig(*tlsCfg) + // if e != nil { + // return nil, e + // } + // dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(cfg))) + // } else { + // dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + // } + // // Reuse default grpc options from transport + // dialOpts = append(dialOpts, + // grpc.WithDefaultServiceConfig(transport.DefaultServiceConfig), + // grpc.WithDisableServiceConfig(), + // ) + + proxyAddresses, ok := m.shardManager.GetProxyAddress(peerNodeName) + if !ok { + return nil, fmt.Errorf("proxy address not found") + } + + cc, err := grpc.DialContext(ctx, proxyAddresses, dialOpts...) //nolint:staticcheck // acceptable here + if err != nil { + return nil, err + } + + m.streamsMu.Lock() + ps := m.peers[peerNodeName] + if ps == nil { + ps = &peerState{conn: cc, receivers: make(map[peerStreamKey]*intraProxyStreamReceiver), senders: make(map[peerStreamKey]*intraProxyStreamSender), recvShutdown: make(map[peerStreamKey]channel.ShutdownOnce)} + m.peers[peerNodeName] = ps + } else { + old := ps.conn + ps.conn = cc + if old != nil { + _ = old.Close() + } + if ps.receivers == nil { + ps.receivers = make(map[peerStreamKey]*intraProxyStreamReceiver) + } + if ps.senders == nil { + ps.senders = make(map[peerStreamKey]*intraProxyStreamSender) + } + if ps.recvShutdown == nil { + ps.recvShutdown = make(map[peerStreamKey]channel.ShutdownOnce) + } + } + m.streamsMu.Unlock() + return ps, nil +} + +// ensureStream dials a peer proxy outbound server and opens a replication stream. +func (m *intraProxyManager) ensureStream( + ctx context.Context, + logger log.Logger, + peerNodeName string, + clientShard history.ClusterShardID, + serverShard history.ClusterShardID, + p *Proxy, +) error { + logger.Info("ensureStream") + key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + + // Fast path: already exists + m.streamsMu.RLock() + if ps, ok := m.peers[peerNodeName]; ok && ps != nil { + if r, ok2 := ps.receivers[key]; ok2 && r != nil && r.streamClient != nil { + m.streamsMu.RUnlock() + logger.Info("ensureStream reused") + return nil + } + } + m.streamsMu.RUnlock() + + // Reuse shared connection per peer + ps, err := m.ensurePeer(ctx, peerNodeName, p) + if err != nil { + logger.Error("Failed to ensure peer", tag.Error(err)) + return err + } + + // Create receiver and register tracking + recv := &intraProxyStreamReceiver{ + logger: log.With(m.logger, + tag.NewStringTag("peerNodeName", peerNodeName), + tag.NewStringTag("targetShardID", ClusterShardIDtoString(clientShard)), + tag.NewStringTag("sourceShardID", ClusterShardIDtoString(serverShard))), + shardManager: m.shardManager, + proxy: p, + intraMgr: m, + peerNodeName: peerNodeName, + targetShardID: clientShard, + sourceShardID: serverShard, + } + // initialize shutdown handle and register it for lifecycle management + recv.shutdown = channel.NewShutdownOnce() + m.streamsMu.Lock() + ps.receivers[key] = recv + ps.recvShutdown[key] = recv.shutdown + m.streamsMu.Unlock() + + // Let the receiver open stream, register tracking, and start goroutines + go func() { + if err := recv.Run(ctx, p, ps.conn); err != nil { + recv.logger.Error("intraProxyStreamReceiver Run failed", tag.Error(err)) + } + }() + return nil +} + +// sendAck forwards an ACK to the specified peer stream (creates it on demand). +func (m *intraProxyManager) sendAck( + ctx context.Context, + peerNodeName string, + clientShard history.ClusterShardID, + serverShard history.ClusterShardID, + p *Proxy, + req *adminservice.StreamWorkflowReplicationMessagesRequest, +) error { + key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + m.streamsMu.RLock() + defer m.streamsMu.RUnlock() + if ps, ok := m.peers[peerNodeName]; ok && ps != nil { + if r, ok2 := ps.receivers[key]; ok2 && r != nil && r.streamClient != nil { + if err := r.sendAck(req); err != nil { + m.logger.Error("Failed to send intra-proxy ACK", tag.Error(err)) + return err + } + return nil + } + } + return fmt.Errorf("peer not found") +} + +// sendReplicationMessages sends replication messages to the peer via the server stream. +func (m *intraProxyManager) sendReplicationMessages( + ctx context.Context, + peerNodeName string, + clientShard history.ClusterShardID, + serverShard history.ClusterShardID, + self *Proxy, + resp *adminservice.StreamWorkflowReplicationMessagesResponse, +) error { + key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + + // Try server stream first with short retry/backoff to await registration + deadline := time.Now().Add(2 * time.Second) + backoff := 10 * time.Millisecond + for { + var sender *intraProxyStreamSender + m.streamsMu.RLock() + ps, ok := m.peers[peerNodeName] + if ok && ps != nil && ps.senders != nil { + if s, ok2 := ps.senders[key]; ok2 && s != nil { + sender = s + } + } + m.streamsMu.RUnlock() + + if sender != nil { + if err := sender.sendReplicationMessages(resp); err != nil { + m.logger.Error("Failed to send intra-proxy replication messages via server stream", tag.Error(err)) + return err + } + return nil + } + + if time.Now().After(deadline) { + break + } + time.Sleep(backoff) + if backoff < 200*time.Millisecond { + backoff *= 2 + } + } + + return fmt.Errorf("stream does not support SendMsg for responses") +} + +// closePeerLocked shuts down and removes all resources for a peer. Caller must hold m.streamsMu. +func (m *intraProxyManager) closePeerLocked(peer string, ps *peerState) { + // Shutdown receivers and unregister client-side tracker entries + for key, shut := range ps.recvShutdown { + if shut != nil { + shut.Shutdown() + } + st := GetGlobalStreamTracker() + cliID := BuildIntraProxyReceiverStreamID(peer, key.clientShard, key.serverShard) + st.UnregisterStream(cliID) + delete(ps.recvShutdown, key) + } + // Close client streams (receiver cleanup is handled by its own goroutine) + for key := range ps.receivers { + delete(ps.receivers, key) + } + // Unregister server-side tracker entries + for key := range ps.senders { + st := GetGlobalStreamTracker() + srvID := BuildIntraProxySenderStreamID(peer, key.clientShard, key.serverShard) + st.UnregisterStream(srvID) + delete(ps.senders, key) + } + if ps.conn != nil { + _ = ps.conn.Close() + ps.conn = nil + } + delete(m.peers, peer) +} + +// closePeerShardLocked shuts down and removes resources for a specific peer/shard pair. Caller must hold m.streamsMu. +func (m *intraProxyManager) closePeerShardLocked(peer string, ps *peerState, key peerStreamKey) { + m.logger.Info("closePeerShardLocked", tag.NewStringTag("peer", peer), tag.NewStringTag("clientShard", ClusterShardIDtoString(key.clientShard)), tag.NewStringTag("serverShard", ClusterShardIDtoString(key.serverShard))) + if shut, ok := ps.recvShutdown[key]; ok && shut != nil { + shut.Shutdown() + st := GetGlobalStreamTracker() + cliID := BuildIntraProxyReceiverStreamID(peer, key.clientShard, key.serverShard) + st.UnregisterStream(cliID) + delete(ps.recvShutdown, key) + } + if r, ok := ps.receivers[key]; ok { + // cancel stream context and attempt to close client send side + if r.cancel != nil { + r.cancel() + } + if r.streamClient != nil { + _ = r.streamClient.CloseSend() + } + delete(ps.receivers, key) + } + st := GetGlobalStreamTracker() + srvID := BuildIntraProxySenderStreamID(peer, key.clientShard, key.serverShard) + st.UnregisterStream(srvID) + delete(ps.senders, key) +} + +// ClosePeer closes and removes all resources for a specific peer. +func (m *intraProxyManager) ClosePeer(peer string) { + m.streamsMu.Lock() + defer m.streamsMu.Unlock() + if ps, ok := m.peers[peer]; ok { + m.closePeerLocked(peer, ps) + } +} + +// ClosePeerShard closes resources for a specific peer/shard pair. +func (m *intraProxyManager) ClosePeerShard(peer string, clientShard, serverShard history.ClusterShardID) { + key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + m.streamsMu.Lock() + defer m.streamsMu.Unlock() + if ps, ok := m.peers[peer]; ok { + m.closePeerShardLocked(peer, ps, key) + } +} + +// CloseShardAcrossPeers closes all sender/receiver streams for any peer that involve the specified shard +// as either client or server shard. Useful when a local shard is unregistered. +func (m *intraProxyManager) CloseShardAcrossPeers(shard history.ClusterShardID) { + m.streamsMu.Lock() + defer m.streamsMu.Unlock() + for peer, ps := range m.peers { + // Collect keys to avoid mutating map during iteration + toClose := make([]peerStreamKey, 0) + for key := range ps.receivers { + if (key.clientShard.ClusterID == shard.ClusterID && key.clientShard.ShardID == shard.ShardID) || + (key.serverShard.ClusterID == shard.ClusterID && key.serverShard.ShardID == shard.ShardID) { + toClose = append(toClose, key) + } + } + for key := range ps.senders { + if (key.clientShard.ClusterID == shard.ClusterID && key.clientShard.ShardID == shard.ShardID) || + (key.serverShard.ClusterID == shard.ClusterID && key.serverShard.ShardID == shard.ShardID) { + // ensure key is present in toClose for unified cleanup + toClose = append(toClose, key) + } + } + for _, key := range toClose { + m.closePeerShardLocked(peer, ps, key) + } + } +} + +// ReconcilePeerStreams ensures receivers exist for desired (local shard, remote shard) pairs +// for a given peer and closes any sender/receiver not in the desired set. +// This mirrors the Temporal StreamReceiverMonitor approach. +func (m *intraProxyManager) ReconcilePeerStreams( + p *Proxy, + peerNodeName string, +) { + f := func() { + m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("peerNodeName", peerNodeName)) + defer m.logger.Info("ReconcilePeerStreams done", tag.NewStringTag("peerNodeName", peerNodeName)) + + if mode := m.shardCountConfig.Mode; mode != config.ShardCountRouting { + return + } + localShards := m.shardManager.GetLocalShards() + remoteShards, err := m.shardManager.GetRemoteShardsForPeer(peerNodeName) + if err != nil { + m.logger.Error("Failed to get remote shards for peer", tag.Error(err)) + return + } + m.logger.Info("ReconcilePeerStreams", + tag.NewStringTag("peerNodeName", peerNodeName), + tag.NewStringTag("remoteShards", fmt.Sprintf("%v", remoteShards)), + tag.NewStringTag("localShards", fmt.Sprintf("%v", localShards)), + ) + + // Build desired set of cross-cluster pairs + desired := make(map[peerStreamKey]string) + for _, l := range localShards { + for peer, shards := range remoteShards { + for _, r := range shards.Shards { + if l.ClusterID == r.ID.ClusterID { + continue + } + desired[peerStreamKey{clientShard: l, serverShard: r.ID}] = peer + } + } + } + + m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("desired", fmt.Sprintf("%v", desired))) + + // Ensure all desired receivers exist + for key := range desired { + m.EnsureReceiverForPeerShard(p, desired[key], key.clientShard, key.serverShard) + } + + // Prune anything not desired + check := func(ps *peerState) { + // Collect keys to close for receivers + for key := range ps.receivers { + if _, ok2 := desired[key]; !ok2 { + m.closePeerShardLocked(peerNodeName, ps, key) + } + } + // And for server-side senders, if they don't belong to desired pairs + for key := range ps.senders { + if _, ok2 := desired[key]; !ok2 { + m.closePeerShardLocked(peerNodeName, ps, key) + } + } + } + + m.streamsMu.Lock() + if peerNodeName != "" { + if ps, ok := m.peers[peerNodeName]; ok && ps != nil { + check(ps) + } + } else { + for _, ps := range m.peers { + check(ps) + } + } + m.streamsMu.Unlock() + } + go f() +} diff --git a/proxy/proxy.go b/proxy/proxy.go index da678824..8ce1e68d 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -46,8 +46,9 @@ type ( inboundHealthCheckServer *http.Server outboundHealthCheckServer *http.Server metricsServer *http.Server - shardManager ShardManager logger log.Logger + shardManagers map[migrationId]ShardManager + intraMgrs map[migrationId]*intraProxyManager // remoteSendChannels maps shard IDs to send channels for replication message routing remoteSendChannels map[history.ClusterShardID]chan RoutedMessage @@ -63,14 +64,15 @@ type ( } ) -func NewProxy(configProvider config.ConfigProvider, shardManager ShardManager, logger log.Logger) *Proxy { +func NewProxy(configProvider config.ConfigProvider, logger log.Logger) *Proxy { s2sConfig := config.ToClusterConnConfig(configProvider.GetS2SProxyConfig()) ctx, cancel := context.WithCancel(context.Background()) proxy := &Proxy{ lifetime: ctx, cancel: cancel, clusterConnections: make(map[migrationId]*ClusterConnection, len(s2sConfig.MuxTransports)), - shardManager: shardManager, + intraMgrs: make(map[migrationId]*intraProxyManager), + shardManagers: make(map[migrationId]ShardManager), logger: log.NewThrottledLogger( logger, func() float64 { @@ -87,13 +89,52 @@ func NewProxy(configProvider config.ConfigProvider, shardManager ShardManager, l if s2sConfig.Metrics != nil { proxy.metricsConfig = s2sConfig.Metrics } + + // TODO: Wire intra-proxy manager callbacks + // // Wire memberlist peer-join callback to reconcile intra-proxy receivers for local/remote pairs + // shardManager.SetOnPeerJoin(func(nodeName string) { + // logger.Info("OnPeerJoin", tag.NewStringTag("nodeName", nodeName)) + // defer logger.Info("OnPeerJoin done", tag.NewStringTag("nodeName", nodeName)) + // proxy.intraMgr.ReconcilePeerStreams(proxy, nodeName) + // }) + + // // Wire peer-leave to cleanup intra-proxy resources for that peer + // shardManager.SetOnPeerLeave(func(nodeName string) { + // logger.Info("OnPeerLeave", tag.NewStringTag("nodeName", nodeName)) + // defer logger.Info("OnPeerLeave done", tag.NewStringTag("nodeName", nodeName)) + // proxy.intraMgr.ReconcilePeerStreams(proxy, nodeName) + // }) + + // // Wire local shard changes to reconcile intra-proxy receivers + // shardManager.SetOnLocalShardChange(func(shard history.ClusterShardID, added bool) { + // logger.Info("OnLocalShardChange", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + // defer logger.Info("OnLocalShardChange done", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + // proxy.intraMgr.ReconcilePeerStreams(proxy, "") + // }) + + // // Wire remote shard changes to reconcile intra-proxy receivers + // shardManager.SetOnRemoteShardChange(func(peer string, shard history.ClusterShardID, added bool) { + // logger.Info("OnRemoteShardChange", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + // defer logger.Info("OnRemoteShardChange done", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + // proxy.intraMgr.ReconcilePeerStreams(proxy, peer) + // }) + for _, clusterCfg := range s2sConfig.ClusterConnections { + shardManager, err := NewShardManager(configProvider, logger) + if err != nil { + logger.Fatal("Failed to create shard manager", tag.Error(err)) + continue + } cc, err := NewClusterConnection(ctx, clusterCfg, shardManager, logger) if err != nil { logger.Fatal("Incorrectly configured Mux cluster connection", tag.Error(err), tag.NewStringTag("name", clusterCfg.Name)) continue } - proxy.clusterConnections[migrationId{clusterCfg.Name}] = cc + migrationId := migrationId{clusterCfg.Name} + proxy.clusterConnections[migrationId] = cc + proxy.intraMgrs[migrationId] = newIntraProxyManager(logger, shardManager, clusterCfg.ShardCountConfig) + proxy.shardManagers[migrationId] = shardManager + shardManager.SetIntraProxyManager(proxy.intraMgrs[migrationId]) } // TODO: correctly host multiple health checks if len(s2sConfig.ClusterConnections) > 0 && s2sConfig.ClusterConnections[0].InboundHealthCheck.ListenAddress != "" { @@ -202,8 +243,8 @@ func (s *Proxy) Start() error { ` it needs at least the following path: metrics.prometheus.listenAddress`) } - if s.shardManager != nil { - if err := s.shardManager.Start(s.lifetime); err != nil { + for _, shardManager := range s.shardManagers { + if err := shardManager.Start(s.lifetime); err != nil { return err } } @@ -240,8 +281,12 @@ func (s *Proxy) Describe() string { } // GetShardInfo returns debug information about shard distribution -func (s *Proxy) GetShardInfo() ShardDebugInfo { - return s.shardManager.GetShardInfo() +func (s *Proxy) GetShardInfos() []ShardDebugInfo { + var shardInfos []ShardDebugInfo + for _, shardManager := range s.shardManagers { + shardInfos = append(shardInfos, shardManager.GetShardInfo()) + } + return shardInfos } // GetChannelInfo returns debug information about active channels @@ -278,6 +323,11 @@ func (s *Proxy) GetChannelInfo() ChannelDebugInfo { } } +// GetIntraProxyManager returns the intra-proxy manager instance +func (s *Proxy) GetIntraProxyManager(migrationId migrationId) *intraProxyManager { + return s.intraMgrs[migrationId] +} + // SetRemoteSendChan registers a send channel for a specific shard ID func (s *Proxy) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) { s.remoteSendChannelsMu.Lock() diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index e9754c81..1615859c 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "sync" + "time" "go.temporal.io/server/api/adminservice/v1" replicationv1 "go.temporal.io/server/api/replication/v1" @@ -246,10 +247,8 @@ func (s *proxyStreamSender) Run( // Register local stream tracking for sender (short id, include role) s.streamTracker = GetGlobalStreamTracker() - s.streamID = fmt.Sprintf("snd-%s-%s", - ClusterShardIDtoString(s.sourceShardID), - ClusterShardIDtoString(s.targetShardID), - ) + s.streamID = BuildSenderStreamID(s.sourceShardID, s.targetShardID) + s.logger = log.With(s.logger, tag.NewStringTag("streamID", s.streamID)) s.streamTracker.RegisterStream( s.streamID, "StreamWorkflowReplicationMessages", @@ -357,7 +356,7 @@ func (s *proxyStreamSender) recvAck( s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", originalAck)) - if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger) { + if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger, originalAck, true) { sent[srcShard] = true numRemaining-- progress = true @@ -419,7 +418,7 @@ func (s *proxyStreamSender) recvAck( } // Log fallback ACK for this source shard s.logger.Info("Sender forwarding fallback ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", prev)) - if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger) { + if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger, prev, true) { sent[srcShard] = true numRemaining-- progress = true @@ -613,10 +612,8 @@ func (r *proxyStreamReceiver) Run( r.lastSentMin = 0 // Register a new local stream for tracking (short id, include role) - r.streamID = fmt.Sprintf("rcv-%s-%s", - ClusterShardIDtoString(r.sourceShardID), - ClusterShardIDtoString(r.targetShardID), - ) + r.streamID = BuildReceiverStreamID(r.sourceShardID, r.targetShardID) + r.logger = log.With(r.logger, tag.NewStringTag("streamID", r.streamID)) r.streamTracker = GetGlobalStreamTracker() r.streamTracker.RegisterStream( r.streamID, @@ -733,27 +730,22 @@ func (r *proxyStreamReceiver) recvReplicationMessages( if sentByTarget[targetShardID] { continue } - if ch, ok := r.proxy.GetRemoteSendChan(targetShardID); ok { - msg := RoutedMessage{ - SourceShard: r.sourceShardID, - Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ - Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ - Messages: &replicationv1.WorkflowReplicationMessages{ - ReplicationTasks: tasks, - ExclusiveHighWatermark: tasks[len(tasks)-1].RawTaskInfo.TaskId + 1, - Priority: attr.Messages.Priority, - }, + msg := RoutedMessage{ + SourceShard: r.sourceShardID, + Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ReplicationTasks: tasks, + ExclusiveHighWatermark: tasks[len(tasks)-1].RawTaskInfo.TaskId + 1, + Priority: attr.Messages.Priority, }, }, - } - select { - case ch <- msg: - sentByTarget[targetShardID] = true - numRemaining-- - progress = true - case <-shutdownChan.Channel(): - return nil - } + }, + } + if r.shardManager.DeliverMessagesToShardOwner(targetShardID, &msg, r.proxy, shutdownChan, r.logger) { + sentByTarget[targetShardID] = true + numRemaining-- + progress = true } else { if !loggedByTarget[targetShardID] { r.logger.Warn("No send channel found for target shard; retrying until available", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 804c194b..ee2bec68 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -4,8 +4,6 @@ import ( "context" "encoding/json" "fmt" - "hash/fnv" - "sort" "strconv" "sync" "time" @@ -30,38 +28,67 @@ type ( RegisterShard(clientShardID history.ClusterShardID) // UnregisterShard removes a clientShardID from this proxy's ownership UnregisterShard(clientShardID history.ClusterShardID) - // GetShardOwner returns the proxy node name that owns the given shard - GetShardOwner(clientShardID history.ClusterShardID) (string, bool) // GetProxyAddress returns the proxy service address for the given node name GetProxyAddress(nodeName string) (string, bool) // IsLocalShard checks if this proxy instance owns the given shard IsLocalShard(clientShardID history.ClusterShardID) bool + // GetNodeName returns the name of this proxy instance + GetNodeName() string // GetMemberNodes returns all active proxy nodes in the cluster GetMemberNodes() []string - // GetLocalShards returns all shards currently handled by this proxy instance - GetLocalShards() []history.ClusterShardID + // GetLocalShards returns all shards currently handled by this proxy instance, keyed by short id + GetLocalShards() map[string]history.ClusterShardID + // GetRemoteShardsForPeer returns all shards owned by the specified peer node, keyed by short id + GetRemoteShardsForPeer(peerNodeName string) (map[string]NodeShardState, error) // GetShardInfo returns debug information about shard distribution GetShardInfo() ShardDebugInfo + // GetShardOwner returns the node name that owns the given shard + GetShardOwner(shard history.ClusterShardID) (string, bool) // DeliverAckToShardOwner routes an ACK request to the appropriate shard owner (local or remote) - DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) bool + DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool + // DeliverMessagesToShardOwner routes replication messages to the appropriate shard owner (local or remote) + DeliverMessagesToShardOwner(targetShard history.ClusterShardID, routedMsg *RoutedMessage, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) bool + // SetOnPeerJoin registers a callback invoked when a new peer joins + SetOnPeerJoin(handler func(nodeName string)) + // SetOnPeerLeave registers a callback invoked when a peer leaves. + SetOnPeerLeave(handler func(nodeName string)) + // New: notify when local shard set changes + SetOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) + // New: notify when remote shard set changes for a peer + SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) + + SetIntraProxyManager(intraMgr *intraProxyManager) + GetIntraProxyManager() *intraProxyManager } shardManagerImpl struct { - config *config.MemberlistConfig - logger log.Logger - ml *memberlist.Memberlist - delegate *shardDelegate - mutex sync.RWMutex - localAddr string - started bool + config *config.MemberlistConfig + logger log.Logger + ml *memberlist.Memberlist + delegate *shardDelegate + mutex sync.RWMutex + localAddr string + started bool + onPeerJoin func(nodeName string) + onPeerLeave func(nodeName string) + // New callbacks + onLocalShardChange func(shard history.ClusterShardID, added bool) + onRemoteShardChange func(peer string, shard history.ClusterShardID, added bool) + // Local shards owned by this node, keyed by short id + localShards map[string]ShardInfo + intraMgr *intraProxyManager } // shardDelegate implements memberlist.Delegate for shard state management shardDelegate struct { - manager *shardManagerImpl - logger log.Logger - localShards map[string]history.ClusterShardID // key: "clusterID:shardID" - mutex sync.RWMutex + manager *shardManagerImpl + logger log.Logger + } + + // ShardInfo describes a local shard and its creation time + ShardInfo struct { + ID history.ClusterShardID `json:"id"` + Created time.Time `json:"created"` } // ShardMessage represents shard ownership changes broadcast to cluster @@ -74,9 +101,9 @@ type ( // NodeShardState represents all shards owned by a node NodeShardState struct { - NodeName string `json:"node"` - Shards map[string]history.ClusterShardID `json:"shards"` - Updated time.Time `json:"updated"` + NodeName string `json:"node"` + Shards map[string]ShardInfo `json:"shards"` + Updated time.Time `json:"updated"` } ) @@ -88,14 +115,15 @@ func NewShardManager(configProvider config.ConfigProvider, logger log.Logger) (S } delegate := &shardDelegate{ - logger: logger, - localShards: make(map[string]history.ClusterShardID), + logger: logger, } sm := &shardManagerImpl{ - config: cfg, - logger: logger, - delegate: delegate, + config: cfg, + logger: logger, + delegate: delegate, + localShards: make(map[string]ShardInfo), + intraMgr: nil, } delegate.manager = sm @@ -103,11 +131,42 @@ func NewShardManager(configProvider config.ConfigProvider, logger log.Logger) (S return sm, nil } +// SetOnPeerJoin registers a callback invoked on new peer joins. +func (sm *shardManagerImpl) SetOnPeerJoin(handler func(nodeName string)) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.onPeerJoin = handler +} + +// SetOnPeerLeave registers a callback invoked when a peer leaves. +func (sm *shardManagerImpl) SetOnPeerLeave(handler func(nodeName string)) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.onPeerLeave = handler +} + +// SetOnLocalShardChange registers local shard change callback. +func (sm *shardManagerImpl) SetOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.onLocalShardChange = handler +} + +// SetOnRemoteShardChange registers remote shard change callback. +func (sm *shardManagerImpl) SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.onRemoteShardChange = handler +} + func (sm *shardManagerImpl) Start(lifetime context.Context) error { + sm.logger.Info("Starting shard manager") + sm.mutex.Lock() defer sm.mutex.Unlock() if sm.started { + sm.logger.Info("Shard manager already started") return nil } @@ -188,7 +247,7 @@ func (sm *shardManagerImpl) Stop() { } func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) { - sm.delegate.addLocalShard(clientShardID) + sm.addLocalShard(clientShardID) sm.broadcastShardChange("register", clientShardID) // Trigger memberlist metadata update to propagate NodeMeta to other nodes @@ -197,10 +256,17 @@ func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) } } + // Notify listeners + if sm.onLocalShardChange != nil { + sm.onLocalShardChange(clientShardID, true) + } } func (sm *shardManagerImpl) UnregisterShard(clientShardID history.ClusterShardID) { - sm.delegate.removeLocalShard(clientShardID) + sm.removeLocalShard(clientShardID) + sm.mutex.Lock() + delete(sm.localShards, ClusterShardIDtoShortString(clientShardID)) + sm.mutex.Unlock() sm.broadcastShardChange("unregister", clientShardID) // Trigger memberlist metadata update to propagate NodeMeta to other nodes @@ -209,15 +275,10 @@ func (sm *shardManagerImpl) UnregisterShard(clientShardID history.ClusterShardID sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) } } -} - -func (sm *shardManagerImpl) GetShardOwner(clientShardID history.ClusterShardID) (string, bool) { - if !sm.started { - return "", false + // Notify listeners + if sm.onLocalShardChange != nil { + sm.onLocalShardChange(clientShardID, false) } - - // Use consistent hashing to determine shard owner - return sm.consistentHashOwner(clientShardID), true } func (sm *shardManagerImpl) IsLocalShard(clientShardID history.ClusterShardID) bool { @@ -225,11 +286,15 @@ func (sm *shardManagerImpl) IsLocalShard(clientShardID history.ClusterShardID) b return true // If not using memberlist, handle locally } - owner, found := sm.GetShardOwner(clientShardID) - return found && owner == sm.config.NodeName + sm.mutex.RLock() + defer sm.mutex.RUnlock() + + _, found := sm.localShards[ClusterShardIDtoShortString(clientShardID)] + return found } func (sm *shardManagerImpl) GetProxyAddress(nodeName string) (string, bool) { + // TODO: get the proxy address from the memberlist metadata if sm.config.ProxyAddresses == nil { return "", false } @@ -237,6 +302,10 @@ func (sm *shardManagerImpl) GetProxyAddress(nodeName string) (string, bool) { return addr, found } +func (sm *shardManagerImpl) GetNodeName() string { + return sm.config.NodeName +} + func (sm *shardManagerImpl) GetMemberNodes() []string { if !sm.started || sm.ml == nil { return []string{sm.config.NodeName} @@ -268,69 +337,103 @@ func (sm *shardManagerImpl) GetMemberNodes() []string { } } -func (sm *shardManagerImpl) GetLocalShards() []history.ClusterShardID { - sm.delegate.mutex.RLock() - defer sm.delegate.mutex.RUnlock() - - shards := make([]history.ClusterShardID, 0, len(sm.delegate.localShards)) - for _, shard := range sm.delegate.localShards { - shards = append(shards, shard) +func (sm *shardManagerImpl) GetLocalShards() map[string]history.ClusterShardID { + sm.mutex.RLock() + defer sm.mutex.RUnlock() + shards := make(map[string]history.ClusterShardID, len(sm.localShards)) + for k, v := range sm.localShards { + shards[k] = v.ID } return shards } func (sm *shardManagerImpl) GetShardInfo() ShardDebugInfo { - localShards := sm.GetLocalShards() - clusterNodes := sm.GetMemberNodes() - - // Build remote shard maps by querying memberlist metadata directly - remoteShards := make(map[string]string) - remoteShardCounts := make(map[string]int) - - // Initialize counts for all nodes - for _, node := range clusterNodes { - remoteShardCounts[node] = 0 + localShardMap := sm.GetLocalShards() + remoteShards, err := sm.GetRemoteShardsForPeer("") + if err != nil { + sm.logger.Error("Failed to get remote shards", tag.Error(err)) } - // Count local shards for this node - remoteShardCounts[sm.config.NodeName] = len(localShards) + remoteShardsMap := make(map[string]string) + remoteShardCounts := make(map[string]int) - // Collect shard ownership information from all cluster members - if sm.ml != nil { - for _, member := range sm.ml.Members() { - if len(member.Meta) > 0 { - var nodeState NodeShardState - if err := json.Unmarshal(member.Meta, &nodeState); err == nil { - nodeName := nodeState.NodeName - if nodeName != "" { - remoteShardCounts[nodeName] = len(nodeState.Shards) - - // Add remote shards (exclude local node) - if nodeName != sm.config.NodeName { - for _, shard := range nodeState.Shards { - shardKey := fmt.Sprintf("%d:%d", shard.ClusterID, shard.ShardID) - remoteShards[shardKey] = nodeName - } - } - } - } - } + for nodeName, shards := range remoteShards { + for _, shard := range shards.Shards { + shardKey := ClusterShardIDtoShortString(shard.ID) + remoteShardsMap[shardKey] = nodeName } + remoteShardCounts[nodeName] = len(shards.Shards) } return ShardDebugInfo{ Enabled: true, - ForwardingEnabled: sm.config.EnableForwarding, NodeName: sm.config.NodeName, - LocalShards: localShards, - LocalShardCount: len(localShards), - ClusterNodes: clusterNodes, - ClusterSize: len(clusterNodes), - RemoteShards: remoteShards, + LocalShards: localShardMap, + LocalShardCount: len(localShardMap), + RemoteShards: remoteShardsMap, RemoteShardCounts: remoteShardCounts, } } +func (sm *shardManagerImpl) GetShardOwner(shard history.ClusterShardID) (string, bool) { + // FIXME: improve this: store remote shards in a map in the shardManagerImpl + remoteShards, err := sm.GetRemoteShardsForPeer("") + if err != nil { + sm.logger.Error("Failed to get remote shards", tag.Error(err)) + } + for nodeName, shards := range remoteShards { + for _, s := range shards.Shards { + if s.ID == shard { + return nodeName, true + } + } + } + return "", false +} + +// GetRemoteShardsForPeer returns all shards owned by the specified peer node. +// Non-blocking: uses memberlist metadata and tolerates timeouts by returning a best-effort set. +func (sm *shardManagerImpl) GetRemoteShardsForPeer(peerNodeName string) (map[string]NodeShardState, error) { + result := make(map[string]NodeShardState) + if sm.ml == nil { + return result, nil + } + + // Read members with a short timeout to avoid blocking debug paths + membersChan := make(chan []*memberlist.Node, 1) + go func() { + defer func() { _ = recover() }() + membersChan <- sm.ml.Members() + }() + + var members []*memberlist.Node + select { + case members = <-membersChan: + case <-time.After(100 * time.Millisecond): + sm.logger.Warn("GetRemoteShardsForPeer timeout") + return result, fmt.Errorf("timeout") + } + + for _, member := range members { + if member == nil || len(member.Meta) == 0 { + continue + } + if member.Name == sm.GetNodeName() { + continue + } + if peerNodeName != "" && member.Name != peerNodeName { + continue + } + var nodeState NodeShardState + if err := json.Unmarshal(member.Meta, &nodeState); err != nil { + continue + } + result[member.Name] = nodeState + } + + return result, nil +} + // DeliverAckToShardOwner routes an ACK to the local shard owner or records intent for remote forwarding. func (sm *shardManagerImpl) DeliverAckToShardOwner( sourceShard history.ClusterShardID, @@ -338,20 +441,95 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger, + ack int64, + allowForward bool, ) bool { + logger = log.With(logger, tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewInt64("ack", ack)) if ackCh, ok := proxy.GetLocalAckChan(sourceShard); ok { select { case ackCh <- *routedAck: + logger.Info("Delivered ACK to local shard owner") return true case <-shutdownChan.Channel(): return false } - } else { - logger.Warn("No local ack channel for source shard", tag.NewStringTag("shard", ClusterShardIDtoString(sourceShard))) } + if !allowForward { + logger.Warn("No local ack channel for source shard, forwarding ACK to shard owner is not allowed") + return false + } + + // Attempt remote delivery via intra-proxy when enabled and shard is remote + if owner, ok := sm.GetShardOwner(sourceShard); ok && owner != sm.config.NodeName { + if addr, found := sm.GetProxyAddress(owner); found { + clientShard := routedAck.TargetShard + serverShard := sourceShard + mgr := proxy.GetIntraProxyManager(migrationId{owner}) + // Synchronous send to preserve ordering + if err := mgr.sendAck(context.Background(), owner, clientShard, serverShard, proxy, routedAck.Req); err != nil { + logger.Error("Failed to forward ACK to shard owner via intra-proxy", tag.Error(err), tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) + return false + } + logger.Info("Forwarded ACK to shard owner via intra-proxy", tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) + return true + } + logger.Warn("Owner proxy address not found for shard") + return false + } + + logger.Warn("No remote shard owner found for source shard") + return false +} + +// DeliverMessagesToShardOwner routes replication messages to the local target shard owner +// or forwards to the remote owner via intra-proxy stream synchronously. +func (sm *shardManagerImpl) DeliverMessagesToShardOwner( + targetShard history.ClusterShardID, + routedMsg *RoutedMessage, + proxy *Proxy, + shutdownChan channel.ShutdownOnce, + logger log.Logger, +) bool { + // Try local delivery first + if ch, ok := proxy.GetRemoteSendChan(targetShard); ok { + select { + case ch <- *routedMsg: + return true + case <-shutdownChan.Channel(): + return false + } + } + + // Attempt remote delivery via intra-proxy when enabled and shard is remote + if sm.config != nil { + if owner, ok := sm.GetShardOwner(targetShard); ok && owner != sm.config.NodeName { + if addr, found := sm.GetProxyAddress(owner); found { + if mgr := sm.GetIntraProxyManager(); mgr != nil { + resp := routedMsg.Resp + if err := mgr.sendReplicationMessages(context.Background(), owner, targetShard, routedMsg.SourceShard, proxy, resp); err != nil { + logger.Error("Failed to forward replication messages to shard owner via intra-proxy", tag.Error(err), tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) + return false + } + return true + } + } else { + logger.Warn("Owner proxy address not found for target shard", tag.NewStringTag("owner", owner), tag.NewStringTag("shard", ClusterShardIDtoString(targetShard))) + } + } + } + + logger.Warn("No local send channel for target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard))) return false } +func (sm *shardManagerImpl) SetIntraProxyManager(intraMgr *intraProxyManager) { + sm.intraMgr = intraMgr +} + +func (sm *shardManagerImpl) GetIntraProxyManager() *intraProxyManager { + return sm.intraMgr +} + func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.ClusterShardID) { if !sm.started || sm.ml == nil { return @@ -388,33 +566,11 @@ func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.C } } -func (sm *shardManagerImpl) consistentHashOwner(shard history.ClusterShardID) string { - nodes := sm.GetMemberNodes() - if len(nodes) == 0 { - return sm.config.NodeName - } - - // Sort nodes for consistent ordering - sort.Strings(nodes) - - // Hash the shard ID - h := fnv.New32a() - shardKey := fmt.Sprintf("%d:%d", shard.ClusterID, shard.ShardID) - h.Write([]byte(shardKey)) - hash := h.Sum32() - - // Use consistent hashing to determine owner - return nodes[hash%uint32(len(nodes))] -} - // shardDelegate implements memberlist.Delegate func (sd *shardDelegate) NodeMeta(limit int) []byte { - sd.mutex.RLock() - defer sd.mutex.RUnlock() - state := NodeShardState{ NodeName: sd.manager.config.NodeName, - Shards: sd.localShards, + Shards: sd.manager.localShards, Updated: time.Now(), } @@ -443,6 +599,24 @@ func (sd *shardDelegate) NotifyMsg(data []byte) { tag.NewStringTag("type", msg.Type), tag.NewStringTag("node", msg.NodeName), tag.NewStringTag("shard", ClusterShardIDtoString(msg.ClientShard))) + + // Inform listeners about remote shard changes + if sd.manager != nil && sd.manager.onRemoteShardChange != nil { + added := msg.Type == "register" + + // if shard is previously registered as local shard, but now is registered as remote shard, + // check if the remote shard is newer than the local shard. If so, unregister the local shard. + if added { + localShard, ok := sd.manager.localShards[ClusterShardIDtoShortString(msg.ClientShard)] + if ok { + if localShard.Created.Before(msg.Timestamp) { + sd.manager.UnregisterShard(msg.ClientShard) + } + } + } + + sd.manager.onRemoteShardChange(msg.NodeName, msg.ClientShard, added) + } } func (sd *shardDelegate) GetBroadcasts(overhead, limit int) [][]byte { @@ -451,7 +625,7 @@ func (sd *shardDelegate) GetBroadcasts(overhead, limit int) [][]byte { } func (sd *shardDelegate) LocalState(join bool) []byte { - return sd.NodeMeta(512) + return sd.NodeMeta(4096) // TODO: set this to a reasonable value } func (sd *shardDelegate) MergeRemoteState(buf []byte, join bool) { @@ -463,23 +637,25 @@ func (sd *shardDelegate) MergeRemoteState(buf []byte, join bool) { sd.logger.Info("Merged remote shard state", tag.NewStringTag("node", state.NodeName), - tag.NewStringTag("shards", strconv.Itoa(len(state.Shards)))) + tag.NewStringTag("shards", strconv.Itoa(len(state.Shards))), + tag.NewStringTag("state", fmt.Sprintf("%+v", state))) } -func (sd *shardDelegate) addLocalShard(shard history.ClusterShardID) { - sd.mutex.Lock() - defer sd.mutex.Unlock() +func (sm *shardManagerImpl) addLocalShard(shard history.ClusterShardID) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + key := ClusterShardIDtoShortString(shard) + sm.localShards[key] = ShardInfo{ID: shard, Created: time.Now()} - key := fmt.Sprintf("%d:%d", shard.ClusterID, shard.ShardID) - sd.localShards[key] = shard } -func (sd *shardDelegate) removeLocalShard(shard history.ClusterShardID) { - sd.mutex.Lock() - defer sd.mutex.Unlock() +func (sm *shardManagerImpl) removeLocalShard(shard history.ClusterShardID) { + sm.mutex.Lock() + defer sm.mutex.Unlock() - key := fmt.Sprintf("%d:%d", shard.ClusterID, shard.ShardID) - delete(sd.localShards, key) + key := ClusterShardIDtoShortString(shard) + delete(sm.localShards, key) } // shardEventDelegate handles memberlist cluster events @@ -516,16 +692,19 @@ func (nsm *noopShardManager) UnregisterShard(history.ClusterShardID) func (nsm *noopShardManager) GetShardOwner(history.ClusterShardID) (string, bool) { return "", false } func (nsm *noopShardManager) GetProxyAddress(string) (string, bool) { return "", false } func (nsm *noopShardManager) IsLocalShard(history.ClusterShardID) bool { return true } +func (nsm *noopShardManager) GetNodeName() string { return "" } func (nsm *noopShardManager) GetMemberNodes() []string { return []string{} } -func (nsm *noopShardManager) GetLocalShards() []history.ClusterShardID { - return []history.ClusterShardID{} +func (nsm *noopShardManager) GetLocalShards() map[string]history.ClusterShardID { + return make(map[string]history.ClusterShardID) +} +func (nsm *noopShardManager) GetRemoteShardsForPeer(string) (map[string]NodeShardState, error) { + return make(map[string]NodeShardState), nil } func (nsm *noopShardManager) GetShardInfo() ShardDebugInfo { return ShardDebugInfo{ Enabled: false, - ForwardingEnabled: false, NodeName: "", - LocalShards: []history.ClusterShardID{}, + LocalShards: make(map[string]history.ClusterShardID), LocalShardCount: 0, ClusterNodes: []string{}, ClusterSize: 0, @@ -534,7 +713,14 @@ func (nsm *noopShardManager) GetShardInfo() ShardDebugInfo { } } -func (nsm *noopShardManager) DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) bool { +func (nsm *noopShardManager) SetOnPeerJoin(handler func(nodeName string)) {} +func (nsm *noopShardManager) SetOnPeerLeave(handler func(nodeName string)) {} +func (nsm *noopShardManager) SetOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) { +} +func (nsm *noopShardManager) SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) { +} + +func (nsm *noopShardManager) DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool { if proxy != nil { if ackCh, ok := proxy.GetLocalAckChan(srcShard); ok { select { @@ -547,3 +733,23 @@ func (nsm *noopShardManager) DeliverAckToShardOwner(srcShard history.ClusterShar } return false } + +func (nsm *noopShardManager) DeliverMessagesToShardOwner(targetShard history.ClusterShardID, routedMsg *RoutedMessage, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) bool { + if proxy != nil { + if ch, ok := proxy.GetRemoteSendChan(targetShard); ok { + select { + case ch <- *routedMsg: + return true + case <-shutdownChan.Channel(): + return false + } + } + } + return false +} + +func (nsm *noopShardManager) SetIntraProxyManager(intraMgr *intraProxyManager) { +} +func (nsm *noopShardManager) GetIntraProxyManager() *intraProxyManager { + return nil +} diff --git a/proxy/stream_tracker.go b/proxy/stream_tracker.go index ede493f0..b6214814 100644 --- a/proxy/stream_tracker.go +++ b/proxy/stream_tracker.go @@ -4,6 +4,8 @@ import ( "fmt" "sync" "time" + + "go.temporal.io/server/client/history" ) const ( @@ -155,6 +157,32 @@ func GetGlobalStreamTracker() *StreamTracker { return globalStreamTracker } +// BuildSenderStreamID returns the canonical sender stream ID. +func BuildSenderStreamID(source, target history.ClusterShardID) string { + return fmt.Sprintf("snd-%s-%s", ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target)) +} + +// BuildReceiverStreamID returns the canonical receiver stream ID. +func BuildReceiverStreamID(source, target history.ClusterShardID) string { + return fmt.Sprintf("rcv-%s-%s", ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target)) +} + +// BuildForwarderStreamID returns the canonical forwarder stream ID. +// Note: forwarder uses server-first ordering in the ID. +func BuildForwarderStreamID(client, server history.ClusterShardID) string { + return fmt.Sprintf("fwd-%s-%s", ClusterShardIDtoShortString(server), ClusterShardIDtoShortString(client)) +} + +// BuildIntraProxySenderStreamID returns the server-side intra-proxy stream ID for a peer and shard pair. +func BuildIntraProxySenderStreamID(peer string, source, target history.ClusterShardID) string { + return fmt.Sprintf("ip-snd-%s-%s|%s", peer, ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target)) +} + +// BuildIntraProxyReceiverStreamID returns the client-side intra-proxy stream ID for a peer and shard pair. +func BuildIntraProxyReceiverStreamID(peer string, source, target history.ClusterShardID) string { + return fmt.Sprintf("ip-rcv-%s-%s|%s", peer, ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target)) +} + // formatDurationSeconds formats a duration in seconds to a readable string func formatDurationSeconds(totalSeconds int) string { if totalSeconds < 60 { diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index 6f3232fd..c2918520 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -287,7 +287,7 @@ func (s *ReplicationTestSuite) createProxy( } configProvider := &simpleConfigProvider{cfg: *cfg} - proxy := s2sproxy.NewProxy(configProvider, nil, s.logger) + proxy := s2sproxy.NewProxy(configProvider, s.logger) s.NotNil(proxy) err := proxy.Start() From 85c6a97102a23b6b5b399679e8d143d95e03a486 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 26 Sep 2025 11:14:06 -0700 Subject: [PATCH 09/38] fix incorrect stream pair --- proxy/intra_proxy_router.go | 202 ++++++++++++++++++++---------------- proxy/proxy.go | 37 ++----- proxy/proxy_streams.go | 51 ++++++--- proxy/shard_manager.go | 34 ++++++ 4 files changed, 190 insertions(+), 134 deletions(-) diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 4ccf058b..1ef525b5 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -27,6 +27,8 @@ type intraProxyManager struct { streamsMu sync.RWMutex shardManager ShardManager shardCountConfig config.ShardCountConfig + proxy *Proxy + notifyCh chan struct{} // Group state by remote peer for unified lifecycle ops peers map[string]*peerState } @@ -39,16 +41,18 @@ type peerState struct { } type peerStreamKey struct { - clientShard history.ClusterShardID - serverShard history.ClusterShardID + targetShard history.ClusterShardID + sourceShard history.ClusterShardID } -func newIntraProxyManager(logger log.Logger, shardManager ShardManager, shardCountConfig config.ShardCountConfig) *intraProxyManager { +func newIntraProxyManager(logger log.Logger, proxy *Proxy, shardManager ShardManager, shardCountConfig config.ShardCountConfig) *intraProxyManager { return &intraProxyManager{ logger: logger, + proxy: proxy, shardManager: shardManager, shardCountConfig: shardCountConfig, peers: make(map[string]*peerState), + notifyCh: make(chan struct{}), } } @@ -172,7 +176,7 @@ type intraProxyStreamReceiver struct { } // Run opens the client stream with metadata, registers tracking, and starts receiver goroutines. -func (r *intraProxyStreamReceiver) Run(ctx context.Context, self *Proxy, conn *grpc.ClientConn) error { +func (r *intraProxyStreamReceiver) Run(ctx context.Context, p *Proxy, conn *grpc.ClientConn) error { r.logger.Info("intraProxyStreamReceiver Run") // Build metadata according to receiver pattern: client=targetShard, server=sourceShard md := metadata.New(map[string]string{}) @@ -200,14 +204,14 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, self *Proxy, conn *g r.streamClient = stream // Register client-side intra-proxy stream in tracker - r.streamID = BuildIntraProxyReceiverStreamID(r.peerNodeName, r.targetShardID, r.sourceShardID) + r.streamID = BuildIntraProxyReceiverStreamID(r.peerNodeName, r.sourceShardID, r.targetShardID) r.logger = log.With(r.logger, tag.NewStringTag("streamID", r.streamID)) st := GetGlobalStreamTracker() st.RegisterStream(r.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(r.targetShardID), ClusterShardIDtoString(r.sourceShardID), StreamRoleForwarder) defer st.UnregisterStream(r.streamID) // Start replication receiver loop - return r.recvReplicationMessages(self) + return r.recvReplicationMessages(p) } // recvReplicationMessages receives replication messages and forwards to local shard owner. @@ -294,7 +298,7 @@ func (m *intraProxyManager) RegisterSender( if clientShard.ClusterID == serverShard.ClusterID { return } - key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} m.streamsMu.Lock() ps := m.peers[peerNodeName] if ps == nil { @@ -313,7 +317,7 @@ func (m *intraProxyManager) UnregisterSender( clientShard history.ClusterShardID, serverShard history.ClusterShardID, ) { - key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} m.streamsMu.Lock() if ps := m.peers[peerNodeName]; ps != nil && ps.senders != nil { delete(ps.senders, key) @@ -322,15 +326,15 @@ func (m *intraProxyManager) UnregisterSender( } // EnsureReceiverForPeerShard ensures a client stream and an ACK aggregator exist for the given peer/shard pair. -func (m *intraProxyManager) EnsureReceiverForPeerShard(p *Proxy, peerNodeName string, clientShard history.ClusterShardID, serverShard history.ClusterShardID) { +func (m *intraProxyManager) EnsureReceiverForPeerShard(p *Proxy, peerNodeName string, targetShard history.ClusterShardID, sourceShard history.ClusterShardID) { logger := log.With(m.logger, tag.NewStringTag("peerNodeName", peerNodeName), - tag.NewStringTag("clientShard", ClusterShardIDtoString(clientShard)), - tag.NewStringTag("serverShard", ClusterShardIDtoString(serverShard))) + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), + tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard))) logger.Info("EnsureReceiverForPeerShard") // Cross-cluster only - if clientShard.ClusterID == serverShard.ClusterID { + if targetShard.ClusterID == sourceShard.ClusterID { return } // Do not create intra-proxy streams to self instance @@ -338,11 +342,11 @@ func (m *intraProxyManager) EnsureReceiverForPeerShard(p *Proxy, peerNodeName st return } // Require at least one shard to be local to this instance - if !m.shardManager.IsLocalShard(clientShard) && !m.shardManager.IsLocalShard(serverShard) { + if !m.shardManager.IsLocalShard(targetShard) && !m.shardManager.IsLocalShard(sourceShard) { return } // Consolidated path: ensure stream and background loops - err := m.ensureStream(context.Background(), logger, peerNodeName, clientShard, serverShard, p) + err := m.ensureStream(context.Background(), logger, peerNodeName, targetShard, sourceShard, p) if err != nil { logger.Error("failed to ensureStream", tag.Error(err)) } @@ -428,12 +432,12 @@ func (m *intraProxyManager) ensureStream( ctx context.Context, logger log.Logger, peerNodeName string, - clientShard history.ClusterShardID, - serverShard history.ClusterShardID, + targetShard history.ClusterShardID, + sourceShard history.ClusterShardID, p *Proxy, ) error { logger.Info("ensureStream") - key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} // Fast path: already exists m.streamsMu.RLock() @@ -457,14 +461,14 @@ func (m *intraProxyManager) ensureStream( recv := &intraProxyStreamReceiver{ logger: log.With(m.logger, tag.NewStringTag("peerNodeName", peerNodeName), - tag.NewStringTag("targetShardID", ClusterShardIDtoString(clientShard)), - tag.NewStringTag("sourceShardID", ClusterShardIDtoString(serverShard))), + tag.NewStringTag("targetShardID", ClusterShardIDtoString(targetShard)), + tag.NewStringTag("sourceShardID", ClusterShardIDtoString(sourceShard))), shardManager: m.shardManager, proxy: p, intraMgr: m, peerNodeName: peerNodeName, - targetShardID: clientShard, - sourceShardID: serverShard, + targetShardID: targetShard, + sourceShardID: sourceShard, } // initialize shutdown handle and register it for lifecycle management recv.shutdown = channel.NewShutdownOnce() @@ -491,7 +495,7 @@ func (m *intraProxyManager) sendAck( p *Proxy, req *adminservice.StreamWorkflowReplicationMessagesRequest, ) error { - key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} m.streamsMu.RLock() defer m.streamsMu.RUnlock() if ps, ok := m.peers[peerNodeName]; ok && ps != nil { @@ -512,10 +516,10 @@ func (m *intraProxyManager) sendReplicationMessages( peerNodeName string, clientShard history.ClusterShardID, serverShard history.ClusterShardID, - self *Proxy, + p *Proxy, resp *adminservice.StreamWorkflowReplicationMessagesResponse, ) error { - key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} // Try server stream first with short retry/backoff to await registration deadline := time.Now().Add(2 * time.Second) @@ -559,7 +563,7 @@ func (m *intraProxyManager) closePeerLocked(peer string, ps *peerState) { shut.Shutdown() } st := GetGlobalStreamTracker() - cliID := BuildIntraProxyReceiverStreamID(peer, key.clientShard, key.serverShard) + cliID := BuildIntraProxyReceiverStreamID(peer, key.targetShard, key.sourceShard) st.UnregisterStream(cliID) delete(ps.recvShutdown, key) } @@ -570,7 +574,7 @@ func (m *intraProxyManager) closePeerLocked(peer string, ps *peerState) { // Unregister server-side tracker entries for key := range ps.senders { st := GetGlobalStreamTracker() - srvID := BuildIntraProxySenderStreamID(peer, key.clientShard, key.serverShard) + srvID := BuildIntraProxySenderStreamID(peer, key.targetShard, key.sourceShard) st.UnregisterStream(srvID) delete(ps.senders, key) } @@ -583,11 +587,11 @@ func (m *intraProxyManager) closePeerLocked(peer string, ps *peerState) { // closePeerShardLocked shuts down and removes resources for a specific peer/shard pair. Caller must hold m.streamsMu. func (m *intraProxyManager) closePeerShardLocked(peer string, ps *peerState, key peerStreamKey) { - m.logger.Info("closePeerShardLocked", tag.NewStringTag("peer", peer), tag.NewStringTag("clientShard", ClusterShardIDtoString(key.clientShard)), tag.NewStringTag("serverShard", ClusterShardIDtoString(key.serverShard))) + m.logger.Info("closePeerShardLocked", tag.NewStringTag("peer", peer), tag.NewStringTag("clientShard", ClusterShardIDtoString(key.targetShard)), tag.NewStringTag("serverShard", ClusterShardIDtoString(key.sourceShard))) if shut, ok := ps.recvShutdown[key]; ok && shut != nil { shut.Shutdown() st := GetGlobalStreamTracker() - cliID := BuildIntraProxyReceiverStreamID(peer, key.clientShard, key.serverShard) + cliID := BuildIntraProxyReceiverStreamID(peer, key.targetShard, key.sourceShard) st.UnregisterStream(cliID) delete(ps.recvShutdown, key) } @@ -602,7 +606,7 @@ func (m *intraProxyManager) closePeerShardLocked(peer string, ps *peerState, key delete(ps.receivers, key) } st := GetGlobalStreamTracker() - srvID := BuildIntraProxySenderStreamID(peer, key.clientShard, key.serverShard) + srvID := BuildIntraProxySenderStreamID(peer, key.targetShard, key.sourceShard) st.UnregisterStream(srvID) delete(ps.senders, key) } @@ -618,7 +622,7 @@ func (m *intraProxyManager) ClosePeer(peer string) { // ClosePeerShard closes resources for a specific peer/shard pair. func (m *intraProxyManager) ClosePeerShard(peer string, clientShard, serverShard history.ClusterShardID) { - key := peerStreamKey{clientShard: clientShard, serverShard: serverShard} + key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} m.streamsMu.Lock() defer m.streamsMu.Unlock() if ps, ok := m.peers[peer]; ok { @@ -635,14 +639,14 @@ func (m *intraProxyManager) CloseShardAcrossPeers(shard history.ClusterShardID) // Collect keys to avoid mutating map during iteration toClose := make([]peerStreamKey, 0) for key := range ps.receivers { - if (key.clientShard.ClusterID == shard.ClusterID && key.clientShard.ShardID == shard.ShardID) || - (key.serverShard.ClusterID == shard.ClusterID && key.serverShard.ShardID == shard.ShardID) { + if (key.targetShard.ClusterID == shard.ClusterID && key.targetShard.ShardID == shard.ShardID) || + (key.sourceShard.ClusterID == shard.ClusterID && key.sourceShard.ShardID == shard.ShardID) { toClose = append(toClose, key) } } for key := range ps.senders { - if (key.clientShard.ClusterID == shard.ClusterID && key.clientShard.ShardID == shard.ShardID) || - (key.serverShard.ClusterID == shard.ClusterID && key.serverShard.ShardID == shard.ShardID) { + if (key.targetShard.ClusterID == shard.ClusterID && key.targetShard.ShardID == shard.ShardID) || + (key.sourceShard.ClusterID == shard.ClusterID && key.sourceShard.ShardID == shard.ShardID) { // ensure key is present in toClose for unified cleanup toClose = append(toClose, key) } @@ -653,6 +657,29 @@ func (m *intraProxyManager) CloseShardAcrossPeers(shard history.ClusterShardID) } } +func (m *intraProxyManager) Start() error { + go func() { + for { + // timer + timer := time.NewTimer(1 * time.Second) + select { + case <-timer.C: + m.ReconcilePeerStreams(m.proxy, "") + case <-m.notifyCh: + m.ReconcilePeerStreams(m.proxy, "") + } + } + }() + return nil +} + +func (m *intraProxyManager) Notify() { + select { + case m.notifyCh <- struct{}{}: + default: + } +} + // ReconcilePeerStreams ensures receivers exist for desired (local shard, remote shard) pairs // for a given peer and closes any sender/receiver not in the desired set. // This mirrors the Temporal StreamReceiverMonitor approach. @@ -660,72 +687,69 @@ func (m *intraProxyManager) ReconcilePeerStreams( p *Proxy, peerNodeName string, ) { - f := func() { - m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("peerNodeName", peerNodeName)) - defer m.logger.Info("ReconcilePeerStreams done", tag.NewStringTag("peerNodeName", peerNodeName)) + m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("peerNodeName", peerNodeName)) + defer m.logger.Info("ReconcilePeerStreams done", tag.NewStringTag("peerNodeName", peerNodeName)) - if mode := m.shardCountConfig.Mode; mode != config.ShardCountRouting { - return - } - localShards := m.shardManager.GetLocalShards() - remoteShards, err := m.shardManager.GetRemoteShardsForPeer(peerNodeName) - if err != nil { - m.logger.Error("Failed to get remote shards for peer", tag.Error(err)) - return - } - m.logger.Info("ReconcilePeerStreams", - tag.NewStringTag("peerNodeName", peerNodeName), - tag.NewStringTag("remoteShards", fmt.Sprintf("%v", remoteShards)), - tag.NewStringTag("localShards", fmt.Sprintf("%v", localShards)), - ) - - // Build desired set of cross-cluster pairs - desired := make(map[peerStreamKey]string) - for _, l := range localShards { - for peer, shards := range remoteShards { - for _, r := range shards.Shards { - if l.ClusterID == r.ID.ClusterID { - continue - } - desired[peerStreamKey{clientShard: l, serverShard: r.ID}] = peer + if mode := m.shardCountConfig.Mode; mode != config.ShardCountRouting { + return + } + localShards := m.shardManager.GetLocalShards() + remoteShards, err := m.shardManager.GetRemoteShardsForPeer(peerNodeName) + if err != nil { + m.logger.Error("Failed to get remote shards for peer", tag.Error(err)) + return + } + m.logger.Info("ReconcilePeerStreams", + tag.NewStringTag("peerNodeName", peerNodeName), + tag.NewStringTag("remoteShards", fmt.Sprintf("%v", remoteShards)), + tag.NewStringTag("localShards", fmt.Sprintf("%v", localShards)), + ) + + // Build desired set of cross-cluster pairs + desired := make(map[peerStreamKey]string) + for _, l := range localShards { + for peer, shards := range remoteShards { + for _, r := range shards.Shards { + if l.ClusterID == r.ID.ClusterID { + continue } + desired[peerStreamKey{targetShard: l, sourceShard: r.ID}] = peer } } + } - m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("desired", fmt.Sprintf("%v", desired))) + m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("desired", fmt.Sprintf("%v", desired))) - // Ensure all desired receivers exist - for key := range desired { - m.EnsureReceiverForPeerShard(p, desired[key], key.clientShard, key.serverShard) - } + // Ensure all desired receivers exist + for key := range desired { + m.EnsureReceiverForPeerShard(p, desired[key], key.targetShard, key.sourceShard) + } - // Prune anything not desired - check := func(ps *peerState) { - // Collect keys to close for receivers - for key := range ps.receivers { - if _, ok2 := desired[key]; !ok2 { - m.closePeerShardLocked(peerNodeName, ps, key) - } + // Prune anything not desired + check := func(ps *peerState) { + // Collect keys to close for receivers + for key := range ps.receivers { + if _, ok2 := desired[key]; !ok2 { + m.closePeerShardLocked(peerNodeName, ps, key) } - // And for server-side senders, if they don't belong to desired pairs - for key := range ps.senders { - if _, ok2 := desired[key]; !ok2 { - m.closePeerShardLocked(peerNodeName, ps, key) - } + } + // And for server-side senders, if they don't belong to desired pairs + for key := range ps.senders { + if _, ok2 := desired[key]; !ok2 { + m.closePeerShardLocked(peerNodeName, ps, key) } } + } - m.streamsMu.Lock() - if peerNodeName != "" { - if ps, ok := m.peers[peerNodeName]; ok && ps != nil { - check(ps) - } - } else { - for _, ps := range m.peers { - check(ps) - } + m.streamsMu.Lock() + if peerNodeName != "" { + if ps, ok := m.peers[peerNodeName]; ok && ps != nil { + check(ps) + } + } else { + for _, ps := range m.peers { + check(ps) } - m.streamsMu.Unlock() } - go f() + m.streamsMu.Unlock() } diff --git a/proxy/proxy.go b/proxy/proxy.go index 8ce1e68d..de08ec15 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -90,35 +90,6 @@ func NewProxy(configProvider config.ConfigProvider, logger log.Logger) *Proxy { proxy.metricsConfig = s2sConfig.Metrics } - // TODO: Wire intra-proxy manager callbacks - // // Wire memberlist peer-join callback to reconcile intra-proxy receivers for local/remote pairs - // shardManager.SetOnPeerJoin(func(nodeName string) { - // logger.Info("OnPeerJoin", tag.NewStringTag("nodeName", nodeName)) - // defer logger.Info("OnPeerJoin done", tag.NewStringTag("nodeName", nodeName)) - // proxy.intraMgr.ReconcilePeerStreams(proxy, nodeName) - // }) - - // // Wire peer-leave to cleanup intra-proxy resources for that peer - // shardManager.SetOnPeerLeave(func(nodeName string) { - // logger.Info("OnPeerLeave", tag.NewStringTag("nodeName", nodeName)) - // defer logger.Info("OnPeerLeave done", tag.NewStringTag("nodeName", nodeName)) - // proxy.intraMgr.ReconcilePeerStreams(proxy, nodeName) - // }) - - // // Wire local shard changes to reconcile intra-proxy receivers - // shardManager.SetOnLocalShardChange(func(shard history.ClusterShardID, added bool) { - // logger.Info("OnLocalShardChange", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) - // defer logger.Info("OnLocalShardChange done", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) - // proxy.intraMgr.ReconcilePeerStreams(proxy, "") - // }) - - // // Wire remote shard changes to reconcile intra-proxy receivers - // shardManager.SetOnRemoteShardChange(func(peer string, shard history.ClusterShardID, added bool) { - // logger.Info("OnRemoteShardChange", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) - // defer logger.Info("OnRemoteShardChange done", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) - // proxy.intraMgr.ReconcilePeerStreams(proxy, peer) - // }) - for _, clusterCfg := range s2sConfig.ClusterConnections { shardManager, err := NewShardManager(configProvider, logger) if err != nil { @@ -132,7 +103,7 @@ func NewProxy(configProvider config.ConfigProvider, logger log.Logger) *Proxy { } migrationId := migrationId{clusterCfg.Name} proxy.clusterConnections[migrationId] = cc - proxy.intraMgrs[migrationId] = newIntraProxyManager(logger, shardManager, clusterCfg.ShardCountConfig) + proxy.intraMgrs[migrationId] = newIntraProxyManager(logger, proxy, shardManager, clusterCfg.ShardCountConfig) proxy.shardManagers[migrationId] = shardManager shardManager.SetIntraProxyManager(proxy.intraMgrs[migrationId]) } @@ -249,6 +220,12 @@ func (s *Proxy) Start() error { } } + for _, intraMgr := range s.intraMgrs { + if err := intraMgr.Start(); err != nil { + return err + } + } + for _, v := range s.clusterConnections { v.Start() } diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 1615859c..222ba9aa 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -238,12 +238,12 @@ func (s *proxyStreamSender) Run( s.logger = log.With(s.logger, tag.NewStringTag("role", "sender"), ) + s.logger.Info("proxyStreamSender Run") + defer s.logger.Info("proxyStreamSender Run finished") // Register this sender as the owner of the shard for the duration of the stream - if s.shardManager != nil { - s.shardManager.RegisterShard(s.targetShardID) - defer s.shardManager.UnregisterShard(s.targetShardID) - } + s.shardManager.RegisterShard(s.targetShardID) + defer s.shardManager.UnregisterShard(s.targetShardID) // Register local stream tracking for sender (short id, include role) s.streamTracker = GetGlobalStreamTracker() @@ -577,6 +577,8 @@ func (r *proxyStreamReceiver) Run( tag.NewStringTag("stream-target-shard", ClusterShardIDtoString(r.targetShardID)), tag.NewStringTag("role", "receiver"), ) + r.logger.Info("proxyStreamReceiver Run") + defer r.logger.Info("proxyStreamReceiver Run finished") // Build metadata for local server stream md := metadata.New(map[string]string{}) @@ -694,21 +696,40 @@ func (r *proxyStreamReceiver) recvReplicationMessages( // If replication tasks are empty, still log the empty batch and send watermark if len(attr.Messages.ReplicationTasks) == 0 { r.logger.Info("Receiver received empty replication batch", tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) - for targetShardID, sendChan := range r.proxy.GetRemoteSendChansByCluster(r.targetShardID.ClusterID) { - r.logger.Info("Sending high watermark to target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) - sendChan <- RoutedMessage{ - SourceShard: r.sourceShardID, - Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ - Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ - Messages: &replicationv1.WorkflowReplicationMessages{ - ExclusiveHighWatermark: attr.Messages.ExclusiveHighWatermark, - Priority: attr.Messages.Priority, - }, + msg := RoutedMessage{ + SourceShard: r.sourceShardID, + Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: attr.Messages.ExclusiveHighWatermark, + Priority: attr.Messages.Priority, }, }, + }, + } + localShardsToSend := r.proxy.GetRemoteSendChansByCluster(r.targetShardID.ClusterID) + r.logger.Info("Going to broadcast high watermark to local shards", tag.NewStringTag("localShardsToSend", fmt.Sprintf("%v", localShardsToSend))) + for targetShardID, sendChan := range localShardsToSend { + r.logger.Info("Sending high watermark to target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) + sendChan <- msg + } + // send to all remote shards on other nodes as well + remoteShards, err := r.shardManager.GetRemoteShardsForPeer("") + if err != nil { + r.logger.Error("Failed to get remote shards", tag.Error(err)) + return err + } + r.logger.Info("Going to broadcast high watermark to remote shards", tag.NewStringTag("remoteShards", fmt.Sprintf("%v", remoteShards))) + for _, shards := range remoteShards { + for _, shard := range shards.Shards { + if shard.ID.ClusterID != r.targetShardID.ClusterID { + continue + } + if !r.shardManager.DeliverMessagesToShardOwner(shard.ID, &msg, r.proxy, shutdownChan, r.logger) { + r.logger.Warn("Failed to send ReplicationTasks to remote shard", tag.NewStringTag("shard", ClusterShardIDtoString(shard.ID))) + } } } - continue } // Retry across the whole target set until all sends succeed (or shutdown) diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index ee2bec68..dd05baec 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -247,6 +247,7 @@ func (sm *shardManagerImpl) Stop() { } func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) { + sm.logger.Info("RegisterShard", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) sm.addLocalShard(clientShardID) sm.broadcastShardChange("register", clientShardID) @@ -263,6 +264,7 @@ func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) } func (sm *shardManagerImpl) UnregisterShard(clientShardID history.ClusterShardID) { + sm.logger.Info("UnregisterShard", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) sm.removeLocalShard(clientShardID) sm.mutex.Lock() delete(sm.localShards, ClusterShardIDtoShortString(clientShardID)) @@ -524,6 +526,38 @@ func (sm *shardManagerImpl) DeliverMessagesToShardOwner( func (sm *shardManagerImpl) SetIntraProxyManager(intraMgr *intraProxyManager) { sm.intraMgr = intraMgr + + // Wire memberlist peer-join callback to reconcile intra-proxy receivers for local/remote pairs + sm.SetOnPeerJoin(func(nodeName string) { + sm.logger.Info("OnPeerJoin", tag.NewStringTag("nodeName", nodeName)) + defer sm.logger.Info("OnPeerJoin done", tag.NewStringTag("nodeName", nodeName)) + sm.intraMgr.Notify() + // proxy.intraMgr.ReconcilePeerStreams(proxy, nodeName) + }) + + // Wire peer-leave to cleanup intra-proxy resources for that peer + sm.SetOnPeerLeave(func(nodeName string) { + sm.logger.Info("OnPeerLeave", tag.NewStringTag("nodeName", nodeName)) + defer sm.logger.Info("OnPeerLeave done", tag.NewStringTag("nodeName", nodeName)) + sm.intraMgr.Notify() + // proxy.intraMgr.ReconcilePeerStreams(proxy, nodeName) + }) + + // Wire local shard changes to reconcile intra-proxy receivers + sm.SetOnLocalShardChange(func(shard history.ClusterShardID, added bool) { + sm.logger.Info("OnLocalShardChange", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + defer sm.logger.Info("OnLocalShardChange done", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + sm.intraMgr.Notify() + // proxy.intraMgr.ReconcilePeerStreams(proxy, "") + }) + + // Wire remote shard changes to reconcile intra-proxy receivers + sm.SetOnRemoteShardChange(func(peer string, shard history.ClusterShardID, added bool) { + sm.logger.Info("OnRemoteShardChange", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + defer sm.logger.Info("OnRemoteShardChange done", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + sm.intraMgr.Notify() + // proxy.intraMgr.ReconcilePeerStreams(proxy, peer) + }) } func (sm *shardManagerImpl) GetIntraProxyManager() *intraProxyManager { From 04046e5292c1c5710bd9f738643d2e095ff0f9e3 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 26 Sep 2025 16:54:33 -0700 Subject: [PATCH 10/38] add log for debugging --- proxy/intra_proxy_router.go | 61 +++++++++++++------------------------ proxy/proxy.go | 10 +++--- proxy/proxy_streams.go | 8 ++++- proxy/shard_manager.go | 3 ++ 4 files changed, 37 insertions(+), 45 deletions(-) diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 1ef525b5..25ca7e16 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -290,15 +290,16 @@ func (r *intraProxyStreamReceiver) sendAck(req *adminservice.StreamWorkflowRepli func (m *intraProxyManager) RegisterSender( peerNodeName string, - clientShard history.ClusterShardID, - serverShard history.ClusterShardID, + targetShard history.ClusterShardID, + sourceShard history.ClusterShardID, sender *intraProxyStreamSender, ) { // Cross-cluster only - if clientShard.ClusterID == serverShard.ClusterID { + if targetShard.ClusterID == sourceShard.ClusterID { return } - key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} + key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} + m.logger.Info("RegisterSender", tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("key", fmt.Sprintf("%v", key)), tag.NewStringTag("sender", sender.streamID)) m.streamsMu.Lock() ps := m.peers[peerNodeName] if ps == nil { @@ -314,10 +315,11 @@ func (m *intraProxyManager) RegisterSender( func (m *intraProxyManager) UnregisterSender( peerNodeName string, - clientShard history.ClusterShardID, - serverShard history.ClusterShardID, + targetShard history.ClusterShardID, + sourceShard history.ClusterShardID, ) { - key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} + key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} + m.logger.Info("UnregisterSender", tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("key", fmt.Sprintf("%v", key))) m.streamsMu.Lock() if ps := m.peers[peerNodeName]; ps != nil && ps.senders != nil { delete(ps.senders, key) @@ -476,6 +478,7 @@ func (m *intraProxyManager) ensureStream( ps.receivers[key] = recv ps.recvShutdown[key] = recv.shutdown m.streamsMu.Unlock() + m.logger.Info("intraProxyStreamReceiver added", tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("key", fmt.Sprintf("%v", key)), tag.NewStringTag("receiver", recv.streamID)) // Let the receiver open stream, register tracking, and start goroutines go func() { @@ -514,12 +517,15 @@ func (m *intraProxyManager) sendAck( func (m *intraProxyManager) sendReplicationMessages( ctx context.Context, peerNodeName string, - clientShard history.ClusterShardID, - serverShard history.ClusterShardID, + targetShard history.ClusterShardID, + sourceShard history.ClusterShardID, p *Proxy, resp *adminservice.StreamWorkflowReplicationMessagesResponse, ) error { - key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} + key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} + logger := log.With(m.logger, tag.NewStringTag("task-target-shard", ClusterShardIDtoString(targetShard)), tag.NewStringTag("task-source-shard", ClusterShardIDtoString(sourceShard))) + logger.Info("sendReplicationMessages") + defer logger.Info("sendReplicationMessages finished") // Try server stream first with short retry/backoff to await registration deadline := time.Now().Add(2 * time.Second) @@ -529,15 +535,17 @@ func (m *intraProxyManager) sendReplicationMessages( m.streamsMu.RLock() ps, ok := m.peers[peerNodeName] if ok && ps != nil && ps.senders != nil { + logger.Info("sendReplicationMessages senders for node", tag.NewStringTag("node", peerNodeName), tag.NewStringTag("senders", fmt.Sprintf("%v", ps.senders))) if s, ok2 := ps.senders[key]; ok2 && s != nil { sender = s } } m.streamsMu.RUnlock() + logger.Info("sendReplicationMessages sender", tag.NewStringTag("sender", fmt.Sprintf("%v", sender))) if sender != nil { if err := sender.sendReplicationMessages(resp); err != nil { - m.logger.Error("Failed to send intra-proxy replication messages via server stream", tag.Error(err)) + logger.Error("Failed to send intra-proxy replication messages via server stream", tag.Error(err)) return err } return nil @@ -552,7 +560,7 @@ func (m *intraProxyManager) sendReplicationMessages( } } - return fmt.Errorf("stream does not support SendMsg for responses") + return fmt.Errorf("failed to send replication messages") } // closePeerLocked shuts down and removes all resources for a peer. Caller must hold m.streamsMu. @@ -569,6 +577,7 @@ func (m *intraProxyManager) closePeerLocked(peer string, ps *peerState) { } // Close client streams (receiver cleanup is handled by its own goroutine) for key := range ps.receivers { + m.logger.Info("intraProxyStreamReceiver deleted", tag.NewStringTag("peerNodeName", peer), tag.NewStringTag("key", fmt.Sprintf("%v", key)), tag.NewStringTag("receiver", ps.receivers[key].streamID)) delete(ps.receivers, key) } // Unregister server-side tracker entries @@ -603,6 +612,7 @@ func (m *intraProxyManager) closePeerShardLocked(peer string, ps *peerState, key if r.streamClient != nil { _ = r.streamClient.CloseSend() } + m.logger.Info("intraProxyStreamReceiver deleted", tag.NewStringTag("peerNodeName", peer), tag.NewStringTag("key", fmt.Sprintf("%v", key)), tag.NewStringTag("receiver", r.streamID)) delete(ps.receivers, key) } st := GetGlobalStreamTracker() @@ -630,33 +640,6 @@ func (m *intraProxyManager) ClosePeerShard(peer string, clientShard, serverShard } } -// CloseShardAcrossPeers closes all sender/receiver streams for any peer that involve the specified shard -// as either client or server shard. Useful when a local shard is unregistered. -func (m *intraProxyManager) CloseShardAcrossPeers(shard history.ClusterShardID) { - m.streamsMu.Lock() - defer m.streamsMu.Unlock() - for peer, ps := range m.peers { - // Collect keys to avoid mutating map during iteration - toClose := make([]peerStreamKey, 0) - for key := range ps.receivers { - if (key.targetShard.ClusterID == shard.ClusterID && key.targetShard.ShardID == shard.ShardID) || - (key.sourceShard.ClusterID == shard.ClusterID && key.sourceShard.ShardID == shard.ShardID) { - toClose = append(toClose, key) - } - } - for key := range ps.senders { - if (key.targetShard.ClusterID == shard.ClusterID && key.targetShard.ShardID == shard.ShardID) || - (key.sourceShard.ClusterID == shard.ClusterID && key.sourceShard.ShardID == shard.ShardID) { - // ensure key is present in toClose for unified cleanup - toClose = append(toClose, key) - } - } - for _, key := range toClose { - m.closePeerShardLocked(peer, ps, key) - } - } -} - func (m *intraProxyManager) Start() error { go func() { for { diff --git a/proxy/proxy.go b/proxy/proxy.go index de08ec15..41a9122b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -307,10 +307,10 @@ func (s *Proxy) GetIntraProxyManager(migrationId migrationId) *intraProxyManager // SetRemoteSendChan registers a send channel for a specific shard ID func (s *Proxy) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) { + s.logger.Info("Register remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) s.remoteSendChannelsMu.Lock() defer s.remoteSendChannelsMu.Unlock() s.remoteSendChannels[shardID] = sendChan - s.logger.Info("Registered remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) } // GetRemoteSendChan retrieves the send channel for a specific shard ID @@ -358,10 +358,10 @@ func (s *Proxy) RemoveRemoteSendChan(shardID history.ClusterShardID) { // SetLocalAckChan registers an ack channel for a specific shard ID func (s *Proxy) SetLocalAckChan(shardID history.ClusterShardID, ackChan chan RoutedAck) { + s.logger.Info("Register local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) s.localAckChannelsMu.Lock() defer s.localAckChannelsMu.Unlock() s.localAckChannels[shardID] = ackChan - s.logger.Info("Registered local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) } // GetLocalAckChan retrieves the ack channel for a specific shard ID @@ -374,18 +374,18 @@ func (s *Proxy) GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, // RemoveLocalAckChan removes the ack channel for a specific shard ID func (s *Proxy) RemoveLocalAckChan(shardID history.ClusterShardID) { + s.logger.Info("Remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) s.localAckChannelsMu.Lock() defer s.localAckChannelsMu.Unlock() delete(s.localAckChannels, shardID) - s.logger.Info("Removed local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) } // SetLocalReceiverCancelFunc registers a cancel function for a local receiver for a specific shard ID func (s *Proxy) SetLocalReceiverCancelFunc(shardID history.ClusterShardID, cancelFunc context.CancelFunc) { + s.logger.Info("Register local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) s.localReceiverCancelFuncsMu.Lock() defer s.localReceiverCancelFuncsMu.Unlock() s.localReceiverCancelFuncs[shardID] = cancelFunc - s.logger.Info("Registered local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) } // GetLocalReceiverCancelFunc retrieves the cancel function for a local receiver for a specific shard ID @@ -398,10 +398,10 @@ func (s *Proxy) GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (cont // RemoveLocalReceiverCancelFunc removes the cancel function for a local receiver for a specific shard ID func (s *Proxy) RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) { + s.logger.Info("Remove local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) s.localReceiverCancelFuncsMu.Lock() defer s.localReceiverCancelFuncsMu.Unlock() delete(s.localReceiverCancelFuncs, shardID) - s.logger.Info("Removed local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) } // TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 222ba9aa..2519ff62 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -591,6 +591,8 @@ func (r *proxyStreamReceiver) Run( outgoingContext, cancel := context.WithCancel(outgoingContext) defer cancel() + r.logger.Info("proxyStreamReceiver outgoingContext created") + // Open stream receiver -> local server's stream sender for clientShardID var sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient var err error @@ -600,6 +602,8 @@ func (r *proxyStreamReceiver) Run( return } + r.logger.Info("proxyStreamReceiver sourceStreamClient created") + // Setup ack channel and cancel func bookkeeping r.ackChan = make(chan RoutedAck, 100) r.proxy.SetLocalAckChan(r.sourceShardID, r.ackChan) @@ -730,6 +734,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( } } } + continue } // Retry across the whole target set until all sends succeed (or shutdown) @@ -738,6 +743,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( for targetShardID := range tasksByTargetShard { sentByTarget[targetShardID] = false } + r.logger.Info("Going to broadcast ReplicationTasks to target shards", tag.NewStringTag("tasksByTargetShard", fmt.Sprintf("%v", tasksByTargetShard))) numRemaining := len(tasksByTargetShard) backoff := 10 * time.Millisecond for numRemaining > 0 { @@ -769,7 +775,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( progress = true } else { if !loggedByTarget[targetShardID] { - r.logger.Warn("No send channel found for target shard; retrying until available", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + r.logger.Warn("No send channel found for target shard; retrying until available", tag.NewStringTag("task-target-shard", ClusterShardIDtoString(targetShardID))) loggedByTarget[targetShardID] = true } } diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index dd05baec..cb71fab9 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -492,10 +492,13 @@ func (sm *shardManagerImpl) DeliverMessagesToShardOwner( shutdownChan channel.ShutdownOnce, logger log.Logger, ) bool { + logger = log.With(logger, tag.NewStringTag("task-target-shard", ClusterShardIDtoString(targetShard))) + // Try local delivery first if ch, ok := proxy.GetRemoteSendChan(targetShard); ok { select { case ch <- *routedMsg: + logger.Info("Delivered messages to local shard owner") return true case <-shutdownChan.Channel(): return false From 796bef6650918ae38c756ffdca4b374e433fb8d7 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Sun, 12 Oct 2025 20:50:53 -0700 Subject: [PATCH 11/38] fix issues --- .../config/cluster-b-mux-server-proxy-1.yaml | 14 +-- .../config/cluster-b-mux-server-proxy-2.yaml | 12 +-- proxy/intra_proxy_router.go | 14 ++- proxy/shard_manager.go | 87 +++++++++++++++---- 4 files changed, 96 insertions(+), 31 deletions(-) diff --git a/develop/config/cluster-b-mux-server-proxy-1.yaml b/develop/config/cluster-b-mux-server-proxy-1.yaml index 4e2b6da9..e204309d 100644 --- a/develop/config/cluster-b-mux-server-proxy-1.yaml +++ b/develop/config/cluster-b-mux-server-proxy-1.yaml @@ -32,15 +32,15 @@ profiling: memberlist: enabled: true nodeName: "proxy-node-b-1" - bindAddr: "0.0.0.0" + bindAddr: "127.0.0.1" bindPort: 6335 - joinAddrs: - - "localhost:6435" + # joinAddrs: + # - "localhost:6435" proxyAddresses: "proxy-node-b-1": "localhost:6333" "proxy-node-b-2": "localhost:6433" - # TCP-only configuration for restricted networks + # # TCP-only configuration for restricted networks tcpOnly: true # Use TCP transport only, disable UDP - disableTCPPings: true # Disable TCP pings for faster convergence - probeTimeoutMs: 1000 # Longer timeout for network latency - probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file + # disableTCPPings: true # Disable TCP pings for faster convergence + # probeTimeoutMs: 1000 # Longer timeout for network latency + # probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file diff --git a/develop/config/cluster-b-mux-server-proxy-2.yaml b/develop/config/cluster-b-mux-server-proxy-2.yaml index 2d7b111a..e37b5006 100644 --- a/develop/config/cluster-b-mux-server-proxy-2.yaml +++ b/develop/config/cluster-b-mux-server-proxy-2.yaml @@ -32,15 +32,15 @@ profiling: memberlist: enabled: true nodeName: "proxy-node-b-2" - bindAddr: "0.0.0.0" + bindAddr: "127.0.0.1" bindPort: 6435 joinAddrs: - "localhost:6335" proxyAddresses: "proxy-node-b-1": "localhost:6333" "proxy-node-b-2": "localhost:6433" - # TCP-only configuration for restricted networks - tcpOnly: true # Use TCP transport only, disable UDP - disableTCPPings: true # Disable TCP pings for faster convergence - probeTimeoutMs: 1000 # Longer timeout for network latency - probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file + # # TCP-only configuration for restricted networks + # tcpOnly: true # Use TCP transport only, disable UDP + # disableTCPPings: true # Disable TCP pings for faster convergence + # probeTimeoutMs: 1000 # Longer timeout for network latency + # probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 25ca7e16..ab80dbdc 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -344,7 +344,10 @@ func (m *intraProxyManager) EnsureReceiverForPeerShard(p *Proxy, peerNodeName st return } // Require at least one shard to be local to this instance - if !m.shardManager.IsLocalShard(targetShard) && !m.shardManager.IsLocalShard(sourceShard) { + isLocalTargetShard := m.shardManager.IsLocalShard(targetShard) + isLocalSourceShard := m.shardManager.IsLocalShard(sourceShard) + if !isLocalTargetShard && !isLocalSourceShard { + logger.Info("EnsureReceiverForPeerShard skipping because neither shard is local", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewBoolTag("isLocalTargetShard", isLocalTargetShard), tag.NewBoolTag("isLocalSourceShard", isLocalSourceShard)) return } // Consolidated path: ensure stream and background loops @@ -483,8 +486,13 @@ func (m *intraProxyManager) ensureStream( // Let the receiver open stream, register tracking, and start goroutines go func() { if err := recv.Run(ctx, p, ps.conn); err != nil { - recv.logger.Error("intraProxyStreamReceiver Run failed", tag.Error(err)) + m.logger.Error("intraProxyStreamReceiver.Run error", tag.Error(err)) } + // remove the receiver from the peer state + m.streamsMu.Lock() + delete(ps.receivers, key) + delete(ps.recvShutdown, key) + m.streamsMu.Unlock() }() return nil } @@ -641,6 +649,8 @@ func (m *intraProxyManager) ClosePeerShard(peer string, clientShard, serverShard } func (m *intraProxyManager) Start() error { + m.logger.Info("intraProxyManager started") + defer m.logger.Info("intraProxyManager stopped") go func() { for { // timer diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index cb71fab9..025c8616 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -162,9 +162,6 @@ func (sm *shardManagerImpl) SetOnRemoteShardChange(handler func(peer string, sha func (sm *shardManagerImpl) Start(lifetime context.Context) error { sm.logger.Info("Starting shard manager") - sm.mutex.Lock() - defer sm.mutex.Unlock() - if sm.started { sm.logger.Info("Shard manager already started") return nil @@ -173,20 +170,29 @@ func (sm *shardManagerImpl) Start(lifetime context.Context) error { // Configure memberlist var mlConfig *memberlist.Config if sm.config.TCPOnly { - mlConfig = memberlist.DefaultWANConfig() - // Disable UDP for restricted networks + // Use LAN config as base for TCP-only mode + mlConfig = memberlist.DefaultLANConfig() mlConfig.DisableTcpPings = sm.config.DisableTCPPings + // Set default timeouts for TCP-only if not specified + if sm.config.ProbeTimeoutMs == 0 { + mlConfig.ProbeTimeout = 1 * time.Second + } + if sm.config.ProbeIntervalMs == 0 { + mlConfig.ProbeInterval = 2 * time.Second + } } else { mlConfig = memberlist.DefaultLocalConfig() } - mlConfig.Name = sm.config.NodeName mlConfig.BindAddr = sm.config.BindAddr mlConfig.BindPort = sm.config.BindPort + mlConfig.AdvertiseAddr = sm.config.BindAddr + mlConfig.AdvertisePort = sm.config.BindPort + mlConfig.Delegate = sm.delegate mlConfig.Events = &shardEventDelegate{manager: sm, logger: sm.logger} - // Configure timeouts if specified + // Configure custom timeouts if specified if sm.config.ProbeTimeoutMs > 0 { mlConfig.ProbeTimeout = time.Duration(sm.config.ProbeTimeoutMs) * time.Millisecond } @@ -194,25 +200,60 @@ func (sm *shardManagerImpl) Start(lifetime context.Context) error { mlConfig.ProbeInterval = time.Duration(sm.config.ProbeIntervalMs) * time.Millisecond } - // Create memberlist - ml, err := memberlist.Create(mlConfig) - if err != nil { - return fmt.Errorf("failed to create memberlist: %w", err) + sm.logger.Info("Creating memberlist", + tag.NewStringTag("nodeName", mlConfig.Name), + tag.NewStringTag("bindAddr", mlConfig.BindAddr), + tag.NewStringTag("bindPort", fmt.Sprintf("%d", mlConfig.BindPort)), + tag.NewBoolTag("tcpOnly", sm.config.TCPOnly), + tag.NewBoolTag("disableTcpPings", mlConfig.DisableTcpPings), + tag.NewStringTag("probeTimeout", mlConfig.ProbeTimeout.String()), + tag.NewStringTag("probeInterval", mlConfig.ProbeInterval.String())) + + // Create memberlist with timeout protection + type result struct { + ml *memberlist.Memberlist + err error + } + resultCh := make(chan result, 1) + go func() { + ml, err := memberlist.Create(mlConfig) + resultCh <- result{ml: ml, err: err} + }() + + var ml *memberlist.Memberlist + select { + case res := <-resultCh: + ml = res.ml + if res.err != nil { + return fmt.Errorf("failed to create memberlist: %w", res.err) + } + sm.logger.Info("Memberlist created successfully") + case <-time.After(10 * time.Second): + return fmt.Errorf("memberlist.Create() timed out after 10s - check bind address/port availability") } + sm.mutex.Lock() sm.ml = ml sm.localAddr = fmt.Sprintf("%s:%d", sm.config.BindAddr, sm.config.BindPort) + sm.started = true + + sm.logger.Info("Shard manager base initialization complete", + tag.NewStringTag("node", sm.config.NodeName), + tag.NewStringTag("addr", sm.localAddr)) + + sm.mutex.Unlock() // Join existing cluster if configured if len(sm.config.JoinAddrs) > 0 { + sm.logger.Info("Attempting to join cluster", tag.NewStringTag("joinAddrs", fmt.Sprintf("%v", sm.config.JoinAddrs))) num, err := ml.Join(sm.config.JoinAddrs) if err != nil { sm.logger.Warn("Failed to join some cluster members", tag.Error(err)) + } else { + sm.logger.Info("Joined memberlist cluster", tag.NewStringTag("members", strconv.Itoa(num))) } - sm.logger.Info("Joined memberlist cluster", tag.NewStringTag("members", strconv.Itoa(num))) } - sm.started = true sm.logger.Info("Shard manager started", tag.NewStringTag("node", sm.config.NodeName), tag.NewStringTag("addr", sm.localAddr)) @@ -225,11 +266,12 @@ func (sm *shardManagerImpl) Start(lifetime context.Context) error { func (sm *shardManagerImpl) Stop() { sm.mutex.Lock() - defer sm.mutex.Unlock() if !sm.started || sm.ml == nil { + sm.mutex.Unlock() return } + sm.mutex.Unlock() // Leave the cluster gracefully err := sm.ml.Leave(5 * time.Second) @@ -242,7 +284,9 @@ func (sm *shardManagerImpl) Stop() { sm.logger.Error("Error shutting down memberlist", tag.Error(err)) } + sm.mutex.Lock() sm.started = false + sm.mutex.Unlock() sm.logger.Info("Shard manager stopped") } @@ -405,6 +449,8 @@ func (sm *shardManagerImpl) GetRemoteShardsForPeer(peerNodeName string) (map[str membersChan := make(chan []*memberlist.Node, 1) go func() { defer func() { _ = recover() }() + sm.mutex.RLock() + defer sm.mutex.RUnlock() membersChan <- sm.ml.Members() }() @@ -605,9 +651,18 @@ func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.C // shardDelegate implements memberlist.Delegate func (sd *shardDelegate) NodeMeta(limit int) []byte { + // Copy shard map under read lock to avoid concurrent map iteration/modification + sd.manager.mutex.RLock() + shardsCopy := make(map[string]ShardInfo, len(sd.manager.localShards)) + for k, v := range sd.manager.localShards { + shardsCopy[k] = v + } + nodeName := sd.manager.config.NodeName + sd.manager.mutex.RUnlock() + state := NodeShardState{ - NodeName: sd.manager.config.NodeName, - Shards: sd.manager.localShards, + NodeName: nodeName, + Shards: shardsCopy, Updated: time.Now(), } From 6e98448ce55ffe80619ac7aadc1c41fd852c8ccc Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Mon, 13 Oct 2025 03:05:37 -0700 Subject: [PATCH 12/38] fix streams --- develop/config/dynamic-config.yaml | 4 +- proxy/adminservice.go | 8 +- proxy/intra_proxy_router.go | 73 +++++--- proxy/proxy.go | 38 ++-- proxy/proxy_streams.go | 282 +++++++++++++++++++---------- proxy/shard_manager.go | 151 +++++++++++---- proxy/stream_tracker.go | 10 +- 7 files changed, 378 insertions(+), 188 deletions(-) diff --git a/develop/config/dynamic-config.yaml b/develop/config/dynamic-config.yaml index dbe95f8b..f95073c9 100644 --- a/develop/config/dynamic-config.yaml +++ b/develop/config/dynamic-config.yaml @@ -28,4 +28,6 @@ frontend.persistenceMaxQPS: - value: 100000 constraints: {} history.shardUpdateMinInterval: - - value: 1s \ No newline at end of file + - value: 1s +history.ReplicationStreamSendEmptyTaskDuration: + - value: 10s \ No newline at end of file diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 2490f6b8..24c890d5 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -252,16 +252,16 @@ func ClusterShardIDtoShortString(sd history.ClusterShardID) string { // stream using our configured adminClient. When we Recv on the initiator, we Send to the client. // When we Recv on the client, we Send to the initiator func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( - targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, ) (retError error) { defer log.CapturePanic(s.logger, &retError) - targetMetadata, ok := metadata.FromIncomingContext(targetStreamServer.Context()) + targetMetadata, ok := metadata.FromIncomingContext(streamServer.Context()) if !ok { return serviceerror.NewInvalidArgument("missing cluster & shard ID metadata") } targetClusterShardID, sourceClusterShardID, err := history.DecodeClusterShardMD( - headers.NewGRPCHeaderGetter(targetStreamServer.Context()), + headers.NewGRPCHeaderGetter(streamServer.Context()), ) if err != nil { return err @@ -313,7 +313,7 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( forwarder := newStreamForwarder( s.adminClient, - targetStreamServer, + streamServer, targetMetadata, sourceClusterShardID, targetClusterShardID, diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index ab80dbdc..48b475c9 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -59,32 +59,33 @@ func newIntraProxyManager(logger log.Logger, proxy *Proxy, shardManager ShardMan // intraProxyStreamSender registers server stream and forwards upstream ACKs to shard owners (local or remote). // Replication messages are sent by intraProxyManager.sendMessages using the registered server stream. type intraProxyStreamSender struct { - logger log.Logger - shardManager ShardManager - proxy *Proxy - intraMgr *intraProxyManager - peerNodeName string - targetShardID history.ClusterShardID - sourceShardID history.ClusterShardID - streamID string - server adminservice.AdminService_StreamWorkflowReplicationMessagesServer + logger log.Logger + shardManager ShardManager + proxy *Proxy + intraMgr *intraProxyManager + peerNodeName string + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + streamID string + sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer } func (s *intraProxyStreamSender) Run( - targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, shutdownChan channel.ShutdownOnce, ) error { + s.streamID = BuildIntraProxySenderStreamID(s.peerNodeName, s.sourceShardID, s.targetShardID) + s.logger = log.With(s.logger, tag.NewStringTag("streamID", s.streamID)) + s.logger.Info("intraProxyStreamSender Run") defer s.logger.Info("intraProxyStreamSender Run finished") // Register server-side intra-proxy stream in tracker - s.streamID = BuildIntraProxySenderStreamID(s.peerNodeName, s.sourceShardID, s.targetShardID) - s.logger = log.With(s.logger, tag.NewStringTag("streamID", s.streamID)) st := GetGlobalStreamTracker() st.RegisterStream(s.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(s.targetShardID), ClusterShardIDtoString(s.sourceShardID), StreamRoleForwarder) defer st.UnregisterStream(s.streamID) - s.server = targetStreamServer + s.sourceStreamServer = sourceStreamServer // register this sender so sendMessages can use it s.intraMgr.RegisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID, s) @@ -100,7 +101,7 @@ func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) erro defer s.logger.Info("intraProxyStreamSender recvAck finished") for !shutdownChan.IsShutdown() { - req, err := s.server.Recv() + req, err := s.sourceStreamServer.Recv() if err == io.EOF { s.logger.Info("intraProxyStreamSender recvAck encountered EOF") return nil @@ -143,6 +144,9 @@ func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) erro // sendReplicationMessages sends replication messages to the peer via the server stream. func (s *intraProxyStreamSender) sendReplicationMessages(resp *adminservice.StreamWorkflowReplicationMessagesResponse) error { + s.logger.Info("intraProxyStreamSender sendReplicationMessages started") + defer s.logger.Info("intraProxyStreamSender sendReplicationMessages finished") + // Update server-side intra-proxy tracker for outgoing messages if msgs, ok := resp.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages); ok && msgs.Messages != nil { st := GetGlobalStreamTracker() @@ -154,7 +158,7 @@ func (s *intraProxyStreamSender) sendReplicationMessages(resp *adminservice.Stre st.UpdateStreamReplicationMessages(s.streamID, msgs.Messages.ExclusiveHighWatermark) st.UpdateStream(s.streamID) } - if err := s.server.Send(resp); err != nil { + if err := s.sourceStreamServer.Send(resp); err != nil { return err } return nil @@ -177,6 +181,9 @@ type intraProxyStreamReceiver struct { // Run opens the client stream with metadata, registers tracking, and starts receiver goroutines. func (r *intraProxyStreamReceiver) Run(ctx context.Context, p *Proxy, conn *grpc.ClientConn) error { + r.streamID = BuildIntraProxyReceiverStreamID(r.peerNodeName, r.sourceShardID, r.targetShardID) + r.logger = log.With(r.logger, tag.NewStringTag("streamID", r.streamID)) + r.logger.Info("intraProxyStreamReceiver Run") // Build metadata according to receiver pattern: client=targetShard, server=sourceShard md := metadata.New(map[string]string{}) @@ -204,8 +211,6 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, p *Proxy, conn *grpc r.streamClient = stream // Register client-side intra-proxy stream in tracker - r.streamID = BuildIntraProxyReceiverStreamID(r.peerNodeName, r.sourceShardID, r.targetShardID) - r.logger = log.With(r.logger, tag.NewStringTag("streamID", r.streamID)) st := GetGlobalStreamTracker() st.RegisterStream(r.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(r.targetShardID), ClusterShardIDtoString(r.sourceShardID), StreamRoleForwarder) defer st.UnregisterStream(r.streamID) @@ -216,8 +221,8 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, p *Proxy, conn *grpc // recvReplicationMessages receives replication messages and forwards to local shard owner. func (r *intraProxyStreamReceiver) recvReplicationMessages(self *Proxy) error { - r.logger.Info("recvReplicationMessages started") - defer r.logger.Info("recvReplicationMessages finished") + r.logger.Info("intraProxyStreamReceiver recvReplicationMessages started") + defer r.logger.Info("intraProxyStreamReceiver recvReplicationMessages finished") shutdown := r.shutdown defer shutdown.Shutdown() @@ -250,11 +255,22 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages(self *Proxy) error { logged := false for !sent { if ch, ok := self.GetRemoteSendChan(r.targetShardID); ok { - select { - case ch <- msg: - sent = true - r.logger.Info("Receiver sent ReplicationTasks to local target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID)), tag.NewInt64("exclusive_high", msgs.Messages.ExclusiveHighWatermark)) - case <-shutdown.Channel(): + func() { + defer func() { + if panicErr := recover(); panicErr != nil { + r.logger.Warn("Failed to send to local target shard (channel closed)", + tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID))) + } + }() + select { + case ch <- msg: + sent = true + r.logger.Info("Receiver sent ReplicationTasks to local target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID)), tag.NewInt64("exclusive_high", msgs.Messages.ExclusiveHighWatermark)) + case <-shutdown.Channel(): + // Will be handled outside the func + } + }() + if shutdown.IsShutdown() { return nil } } else { @@ -277,6 +293,9 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages(self *Proxy) error { // sendAck sends an ACK upstream via the client stream and updates tracker. func (r *intraProxyStreamReceiver) sendAck(req *adminservice.StreamWorkflowReplicationMessagesRequest) error { + r.logger.Info("intraProxyStreamReceiver sendAck started") + defer r.logger.Info("intraProxyStreamReceiver sendAck finished") + if err := r.streamClient.Send(req); err != nil { return err } @@ -726,12 +745,6 @@ func (m *intraProxyManager) ReconcilePeerStreams( m.closePeerShardLocked(peerNodeName, ps, key) } } - // And for server-side senders, if they don't belong to desired pairs - for key := range ps.senders { - if _, ok2 := desired[key]; !ok2 { - m.closePeerShardLocked(peerNodeName, ps, key) - } - } } m.streamsMu.Lock() diff --git a/proxy/proxy.go b/proxy/proxy.go index 41a9122b..5cdf78b9 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -348,12 +348,16 @@ func (s *Proxy) GetRemoteSendChansByCluster(clusterID int32) map[history.Cluster return result } -// RemoveRemoteSendChan removes the send channel for a specific shard ID -func (s *Proxy) RemoveRemoteSendChan(shardID history.ClusterShardID) { +// RemoveRemoteSendChan removes the send channel for a specific shard ID only if it matches the provided channel +func (s *Proxy) RemoveRemoteSendChan(shardID history.ClusterShardID, expectedChan chan RoutedMessage) { s.remoteSendChannelsMu.Lock() defer s.remoteSendChannelsMu.Unlock() - delete(s.remoteSendChannels, shardID) - s.logger.Info("Removed remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + if currentChan, exists := s.remoteSendChannels[shardID]; exists && currentChan == expectedChan { + delete(s.remoteSendChannels, shardID) + s.logger.Info("Removed remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } else { + s.logger.Info("Skipped removing remote send channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } } // SetLocalAckChan registers an ack channel for a specific shard ID @@ -372,11 +376,23 @@ func (s *Proxy) GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, return ch, exists } -// RemoveLocalAckChan removes the ack channel for a specific shard ID -func (s *Proxy) RemoveLocalAckChan(shardID history.ClusterShardID) { +// RemoveLocalAckChan removes the ack channel for a specific shard ID only if it matches the provided channel +func (s *Proxy) RemoveLocalAckChan(shardID history.ClusterShardID, expectedChan chan RoutedAck) { s.logger.Info("Remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) s.localAckChannelsMu.Lock() defer s.localAckChannelsMu.Unlock() + if currentChan, exists := s.localAckChannels[shardID]; exists && currentChan == expectedChan { + delete(s.localAckChannels, shardID) + } else { + s.logger.Info("Skipped removing local ack channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } +} + +// ForceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID +func (s *Proxy) ForceRemoveLocalAckChan(shardID history.ClusterShardID) { + s.logger.Info("Force remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + s.localAckChannelsMu.Lock() + defer s.localAckChannelsMu.Unlock() delete(s.localAckChannels, shardID) } @@ -396,7 +412,9 @@ func (s *Proxy) GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (cont return cancelFunc, exists } -// RemoveLocalReceiverCancelFunc removes the cancel function for a local receiver for a specific shard ID +// RemoveLocalReceiverCancelFunc unconditionally removes the cancel function for a local receiver for a specific shard ID +// Note: Functions cannot be compared in Go, so we use unconditional removal. +// The race condition is primarily with channels; TerminatePreviousLocalReceiver handles forced cleanup. func (s *Proxy) RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) { s.logger.Info("Remove local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) s.localReceiverCancelFuncsMu.Lock() @@ -413,10 +431,8 @@ func (s *Proxy) TerminatePreviousLocalReceiver(serverShardID history.ClusterShar // Cancel the previous receiver's context prevCancelFunc() - // Remove the cancel function from tracking + // Force remove the cancel function and ack channel from tracking s.RemoveLocalReceiverCancelFunc(serverShardID) - - // Also clean up the associated ack channel if it exists - s.RemoveLocalAckChan(serverShardID) + s.ForceRemoveLocalAckChan(serverShardID) } } diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 2519ff62..844783a6 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -15,6 +15,7 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" ) // proxyIDMapping stores the original source shard and task for a given proxy task ID @@ -84,43 +85,6 @@ func (b *proxyIDRingBuffer) Append(proxyID int64, sourceShard history.ClusterSha b.size++ } -// PopUpTo pops and aggregates mappings up to and including the given watermark (proxy ID). -// Returns per-source-shard the maximal original source task acknowledged. -func (b *proxyIDRingBuffer) PopUpTo(watermark int64) map[history.ClusterShardID]int64 { - result := make(map[history.ClusterShardID]int64) - if b.size == 0 { - return result - } - // if watermark is before head, nothing to pop - if watermark < b.startProxyID { - return result - } - count64 := watermark - b.startProxyID + 1 - if count64 <= 0 { - return result - } - count := int(count64) - if count > b.size { - count = b.size - } - for i := 0; i < count; i++ { - idx := (b.head + i) % len(b.entries) - m := b.entries[idx] - // Skip zero entries (shouldn't happen unless contiguity fix inserted holes) - if m.sourceShard.ClusterID == 0 && m.sourceShard.ShardID == 0 { - continue - } - if current, ok := result[m.sourceShard]; !ok || m.sourceTask > current { - result[m.sourceShard] = m.sourceTask - } - } - // advance head - b.head = (b.head + count) % len(b.entries) - b.size -= count - b.startProxyID += int64(count) - return result -} - // AggregateUpTo computes the per-shard aggregation up to watermark without removing entries. // Returns (aggregation, count) where count is the number of entries covered. func (b *proxyIDRingBuffer) AggregateUpTo(watermark int64) (map[history.ClusterShardID]int64, int) { @@ -185,6 +149,9 @@ type proxyStreamSender struct { idRing *proxyIDRingBuffer // prevAckBySource tracks the last ack level sent per original source shard prevAckBySource map[history.ClusterShardID]int64 + // keepalive state + lastMsgSendTime time.Time + lastSentWatermark int64 } // buildSenderDebugSnapshot returns a snapshot of the sender's ring buffer and related state @@ -232,23 +199,16 @@ func (s *proxyStreamSender) buildSenderDebugSnapshot(maxEntries int) *SenderDebu } func (s *proxyStreamSender) Run( - targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, shutdownChan channel.ShutdownOnce, ) { - s.logger = log.With(s.logger, - tag.NewStringTag("role", "sender"), - ) + s.streamID = BuildSenderStreamID(s.sourceShardID, s.targetShardID) + s.logger = log.With(s.logger, tag.NewStringTag("streamID", s.streamID), tag.NewStringTag("role", "sender")) + s.logger.Info("proxyStreamSender Run") defer s.logger.Info("proxyStreamSender Run finished") - // Register this sender as the owner of the shard for the duration of the stream - s.shardManager.RegisterShard(s.targetShardID) - defer s.shardManager.UnregisterShard(s.targetShardID) - - // Register local stream tracking for sender (short id, include role) s.streamTracker = GetGlobalStreamTracker() - s.streamID = BuildSenderStreamID(s.sourceShardID, s.targetShardID) - s.logger = log.With(s.logger, tag.NewStringTag("streamID", s.streamID)) s.streamTracker.RegisterStream( s.streamID, "StreamWorkflowReplicationMessages", @@ -259,7 +219,6 @@ func (s *proxyStreamSender) Run( ) defer s.streamTracker.UnregisterStream(s.streamID) - wg := sync.WaitGroup{} // lazy init maps s.mu.Lock() if s.idRing == nil { @@ -274,16 +233,26 @@ func (s *proxyStreamSender) Run( s.sendMsgChan = make(chan RoutedMessage, 100) s.proxy.SetRemoteSendChan(s.targetShardID, s.sendMsgChan) - defer s.proxy.RemoveRemoteSendChan(s.targetShardID) + defer s.proxy.RemoveRemoteSendChan(s.targetShardID, s.sendMsgChan) + + registeredAt := s.shardManager.RegisterShard(s.targetShardID) + defer s.shardManager.UnregisterShard(s.targetShardID, registeredAt) + wg := sync.WaitGroup{} wg.Add(2) go func() { defer wg.Done() - _ = s.sendReplicationMessages(targetStreamServer, shutdownChan) + err := s.sendReplicationMessages(sourceStreamServer, shutdownChan) + if err != nil { + s.logger.Error("proxyStreamSender sendReplicationMessages error", tag.Error(err)) + } }() go func() { defer wg.Done() - _ = s.recvAck(targetStreamServer, shutdownChan) + err := s.recvAck(sourceStreamServer, shutdownChan) + if err != nil { + s.logger.Error("proxyStreamSender recvAck error", tag.Error(err)) + } }() // Wait for shutdown signal (triggered by receiver or stream errors) <-shutdownChan.Channel() @@ -296,15 +265,16 @@ func (s *proxyStreamSender) Run( // channel for aggregation/routing. Non-blocking shutdown is coordinated via // shutdownChan. This is a placeholder implementation. func (s *proxyStreamSender) recvAck( - targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, shutdownChan channel.ShutdownOnce, ) error { + s.logger.Info("proxyStreamSender recvAck started") defer func() { - s.logger.Info("Shutdown targetStreamServer.Recv loop.") + s.logger.Info("proxyStreamSender recvAck finished") shutdownChan.Shutdown() }() for !shutdownChan.IsShutdown() { - req, err := targetStreamServer.Recv() + req, err := sourceStreamServer.Recv() if err == io.EOF { return nil } @@ -316,8 +286,6 @@ func (s *proxyStreamSender) recvAck( if attr, ok := req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { proxyAckWatermark := attr.SyncReplicationState.InclusiveLowWatermark - // Log incoming upstream ACK watermark - s.logger.Info("Sender received upstream ACK", tag.NewInt64("inclusive_low", proxyAckWatermark)) // track sync watermark s.streamTracker.UpdateStreamSyncReplicationState(s.streamID, proxyAckWatermark, nil) s.streamTracker.UpdateStream(s.streamID) @@ -326,6 +294,8 @@ func (s *proxyStreamSender) recvAck( shardToAck, pendingDiscard := s.idRing.AggregateUpTo(proxyAckWatermark) s.mu.Unlock() + s.logger.Info("Sender received upstream ACK", tag.NewInt64("inclusive_low", proxyAckWatermark), tag.NewStringTag("shardToAck", fmt.Sprintf("%v", shardToAck)), tag.NewInt("pendingDiscard", pendingDiscard)) + if len(shardToAck) > 0 { sent := make(map[history.ClusterShardID]bool, len(shardToAck)) logged := make(map[history.ClusterShardID]bool, len(shardToAck)) @@ -455,13 +425,18 @@ func (s *proxyStreamSender) recvAck( // sendReplicationMessages sends replication messages read from sendMsgChan to // the remote side. This is a placeholder implementation. func (s *proxyStreamSender) sendReplicationMessages( - targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, shutdownChan channel.ShutdownOnce, ) error { + s.logger.Info("proxyStreamSender sendReplicationMessages started") defer func() { - s.logger.Info("Shutdown sendMsgChan loop.") + s.logger.Info("proxyStreamSender sendReplicationMessages finished") shutdownChan.Shutdown() }() + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + for !shutdownChan.IsShutdown() { if s.sendMsgChan == nil { return nil @@ -471,58 +446,97 @@ func (s *proxyStreamSender) sendReplicationMessages( if !ok { return nil } + s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: routed.Resp=%p", routed.Resp), tag.NewStringTag("routed", fmt.Sprintf("%v", routed))) resp := routed.Resp - if m, ok := resp.Attributes.(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages); ok && m.Messages != nil { - // rewrite task ids - s.mu.Lock() - var originalIDs []int64 - var proxyIDs []int64 - // capture original exclusive high watermark before rewriting - originalHigh := m.Messages.ExclusiveHighWatermark + m, ok := resp.Attributes.(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages) + if !ok || m.Messages == nil { + return nil + } + + sourceTaskIds := make([]int64, 0, len(m.Messages.ReplicationTasks)) + for _, t := range m.Messages.ReplicationTasks { + sourceTaskIds = append(sourceTaskIds, t.SourceTaskId) + } + s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d ids=%v", m.Messages.ExclusiveHighWatermark, sourceTaskIds)) + + // rewrite task ids + s.mu.Lock() + var originalIDs []int64 + var proxyIDs []int64 + // capture original exclusive high watermark before rewriting + originalHigh := m.Messages.ExclusiveHighWatermark + s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d", m.Messages.ExclusiveHighWatermark, originalHigh)) + // Ensure exclusive high watermark is in proxy task ID space + if len(m.Messages.ReplicationTasks) > 0 { for _, t := range m.Messages.ReplicationTasks { // allocate proxy task id s.nextProxyTaskID++ proxyID := s.nextProxyTaskID // remember original original := t.SourceTaskId - originalIDs = append(originalIDs, original) s.idRing.Append(proxyID, routed.SourceShard, original) // rewrite id t.SourceTaskId = proxyID if t.RawTaskInfo != nil { t.RawTaskInfo.TaskId = proxyID } + originalIDs = append(originalIDs, original) proxyIDs = append(proxyIDs, proxyID) } - s.mu.Unlock() - // Log mapping from original -> proxy IDs - s.logger.Info(fmt.Sprintf("Sender forwarding ReplicationTasks from shard %s: original=%v proxy=%v", ClusterShardIDtoString(routed.SourceShard), originalIDs, proxyIDs)) - - // Ensure exclusive high watermark is in proxy task ID space - if len(m.Messages.ReplicationTasks) > 0 { - m.Messages.ExclusiveHighWatermark = m.Messages.ReplicationTasks[len(m.Messages.ReplicationTasks)-1].SourceTaskId + 1 - } else { - // No tasks in this batch: allocate a synthetic proxy task id mapping - s.mu.Lock() - s.nextProxyTaskID++ - proxyHigh := s.nextProxyTaskID - s.idRing.Append(proxyHigh, routed.SourceShard, originalHigh) - m.Messages.ExclusiveHighWatermark = proxyHigh - s.mu.Unlock() - } - // track sent tasks ids and high watermark - ids := make([]int64, 0, len(m.Messages.ReplicationTasks)) - for _, t := range m.Messages.ReplicationTasks { - ids = append(ids, t.SourceTaskId) - } - s.streamTracker.UpdateStreamLastTaskIDs(s.streamID, ids) - s.streamTracker.UpdateStreamReplicationMessages(s.streamID, m.Messages.ExclusiveHighWatermark) - s.streamTracker.UpdateStreamSenderDebug(s.streamID, s.buildSenderDebugSnapshot(20)) - s.streamTracker.UpdateStream(s.streamID) + m.Messages.ExclusiveHighWatermark = m.Messages.ReplicationTasks[len(m.Messages.ReplicationTasks)-1].SourceTaskId + 1 + } else { + // No tasks in this batch: allocate a synthetic proxy task id mapping + s.nextProxyTaskID++ + proxyHigh := s.nextProxyTaskID + s.idRing.Append(proxyHigh, routed.SourceShard, originalHigh) + originalIDs = append(originalIDs, originalHigh) + proxyIDs = append(proxyIDs, proxyHigh) + m.Messages.ExclusiveHighWatermark = proxyHigh + s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d proxy_high=%d original", m.Messages.ExclusiveHighWatermark, originalHigh, proxyHigh)) } - if err := targetStreamServer.Send(resp); err != nil { + s.mu.Unlock() + // Log mapping from original -> proxy IDs + s.logger.Info(fmt.Sprintf("Sender sending ReplicationTasks from shard %s: original=%v proxy=%v", ClusterShardIDtoString(routed.SourceShard), originalIDs, proxyIDs), tag.NewInt64("exclusive_high", m.Messages.ExclusiveHighWatermark)) + + if err := sourceStreamServer.Send(resp); err != nil { return err } + s.logger.Info("Sender sent ReplicationTasks", tag.NewStringTag("sourceShard", ClusterShardIDtoString(routed.SourceShard)), tag.NewInt64("exclusive_high", m.Messages.ExclusiveHighWatermark)) + + // Update keepalive state + s.mu.Lock() + s.lastMsgSendTime = time.Now() + s.lastSentWatermark = m.Messages.ExclusiveHighWatermark + s.mu.Unlock() + + s.streamTracker.UpdateStreamLastTaskIDs(s.streamID, sourceTaskIds) + s.streamTracker.UpdateStreamReplicationMessages(s.streamID, m.Messages.ExclusiveHighWatermark) + s.streamTracker.UpdateStreamSenderDebug(s.streamID, s.buildSenderDebugSnapshot(20)) + s.streamTracker.UpdateStream(s.streamID) + case <-ticker.C: + // Send keepalive if idle for 1 second + s.mu.Lock() + shouldSendKeepalive := s.lastSentWatermark > 0 && time.Since(s.lastMsgSendTime) >= 1*time.Second + watermark := s.lastSentWatermark + s.mu.Unlock() + + if shouldSendKeepalive { + keepaliveResp := &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ReplicationTasks: []*replicationv1.ReplicationTask{}, + ExclusiveHighWatermark: watermark, + }, + }, + } + s.logger.Info("Sender sending keepalive message", tag.NewInt64("watermark", watermark)) + if err := sourceStreamServer.Send(keepaliveResp); err != nil { + return err + } + s.mu.Lock() + s.lastMsgSendTime = time.Now() + s.mu.Unlock() + } case <-shutdownChan.Channel(): return nil } @@ -549,6 +563,10 @@ type proxyStreamReceiver struct { lastExclusiveHighOriginal int64 streamID string streamTracker *StreamTracker + // keepalive state + ackMu sync.Mutex + lastAckSendTime time.Time + lastSentAck *adminservice.StreamWorkflowReplicationMessagesRequest } // buildReceiverDebugSnapshot builds receiver ACK aggregation state for debugging @@ -570,7 +588,9 @@ func (r *proxyStreamReceiver) Run( // Terminate any previous local receiver for this shard r.proxy.TerminatePreviousLocalReceiver(r.sourceShardID) + r.streamID = BuildReceiverStreamID(r.sourceShardID, r.targetShardID) r.logger = log.With(r.logger, + tag.NewStringTag("streamID", r.streamID), tag.NewStringTag("client", ClusterShardIDtoString(r.targetShardID)), tag.NewStringTag("server", ClusterShardIDtoString(r.sourceShardID)), tag.NewStringTag("stream-source-shard", ClusterShardIDtoString(r.sourceShardID)), @@ -609,7 +629,7 @@ func (r *proxyStreamReceiver) Run( r.proxy.SetLocalAckChan(r.sourceShardID, r.ackChan) r.proxy.SetLocalReceiverCancelFunc(r.sourceShardID, cancel) defer func() { - r.proxy.RemoveLocalAckChan(r.sourceShardID) + r.proxy.RemoveLocalAckChan(r.sourceShardID, r.ackChan) r.proxy.RemoveLocalReceiverCancelFunc(r.sourceShardID) }() @@ -618,8 +638,6 @@ func (r *proxyStreamReceiver) Run( r.lastSentMin = 0 // Register a new local stream for tracking (short id, include role) - r.streamID = BuildReceiverStreamID(r.sourceShardID, r.targetShardID) - r.logger = log.With(r.logger, tag.NewStringTag("streamID", r.streamID)) r.streamTracker = GetGlobalStreamTracker() r.streamTracker.RegisterStream( r.streamID, @@ -659,6 +677,9 @@ func (r *proxyStreamReceiver) recvReplicationMessages( sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient, shutdownChan channel.ShutdownOnce, ) error { + r.logger.Info("proxyStreamReceiver recvReplicationMessages started") + defer r.logger.Info("proxyStreamReceiver recvReplicationMessages finished") + for !shutdownChan.IsShutdown() { resp, err := sourceStreamClient.Recv() if err == io.EOF { @@ -714,8 +735,33 @@ func (r *proxyStreamReceiver) recvReplicationMessages( localShardsToSend := r.proxy.GetRemoteSendChansByCluster(r.targetShardID.ClusterID) r.logger.Info("Going to broadcast high watermark to local shards", tag.NewStringTag("localShardsToSend", fmt.Sprintf("%v", localShardsToSend))) for targetShardID, sendChan := range localShardsToSend { - r.logger.Info("Sending high watermark to target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) - sendChan <- msg + // Clone the message for each recipient to prevent shared mutation + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + r.logger.Info(fmt.Sprintf("Sending high watermark to target shard, msg.Resp=%p", clonedMsg.Resp), tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark), tag.NewStringTag("msg", fmt.Sprintf("%v", clonedMsg))) + // Use non-blocking send with recover to handle closed channels + func() { + defer func() { + if panicErr := recover(); panicErr != nil { + // Channel was closed while we were trying to send + r.logger.Warn("Failed to send high watermark to target shard (channel closed)", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), + tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) + } + }() + select { + case sendChan <- clonedMsg: + // Message sent successfully + default: + // Channel is full or closed, log and skip + r.logger.Warn("Failed to send high watermark to target shard (channel full)", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), + tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) + } + }() } // send to all remote shards on other nodes as well remoteShards, err := r.shardManager.GetRemoteShardsForPeer("") @@ -729,7 +775,13 @@ func (r *proxyStreamReceiver) recvReplicationMessages( if shard.ID.ClusterID != r.targetShardID.ClusterID { continue } - if !r.shardManager.DeliverMessagesToShardOwner(shard.ID, &msg, r.proxy, shutdownChan, r.logger) { + // Clone the message for each remote recipient + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + if !r.shardManager.DeliverMessagesToShardOwner(shard.ID, &clonedMsg, r.proxy, shutdownChan, r.logger) { r.logger.Warn("Failed to send ReplicationTasks to remote shard", tag.NewStringTag("shard", ClusterShardIDtoString(shard.ID))) } } @@ -799,11 +851,18 @@ func (r *proxyStreamReceiver) sendAck( sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient, shutdownChan channel.ShutdownOnce, ) error { + r.logger.Info("proxyStreamReceiver sendAck started") + defer r.logger.Info("proxyStreamReceiver sendAck finished") + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + for !shutdownChan.IsShutdown() { select { case routed := <-r.ackChan: // Update per-target watermark if attr, ok := routed.Req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { + r.logger.Info("Receiver received upstream ACK", tag.NewInt64("inclusive_low", attr.SyncReplicationState.InclusiveLowWatermark), tag.NewStringTag("targetShard", ClusterShardIDtoString(routed.TargetShard))) r.ackByTarget[routed.TargetShard] = attr.SyncReplicationState.InclusiveLowWatermark // Compute minimal watermark across targets min := int64(0) @@ -847,7 +906,32 @@ func (r *proxyStreamReceiver) sendAck( r.streamTracker.UpdateStreamReceiverDebug(r.streamID, r.buildReceiverDebugSnapshot()) } r.lastSentMin = min + + // Update keepalive state + r.ackMu.Lock() + r.lastAckSendTime = time.Now() + r.lastSentAck = aggregated + r.ackMu.Unlock() + } + } + case <-ticker.C: + // Send keepalive if idle for 1 second + r.ackMu.Lock() + shouldSendKeepalive := r.lastSentAck != nil && time.Since(r.lastAckSendTime) >= 1*time.Second + lastAck := r.lastSentAck + r.ackMu.Unlock() + + if shouldSendKeepalive { + r.logger.Info("Receiver sending keepalive ACK") + if err := sourceStreamClient.Send(lastAck); err != nil { + if err != io.EOF { + r.logger.Error("sourceStreamClient.Send keepalive encountered error", tag.Error(err)) + } + return err } + r.ackMu.Lock() + r.lastAckSendTime = time.Now() + r.ackMu.Unlock() } case <-shutdownChan.Channel(): return nil diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 025c8616..8fcb7ac6 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -24,10 +24,10 @@ type ( Start(lifetime context.Context) error // Stop shuts down the manager and leaves the cluster Stop() - // RegisterShard registers a clientShardID as owned by this proxy instance - RegisterShard(clientShardID history.ClusterShardID) - // UnregisterShard removes a clientShardID from this proxy's ownership - UnregisterShard(clientShardID history.ClusterShardID) + // RegisterShard registers a clientShardID as owned by this proxy instance and returns the registration timestamp + RegisterShard(clientShardID history.ClusterShardID) time.Time + // UnregisterShard removes a clientShardID from this proxy's ownership only if the timestamp matches + UnregisterShard(clientShardID history.ClusterShardID, expectedRegisteredAt time.Time) // GetProxyAddress returns the proxy service address for the given node name GetProxyAddress(nodeName string) (string, bool) // IsLocalShard checks if this proxy instance owns the given shard @@ -290,9 +290,9 @@ func (sm *shardManagerImpl) Stop() { sm.logger.Info("Shard manager stopped") } -func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) { +func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) time.Time { sm.logger.Info("RegisterShard", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) - sm.addLocalShard(clientShardID) + registeredAt := sm.addLocalShard(clientShardID) sm.broadcastShardChange("register", clientShardID) // Trigger memberlist metadata update to propagate NodeMeta to other nodes @@ -305,25 +305,37 @@ func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) if sm.onLocalShardChange != nil { sm.onLocalShardChange(clientShardID, true) } + return registeredAt } -func (sm *shardManagerImpl) UnregisterShard(clientShardID history.ClusterShardID) { +func (sm *shardManagerImpl) UnregisterShard(clientShardID history.ClusterShardID, expectedRegisteredAt time.Time) { sm.logger.Info("UnregisterShard", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) - sm.removeLocalShard(clientShardID) + + // Only unregister if the registration timestamp matches (prevents old senders from removing new registrations) sm.mutex.Lock() - delete(sm.localShards, ClusterShardIDtoShortString(clientShardID)) - sm.mutex.Unlock() - sm.broadcastShardChange("unregister", clientShardID) + key := ClusterShardIDtoShortString(clientShardID) + if shardInfo, exists := sm.localShards[key]; exists && shardInfo.Created.Equal(expectedRegisteredAt) { + delete(sm.localShards, key) + // Update metrics after local shards change + sm.mutex.Unlock() - // Trigger memberlist metadata update to propagate NodeMeta to other nodes - if sm.ml != nil { - if err := sm.ml.UpdateNode(0); err != nil { // 0 timeout means immediate update - sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) + sm.removeLocalShard(clientShardID) + sm.broadcastShardChange("unregister", clientShardID) + + // Trigger memberlist metadata update to propagate NodeMeta to other nodes + if sm.ml != nil { + if err := sm.ml.UpdateNode(0); err != nil { // 0 timeout means immediate update + sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) + } } - } - // Notify listeners - if sm.onLocalShardChange != nil { - sm.onLocalShardChange(clientShardID, false) + // Notify listeners + if sm.onLocalShardChange != nil { + sm.onLocalShardChange(clientShardID, false) + } + sm.logger.Info("UnregisterShard completed", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) + } else { + sm.mutex.Unlock() + sm.logger.Info("Skipped unregistering shard (timestamp mismatch or already unregistered)", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) } } @@ -494,11 +506,25 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( ) bool { logger = log.With(logger, tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewInt64("ack", ack)) if ackCh, ok := proxy.GetLocalAckChan(sourceShard); ok { - select { - case ackCh <- *routedAck: - logger.Info("Delivered ACK to local shard owner") + delivered := false + func() { + defer func() { + if panicErr := recover(); panicErr != nil { + logger.Warn("Failed to deliver ACK to local shard owner (channel closed)") + } + }() + select { + case ackCh <- *routedAck: + logger.Info("Delivered ACK to local shard owner") + delivered = true + case <-shutdownChan.Channel(): + // Shutdown signal received + } + }() + if delivered { return true - case <-shutdownChan.Channel(): + } + if shutdownChan.IsShutdown() { return false } } @@ -542,11 +568,25 @@ func (sm *shardManagerImpl) DeliverMessagesToShardOwner( // Try local delivery first if ch, ok := proxy.GetRemoteSendChan(targetShard); ok { - select { - case ch <- *routedMsg: - logger.Info("Delivered messages to local shard owner") + delivered := false + func() { + defer func() { + if panicErr := recover(); panicErr != nil { + logger.Warn("Failed to deliver messages to local shard owner (channel closed)") + } + }() + select { + case ch <- *routedMsg: + logger.Info("Delivered messages to local shard owner") + delivered = true + case <-shutdownChan.Channel(): + // Shutdown signal received + } + }() + if delivered { return true - case <-shutdownChan.Channel(): + } + if shutdownChan.IsShutdown() { return false } } @@ -702,7 +742,8 @@ func (sd *shardDelegate) NotifyMsg(data []byte) { localShard, ok := sd.manager.localShards[ClusterShardIDtoShortString(msg.ClientShard)] if ok { if localShard.Created.Before(msg.Timestamp) { - sd.manager.UnregisterShard(msg.ClientShard) + // Force unregister the local shard by passing its own timestamp + sd.manager.UnregisterShard(msg.ClientShard, localShard.Created) } } } @@ -733,13 +774,15 @@ func (sd *shardDelegate) MergeRemoteState(buf []byte, join bool) { tag.NewStringTag("state", fmt.Sprintf("%+v", state))) } -func (sm *shardManagerImpl) addLocalShard(shard history.ClusterShardID) { +func (sm *shardManagerImpl) addLocalShard(shard history.ClusterShardID) time.Time { sm.mutex.Lock() defer sm.mutex.Unlock() key := ClusterShardIDtoShortString(shard) - sm.localShards[key] = ShardInfo{ID: shard, Created: time.Now()} + now := time.Now() + sm.localShards[key] = ShardInfo{ID: shard, Created: now} + return now } func (sm *shardManagerImpl) removeLocalShard(shard history.ClusterShardID) { @@ -779,8 +822,8 @@ type noopShardManager struct{} func (nsm *noopShardManager) Start(_ context.Context) error { return nil } func (nsm *noopShardManager) Stop() {} -func (nsm *noopShardManager) RegisterShard(history.ClusterShardID) {} -func (nsm *noopShardManager) UnregisterShard(history.ClusterShardID) {} +func (nsm *noopShardManager) RegisterShard(history.ClusterShardID) time.Time { return time.Now() } +func (nsm *noopShardManager) UnregisterShard(history.ClusterShardID, time.Time) {} func (nsm *noopShardManager) GetShardOwner(history.ClusterShardID) (string, bool) { return "", false } func (nsm *noopShardManager) GetProxyAddress(string) (string, bool) { return "", false } func (nsm *noopShardManager) IsLocalShard(history.ClusterShardID) bool { return true } @@ -815,10 +858,26 @@ func (nsm *noopShardManager) SetOnRemoteShardChange(handler func(peer string, sh func (nsm *noopShardManager) DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool { if proxy != nil { if ackCh, ok := proxy.GetLocalAckChan(srcShard); ok { - select { - case ackCh <- *routedAck: + delivered := false + func() { + defer func() { + if panicErr := recover(); panicErr != nil { + if logger != nil { + logger.Warn("Failed to deliver ACK to local shard owner (channel closed)") + } + } + }() + select { + case ackCh <- *routedAck: + delivered = true + case <-shutdownChan.Channel(): + // Shutdown signal received + } + }() + if delivered { return true - case <-shutdownChan.Channel(): + } + if shutdownChan.IsShutdown() { return false } } @@ -829,10 +888,26 @@ func (nsm *noopShardManager) DeliverAckToShardOwner(srcShard history.ClusterShar func (nsm *noopShardManager) DeliverMessagesToShardOwner(targetShard history.ClusterShardID, routedMsg *RoutedMessage, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) bool { if proxy != nil { if ch, ok := proxy.GetRemoteSendChan(targetShard); ok { - select { - case ch <- *routedMsg: + delivered := false + func() { + defer func() { + if panicErr := recover(); panicErr != nil { + if logger != nil { + logger.Warn("Failed to deliver messages to local shard owner (channel closed)") + } + } + }() + select { + case ch <- *routedMsg: + delivered = true + case <-shutdownChan.Channel(): + // Shutdown signal received + } + }() + if delivered { return true - case <-shutdownChan.Channel(): + } + if shutdownChan.IsShutdown() { return false } } diff --git a/proxy/stream_tracker.go b/proxy/stream_tracker.go index b6214814..b53086c5 100644 --- a/proxy/stream_tracker.go +++ b/proxy/stream_tracker.go @@ -159,28 +159,28 @@ func GetGlobalStreamTracker() *StreamTracker { // BuildSenderStreamID returns the canonical sender stream ID. func BuildSenderStreamID(source, target history.ClusterShardID) string { - return fmt.Sprintf("snd-%s-%s", ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target)) + return fmt.Sprintf("snd-%s", ClusterShardIDtoShortString(target)) } // BuildReceiverStreamID returns the canonical receiver stream ID. func BuildReceiverStreamID(source, target history.ClusterShardID) string { - return fmt.Sprintf("rcv-%s-%s", ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target)) + return fmt.Sprintf("rcv-%s", ClusterShardIDtoShortString(source)) } // BuildForwarderStreamID returns the canonical forwarder stream ID. // Note: forwarder uses server-first ordering in the ID. func BuildForwarderStreamID(client, server history.ClusterShardID) string { - return fmt.Sprintf("fwd-%s-%s", ClusterShardIDtoShortString(server), ClusterShardIDtoShortString(client)) + return fmt.Sprintf("fwd-snd-%s", ClusterShardIDtoShortString(server)) } // BuildIntraProxySenderStreamID returns the server-side intra-proxy stream ID for a peer and shard pair. func BuildIntraProxySenderStreamID(peer string, source, target history.ClusterShardID) string { - return fmt.Sprintf("ip-snd-%s-%s|%s", peer, ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target)) + return fmt.Sprintf("ip-snd-%s-%s|%s", ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target), peer) } // BuildIntraProxyReceiverStreamID returns the client-side intra-proxy stream ID for a peer and shard pair. func BuildIntraProxyReceiverStreamID(peer string, source, target history.ClusterShardID) string { - return fmt.Sprintf("ip-rcv-%s-%s|%s", peer, ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target)) + return fmt.Sprintf("ip-rcv-%s-%s|%s", ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target), peer) } // formatDurationSeconds formats a duration in seconds to a readable string From ccdb29df435d72728e278d94be80153769e1bc02 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Mon, 13 Oct 2025 10:06:42 -0700 Subject: [PATCH 13/38] add ring_max_size to debug --- proxy/debug.go | 1 + proxy/proxy_streams.go | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/proxy/debug.go b/proxy/debug.go index 60a5abc9..8add32e8 100644 --- a/proxy/debug.go +++ b/proxy/debug.go @@ -23,6 +23,7 @@ type ( SenderDebugInfo struct { RingStartProxyID int64 `json:"ring_start_proxy_id"` RingSize int `json:"ring_size"` + RingMaxSize int `json:"ring_max_size"` RingCapacity int `json:"ring_capacity"` RingHead int `json:"ring_head"` NextProxyTaskID int64 `json:"next_proxy_task_id"` diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 844783a6..1835ee37 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -31,6 +31,7 @@ type proxyIDRingBuffer struct { entries []proxyIDMapping head int size int + maxSize int // Maximum size ever reached startProxyID int64 // proxyID of the current head element when size > 0 } @@ -76,6 +77,9 @@ func (b *proxyIDRingBuffer) Append(proxyID int64, sourceShard history.ClusterSha pos := (b.head + b.size) % len(b.entries) b.entries[pos] = proxyIDMapping{sourceShard: history.ClusterShardID{}, sourceTask: 0} b.size++ + if b.size > b.maxSize { + b.maxSize = b.size + } expected++ } } @@ -83,6 +87,9 @@ func (b *proxyIDRingBuffer) Append(proxyID int64, sourceShard history.ClusterSha pos := (b.head + b.size) % len(b.entries) b.entries[pos] = proxyIDMapping{sourceShard: sourceShard, sourceTask: sourceTask} b.size++ + if b.size > b.maxSize { + b.maxSize = b.size + } } // AggregateUpTo computes the per-shard aggregation up to watermark without removing entries. @@ -172,6 +179,7 @@ func (s *proxyStreamSender) buildSenderDebugSnapshot(maxEntries int) *SenderDebu if s.idRing != nil { info.RingStartProxyID = s.idRing.startProxyID info.RingSize = s.idRing.size + info.RingMaxSize = s.idRing.maxSize info.RingCapacity = len(s.idRing.entries) info.RingHead = s.idRing.head From 939d6b2aef57991d42314975dc723b18adacaa1a Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Wed, 15 Oct 2025 21:47:17 -0700 Subject: [PATCH 14/38] fix panic; fix memberlist join issue --- proxy/intra_proxy_router.go | 6 +- proxy/shard_manager.go | 114 ++++++++++++++++++++++++++++++++---- 2 files changed, 106 insertions(+), 14 deletions(-) diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 48b475c9..27bd9d50 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -98,7 +98,10 @@ func (s *intraProxyStreamSender) Run( // recvAck reads ACKs from the peer and routes them to the source shard owner. func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) error { s.logger.Info("intraProxyStreamSender recvAck") - defer s.logger.Info("intraProxyStreamSender recvAck finished") + defer func() { + s.logger.Info("intraProxyStreamSender recvAck finished") + shutdownChan.Shutdown() + }() for !shutdownChan.IsShutdown() { req, err := s.sourceStreamServer.Recv() @@ -107,7 +110,6 @@ func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) erro return nil } if err != nil { - shutdownChan.Shutdown() s.logger.Error("intraProxyStreamSender recvAck encountered error", tag.Error(err)) return err } diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 8fcb7ac6..e71edd35 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -77,6 +77,10 @@ type ( // Local shards owned by this node, keyed by short id localShards map[string]ShardInfo intraMgr *intraProxyManager + // Join retry control + stopJoinRetry chan struct{} + joinWg sync.WaitGroup + joinLoopRunning bool } // shardDelegate implements memberlist.Delegate for shard state management @@ -119,11 +123,12 @@ func NewShardManager(configProvider config.ConfigProvider, logger log.Logger) (S } sm := &shardManagerImpl{ - config: cfg, - logger: logger, - delegate: delegate, - localShards: make(map[string]ShardInfo), - intraMgr: nil, + config: cfg, + logger: logger, + delegate: delegate, + localShards: make(map[string]ShardInfo), + intraMgr: nil, + stopJoinRetry: make(chan struct{}), } delegate.manager = sm @@ -245,13 +250,7 @@ func (sm *shardManagerImpl) Start(lifetime context.Context) error { // Join existing cluster if configured if len(sm.config.JoinAddrs) > 0 { - sm.logger.Info("Attempting to join cluster", tag.NewStringTag("joinAddrs", fmt.Sprintf("%v", sm.config.JoinAddrs))) - num, err := ml.Join(sm.config.JoinAddrs) - if err != nil { - sm.logger.Warn("Failed to join some cluster members", tag.Error(err)) - } else { - sm.logger.Info("Joined memberlist cluster", tag.NewStringTag("members", strconv.Itoa(num))) - } + sm.startJoinLoop() } sm.logger.Info("Shard manager started", @@ -273,6 +272,10 @@ func (sm *shardManagerImpl) Stop() { } sm.mutex.Unlock() + // Stop any ongoing join retry + close(sm.stopJoinRetry) + sm.joinWg.Wait() + // Leave the cluster gracefully err := sm.ml.Leave(5 * time.Second) if err != nil { @@ -290,6 +293,83 @@ func (sm *shardManagerImpl) Stop() { sm.logger.Info("Shard manager stopped") } +// startJoinLoop starts the join retry loop if not already running +func (sm *shardManagerImpl) startJoinLoop() { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + if sm.joinLoopRunning { + sm.logger.Info("Join loop already running, skipping") + return + } + + sm.logger.Info("Starting join loop") + sm.joinLoopRunning = true + sm.joinWg.Add(1) + go sm.retryJoinCluster() +} + +// retryJoinCluster attempts to join the cluster infinitely with exponential backoff +func (sm *shardManagerImpl) retryJoinCluster() { + defer func() { + sm.joinWg.Done() + sm.mutex.Lock() + sm.joinLoopRunning = false + sm.mutex.Unlock() + }() + + const ( + initialInterval = 2 * time.Second + maxInterval = 60 * time.Second + ) + + interval := initialInterval + attempt := 0 + + sm.logger.Info("Starting join retry loop", + tag.NewStringTag("joinAddrs", fmt.Sprintf("%v", sm.config.JoinAddrs))) + + for { + attempt++ + + sm.mutex.RLock() + ml := sm.ml + joinAddrs := sm.config.JoinAddrs + sm.mutex.RUnlock() + + if ml == nil { + sm.logger.Warn("Memberlist not initialized, stopping retry") + return + } + + sm.logger.Info("Attempting to join cluster", + tag.NewStringTag("attempt", strconv.Itoa(attempt)), + tag.NewStringTag("joinAddrs", fmt.Sprintf("%v", joinAddrs))) + + num, err := ml.Join(joinAddrs) + if err != nil { + sm.logger.Warn("Failed to join cluster", tag.Error(err)) + + // Exponential backoff with cap + select { + case <-sm.stopJoinRetry: + sm.logger.Info("Join retry cancelled") + return + case <-time.After(interval): + interval *= 2 + if interval > maxInterval { + interval = maxInterval + } + } + } else { + sm.logger.Info("Successfully joined memberlist cluster", + tag.NewStringTag("members", strconv.Itoa(num)), + tag.NewStringTag("attempt", strconv.Itoa(attempt))) + return + } + } +} + func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) time.Time { sm.logger.Info("RegisterShard", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) registeredAt := sm.addLocalShard(clientShardID) @@ -809,6 +889,16 @@ func (sed *shardEventDelegate) NotifyLeave(node *memberlist.Node) { sed.logger.Info("Node left cluster", tag.NewStringTag("node", node.Name), tag.NewStringTag("addr", node.Addr.String())) + + // If we're now isolated and have join addresses configured, restart join loop + if sed.manager != nil && sed.manager.ml != nil { + numMembers := sed.manager.ml.NumMembers() + if numMembers == 1 && len(sed.manager.config.JoinAddrs) > 0 { + sed.logger.Info("Node is now isolated, restarting join loop", + tag.NewStringTag("numMembers", strconv.Itoa(numMembers))) + sed.manager.startJoinLoop() + } + } } func (sed *shardEventDelegate) NotifyUpdate(node *memberlist.Node) { From 4ed7e501b26d17ffdac8a23d734269f6c86b09d7 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 17 Oct 2025 15:35:49 -0700 Subject: [PATCH 15/38] fix panic --- proxy/intra_proxy_router.go | 52 ++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 27bd9d50..33bbf98e 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -719,44 +719,72 @@ func (m *intraProxyManager) ReconcilePeerStreams( tag.NewStringTag("localShards", fmt.Sprintf("%v", localShards)), ) - // Build desired set of cross-cluster pairs - desired := make(map[peerStreamKey]string) + // Build desiredReceivers receiver set of cross-cluster pairs + desiredReceivers := make(map[peerStreamKey]string) for _, l := range localShards { for peer, shards := range remoteShards { for _, r := range shards.Shards { if l.ClusterID == r.ID.ClusterID { continue } - desired[peerStreamKey{targetShard: l, sourceShard: r.ID}] = peer + desiredReceivers[peerStreamKey{targetShard: l, sourceShard: r.ID}] = peer } } } - m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("desired", fmt.Sprintf("%v", desired))) + // Build desiredSenders set: inverted direction of desiredReceivers + // Senders exist when remote shard is the target and local shard is the source + desiredSenders := make(map[peerStreamKey]string) + for _, l := range localShards { + for peer, shards := range remoteShards { + for _, r := range shards.Shards { + if l.ClusterID == r.ID.ClusterID { + continue + } + desiredSenders[peerStreamKey{targetShard: r.ID, sourceShard: l}] = peer + } + } + } + + m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("desiredReceivers", fmt.Sprintf("%v", desiredReceivers)), tag.NewStringTag("desiredSenders", fmt.Sprintf("%v", desiredSenders))) // Ensure all desired receivers exist - for key := range desired { - m.EnsureReceiverForPeerShard(p, desired[key], key.targetShard, key.sourceShard) + for key := range desiredReceivers { + m.EnsureReceiverForPeerShard(p, desiredReceivers[key], key.targetShard, key.sourceShard) } // Prune anything not desired - check := func(ps *peerState) { + check := func(peer string, ps *peerState) { // Collect keys to close for receivers + var receiversToClose []peerStreamKey for key := range ps.receivers { - if _, ok2 := desired[key]; !ok2 { - m.closePeerShardLocked(peerNodeName, ps, key) + if _, ok2 := desiredReceivers[key]; !ok2 { + receiversToClose = append(receiversToClose, key) } } + for _, key := range receiversToClose { + m.closePeerShardLocked(peer, ps, key) + } + // Collect keys to close for senders + var sendersToClose []peerStreamKey + for key := range ps.senders { + if _, ok2 := desiredSenders[key]; !ok2 { + sendersToClose = append(sendersToClose, key) + } + } + for _, key := range sendersToClose { + m.closePeerShardLocked(peer, ps, key) + } } m.streamsMu.Lock() if peerNodeName != "" { if ps, ok := m.peers[peerNodeName]; ok && ps != nil { - check(ps) + check(peerNodeName, ps) } } else { - for _, ps := range m.peers { - check(ps) + for peer, ps := range m.peers { + check(peer, ps) } } m.streamsMu.Unlock() From 5fb2b88716e5f270ac39f4f5ab44d28707e5d616 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Sun, 7 Dec 2025 23:43:19 -0800 Subject: [PATCH 16/38] update regarding cluster_connection --- config/cluster_conn_config.go | 1 + config/config.go | 8 +- config/converter.go | 2 + .../config/cluster-a-mux-client-proxy-1.yaml | 8 - .../config/cluster-a-mux-client-proxy-2.yaml | 8 - develop/config/nginx.conf | 43 +++ proxy/adminservice.go | 158 +++++++-- proxy/adminservice_test.go | 3 +- proxy/cluster_connection.go | 311 ++++++++++++++++-- proxy/cluster_connection_test.go | 10 +- proxy/debug.go | 20 +- proxy/intra_proxy_router.go | 108 +++--- proxy/proxy.go | 222 +------------ proxy/proxy_streams.go | 62 ++-- proxy/shard_manager.go | 235 ++++--------- 15 files changed, 625 insertions(+), 574 deletions(-) create mode 100644 develop/config/nginx.conf diff --git a/config/cluster_conn_config.go b/config/cluster_conn_config.go index 636e83f4..2bcf1388 100644 --- a/config/cluster_conn_config.go +++ b/config/cluster_conn_config.go @@ -16,6 +16,7 @@ type ( OutboundHealthCheck HealthCheckConfig `yaml:"outboundHealthCheck"` InboundHealthCheck HealthCheckConfig `yaml:"inboundHealthCheck"` ShardCountConfig ShardCountConfig `yaml:"shardCount"` + MemberlistConfig *MemberlistConfig `yaml:"memberlist"` } StringTranslator struct { Mappings []StringMapping `yaml:"mappings"` diff --git a/config/config.go b/config/config.go index bf1fee43..d1b28116 100644 --- a/config/config.go +++ b/config/config.go @@ -40,7 +40,6 @@ type ShardCountMode string const ( ShardCountDefault ShardCountMode = "" ShardCountLCM ShardCountMode = "lcm" - ShardCountFixed ShardCountMode = "fixed" ShardCountRouting ShardCountMode = "routing" ) @@ -155,10 +154,13 @@ type ( // TODO: Soon to be deprecated! Create an item in ClusterConnections instead HealthCheck *HealthCheckConfig `yaml:"healthCheck"` // TODO: Soon to be deprecated! Create an item in ClusterConnections instead - OutboundHealthCheck *HealthCheckConfig `yaml:"outboundHealthCheck"` + OutboundHealthCheck *HealthCheckConfig `yaml:"outboundHealthCheck"` + // TODO: Soon to be deprecated! Create an item in ClusterConnections instead + ShardCountConfig ShardCountConfig `yaml:"shardCount"` + // TODO: Soon to be deprecated! Create an item in ClusterConnections instead + MemberlistConfig *MemberlistConfig `yaml:"memberlist"` NamespaceNameTranslation NameTranslationConfig `yaml:"namespaceNameTranslation"` SearchAttributeTranslation SATranslationConfig `yaml:"searchAttributeTranslation"` - MemberlistConfig *MemberlistConfig `yaml:"memberlist"` Metrics *MetricsConfig `yaml:"metrics"` ProfilingConfig ProfilingConfig `yaml:"profiling"` Logging LoggingConfig `yaml:"logging"` diff --git a/config/converter.go b/config/converter.go index 0a07f50a..93412caf 100644 --- a/config/converter.go +++ b/config/converter.go @@ -41,6 +41,8 @@ func ToClusterConnConfig(config S2SProxyConfig) S2SProxyConfig { SearchAttributeTranslation: config.SearchAttributeTranslation, OutboundHealthCheck: flattenNilHealthCheck(config.OutboundHealthCheck), InboundHealthCheck: flattenNilHealthCheck(config.HealthCheck), + ShardCountConfig: config.ShardCountConfig, + MemberlistConfig: config.MemberlistConfig, }, }, Metrics: config.Metrics, diff --git a/develop/config/cluster-a-mux-client-proxy-1.yaml b/develop/config/cluster-a-mux-client-proxy-1.yaml index 3c54388b..4855af00 100644 --- a/develop/config/cluster-a-mux-client-proxy-1.yaml +++ b/develop/config/cluster-a-mux-client-proxy-1.yaml @@ -19,13 +19,5 @@ mux: mode: "client" client: serverAddress: "localhost:7003" -# shardCount: -# mode: "lcm" -# localShardCount: 2 -# remoteShardCount: 3 -# shardCount: -# mode: "fixed" -# localShardCount: 2 -# remoteShardCount: 3 profiling: pprofAddress: "localhost:6060" \ No newline at end of file diff --git a/develop/config/cluster-a-mux-client-proxy-2.yaml b/develop/config/cluster-a-mux-client-proxy-2.yaml index 8bdbfbb1..85aa9bf6 100644 --- a/develop/config/cluster-a-mux-client-proxy-2.yaml +++ b/develop/config/cluster-a-mux-client-proxy-2.yaml @@ -19,13 +19,5 @@ mux: mode: "client" client: serverAddress: "localhost:7003" -# shardCount: -# mode: "lcm" -# localShardCount: 2 -# remoteShardCount: 3 -# shardCount: -# mode: "fixed" -# localShardCount: 2 -# remoteShardCount: 3 profiling: pprofAddress: "localhost:6061" \ No newline at end of file diff --git a/develop/config/nginx.conf b/develop/config/nginx.conf new file mode 100644 index 00000000..fd62b4d1 --- /dev/null +++ b/develop/config/nginx.conf @@ -0,0 +1,43 @@ +worker_processes 1; + +events { + worker_connections 1024; +} + +stream { + # Proxy for source outbound (6133, 6233) => exposed at 7001 + upstream source_outbound { + least_conn; + server host.docker.internal:6133; + server host.docker.internal:6233; + } + + server { + listen 7001; + proxy_pass source_outbound; + } + + # Proxy for target outbound (6333, 6433) => exposed at 7002 + upstream target_outbound { + least_conn; + server host.docker.internal:6333; + server host.docker.internal:6433; + } + + server { + listen 7002; + proxy_pass target_outbound; + } + + # Proxy for target server ports (6334, 6434) => exposed at 7003 + upstream target_server { + least_conn; + server host.docker.internal:6334; + server host.docker.internal:6434; + } + + server { + listen 7003; + proxy_pass target_server; + } +} diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 24c890d5..5955f943 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -4,11 +4,13 @@ import ( "context" "fmt" "strconv" + "sync" "go.temporal.io/api/serviceerror" "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/client/history" servercommon "go.temporal.io/server/common" + "go.temporal.io/server/common/channel" "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" @@ -27,13 +29,16 @@ type ( adminServiceProxyServer struct { adminservice.UnimplementedAdminServiceServer - adminClient adminservice.AdminServiceClient - logger log.Logger - apiOverrides *config.APIOverridesConfig - metricLabelValues []string - reportStreamValue func(idx int32, value int32) - shardCountConfig config.ShardCountConfig - lcmParameters LCMParameters + clusterConnection *ClusterConnection + adminClient adminservice.AdminServiceClient + adminClientReverse adminservice.AdminServiceClient + logger log.Logger + apiOverrides *config.APIOverridesConfig + metricLabelValues []string + reportStreamValue func(idx int32, value int32) + shardCountConfig config.ShardCountConfig + lcmParameters LCMParameters + overrideShardCount int32 } ) @@ -41,25 +46,31 @@ type ( func NewAdminServiceProxyServer( serviceName string, adminClient adminservice.AdminServiceClient, + adminClientReverse adminservice.AdminServiceClient, apiOverrides *config.APIOverridesConfig, metricLabelValues []string, reportStreamValue func(idx int32, value int32), shardCountConfig config.ShardCountConfig, lcmParameters LCMParameters, logger log.Logger, + clusterConnection *ClusterConnection, + overrideShardCount int32, ) adminservice.AdminServiceServer { // The AdminServiceStreams will duplicate the same output for an underlying connection issue hundreds of times. // Limit their output to three times per minute logger = log.NewThrottledLogger(log.With(logger, common.ServiceTag(serviceName)), func() float64 { return 3.0 / 60.0 }) return &adminServiceProxyServer{ - adminClient: adminClient, - logger: logger, - apiOverrides: apiOverrides, - metricLabelValues: metricLabelValues, - reportStreamValue: reportStreamValue, - shardCountConfig: shardCountConfig, - lcmParameters: lcmParameters, + clusterConnection: clusterConnection, + adminClient: adminClient, + adminClientReverse: adminClientReverse, + logger: logger, + apiOverrides: apiOverrides, + metricLabelValues: metricLabelValues, + reportStreamValue: reportStreamValue, + shardCountConfig: shardCountConfig, + lcmParameters: lcmParameters, + overrideShardCount: overrideShardCount, } } @@ -103,10 +114,15 @@ func (s *adminServiceProxyServer) DescribeCluster(ctx context.Context, in0 *admi return resp, err } - if s.shardCountConfig.Mode == config.ShardCountLCM { + switch s.shardCountConfig.Mode { + case config.ShardCountLCM: // Present a fake number of shards. In LCM mode, we present the least // common multiple of both cluster shard counts. resp.HistoryShardCount = s.lcmParameters.LCM + case config.ShardCountRouting: + if s.overrideShardCount > 0 { + resp.HistoryShardCount = s.overrideShardCount + } } if s.apiOverrides != nil && s.apiOverrides.AdminService.DescribeCluster != nil { @@ -267,6 +283,9 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( return err } + // Detect intra-proxy streams early for logging/behavior toggles + isIntraProxy := common.IsIntraProxy(streamServer.Context()) + logger := log.With(s.logger, tag.NewStringTag("source", ClusterShardIDtoString(sourceClusterShardID)), tag.NewStringTag("target", ClusterShardIDtoString(targetClusterShardID))) @@ -307,8 +326,12 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( targetMetadata.Set(history.MetadataKeyServerShardID, strconv.Itoa(int(newSourceShardID.ShardID))) } + if isIntraProxy { + return s.streamIntraProxyRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID) + } + if s.shardCountConfig.Mode == config.ShardCountRouting { - return s.streamRouting() + return s.streamRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID) } forwarder := newStreamForwarder( @@ -329,10 +352,105 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( return nil } -// streamRouting: placeholder for future stream routing implementation -func (s *adminServiceProxyServer) streamRouting() error { - _ = &proxyStreamSender{} - _ = &proxyStreamReceiver{} +func (s *adminServiceProxyServer) streamIntraProxyRouting( + logger log.Logger, + streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + clientShardID history.ClusterShardID, + serverShardID history.ClusterShardID, +) error { + logger.Info("streamIntraProxyRouting started") + defer logger.Info("streamIntraProxyRouting finished") + + // Determine remote peer identity from intra-proxy headers + peerNodeName := "" + if md, ok := metadata.FromIncomingContext(streamServer.Context()); ok { + vals := md.Get(common.IntraProxyOriginProxyIDHeader) + if len(vals) > 0 { + peerNodeName = vals[0] + } + } + + // Only allow intra-proxy when at least one shard is local to this proxy instance + isLocalClient := s.clusterConnection.shardManager.IsLocalShard(clientShardID) + isLocalServer := s.clusterConnection.shardManager.IsLocalShard(serverShardID) + if isLocalClient || !isLocalServer { + logger.Info("Skipping intra-proxy between two local shards or two remote shards. Client may use outdated shard info.", + tag.NewStringTag("client", ClusterShardIDtoString(clientShardID)), + tag.NewStringTag("server", ClusterShardIDtoString(serverShardID)), + tag.NewBoolTag("isLocalClient", isLocalClient), + tag.NewBoolTag("isLocalServer", isLocalServer), + ) + return nil + } + + // Sender: handle ACKs coming from peer and forward to original owner + sender := &intraProxyStreamSender{ + logger: logger, + clusterConnection: s.clusterConnection, + peerNodeName: peerNodeName, + targetShardID: clientShardID, + sourceShardID: serverShardID, + } + + shutdownChan := channel.NewShutdownOnce() + go func() { + if err := sender.Run(streamServer, shutdownChan); err != nil { + logger.Error("intraProxyStreamSender.Run error", tag.Error(err)) + } + }() + <-shutdownChan.Channel() + return nil +} + +func (s *adminServiceProxyServer) streamRouting( + logger log.Logger, + streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + clientShardID history.ClusterShardID, + serverShardID history.ClusterShardID, +) error { + logger.Info("streamRouting started") + defer logger.Info("streamRouting stopped") + + // client: stream receiver + // server: stream sender + proxyStreamSender := &proxyStreamSender{ + logger: logger, + clusterConnection: s.clusterConnection, + sourceShardID: clientShardID, + targetShardID: serverShardID, + directionLabel: "routing", + } + + var localShardCount int32 + if s.shardCountConfig.Mode == config.ShardCountRouting { + localShardCount = s.shardCountConfig.LocalShardCount + } else { + localShardCount = s.shardCountConfig.RemoteShardCount + } + // receiver for reverse direction + proxyStreamReceiverReverse := &proxyStreamReceiver{ + logger: s.logger, + clusterConnection: s.clusterConnection, + adminClient: s.adminClientReverse, + localShardCount: localShardCount, + sourceShardID: serverShardID, + targetShardID: clientShardID, + directionLabel: "routing", + } + + shutdownChan := channel.NewShutdownOnce() + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + proxyStreamSender.Run(streamServer, shutdownChan) + }() + go func() { + defer wg.Done() + proxyStreamReceiverReverse.Run(shutdownChan) + }() + wg.Wait() + return nil } diff --git a/proxy/adminservice_test.go b/proxy/adminservice_test.go index 456b4cc2..22a7bcab 100644 --- a/proxy/adminservice_test.go +++ b/proxy/adminservice_test.go @@ -45,7 +45,8 @@ type adminProxyServerInput struct { func (s *adminserviceSuite) newAdminServiceProxyServer(in adminProxyServerInput, observer *ReplicationStreamObserver) adminservice.AdminServiceServer { return NewAdminServiceProxyServer("test-service-name", s.adminClientMock, - in.apiOverrides, in.metricLabels, observer.ReportStreamValue, config.ShardCountConfig{}, LCMParameters{}, log.NewTestLogger()) + s.adminClientMock, + in.apiOverrides, in.metricLabels, observer.ReportStreamValue, config.ShardCountConfig{}, LCMParameters{}, log.NewTestLogger(), nil, 0) } func (s *adminserviceSuite) TestAddOrUpdateRemoteCluster() { diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index 90004a5d..ade187f4 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -7,12 +7,14 @@ import ( "fmt" "io" "net" + "sync" "time" grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" "github.com/prometheus/client_golang/prometheus" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/api/adminservice/v1" + "go.temporal.io/server/client/history" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "google.golang.org/grpc" @@ -60,7 +62,20 @@ type ( inboundObserver *ReplicationStreamObserver outboundObserver *ReplicationStreamObserver shardManager ShardManager + intraMgr *intraProxyManager logger log.Logger + + // remoteSendChannels maps shard IDs to send channels for replication message routing + remoteSendChannels map[history.ClusterShardID]chan RoutedMessage + remoteSendChannelsMu sync.RWMutex + + // localAckChannels maps shard IDs to ack channels for local acknowledgment handling + localAckChannels map[history.ClusterShardID]chan RoutedAck + localAckChannelsMu sync.RWMutex + + // localReceiverCancelFuncs maps shard IDs to context cancel functions for local receiver termination + localReceiverCancelFuncs map[history.ClusterShardID]context.CancelFunc + localReceiverCancelFuncsMu sync.RWMutex } // contextAwareServer represents a startable gRPC server used to provide the Temporal interface on some connection. // IsUsable and Describe allow the caller to know and log the current state of the server. @@ -93,8 +108,11 @@ type ( nsTranslations collect.StaticBiMap[string, string] saTranslations config.SearchAttributeTranslation shardCountConfig config.ShardCountConfig - targetShardCount int32 logger log.Logger + + clusterConnection *ClusterConnection + lcmParameters LCMParameters + overrideShardCount int32 } ) @@ -104,13 +122,15 @@ func sanitizeConnectionName(name string) string { } // NewClusterConnection unpacks the connConfig and creates the inbound and outbound clients and servers. -func NewClusterConnection(lifetime context.Context, connConfig config.ClusterConnConfig, shardManager ShardManager, logger log.Logger) (*ClusterConnection, error) { +func NewClusterConnection(lifetime context.Context, connConfig config.ClusterConnConfig, logger log.Logger) (*ClusterConnection, error) { // The name is used in metrics and in the protocol for identifying the multi-client-conn. Sanitize it or else grpc.Dial will be very unhappy. sanitizedConnectionName := sanitizeConnectionName(connConfig.Name) cc := &ClusterConnection{ - lifetime: lifetime, - logger: log.With(logger, tag.NewStringTag("clusterConn", sanitizedConnectionName)), - shardManager: shardManager, + lifetime: lifetime, + logger: log.With(logger, tag.NewStringTag("clusterConn", sanitizedConnectionName)), + remoteSendChannels: make(map[history.ClusterShardID]chan RoutedMessage), + localAckChannels: make(map[history.ClusterShardID]chan RoutedAck), + localReceiverCancelFuncs: make(map[history.ClusterShardID]context.CancelFunc), } var err error cc.inboundClient, err = createClient(lifetime, sanitizedConnectionName, connConfig.LocalServer.Connection, "inbound") @@ -130,36 +150,70 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon return nil, err } + var lcmParameters LCMParameters + if connConfig.ShardCountConfig.Mode == config.ShardCountLCM { + lcmParameters = LCMParameters{ + LCM: common.LCM(connConfig.ShardCountConfig.LocalShardCount, connConfig.ShardCountConfig.RemoteShardCount), + TargetShardCount: connConfig.ShardCountConfig.LocalShardCount, + } + } + getOverrideShardCount := func(shardCountConfig config.ShardCountConfig, reverse bool) int32 { + switch shardCountConfig.Mode { + case config.ShardCountLCM: + return lcmParameters.LCM + case config.ShardCountRouting: + if reverse { + return shardCountConfig.RemoteShardCount + } + return shardCountConfig.LocalShardCount + } + return 0 + } + cc.inboundServer, cc.inboundObserver, err = createServer(lifetime, serverConfiguration{ - name: sanitizedConnectionName, - clusterDefinition: connConfig.RemoteServer, - directionLabel: "inbound", - client: cc.inboundClient, - managedClient: cc.outboundClient, - nsTranslations: nsTranslations.Inverse(), - saTranslations: saTranslations.Inverse(), - shardCountConfig: connConfig.ShardCountConfig, - targetShardCount: connConfig.ShardCountConfig.LocalShardCount, - logger: cc.logger, + name: sanitizedConnectionName, + clusterDefinition: connConfig.RemoteServer, + directionLabel: "inbound", + client: cc.inboundClient, + managedClient: cc.outboundClient, + nsTranslations: nsTranslations.Inverse(), + saTranslations: saTranslations.Inverse(), + shardCountConfig: connConfig.ShardCountConfig, + logger: cc.logger, + clusterConnection: cc, + overrideShardCount: getOverrideShardCount(connConfig.ShardCountConfig, true), + lcmParameters: lcmParameters, }) if err != nil { return nil, err } + + if connConfig.ShardCountConfig.Mode == config.ShardCountLCM { + lcmParameters.TargetShardCount = connConfig.ShardCountConfig.RemoteShardCount + } cc.outboundServer, cc.outboundObserver, err = createServer(lifetime, serverConfiguration{ - name: sanitizedConnectionName, - clusterDefinition: connConfig.LocalServer, - directionLabel: "outbound", - client: cc.outboundClient, - managedClient: cc.inboundClient, - nsTranslations: nsTranslations, - saTranslations: saTranslations, - shardCountConfig: connConfig.ShardCountConfig, - targetShardCount: connConfig.ShardCountConfig.RemoteShardCount, - logger: cc.logger, + name: sanitizedConnectionName, + clusterDefinition: connConfig.LocalServer, + directionLabel: "outbound", + client: cc.outboundClient, + managedClient: cc.inboundClient, + nsTranslations: nsTranslations, + saTranslations: saTranslations, + shardCountConfig: connConfig.ShardCountConfig, + logger: cc.logger, + clusterConnection: cc, + overrideShardCount: getOverrideShardCount(connConfig.ShardCountConfig, false), + lcmParameters: lcmParameters, }) if err != nil { return nil, err } + + if connConfig.MemberlistConfig != nil { + cc.shardManager = NewShardManager(connConfig.MemberlistConfig, logger) + cc.intraMgr = newIntraProxyManager(logger, cc, connConfig.ShardCountConfig) + } + return cc, nil } @@ -245,6 +299,15 @@ func (c *ClusterConnection) Start() { c.inboundObserver.Start(c.lifetime, c.inboundServer.Name(), "inbound") c.outboundServer.Start() c.outboundObserver.Start(c.lifetime, c.outboundServer.Name(), "outbound") + if c.shardManager != nil { + err := c.shardManager.Start(c.lifetime) + if err != nil { + c.logger.Error("Failed to start shard manager", tag.Error(err)) + } + } + if c.intraMgr != nil { + c.intraMgr.Start() + } } func (c *ClusterConnection) Describe() string { return fmt.Sprintf("[ClusterConnection connects outbound server %s to outbound client %s, inbound server %s to inbound client %s]", @@ -258,6 +321,181 @@ func (c *ClusterConnection) AcceptingOutboundTraffic() bool { return c.outboundClient.CanMakeCalls() && c.outboundServer.CanAcceptConnections() } +// GetShardInfo returns debug information about shard distribution +func (c *ClusterConnection) GetShardInfos() []ShardDebugInfo { + var shardInfos []ShardDebugInfo + if c.shardManager != nil { + shardInfos = append(shardInfos, c.shardManager.GetShardInfo()) + } + return shardInfos +} + +// GetChannelInfo returns debug information about active channels +func (c *ClusterConnection) GetChannelInfo() ChannelDebugInfo { + remoteSendChannels := make(map[string]int) + var totalSendChannels int + + // Collect remote send channel info first + c.remoteSendChannelsMu.RLock() + for shardID, ch := range c.remoteSendChannels { + shardKey := ClusterShardIDtoString(shardID) + remoteSendChannels[shardKey] = len(ch) + } + totalSendChannels = len(c.remoteSendChannels) + c.remoteSendChannelsMu.RUnlock() + + localAckChannels := make(map[string]int) + var totalAckChannels int + + // Collect local ack channel info separately + c.localAckChannelsMu.RLock() + for shardID, ch := range c.localAckChannels { + shardKey := ClusterShardIDtoString(shardID) + localAckChannels[shardKey] = len(ch) + } + totalAckChannels = len(c.localAckChannels) + c.localAckChannelsMu.RUnlock() + + return ChannelDebugInfo{ + RemoteSendChannels: remoteSendChannels, + LocalAckChannels: localAckChannels, + TotalSendChannels: totalSendChannels, + TotalAckChannels: totalAckChannels, + } +} + +// SetRemoteSendChan registers a send channel for a specific shard ID +func (c *ClusterConnection) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) { + c.logger.Info("Register remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + c.remoteSendChannelsMu.Lock() + defer c.remoteSendChannelsMu.Unlock() + c.remoteSendChannels[shardID] = sendChan +} + +// GetRemoteSendChan retrieves the send channel for a specific shard ID +func (c *ClusterConnection) GetRemoteSendChan(shardID history.ClusterShardID) (chan RoutedMessage, bool) { + c.remoteSendChannelsMu.RLock() + defer c.remoteSendChannelsMu.RUnlock() + ch, exists := c.remoteSendChannels[shardID] + return ch, exists +} + +// GetAllRemoteSendChans returns a map of all remote send channels +func (c *ClusterConnection) GetAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage { + c.remoteSendChannelsMu.RLock() + defer c.remoteSendChannelsMu.RUnlock() + + // Create a copy of the map + result := make(map[history.ClusterShardID]chan RoutedMessage, len(c.remoteSendChannels)) + for k, v := range c.remoteSendChannels { + result[k] = v + } + return result +} + +// GetRemoteSendChansByCluster returns a copy of remote send channels filtered by clusterID +func (c *ClusterConnection) GetRemoteSendChansByCluster(clusterID int32) map[history.ClusterShardID]chan RoutedMessage { + c.remoteSendChannelsMu.RLock() + defer c.remoteSendChannelsMu.RUnlock() + + result := make(map[history.ClusterShardID]chan RoutedMessage) + for k, v := range c.remoteSendChannels { + if k.ClusterID == clusterID { + result[k] = v + } + } + return result +} + +// RemoveRemoteSendChan removes the send channel for a specific shard ID only if it matches the provided channel +func (c *ClusterConnection) RemoveRemoteSendChan(shardID history.ClusterShardID, expectedChan chan RoutedMessage) { + c.remoteSendChannelsMu.Lock() + defer c.remoteSendChannelsMu.Unlock() + if currentChan, exists := c.remoteSendChannels[shardID]; exists && currentChan == expectedChan { + delete(c.remoteSendChannels, shardID) + c.logger.Info("Removed remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } else { + c.logger.Info("Skipped removing remote send channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } +} + +// SetLocalAckChan registers an ack channel for a specific shard ID +func (c *ClusterConnection) SetLocalAckChan(shardID history.ClusterShardID, ackChan chan RoutedAck) { + c.logger.Info("Register local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + c.localAckChannelsMu.Lock() + defer c.localAckChannelsMu.Unlock() + c.localAckChannels[shardID] = ackChan +} + +// GetLocalAckChan retrieves the ack channel for a specific shard ID +func (c *ClusterConnection) GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, bool) { + c.localAckChannelsMu.RLock() + defer c.localAckChannelsMu.RUnlock() + ch, exists := c.localAckChannels[shardID] + return ch, exists +} + +// RemoveLocalAckChan removes the ack channel for a specific shard ID only if it matches the provided channel +func (c *ClusterConnection) RemoveLocalAckChan(shardID history.ClusterShardID, expectedChan chan RoutedAck) { + c.logger.Info("Remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + c.localAckChannelsMu.Lock() + defer c.localAckChannelsMu.Unlock() + if currentChan, exists := c.localAckChannels[shardID]; exists && currentChan == expectedChan { + delete(c.localAckChannels, shardID) + } else { + c.logger.Info("Skipped removing local ack channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } +} + +// ForceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID +func (c *ClusterConnection) ForceRemoveLocalAckChan(shardID history.ClusterShardID) { + c.logger.Info("Force remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + c.localAckChannelsMu.Lock() + defer c.localAckChannelsMu.Unlock() + delete(c.localAckChannels, shardID) +} + +// SetLocalReceiverCancelFunc registers a cancel function for a local receiver for a specific shard ID +func (c *ClusterConnection) SetLocalReceiverCancelFunc(shardID history.ClusterShardID, cancelFunc context.CancelFunc) { + c.logger.Info("Register local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + c.localReceiverCancelFuncsMu.Lock() + defer c.localReceiverCancelFuncsMu.Unlock() + c.localReceiverCancelFuncs[shardID] = cancelFunc +} + +// GetLocalReceiverCancelFunc retrieves the cancel function for a local receiver for a specific shard ID +func (c *ClusterConnection) GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (context.CancelFunc, bool) { + c.localReceiverCancelFuncsMu.RLock() + defer c.localReceiverCancelFuncsMu.RUnlock() + cancelFunc, exists := c.localReceiverCancelFuncs[shardID] + return cancelFunc, exists +} + +// RemoveLocalReceiverCancelFunc unconditionally removes the cancel function for a local receiver for a specific shard ID +// Note: Functions cannot be compared in Go, so we use unconditional removal. +// The race condition is primarily with channels; TerminatePreviousLocalReceiver handles forced cleanup. +func (c *ClusterConnection) RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) { + c.logger.Info("Remove local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + c.localReceiverCancelFuncsMu.Lock() + defer c.localReceiverCancelFuncsMu.Unlock() + delete(c.localReceiverCancelFuncs, shardID) +} + +// TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed +func (c *ClusterConnection) TerminatePreviousLocalReceiver(serverShardID history.ClusterShardID) { + // Check if there's a previous cancel function for this shard + if prevCancelFunc, exists := c.GetLocalReceiverCancelFunc(serverShardID); exists { + c.logger.Info("Terminating previous local receiver for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(serverShardID))) + + // Cancel the previous receiver's context + prevCancelFunc() + + // Force remove the cancel function and ack channel from tracking + c.RemoveLocalReceiverCancelFunc(serverShardID) + c.ForceRemoveLocalAckChan(serverShardID) + } +} + // buildProxyServer uses the provided grpc.ClientConnInterface and config.ProxyConfig to create a grpc.Server that proxies // the Temporal API across the ClientConnInterface. func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, observeFn func(int32, int32)) (*grpc.Server, error) { @@ -267,16 +505,19 @@ func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, obs } server := grpc.NewServer(serverOpts...) - var lcmParameters LCMParameters - if c.shardCountConfig.Mode == config.ShardCountLCM { - lcmParameters = LCMParameters{ - LCM: common.LCM(c.shardCountConfig.LocalShardCount, c.shardCountConfig.RemoteShardCount), - TargetShardCount: c.targetShardCount, - } - } - - adminServiceImpl := NewAdminServiceProxyServer(fmt.Sprintf("%sAdminService", c.directionLabel), adminservice.NewAdminServiceClient(c.client), - c.clusterDefinition.APIOverrides, []string{c.directionLabel}, observeFn, c.shardCountConfig, lcmParameters, c.logger) + adminServiceImpl := NewAdminServiceProxyServer( + fmt.Sprintf("%sAdminService", c.directionLabel), + adminservice.NewAdminServiceClient(c.client), + adminservice.NewAdminServiceClient(c.managedClient), + c.clusterDefinition.APIOverrides, + []string{c.directionLabel}, + observeFn, + c.shardCountConfig, + c.lcmParameters, + c.logger, + c.clusterConnection, + c.overrideShardCount, + ) var accessControl *auth.AccessControl if c.clusterDefinition.ACLPolicy != nil { accessControl = auth.NewAccesControl(c.clusterDefinition.ACLPolicy.AllowedNamespaces) diff --git a/proxy/cluster_connection_test.go b/proxy/cluster_connection_test.go index 27db4275..cf45a56d 100644 --- a/proxy/cluster_connection_test.go +++ b/proxy/cluster_connection_test.go @@ -172,25 +172,25 @@ func newPairedLocalClusterConnection(t *testing.T, isMux bool, logger log.Logger var localCtx context.Context localCtx, cancelLocalCC = context.WithCancel(t.Context()) localCC, err = NewClusterConnection(localCtx, makeTCPClusterConfig("TCP-only Connection Local Proxy", - a.localTemporalAddr, a.localProxyInbound, a.localProxyOutbound, a.remoteProxyInbound), nil, logger) + a.localTemporalAddr, a.localProxyInbound, a.localProxyOutbound, a.remoteProxyInbound), logger) require.NoError(t, err) var remoteCtx context.Context remoteCtx, cancelRemoteCC = context.WithCancel(t.Context()) remoteCC, err = NewClusterConnection(remoteCtx, makeTCPClusterConfig("TCP-only Connection Remote Proxy", - a.remoteTemporalAddr, a.remoteProxyInbound, a.remoteProxyOutbound, a.localProxyInbound), nil, logger) + a.remoteTemporalAddr, a.remoteProxyInbound, a.remoteProxyOutbound, a.localProxyInbound), logger) require.NoError(t, err) } else { var localCtx context.Context localCtx, cancelLocalCC = context.WithCancel(t.Context()) localCC, err = NewClusterConnection(localCtx, makeMuxClusterConfig("Mux Connection Local Establishing Proxy", - config.ConnTypeMuxClient, a.localTemporalAddr, a.localProxyOutbound, a.remoteProxyInbound), nil, logger) + config.ConnTypeMuxClient, a.localTemporalAddr, a.localProxyOutbound, a.remoteProxyInbound), logger) require.NoError(t, err) var remoteCtx context.Context remoteCtx, cancelRemoteCC = context.WithCancel(t.Context()) remoteCC, err = NewClusterConnection(remoteCtx, makeMuxClusterConfig("Mux Connection Remote Receiving Proxy", - config.ConnTypeMuxServer, a.remoteTemporalAddr, a.remoteProxyOutbound, a.remoteProxyInbound), nil, logger) + config.ConnTypeMuxServer, a.remoteTemporalAddr, a.remoteProxyOutbound, a.remoteProxyInbound), logger) require.NoError(t, err) } clientFromLocal, err := grpc.NewClient(a.localProxyOutbound, grpcutil.MakeDialOptions(nil, metrics.GetStandardGRPCClientInterceptor("outbound-local"))...) @@ -259,7 +259,7 @@ func TestMuxCCFailover(t *testing.T) { cancel() newConnection, err := NewClusterConnection(t.Context(), makeMuxClusterConfig("newRemoteMux", config.ConnTypeMuxServer, plcc.addresses.remoteTemporalAddr, plcc.addresses.remoteProxyOutbound, plcc.addresses.remoteProxyInbound, - func(cc *config.ClusterConnConfig) { cc.RemoteServer.Connection.MuxCount = 5 }), nil, logger) + func(cc *config.ClusterConnConfig) { cc.RemoteServer.Connection.MuxCount = 5 }), logger) require.NoError(t, err) newConnection.Start() // Wait for localCC's client retry... diff --git a/proxy/debug.go b/proxy/debug.go index 8add32e8..cf04b90a 100644 --- a/proxy/debug.go +++ b/proxy/debug.go @@ -81,11 +81,11 @@ type ( } DebugResponse struct { - Timestamp time.Time `json:"timestamp"` - ActiveStreams []StreamInfo `json:"active_streams"` - StreamCount int `json:"stream_count"` - ShardInfos []ShardDebugInfo `json:"shard_infos"` - ChannelInfo ChannelDebugInfo `json:"channel_info"` + Timestamp time.Time `json:"timestamp"` + ActiveStreams []StreamInfo `json:"active_streams"` + StreamCount int `json:"stream_count"` + ShardInfos []ShardDebugInfo `json:"shard_infos"` + ChannelInfos []ChannelDebugInfo `json:"channel_infos"` } ) @@ -95,21 +95,23 @@ func HandleDebugInfo(w http.ResponseWriter, r *http.Request, proxyInstance *Prox var activeStreams []StreamInfo var streamCount int var shardInfos []ShardDebugInfo - var channelInfo ChannelDebugInfo + var channelInfos []ChannelDebugInfo // Get active streams information streamTracker := GetGlobalStreamTracker() activeStreams = streamTracker.GetActiveStreams() streamCount = streamTracker.GetStreamCount() - shardInfos = proxyInstance.GetShardInfos() - channelInfo = proxyInstance.GetChannelInfo() + for _, clusterConnection := range proxyInstance.clusterConnections { + shardInfos = append(shardInfos, clusterConnection.GetShardInfos()...) + channelInfos = append(channelInfos, clusterConnection.GetChannelInfo()) + } response := DebugResponse{ Timestamp: time.Now(), ActiveStreams: activeStreams, StreamCount: streamCount, ShardInfos: shardInfos, - ChannelInfo: channelInfo, + ChannelInfos: channelInfos, } if err := json.NewEncoder(w).Encode(response); err != nil { diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 33bbf98e..5751ff87 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -23,12 +23,11 @@ import ( // intraProxyManager maintains long-lived intra-proxy streams to peer proxies and // provides simple send helpers (e.g., forwarding ACKs). type intraProxyManager struct { - logger log.Logger - streamsMu sync.RWMutex - shardManager ShardManager - shardCountConfig config.ShardCountConfig - proxy *Proxy - notifyCh chan struct{} + logger log.Logger + streamsMu sync.RWMutex + shardCountConfig config.ShardCountConfig + clusterConnection *ClusterConnection + notifyCh chan struct{} // Group state by remote peer for unified lifecycle ops peers map[string]*peerState } @@ -45,14 +44,13 @@ type peerStreamKey struct { sourceShard history.ClusterShardID } -func newIntraProxyManager(logger log.Logger, proxy *Proxy, shardManager ShardManager, shardCountConfig config.ShardCountConfig) *intraProxyManager { +func newIntraProxyManager(logger log.Logger, clusterConnection *ClusterConnection, shardCountConfig config.ShardCountConfig) *intraProxyManager { return &intraProxyManager{ - logger: logger, - proxy: proxy, - shardManager: shardManager, - shardCountConfig: shardCountConfig, - peers: make(map[string]*peerState), - notifyCh: make(chan struct{}), + logger: logger, + clusterConnection: clusterConnection, + shardCountConfig: shardCountConfig, + peers: make(map[string]*peerState), + notifyCh: make(chan struct{}), } } @@ -60,8 +58,7 @@ func newIntraProxyManager(logger log.Logger, proxy *Proxy, shardManager ShardMan // Replication messages are sent by intraProxyManager.sendMessages using the registered server stream. type intraProxyStreamSender struct { logger log.Logger - shardManager ShardManager - proxy *Proxy + clusterConnection *ClusterConnection intraMgr *intraProxyManager peerNodeName string targetShardID history.ClusterShardID @@ -134,7 +131,7 @@ func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) erro s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(s.sourceShardID)), tag.NewInt64("ack", ack)) // FIXME: should retry. If not succeed, return and shutdown the stream - sent := s.shardManager.DeliverAckToShardOwner(s.sourceShardID, routedAck, s.proxy, shutdownChan, s.logger, ack, false) + sent := s.clusterConnection.shardManager.DeliverAckToShardOwner(s.sourceShardID, routedAck, s.clusterConnection, shutdownChan, s.logger, ack, false) if !sent { s.logger.Error("Sender failed to forward ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(s.sourceShardID)), tag.NewInt64("ack", ack)) return fmt.Errorf("failed to forward ACK to source shard") @@ -168,21 +165,20 @@ func (s *intraProxyStreamSender) sendReplicationMessages(resp *adminservice.Stre // intraProxyStreamReceiver ensures a client stream to peer exists and sends aggregated ACKs upstream. type intraProxyStreamReceiver struct { - logger log.Logger - shardManager ShardManager - proxy *Proxy - intraMgr *intraProxyManager - peerNodeName string - targetShardID history.ClusterShardID - sourceShardID history.ClusterShardID - streamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient - streamID string - shutdown channel.ShutdownOnce - cancel context.CancelFunc + logger log.Logger + clusterConnection *ClusterConnection + intraMgr *intraProxyManager + peerNodeName string + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + streamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient + streamID string + shutdown channel.ShutdownOnce + cancel context.CancelFunc } // Run opens the client stream with metadata, registers tracking, and starts receiver goroutines. -func (r *intraProxyStreamReceiver) Run(ctx context.Context, p *Proxy, conn *grpc.ClientConn) error { +func (r *intraProxyStreamReceiver) Run(ctx context.Context, clusterConnection *ClusterConnection, conn *grpc.ClientConn) error { r.streamID = BuildIntraProxyReceiverStreamID(r.peerNodeName, r.sourceShardID, r.targetShardID) r.logger = log.With(r.logger, tag.NewStringTag("streamID", r.streamID)) @@ -195,7 +191,7 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, p *Proxy, conn *grpc md.Set(history.MetadataKeyServerShardID, fmt.Sprintf("%d", r.sourceShardID.ShardID)) ctx = metadata.NewOutgoingContext(ctx, md) ctx = common.WithIntraProxyHeaders(ctx, map[string]string{ - common.IntraProxyOriginProxyIDHeader: r.shardManager.GetShardInfo().NodeName, + common.IntraProxyOriginProxyIDHeader: clusterConnection.shardManager.GetShardInfo().NodeName, }) // Ensure we can cancel Recv() by canceling the context when tearing down @@ -218,11 +214,11 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, p *Proxy, conn *grpc defer st.UnregisterStream(r.streamID) // Start replication receiver loop - return r.recvReplicationMessages(p) + return r.recvReplicationMessages(r.clusterConnection) } // recvReplicationMessages receives replication messages and forwards to local shard owner. -func (r *intraProxyStreamReceiver) recvReplicationMessages(self *Proxy) error { +func (r *intraProxyStreamReceiver) recvReplicationMessages(clusterConnection *ClusterConnection) error { r.logger.Info("intraProxyStreamReceiver recvReplicationMessages started") defer r.logger.Info("intraProxyStreamReceiver recvReplicationMessages finished") @@ -256,7 +252,7 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages(self *Proxy) error { sent := false logged := false for !sent { - if ch, ok := self.GetRemoteSendChan(r.targetShardID); ok { + if ch, ok := clusterConnection.remoteSendChannels[r.targetShardID]; ok { func() { defer func() { if panicErr := recover(); panicErr != nil { @@ -349,7 +345,7 @@ func (m *intraProxyManager) UnregisterSender( } // EnsureReceiverForPeerShard ensures a client stream and an ACK aggregator exist for the given peer/shard pair. -func (m *intraProxyManager) EnsureReceiverForPeerShard(p *Proxy, peerNodeName string, targetShard history.ClusterShardID, sourceShard history.ClusterShardID) { +func (m *intraProxyManager) EnsureReceiverForPeerShard(clusterConnection *ClusterConnection, peerNodeName string, targetShard history.ClusterShardID, sourceShard history.ClusterShardID) { logger := log.With(m.logger, tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), @@ -361,18 +357,18 @@ func (m *intraProxyManager) EnsureReceiverForPeerShard(p *Proxy, peerNodeName st return } // Do not create intra-proxy streams to self instance - if peerNodeName == m.shardManager.GetNodeName() { + if peerNodeName == m.clusterConnection.shardManager.GetNodeName() { return } // Require at least one shard to be local to this instance - isLocalTargetShard := m.shardManager.IsLocalShard(targetShard) - isLocalSourceShard := m.shardManager.IsLocalShard(sourceShard) + isLocalTargetShard := m.clusterConnection.shardManager.IsLocalShard(targetShard) + isLocalSourceShard := m.clusterConnection.shardManager.IsLocalShard(sourceShard) if !isLocalTargetShard && !isLocalSourceShard { logger.Info("EnsureReceiverForPeerShard skipping because neither shard is local", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewBoolTag("isLocalTargetShard", isLocalTargetShard), tag.NewBoolTag("isLocalSourceShard", isLocalSourceShard)) return } // Consolidated path: ensure stream and background loops - err := m.ensureStream(context.Background(), logger, peerNodeName, targetShard, sourceShard, p) + err := m.ensureStream(context.Background(), logger, peerNodeName, targetShard, sourceShard, m.clusterConnection) if err != nil { logger.Error("failed to ensureStream", tag.Error(err)) } @@ -382,7 +378,7 @@ func (m *intraProxyManager) EnsureReceiverForPeerShard(p *Proxy, peerNodeName st func (m *intraProxyManager) ensurePeer( ctx context.Context, peerNodeName string, - p *Proxy, + clusterConnection *ClusterConnection, ) (*peerState, error) { m.streamsMu.RLock() if ps, ok := m.peers[peerNodeName]; ok && ps != nil && ps.conn != nil { @@ -418,7 +414,7 @@ func (m *intraProxyManager) ensurePeer( // grpc.WithDisableServiceConfig(), // ) - proxyAddresses, ok := m.shardManager.GetProxyAddress(peerNodeName) + proxyAddresses, ok := clusterConnection.shardManager.GetProxyAddress(peerNodeName) if !ok { return nil, fmt.Errorf("proxy address not found") } @@ -460,7 +456,7 @@ func (m *intraProxyManager) ensureStream( peerNodeName string, targetShard history.ClusterShardID, sourceShard history.ClusterShardID, - p *Proxy, + clusterConnection *ClusterConnection, ) error { logger.Info("ensureStream") key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} @@ -477,7 +473,7 @@ func (m *intraProxyManager) ensureStream( m.streamsMu.RUnlock() // Reuse shared connection per peer - ps, err := m.ensurePeer(ctx, peerNodeName, p) + ps, err := m.ensurePeer(ctx, peerNodeName, clusterConnection) if err != nil { logger.Error("Failed to ensure peer", tag.Error(err)) return err @@ -489,12 +485,11 @@ func (m *intraProxyManager) ensureStream( tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("targetShardID", ClusterShardIDtoString(targetShard)), tag.NewStringTag("sourceShardID", ClusterShardIDtoString(sourceShard))), - shardManager: m.shardManager, - proxy: p, - intraMgr: m, - peerNodeName: peerNodeName, - targetShardID: targetShard, - sourceShardID: sourceShard, + clusterConnection: clusterConnection, + intraMgr: m, + peerNodeName: peerNodeName, + targetShardID: targetShard, + sourceShardID: sourceShard, } // initialize shutdown handle and register it for lifecycle management recv.shutdown = channel.NewShutdownOnce() @@ -506,7 +501,7 @@ func (m *intraProxyManager) ensureStream( // Let the receiver open stream, register tracking, and start goroutines go func() { - if err := recv.Run(ctx, p, ps.conn); err != nil { + if err := recv.Run(ctx, clusterConnection, ps.conn); err != nil { m.logger.Error("intraProxyStreamReceiver.Run error", tag.Error(err)) } // remove the receiver from the peer state @@ -524,7 +519,6 @@ func (m *intraProxyManager) sendAck( peerNodeName string, clientShard history.ClusterShardID, serverShard history.ClusterShardID, - p *Proxy, req *adminservice.StreamWorkflowReplicationMessagesRequest, ) error { key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} @@ -548,7 +542,6 @@ func (m *intraProxyManager) sendReplicationMessages( peerNodeName string, targetShard history.ClusterShardID, sourceShard history.ClusterShardID, - p *Proxy, resp *adminservice.StreamWorkflowReplicationMessagesResponse, ) error { key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} @@ -669,7 +662,7 @@ func (m *intraProxyManager) ClosePeerShard(peer string, clientShard, serverShard } } -func (m *intraProxyManager) Start() error { +func (m *intraProxyManager) Start() { m.logger.Info("intraProxyManager started") defer m.logger.Info("intraProxyManager stopped") go func() { @@ -678,13 +671,12 @@ func (m *intraProxyManager) Start() error { timer := time.NewTimer(1 * time.Second) select { case <-timer.C: - m.ReconcilePeerStreams(m.proxy, "") + m.ReconcilePeerStreams(m.clusterConnection, "") case <-m.notifyCh: - m.ReconcilePeerStreams(m.proxy, "") + m.ReconcilePeerStreams(m.clusterConnection, "") } } }() - return nil } func (m *intraProxyManager) Notify() { @@ -698,7 +690,7 @@ func (m *intraProxyManager) Notify() { // for a given peer and closes any sender/receiver not in the desired set. // This mirrors the Temporal StreamReceiverMonitor approach. func (m *intraProxyManager) ReconcilePeerStreams( - p *Proxy, + clusterConnection *ClusterConnection, peerNodeName string, ) { m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("peerNodeName", peerNodeName)) @@ -707,8 +699,8 @@ func (m *intraProxyManager) ReconcilePeerStreams( if mode := m.shardCountConfig.Mode; mode != config.ShardCountRouting { return } - localShards := m.shardManager.GetLocalShards() - remoteShards, err := m.shardManager.GetRemoteShardsForPeer(peerNodeName) + localShards := clusterConnection.shardManager.GetLocalShards() + remoteShards, err := clusterConnection.shardManager.GetRemoteShardsForPeer(peerNodeName) if err != nil { m.logger.Error("Failed to get remote shards for peer", tag.Error(err)) return @@ -750,7 +742,7 @@ func (m *intraProxyManager) ReconcilePeerStreams( // Ensure all desired receivers exist for key := range desiredReceivers { - m.EnsureReceiverForPeerShard(p, desiredReceivers[key], key.targetShard, key.sourceShard) + m.EnsureReceiverForPeerShard(clusterConnection, desiredReceivers[key], key.targetShard, key.sourceShard) } // Prune anything not desired diff --git a/proxy/proxy.go b/proxy/proxy.go index 5cdf78b9..98613e50 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" "strings" - "sync" "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/client/history" @@ -47,20 +46,6 @@ type ( outboundHealthCheckServer *http.Server metricsServer *http.Server logger log.Logger - shardManagers map[migrationId]ShardManager - intraMgrs map[migrationId]*intraProxyManager - - // remoteSendChannels maps shard IDs to send channels for replication message routing - remoteSendChannels map[history.ClusterShardID]chan RoutedMessage - remoteSendChannelsMu sync.RWMutex - - // localAckChannels maps shard IDs to ack channels for local acknowledgment handling - localAckChannels map[history.ClusterShardID]chan RoutedAck - localAckChannelsMu sync.RWMutex - - // localReceiverCancelFuncs maps shard IDs to context cancel functions for local receiver termination - localReceiverCancelFuncs map[history.ClusterShardID]context.CancelFunc - localReceiverCancelFuncsMu sync.RWMutex } ) @@ -71,17 +56,12 @@ func NewProxy(configProvider config.ConfigProvider, logger log.Logger) *Proxy { lifetime: ctx, cancel: cancel, clusterConnections: make(map[migrationId]*ClusterConnection, len(s2sConfig.MuxTransports)), - intraMgrs: make(map[migrationId]*intraProxyManager), - shardManagers: make(map[migrationId]ShardManager), logger: log.NewThrottledLogger( logger, func() float64 { return s2sConfig.Logging.GetThrottleMaxRPS() }, ), - remoteSendChannels: make(map[history.ClusterShardID]chan RoutedMessage), - localAckChannels: make(map[history.ClusterShardID]chan RoutedAck), - localReceiverCancelFuncs: make(map[history.ClusterShardID]context.CancelFunc), } if len(s2sConfig.ClusterConnections) == 0 { panic(errors.New("cannot create proxy without inbound and outbound config")) @@ -91,21 +71,13 @@ func NewProxy(configProvider config.ConfigProvider, logger log.Logger) *Proxy { } for _, clusterCfg := range s2sConfig.ClusterConnections { - shardManager, err := NewShardManager(configProvider, logger) - if err != nil { - logger.Fatal("Failed to create shard manager", tag.Error(err)) - continue - } - cc, err := NewClusterConnection(ctx, clusterCfg, shardManager, logger) + cc, err := NewClusterConnection(ctx, clusterCfg, logger) if err != nil { logger.Fatal("Incorrectly configured Mux cluster connection", tag.Error(err), tag.NewStringTag("name", clusterCfg.Name)) continue } migrationId := migrationId{clusterCfg.Name} proxy.clusterConnections[migrationId] = cc - proxy.intraMgrs[migrationId] = newIntraProxyManager(logger, proxy, shardManager, clusterCfg.ShardCountConfig) - proxy.shardManagers[migrationId] = shardManager - shardManager.SetIntraProxyManager(proxy.intraMgrs[migrationId]) } // TODO: correctly host multiple health checks if len(s2sConfig.ClusterConnections) > 0 && s2sConfig.ClusterConnections[0].InboundHealthCheck.ListenAddress != "" { @@ -214,18 +186,6 @@ func (s *Proxy) Start() error { ` it needs at least the following path: metrics.prometheus.listenAddress`) } - for _, shardManager := range s.shardManagers { - if err := shardManager.Start(s.lifetime); err != nil { - return err - } - } - - for _, intraMgr := range s.intraMgrs { - if err := intraMgr.Start(); err != nil { - return err - } - } - for _, v := range s.clusterConnections { v.Start() } @@ -256,183 +216,3 @@ func (s *Proxy) Describe() string { sb.WriteString("]") return sb.String() } - -// GetShardInfo returns debug information about shard distribution -func (s *Proxy) GetShardInfos() []ShardDebugInfo { - var shardInfos []ShardDebugInfo - for _, shardManager := range s.shardManagers { - shardInfos = append(shardInfos, shardManager.GetShardInfo()) - } - return shardInfos -} - -// GetChannelInfo returns debug information about active channels -func (s *Proxy) GetChannelInfo() ChannelDebugInfo { - remoteSendChannels := make(map[string]int) - var totalSendChannels int - - // Collect remote send channel info first - s.remoteSendChannelsMu.RLock() - for shardID, ch := range s.remoteSendChannels { - shardKey := ClusterShardIDtoString(shardID) - remoteSendChannels[shardKey] = len(ch) - } - totalSendChannels = len(s.remoteSendChannels) - s.remoteSendChannelsMu.RUnlock() - - localAckChannels := make(map[string]int) - var totalAckChannels int - - // Collect local ack channel info separately - s.localAckChannelsMu.RLock() - for shardID, ch := range s.localAckChannels { - shardKey := ClusterShardIDtoString(shardID) - localAckChannels[shardKey] = len(ch) - } - totalAckChannels = len(s.localAckChannels) - s.localAckChannelsMu.RUnlock() - - return ChannelDebugInfo{ - RemoteSendChannels: remoteSendChannels, - LocalAckChannels: localAckChannels, - TotalSendChannels: totalSendChannels, - TotalAckChannels: totalAckChannels, - } -} - -// GetIntraProxyManager returns the intra-proxy manager instance -func (s *Proxy) GetIntraProxyManager(migrationId migrationId) *intraProxyManager { - return s.intraMgrs[migrationId] -} - -// SetRemoteSendChan registers a send channel for a specific shard ID -func (s *Proxy) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) { - s.logger.Info("Register remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - s.remoteSendChannelsMu.Lock() - defer s.remoteSendChannelsMu.Unlock() - s.remoteSendChannels[shardID] = sendChan -} - -// GetRemoteSendChan retrieves the send channel for a specific shard ID -func (s *Proxy) GetRemoteSendChan(shardID history.ClusterShardID) (chan RoutedMessage, bool) { - s.remoteSendChannelsMu.RLock() - defer s.remoteSendChannelsMu.RUnlock() - ch, exists := s.remoteSendChannels[shardID] - return ch, exists -} - -// GetAllRemoteSendChans returns a map of all remote send channels -func (s *Proxy) GetAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage { - s.remoteSendChannelsMu.RLock() - defer s.remoteSendChannelsMu.RUnlock() - - // Create a copy of the map - result := make(map[history.ClusterShardID]chan RoutedMessage, len(s.remoteSendChannels)) - for k, v := range s.remoteSendChannels { - result[k] = v - } - return result -} - -// GetRemoteSendChansByCluster returns a copy of remote send channels filtered by clusterID -func (s *Proxy) GetRemoteSendChansByCluster(clusterID int32) map[history.ClusterShardID]chan RoutedMessage { - s.remoteSendChannelsMu.RLock() - defer s.remoteSendChannelsMu.RUnlock() - - result := make(map[history.ClusterShardID]chan RoutedMessage) - for k, v := range s.remoteSendChannels { - if k.ClusterID == clusterID { - result[k] = v - } - } - return result -} - -// RemoveRemoteSendChan removes the send channel for a specific shard ID only if it matches the provided channel -func (s *Proxy) RemoveRemoteSendChan(shardID history.ClusterShardID, expectedChan chan RoutedMessage) { - s.remoteSendChannelsMu.Lock() - defer s.remoteSendChannelsMu.Unlock() - if currentChan, exists := s.remoteSendChannels[shardID]; exists && currentChan == expectedChan { - delete(s.remoteSendChannels, shardID) - s.logger.Info("Removed remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - } else { - s.logger.Info("Skipped removing remote send channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - } -} - -// SetLocalAckChan registers an ack channel for a specific shard ID -func (s *Proxy) SetLocalAckChan(shardID history.ClusterShardID, ackChan chan RoutedAck) { - s.logger.Info("Register local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - s.localAckChannelsMu.Lock() - defer s.localAckChannelsMu.Unlock() - s.localAckChannels[shardID] = ackChan -} - -// GetLocalAckChan retrieves the ack channel for a specific shard ID -func (s *Proxy) GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, bool) { - s.localAckChannelsMu.RLock() - defer s.localAckChannelsMu.RUnlock() - ch, exists := s.localAckChannels[shardID] - return ch, exists -} - -// RemoveLocalAckChan removes the ack channel for a specific shard ID only if it matches the provided channel -func (s *Proxy) RemoveLocalAckChan(shardID history.ClusterShardID, expectedChan chan RoutedAck) { - s.logger.Info("Remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - s.localAckChannelsMu.Lock() - defer s.localAckChannelsMu.Unlock() - if currentChan, exists := s.localAckChannels[shardID]; exists && currentChan == expectedChan { - delete(s.localAckChannels, shardID) - } else { - s.logger.Info("Skipped removing local ack channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - } -} - -// ForceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID -func (s *Proxy) ForceRemoveLocalAckChan(shardID history.ClusterShardID) { - s.logger.Info("Force remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - s.localAckChannelsMu.Lock() - defer s.localAckChannelsMu.Unlock() - delete(s.localAckChannels, shardID) -} - -// SetLocalReceiverCancelFunc registers a cancel function for a local receiver for a specific shard ID -func (s *Proxy) SetLocalReceiverCancelFunc(shardID history.ClusterShardID, cancelFunc context.CancelFunc) { - s.logger.Info("Register local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - s.localReceiverCancelFuncsMu.Lock() - defer s.localReceiverCancelFuncsMu.Unlock() - s.localReceiverCancelFuncs[shardID] = cancelFunc -} - -// GetLocalReceiverCancelFunc retrieves the cancel function for a local receiver for a specific shard ID -func (s *Proxy) GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (context.CancelFunc, bool) { - s.localReceiverCancelFuncsMu.RLock() - defer s.localReceiverCancelFuncsMu.RUnlock() - cancelFunc, exists := s.localReceiverCancelFuncs[shardID] - return cancelFunc, exists -} - -// RemoveLocalReceiverCancelFunc unconditionally removes the cancel function for a local receiver for a specific shard ID -// Note: Functions cannot be compared in Go, so we use unconditional removal. -// The race condition is primarily with channels; TerminatePreviousLocalReceiver handles forced cleanup. -func (s *Proxy) RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) { - s.logger.Info("Remove local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - s.localReceiverCancelFuncsMu.Lock() - defer s.localReceiverCancelFuncsMu.Unlock() - delete(s.localReceiverCancelFuncs, shardID) -} - -// TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed -func (s *Proxy) TerminatePreviousLocalReceiver(serverShardID history.ClusterShardID) { - // Check if there's a previous cancel function for this shard - if prevCancelFunc, exists := s.GetLocalReceiverCancelFunc(serverShardID); exists { - s.logger.Info("Terminating previous local receiver for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(serverShardID))) - - // Cancel the previous receiver's context - prevCancelFunc() - - // Force remove the cancel function and ack channel from tracking - s.RemoveLocalReceiverCancelFunc(serverShardID) - s.ForceRemoveLocalAckChan(serverShardID) - } -} diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 1835ee37..36d84a89 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -140,14 +140,13 @@ func (b *proxyIDRingBuffer) Discard(count int) { // (another proxy or a target server) and receiving ACKs back. // This is scaffolding only – the concrete behavior will be wired in later. type proxyStreamSender struct { - logger log.Logger - shardManager ShardManager - proxy *Proxy - targetShardID history.ClusterShardID - sourceShardID history.ClusterShardID - directionLabel string - streamID string - streamTracker *StreamTracker + logger log.Logger + clusterConnection *ClusterConnection + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + directionLabel string + streamID string + streamTracker *StreamTracker // sendMsgChan carries replication messages to be sent to the remote side. sendMsgChan chan RoutedMessage @@ -240,11 +239,11 @@ func (s *proxyStreamSender) Run( // Register remote send channel for this shard so receiver can forward tasks locally s.sendMsgChan = make(chan RoutedMessage, 100) - s.proxy.SetRemoteSendChan(s.targetShardID, s.sendMsgChan) - defer s.proxy.RemoveRemoteSendChan(s.targetShardID, s.sendMsgChan) + s.clusterConnection.SetRemoteSendChan(s.targetShardID, s.sendMsgChan) + defer s.clusterConnection.RemoveRemoteSendChan(s.targetShardID, s.sendMsgChan) - registeredAt := s.shardManager.RegisterShard(s.targetShardID) - defer s.shardManager.UnregisterShard(s.targetShardID, registeredAt) + registeredAt := s.clusterConnection.shardManager.RegisterShard(s.targetShardID) + defer s.clusterConnection.shardManager.UnregisterShard(s.targetShardID, registeredAt) wg := sync.WaitGroup{} wg.Add(2) @@ -334,7 +333,7 @@ func (s *proxyStreamSender) recvAck( s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", originalAck)) - if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger, originalAck, true) { + if s.clusterConnection.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.clusterConnection, shutdownChan, s.logger, originalAck, true) { sent[srcShard] = true numRemaining-- progress = true @@ -396,7 +395,7 @@ func (s *proxyStreamSender) recvAck( } // Log fallback ACK for this source shard s.logger.Info("Sender forwarding fallback ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", prev)) - if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.proxy, shutdownChan, s.logger, prev, true) { + if s.clusterConnection.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.clusterConnection, shutdownChan, s.logger, prev, true) { sent[srcShard] = true numRemaining-- progress = true @@ -555,15 +554,14 @@ func (s *proxyStreamSender) sendReplicationMessages( // proxyStreamReceiver receives replication messages from a local/remote server and // produces ACKs destined for the original sender. type proxyStreamReceiver struct { - logger log.Logger - shardManager ShardManager - proxy *Proxy - adminClient adminservice.AdminServiceClient - localShardCount int32 - targetShardID history.ClusterShardID - sourceShardID history.ClusterShardID - directionLabel string - ackChan chan RoutedAck + logger log.Logger + clusterConnection *ClusterConnection + adminClient adminservice.AdminServiceClient + localShardCount int32 + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + directionLabel string + ackChan chan RoutedAck // ack aggregation across target shards ackByTarget map[history.ClusterShardID]int64 lastSentMin int64 @@ -594,7 +592,7 @@ func (r *proxyStreamReceiver) Run( shutdownChan channel.ShutdownOnce, ) { // Terminate any previous local receiver for this shard - r.proxy.TerminatePreviousLocalReceiver(r.sourceShardID) + r.clusterConnection.TerminatePreviousLocalReceiver(r.sourceShardID) r.streamID = BuildReceiverStreamID(r.sourceShardID, r.targetShardID) r.logger = log.With(r.logger, @@ -634,11 +632,11 @@ func (r *proxyStreamReceiver) Run( // Setup ack channel and cancel func bookkeeping r.ackChan = make(chan RoutedAck, 100) - r.proxy.SetLocalAckChan(r.sourceShardID, r.ackChan) - r.proxy.SetLocalReceiverCancelFunc(r.sourceShardID, cancel) + r.clusterConnection.SetLocalAckChan(r.sourceShardID, r.ackChan) + r.clusterConnection.SetLocalReceiverCancelFunc(r.sourceShardID, cancel) defer func() { - r.proxy.RemoveLocalAckChan(r.sourceShardID, r.ackChan) - r.proxy.RemoveLocalReceiverCancelFunc(r.sourceShardID) + r.clusterConnection.RemoveLocalAckChan(r.sourceShardID, r.ackChan) + r.clusterConnection.RemoveLocalReceiverCancelFunc(r.sourceShardID) }() // init aggregation state @@ -740,7 +738,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( }, }, } - localShardsToSend := r.proxy.GetRemoteSendChansByCluster(r.targetShardID.ClusterID) + localShardsToSend := r.clusterConnection.GetRemoteSendChansByCluster(r.targetShardID.ClusterID) r.logger.Info("Going to broadcast high watermark to local shards", tag.NewStringTag("localShardsToSend", fmt.Sprintf("%v", localShardsToSend))) for targetShardID, sendChan := range localShardsToSend { // Clone the message for each recipient to prevent shared mutation @@ -772,7 +770,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( }() } // send to all remote shards on other nodes as well - remoteShards, err := r.shardManager.GetRemoteShardsForPeer("") + remoteShards, err := r.clusterConnection.shardManager.GetRemoteShardsForPeer("") if err != nil { r.logger.Error("Failed to get remote shards", tag.Error(err)) return err @@ -789,7 +787,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( SourceShard: msg.SourceShard, Resp: clonedResp, } - if !r.shardManager.DeliverMessagesToShardOwner(shard.ID, &clonedMsg, r.proxy, shutdownChan, r.logger) { + if !r.clusterConnection.shardManager.DeliverMessagesToShardOwner(shard.ID, &clonedMsg, r.clusterConnection, shutdownChan, r.logger) { r.logger.Warn("Failed to send ReplicationTasks to remote shard", tag.NewStringTag("shard", ClusterShardIDtoString(shard.ID))) } } @@ -829,7 +827,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( }, }, } - if r.shardManager.DeliverMessagesToShardOwner(targetShardID, &msg, r.proxy, shutdownChan, r.logger) { + if r.clusterConnection.shardManager.DeliverMessagesToShardOwner(targetShardID, &msg, r.clusterConnection, shutdownChan, r.logger) { sentByTarget[targetShardID] = true numRemaining-- progress = true diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index e71edd35..6dab0487 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -45,9 +45,9 @@ type ( // GetShardOwner returns the node name that owns the given shard GetShardOwner(shard history.ClusterShardID) (string, bool) // DeliverAckToShardOwner routes an ACK request to the appropriate shard owner (local or remote) - DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool + DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, clusterConnection *ClusterConnection, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool // DeliverMessagesToShardOwner routes replication messages to the appropriate shard owner (local or remote) - DeliverMessagesToShardOwner(targetShard history.ClusterShardID, routedMsg *RoutedMessage, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) bool + DeliverMessagesToShardOwner(targetShard history.ClusterShardID, routedMsg *RoutedMessage, clusterConnection *ClusterConnection, shutdownChan channel.ShutdownOnce, logger log.Logger) bool // SetOnPeerJoin registers a callback invoked when a new peer joins SetOnPeerJoin(handler func(nodeName string)) // SetOnPeerLeave registers a callback invoked when a peer leaves. @@ -56,21 +56,18 @@ type ( SetOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) // New: notify when remote shard set changes for a peer SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) - - SetIntraProxyManager(intraMgr *intraProxyManager) - GetIntraProxyManager() *intraProxyManager } shardManagerImpl struct { - config *config.MemberlistConfig - logger log.Logger - ml *memberlist.Memberlist - delegate *shardDelegate - mutex sync.RWMutex - localAddr string - started bool - onPeerJoin func(nodeName string) - onPeerLeave func(nodeName string) + memberlistConfig *config.MemberlistConfig + logger log.Logger + ml *memberlist.Memberlist + delegate *shardDelegate + mutex sync.RWMutex + localAddr string + started bool + onPeerJoin func(nodeName string) + onPeerLeave func(nodeName string) // New callbacks onLocalShardChange func(shard history.ClusterShardID, added bool) onRemoteShardChange func(peer string, shard history.ClusterShardID, added bool) @@ -112,28 +109,23 @@ type ( ) // NewShardManager creates a new shard manager instance -func NewShardManager(configProvider config.ConfigProvider, logger log.Logger) (ShardManager, error) { - cfg := configProvider.GetS2SProxyConfig().MemberlistConfig - if cfg == nil || !cfg.Enabled { - return &noopShardManager{}, nil - } - +func NewShardManager(memberlistConfig *config.MemberlistConfig, logger log.Logger) ShardManager { delegate := &shardDelegate{ logger: logger, } sm := &shardManagerImpl{ - config: cfg, - logger: logger, - delegate: delegate, - localShards: make(map[string]ShardInfo), - intraMgr: nil, - stopJoinRetry: make(chan struct{}), + memberlistConfig: memberlistConfig, + logger: logger, + delegate: delegate, + localShards: make(map[string]ShardInfo), + intraMgr: nil, + stopJoinRetry: make(chan struct{}), } delegate.manager = sm - return sm, nil + return sm } // SetOnPeerJoin registers a callback invoked on new peer joins. @@ -174,42 +166,42 @@ func (sm *shardManagerImpl) Start(lifetime context.Context) error { // Configure memberlist var mlConfig *memberlist.Config - if sm.config.TCPOnly { + if sm.memberlistConfig.TCPOnly { // Use LAN config as base for TCP-only mode mlConfig = memberlist.DefaultLANConfig() - mlConfig.DisableTcpPings = sm.config.DisableTCPPings + mlConfig.DisableTcpPings = sm.memberlistConfig.DisableTCPPings // Set default timeouts for TCP-only if not specified - if sm.config.ProbeTimeoutMs == 0 { + if sm.memberlistConfig.ProbeTimeoutMs == 0 { mlConfig.ProbeTimeout = 1 * time.Second } - if sm.config.ProbeIntervalMs == 0 { + if sm.memberlistConfig.ProbeIntervalMs == 0 { mlConfig.ProbeInterval = 2 * time.Second } } else { mlConfig = memberlist.DefaultLocalConfig() } - mlConfig.Name = sm.config.NodeName - mlConfig.BindAddr = sm.config.BindAddr - mlConfig.BindPort = sm.config.BindPort - mlConfig.AdvertiseAddr = sm.config.BindAddr - mlConfig.AdvertisePort = sm.config.BindPort + mlConfig.Name = sm.memberlistConfig.NodeName + mlConfig.BindAddr = sm.memberlistConfig.BindAddr + mlConfig.BindPort = sm.memberlistConfig.BindPort + mlConfig.AdvertiseAddr = sm.memberlistConfig.BindAddr + mlConfig.AdvertisePort = sm.memberlistConfig.BindPort mlConfig.Delegate = sm.delegate mlConfig.Events = &shardEventDelegate{manager: sm, logger: sm.logger} // Configure custom timeouts if specified - if sm.config.ProbeTimeoutMs > 0 { - mlConfig.ProbeTimeout = time.Duration(sm.config.ProbeTimeoutMs) * time.Millisecond + if sm.memberlistConfig.ProbeTimeoutMs > 0 { + mlConfig.ProbeTimeout = time.Duration(sm.memberlistConfig.ProbeTimeoutMs) * time.Millisecond } - if sm.config.ProbeIntervalMs > 0 { - mlConfig.ProbeInterval = time.Duration(sm.config.ProbeIntervalMs) * time.Millisecond + if sm.memberlistConfig.ProbeIntervalMs > 0 { + mlConfig.ProbeInterval = time.Duration(sm.memberlistConfig.ProbeIntervalMs) * time.Millisecond } sm.logger.Info("Creating memberlist", tag.NewStringTag("nodeName", mlConfig.Name), tag.NewStringTag("bindAddr", mlConfig.BindAddr), tag.NewStringTag("bindPort", fmt.Sprintf("%d", mlConfig.BindPort)), - tag.NewBoolTag("tcpOnly", sm.config.TCPOnly), + tag.NewBoolTag("tcpOnly", sm.memberlistConfig.TCPOnly), tag.NewBoolTag("disableTcpPings", mlConfig.DisableTcpPings), tag.NewStringTag("probeTimeout", mlConfig.ProbeTimeout.String()), tag.NewStringTag("probeInterval", mlConfig.ProbeInterval.String())) @@ -239,22 +231,22 @@ func (sm *shardManagerImpl) Start(lifetime context.Context) error { sm.mutex.Lock() sm.ml = ml - sm.localAddr = fmt.Sprintf("%s:%d", sm.config.BindAddr, sm.config.BindPort) + sm.localAddr = fmt.Sprintf("%s:%d", sm.memberlistConfig.BindAddr, sm.memberlistConfig.BindPort) sm.started = true sm.logger.Info("Shard manager base initialization complete", - tag.NewStringTag("node", sm.config.NodeName), + tag.NewStringTag("node", sm.memberlistConfig.NodeName), tag.NewStringTag("addr", sm.localAddr)) sm.mutex.Unlock() // Join existing cluster if configured - if len(sm.config.JoinAddrs) > 0 { + if len(sm.memberlistConfig.JoinAddrs) > 0 { sm.startJoinLoop() } sm.logger.Info("Shard manager started", - tag.NewStringTag("node", sm.config.NodeName), + tag.NewStringTag("node", sm.memberlistConfig.NodeName), tag.NewStringTag("addr", sm.localAddr)) context.AfterFunc(lifetime, func() { @@ -327,14 +319,14 @@ func (sm *shardManagerImpl) retryJoinCluster() { attempt := 0 sm.logger.Info("Starting join retry loop", - tag.NewStringTag("joinAddrs", fmt.Sprintf("%v", sm.config.JoinAddrs))) + tag.NewStringTag("joinAddrs", fmt.Sprintf("%v", sm.memberlistConfig.JoinAddrs))) for { attempt++ sm.mutex.RLock() ml := sm.ml - joinAddrs := sm.config.JoinAddrs + joinAddrs := sm.memberlistConfig.JoinAddrs sm.mutex.RUnlock() if ml == nil { @@ -433,20 +425,20 @@ func (sm *shardManagerImpl) IsLocalShard(clientShardID history.ClusterShardID) b func (sm *shardManagerImpl) GetProxyAddress(nodeName string) (string, bool) { // TODO: get the proxy address from the memberlist metadata - if sm.config.ProxyAddresses == nil { + if sm.memberlistConfig.ProxyAddresses == nil { return "", false } - addr, found := sm.config.ProxyAddresses[nodeName] + addr, found := sm.memberlistConfig.ProxyAddresses[nodeName] return addr, found } func (sm *shardManagerImpl) GetNodeName() string { - return sm.config.NodeName + return sm.memberlistConfig.NodeName } func (sm *shardManagerImpl) GetMemberNodes() []string { if !sm.started || sm.ml == nil { - return []string{sm.config.NodeName} + return []string{sm.memberlistConfig.NodeName} } // Use a timeout to prevent deadlocks when memberlist is busy @@ -470,8 +462,8 @@ func (sm *shardManagerImpl) GetMemberNodes() []string { case <-time.After(100 * time.Millisecond): // Timeout: return cached node name to prevent hanging sm.logger.Warn("GetMemberNodes timeout, returning self node", - tag.NewStringTag("node", sm.config.NodeName)) - return []string{sm.config.NodeName} + tag.NewStringTag("node", sm.memberlistConfig.NodeName)) + return []string{sm.memberlistConfig.NodeName} } } @@ -505,7 +497,7 @@ func (sm *shardManagerImpl) GetShardInfo() ShardDebugInfo { return ShardDebugInfo{ Enabled: true, - NodeName: sm.config.NodeName, + NodeName: sm.memberlistConfig.NodeName, LocalShards: localShardMap, LocalShardCount: len(localShardMap), RemoteShards: remoteShardsMap, @@ -514,7 +506,6 @@ func (sm *shardManagerImpl) GetShardInfo() ShardDebugInfo { } func (sm *shardManagerImpl) GetShardOwner(shard history.ClusterShardID) (string, bool) { - // FIXME: improve this: store remote shards in a map in the shardManagerImpl remoteShards, err := sm.GetRemoteShardsForPeer("") if err != nil { sm.logger.Error("Failed to get remote shards", tag.Error(err)) @@ -578,14 +569,14 @@ func (sm *shardManagerImpl) GetRemoteShardsForPeer(peerNodeName string) (map[str func (sm *shardManagerImpl) DeliverAckToShardOwner( sourceShard history.ClusterShardID, routedAck *RoutedAck, - proxy *Proxy, + clusterConnection *ClusterConnection, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool, ) bool { logger = log.With(logger, tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewInt64("ack", ack)) - if ackCh, ok := proxy.GetLocalAckChan(sourceShard); ok { + if ackCh, ok := clusterConnection.localAckChannels[sourceShard]; ok { delivered := false func() { defer func() { @@ -614,13 +605,13 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( } // Attempt remote delivery via intra-proxy when enabled and shard is remote - if owner, ok := sm.GetShardOwner(sourceShard); ok && owner != sm.config.NodeName { + if owner, ok := sm.GetShardOwner(sourceShard); ok && owner != sm.memberlistConfig.NodeName { if addr, found := sm.GetProxyAddress(owner); found { clientShard := routedAck.TargetShard serverShard := sourceShard - mgr := proxy.GetIntraProxyManager(migrationId{owner}) + mgr := clusterConnection.intraMgr // Synchronous send to preserve ordering - if err := mgr.sendAck(context.Background(), owner, clientShard, serverShard, proxy, routedAck.Req); err != nil { + if err := mgr.sendAck(context.Background(), owner, clientShard, serverShard, routedAck.Req); err != nil { logger.Error("Failed to forward ACK to shard owner via intra-proxy", tag.Error(err), tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) return false } @@ -640,14 +631,14 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( func (sm *shardManagerImpl) DeliverMessagesToShardOwner( targetShard history.ClusterShardID, routedMsg *RoutedMessage, - proxy *Proxy, + clusterConnection *ClusterConnection, shutdownChan channel.ShutdownOnce, logger log.Logger, ) bool { logger = log.With(logger, tag.NewStringTag("task-target-shard", ClusterShardIDtoString(targetShard))) // Try local delivery first - if ch, ok := proxy.GetRemoteSendChan(targetShard); ok { + if ch, ok := clusterConnection.remoteSendChannels[targetShard]; ok { delivered := false func() { defer func() { @@ -672,12 +663,12 @@ func (sm *shardManagerImpl) DeliverMessagesToShardOwner( } // Attempt remote delivery via intra-proxy when enabled and shard is remote - if sm.config != nil { - if owner, ok := sm.GetShardOwner(targetShard); ok && owner != sm.config.NodeName { + if sm.memberlistConfig != nil { + if owner, ok := sm.GetShardOwner(targetShard); ok && owner != sm.memberlistConfig.NodeName { if addr, found := sm.GetProxyAddress(owner); found { if mgr := sm.GetIntraProxyManager(); mgr != nil { resp := routedMsg.Resp - if err := mgr.sendReplicationMessages(context.Background(), owner, targetShard, routedMsg.SourceShard, proxy, resp); err != nil { + if err := mgr.sendReplicationMessages(context.Background(), owner, targetShard, routedMsg.SourceShard, resp); err != nil { logger.Error("Failed to forward replication messages to shard owner via intra-proxy", tag.Error(err), tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) return false } @@ -740,7 +731,7 @@ func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.C msg := ShardMessage{ Type: msgType, - NodeName: sm.config.NodeName, + NodeName: sm.memberlistConfig.NodeName, ClientShard: shard, Timestamp: time.Now(), } @@ -753,7 +744,7 @@ func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.C for _, member := range sm.ml.Members() { // Skip sending to self node - if member.Name == sm.config.NodeName { + if member.Name == sm.memberlistConfig.NodeName { continue } @@ -777,7 +768,7 @@ func (sd *shardDelegate) NodeMeta(limit int) []byte { for k, v := range sd.manager.localShards { shardsCopy[k] = v } - nodeName := sd.manager.config.NodeName + nodeName := sd.manager.memberlistConfig.NodeName sd.manager.mutex.RUnlock() state := NodeShardState{ @@ -794,7 +785,7 @@ func (sd *shardDelegate) NodeMeta(limit int) []byte { if len(data) > limit { // If metadata is too large, just send node name - return []byte(sd.manager.config.NodeName) + return []byte(sd.manager.memberlistConfig.NodeName) } return data @@ -893,7 +884,7 @@ func (sed *shardEventDelegate) NotifyLeave(node *memberlist.Node) { // If we're now isolated and have join addresses configured, restart join loop if sed.manager != nil && sed.manager.ml != nil { numMembers := sed.manager.ml.NumMembers() - if numMembers == 1 && len(sed.manager.config.JoinAddrs) > 0 { + if numMembers == 1 && len(sed.manager.memberlistConfig.JoinAddrs) > 0 { sed.logger.Info("Node is now isolated, restarting join loop", tag.NewStringTag("numMembers", strconv.Itoa(numMembers))) sed.manager.startJoinLoop() @@ -906,107 +897,3 @@ func (sed *shardEventDelegate) NotifyUpdate(node *memberlist.Node) { tag.NewStringTag("node", node.Name), tag.NewStringTag("addr", node.Addr.String())) } - -// noopShardManager provides a no-op implementation when memberlist is disabled -type noopShardManager struct{} - -func (nsm *noopShardManager) Start(_ context.Context) error { return nil } -func (nsm *noopShardManager) Stop() {} -func (nsm *noopShardManager) RegisterShard(history.ClusterShardID) time.Time { return time.Now() } -func (nsm *noopShardManager) UnregisterShard(history.ClusterShardID, time.Time) {} -func (nsm *noopShardManager) GetShardOwner(history.ClusterShardID) (string, bool) { return "", false } -func (nsm *noopShardManager) GetProxyAddress(string) (string, bool) { return "", false } -func (nsm *noopShardManager) IsLocalShard(history.ClusterShardID) bool { return true } -func (nsm *noopShardManager) GetNodeName() string { return "" } -func (nsm *noopShardManager) GetMemberNodes() []string { return []string{} } -func (nsm *noopShardManager) GetLocalShards() map[string]history.ClusterShardID { - return make(map[string]history.ClusterShardID) -} -func (nsm *noopShardManager) GetRemoteShardsForPeer(string) (map[string]NodeShardState, error) { - return make(map[string]NodeShardState), nil -} -func (nsm *noopShardManager) GetShardInfo() ShardDebugInfo { - return ShardDebugInfo{ - Enabled: false, - NodeName: "", - LocalShards: make(map[string]history.ClusterShardID), - LocalShardCount: 0, - ClusterNodes: []string{}, - ClusterSize: 0, - RemoteShards: make(map[string]string), - RemoteShardCounts: make(map[string]int), - } -} - -func (nsm *noopShardManager) SetOnPeerJoin(handler func(nodeName string)) {} -func (nsm *noopShardManager) SetOnPeerLeave(handler func(nodeName string)) {} -func (nsm *noopShardManager) SetOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) { -} -func (nsm *noopShardManager) SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) { -} - -func (nsm *noopShardManager) DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool { - if proxy != nil { - if ackCh, ok := proxy.GetLocalAckChan(srcShard); ok { - delivered := false - func() { - defer func() { - if panicErr := recover(); panicErr != nil { - if logger != nil { - logger.Warn("Failed to deliver ACK to local shard owner (channel closed)") - } - } - }() - select { - case ackCh <- *routedAck: - delivered = true - case <-shutdownChan.Channel(): - // Shutdown signal received - } - }() - if delivered { - return true - } - if shutdownChan.IsShutdown() { - return false - } - } - } - return false -} - -func (nsm *noopShardManager) DeliverMessagesToShardOwner(targetShard history.ClusterShardID, routedMsg *RoutedMessage, proxy *Proxy, shutdownChan channel.ShutdownOnce, logger log.Logger) bool { - if proxy != nil { - if ch, ok := proxy.GetRemoteSendChan(targetShard); ok { - delivered := false - func() { - defer func() { - if panicErr := recover(); panicErr != nil { - if logger != nil { - logger.Warn("Failed to deliver messages to local shard owner (channel closed)") - } - } - }() - select { - case ch <- *routedMsg: - delivered = true - case <-shutdownChan.Channel(): - // Shutdown signal received - } - }() - if delivered { - return true - } - if shutdownChan.IsShutdown() { - return false - } - } - } - return false -} - -func (nsm *noopShardManager) SetIntraProxyManager(intraMgr *intraProxyManager) { -} -func (nsm *noopShardManager) GetIntraProxyManager() *intraProxyManager { - return nil -} From 4c716c986dec216a174f97257d7e271049c1839a Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Tue, 16 Dec 2025 00:30:12 -0800 Subject: [PATCH 17/38] fix for cluster_conn; fix routing --- proxy/admin_stream_transfer.go | 43 ++++++++- proxy/adminservice.go | 71 +++++++------- proxy/adminservice_test.go | 2 +- proxy/cluster_connection.go | 123 +++++++++++++----------- proxy/intra_proxy_router.go | 8 +- proxy/proxy_streams.go | 19 ++-- proxy/shard_manager.go | 46 +++++---- proxy/stream_tracker.go | 12 +-- proxy/test/replication_failover_test.go | 71 ++++++++++---- 9 files changed, 246 insertions(+), 149 deletions(-) diff --git a/proxy/admin_stream_transfer.go b/proxy/admin_stream_transfer.go index 03d47608..ceae61c6 100644 --- a/proxy/admin_stream_transfer.go +++ b/proxy/admin_stream_transfer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "strings" "sync" "time" @@ -58,6 +59,7 @@ type StreamForwarder struct { targetClusterShardID history.ClusterShardID metricLabelValues []string logger log.Logger + streamID string sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient shutdownChan channel.ShutdownOnce @@ -72,7 +74,10 @@ func newStreamForwarder( metricLabelValues []string, logger log.Logger, ) *StreamForwarder { + streamID := BuildForwarderStreamID(sourceClusterShardID, targetClusterShardID) + logger = log.With(logger, tag.NewStringTag("streamID", streamID)) return &StreamForwarder{ + streamID: streamID, adminClient: adminClient, targetStreamServer: targetStreamServer, targetMetadata: targetMetadata, @@ -87,6 +92,11 @@ func newStreamForwarder( // It sets up bidirectional forwarding with proper shutdown handling. // Returns the stream duration. func (f *StreamForwarder) Run() error { + f.logger = log.With(f.logger, + tag.NewStringTag("role", "forwarder"), + tag.NewStringTag("streamID", f.streamID), + ) + // simply forwarding target metadata outgoingContext := metadata.NewOutgoingContext(f.targetStreamServer.Context(), f.targetMetadata) outgoingContext, cancel := context.WithCancel(outgoingContext) @@ -105,6 +115,13 @@ func (f *StreamForwarder) Run() error { defer metrics.AdminServiceStreamsClosedCount.WithLabelValues(f.metricLabelValues...).Inc() streamStartTime := time.Now() + // Register the forwarder stream here + streamTracker := GetGlobalStreamTracker() + sourceShard := ClusterShardIDtoString(f.sourceClusterShardID) + targetShard := ClusterShardIDtoString(f.targetClusterShardID) + streamTracker.RegisterStream(f.streamID, "StreamWorkflowReplicationMessages", "forwarder", sourceShard, targetShard, StreamRoleForwarder) + defer streamTracker.UnregisterStream(f.streamID) + // When one side of the stream dies, we want to tell the other side to hang up // (see https://stackoverflow.com/questions/68218469/how-to-un-wedge-go-grpc-bidi-streaming-server-from-the-blocking-recv-call) // One call to StreamWorkflowReplicationMessages establishes a one-way channel through the proxy from one server to another. @@ -135,8 +152,10 @@ func (f *StreamForwarder) Run() error { } func (f *StreamForwarder) forwardReplicationMessages(wg *sync.WaitGroup) { + f.logger.Info("proxyStreamForwarder forwardReplicationMessages started") + defer f.logger.Info("proxyStreamForwarder forwardReplicationMessages finished") + defer func() { - f.logger.Debug("Shutdown sourceStreamClient.Recv loop.") f.shutdownChan.Shutdown() wg.Done() }() @@ -167,6 +186,15 @@ func (f *StreamForwarder) forwardReplicationMessages(wg *sync.WaitGroup) { case *adminservice.StreamWorkflowReplicationMessagesResponse_Messages: f.logger.Debug("forwarding ReplicationMessages", tag.NewInt64("exclusive", attr.Messages.GetExclusiveHighWatermark())) + msg := make([]string, 0, len(attr.Messages.ReplicationTasks)) + for i, task := range attr.Messages.ReplicationTasks { + msg = append(msg, fmt.Sprintf("[%d]: %v", i, task.SourceTaskId)) + } + f.logger.Info(fmt.Sprintf("forwarding ReplicationMessages: exclusive %v, tasks: %v", attr.Messages.ExclusiveHighWatermark, strings.Join(msg, ", "))) + + streamTracker := GetGlobalStreamTracker() + streamTracker.UpdateStreamReplicationMessages(f.streamID, attr.Messages.ExclusiveHighWatermark) + if err = f.targetStreamServer.Send(resp); err != nil { if err != io.EOF { f.logger.Error("targetStreamServer.Send encountered error", tag.Error(err)) @@ -188,7 +216,8 @@ func (f *StreamForwarder) forwardReplicationMessages(wg *sync.WaitGroup) { func (f *StreamForwarder) forwardAcks(wg *sync.WaitGroup) { defer func() { - f.logger.Debug("Shutdown targetStreamServer.Recv loop.") + f.logger.Info("StreamForwarder forwardAck started") + defer f.logger.Info("proxyStreamForwarder forwardAck finished") f.shutdownChan.Shutdown() var err error closeSent := make(chan struct{}) @@ -235,7 +264,15 @@ func (f *StreamForwarder) forwardAcks(wg *sync.WaitGroup) { switch attr := req.GetAttributes().(type) { case *adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState: - f.logger.Debug("forwarding SyncReplicationState", tag.NewInt64("inclusive", attr.SyncReplicationState.GetInclusiveLowWatermark())) + f.logger.Info(fmt.Sprintf("forwarding SyncReplicationState: inclusive %v, attr: %v", attr.SyncReplicationState.InclusiveLowWatermark, attr)) + + var watermarkTime *time.Time + if attr.SyncReplicationState.InclusiveLowWatermarkTime != nil { + t := attr.SyncReplicationState.InclusiveLowWatermarkTime.AsTime() + watermarkTime = &t + } + streamTracker := GetGlobalStreamTracker() + streamTracker.UpdateStreamSyncReplicationState(f.streamID, attr.SyncReplicationState.InclusiveLowWatermark, watermarkTime) if err = f.sourceStreamClient.Send(req); err != nil { if err != io.EOF { f.logger.Error("sourceStreamClient.Send encountered error", tag.Error(err)) diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 5955f943..8a5ebfd9 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -23,8 +23,14 @@ import ( type ( LCMParameters struct { - LCM int32 `yaml:"lcm"` - TargetShardCount int32 `yaml:"targetShardCount"` + LCM int32 + TargetShardCount int32 + } + + RoutingParameters struct { + OverrideShardCount int32 + RoutingLocalShardCount int32 + DirectionLabel string } adminServiceProxyServer struct { @@ -38,7 +44,7 @@ type ( reportStreamValue func(idx int32, value int32) shardCountConfig config.ShardCountConfig lcmParameters LCMParameters - overrideShardCount int32 + routingParameters RoutingParameters } ) @@ -52,9 +58,9 @@ func NewAdminServiceProxyServer( reportStreamValue func(idx int32, value int32), shardCountConfig config.ShardCountConfig, lcmParameters LCMParameters, + routingParameters RoutingParameters, logger log.Logger, clusterConnection *ClusterConnection, - overrideShardCount int32, ) adminservice.AdminServiceServer { // The AdminServiceStreams will duplicate the same output for an underlying connection issue hundreds of times. // Limit their output to three times per minute @@ -70,7 +76,7 @@ func NewAdminServiceProxyServer( reportStreamValue: reportStreamValue, shardCountConfig: shardCountConfig, lcmParameters: lcmParameters, - overrideShardCount: overrideShardCount, + routingParameters: routingParameters, } } @@ -120,8 +126,8 @@ func (s *adminServiceProxyServer) DescribeCluster(ctx context.Context, in0 *admi // common multiple of both cluster shard counts. resp.HistoryShardCount = s.lcmParameters.LCM case config.ShardCountRouting: - if s.overrideShardCount > 0 { - resp.HistoryShardCount = s.overrideShardCount + if s.routingParameters.OverrideShardCount > 0 { + resp.HistoryShardCount = s.routingParameters.OverrideShardCount } } @@ -132,6 +138,8 @@ func (s *adminServiceProxyServer) DescribeCluster(ctx context.Context, in0 *admi } } + s.logger.Info("DescribeCluster response", tag.NewStringTag("response", fmt.Sprintf("%v", resp))) + return resp, err } @@ -355,8 +363,8 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( func (s *adminServiceProxyServer) streamIntraProxyRouting( logger log.Logger, streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, - clientShardID history.ClusterShardID, - serverShardID history.ClusterShardID, + sourceShardID history.ClusterShardID, + targetShardID history.ClusterShardID, ) error { logger.Info("streamIntraProxyRouting started") defer logger.Info("streamIntraProxyRouting finished") @@ -371,14 +379,12 @@ func (s *adminServiceProxyServer) streamIntraProxyRouting( } // Only allow intra-proxy when at least one shard is local to this proxy instance - isLocalClient := s.clusterConnection.shardManager.IsLocalShard(clientShardID) - isLocalServer := s.clusterConnection.shardManager.IsLocalShard(serverShardID) - if isLocalClient || !isLocalServer { + isLocalSource := s.clusterConnection.shardManager.IsLocalShard(sourceShardID) + isLocalTarget := s.clusterConnection.shardManager.IsLocalShard(targetShardID) + if isLocalSource || !isLocalTarget { logger.Info("Skipping intra-proxy between two local shards or two remote shards. Client may use outdated shard info.", - tag.NewStringTag("client", ClusterShardIDtoString(clientShardID)), - tag.NewStringTag("server", ClusterShardIDtoString(serverShardID)), - tag.NewBoolTag("isLocalClient", isLocalClient), - tag.NewBoolTag("isLocalServer", isLocalServer), + tag.NewBoolTag("isLocalSource", isLocalSource), + tag.NewBoolTag("isLocalTarget", isLocalTarget), ) return nil } @@ -388,8 +394,8 @@ func (s *adminServiceProxyServer) streamIntraProxyRouting( logger: logger, clusterConnection: s.clusterConnection, peerNodeName: peerNodeName, - targetShardID: clientShardID, - sourceShardID: serverShardID, + sourceShardID: sourceShardID, + targetShardID: targetShardID, } shutdownChan := channel.NewShutdownOnce() @@ -405,8 +411,8 @@ func (s *adminServiceProxyServer) streamIntraProxyRouting( func (s *adminServiceProxyServer) streamRouting( logger log.Logger, streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, - clientShardID history.ClusterShardID, - serverShardID history.ClusterShardID, + sourceShardID history.ClusterShardID, + targetShardID history.ClusterShardID, ) error { logger.Info("streamRouting started") defer logger.Info("streamRouting stopped") @@ -416,26 +422,19 @@ func (s *adminServiceProxyServer) streamRouting( proxyStreamSender := &proxyStreamSender{ logger: logger, clusterConnection: s.clusterConnection, - sourceShardID: clientShardID, - targetShardID: serverShardID, - directionLabel: "routing", + sourceShardID: sourceShardID, + targetShardID: targetShardID, + directionLabel: s.routingParameters.DirectionLabel, } - var localShardCount int32 - if s.shardCountConfig.Mode == config.ShardCountRouting { - localShardCount = s.shardCountConfig.LocalShardCount - } else { - localShardCount = s.shardCountConfig.RemoteShardCount - } - // receiver for reverse direction - proxyStreamReceiverReverse := &proxyStreamReceiver{ + proxyStreamReceiver := &proxyStreamReceiver{ logger: s.logger, clusterConnection: s.clusterConnection, adminClient: s.adminClientReverse, - localShardCount: localShardCount, - sourceShardID: serverShardID, - targetShardID: clientShardID, - directionLabel: "routing", + localShardCount: s.routingParameters.RoutingLocalShardCount, + sourceShardID: targetShardID, // reverse direction + targetShardID: sourceShardID, // reverse direction + directionLabel: s.routingParameters.DirectionLabel, } shutdownChan := channel.NewShutdownOnce() @@ -447,7 +446,7 @@ func (s *adminServiceProxyServer) streamRouting( }() go func() { defer wg.Done() - proxyStreamReceiverReverse.Run(shutdownChan) + proxyStreamReceiver.Run(shutdownChan) }() wg.Wait() diff --git a/proxy/adminservice_test.go b/proxy/adminservice_test.go index 22a7bcab..453e5bef 100644 --- a/proxy/adminservice_test.go +++ b/proxy/adminservice_test.go @@ -46,7 +46,7 @@ type adminProxyServerInput struct { func (s *adminserviceSuite) newAdminServiceProxyServer(in adminProxyServerInput, observer *ReplicationStreamObserver) adminservice.AdminServiceServer { return NewAdminServiceProxyServer("test-service-name", s.adminClientMock, s.adminClientMock, - in.apiOverrides, in.metricLabels, observer.ReportStreamValue, config.ShardCountConfig{}, LCMParameters{}, log.NewTestLogger(), nil, 0) + in.apiOverrides, in.metricLabels, observer.ReportStreamValue, config.ShardCountConfig{}, LCMParameters{}, RoutingParameters{}, log.NewTestLogger(), nil) } func (s *adminserviceSuite) TestAddOrUpdateRemoteCluster() { diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index ade187f4..1d92177b 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -110,9 +110,9 @@ type ( shardCountConfig config.ShardCountConfig logger log.Logger - clusterConnection *ClusterConnection - lcmParameters LCMParameters - overrideShardCount int32 + clusterConnection *ClusterConnection + lcmParameters LCMParameters + routingParameters RoutingParameters } ) @@ -150,67 +150,78 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon return nil, err } - var lcmParameters LCMParameters - if connConfig.ShardCountConfig.Mode == config.ShardCountLCM { - lcmParameters = LCMParameters{ - LCM: common.LCM(connConfig.ShardCountConfig.LocalShardCount, connConfig.ShardCountConfig.RemoteShardCount), - TargetShardCount: connConfig.ShardCountConfig.LocalShardCount, + getLCMParameters := func(shardCountConfig config.ShardCountConfig, inverse bool) LCMParameters { + if shardCountConfig.Mode != config.ShardCountLCM { + return LCMParameters{} + } + lcm := common.LCM(shardCountConfig.LocalShardCount, shardCountConfig.RemoteShardCount) + if inverse { + return LCMParameters{ + LCM: lcm, + TargetShardCount: shardCountConfig.LocalShardCount, + } + } + return LCMParameters{ + LCM: lcm, + TargetShardCount: shardCountConfig.RemoteShardCount, } } - getOverrideShardCount := func(shardCountConfig config.ShardCountConfig, reverse bool) int32 { - switch shardCountConfig.Mode { - case config.ShardCountLCM: - return lcmParameters.LCM - case config.ShardCountRouting: - if reverse { - return shardCountConfig.RemoteShardCount + getRoutingParameters := func(shardCountConfig config.ShardCountConfig, inverse bool, directionLabel string) RoutingParameters { + if shardCountConfig.Mode != config.ShardCountRouting { + return RoutingParameters{} + } + if inverse { + return RoutingParameters{ + OverrideShardCount: shardCountConfig.RemoteShardCount, + RoutingLocalShardCount: shardCountConfig.LocalShardCount, + DirectionLabel: directionLabel, } - return shardCountConfig.LocalShardCount } - return 0 + return RoutingParameters{ + OverrideShardCount: shardCountConfig.LocalShardCount, + RoutingLocalShardCount: shardCountConfig.RemoteShardCount, + DirectionLabel: directionLabel, + } } cc.inboundServer, cc.inboundObserver, err = createServer(lifetime, serverConfiguration{ - name: sanitizedConnectionName, - clusterDefinition: connConfig.RemoteServer, - directionLabel: "inbound", - client: cc.inboundClient, - managedClient: cc.outboundClient, - nsTranslations: nsTranslations.Inverse(), - saTranslations: saTranslations.Inverse(), - shardCountConfig: connConfig.ShardCountConfig, - logger: cc.logger, - clusterConnection: cc, - overrideShardCount: getOverrideShardCount(connConfig.ShardCountConfig, true), - lcmParameters: lcmParameters, + name: sanitizedConnectionName, + clusterDefinition: connConfig.RemoteServer, + directionLabel: "inbound", + client: cc.inboundClient, + managedClient: cc.outboundClient, + nsTranslations: nsTranslations.Inverse(), + saTranslations: saTranslations.Inverse(), + shardCountConfig: connConfig.ShardCountConfig, + logger: cc.logger, + clusterConnection: cc, + lcmParameters: getLCMParameters(connConfig.ShardCountConfig, true), + routingParameters: getRoutingParameters(connConfig.ShardCountConfig, true, "inbound"), }) if err != nil { return nil, err } - if connConfig.ShardCountConfig.Mode == config.ShardCountLCM { - lcmParameters.TargetShardCount = connConfig.ShardCountConfig.RemoteShardCount - } cc.outboundServer, cc.outboundObserver, err = createServer(lifetime, serverConfiguration{ - name: sanitizedConnectionName, - clusterDefinition: connConfig.LocalServer, - directionLabel: "outbound", - client: cc.outboundClient, - managedClient: cc.inboundClient, - nsTranslations: nsTranslations, - saTranslations: saTranslations, - shardCountConfig: connConfig.ShardCountConfig, - logger: cc.logger, - clusterConnection: cc, - overrideShardCount: getOverrideShardCount(connConfig.ShardCountConfig, false), - lcmParameters: lcmParameters, + name: sanitizedConnectionName, + clusterDefinition: connConfig.LocalServer, + directionLabel: "outbound", + client: cc.outboundClient, + managedClient: cc.inboundClient, + nsTranslations: nsTranslations, + saTranslations: saTranslations, + shardCountConfig: connConfig.ShardCountConfig, + logger: cc.logger, + clusterConnection: cc, + lcmParameters: getLCMParameters(connConfig.ShardCountConfig, false), + routingParameters: getRoutingParameters(connConfig.ShardCountConfig, false, "outbound"), }) if err != nil { return nil, err } + cc.shardManager = NewShardManager(connConfig.MemberlistConfig, logger) if connConfig.MemberlistConfig != nil { - cc.shardManager = NewShardManager(connConfig.MemberlistConfig, logger) cc.intraMgr = newIntraProxyManager(logger, cc, connConfig.ShardCountConfig) } @@ -295,10 +306,6 @@ func buildTLSTCPClient(lifetime context.Context, serverAddress string, tlsCfg en } func (c *ClusterConnection) Start() { - c.inboundServer.Start() - c.inboundObserver.Start(c.lifetime, c.inboundServer.Name(), "inbound") - c.outboundServer.Start() - c.outboundObserver.Start(c.lifetime, c.outboundServer.Name(), "outbound") if c.shardManager != nil { err := c.shardManager.Start(c.lifetime) if err != nil { @@ -308,7 +315,12 @@ func (c *ClusterConnection) Start() { if c.intraMgr != nil { c.intraMgr.Start() } + c.inboundServer.Start() + c.inboundObserver.Start(c.lifetime, c.inboundServer.Name(), "inbound") + c.outboundServer.Start() + c.outboundObserver.Start(c.lifetime, c.outboundServer.Name(), "outbound") } + func (c *ClusterConnection) Describe() string { return fmt.Sprintf("[ClusterConnection connects outbound server %s to outbound client %s, inbound server %s to inbound client %s]", c.outboundServer.Describe(), c.outboundClient.Describe(), c.inboundServer.Describe(), c.inboundClient.Describe()) @@ -317,6 +329,7 @@ func (c *ClusterConnection) Describe() string { func (c *ClusterConnection) AcceptingInboundTraffic() bool { return c.inboundClient.CanMakeCalls() && c.inboundServer.CanAcceptConnections() } + func (c *ClusterConnection) AcceptingOutboundTraffic() bool { return c.outboundClient.CanMakeCalls() && c.outboundServer.CanAcceptConnections() } @@ -482,17 +495,17 @@ func (c *ClusterConnection) RemoveLocalReceiverCancelFunc(shardID history.Cluste } // TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed -func (c *ClusterConnection) TerminatePreviousLocalReceiver(serverShardID history.ClusterShardID) { +func (c *ClusterConnection) TerminatePreviousLocalReceiver(shardID history.ClusterShardID) { // Check if there's a previous cancel function for this shard - if prevCancelFunc, exists := c.GetLocalReceiverCancelFunc(serverShardID); exists { - c.logger.Info("Terminating previous local receiver for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(serverShardID))) + if prevCancelFunc, exists := c.GetLocalReceiverCancelFunc(shardID); exists { + c.logger.Info("Terminating previous local receiver for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) // Cancel the previous receiver's context prevCancelFunc() // Force remove the cancel function and ack channel from tracking - c.RemoveLocalReceiverCancelFunc(serverShardID) - c.ForceRemoveLocalAckChan(serverShardID) + c.RemoveLocalReceiverCancelFunc(shardID) + c.ForceRemoveLocalAckChan(shardID) } } @@ -514,9 +527,9 @@ func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, obs observeFn, c.shardCountConfig, c.lcmParameters, + c.routingParameters, c.logger, c.clusterConnection, - c.overrideShardCount, ) var accessControl *auth.AccessControl if c.clusterDefinition.ACLPolicy != nil { diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 5751ff87..499d12b8 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -79,7 +79,7 @@ func (s *intraProxyStreamSender) Run( // Register server-side intra-proxy stream in tracker st := GetGlobalStreamTracker() - st.RegisterStream(s.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(s.targetShardID), ClusterShardIDtoString(s.sourceShardID), StreamRoleForwarder) + st.RegisterStream(s.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(s.sourceShardID), ClusterShardIDtoString(s.targetShardID), StreamRoleForwarder) defer st.UnregisterStream(s.streamID) s.sourceStreamServer = sourceStreamServer @@ -210,7 +210,7 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, clusterConnection *C // Register client-side intra-proxy stream in tracker st := GetGlobalStreamTracker() - st.RegisterStream(r.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(r.targetShardID), ClusterShardIDtoString(r.sourceShardID), StreamRoleForwarder) + st.RegisterStream(r.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(r.sourceShardID), ClusterShardIDtoString(r.targetShardID), StreamRoleForwarder) defer st.UnregisterStream(r.streamID) // Start replication receiver loop @@ -663,8 +663,8 @@ func (m *intraProxyManager) ClosePeerShard(peer string, clientShard, serverShard } func (m *intraProxyManager) Start() { - m.logger.Info("intraProxyManager started") - defer m.logger.Info("intraProxyManager stopped") + m.logger.Info("intraProxyManager starting") + defer m.logger.Info("intraProxyManager started") go func() { for { // timer diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 36d84a89..6b6bd50c 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -570,13 +570,15 @@ type proxyStreamReceiver struct { streamID string streamTracker *StreamTracker // keepalive state - ackMu sync.Mutex + ackMu sync.RWMutex lastAckSendTime time.Time lastSentAck *adminservice.StreamWorkflowReplicationMessagesRequest } // buildReceiverDebugSnapshot builds receiver ACK aggregation state for debugging func (r *proxyStreamReceiver) buildReceiverDebugSnapshot() *ReceiverDebugInfo { + r.ackMu.RLock() + defer r.ackMu.RUnlock() info := &ReceiverDebugInfo{ AckByTarget: make(map[string]int64), } @@ -597,10 +599,8 @@ func (r *proxyStreamReceiver) Run( r.streamID = BuildReceiverStreamID(r.sourceShardID, r.targetShardID) r.logger = log.With(r.logger, tag.NewStringTag("streamID", r.streamID), - tag.NewStringTag("client", ClusterShardIDtoString(r.targetShardID)), - tag.NewStringTag("server", ClusterShardIDtoString(r.sourceShardID)), - tag.NewStringTag("stream-source-shard", ClusterShardIDtoString(r.sourceShardID)), - tag.NewStringTag("stream-target-shard", ClusterShardIDtoString(r.targetShardID)), + tag.NewStringTag("source", ClusterShardIDtoString(r.sourceShardID)), + tag.NewStringTag("target", ClusterShardIDtoString(r.targetShardID)), tag.NewStringTag("role", "receiver"), ) r.logger.Info("proxyStreamReceiver Run") @@ -869,6 +869,7 @@ func (r *proxyStreamReceiver) sendAck( // Update per-target watermark if attr, ok := routed.Req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { r.logger.Info("Receiver received upstream ACK", tag.NewInt64("inclusive_low", attr.SyncReplicationState.InclusiveLowWatermark), tag.NewStringTag("targetShard", ClusterShardIDtoString(routed.TargetShard))) + r.ackMu.Lock() r.ackByTarget[routed.TargetShard] = attr.SyncReplicationState.InclusiveLowWatermark // Compute minimal watermark across targets min := int64(0) @@ -879,7 +880,9 @@ func (r *proxyStreamReceiver) sendAck( first = false } } - if !first && min >= r.lastSentMin { + lastSentMin := r.lastSentMin + r.ackMu.Unlock() + if !first && min >= lastSentMin { // Clamp ACK to last known exclusive high watermark from source if r.lastExclusiveHighOriginal > 0 && min > r.lastExclusiveHighOriginal { r.logger.Warn("Aggregated ACK exceeds last source high watermark; clamping", @@ -922,10 +925,10 @@ func (r *proxyStreamReceiver) sendAck( } case <-ticker.C: // Send keepalive if idle for 1 second - r.ackMu.Lock() + r.ackMu.RLock() shouldSendKeepalive := r.lastSentAck != nil && time.Since(r.lastAckSendTime) >= 1*time.Second lastAck := r.lastSentAck - r.ackMu.Unlock() + r.ackMu.RUnlock() if shouldSendKeepalive { r.logger.Info("Receiver sending keepalive ACK") diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 6dab0487..3f2f3c87 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -158,6 +158,10 @@ func (sm *shardManagerImpl) SetOnRemoteShardChange(handler func(peer string, sha func (sm *shardManagerImpl) Start(lifetime context.Context) error { sm.logger.Info("Starting shard manager") + if sm.memberlistConfig == nil { + sm.logger.Info("Shard manager not configured, skipping") + return nil + } if sm.started { sm.logger.Info("Shard manager already started") @@ -425,7 +429,7 @@ func (sm *shardManagerImpl) IsLocalShard(clientShardID history.ClusterShardID) b func (sm *shardManagerImpl) GetProxyAddress(nodeName string) (string, bool) { // TODO: get the proxy address from the memberlist metadata - if sm.memberlistConfig.ProxyAddresses == nil { + if sm.memberlistConfig == nil || sm.memberlistConfig.ProxyAddresses == nil { return "", false } addr, found := sm.memberlistConfig.ProxyAddresses[nodeName] @@ -433,6 +437,9 @@ func (sm *shardManagerImpl) GetProxyAddress(nodeName string) (string, bool) { } func (sm *shardManagerImpl) GetNodeName() string { + if sm.memberlistConfig == nil { + return "" + } return sm.memberlistConfig.NodeName } @@ -497,7 +504,7 @@ func (sm *shardManagerImpl) GetShardInfo() ShardDebugInfo { return ShardDebugInfo{ Enabled: true, - NodeName: sm.memberlistConfig.NodeName, + NodeName: sm.GetNodeName(), LocalShards: localShardMap, LocalShardCount: len(localShardMap), RemoteShards: remoteShardsMap, @@ -605,21 +612,23 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( } // Attempt remote delivery via intra-proxy when enabled and shard is remote - if owner, ok := sm.GetShardOwner(sourceShard); ok && owner != sm.memberlistConfig.NodeName { - if addr, found := sm.GetProxyAddress(owner); found { - clientShard := routedAck.TargetShard - serverShard := sourceShard - mgr := clusterConnection.intraMgr - // Synchronous send to preserve ordering - if err := mgr.sendAck(context.Background(), owner, clientShard, serverShard, routedAck.Req); err != nil { - logger.Error("Failed to forward ACK to shard owner via intra-proxy", tag.Error(err), tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) - return false + if sm.memberlistConfig != nil { + if owner, ok := sm.GetShardOwner(sourceShard); ok && owner != sm.memberlistConfig.NodeName { + if addr, found := sm.GetProxyAddress(owner); found { + clientShard := routedAck.TargetShard + serverShard := sourceShard + mgr := clusterConnection.intraMgr + // Synchronous send to preserve ordering + if err := mgr.sendAck(context.Background(), owner, clientShard, serverShard, routedAck.Req); err != nil { + logger.Error("Failed to forward ACK to shard owner via intra-proxy", tag.Error(err), tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) + return false + } + logger.Info("Forwarded ACK to shard owner via intra-proxy", tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) + return true } - logger.Info("Forwarded ACK to shard owner via intra-proxy", tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) - return true + logger.Warn("Owner proxy address not found for shard") + return false } - logger.Warn("Owner proxy address not found for shard") - return false } logger.Warn("No remote shard owner found for source shard") @@ -725,7 +734,7 @@ func (sm *shardManagerImpl) GetIntraProxyManager() *intraProxyManager { } func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.ClusterShardID) { - if !sm.started || sm.ml == nil { + if !sm.started || sm.ml == nil || sm.memberlistConfig == nil { return } @@ -762,6 +771,9 @@ func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.C // shardDelegate implements memberlist.Delegate func (sd *shardDelegate) NodeMeta(limit int) []byte { + if sd.manager == nil || sd.manager.memberlistConfig == nil { + return nil + } // Copy shard map under read lock to avoid concurrent map iteration/modification sd.manager.mutex.RLock() shardsCopy := make(map[string]ShardInfo, len(sd.manager.localShards)) @@ -882,7 +894,7 @@ func (sed *shardEventDelegate) NotifyLeave(node *memberlist.Node) { tag.NewStringTag("addr", node.Addr.String())) // If we're now isolated and have join addresses configured, restart join loop - if sed.manager != nil && sed.manager.ml != nil { + if sed.manager != nil && sed.manager.ml != nil && sed.manager.memberlistConfig != nil { numMembers := sed.manager.ml.NumMembers() if numMembers == 1 && len(sed.manager.memberlistConfig.JoinAddrs) > 0 { sed.logger.Info("Node is now isolated, restarting join loop", diff --git a/proxy/stream_tracker.go b/proxy/stream_tracker.go index b53086c5..8a09e179 100644 --- a/proxy/stream_tracker.go +++ b/proxy/stream_tracker.go @@ -28,7 +28,7 @@ func NewStreamTracker() *StreamTracker { } // RegisterStream adds a new active stream -func (st *StreamTracker) RegisterStream(id, method, direction, clientShard, serverShard, role string) { +func (st *StreamTracker) RegisterStream(id, method, direction, sourceShard, targetShard, role string) { st.mu.Lock() defer st.mu.Unlock() @@ -37,8 +37,8 @@ func (st *StreamTracker) RegisterStream(id, method, direction, clientShard, serv ID: id, Method: method, Direction: direction, - ClientShard: clientShard, - ServerShard: serverShard, + ClientShard: targetShard, + ServerShard: sourceShard, Role: role, StartTime: now, LastSeen: now, @@ -159,7 +159,7 @@ func GetGlobalStreamTracker() *StreamTracker { // BuildSenderStreamID returns the canonical sender stream ID. func BuildSenderStreamID(source, target history.ClusterShardID) string { - return fmt.Sprintf("snd-%s", ClusterShardIDtoShortString(target)) + return fmt.Sprintf("snd-%s", ClusterShardIDtoShortString(source)) } // BuildReceiverStreamID returns the canonical receiver stream ID. @@ -169,8 +169,8 @@ func BuildReceiverStreamID(source, target history.ClusterShardID) string { // BuildForwarderStreamID returns the canonical forwarder stream ID. // Note: forwarder uses server-first ordering in the ID. -func BuildForwarderStreamID(client, server history.ClusterShardID) string { - return fmt.Sprintf("fwd-snd-%s", ClusterShardIDtoShortString(server)) +func BuildForwarderStreamID(source, target history.ClusterShardID) string { + return fmt.Sprintf("fwd-snd-%s", ClusterShardIDtoShortString(source)) } // BuildIntraProxySenderStreamID returns the server-side intra-proxy stream ID for a peer and shard pair. diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index c2918520..b1f62d94 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -47,12 +47,12 @@ type ( proxyAAddress string proxyBAddress string - shardCountA int - shardCountB int - shardCountConfig config.ShardCountConfig - namespace string - namespaceID string - startTime time.Time + shardCountA int + shardCountB int + shardCountConfigB config.ShardCountConfig + namespace string + namespaceID string + startTime time.Time workflows []*WorkflowDistribution @@ -69,11 +69,11 @@ type ( } TestConfig struct { - Name string - ShardCountA int - ShardCountB int - WorkflowsPerPair int - ShardCountConfig config.ShardCountConfig + Name string + ShardCountA int + ShardCountB int + WorkflowsPerPair int + ShardCountConfigB config.ShardCountConfig } ) @@ -107,20 +107,29 @@ var testConfigs = []TestConfig{ ShardCountA: 2, ShardCountB: 3, WorkflowsPerPair: 1, - ShardCountConfig: config.ShardCountConfig{ + ShardCountConfigB: config.ShardCountConfig{ Mode: config.ShardCountLCM, }, }, + { + Name: "ArbitraryShards_2to3_Routing", + ShardCountA: 2, + ShardCountB: 3, + WorkflowsPerPair: 1, + ShardCountConfigB: config.ShardCountConfig{ + Mode: config.ShardCountRouting, + }, + }, } func TestReplicationFailoverTestSuite(t *testing.T) { for _, tc := range testConfigs { t.Run(tc.Name, func(t *testing.T) { s := &ReplicationTestSuite{ - shardCountA: tc.ShardCountA, - shardCountB: tc.ShardCountB, - shardCountConfig: tc.ShardCountConfig, - workflowsPerPair: tc.WorkflowsPerPair, + shardCountA: tc.ShardCountA, + shardCountB: tc.ShardCountB, + shardCountConfigB: tc.ShardCountConfigB, + workflowsPerPair: tc.WorkflowsPerPair, } suite.Run(t, s) }) @@ -147,8 +156,8 @@ func (s *ReplicationTestSuite) SetupSuite() { proxyBOutbound := fmt.Sprintf("localhost:%d", basePort+101) muxServerAddress := fmt.Sprintf("localhost:%d", basePort+200) - proxyBShardConfig := s.shardCountConfig - if proxyBShardConfig.Mode == config.ShardCountLCM { + proxyBShardConfig := s.shardCountConfigB + if proxyBShardConfig.Mode == config.ShardCountLCM || proxyBShardConfig.Mode == config.ShardCountRouting { proxyBShardConfig.LocalShardCount = int32(s.shardCountB) proxyBShardConfig.RemoteShardCount = int32(s.shardCountA) } @@ -156,6 +165,10 @@ func (s *ReplicationTestSuite) SetupSuite() { s.proxyA = s.createProxy("proxy-a", s.proxyAAddress, proxyAOutbound, muxServerAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}) s.proxyB = s.createProxy("proxy-b", s.proxyBAddress, proxyBOutbound, muxServerAddress, s.clusterB, config.ServerMode, proxyBShardConfig) + s.logger.Info("Waiting for proxies to start and connect") + time.Sleep(10 * time.Second) // TODO: remove this once we have a better way to wait for proxies to start and connect + + s.logger.Info("Configuring remote clusters") s.configureRemoteCluster(s.clusterA, s.clusterB.ClusterName(), proxyAOutbound) s.configureRemoteCluster(s.clusterB, s.clusterA.ClusterName(), proxyBOutbound) s.waitForReplicationReady() @@ -229,7 +242,8 @@ func (s *ReplicationTestSuite) createCluster( } testClusterFactory := testcore.NewTestClusterFactory() - cluster, err := testClusterFactory.NewCluster(s.T(), clusterConfig, s.logger) + logger := log.With(s.logger, tag.NewStringTag("clusterName", clusterName)) + cluster, err := testClusterFactory.NewCluster(s.T(), clusterConfig, logger) s.NoError(err, "Failed to create cluster %s", clusterName) s.NotNil(cluster) @@ -601,18 +615,37 @@ func (s *ReplicationTestSuite) waitForClusterConnected( s.logger.Debug("GetReplicationStatus failed", tag.Error(err)) return false } + s.logger.Info("GetReplicationStatus response", + tag.NewStringTag("response", fmt.Sprintf("%+v", resp)), + tag.NewStringTag("source", sourceCluster.ClusterName()), + tag.NewStringTag("target", targetClusterName), + ) if len(resp.Shards) == 0 { return false } for _, shard := range resp.Shards { + s.logger.Info("Replication status", + tag.NewStringTag("shard", fmt.Sprintf("%d", shard.ShardId)), + tag.NewInt64("maxTaskId", shard.MaxReplicationTaskId), + tag.NewStringTag("remoteClusters", fmt.Sprintf("%+v", shard.RemoteClusters)), + ) + if shard.MaxReplicationTaskId <= 0 { + s.logger.Info("Max replication task id is 0", + tag.NewStringTag("shard", fmt.Sprintf("%d", shard.ShardId)), + tag.NewInt64("maxTaskId", shard.MaxReplicationTaskId), + ) continue } remoteInfo, ok := shard.RemoteClusters[targetClusterName] if !ok || remoteInfo == nil { + s.logger.Info("Remote cluster not found", + tag.NewStringTag("shard", fmt.Sprintf("%d", shard.ShardId)), + tag.NewStringTag("targetClusterName", targetClusterName), + ) return false } From cf59b2dda77a7f661c301b1ac89ff6fe6c198d4b Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Tue, 16 Dec 2025 17:13:08 -0800 Subject: [PATCH 18/38] fix test error: handle case when no replication task after shard register --- MEMBERLIST_TROUBLESHOOTING.md | 229 ------------------------ PROXY_FORWARDING.md | 96 ---------- develop/config/nginx.conf | 43 ----- proxy/adminservice_test.go | 2 +- proxy/cluster_connection.go | 58 ++++++ proxy/proxy_streams.go | 78 ++++++++ proxy/test/replication_failover_test.go | 25 ++- 7 files changed, 155 insertions(+), 376 deletions(-) delete mode 100644 MEMBERLIST_TROUBLESHOOTING.md delete mode 100644 PROXY_FORWARDING.md delete mode 100644 develop/config/nginx.conf diff --git a/MEMBERLIST_TROUBLESHOOTING.md b/MEMBERLIST_TROUBLESHOOTING.md deleted file mode 100644 index e22fed3b..00000000 --- a/MEMBERLIST_TROUBLESHOOTING.md +++ /dev/null @@ -1,229 +0,0 @@ -# Memberlist Network Troubleshooting - -This guide helps resolve network connectivity issues with memberlist in the s2s-proxy. - -## Common Issues - -### UDP Ping Failures - -**Symptoms:** -``` -[DEBUG] memberlist: Failed UDP ping: proxy-node-a-2 (timeout reached) -[WARN] memberlist: Was able to connect to proxy-node-a-2 over TCP but UDP probes failed, network may be misconfigured -``` - -**Causes:** -- UDP traffic blocked by firewalls -- Running in containers without UDP port mapping -- Network security policies blocking UDP -- NAT/proxy configurations - -**Solutions:** - -#### 1. Use TCP-Only Mode (Recommended) - -Update your configuration to use TCP-only transport: - -```yaml -memberlist: - enabled: true - enableForwarding: true - nodeName: "proxy-node-1" - bindAddr: "0.0.0.0" - bindPort: 7946 - joinAddrs: - - "proxy-node-2:7946" - - "proxy-node-3:7946" - # TCP-only configuration - tcpOnly: true # Disable UDP entirely - disableTCPPings: true # Improve performance in TCP-only mode - probeTimeoutMs: 1000 # Adjust for network latency - probeIntervalMs: 2000 # Reduce probe frequency -``` - -#### 2. Open UDP Ports - -If you want to keep UDP enabled: - -**Docker/Kubernetes:** -```bash -# Expose UDP port in Docker -docker run -p 7946:7946/udp -p 7946:7946/tcp ... - -# Kubernetes service -apiVersion: v1 -kind: Service -spec: - ports: - - name: memberlist-tcp - port: 7946 - protocol: TCP - - name: memberlist-udp - port: 7946 - protocol: UDP -``` - -**Firewall:** -```bash -# Linux iptables -iptables -A INPUT -p udp --dport 7946 -j ACCEPT -iptables -A INPUT -p tcp --dport 7946 -j ACCEPT - -# AWS Security Groups - allow UDP/TCP 7946 -``` - -#### 3. Adjust Bind Address - -For container environments, use specific bind addresses: - -```yaml -memberlist: - bindAddr: "0.0.0.0" # Listen on all interfaces - # OR - bindAddr: "10.0.0.1" # Specific container IP -``` - -## Configuration Options - -### Network Timing - -```yaml -memberlist: - probeTimeoutMs: 500 # Time to wait for ping response (default: 500ms) - probeIntervalMs: 1000 # Time between health probes (default: 1s) -``` - -**Adjust based on network conditions:** -- **Fast networks**: Lower values (500ms timeout, 1s interval) -- **Slow/high-latency networks**: Higher values (1000ms timeout, 2s interval) -- **Unreliable networks**: Much higher values (2000ms timeout, 5s interval) - -### Transport Modes - -#### Local Network Mode (Default) -```yaml -memberlist: - tcpOnly: false # Uses both UDP and TCP -``` -- Best for local networks -- Fastest failure detection -- Requires UDP connectivity - -#### TCP-Only Mode -```yaml -memberlist: - tcpOnly: true # TCP transport only - disableTCPPings: true # Optimize for TCP-only -``` -- Works in restricted networks -- Slightly slower failure detection -- More reliable in containerized environments - -## Testing Connectivity - -### 1. Test TCP Connectivity -```bash -# Test if TCP port is reachable -telnet proxy-node-2 7946 -nc -zv proxy-node-2 7946 -``` - -### 2. Test UDP Connectivity -```bash -# Test UDP port (if not using tcpOnly) -nc -u -zv proxy-node-2 7946 -``` - -### 3. Monitor Memberlist Logs -Enable debug logging to see detailed memberlist behavior: -```bash -# Set log level to debug -export LOG_LEVEL=debug -./s2s-proxy start --config your-config.yaml -``` - -### 4. Check Debug Endpoint -Query the debug endpoint to see cluster status: -```bash -curl http://localhost:6060/debug/connections | jq .shard_info -``` - -## Example Configurations - -### Docker Compose -```yaml -version: '3.8' -services: - proxy1: - image: s2s-proxy - ports: - - "7946:7946/tcp" - - "7946:7946/udp" # Only if not using tcpOnly - environment: - - CONFIG_PATH=/config/proxy.yaml -``` - -### Kubernetes -```yaml -apiVersion: apps/v1 -kind: Deployment -spec: - template: - spec: - containers: - - name: s2s-proxy - ports: - - containerPort: 7946 - protocol: TCP - - containerPort: 7946 - protocol: UDP # Only if not using tcpOnly -``` - -## Performance Impact - -**UDP + TCP Mode:** -- Fastest failure detection (~1-2 seconds) -- Best for stable networks -- Requires UDP connectivity - -**TCP-Only Mode:** -- Slightly slower failure detection (~2-5 seconds) -- More reliable in restricted environments -- Works everywhere TCP works - -## Recommended Settings by Environment - -### Local Development -```yaml -memberlist: - tcpOnly: false - probeTimeoutMs: 500 - probeIntervalMs: 1000 -``` - -### Docker/Containers -```yaml -memberlist: - tcpOnly: true - disableTCPPings: true - probeTimeoutMs: 1000 - probeIntervalMs: 2000 -``` - -### Kubernetes -```yaml -memberlist: - tcpOnly: true - disableTCPPings: true - probeTimeoutMs: 1500 - probeIntervalMs: 3000 -``` - -### High-Latency/Unreliable Networks -```yaml -memberlist: - tcpOnly: true - disableTCPPings: true - probeTimeoutMs: 2000 - probeIntervalMs: 5000 -``` diff --git a/PROXY_FORWARDING.md b/PROXY_FORWARDING.md deleted file mode 100644 index 0ee8639f..00000000 --- a/PROXY_FORWARDING.md +++ /dev/null @@ -1,96 +0,0 @@ -# Proxy-to-Proxy Forwarding - -This document describes the proxy-to-proxy forwarding functionality that enables distributed shard management across multiple s2s-proxy instances. - -## Overview - -The proxy-to-proxy forwarding mechanism allows multiple proxy instances to work together as a cluster, where each proxy instance owns a subset of shards. When a replication stream request comes to a proxy that doesn't own the target shard, it automatically forwards the request to the proxy instance that does own that shard. - -## Architecture - -``` -Client → Proxy A (Inbound) → Proxy B (Inbound) → Target Server - (Forward) (Owner) -``` - -## How It Works - -1. **Shard Ownership**: Using consistent hashing via HashiCorp memberlist, each proxy instance is assigned ownership of specific shards -2. **Ownership Check**: When a `StreamWorkflowReplicationMessages` request arrives on an **inbound connection** with **forwarding enabled**, the proxy checks if it owns the required shard -3. **Forwarding**: If another proxy owns the shard, the request is forwarded to that proxy (only for inbound connections with forwarding enabled) -4. **Bidirectional Streaming**: The forwarding proxy acts as a transparent relay, forwarding both requests and responses - -## Key Components - -### Shard Manager -- **Interface**: `ShardManager` with methods for shard ownership and proxy address resolution -- **Implementation**: Uses memberlist for cluster membership and consistent hashing for shard distribution -- **Methods**: - - `IsLocalShard(shardID)` - Check if this proxy owns a shard - - `GetShardOwner(shardID)` - Get the node name that owns a shard - - `GetProxyAddress(nodeName)` - Get the service address for a proxy node - -### Forwarding Logic -- **Location**: `StreamWorkflowReplicationMessages` in `adminservice.go` -- **Conditions**: Forwards only when: - - **Inbound connection** (`s.IsInbound == true`) - - **Memberlist enabled** (`memberlist.enabled == true`) - - **Forwarding enabled** (`memberlist.enableForwarding == true`) -- **Checks**: Two shard ownership checks (only for inbound): - 1. `clientShardID` - the incoming shard from the client - 2. `serverShardID` - the target shard (after LCM remapping if applicable) -- **Forwarding Function**: `forwardToProxy()` handles the bidirectional streaming - -### Configuration - -```yaml -memberlist: - enabled: true - # Enable proxy-to-proxy forwarding - enableForwarding: true - nodeName: "proxy-node-1" - bindAddr: "0.0.0.0" - bindPort: 7946 - joinAddrs: - - "proxy-node-2:7946" - - "proxy-node-3:7946" - shardStrategy: "consistent" - proxyAddresses: - "proxy-node-1": "localhost:7001" - "proxy-node-2": "proxy-node-2:7001" - "proxy-node-3": "proxy-node-3:7001" -``` - -## Metrics - -The following Prometheus metrics track forwarding operations: - -- `shard_distribution` - Number of shards handled by each proxy instance -- `shard_forwarding_total` - Total forwarding operations (labels: from_node, to_node, result) -- `memberlist_cluster_size` - Number of nodes in the memberlist cluster -- `memberlist_events_total` - Memberlist events (join/leave) - -## Benefits - -1. **Horizontal Scaling**: Add more proxy instances to handle more shards -2. **High Availability**: Automatic shard redistribution when proxies fail -3. **Load Distribution**: Shards are evenly distributed across proxy instances -4. **Transparent**: Clients don't need to know about shard ownership -5. **Configurable**: Can enable cluster coordination without forwarding via `enableForwarding: false` -6. **Backward Compatible**: Works with existing setups when memberlist is disabled - -## Limitations - -- Forwarding adds one additional network hop for non-local shards -- Requires careful configuration of proxy addresses for inter-proxy communication -- Uses insecure gRPC connections for proxy-to-proxy communication (can be enhanced with TLS) - -## Example Deployment - -For a 3-proxy cluster handling temporal replication: - -1. **proxy-node-1**: Handles shards 0, 3, 6, 9, ... -2. **proxy-node-2**: Handles shards 1, 4, 7, 10, ... -3. **proxy-node-3**: Handles shards 2, 5, 8, 11, ... - -When a replication stream for shard 7 comes to proxy-node-1, it will automatically forward to proxy-node-2. \ No newline at end of file diff --git a/develop/config/nginx.conf b/develop/config/nginx.conf deleted file mode 100644 index fd62b4d1..00000000 --- a/develop/config/nginx.conf +++ /dev/null @@ -1,43 +0,0 @@ -worker_processes 1; - -events { - worker_connections 1024; -} - -stream { - # Proxy for source outbound (6133, 6233) => exposed at 7001 - upstream source_outbound { - least_conn; - server host.docker.internal:6133; - server host.docker.internal:6233; - } - - server { - listen 7001; - proxy_pass source_outbound; - } - - # Proxy for target outbound (6333, 6433) => exposed at 7002 - upstream target_outbound { - least_conn; - server host.docker.internal:6333; - server host.docker.internal:6433; - } - - server { - listen 7002; - proxy_pass target_outbound; - } - - # Proxy for target server ports (6334, 6434) => exposed at 7003 - upstream target_server { - least_conn; - server host.docker.internal:6334; - server host.docker.internal:6434; - } - - server { - listen 7003; - proxy_pass target_server; - } -} diff --git a/proxy/adminservice_test.go b/proxy/adminservice_test.go index 453e5bef..e964ba95 100644 --- a/proxy/adminservice_test.go +++ b/proxy/adminservice_test.go @@ -235,7 +235,7 @@ func (s *adminserviceSuite) TestAPIOverrides_FailoverVersionIncrement() { s.adminClientMock.EXPECT().DescribeCluster(ctx, gomock.Any()).Return(c.mockResp, nil) resp, err := server.DescribeCluster(ctx, req) s.NoError(err) - s.Equal(c.expResp, resp) + s.Equal(c.expResp.FailoverVersionIncrement, resp.FailoverVersionIncrement) s.Equal("[]", observer.PrintActiveStreams()) }) } diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index 1d92177b..06438f2a 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -76,6 +76,10 @@ type ( // localReceiverCancelFuncs maps shard IDs to context cancel functions for local receiver termination localReceiverCancelFuncs map[history.ClusterShardID]context.CancelFunc localReceiverCancelFuncsMu sync.RWMutex + + // activeReceivers tracks active proxyStreamReceiver instances by source shard for watermark propagation + activeReceivers map[history.ClusterShardID]*proxyStreamReceiver + activeReceiversMu sync.RWMutex } // contextAwareServer represents a startable gRPC server used to provide the Temporal interface on some connection. // IsUsable and Describe allow the caller to know and log the current state of the server. @@ -131,6 +135,7 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon remoteSendChannels: make(map[history.ClusterShardID]chan RoutedMessage), localAckChannels: make(map[history.ClusterShardID]chan RoutedAck), localReceiverCancelFuncs: make(map[history.ClusterShardID]context.CancelFunc), + activeReceivers: make(map[history.ClusterShardID]*proxyStreamReceiver), } var err error cc.inboundClient, err = createClient(lifetime, sanitizedConnectionName, connConfig.LocalServer.Connection, "inbound") @@ -311,6 +316,27 @@ func (c *ClusterConnection) Start() { if err != nil { c.logger.Error("Failed to start shard manager", tag.Error(err)) } + // Wire up shard change callbacks to propagate pending watermarks + // Note: This will overwrite any existing callbacks set by SetIntraProxyManager, + // but we ensure intra-proxy manager is notified via its Notify() method + c.shardManager.SetOnLocalShardChange(func(shard history.ClusterShardID, added bool) { + // Notify intra-proxy manager if it exists + if c.intraMgr != nil { + c.intraMgr.Notify() + } + if added { + c.notifyReceiversOfNewShard(shard) + } + }) + c.shardManager.SetOnRemoteShardChange(func(peer string, shard history.ClusterShardID, added bool) { + // Notify intra-proxy manager if it exists + if c.intraMgr != nil { + c.intraMgr.Notify() + } + if added { + c.notifyReceiversOfNewShard(shard) + } + }) } if c.intraMgr != nil { c.intraMgr.Start() @@ -509,6 +535,38 @@ func (c *ClusterConnection) TerminatePreviousLocalReceiver(shardID history.Clust } } +// RegisterActiveReceiver registers an active receiver for watermark propagation +func (c *ClusterConnection) RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver *proxyStreamReceiver) { + c.activeReceiversMu.Lock() + defer c.activeReceiversMu.Unlock() + c.activeReceivers[sourceShardID] = receiver +} + +// UnregisterActiveReceiver removes an active receiver +func (c *ClusterConnection) UnregisterActiveReceiver(sourceShardID history.ClusterShardID) { + c.activeReceiversMu.Lock() + defer c.activeReceiversMu.Unlock() + delete(c.activeReceivers, sourceShardID) +} + +// notifyReceiversOfNewShard notifies all receivers about a newly registered target shard +// so they can send pending watermarks if available +func (c *ClusterConnection) notifyReceiversOfNewShard(targetShardID history.ClusterShardID) { + c.activeReceiversMu.RLock() + receivers := make([]*proxyStreamReceiver, 0, len(c.activeReceivers)) + for _, receiver := range c.activeReceivers { + receivers = append(receivers, receiver) + } + c.activeReceiversMu.RUnlock() + + for _, receiver := range receivers { + // Only notify receivers that route to the same cluster as the newly registered shard + if receiver.targetShardID.ClusterID == targetShardID.ClusterID { + receiver.sendPendingWatermarkToShard(targetShardID, c) + } + } +} + // buildProxyServer uses the provided grpc.ClientConnInterface and config.ProxyConfig to create a grpc.Server that proxies // the Temporal API across the ClientConnInterface. func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, observeFn func(int32, int32)) (*grpc.Server, error) { diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 6b6bd50c..5e366e51 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -573,6 +573,9 @@ type proxyStreamReceiver struct { ackMu sync.RWMutex lastAckSendTime time.Time lastSentAck *adminservice.StreamWorkflowReplicationMessagesRequest + // lastWatermark tracks the last watermark received from source shard for late-registering target shards + lastWatermarkMu sync.RWMutex + lastWatermark *replicationv1.WorkflowReplicationMessages } // buildReceiverDebugSnapshot builds receiver ACK aggregation state for debugging @@ -634,9 +637,12 @@ func (r *proxyStreamReceiver) Run( r.ackChan = make(chan RoutedAck, 100) r.clusterConnection.SetLocalAckChan(r.sourceShardID, r.ackChan) r.clusterConnection.SetLocalReceiverCancelFunc(r.sourceShardID, cancel) + // Register receiver for watermark propagation to late-registering shards + r.clusterConnection.RegisterActiveReceiver(r.sourceShardID, r) defer func() { r.clusterConnection.RemoveLocalAckChan(r.sourceShardID, r.ackChan) r.clusterConnection.RemoveLocalReceiverCancelFunc(r.sourceShardID) + r.clusterConnection.UnregisterActiveReceiver(r.sourceShardID) }() // init aggregation state @@ -716,6 +722,14 @@ func (r *proxyStreamReceiver) recvReplicationMessages( // record last source exclusive high watermark (original id space) r.lastExclusiveHighOriginal = attr.Messages.ExclusiveHighWatermark + // Track last watermark for late-registering shards + r.lastWatermarkMu.Lock() + r.lastWatermark = &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: attr.Messages.ExclusiveHighWatermark, + Priority: attr.Messages.Priority, + } + r.lastWatermarkMu.Unlock() + // update tracker for incoming messages if r.streamTracker != nil && r.streamID != "" { r.streamTracker.UpdateStreamLastTaskIDs(r.streamID, ids) @@ -852,6 +866,70 @@ func (r *proxyStreamReceiver) recvReplicationMessages( return nil } +// sendPendingWatermarkToShard sends the last known watermark to a newly registered target shard +// This ensures late-registering shards receive watermarks that were sent before they registered +func (r *proxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history.ClusterShardID, clusterConnection *ClusterConnection) { + r.lastWatermarkMu.RLock() + lastWatermark := r.lastWatermark + r.lastWatermarkMu.RUnlock() + + if lastWatermark == nil || lastWatermark.ExclusiveHighWatermark == 0 { + // No pending watermark to send + return + } + + r.logger.Info("Sending pending watermark to newly registered shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), + tag.NewInt64("exclusive_high", lastWatermark.ExclusiveHighWatermark)) + + msg := RoutedMessage{ + SourceShard: r.sourceShardID, + Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: lastWatermark.ExclusiveHighWatermark, + Priority: lastWatermark.Priority, + }, + }, + }, + } + + // Try to send to local shard first + if sendChan, exists := clusterConnection.GetRemoteSendChan(targetShardID); exists { + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + select { + case sendChan <- clonedMsg: + r.logger.Info("Sent pending watermark to local shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + default: + r.logger.Warn("Failed to send pending watermark to local shard (channel full)", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } + return + } + + // If not local, try to send to remote shard + if clusterConnection.shardManager != nil { + shutdownChan := channel.NewShutdownOnce() + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + if clusterConnection.shardManager.DeliverMessagesToShardOwner(targetShardID, &clonedMsg, clusterConnection, shutdownChan, r.logger) { + r.logger.Info("Sent pending watermark to remote shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } else { + r.logger.Warn("Failed to send pending watermark to remote shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } + } +} + // sendAck forwards ACKs from local ack channel upstream to the local server. func (r *proxyStreamReceiver) sendAck( sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient, diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index b1f62d94..17930e1b 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -3,7 +3,7 @@ package proxy import ( "context" "fmt" - "math/rand" + "net" "sync" "testing" "time" @@ -136,6 +136,15 @@ func TestReplicationFailoverTestSuite(t *testing.T) { } } +func getFreePort() int { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + panic(fmt.Sprintf("failed to get free port: %v", err)) + } + defer l.Close() + return l.Addr().(*net.TCPAddr).Port +} + func (s *ReplicationTestSuite) SetupSuite() { s.Assertions = require.New(s.T()) s.logger = log.NewTestLogger() @@ -149,12 +158,11 @@ func (s *ReplicationTestSuite) SetupSuite() { s.clusterA = s.createCluster("cluster-a", s.shardCountA, 1) s.clusterB = s.createCluster("cluster-b", s.shardCountB, 2) - basePort := 17000 + rand.Intn(10000) - s.proxyAAddress = fmt.Sprintf("localhost:%d", basePort) - proxyAOutbound := fmt.Sprintf("localhost:%d", basePort+1) - s.proxyBAddress = fmt.Sprintf("localhost:%d", basePort+100) - proxyBOutbound := fmt.Sprintf("localhost:%d", basePort+101) - muxServerAddress := fmt.Sprintf("localhost:%d", basePort+200) + s.proxyAAddress = fmt.Sprintf("localhost:%d", getFreePort()) + proxyAOutbound := fmt.Sprintf("localhost:%d", getFreePort()) + s.proxyBAddress = fmt.Sprintf("localhost:%d", getFreePort()) + proxyBOutbound := fmt.Sprintf("localhost:%d", getFreePort()) + muxServerAddress := fmt.Sprintf("localhost:%d", getFreePort()) proxyBShardConfig := s.shardCountConfigB if proxyBShardConfig.Mode == config.ShardCountLCM || proxyBShardConfig.Mode == config.ShardCountRouting { @@ -640,6 +648,9 @@ func (s *ReplicationTestSuite) waitForClusterConnected( continue } + s.NotNil(shard.ShardLocalTime) + s.WithinRange(shard.ShardLocalTime.AsTime(), s.startTime, time.Now()) + remoteInfo, ok := shard.RemoteClusters[targetClusterName] if !ok || remoteInfo == nil { s.logger.Info("Remote cluster not found", From 4813e58d8926ce6069445c124098109d14d27432 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Tue, 16 Dec 2025 19:51:15 -0800 Subject: [PATCH 19/38] Refactor: move channel management to ShardManager Move channel management and intra-proxy routing from ClusterConnection to ShardManager. --- proxy/adminservice.go | 34 +- proxy/cluster_connection.go | 263 +-------------- proxy/debug.go | 6 +- proxy/fx.go | 1 - proxy/intra_proxy_router.go | 106 +++--- proxy/proxy_streams.go | 82 ++--- proxy/shard_manager.go | 416 ++++++++++++++++++++---- proxy/test/replication_failover_test.go | 6 +- 8 files changed, 479 insertions(+), 435 deletions(-) diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 8a5ebfd9..ac52d215 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -391,11 +391,11 @@ func (s *adminServiceProxyServer) streamIntraProxyRouting( // Sender: handle ACKs coming from peer and forward to original owner sender := &intraProxyStreamSender{ - logger: logger, - clusterConnection: s.clusterConnection, - peerNodeName: peerNodeName, - sourceShardID: sourceShardID, - targetShardID: targetShardID, + logger: logger, + shardManager: s.clusterConnection.shardManager, + peerNodeName: peerNodeName, + sourceShardID: sourceShardID, + targetShardID: targetShardID, } shutdownChan := channel.NewShutdownOnce() @@ -420,21 +420,21 @@ func (s *adminServiceProxyServer) streamRouting( // client: stream receiver // server: stream sender proxyStreamSender := &proxyStreamSender{ - logger: logger, - clusterConnection: s.clusterConnection, - sourceShardID: sourceShardID, - targetShardID: targetShardID, - directionLabel: s.routingParameters.DirectionLabel, + logger: logger, + shardManager: s.clusterConnection.shardManager, + sourceShardID: sourceShardID, + targetShardID: targetShardID, + directionLabel: s.routingParameters.DirectionLabel, } proxyStreamReceiver := &proxyStreamReceiver{ - logger: s.logger, - clusterConnection: s.clusterConnection, - adminClient: s.adminClientReverse, - localShardCount: s.routingParameters.RoutingLocalShardCount, - sourceShardID: targetShardID, // reverse direction - targetShardID: sourceShardID, // reverse direction - directionLabel: s.routingParameters.DirectionLabel, + logger: s.logger, + shardManager: s.clusterConnection.shardManager, + adminClient: s.adminClientReverse, + localShardCount: s.routingParameters.RoutingLocalShardCount, + sourceShardID: targetShardID, // reverse direction + targetShardID: sourceShardID, // reverse direction + directionLabel: s.routingParameters.DirectionLabel, } shutdownChan := channel.NewShutdownOnce() diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index 06438f2a..b727d519 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -7,14 +7,12 @@ import ( "fmt" "io" "net" - "sync" "time" grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" "github.com/prometheus/client_golang/prometheus" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/api/adminservice/v1" - "go.temporal.io/server/client/history" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "google.golang.org/grpc" @@ -62,24 +60,7 @@ type ( inboundObserver *ReplicationStreamObserver outboundObserver *ReplicationStreamObserver shardManager ShardManager - intraMgr *intraProxyManager logger log.Logger - - // remoteSendChannels maps shard IDs to send channels for replication message routing - remoteSendChannels map[history.ClusterShardID]chan RoutedMessage - remoteSendChannelsMu sync.RWMutex - - // localAckChannels maps shard IDs to ack channels for local acknowledgment handling - localAckChannels map[history.ClusterShardID]chan RoutedAck - localAckChannelsMu sync.RWMutex - - // localReceiverCancelFuncs maps shard IDs to context cancel functions for local receiver termination - localReceiverCancelFuncs map[history.ClusterShardID]context.CancelFunc - localReceiverCancelFuncsMu sync.RWMutex - - // activeReceivers tracks active proxyStreamReceiver instances by source shard for watermark propagation - activeReceivers map[history.ClusterShardID]*proxyStreamReceiver - activeReceiversMu sync.RWMutex } // contextAwareServer represents a startable gRPC server used to provide the Temporal interface on some connection. // IsUsable and Describe allow the caller to know and log the current state of the server. @@ -130,12 +111,8 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon // The name is used in metrics and in the protocol for identifying the multi-client-conn. Sanitize it or else grpc.Dial will be very unhappy. sanitizedConnectionName := sanitizeConnectionName(connConfig.Name) cc := &ClusterConnection{ - lifetime: lifetime, - logger: log.With(logger, tag.NewStringTag("clusterConn", sanitizedConnectionName)), - remoteSendChannels: make(map[history.ClusterShardID]chan RoutedMessage), - localAckChannels: make(map[history.ClusterShardID]chan RoutedAck), - localReceiverCancelFuncs: make(map[history.ClusterShardID]context.CancelFunc), - activeReceivers: make(map[history.ClusterShardID]*proxyStreamReceiver), + lifetime: lifetime, + logger: log.With(logger, tag.NewStringTag("clusterConn", sanitizedConnectionName)), } var err error cc.inboundClient, err = createClient(lifetime, sanitizedConnectionName, connConfig.LocalServer.Connection, "inbound") @@ -225,10 +202,7 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon return nil, err } - cc.shardManager = NewShardManager(connConfig.MemberlistConfig, logger) - if connConfig.MemberlistConfig != nil { - cc.intraMgr = newIntraProxyManager(logger, cc, connConfig.ShardCountConfig) - } + cc.shardManager = NewShardManager(connConfig.MemberlistConfig, connConfig.ShardCountConfig, logger) return cc, nil } @@ -316,30 +290,6 @@ func (c *ClusterConnection) Start() { if err != nil { c.logger.Error("Failed to start shard manager", tag.Error(err)) } - // Wire up shard change callbacks to propagate pending watermarks - // Note: This will overwrite any existing callbacks set by SetIntraProxyManager, - // but we ensure intra-proxy manager is notified via its Notify() method - c.shardManager.SetOnLocalShardChange(func(shard history.ClusterShardID, added bool) { - // Notify intra-proxy manager if it exists - if c.intraMgr != nil { - c.intraMgr.Notify() - } - if added { - c.notifyReceiversOfNewShard(shard) - } - }) - c.shardManager.SetOnRemoteShardChange(func(peer string, shard history.ClusterShardID, added bool) { - // Notify intra-proxy manager if it exists - if c.intraMgr != nil { - c.intraMgr.Notify() - } - if added { - c.notifyReceiversOfNewShard(shard) - } - }) - } - if c.intraMgr != nil { - c.intraMgr.Start() } c.inboundServer.Start() c.inboundObserver.Start(c.lifetime, c.inboundServer.Name(), "inbound") @@ -360,213 +310,6 @@ func (c *ClusterConnection) AcceptingOutboundTraffic() bool { return c.outboundClient.CanMakeCalls() && c.outboundServer.CanAcceptConnections() } -// GetShardInfo returns debug information about shard distribution -func (c *ClusterConnection) GetShardInfos() []ShardDebugInfo { - var shardInfos []ShardDebugInfo - if c.shardManager != nil { - shardInfos = append(shardInfos, c.shardManager.GetShardInfo()) - } - return shardInfos -} - -// GetChannelInfo returns debug information about active channels -func (c *ClusterConnection) GetChannelInfo() ChannelDebugInfo { - remoteSendChannels := make(map[string]int) - var totalSendChannels int - - // Collect remote send channel info first - c.remoteSendChannelsMu.RLock() - for shardID, ch := range c.remoteSendChannels { - shardKey := ClusterShardIDtoString(shardID) - remoteSendChannels[shardKey] = len(ch) - } - totalSendChannels = len(c.remoteSendChannels) - c.remoteSendChannelsMu.RUnlock() - - localAckChannels := make(map[string]int) - var totalAckChannels int - - // Collect local ack channel info separately - c.localAckChannelsMu.RLock() - for shardID, ch := range c.localAckChannels { - shardKey := ClusterShardIDtoString(shardID) - localAckChannels[shardKey] = len(ch) - } - totalAckChannels = len(c.localAckChannels) - c.localAckChannelsMu.RUnlock() - - return ChannelDebugInfo{ - RemoteSendChannels: remoteSendChannels, - LocalAckChannels: localAckChannels, - TotalSendChannels: totalSendChannels, - TotalAckChannels: totalAckChannels, - } -} - -// SetRemoteSendChan registers a send channel for a specific shard ID -func (c *ClusterConnection) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) { - c.logger.Info("Register remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - c.remoteSendChannelsMu.Lock() - defer c.remoteSendChannelsMu.Unlock() - c.remoteSendChannels[shardID] = sendChan -} - -// GetRemoteSendChan retrieves the send channel for a specific shard ID -func (c *ClusterConnection) GetRemoteSendChan(shardID history.ClusterShardID) (chan RoutedMessage, bool) { - c.remoteSendChannelsMu.RLock() - defer c.remoteSendChannelsMu.RUnlock() - ch, exists := c.remoteSendChannels[shardID] - return ch, exists -} - -// GetAllRemoteSendChans returns a map of all remote send channels -func (c *ClusterConnection) GetAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage { - c.remoteSendChannelsMu.RLock() - defer c.remoteSendChannelsMu.RUnlock() - - // Create a copy of the map - result := make(map[history.ClusterShardID]chan RoutedMessage, len(c.remoteSendChannels)) - for k, v := range c.remoteSendChannels { - result[k] = v - } - return result -} - -// GetRemoteSendChansByCluster returns a copy of remote send channels filtered by clusterID -func (c *ClusterConnection) GetRemoteSendChansByCluster(clusterID int32) map[history.ClusterShardID]chan RoutedMessage { - c.remoteSendChannelsMu.RLock() - defer c.remoteSendChannelsMu.RUnlock() - - result := make(map[history.ClusterShardID]chan RoutedMessage) - for k, v := range c.remoteSendChannels { - if k.ClusterID == clusterID { - result[k] = v - } - } - return result -} - -// RemoveRemoteSendChan removes the send channel for a specific shard ID only if it matches the provided channel -func (c *ClusterConnection) RemoveRemoteSendChan(shardID history.ClusterShardID, expectedChan chan RoutedMessage) { - c.remoteSendChannelsMu.Lock() - defer c.remoteSendChannelsMu.Unlock() - if currentChan, exists := c.remoteSendChannels[shardID]; exists && currentChan == expectedChan { - delete(c.remoteSendChannels, shardID) - c.logger.Info("Removed remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - } else { - c.logger.Info("Skipped removing remote send channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - } -} - -// SetLocalAckChan registers an ack channel for a specific shard ID -func (c *ClusterConnection) SetLocalAckChan(shardID history.ClusterShardID, ackChan chan RoutedAck) { - c.logger.Info("Register local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - c.localAckChannelsMu.Lock() - defer c.localAckChannelsMu.Unlock() - c.localAckChannels[shardID] = ackChan -} - -// GetLocalAckChan retrieves the ack channel for a specific shard ID -func (c *ClusterConnection) GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, bool) { - c.localAckChannelsMu.RLock() - defer c.localAckChannelsMu.RUnlock() - ch, exists := c.localAckChannels[shardID] - return ch, exists -} - -// RemoveLocalAckChan removes the ack channel for a specific shard ID only if it matches the provided channel -func (c *ClusterConnection) RemoveLocalAckChan(shardID history.ClusterShardID, expectedChan chan RoutedAck) { - c.logger.Info("Remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - c.localAckChannelsMu.Lock() - defer c.localAckChannelsMu.Unlock() - if currentChan, exists := c.localAckChannels[shardID]; exists && currentChan == expectedChan { - delete(c.localAckChannels, shardID) - } else { - c.logger.Info("Skipped removing local ack channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - } -} - -// ForceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID -func (c *ClusterConnection) ForceRemoveLocalAckChan(shardID history.ClusterShardID) { - c.logger.Info("Force remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - c.localAckChannelsMu.Lock() - defer c.localAckChannelsMu.Unlock() - delete(c.localAckChannels, shardID) -} - -// SetLocalReceiverCancelFunc registers a cancel function for a local receiver for a specific shard ID -func (c *ClusterConnection) SetLocalReceiverCancelFunc(shardID history.ClusterShardID, cancelFunc context.CancelFunc) { - c.logger.Info("Register local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - c.localReceiverCancelFuncsMu.Lock() - defer c.localReceiverCancelFuncsMu.Unlock() - c.localReceiverCancelFuncs[shardID] = cancelFunc -} - -// GetLocalReceiverCancelFunc retrieves the cancel function for a local receiver for a specific shard ID -func (c *ClusterConnection) GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (context.CancelFunc, bool) { - c.localReceiverCancelFuncsMu.RLock() - defer c.localReceiverCancelFuncsMu.RUnlock() - cancelFunc, exists := c.localReceiverCancelFuncs[shardID] - return cancelFunc, exists -} - -// RemoveLocalReceiverCancelFunc unconditionally removes the cancel function for a local receiver for a specific shard ID -// Note: Functions cannot be compared in Go, so we use unconditional removal. -// The race condition is primarily with channels; TerminatePreviousLocalReceiver handles forced cleanup. -func (c *ClusterConnection) RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) { - c.logger.Info("Remove local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - c.localReceiverCancelFuncsMu.Lock() - defer c.localReceiverCancelFuncsMu.Unlock() - delete(c.localReceiverCancelFuncs, shardID) -} - -// TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed -func (c *ClusterConnection) TerminatePreviousLocalReceiver(shardID history.ClusterShardID) { - // Check if there's a previous cancel function for this shard - if prevCancelFunc, exists := c.GetLocalReceiverCancelFunc(shardID); exists { - c.logger.Info("Terminating previous local receiver for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) - - // Cancel the previous receiver's context - prevCancelFunc() - - // Force remove the cancel function and ack channel from tracking - c.RemoveLocalReceiverCancelFunc(shardID) - c.ForceRemoveLocalAckChan(shardID) - } -} - -// RegisterActiveReceiver registers an active receiver for watermark propagation -func (c *ClusterConnection) RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver *proxyStreamReceiver) { - c.activeReceiversMu.Lock() - defer c.activeReceiversMu.Unlock() - c.activeReceivers[sourceShardID] = receiver -} - -// UnregisterActiveReceiver removes an active receiver -func (c *ClusterConnection) UnregisterActiveReceiver(sourceShardID history.ClusterShardID) { - c.activeReceiversMu.Lock() - defer c.activeReceiversMu.Unlock() - delete(c.activeReceivers, sourceShardID) -} - -// notifyReceiversOfNewShard notifies all receivers about a newly registered target shard -// so they can send pending watermarks if available -func (c *ClusterConnection) notifyReceiversOfNewShard(targetShardID history.ClusterShardID) { - c.activeReceiversMu.RLock() - receivers := make([]*proxyStreamReceiver, 0, len(c.activeReceivers)) - for _, receiver := range c.activeReceivers { - receivers = append(receivers, receiver) - } - c.activeReceiversMu.RUnlock() - - for _, receiver := range receivers { - // Only notify receivers that route to the same cluster as the newly registered shard - if receiver.targetShardID.ClusterID == targetShardID.ClusterID { - receiver.sendPendingWatermarkToShard(targetShardID, c) - } - } -} - // buildProxyServer uses the provided grpc.ClientConnInterface and config.ProxyConfig to create a grpc.Server that proxies // the Temporal API across the ClientConnInterface. func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, observeFn func(int32, int32)) (*grpc.Server, error) { diff --git a/proxy/debug.go b/proxy/debug.go index cf04b90a..44ab5cab 100644 --- a/proxy/debug.go +++ b/proxy/debug.go @@ -102,8 +102,10 @@ func HandleDebugInfo(w http.ResponseWriter, r *http.Request, proxyInstance *Prox activeStreams = streamTracker.GetActiveStreams() streamCount = streamTracker.GetStreamCount() for _, clusterConnection := range proxyInstance.clusterConnections { - shardInfos = append(shardInfos, clusterConnection.GetShardInfos()...) - channelInfos = append(channelInfos, clusterConnection.GetChannelInfo()) + if clusterConnection.shardManager != nil { + shardInfos = append(shardInfos, clusterConnection.shardManager.GetShardInfos()...) + channelInfos = append(channelInfos, clusterConnection.shardManager.GetChannelInfo()) + } } response := DebugResponse{ diff --git a/proxy/fx.go b/proxy/fx.go index 16afc573..aba1ded7 100644 --- a/proxy/fx.go +++ b/proxy/fx.go @@ -6,5 +6,4 @@ import ( var Module = fx.Options( fx.Provide(NewProxy), - fx.Provide(NewShardManager), ) diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 499d12b8..a4715772 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -17,17 +17,15 @@ import ( "google.golang.org/grpc/metadata" "github.com/temporalio/s2s-proxy/common" - "github.com/temporalio/s2s-proxy/config" ) // intraProxyManager maintains long-lived intra-proxy streams to peer proxies and // provides simple send helpers (e.g., forwarding ACKs). type intraProxyManager struct { - logger log.Logger - streamsMu sync.RWMutex - shardCountConfig config.ShardCountConfig - clusterConnection *ClusterConnection - notifyCh chan struct{} + logger log.Logger + streamsMu sync.RWMutex + shardManager ShardManager + notifyCh chan struct{} // Group state by remote peer for unified lifecycle ops peers map[string]*peerState } @@ -44,13 +42,12 @@ type peerStreamKey struct { sourceShard history.ClusterShardID } -func newIntraProxyManager(logger log.Logger, clusterConnection *ClusterConnection, shardCountConfig config.ShardCountConfig) *intraProxyManager { +func newIntraProxyManager(logger log.Logger, shardManager ShardManager) *intraProxyManager { return &intraProxyManager{ - logger: logger, - clusterConnection: clusterConnection, - shardCountConfig: shardCountConfig, - peers: make(map[string]*peerState), - notifyCh: make(chan struct{}), + logger: logger, + shardManager: shardManager, + peers: make(map[string]*peerState), + notifyCh: make(chan struct{}), } } @@ -58,8 +55,7 @@ func newIntraProxyManager(logger log.Logger, clusterConnection *ClusterConnectio // Replication messages are sent by intraProxyManager.sendMessages using the registered server stream. type intraProxyStreamSender struct { logger log.Logger - clusterConnection *ClusterConnection - intraMgr *intraProxyManager + shardManager ShardManager peerNodeName string targetShardID history.ClusterShardID sourceShardID history.ClusterShardID @@ -85,8 +81,8 @@ func (s *intraProxyStreamSender) Run( s.sourceStreamServer = sourceStreamServer // register this sender so sendMessages can use it - s.intraMgr.RegisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID, s) - defer s.intraMgr.UnregisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID) + s.shardManager.GetIntraProxyManager().RegisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID, s) + defer s.shardManager.GetIntraProxyManager().UnregisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID) // recv ACKs from peer and route to original source shard owner return s.recvAck(shutdownChan) @@ -131,7 +127,7 @@ func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) erro s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(s.sourceShardID)), tag.NewInt64("ack", ack)) // FIXME: should retry. If not succeed, return and shutdown the stream - sent := s.clusterConnection.shardManager.DeliverAckToShardOwner(s.sourceShardID, routedAck, s.clusterConnection, shutdownChan, s.logger, ack, false) + sent := s.shardManager.DeliverAckToShardOwner(s.sourceShardID, routedAck, shutdownChan, s.logger, ack, false) if !sent { s.logger.Error("Sender failed to forward ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(s.sourceShardID)), tag.NewInt64("ack", ack)) return fmt.Errorf("failed to forward ACK to source shard") @@ -165,20 +161,20 @@ func (s *intraProxyStreamSender) sendReplicationMessages(resp *adminservice.Stre // intraProxyStreamReceiver ensures a client stream to peer exists and sends aggregated ACKs upstream. type intraProxyStreamReceiver struct { - logger log.Logger - clusterConnection *ClusterConnection - intraMgr *intraProxyManager - peerNodeName string - targetShardID history.ClusterShardID - sourceShardID history.ClusterShardID - streamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient - streamID string - shutdown channel.ShutdownOnce - cancel context.CancelFunc + logger log.Logger + shardManager ShardManager + intraMgr *intraProxyManager + peerNodeName string + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + streamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient + streamID string + shutdown channel.ShutdownOnce + cancel context.CancelFunc } // Run opens the client stream with metadata, registers tracking, and starts receiver goroutines. -func (r *intraProxyStreamReceiver) Run(ctx context.Context, clusterConnection *ClusterConnection, conn *grpc.ClientConn) error { +func (r *intraProxyStreamReceiver) Run(ctx context.Context, shardManager ShardManager, conn *grpc.ClientConn) error { r.streamID = BuildIntraProxyReceiverStreamID(r.peerNodeName, r.sourceShardID, r.targetShardID) r.logger = log.With(r.logger, tag.NewStringTag("streamID", r.streamID)) @@ -191,7 +187,7 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, clusterConnection *C md.Set(history.MetadataKeyServerShardID, fmt.Sprintf("%d", r.sourceShardID.ShardID)) ctx = metadata.NewOutgoingContext(ctx, md) ctx = common.WithIntraProxyHeaders(ctx, map[string]string{ - common.IntraProxyOriginProxyIDHeader: clusterConnection.shardManager.GetShardInfo().NodeName, + common.IntraProxyOriginProxyIDHeader: shardManager.GetShardInfo().NodeName, }) // Ensure we can cancel Recv() by canceling the context when tearing down @@ -214,11 +210,11 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, clusterConnection *C defer st.UnregisterStream(r.streamID) // Start replication receiver loop - return r.recvReplicationMessages(r.clusterConnection) + return r.recvReplicationMessages() } // recvReplicationMessages receives replication messages and forwards to local shard owner. -func (r *intraProxyStreamReceiver) recvReplicationMessages(clusterConnection *ClusterConnection) error { +func (r *intraProxyStreamReceiver) recvReplicationMessages() error { r.logger.Info("intraProxyStreamReceiver recvReplicationMessages started") defer r.logger.Info("intraProxyStreamReceiver recvReplicationMessages finished") @@ -252,7 +248,7 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages(clusterConnection *Cl sent := false logged := false for !sent { - if ch, ok := clusterConnection.remoteSendChannels[r.targetShardID]; ok { + if ch, ok := r.shardManager.GetRemoteSendChan(r.targetShardID); ok { func() { defer func() { if panicErr := recover(); panicErr != nil { @@ -345,7 +341,7 @@ func (m *intraProxyManager) UnregisterSender( } // EnsureReceiverForPeerShard ensures a client stream and an ACK aggregator exist for the given peer/shard pair. -func (m *intraProxyManager) EnsureReceiverForPeerShard(clusterConnection *ClusterConnection, peerNodeName string, targetShard history.ClusterShardID, sourceShard history.ClusterShardID) { +func (m *intraProxyManager) EnsureReceiverForPeerShard(peerNodeName string, targetShard history.ClusterShardID, sourceShard history.ClusterShardID) { logger := log.With(m.logger, tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), @@ -357,18 +353,18 @@ func (m *intraProxyManager) EnsureReceiverForPeerShard(clusterConnection *Cluste return } // Do not create intra-proxy streams to self instance - if peerNodeName == m.clusterConnection.shardManager.GetNodeName() { + if peerNodeName == m.shardManager.GetNodeName() { return } // Require at least one shard to be local to this instance - isLocalTargetShard := m.clusterConnection.shardManager.IsLocalShard(targetShard) - isLocalSourceShard := m.clusterConnection.shardManager.IsLocalShard(sourceShard) + isLocalTargetShard := m.shardManager.IsLocalShard(targetShard) + isLocalSourceShard := m.shardManager.IsLocalShard(sourceShard) if !isLocalTargetShard && !isLocalSourceShard { logger.Info("EnsureReceiverForPeerShard skipping because neither shard is local", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewBoolTag("isLocalTargetShard", isLocalTargetShard), tag.NewBoolTag("isLocalSourceShard", isLocalSourceShard)) return } // Consolidated path: ensure stream and background loops - err := m.ensureStream(context.Background(), logger, peerNodeName, targetShard, sourceShard, m.clusterConnection) + err := m.ensureStream(context.Background(), logger, peerNodeName, targetShard, sourceShard) if err != nil { logger.Error("failed to ensureStream", tag.Error(err)) } @@ -378,7 +374,6 @@ func (m *intraProxyManager) EnsureReceiverForPeerShard(clusterConnection *Cluste func (m *intraProxyManager) ensurePeer( ctx context.Context, peerNodeName string, - clusterConnection *ClusterConnection, ) (*peerState, error) { m.streamsMu.RLock() if ps, ok := m.peers[peerNodeName]; ok && ps != nil && ps.conn != nil { @@ -414,7 +409,7 @@ func (m *intraProxyManager) ensurePeer( // grpc.WithDisableServiceConfig(), // ) - proxyAddresses, ok := clusterConnection.shardManager.GetProxyAddress(peerNodeName) + proxyAddresses, ok := m.shardManager.GetProxyAddress(peerNodeName) if !ok { return nil, fmt.Errorf("proxy address not found") } @@ -456,7 +451,6 @@ func (m *intraProxyManager) ensureStream( peerNodeName string, targetShard history.ClusterShardID, sourceShard history.ClusterShardID, - clusterConnection *ClusterConnection, ) error { logger.Info("ensureStream") key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} @@ -473,7 +467,7 @@ func (m *intraProxyManager) ensureStream( m.streamsMu.RUnlock() // Reuse shared connection per peer - ps, err := m.ensurePeer(ctx, peerNodeName, clusterConnection) + ps, err := m.ensurePeer(ctx, peerNodeName) if err != nil { logger.Error("Failed to ensure peer", tag.Error(err)) return err @@ -485,11 +479,11 @@ func (m *intraProxyManager) ensureStream( tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("targetShardID", ClusterShardIDtoString(targetShard)), tag.NewStringTag("sourceShardID", ClusterShardIDtoString(sourceShard))), - clusterConnection: clusterConnection, - intraMgr: m, - peerNodeName: peerNodeName, - targetShardID: targetShard, - sourceShardID: sourceShard, + shardManager: m.shardManager, + intraMgr: m, + peerNodeName: peerNodeName, + targetShardID: targetShard, + sourceShardID: sourceShard, } // initialize shutdown handle and register it for lifecycle management recv.shutdown = channel.NewShutdownOnce() @@ -501,7 +495,7 @@ func (m *intraProxyManager) ensureStream( // Let the receiver open stream, register tracking, and start goroutines go func() { - if err := recv.Run(ctx, clusterConnection, ps.conn); err != nil { + if err := recv.Run(ctx, m.shardManager, ps.conn); err != nil { m.logger.Error("intraProxyStreamReceiver.Run error", tag.Error(err)) } // remove the receiver from the peer state @@ -671,9 +665,9 @@ func (m *intraProxyManager) Start() { timer := time.NewTimer(1 * time.Second) select { case <-timer.C: - m.ReconcilePeerStreams(m.clusterConnection, "") + m.ReconcilePeerStreams("") case <-m.notifyCh: - m.ReconcilePeerStreams(m.clusterConnection, "") + m.ReconcilePeerStreams("") } } }() @@ -689,18 +683,12 @@ func (m *intraProxyManager) Notify() { // ReconcilePeerStreams ensures receivers exist for desired (local shard, remote shard) pairs // for a given peer and closes any sender/receiver not in the desired set. // This mirrors the Temporal StreamReceiverMonitor approach. -func (m *intraProxyManager) ReconcilePeerStreams( - clusterConnection *ClusterConnection, - peerNodeName string, -) { +func (m *intraProxyManager) ReconcilePeerStreams(peerNodeName string) { m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("peerNodeName", peerNodeName)) defer m.logger.Info("ReconcilePeerStreams done", tag.NewStringTag("peerNodeName", peerNodeName)) - if mode := m.shardCountConfig.Mode; mode != config.ShardCountRouting { - return - } - localShards := clusterConnection.shardManager.GetLocalShards() - remoteShards, err := clusterConnection.shardManager.GetRemoteShardsForPeer(peerNodeName) + localShards := m.shardManager.GetLocalShards() + remoteShards, err := m.shardManager.GetRemoteShardsForPeer(peerNodeName) if err != nil { m.logger.Error("Failed to get remote shards for peer", tag.Error(err)) return @@ -742,7 +730,7 @@ func (m *intraProxyManager) ReconcilePeerStreams( // Ensure all desired receivers exist for key := range desiredReceivers { - m.EnsureReceiverForPeerShard(clusterConnection, desiredReceivers[key], key.targetShard, key.sourceShard) + m.EnsureReceiverForPeerShard(desiredReceivers[key], key.targetShard, key.sourceShard) } // Prune anything not desired diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 5e366e51..8ebbd370 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -140,13 +140,13 @@ func (b *proxyIDRingBuffer) Discard(count int) { // (another proxy or a target server) and receiving ACKs back. // This is scaffolding only – the concrete behavior will be wired in later. type proxyStreamSender struct { - logger log.Logger - clusterConnection *ClusterConnection - targetShardID history.ClusterShardID - sourceShardID history.ClusterShardID - directionLabel string - streamID string - streamTracker *StreamTracker + logger log.Logger + shardManager ShardManager + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + directionLabel string + streamID string + streamTracker *StreamTracker // sendMsgChan carries replication messages to be sent to the remote side. sendMsgChan chan RoutedMessage @@ -239,11 +239,11 @@ func (s *proxyStreamSender) Run( // Register remote send channel for this shard so receiver can forward tasks locally s.sendMsgChan = make(chan RoutedMessage, 100) - s.clusterConnection.SetRemoteSendChan(s.targetShardID, s.sendMsgChan) - defer s.clusterConnection.RemoveRemoteSendChan(s.targetShardID, s.sendMsgChan) + s.shardManager.SetRemoteSendChan(s.targetShardID, s.sendMsgChan) + defer s.shardManager.RemoveRemoteSendChan(s.targetShardID, s.sendMsgChan) - registeredAt := s.clusterConnection.shardManager.RegisterShard(s.targetShardID) - defer s.clusterConnection.shardManager.UnregisterShard(s.targetShardID, registeredAt) + registeredAt := s.shardManager.RegisterShard(s.targetShardID) + defer s.shardManager.UnregisterShard(s.targetShardID, registeredAt) wg := sync.WaitGroup{} wg.Add(2) @@ -333,7 +333,7 @@ func (s *proxyStreamSender) recvAck( s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", originalAck)) - if s.clusterConnection.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.clusterConnection, shutdownChan, s.logger, originalAck, true) { + if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, shutdownChan, s.logger, originalAck, true) { sent[srcShard] = true numRemaining-- progress = true @@ -395,7 +395,7 @@ func (s *proxyStreamSender) recvAck( } // Log fallback ACK for this source shard s.logger.Info("Sender forwarding fallback ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", prev)) - if s.clusterConnection.shardManager.DeliverAckToShardOwner(srcShard, routedAck, s.clusterConnection, shutdownChan, s.logger, prev, true) { + if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, shutdownChan, s.logger, prev, true) { sent[srcShard] = true numRemaining-- progress = true @@ -554,14 +554,14 @@ func (s *proxyStreamSender) sendReplicationMessages( // proxyStreamReceiver receives replication messages from a local/remote server and // produces ACKs destined for the original sender. type proxyStreamReceiver struct { - logger log.Logger - clusterConnection *ClusterConnection - adminClient adminservice.AdminServiceClient - localShardCount int32 - targetShardID history.ClusterShardID - sourceShardID history.ClusterShardID - directionLabel string - ackChan chan RoutedAck + logger log.Logger + shardManager ShardManager + adminClient adminservice.AdminServiceClient + localShardCount int32 + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + directionLabel string + ackChan chan RoutedAck // ack aggregation across target shards ackByTarget map[history.ClusterShardID]int64 lastSentMin int64 @@ -597,7 +597,9 @@ func (r *proxyStreamReceiver) Run( shutdownChan channel.ShutdownOnce, ) { // Terminate any previous local receiver for this shard - r.clusterConnection.TerminatePreviousLocalReceiver(r.sourceShardID) + if r.shardManager != nil { + r.shardManager.TerminatePreviousLocalReceiver(r.sourceShardID, r.logger) + } r.streamID = BuildReceiverStreamID(r.sourceShardID, r.targetShardID) r.logger = log.With(r.logger, @@ -635,15 +637,17 @@ func (r *proxyStreamReceiver) Run( // Setup ack channel and cancel func bookkeeping r.ackChan = make(chan RoutedAck, 100) - r.clusterConnection.SetLocalAckChan(r.sourceShardID, r.ackChan) - r.clusterConnection.SetLocalReceiverCancelFunc(r.sourceShardID, cancel) - // Register receiver for watermark propagation to late-registering shards - r.clusterConnection.RegisterActiveReceiver(r.sourceShardID, r) - defer func() { - r.clusterConnection.RemoveLocalAckChan(r.sourceShardID, r.ackChan) - r.clusterConnection.RemoveLocalReceiverCancelFunc(r.sourceShardID) - r.clusterConnection.UnregisterActiveReceiver(r.sourceShardID) - }() + if r.shardManager != nil { + r.shardManager.SetLocalAckChan(r.sourceShardID, r.ackChan) + r.shardManager.SetLocalReceiverCancelFunc(r.sourceShardID, cancel) + // Register receiver for watermark propagation to late-registering shards + r.shardManager.RegisterActiveReceiver(r.sourceShardID, r) + defer func() { + r.shardManager.RemoveLocalAckChan(r.sourceShardID, r.ackChan) + r.shardManager.RemoveLocalReceiverCancelFunc(r.sourceShardID) + r.shardManager.UnregisterActiveReceiver(r.sourceShardID) + }() + } // init aggregation state r.ackByTarget = make(map[history.ClusterShardID]int64) @@ -752,7 +756,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( }, }, } - localShardsToSend := r.clusterConnection.GetRemoteSendChansByCluster(r.targetShardID.ClusterID) + localShardsToSend := r.shardManager.GetRemoteSendChansByCluster(r.targetShardID.ClusterID) r.logger.Info("Going to broadcast high watermark to local shards", tag.NewStringTag("localShardsToSend", fmt.Sprintf("%v", localShardsToSend))) for targetShardID, sendChan := range localShardsToSend { // Clone the message for each recipient to prevent shared mutation @@ -784,7 +788,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( }() } // send to all remote shards on other nodes as well - remoteShards, err := r.clusterConnection.shardManager.GetRemoteShardsForPeer("") + remoteShards, err := r.shardManager.GetRemoteShardsForPeer("") if err != nil { r.logger.Error("Failed to get remote shards", tag.Error(err)) return err @@ -801,7 +805,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( SourceShard: msg.SourceShard, Resp: clonedResp, } - if !r.clusterConnection.shardManager.DeliverMessagesToShardOwner(shard.ID, &clonedMsg, r.clusterConnection, shutdownChan, r.logger) { + if !r.shardManager.DeliverMessagesToShardOwner(shard.ID, &clonedMsg, shutdownChan, r.logger) { r.logger.Warn("Failed to send ReplicationTasks to remote shard", tag.NewStringTag("shard", ClusterShardIDtoString(shard.ID))) } } @@ -841,7 +845,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( }, }, } - if r.clusterConnection.shardManager.DeliverMessagesToShardOwner(targetShardID, &msg, r.clusterConnection, shutdownChan, r.logger) { + if r.shardManager.DeliverMessagesToShardOwner(targetShardID, &msg, shutdownChan, r.logger) { sentByTarget[targetShardID] = true numRemaining-- progress = true @@ -868,7 +872,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( // sendPendingWatermarkToShard sends the last known watermark to a newly registered target shard // This ensures late-registering shards receive watermarks that were sent before they registered -func (r *proxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history.ClusterShardID, clusterConnection *ClusterConnection) { +func (r *proxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history.ClusterShardID) { r.lastWatermarkMu.RLock() lastWatermark := r.lastWatermark r.lastWatermarkMu.RUnlock() @@ -895,7 +899,7 @@ func (r *proxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history. } // Try to send to local shard first - if sendChan, exists := clusterConnection.GetRemoteSendChan(targetShardID); exists { + if sendChan, exists := r.shardManager.GetRemoteSendChan(targetShardID); exists { clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) clonedMsg := RoutedMessage{ SourceShard: msg.SourceShard, @@ -913,14 +917,14 @@ func (r *proxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history. } // If not local, try to send to remote shard - if clusterConnection.shardManager != nil { + if r.shardManager != nil { shutdownChan := channel.NewShutdownOnce() clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) clonedMsg := RoutedMessage{ SourceShard: msg.SourceShard, Resp: clonedResp, } - if clusterConnection.shardManager.DeliverMessagesToShardOwner(targetShardID, &clonedMsg, clusterConnection, shutdownChan, r.logger) { + if r.shardManager.DeliverMessagesToShardOwner(targetShardID, &clonedMsg, shutdownChan, r.logger) { r.logger.Info("Sent pending watermark to remote shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) } else { diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 3f2f3c87..91f17e17 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -42,12 +42,20 @@ type ( GetRemoteShardsForPeer(peerNodeName string) (map[string]NodeShardState, error) // GetShardInfo returns debug information about shard distribution GetShardInfo() ShardDebugInfo + // GetShardInfos returns debug information about shard distribution as a slice + GetShardInfos() []ShardDebugInfo + // GetChannelInfo returns debug information about active channels + GetChannelInfo() ChannelDebugInfo // GetShardOwner returns the node name that owns the given shard GetShardOwner(shard history.ClusterShardID) (string, bool) + // TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed + TerminatePreviousLocalReceiver(shardID history.ClusterShardID, logger log.Logger) + // GetIntraProxyManager returns the intra-proxy manager if it exists + GetIntraProxyManager() *intraProxyManager // DeliverAckToShardOwner routes an ACK request to the appropriate shard owner (local or remote) - DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, clusterConnection *ClusterConnection, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool + DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool // DeliverMessagesToShardOwner routes replication messages to the appropriate shard owner (local or remote) - DeliverMessagesToShardOwner(targetShard history.ClusterShardID, routedMsg *RoutedMessage, clusterConnection *ClusterConnection, shutdownChan channel.ShutdownOnce, logger log.Logger) bool + DeliverMessagesToShardOwner(targetShard history.ClusterShardID, routedMsg *RoutedMessage, shutdownChan channel.ShutdownOnce, logger log.Logger) bool // SetOnPeerJoin registers a callback invoked when a new peer joins SetOnPeerJoin(handler func(nodeName string)) // SetOnPeerLeave registers a callback invoked when a peer leaves. @@ -56,6 +64,36 @@ type ( SetOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) // New: notify when remote shard set changes for a peer SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) + // RegisterActiveReceiver registers an active receiver for watermark propagation + RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver *proxyStreamReceiver) + // UnregisterActiveReceiver removes an active receiver + UnregisterActiveReceiver(sourceShardID history.ClusterShardID) + // SetRemoteSendChan registers a send channel for a specific shard ID + SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) + // GetRemoteSendChan retrieves the send channel for a specific shard ID + GetRemoteSendChan(shardID history.ClusterShardID) (chan RoutedMessage, bool) + // GetAllRemoteSendChans returns a map of all remote send channels + GetAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage + // GetRemoteSendChansByCluster returns a copy of remote send channels filtered by clusterID + GetRemoteSendChansByCluster(clusterID int32) map[history.ClusterShardID]chan RoutedMessage + // RemoveRemoteSendChan removes the send channel for a specific shard ID only if it matches the provided channel + RemoveRemoteSendChan(shardID history.ClusterShardID, expectedChan chan RoutedMessage) + // SetLocalAckChan registers an ack channel for a specific shard ID + SetLocalAckChan(shardID history.ClusterShardID, ackChan chan RoutedAck) + // GetLocalAckChan retrieves the ack channel for a specific shard ID + GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, bool) + // GetAllLocalAckChans returns a map of all local ack channels + GetAllLocalAckChans() map[history.ClusterShardID]chan RoutedAck + // RemoveLocalAckChan removes the ack channel for a specific shard ID only if it matches the provided channel + RemoveLocalAckChan(shardID history.ClusterShardID, expectedChan chan RoutedAck) + // ForceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID + ForceRemoveLocalAckChan(shardID history.ClusterShardID) + // SetLocalReceiverCancelFunc registers a cancel function for a local receiver for a specific shard ID + SetLocalReceiverCancelFunc(shardID history.ClusterShardID, cancelFunc context.CancelFunc) + // GetLocalReceiverCancelFunc retrieves the cancel function for a local receiver for a specific shard ID + GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (context.CancelFunc, bool) + // RemoveLocalReceiverCancelFunc unconditionally removes the cancel function for a local receiver for a specific shard ID + RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) } shardManagerImpl struct { @@ -78,6 +116,18 @@ type ( stopJoinRetry chan struct{} joinWg sync.WaitGroup joinLoopRunning bool + // activeReceivers tracks active proxyStreamReceiver instances by source shard for watermark propagation + activeReceivers map[history.ClusterShardID]*proxyStreamReceiver + activeReceiversMu sync.RWMutex + // remoteSendChannels maps shard IDs to send channels for replication message routing + remoteSendChannels map[history.ClusterShardID]chan RoutedMessage + remoteSendChannelsMu sync.RWMutex + // localAckChannels maps shard IDs to ack channels for local acknowledgment handling + localAckChannels map[history.ClusterShardID]chan RoutedAck + localAckChannelsMu sync.RWMutex + // localReceiverCancelFuncs maps shard IDs to context cancel functions for local receiver termination + localReceiverCancelFuncs map[history.ClusterShardID]context.CancelFunc + localReceiverCancelFuncsMu sync.RWMutex } // shardDelegate implements memberlist.Delegate for shard state management @@ -109,22 +159,30 @@ type ( ) // NewShardManager creates a new shard manager instance -func NewShardManager(memberlistConfig *config.MemberlistConfig, logger log.Logger) ShardManager { +func NewShardManager(memberlistConfig *config.MemberlistConfig, shardCountConfig config.ShardCountConfig, logger log.Logger) ShardManager { delegate := &shardDelegate{ logger: logger, } sm := &shardManagerImpl{ - memberlistConfig: memberlistConfig, - logger: logger, - delegate: delegate, - localShards: make(map[string]ShardInfo), - intraMgr: nil, - stopJoinRetry: make(chan struct{}), + memberlistConfig: memberlistConfig, + logger: logger, + delegate: delegate, + localShards: make(map[string]ShardInfo), + intraMgr: nil, + stopJoinRetry: make(chan struct{}), + activeReceivers: make(map[history.ClusterShardID]*proxyStreamReceiver), + remoteSendChannels: make(map[history.ClusterShardID]chan RoutedMessage), + localAckChannels: make(map[history.ClusterShardID]chan RoutedAck), + localReceiverCancelFuncs: make(map[history.ClusterShardID]context.CancelFunc), } delegate.manager = sm + if memberlistConfig != nil && shardCountConfig.Mode == config.ShardCountRouting { + sm.intraMgr = newIntraProxyManager(logger, sm) + } + return sm } @@ -158,16 +216,42 @@ func (sm *shardManagerImpl) SetOnRemoteShardChange(handler func(peer string, sha func (sm *shardManagerImpl) Start(lifetime context.Context) error { sm.logger.Info("Starting shard manager") - if sm.memberlistConfig == nil { - sm.logger.Info("Shard manager not configured, skipping") - return nil - } if sm.started { sm.logger.Info("Shard manager already started") return nil } + if sm.intraMgr != nil { + sm.intraMgr.Start() + } + + sm.SetupCallbacks() + + if err := sm.initializeMemberlist(); err != nil { + return err + } + + sm.mutex.Lock() + sm.started = true + sm.mutex.Unlock() + + sm.logger.Info("Shard manager started", + tag.NewStringTag("node", sm.GetNodeName()), + tag.NewStringTag("addr", sm.localAddr)) + + context.AfterFunc(lifetime, func() { + sm.Stop() + }) + return nil +} + +func (sm *shardManagerImpl) initializeMemberlist() error { + if sm.memberlistConfig == nil { + sm.logger.Info("Shard manager not configured, skipping") + return nil + } + // Configure memberlist var mlConfig *memberlist.Config if sm.memberlistConfig.TCPOnly { @@ -236,38 +320,42 @@ func (sm *shardManagerImpl) Start(lifetime context.Context) error { sm.mutex.Lock() sm.ml = ml sm.localAddr = fmt.Sprintf("%s:%d", sm.memberlistConfig.BindAddr, sm.memberlistConfig.BindPort) - sm.started = true + sm.mutex.Unlock() sm.logger.Info("Shard manager base initialization complete", - tag.NewStringTag("node", sm.memberlistConfig.NodeName), + tag.NewStringTag("node", sm.GetNodeName()), tag.NewStringTag("addr", sm.localAddr)) - sm.mutex.Unlock() - // Join existing cluster if configured if len(sm.memberlistConfig.JoinAddrs) > 0 { sm.startJoinLoop() } - sm.logger.Info("Shard manager started", - tag.NewStringTag("node", sm.memberlistConfig.NodeName), - tag.NewStringTag("addr", sm.localAddr)) - - context.AfterFunc(lifetime, func() { - sm.Stop() - }) return nil } func (sm *shardManagerImpl) Stop() { sm.mutex.Lock() - if !sm.started || sm.ml == nil { + if !sm.started { sm.mutex.Unlock() return } sm.mutex.Unlock() + sm.shutdownMemberlist() + + sm.mutex.Lock() + sm.started = false + sm.mutex.Unlock() + sm.logger.Info("Shard manager stopped") +} + +func (sm *shardManagerImpl) shutdownMemberlist() { + if sm.ml == nil { + return + } + // Stop any ongoing join retry close(sm.stopJoinRetry) sm.joinWg.Wait() @@ -282,11 +370,7 @@ func (sm *shardManagerImpl) Stop() { if err != nil { sm.logger.Error("Error shutting down memberlist", tag.Error(err)) } - - sm.mutex.Lock() - sm.started = false - sm.mutex.Unlock() - sm.logger.Info("Shard manager stopped") + sm.ml = nil } // startJoinLoop starts the join retry loop if not already running @@ -445,7 +529,7 @@ func (sm *shardManagerImpl) GetNodeName() string { func (sm *shardManagerImpl) GetMemberNodes() []string { if !sm.started || sm.ml == nil { - return []string{sm.memberlistConfig.NodeName} + return []string{sm.GetNodeName()} } // Use a timeout to prevent deadlocks when memberlist is busy @@ -469,8 +553,8 @@ func (sm *shardManagerImpl) GetMemberNodes() []string { case <-time.After(100 * time.Millisecond): // Timeout: return cached node name to prevent hanging sm.logger.Warn("GetMemberNodes timeout, returning self node", - tag.NewStringTag("node", sm.memberlistConfig.NodeName)) - return []string{sm.memberlistConfig.NodeName} + tag.NewStringTag("node", sm.GetNodeName())) + return []string{sm.GetNodeName()} } } @@ -512,6 +596,61 @@ func (sm *shardManagerImpl) GetShardInfo() ShardDebugInfo { } } +// GetShardInfos returns debug information about shard distribution as a slice +func (sm *shardManagerImpl) GetShardInfos() []ShardDebugInfo { + if sm.memberlistConfig == nil { + return []ShardDebugInfo{} + } + return []ShardDebugInfo{sm.GetShardInfo()} +} + +// GetChannelInfo returns debug information about active channels +func (sm *shardManagerImpl) GetChannelInfo() ChannelDebugInfo { + remoteSendChannels := make(map[string]int) + var totalSendChannels int + + // Collect remote send channel info first + allSendChans := sm.GetAllRemoteSendChans() + for shardID, ch := range allSendChans { + shardKey := ClusterShardIDtoString(shardID) + remoteSendChannels[shardKey] = len(ch) + } + totalSendChannels = len(allSendChans) + + localAckChannels := make(map[string]int) + var totalAckChannels int + + // Collect local ack channel info separately + allAckChans := sm.GetAllLocalAckChans() + for shardID, ch := range allAckChans { + shardKey := ClusterShardIDtoString(shardID) + localAckChannels[shardKey] = len(ch) + } + totalAckChannels = len(allAckChans) + + return ChannelDebugInfo{ + RemoteSendChannels: remoteSendChannels, + LocalAckChannels: localAckChannels, + TotalSendChannels: totalSendChannels, + TotalAckChannels: totalAckChannels, + } +} + +// TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed +func (sm *shardManagerImpl) TerminatePreviousLocalReceiver(shardID history.ClusterShardID, logger log.Logger) { + // Check if there's a previous cancel function for this shard + if prevCancelFunc, exists := sm.GetLocalReceiverCancelFunc(shardID); exists { + logger.Info("Terminating previous local receiver for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + + // Cancel the previous receiver's context + prevCancelFunc() + + // Force remove the cancel function and ack channel from tracking + sm.RemoveLocalReceiverCancelFunc(shardID) + sm.ForceRemoveLocalAckChan(shardID) + } +} + func (sm *shardManagerImpl) GetShardOwner(shard history.ClusterShardID) (string, bool) { remoteShards, err := sm.GetRemoteShardsForPeer("") if err != nil { @@ -576,14 +715,13 @@ func (sm *shardManagerImpl) GetRemoteShardsForPeer(peerNodeName string) (map[str func (sm *shardManagerImpl) DeliverAckToShardOwner( sourceShard history.ClusterShardID, routedAck *RoutedAck, - clusterConnection *ClusterConnection, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool, ) bool { logger = log.With(logger, tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewInt64("ack", ack)) - if ackCh, ok := clusterConnection.localAckChannels[sourceShard]; ok { + if ackCh, ok := sm.GetLocalAckChan(sourceShard); ok { delivered := false func() { defer func() { @@ -613,13 +751,12 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( // Attempt remote delivery via intra-proxy when enabled and shard is remote if sm.memberlistConfig != nil { - if owner, ok := sm.GetShardOwner(sourceShard); ok && owner != sm.memberlistConfig.NodeName { + if owner, ok := sm.GetShardOwner(sourceShard); ok && owner != sm.GetNodeName() { if addr, found := sm.GetProxyAddress(owner); found { clientShard := routedAck.TargetShard serverShard := sourceShard - mgr := clusterConnection.intraMgr // Synchronous send to preserve ordering - if err := mgr.sendAck(context.Background(), owner, clientShard, serverShard, routedAck.Req); err != nil { + if err := sm.intraMgr.sendAck(context.Background(), owner, clientShard, serverShard, routedAck.Req); err != nil { logger.Error("Failed to forward ACK to shard owner via intra-proxy", tag.Error(err), tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) return false } @@ -640,14 +777,13 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( func (sm *shardManagerImpl) DeliverMessagesToShardOwner( targetShard history.ClusterShardID, routedMsg *RoutedMessage, - clusterConnection *ClusterConnection, shutdownChan channel.ShutdownOnce, logger log.Logger, ) bool { logger = log.With(logger, tag.NewStringTag("task-target-shard", ClusterShardIDtoString(targetShard))) // Try local delivery first - if ch, ok := clusterConnection.remoteSendChannels[targetShard]; ok { + if ch, ok := sm.GetRemoteSendChan(targetShard); ok { delivered := false func() { defer func() { @@ -673,7 +809,7 @@ func (sm *shardManagerImpl) DeliverMessagesToShardOwner( // Attempt remote delivery via intra-proxy when enabled and shard is remote if sm.memberlistConfig != nil { - if owner, ok := sm.GetShardOwner(targetShard); ok && owner != sm.memberlistConfig.NodeName { + if owner, ok := sm.GetShardOwner(targetShard); ok && owner != sm.GetNodeName() { if addr, found := sm.GetProxyAddress(owner); found { if mgr := sm.GetIntraProxyManager(); mgr != nil { resp := routedMsg.Resp @@ -693,39 +829,47 @@ func (sm *shardManagerImpl) DeliverMessagesToShardOwner( return false } -func (sm *shardManagerImpl) SetIntraProxyManager(intraMgr *intraProxyManager) { - sm.intraMgr = intraMgr - +func (sm *shardManagerImpl) SetupCallbacks() { // Wire memberlist peer-join callback to reconcile intra-proxy receivers for local/remote pairs sm.SetOnPeerJoin(func(nodeName string) { sm.logger.Info("OnPeerJoin", tag.NewStringTag("nodeName", nodeName)) defer sm.logger.Info("OnPeerJoin done", tag.NewStringTag("nodeName", nodeName)) - sm.intraMgr.Notify() - // proxy.intraMgr.ReconcilePeerStreams(proxy, nodeName) + if sm.intraMgr != nil { + sm.intraMgr.Notify() + } }) // Wire peer-leave to cleanup intra-proxy resources for that peer sm.SetOnPeerLeave(func(nodeName string) { sm.logger.Info("OnPeerLeave", tag.NewStringTag("nodeName", nodeName)) defer sm.logger.Info("OnPeerLeave done", tag.NewStringTag("nodeName", nodeName)) - sm.intraMgr.Notify() - // proxy.intraMgr.ReconcilePeerStreams(proxy, nodeName) + if sm.intraMgr != nil { + sm.intraMgr.Notify() + } }) // Wire local shard changes to reconcile intra-proxy receivers sm.SetOnLocalShardChange(func(shard history.ClusterShardID, added bool) { sm.logger.Info("OnLocalShardChange", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) defer sm.logger.Info("OnLocalShardChange done", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) - sm.intraMgr.Notify() - // proxy.intraMgr.ReconcilePeerStreams(proxy, "") + if added { + sm.notifyReceiversOfNewShard(shard) + } + if sm.intraMgr != nil { + sm.intraMgr.Notify() + } }) // Wire remote shard changes to reconcile intra-proxy receivers sm.SetOnRemoteShardChange(func(peer string, shard history.ClusterShardID, added bool) { sm.logger.Info("OnRemoteShardChange", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) defer sm.logger.Info("OnRemoteShardChange done", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) - sm.intraMgr.Notify() - // proxy.intraMgr.ReconcilePeerStreams(proxy, peer) + if added { + sm.notifyReceiversOfNewShard(shard) + } + if sm.intraMgr != nil { + sm.intraMgr.Notify() + } }) } @@ -740,7 +884,7 @@ func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.C msg := ShardMessage{ Type: msgType, - NodeName: sm.memberlistConfig.NodeName, + NodeName: sm.GetNodeName(), ClientShard: shard, Timestamp: time.Now(), } @@ -753,7 +897,7 @@ func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.C for _, member := range sm.ml.Members() { // Skip sending to self node - if member.Name == sm.memberlistConfig.NodeName { + if member.Name == sm.GetNodeName() { continue } @@ -780,7 +924,7 @@ func (sd *shardDelegate) NodeMeta(limit int) []byte { for k, v := range sd.manager.localShards { shardsCopy[k] = v } - nodeName := sd.manager.memberlistConfig.NodeName + nodeName := sd.manager.GetNodeName() sd.manager.mutex.RUnlock() state := NodeShardState{ @@ -797,7 +941,7 @@ func (sd *shardDelegate) NodeMeta(limit int) []byte { if len(data) > limit { // If metadata is too large, just send node name - return []byte(sd.manager.memberlistConfig.NodeName) + return []byte(sd.manager.GetNodeName()) } return data @@ -876,6 +1020,166 @@ func (sm *shardManagerImpl) removeLocalShard(shard history.ClusterShardID) { delete(sm.localShards, key) } +// RegisterActiveReceiver registers an active receiver for watermark propagation +func (sm *shardManagerImpl) RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver *proxyStreamReceiver) { + sm.activeReceiversMu.Lock() + defer sm.activeReceiversMu.Unlock() + sm.activeReceivers[sourceShardID] = receiver +} + +// UnregisterActiveReceiver removes an active receiver +func (sm *shardManagerImpl) UnregisterActiveReceiver(sourceShardID history.ClusterShardID) { + sm.activeReceiversMu.Lock() + defer sm.activeReceiversMu.Unlock() + delete(sm.activeReceivers, sourceShardID) +} + +// notifyReceiversOfNewShard notifies all receivers about a newly registered target shard +// so they can send pending watermarks if available +func (sm *shardManagerImpl) notifyReceiversOfNewShard(targetShardID history.ClusterShardID) { + sm.activeReceiversMu.RLock() + receivers := make([]*proxyStreamReceiver, 0, len(sm.activeReceivers)) + for _, receiver := range sm.activeReceivers { + receivers = append(receivers, receiver) + } + sm.activeReceiversMu.RUnlock() + + for _, receiver := range receivers { + // Only notify receivers that route to the same cluster as the newly registered shard + if receiver.targetShardID.ClusterID == targetShardID.ClusterID { + receiver.sendPendingWatermarkToShard(targetShardID) + } + } +} + +// SetRemoteSendChan registers a send channel for a specific shard ID +func (sm *shardManagerImpl) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) { + sm.logger.Info("Register remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.remoteSendChannelsMu.Lock() + defer sm.remoteSendChannelsMu.Unlock() + sm.remoteSendChannels[shardID] = sendChan +} + +// GetRemoteSendChan retrieves the send channel for a specific shard ID +func (sm *shardManagerImpl) GetRemoteSendChan(shardID history.ClusterShardID) (chan RoutedMessage, bool) { + sm.remoteSendChannelsMu.RLock() + defer sm.remoteSendChannelsMu.RUnlock() + ch, exists := sm.remoteSendChannels[shardID] + return ch, exists +} + +// GetAllRemoteSendChans returns a map of all remote send channels +func (sm *shardManagerImpl) GetAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage { + sm.remoteSendChannelsMu.RLock() + defer sm.remoteSendChannelsMu.RUnlock() + + // Create a copy of the map + result := make(map[history.ClusterShardID]chan RoutedMessage, len(sm.remoteSendChannels)) + for k, v := range sm.remoteSendChannels { + result[k] = v + } + return result +} + +// GetRemoteSendChansByCluster returns a copy of remote send channels filtered by clusterID +func (sm *shardManagerImpl) GetRemoteSendChansByCluster(clusterID int32) map[history.ClusterShardID]chan RoutedMessage { + sm.remoteSendChannelsMu.RLock() + defer sm.remoteSendChannelsMu.RUnlock() + + result := make(map[history.ClusterShardID]chan RoutedMessage) + for k, v := range sm.remoteSendChannels { + if k.ClusterID == clusterID { + result[k] = v + } + } + return result +} + +// RemoveRemoteSendChan removes the send channel for a specific shard ID only if it matches the provided channel +func (sm *shardManagerImpl) RemoveRemoteSendChan(shardID history.ClusterShardID, expectedChan chan RoutedMessage) { + sm.remoteSendChannelsMu.Lock() + defer sm.remoteSendChannelsMu.Unlock() + if currentChan, exists := sm.remoteSendChannels[shardID]; exists && currentChan == expectedChan { + delete(sm.remoteSendChannels, shardID) + sm.logger.Info("Removed remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } else { + sm.logger.Info("Skipped removing remote send channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } +} + +// SetLocalAckChan registers an ack channel for a specific shard ID +func (sm *shardManagerImpl) SetLocalAckChan(shardID history.ClusterShardID, ackChan chan RoutedAck) { + sm.logger.Info("Register local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.localAckChannelsMu.Lock() + defer sm.localAckChannelsMu.Unlock() + sm.localAckChannels[shardID] = ackChan +} + +// GetLocalAckChan retrieves the ack channel for a specific shard ID +func (sm *shardManagerImpl) GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, bool) { + sm.localAckChannelsMu.RLock() + defer sm.localAckChannelsMu.RUnlock() + ch, exists := sm.localAckChannels[shardID] + return ch, exists +} + +// GetAllLocalAckChans returns a map of all local ack channels +func (sm *shardManagerImpl) GetAllLocalAckChans() map[history.ClusterShardID]chan RoutedAck { + sm.localAckChannelsMu.RLock() + defer sm.localAckChannelsMu.RUnlock() + + // Create a copy of the map + result := make(map[history.ClusterShardID]chan RoutedAck, len(sm.localAckChannels)) + for k, v := range sm.localAckChannels { + result[k] = v + } + return result +} + +// RemoveLocalAckChan removes the ack channel for a specific shard ID only if it matches the provided channel +func (sm *shardManagerImpl) RemoveLocalAckChan(shardID history.ClusterShardID, expectedChan chan RoutedAck) { + sm.logger.Info("Remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.localAckChannelsMu.Lock() + defer sm.localAckChannelsMu.Unlock() + if currentChan, exists := sm.localAckChannels[shardID]; exists && currentChan == expectedChan { + delete(sm.localAckChannels, shardID) + } else { + sm.logger.Info("Skipped removing local ack channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } +} + +// ForceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID +func (sm *shardManagerImpl) ForceRemoveLocalAckChan(shardID history.ClusterShardID) { + sm.logger.Info("Force remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.localAckChannelsMu.Lock() + defer sm.localAckChannelsMu.Unlock() + delete(sm.localAckChannels, shardID) +} + +// SetLocalReceiverCancelFunc registers a cancel function for a local receiver for a specific shard ID +func (sm *shardManagerImpl) SetLocalReceiverCancelFunc(shardID history.ClusterShardID, cancelFunc context.CancelFunc) { + sm.logger.Info("Register local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.localReceiverCancelFuncsMu.Lock() + defer sm.localReceiverCancelFuncsMu.Unlock() + sm.localReceiverCancelFuncs[shardID] = cancelFunc +} + +// GetLocalReceiverCancelFunc retrieves the cancel function for a local receiver for a specific shard ID +func (sm *shardManagerImpl) GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (context.CancelFunc, bool) { + sm.localReceiverCancelFuncsMu.RLock() + defer sm.localReceiverCancelFuncsMu.RUnlock() + cancelFunc, exists := sm.localReceiverCancelFuncs[shardID] + return cancelFunc, exists +} + +// RemoveLocalReceiverCancelFunc unconditionally removes the cancel function for a local receiver for a specific shard ID +func (sm *shardManagerImpl) RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) { + sm.logger.Info("Remove local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.localReceiverCancelFuncsMu.Lock() + defer sm.localReceiverCancelFuncsMu.Unlock() + delete(sm.localReceiverCancelFuncs, shardID) +} + // shardEventDelegate handles memberlist cluster events type shardEventDelegate struct { manager *shardManagerImpl diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index 17930e1b..692f5864 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -141,7 +141,11 @@ func getFreePort() int { if err != nil { panic(fmt.Sprintf("failed to get free port: %v", err)) } - defer l.Close() + defer func() { + if err := l.Close(); err != nil { + fmt.Printf("Failed to close listener: %v\n", err) + } + }() return l.Addr().(*net.TCPAddr).Port } From 286f51f2ad9ee6060e880fbbaef6427cd07c8d12 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Wed, 17 Dec 2025 09:42:00 -0800 Subject: [PATCH 20/38] update helm --- .github/workflows/pull-request.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index dda69249..3302d852 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -67,7 +67,7 @@ jobs: - name: Install helm uses: azure/setup-helm@v4.3.0 with: - version: v3.17.3 + version: v3.19.4 - name: Install helm-unittest run: helm plugin install https://github.com/helm-unittest/helm-unittest.git From 0b7b2d435a17cdd5f214716da118c0092c9b8d36 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Wed, 17 Dec 2025 10:10:57 -0800 Subject: [PATCH 21/38] remove clusterConnection from adminServiceProxyServer --- proxy/adminservice.go | 16 ++++++++-------- proxy/cluster_connection.go | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/proxy/adminservice.go b/proxy/adminservice.go index ac52d215..9721265c 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -35,7 +35,7 @@ type ( adminServiceProxyServer struct { adminservice.UnimplementedAdminServiceServer - clusterConnection *ClusterConnection + shardManager ShardManager adminClient adminservice.AdminServiceClient adminClientReverse adminservice.AdminServiceClient logger log.Logger @@ -60,14 +60,14 @@ func NewAdminServiceProxyServer( lcmParameters LCMParameters, routingParameters RoutingParameters, logger log.Logger, - clusterConnection *ClusterConnection, + shardManager ShardManager, ) adminservice.AdminServiceServer { // The AdminServiceStreams will duplicate the same output for an underlying connection issue hundreds of times. // Limit their output to three times per minute logger = log.NewThrottledLogger(log.With(logger, common.ServiceTag(serviceName)), func() float64 { return 3.0 / 60.0 }) return &adminServiceProxyServer{ - clusterConnection: clusterConnection, + shardManager: shardManager, adminClient: adminClient, adminClientReverse: adminClientReverse, logger: logger, @@ -379,8 +379,8 @@ func (s *adminServiceProxyServer) streamIntraProxyRouting( } // Only allow intra-proxy when at least one shard is local to this proxy instance - isLocalSource := s.clusterConnection.shardManager.IsLocalShard(sourceShardID) - isLocalTarget := s.clusterConnection.shardManager.IsLocalShard(targetShardID) + isLocalSource := s.shardManager.IsLocalShard(sourceShardID) + isLocalTarget := s.shardManager.IsLocalShard(targetShardID) if isLocalSource || !isLocalTarget { logger.Info("Skipping intra-proxy between two local shards or two remote shards. Client may use outdated shard info.", tag.NewBoolTag("isLocalSource", isLocalSource), @@ -392,7 +392,7 @@ func (s *adminServiceProxyServer) streamIntraProxyRouting( // Sender: handle ACKs coming from peer and forward to original owner sender := &intraProxyStreamSender{ logger: logger, - shardManager: s.clusterConnection.shardManager, + shardManager: s.shardManager, peerNodeName: peerNodeName, sourceShardID: sourceShardID, targetShardID: targetShardID, @@ -421,7 +421,7 @@ func (s *adminServiceProxyServer) streamRouting( // server: stream sender proxyStreamSender := &proxyStreamSender{ logger: logger, - shardManager: s.clusterConnection.shardManager, + shardManager: s.shardManager, sourceShardID: sourceShardID, targetShardID: targetShardID, directionLabel: s.routingParameters.DirectionLabel, @@ -429,7 +429,7 @@ func (s *adminServiceProxyServer) streamRouting( proxyStreamReceiver := &proxyStreamReceiver{ logger: s.logger, - shardManager: s.clusterConnection.shardManager, + shardManager: s.shardManager, adminClient: s.adminClientReverse, localShardCount: s.routingParameters.RoutingLocalShardCount, sourceShardID: targetShardID, // reverse direction diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index b727d519..b5eb8900 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -330,7 +330,7 @@ func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, obs c.lcmParameters, c.routingParameters, c.logger, - c.clusterConnection, + c.clusterConnection.shardManager, ) var accessControl *auth.AccessControl if c.clusterDefinition.ACLPolicy != nil { From 34e07dba97026bcfd97df7fe5af0b07179c8acf2 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Wed, 17 Dec 2025 11:02:38 -0800 Subject: [PATCH 22/38] fix unit test --- proxy/cluster_connection.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index b5eb8900..ec7c5ed0 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -132,6 +132,8 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon return nil, err } + cc.shardManager = NewShardManager(connConfig.MemberlistConfig, connConfig.ShardCountConfig, logger) + getLCMParameters := func(shardCountConfig config.ShardCountConfig, inverse bool) LCMParameters { if shardCountConfig.Mode != config.ShardCountLCM { return LCMParameters{} @@ -202,8 +204,6 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon return nil, err } - cc.shardManager = NewShardManager(connConfig.MemberlistConfig, connConfig.ShardCountConfig, logger) - return cc, nil } From 2f26e6f512a230fbbd845c85174582373d8a50f8 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Wed, 17 Dec 2025 11:16:27 -0800 Subject: [PATCH 23/38] fix test error --- transport/mux/multi_mux_manager_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transport/mux/multi_mux_manager_test.go b/transport/mux/multi_mux_manager_test.go index 87105828..391ceac2 100644 --- a/transport/mux/multi_mux_manager_test.go +++ b/transport/mux/multi_mux_manager_test.go @@ -47,12 +47,14 @@ func TestMultiMuxManager(t *testing.T) { require.False(t, muxesOnPipes.clientMM.CanAcceptConnections(), "All connections should have been consumed") // Close connections. We should see both sides fire disconnectFn + require.Eventually(t, func() bool { return clientConns.Load() != nil }, 2*time.Second, 10*time.Millisecond, "clientConns should be set") for _, v := range *clientConns.Load() { v.Close() } clientEvent = proxyassert.RequireCh(t, muxesOnPipes.clientEvents, 2*time.Second, "Client connection failed to disconnect!\nclientMux:%s", muxesOnPipes.clientMM.Describe()) require.Equal(t, "closed", clientEvent.eventType) require.Same(t, clientSession, clientEvent.session) + require.Eventually(t, func() bool { return serverConns.Load() != nil }, 2*time.Second, 10*time.Millisecond, "serverConns should be set") for _, v := range *serverConns.Load() { v.Close() } From 970b6adcf9593d1dacdc5c528a7a7c4edd0d72e1 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 19 Dec 2025 17:04:22 -0800 Subject: [PATCH 24/38] fix intra proxy streams; add connection debug info. --- metrics/prometheus_defs.go | 7 +- proxy/adminservice.go | 2 +- proxy/cluster_connection.go | 2 +- proxy/debug.go | 104 +++++++++++++++++-- proxy/intra_proxy_router.go | 59 ++++++----- proxy/shard_manager.go | 15 ++- transport/mux/multi_mux_manager.go | 13 +++ transport/mux/session/managed_mux_session.go | 12 +++ 8 files changed, 171 insertions(+), 43 deletions(-) diff --git a/metrics/prometheus_defs.go b/metrics/prometheus_defs.go index 45bd3ba1..f84c6b97 100644 --- a/metrics/prometheus_defs.go +++ b/metrics/prometheus_defs.go @@ -36,8 +36,9 @@ var ( GRPCServerStarted = DefaultCounterVec("grpc_server_started", "Emits when the grpc server is started", "service_name") GRPCServerStopped = DefaultCounterVec("grpc_server_stopped", "Emits when the grpc server is stopped", "service_name", "error") - GRPCOutboundClientMetrics = GetStandardGRPCClientInterceptor("outbound") - GRPCInboundClientMetrics = GetStandardGRPCClientInterceptor("inbound") + GRPCOutboundClientMetrics = GetStandardGRPCClientInterceptor("outbound") + GRPCInboundClientMetrics = GetStandardGRPCClientInterceptor("inbound") + GRPCIntraProxyClientMetrics = GetStandardGRPCClientInterceptor("intra_proxy") // /transport/mux @@ -86,6 +87,8 @@ func GetGRPCClientMetrics(directionLabel string) *grpcprom.ClientMetrics { return GRPCOutboundClientMetrics case "inbound": return GRPCInboundClientMetrics + case "intra_proxy": + return GRPCIntraProxyClientMetrics } panic("unknown direction label: " + directionLabel) } diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 9721265c..6cf491c7 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -381,7 +381,7 @@ func (s *adminServiceProxyServer) streamIntraProxyRouting( // Only allow intra-proxy when at least one shard is local to this proxy instance isLocalSource := s.shardManager.IsLocalShard(sourceShardID) isLocalTarget := s.shardManager.IsLocalShard(targetShardID) - if isLocalSource || !isLocalTarget { + if isLocalTarget || !isLocalSource { logger.Info("Skipping intra-proxy between two local shards or two remote shards. Client may use outdated shard info.", tag.NewBoolTag("isLocalSource", isLocalSource), tag.NewBoolTag("isLocalTarget", isLocalTarget), diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index ec7c5ed0..8fe3f9e8 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -132,7 +132,7 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon return nil, err } - cc.shardManager = NewShardManager(connConfig.MemberlistConfig, connConfig.ShardCountConfig, logger) + cc.shardManager = NewShardManager(connConfig.MemberlistConfig, connConfig.ShardCountConfig, connConfig.LocalServer.Connection.TcpClient.TLSConfig, logger) getLCMParameters := func(shardCountConfig config.ShardCountConfig, inverse bool) LCMParameters { if shardCountConfig.Mode != config.ShardCountLCM { diff --git a/proxy/debug.go b/proxy/debug.go index 44ab5cab..4af633e1 100644 --- a/proxy/debug.go +++ b/proxy/debug.go @@ -8,6 +8,9 @@ import ( "go.temporal.io/server/client/history" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + + "github.com/temporalio/s2s-proxy/transport/mux" + "github.com/temporalio/s2s-proxy/transport/mux/session" ) type ( @@ -80,12 +83,31 @@ type ( TotalAckChannels int `json:"total_ack_channels"` } + // MuxConnectionInfo holds debug information about a mux connection + MuxConnectionInfo struct { + ID string `json:"id"` + LocalAddr string `json:"local_addr"` + RemoteAddr string `json:"remote_addr"` + State string `json:"state"` + IsClosed bool `json:"is_closed"` + } + + // MuxConnectionsDebugInfo holds debug information about mux connections for a cluster connection + MuxConnectionsDebugInfo struct { + ConnectionName string `json:"connection_name"` + Direction string `json:"direction"` + Address string `json:"address"` + Connections []MuxConnectionInfo `json:"connections"` + ConnectionCount int `json:"connection_count"` + } + DebugResponse struct { - Timestamp time.Time `json:"timestamp"` - ActiveStreams []StreamInfo `json:"active_streams"` - StreamCount int `json:"stream_count"` - ShardInfos []ShardDebugInfo `json:"shard_infos"` - ChannelInfos []ChannelDebugInfo `json:"channel_infos"` + Timestamp time.Time `json:"timestamp"` + ActiveStreams []StreamInfo `json:"active_streams"` + StreamCount int `json:"stream_count"` + ShardInfos []ShardDebugInfo `json:"shard_infos"` + ChannelInfos []ChannelDebugInfo `json:"channel_infos"` + MuxConnections []MuxConnectionsDebugInfo `json:"mux_connections"` } ) @@ -96,6 +118,7 @@ func HandleDebugInfo(w http.ResponseWriter, r *http.Request, proxyInstance *Prox var streamCount int var shardInfos []ShardDebugInfo var channelInfos []ChannelDebugInfo + var muxConnections []MuxConnectionsDebugInfo // Get active streams information streamTracker := GetGlobalStreamTracker() @@ -106,14 +129,19 @@ func HandleDebugInfo(w http.ResponseWriter, r *http.Request, proxyInstance *Prox shardInfos = append(shardInfos, clusterConnection.shardManager.GetShardInfos()...) channelInfos = append(channelInfos, clusterConnection.shardManager.GetChannelInfo()) } + + // Collect mux connection info from inbound and outbound servers + muxConnections = append(muxConnections, getMuxConnectionsInfo(clusterConnection.inboundServer, "inbound")...) + muxConnections = append(muxConnections, getMuxConnectionsInfo(clusterConnection.outboundServer, "outbound")...) } response := DebugResponse{ - Timestamp: time.Now(), - ActiveStreams: activeStreams, - StreamCount: streamCount, - ShardInfos: shardInfos, - ChannelInfos: channelInfos, + Timestamp: time.Now(), + ActiveStreams: activeStreams, + StreamCount: streamCount, + ShardInfos: shardInfos, + ChannelInfos: channelInfos, + MuxConnections: muxConnections, } if err := json.NewEncoder(w).Encode(response); err != nil { @@ -121,3 +149,59 @@ func HandleDebugInfo(w http.ResponseWriter, r *http.Request, proxyInstance *Prox http.Error(w, "Internal server error", http.StatusInternalServerError) } } + +func getMuxConnectionsInfo(server contextAwareServer, direction string) []MuxConnectionsDebugInfo { + muxMgr, ok := server.(mux.MultiMuxManager) + if !ok { + return nil + } + + connections := muxMgr.GetMuxConnections() + if len(connections) == 0 { + return nil + } + + var connInfos []MuxConnectionInfo + for id, muxSession := range connections { + localAddr, remoteAddr := muxSession.GetConnectionInfo() + state := muxSession.State() + stateStr := "unknown" + if state != nil { + switch state.State { + case session.Connected: + stateStr = "connected" + case session.Closed: + stateStr = "closed" + case session.Error: + stateStr = "error" + } + } + + localAddrStr := "" + if localAddr != nil { + localAddrStr = localAddr.String() + } + remoteAddrStr := "" + if remoteAddr != nil { + remoteAddrStr = remoteAddr.String() + } + + connInfos = append(connInfos, MuxConnectionInfo{ + ID: id, + LocalAddr: localAddrStr, + RemoteAddr: remoteAddrStr, + State: stateStr, + IsClosed: muxSession.IsClosed(), + }) + } + + return []MuxConnectionsDebugInfo{ + { + ConnectionName: muxMgr.Name(), + Direction: direction, + Address: muxMgr.Address(), + Connections: connInfos, + ConnectionCount: len(connInfos), + }, + } +} diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index a4715772..726432cf 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "crypto/tls" "fmt" "io" "sync" @@ -17,6 +18,9 @@ import ( "google.golang.org/grpc/metadata" "github.com/temporalio/s2s-proxy/common" + "github.com/temporalio/s2s-proxy/encryption" + "github.com/temporalio/s2s-proxy/metrics" + "github.com/temporalio/s2s-proxy/transport/grpcutil" ) // intraProxyManager maintains long-lived intra-proxy streams to peer proxies and @@ -375,59 +379,62 @@ func (m *intraProxyManager) ensurePeer( ctx context.Context, peerNodeName string, ) (*peerState, error) { + logger := log.With(m.logger, tag.NewStringTag("peerNodeName", peerNodeName)) + logger.Info("ensurePeer started") + defer logger.Info("ensurePeer finished") + m.streamsMu.RLock() if ps, ok := m.peers[peerNodeName]; ok && ps != nil && ps.conn != nil { m.streamsMu.RUnlock() + logger.Info("ensurePeer found existing peer with connection") return ps, nil } m.streamsMu.RUnlock() + logger.Info("ensurePeer creating new peer connection") + // Build TLS from this proxy's outbound client TLS config if available - var dialOpts []grpc.DialOption - - // TODO: FIX this for new config format - // var tlsCfg *config.ClientTLSConfig - // if p.outboundServer != nil { - // t := p.outboundServer.config.Client.TLS - // tlsCfg = &t - // } else if p.inboundServer != nil { - // t := p.inboundServer.config.Client.TLS - // tlsCfg = &t - // } - // if tlsCfg != nil && tlsCfg.IsEnabled() { - // cfg, e := encryption.GetClientTLSConfig(*tlsCfg) - // if e != nil { - // return nil, e - // } - // dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(cfg))) - // } else { - // dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) - // } - // // Reuse default grpc options from transport - // dialOpts = append(dialOpts, - // grpc.WithDefaultServiceConfig(transport.DefaultServiceConfig), - // grpc.WithDisableServiceConfig(), - // ) + tlsCfg := m.shardManager.GetIntraProxyTLSConfig() + var parsedTLSCfg *tls.Config + if tlsCfg.IsEnabled() { + logger.Info("ensurePeer TLS enabled, building TLS config") + var err error + parsedTLSCfg, err = encryption.GetClientTLSConfig(tlsCfg) + if err != nil { + logger.Error("ensurePeer failed to create TLS config", tag.Error(err)) + return nil, fmt.Errorf("config error when creating tls config: %w", err) + } + } else { + logger.Info("ensurePeer TLS disabled") + } + dialOpts := grpcutil.MakeDialOptions(parsedTLSCfg, metrics.GetGRPCClientMetrics("intra_proxy")) proxyAddresses, ok := m.shardManager.GetProxyAddress(peerNodeName) if !ok { + logger.Error("ensurePeer proxy address not found") return nil, fmt.Errorf("proxy address not found") } + logger.Info("ensurePeer dialing peer", tag.NewStringTag("proxyAddresses", proxyAddresses)) - cc, err := grpc.DialContext(ctx, proxyAddresses, dialOpts...) //nolint:staticcheck // acceptable here + cc, err := grpc.NewClient(proxyAddresses, dialOpts...) if err != nil { + logger.Error("ensurePeer failed to dial peer", tag.Error(err)) return nil, err } + logger.Info("ensurePeer successfully dialed peer") m.streamsMu.Lock() ps := m.peers[peerNodeName] if ps == nil { + logger.Info("ensurePeer creating new peer state") ps = &peerState{conn: cc, receivers: make(map[peerStreamKey]*intraProxyStreamReceiver), senders: make(map[peerStreamKey]*intraProxyStreamSender), recvShutdown: make(map[peerStreamKey]channel.ShutdownOnce)} m.peers[peerNodeName] = ps } else { + logger.Info("ensurePeer updating existing peer state with new connection") old := ps.conn ps.conn = cc if old != nil { + logger.Info("ensurePeer closing old connection") _ = old.Close() } if ps.receivers == nil { diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 91f17e17..6219956d 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -15,6 +15,7 @@ import ( "go.temporal.io/server/common/log/tag" "github.com/temporalio/s2s-proxy/config" + "github.com/temporalio/s2s-proxy/encryption" ) type ( @@ -52,6 +53,8 @@ type ( TerminatePreviousLocalReceiver(shardID history.ClusterShardID, logger log.Logger) // GetIntraProxyManager returns the intra-proxy manager if it exists GetIntraProxyManager() *intraProxyManager + // GetIntraProxyTLSConfig returns the TLS config for intra-proxy connections + GetIntraProxyTLSConfig() encryption.TLSConfig // DeliverAckToShardOwner routes an ACK request to the appropriate shard owner (local or remote) DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool // DeliverMessagesToShardOwner routes replication messages to the appropriate shard owner (local or remote) @@ -110,8 +113,9 @@ type ( onLocalShardChange func(shard history.ClusterShardID, added bool) onRemoteShardChange func(peer string, shard history.ClusterShardID, added bool) // Local shards owned by this node, keyed by short id - localShards map[string]ShardInfo - intraMgr *intraProxyManager + localShards map[string]ShardInfo + intraMgr *intraProxyManager + intraProxyTLSConfig encryption.TLSConfig // Join retry control stopJoinRetry chan struct{} joinWg sync.WaitGroup @@ -159,7 +163,7 @@ type ( ) // NewShardManager creates a new shard manager instance -func NewShardManager(memberlistConfig *config.MemberlistConfig, shardCountConfig config.ShardCountConfig, logger log.Logger) ShardManager { +func NewShardManager(memberlistConfig *config.MemberlistConfig, shardCountConfig config.ShardCountConfig, intraProxyTLSConfig encryption.TLSConfig, logger log.Logger) ShardManager { delegate := &shardDelegate{ logger: logger, } @@ -170,6 +174,7 @@ func NewShardManager(memberlistConfig *config.MemberlistConfig, shardCountConfig delegate: delegate, localShards: make(map[string]ShardInfo), intraMgr: nil, + intraProxyTLSConfig: intraProxyTLSConfig, stopJoinRetry: make(chan struct{}), activeReceivers: make(map[history.ClusterShardID]*proxyStreamReceiver), remoteSendChannels: make(map[history.ClusterShardID]chan RoutedMessage), @@ -877,6 +882,10 @@ func (sm *shardManagerImpl) GetIntraProxyManager() *intraProxyManager { return sm.intraMgr } +func (sm *shardManagerImpl) GetIntraProxyTLSConfig() encryption.TLSConfig { + return sm.intraProxyTLSConfig +} + func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.ClusterShardID) { if !sm.started || sm.ml == nil || sm.memberlistConfig == nil { return diff --git a/transport/mux/multi_mux_manager.go b/transport/mux/multi_mux_manager.go index 536ee34a..8c9312e1 100644 --- a/transport/mux/multi_mux_manager.go +++ b/transport/mux/multi_mux_manager.go @@ -53,6 +53,8 @@ type ( CanAcceptConnections() bool Describe() string Name() string + // GetMuxConnections returns a snapshot of active mux connections + GetMuxConnections() map[string]session.ManagedMuxSession } MuxProviderBuilder func(AddNewMux, context.Context) (MuxProvider, error) ) @@ -202,3 +204,14 @@ func (m *multiMuxManager) Describe() string { func (m *multiMuxManager) Name() string { return m.name } + +func (m *multiMuxManager) GetMuxConnections() map[string]session.ManagedMuxSession { + m.muxesLock.RLock() + defer m.muxesLock.RUnlock() + // Return a copy to avoid holding the lock + result := make(map[string]session.ManagedMuxSession, len(m.muxes)) + for k, v := range m.muxes { + result[k] = v + } + return result +} diff --git a/transport/mux/session/managed_mux_session.go b/transport/mux/session/managed_mux_session.go index b9a582d9..2d87dbd2 100644 --- a/transport/mux/session/managed_mux_session.go +++ b/transport/mux/session/managed_mux_session.go @@ -38,6 +38,8 @@ type ( Open() (net.Conn, error) State() *MuxSessionInfo Describe() string + // GetConnectionInfo returns the local and remote addresses of the underlying connection + GetConnectionInfo() (localAddr, remoteAddr net.Addr) } ) @@ -142,3 +144,13 @@ func (s *muxSession) Addr() net.Addr { func (s *muxSession) Describe() string { return fmt.Sprintf("[muxSession %s, state=%v, address=%s]", s.id, s.state.Load(), s.conn.RemoteAddr().String()) } + +func (s *muxSession) GetConnectionInfo() (localAddr, remoteAddr net.Addr) { + if s.session != nil { + return s.session.LocalAddr(), s.session.RemoteAddr() + } + if s.conn != nil { + return s.conn.LocalAddr(), s.conn.RemoteAddr() + } + return nil, nil +} From bed6b1d2cb7b78e2e811bc0759f6c27adcf79f9c Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 19 Dec 2025 22:08:32 -0800 Subject: [PATCH 25/38] add tcp_proxy for test. add intra_proxy test file. --- endtoendtest/tcp_proxy.go | 176 +++++++ endtoendtest/tcp_proxy_test.go | 168 ++++++ proxy/test/intra_proxy_routing_test.go | 659 ++++++++++++++++++++++++ proxy/test/replication_failover_test.go | 25 +- testutil/testutil.go | 21 + 5 files changed, 1030 insertions(+), 19 deletions(-) create mode 100644 endtoendtest/tcp_proxy.go create mode 100644 endtoendtest/tcp_proxy_test.go create mode 100644 proxy/test/intra_proxy_routing_test.go create mode 100644 testutil/testutil.go diff --git a/endtoendtest/tcp_proxy.go b/endtoendtest/tcp_proxy.go new file mode 100644 index 00000000..9cacac2c --- /dev/null +++ b/endtoendtest/tcp_proxy.go @@ -0,0 +1,176 @@ +package endtoendtest + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" +) + +type ( + UpstreamServer struct { + Address string + conns atomic.Int64 + } + + Upstream struct { + Servers []*UpstreamServer + mu sync.RWMutex + } + + ProxyRule struct { + ListenPort string + Upstream *Upstream + } + + TCPProxy struct { + rules []*ProxyRule + logger log.Logger + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + servers []net.Listener + } +) + +func NewUpstream(servers []string) *Upstream { + upstreamServers := make([]*UpstreamServer, len(servers)) + for i, addr := range servers { + upstreamServers[i] = &UpstreamServer{Address: addr} + } + return &Upstream{Servers: upstreamServers} +} + +func (u *Upstream) selectLeastConn() *UpstreamServer { + u.mu.RLock() + defer u.mu.RUnlock() + + if len(u.Servers) == 0 { + return nil + } + + selected := u.Servers[0] + minConns := selected.conns.Load() + + for i := 1; i < len(u.Servers); i++ { + conns := u.Servers[i].conns.Load() + if conns < minConns { + minConns = conns + selected = u.Servers[i] + } + } + + return selected +} + +func (u *Upstream) incrementConn(server *UpstreamServer) { + server.conns.Add(1) +} + +func (u *Upstream) decrementConn(server *UpstreamServer) { + server.conns.Add(-1) +} + +func NewTCPProxy(logger log.Logger, rules []*ProxyRule) *TCPProxy { + ctx, cancel := context.WithCancel(context.Background()) + return &TCPProxy{ + rules: rules, + logger: logger, + ctx: ctx, + cancel: cancel, + } +} + +func (p *TCPProxy) Start() error { + for _, rule := range p.rules { + listener, err := net.Listen("tcp", ":"+rule.ListenPort) + if err != nil { + p.Stop() + return fmt.Errorf("failed to listen on port %s: %w", rule.ListenPort, err) + } + p.servers = append(p.servers, listener) + + p.wg.Add(1) + go p.handleListener(listener, rule) + } + + return nil +} + +func (p *TCPProxy) Stop() { + p.cancel() + for _, server := range p.servers { + _ = server.Close() + } + p.wg.Wait() +} + +func (p *TCPProxy) handleListener(listener net.Listener, rule *ProxyRule) { + defer p.wg.Done() + + for { + select { + case <-p.ctx.Done(): + return + default: + } + + clientConn, err := listener.Accept() + if err != nil { + select { + case <-p.ctx.Done(): + return + default: + p.logger.Warn("failed to accept connection", tag.Error(err)) + continue + } + } + + p.wg.Add(1) + go p.handleConnection(clientConn, rule) + } +} + +func (p *TCPProxy) handleConnection(clientConn net.Conn, rule *ProxyRule) { + defer p.wg.Done() + defer func() { _ = clientConn.Close() }() + + upstream := rule.Upstream.selectLeastConn() + if upstream == nil { + p.logger.Error("no upstream servers available") + return + } + + rule.Upstream.incrementConn(upstream) + defer rule.Upstream.decrementConn(upstream) + + serverConn, err := net.DialTimeout("tcp", upstream.Address, 5*time.Second) + if err != nil { + p.logger.Warn("failed to connect to upstream", tag.NewStringTag("upstream", upstream.Address), tag.Error(err)) + return + } + defer func() { _ = serverConn.Close() }() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, _ = io.Copy(serverConn, clientConn) + _ = serverConn.Close() + }() + + go func() { + defer wg.Done() + _, _ = io.Copy(clientConn, serverConn) + _ = clientConn.Close() + }() + + wg.Wait() +} diff --git a/endtoendtest/tcp_proxy_test.go b/endtoendtest/tcp_proxy_test.go new file mode 100644 index 00000000..aa5ecfeb --- /dev/null +++ b/endtoendtest/tcp_proxy_test.go @@ -0,0 +1,168 @@ +package endtoendtest + +import ( + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.temporal.io/server/common/log" + + "github.com/temporalio/s2s-proxy/testutil" +) + +func TestTCPProxy(t *testing.T) { + logger := log.NewTestLogger() + + server1Port := testutil.GetFreePort() + server2Port := testutil.GetFreePort() + server3Port := testutil.GetFreePort() + server4Port := testutil.GetFreePort() + server5Port := testutil.GetFreePort() + server6Port := testutil.GetFreePort() + + echoServer1 := startEchoServer(t, fmt.Sprintf("localhost:%d", server1Port)) + echoServer2 := startEchoServer(t, fmt.Sprintf("localhost:%d", server2Port)) + echoServer3 := startEchoServer(t, fmt.Sprintf("localhost:%d", server3Port)) + echoServer4 := startEchoServer(t, fmt.Sprintf("localhost:%d", server4Port)) + echoServer5 := startEchoServer(t, fmt.Sprintf("localhost:%d", server5Port)) + echoServer6 := startEchoServer(t, fmt.Sprintf("localhost:%d", server6Port)) + + defer func() { _ = echoServer1.Close() }() + defer func() { _ = echoServer2.Close() }() + defer func() { _ = echoServer3.Close() }() + defer func() { _ = echoServer4.Close() }() + defer func() { _ = echoServer5.Close() }() + defer func() { _ = echoServer6.Close() }() + + proxyPort1 := testutil.GetFreePort() + proxyPort2 := testutil.GetFreePort() + proxyPort3 := testutil.GetFreePort() + + rules := []*ProxyRule{ + { + ListenPort: fmt.Sprintf("%d", proxyPort1), + Upstream: NewUpstream([]string{fmt.Sprintf("localhost:%d", server1Port), fmt.Sprintf("localhost:%d", server2Port)}), + }, + { + ListenPort: fmt.Sprintf("%d", proxyPort2), + Upstream: NewUpstream([]string{fmt.Sprintf("localhost:%d", server3Port), fmt.Sprintf("localhost:%d", server4Port)}), + }, + { + ListenPort: fmt.Sprintf("%d", proxyPort3), + Upstream: NewUpstream([]string{fmt.Sprintf("localhost:%d", server5Port), fmt.Sprintf("localhost:%d", server6Port)}), + }, + } + + proxy := NewTCPProxy(logger, rules) + err := proxy.Start() + require.NoError(t, err) + defer proxy.Stop() + + // Test proxy on port 1 + testProxyConnection(t, fmt.Sprintf("localhost:%d", proxyPort1), "test message 1") + + // Test proxy on port 2 + testProxyConnection(t, fmt.Sprintf("localhost:%d", proxyPort2), "test message 2") + + // Test proxy on port 3 + testProxyConnection(t, fmt.Sprintf("localhost:%d", proxyPort3), "test message 3") +} + +func testProxyConnection(t *testing.T, proxyAddr, message string) { + conn, err := net.DialTimeout("tcp", proxyAddr, 5*time.Second) + require.NoError(t, err) + defer func() { _ = conn.Close() }() + + _, err = conn.Write([]byte(message)) + require.NoError(t, err) + + buf := make([]byte, len(message)) + _, err = io.ReadFull(conn, buf) + require.NoError(t, err) + require.Equal(t, message, string(buf)) +} + +func startEchoServer(t *testing.T, addr string) net.Listener { + listener, err := net.Listen("tcp", addr) + require.NoError(t, err) + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { _ = c.Close() }() + _, _ = io.Copy(c, c) + }(conn) + } + }() + + return listener +} + +func TestTCPProxyLeastConn(t *testing.T) { + logger := log.NewTestLogger() + + // Create two echo servers + server1Port := testutil.GetFreePort() + server2Port := testutil.GetFreePort() + server1 := startEchoServer(t, fmt.Sprintf("localhost:%d", server1Port)) + server2 := startEchoServer(t, fmt.Sprintf("localhost:%d", server2Port)) + defer func() { _ = server1.Close() }() + defer func() { _ = server2.Close() }() + + // Create proxy with two upstreams + proxyPort := testutil.GetFreePort() + rules := []*ProxyRule{ + { + ListenPort: fmt.Sprintf("%d", proxyPort), + Upstream: NewUpstream([]string{fmt.Sprintf("localhost:%d", server1Port), fmt.Sprintf("localhost:%d", server2Port)}), + }, + } + + proxy := NewTCPProxy(logger, rules) + err := proxy.Start() + require.NoError(t, err) + defer proxy.Stop() + + // Make multiple connections to verify load balancing + for i := 0; i < 10; i++ { + testProxyConnection(t, fmt.Sprintf("localhost:%d", proxyPort), "test") + time.Sleep(10 * time.Millisecond) + } +} + +func TestTCPProxyContextCancellation(t *testing.T) { + logger := log.NewTestLogger() + + serverPort := testutil.GetFreePort() + server := startEchoServer(t, fmt.Sprintf("localhost:%d", serverPort)) + defer func() { _ = server.Close() }() + + proxyPort := testutil.GetFreePort() + rules := []*ProxyRule{ + { + ListenPort: fmt.Sprintf("%d", proxyPort), + Upstream: NewUpstream([]string{fmt.Sprintf("localhost:%d", serverPort)}), + }, + } + + proxy := NewTCPProxy(logger, rules) + err := proxy.Start() + require.NoError(t, err) + + // Verify it's working + testProxyConnection(t, fmt.Sprintf("localhost:%d", proxyPort), "test") + + // Stop the proxy + proxy.Stop() + + // Verify new connections fail + _, err = net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", proxyPort), 100*time.Millisecond) + require.Error(t, err) +} diff --git a/proxy/test/intra_proxy_routing_test.go b/proxy/test/intra_proxy_routing_test.go new file mode 100644 index 00000000..d4656072 --- /dev/null +++ b/proxy/test/intra_proxy_routing_test.go @@ -0,0 +1,659 @@ +package proxy + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "go.temporal.io/server/api/adminservice/v1" + "go.temporal.io/server/api/historyservice/v1" + "go.temporal.io/server/common" + "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/tests/testcore" + + "github.com/temporalio/s2s-proxy/config" + s2sproxy "github.com/temporalio/s2s-proxy/proxy" + "github.com/temporalio/s2s-proxy/testutil" +) + +type ( + IntraProxyRoutingTestSuite struct { + suite.Suite + *require.Assertions + + logger log.Logger + + clusterA *testcore.TestCluster + clusterB *testcore.TestCluster + + proxyA1 *s2sproxy.Proxy + proxyA2 *s2sproxy.Proxy + proxyB1 *s2sproxy.Proxy + proxyB2 *s2sproxy.Proxy + + proxyA1Outbound string + proxyA2Outbound string + proxyB1Outbound string + proxyB2Outbound string + + proxyB1Mux string + proxyB2Mux string + + proxyA1MemberlistPort int + proxyA2MemberlistPort int + proxyB1MemberlistPort int + proxyB2MemberlistPort int + + loadBalancerA *trackingTCPProxy + loadBalancerB *trackingTCPProxy + loadBalancerC *trackingTCPProxy + + loadBalancerAPort string + loadBalancerBPort string + loadBalancerCPort string + + connectionCountsA1 atomic.Int64 + connectionCountsA2 atomic.Int64 + connectionCountsB1 atomic.Int64 + connectionCountsB2 atomic.Int64 + connectionCountsPA1 atomic.Int64 + connectionCountsPA2 atomic.Int64 + } +) + +func TestIntraProxyRoutingTestSuite(t *testing.T) { + s := &IntraProxyRoutingTestSuite{} + suite.Run(t, s) +} + +func (s *IntraProxyRoutingTestSuite) SetupSuite() { + s.Assertions = require.New(s.T()) + s.logger = log.NewTestLogger() + + s.logger.Info("Setting up intra-proxy routing test suite") + + s.clusterA = s.createCluster("cluster-a", 2, 1, 1) + s.clusterB = s.createCluster("cluster-b", 2, 2, 1) + + s.proxyA1Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyA2Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyB1Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyB2Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + + s.proxyB1Mux = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyB2Mux = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + + loadBalancerAPort := fmt.Sprintf("%d", testutil.GetFreePort()) + loadBalancerBPort := fmt.Sprintf("%d", testutil.GetFreePort()) + loadBalancerCPort := fmt.Sprintf("%d", testutil.GetFreePort()) + + s.loadBalancerAPort = loadBalancerAPort + s.loadBalancerBPort = loadBalancerBPort + s.loadBalancerCPort = loadBalancerCPort + + proxyA1Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyA2Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyB1Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyB2Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + + proxyAddressesA := map[string]string{ + "proxy-node-a-1": proxyA1Address, + "proxy-node-a-2": proxyA2Address, + } + proxyAddressesB := map[string]string{ + "proxy-node-b-1": proxyB1Address, + "proxy-node-b-2": proxyB2Address, + } + + s.proxyA1MemberlistPort = testutil.GetFreePort() + s.proxyA2MemberlistPort = testutil.GetFreePort() + s.proxyB1MemberlistPort = testutil.GetFreePort() + s.proxyB2MemberlistPort = testutil.GetFreePort() + + s.proxyB1 = s.createProxy("proxy-b-1", proxyB1Address, s.proxyB1Outbound, s.proxyB1Mux, s.clusterB, config.ServerMode, config.ShardCountConfig{}, "proxy-node-b-1", "127.0.0.1", s.proxyB1MemberlistPort, nil, proxyAddressesB) + s.proxyB2 = s.createProxy("proxy-b-2", proxyB2Address, s.proxyB2Outbound, s.proxyB2Mux, s.clusterB, config.ServerMode, config.ShardCountConfig{}, "proxy-node-b-2", "127.0.0.1", s.proxyB2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyB1MemberlistPort)}, proxyAddressesB) + + s.logger.Info("Setting up load balancers") + + s.loadBalancerA = s.createLoadBalancer(loadBalancerAPort, []string{s.proxyA1Outbound, s.proxyA2Outbound}, &s.connectionCountsA1, &s.connectionCountsA2) + s.loadBalancerB = s.createLoadBalancer(loadBalancerBPort, []string{s.proxyB1Mux, s.proxyB2Mux}, &s.connectionCountsPA1, &s.connectionCountsPA2) + s.loadBalancerC = s.createLoadBalancer(loadBalancerCPort, []string{s.proxyB1Outbound, s.proxyB2Outbound}, &s.connectionCountsB1, &s.connectionCountsB2) + + muxLoadBalancerBAddress := fmt.Sprintf("localhost:%s", loadBalancerBPort) + s.proxyA1 = s.createProxy("proxy-a-1", proxyA1Address, s.proxyA1Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-1", "127.0.0.1", s.proxyA1MemberlistPort, nil, proxyAddressesA) + s.proxyA2 = s.createProxy("proxy-a-2", proxyA2Address, s.proxyA2Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-2", "127.0.0.1", s.proxyA2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyA1MemberlistPort)}, proxyAddressesA) + + s.logger.Info("Waiting for proxies to start and connect") + time.Sleep(15 * time.Second) + + s.logger.Info("Configuring remote clusters") + s.configureRemoteCluster(s.clusterA, s.clusterB.ClusterName(), fmt.Sprintf("localhost:%s", loadBalancerAPort)) + s.configureRemoteCluster(s.clusterB, s.clusterA.ClusterName(), fmt.Sprintf("localhost:%s", loadBalancerCPort)) + s.waitForReplicationReady() +} + +func (s *IntraProxyRoutingTestSuite) TearDownSuite() { + s.logger.Info("Tearing down intra-proxy routing test suite") + if s.clusterA != nil && s.clusterB != nil { + s.logger.Info("Removing remote cluster A from cluster B") + s.removeRemoteCluster(s.clusterA, s.clusterB.ClusterName()) + s.logger.Info("Remote cluster A removed") + s.logger.Info("Removing remote cluster B from cluster A") + s.removeRemoteCluster(s.clusterB, s.clusterA.ClusterName()) + s.logger.Info("Remote cluster B removed") + } + if s.clusterA != nil { + s.NoError(s.clusterA.TearDownCluster()) + s.logger.Info("Cluster A torn down") + } + if s.clusterB != nil { + s.NoError(s.clusterB.TearDownCluster()) + s.logger.Info("Cluster B torn down") + } + if s.loadBalancerA != nil { + s.logger.Info("Stopping load balancer A") + s.loadBalancerA.Stop() + s.logger.Info("Load balancer A stopped") + } + if s.loadBalancerB != nil { + s.logger.Info("Stopping load balancer B") + s.loadBalancerB.Stop() + s.logger.Info("Load balancer B stopped") + } + if s.loadBalancerC != nil { + s.logger.Info("Stopping load balancer C") + s.loadBalancerC.Stop() + s.logger.Info("Load balancer C stopped") + } + if s.proxyA1 != nil { + s.logger.Info("Stopping proxy A1") + s.proxyA1.Stop() + s.logger.Info("Proxy A1 stopped") + } + if s.proxyA2 != nil { + s.logger.Info("Stopping proxy A2") + s.proxyA2.Stop() + s.logger.Info("Proxy A2 stopped") + } + if s.proxyB1 != nil { + s.logger.Info("Stopping proxy B1") + s.proxyB1.Stop() + s.logger.Info("Proxy B1 stopped") + } + if s.proxyB2 != nil { + s.logger.Info("Stopping proxy B2") + s.proxyB2.Stop() + s.logger.Info("Proxy B2 stopped") + } + s.logger.Info("Intra-proxy routing test suite torn down") +} + +func (s *IntraProxyRoutingTestSuite) createCluster( + clusterName string, + numShards int, + initialFailoverVersion int64, + numHistoryHosts int, +) *testcore.TestCluster { + clusterSuffix := common.GenerateRandomString(8) + fullClusterName := fmt.Sprintf("%s-%s", clusterName, clusterSuffix) + + clusterConfig := &testcore.TestClusterConfig{ + ClusterMetadata: cluster.Config{ + EnableGlobalNamespace: true, + FailoverVersionIncrement: 10, + MasterClusterName: fullClusterName, + CurrentClusterName: fullClusterName, + ClusterInformation: map[string]cluster.ClusterInformation{ + fullClusterName: { + Enabled: true, + InitialFailoverVersion: initialFailoverVersion, + }, + }, + }, + HistoryConfig: testcore.HistoryConfig{ + NumHistoryShards: int32(numShards), + NumHistoryHosts: numHistoryHosts, + }, + DynamicConfigOverrides: map[dynamicconfig.Key]interface{}{ + dynamicconfig.NamespaceCacheRefreshInterval.Key(): time.Second, + dynamicconfig.EnableReplicationStream.Key(): true, + dynamicconfig.EnableReplicationTaskBatching.Key(): true, + }, + } + + testClusterFactory := testcore.NewTestClusterFactory() + logger := log.With(s.logger, tag.NewStringTag("clusterName", clusterName)) + cluster, err := testClusterFactory.NewCluster(s.T(), clusterConfig, logger) + s.NoError(err, "Failed to create cluster %s", clusterName) + s.NotNil(cluster) + + return cluster +} + +func (s *IntraProxyRoutingTestSuite) createProxy( + name string, + inboundAddress string, + outboundAddress string, + muxAddress string, + cluster *testcore.TestCluster, + muxMode config.MuxMode, + shardCountConfig config.ShardCountConfig, + nodeName string, + memberlistBindAddr string, + memberlistBindPort int, + memberlistJoinAddrs []string, + proxyAddresses map[string]string, +) *s2sproxy.Proxy { + var muxConnectionType config.ConnectionType + var muxAddressInfo config.TCPTLSInfo + if muxMode == config.ServerMode { + muxConnectionType = config.ConnTypeMuxServer + muxAddressInfo = config.TCPTLSInfo{ + ConnectionString: muxAddress, + } + } else { + muxConnectionType = config.ConnTypeMuxClient + muxAddressInfo = config.TCPTLSInfo{ + ConnectionString: muxAddress, + } + } + + cfg := &config.S2SProxyConfig{ + ClusterConnections: []config.ClusterConnConfig{ + { + Name: name, + LocalServer: config.ClusterDefinition{ + Connection: config.TransportInfo{ + ConnectionType: config.ConnTypeTCP, + TcpClient: config.TCPTLSInfo{ + ConnectionString: cluster.Host().FrontendGRPCAddress(), + }, + TcpServer: config.TCPTLSInfo{ + ConnectionString: outboundAddress, + }, + }, + }, + RemoteServer: config.ClusterDefinition{ + Connection: config.TransportInfo{ + ConnectionType: muxConnectionType, + MuxCount: 1, + MuxAddressInfo: muxAddressInfo, + }, + }, + ShardCountConfig: shardCountConfig, + MemberlistConfig: &config.MemberlistConfig{ + Enabled: true, + NodeName: nodeName, + BindAddr: memberlistBindAddr, + BindPort: memberlistBindPort, + JoinAddrs: memberlistJoinAddrs, + ProxyAddresses: proxyAddresses, + TCPOnly: true, + }, + }, + }, + } + + configProvider := &simpleConfigProvider{cfg: *cfg} + proxy := s2sproxy.NewProxy(configProvider, s.logger) + s.NotNil(proxy) + + err := proxy.Start() + s.NoError(err, "Failed to start proxy %s", name) + + s.logger.Info("Started proxy", tag.NewStringTag("name", name), + tag.NewStringTag("inboundAddress", inboundAddress), + tag.NewStringTag("outboundAddress", outboundAddress), + tag.NewStringTag("muxAddress", muxAddress), + tag.NewStringTag("muxMode", string(muxMode)), + tag.NewStringTag("nodeName", nodeName), + ) + + return proxy +} + +type trackingUpstreamServer struct { + address string + conns atomic.Int64 + count1 *atomic.Int64 + count2 *atomic.Int64 +} + +type trackingUpstream struct { + servers []*trackingUpstreamServer + mu sync.RWMutex +} + +func (u *trackingUpstream) selectLeastConn() *trackingUpstreamServer { + u.mu.RLock() + defer u.mu.RUnlock() + + if len(u.servers) == 0 { + return nil + } + + selected := u.servers[0] + minConns := selected.conns.Load() + + for i := 1; i < len(u.servers); i++ { + conns := u.servers[i].conns.Load() + if conns < minConns { + minConns = conns + selected = u.servers[i] + } + } + + if selected != nil { + if selected == u.servers[0] { + selected.count1.Add(1) + } else if len(u.servers) > 1 && selected == u.servers[1] { + selected.count2.Add(1) + } + } + + return selected +} + +func (u *trackingUpstream) incrementConn(server *trackingUpstreamServer) { + server.conns.Add(1) +} + +func (u *trackingUpstream) decrementConn(server *trackingUpstreamServer) { + server.conns.Add(-1) +} + +type trackingTCPProxy struct { + rules []*trackingProxyRule + logger log.Logger + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + servers []net.Listener +} + +type trackingProxyRule struct { + ListenPort string + Upstream *trackingUpstream +} + +func (p *trackingTCPProxy) Start() error { + for _, rule := range p.rules { + listener, err := net.Listen("tcp", ":"+rule.ListenPort) + if err != nil { + p.Stop() + return fmt.Errorf("failed to listen on port %s: %w", rule.ListenPort, err) + } + p.servers = append(p.servers, listener) + + p.wg.Add(1) + go p.handleListener(listener, rule) + } + + return nil +} + +func (p *trackingTCPProxy) Stop() { + p.logger.Info("Stopping tracking TCP proxy") + p.cancel() + for _, server := range p.servers { + p.logger.Info("Closing server", tag.NewStringTag("server", server.Addr().String())) + _ = server.Close() + } + p.logger.Info("Waiting for goroutines to finish") + p.wg.Wait() + p.logger.Info("Tracking TCP proxy stopped") +} + +func (p *trackingTCPProxy) handleListener(listener net.Listener, rule *trackingProxyRule) { + defer p.wg.Done() + + for { + select { + case <-p.ctx.Done(): + return + default: + } + + clientConn, err := listener.Accept() + if err != nil { + select { + case <-p.ctx.Done(): + return + default: + p.logger.Warn("failed to accept connection", tag.Error(err)) + continue + } + } + + p.wg.Add(1) + go p.handleConnection(clientConn, rule) + } +} + +func (p *trackingTCPProxy) handleConnection(clientConn net.Conn, rule *trackingProxyRule) { + defer p.wg.Done() + defer func() { _ = clientConn.Close() }() + + // Check if already cancelled + select { + case <-p.ctx.Done(): + return + default: + } + + upstream := rule.Upstream.selectLeastConn() + if upstream == nil { + p.logger.Error("no upstream servers available") + return + } + + rule.Upstream.incrementConn(upstream) + defer rule.Upstream.decrementConn(upstream) + + serverConn, err := net.DialTimeout("tcp", upstream.address, 5*time.Second) + if err != nil { + p.logger.Warn("failed to connect to upstream", tag.NewStringTag("upstream", upstream.address), tag.Error(err)) + return + } + defer func() { _ = serverConn.Close() }() + + // Close connections when context is cancelled to unblock io.Copy + var wg sync.WaitGroup + wg.Add(3) + + go func() { + defer wg.Done() + <-p.ctx.Done() + _ = clientConn.Close() + _ = serverConn.Close() + }() + + go func() { + defer wg.Done() + _, _ = io.Copy(serverConn, clientConn) + _ = serverConn.Close() + }() + + go func() { + defer wg.Done() + _, _ = io.Copy(clientConn, serverConn) + _ = clientConn.Close() + }() + + wg.Wait() +} + +func (s *IntraProxyRoutingTestSuite) createLoadBalancer( + listenPort string, + upstreams []string, + count1 *atomic.Int64, + count2 *atomic.Int64, +) *trackingTCPProxy { + trackingServers := make([]*trackingUpstreamServer, len(upstreams)) + for i, addr := range upstreams { + trackingServers[i] = &trackingUpstreamServer{ + address: addr, + count1: count1, + count2: count2, + } + } + + trackingUpstream := &trackingUpstream{ + servers: trackingServers, + } + + rules := []*trackingProxyRule{ + { + ListenPort: listenPort, + Upstream: trackingUpstream, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + trackingProxy := &trackingTCPProxy{ + rules: rules, + logger: s.logger, + ctx: ctx, + cancel: cancel, + } + + err := trackingProxy.Start() + s.NoError(err, "Failed to start load balancer on port %s", listenPort) + + return trackingProxy +} + +func (s *IntraProxyRoutingTestSuite) configureRemoteCluster( + cluster *testcore.TestCluster, + remoteClusterName string, + proxyAddress string, +) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + _, err := cluster.AdminClient().AddOrUpdateRemoteCluster( + ctx, + &adminservice.AddOrUpdateRemoteClusterRequest{ + FrontendAddress: proxyAddress, + EnableRemoteClusterConnection: true, + }, + ) + s.NoError(err, "Failed to configure remote cluster %s", remoteClusterName) + s.logger.Info("Configured remote cluster", + tag.NewStringTag("remoteClusterName", remoteClusterName), + tag.NewStringTag("proxyAddress", proxyAddress), + tag.NewStringTag("clusterName", cluster.ClusterName()), + ) +} + +func (s *IntraProxyRoutingTestSuite) removeRemoteCluster( + cluster *testcore.TestCluster, + remoteClusterName string, +) { + _, err := cluster.AdminClient().RemoveRemoteCluster( + context.Background(), + &adminservice.RemoveRemoteClusterRequest{ + ClusterName: remoteClusterName, + }, + ) + s.NoError(err, "Failed to remove remote cluster %s", remoteClusterName) + s.logger.Info("Removed remote cluster", + tag.NewStringTag("remoteClusterName", remoteClusterName), + tag.NewStringTag("clusterName", cluster.ClusterName()), + ) +} + +func (s *IntraProxyRoutingTestSuite) waitForReplicationReady() { + time.Sleep(1 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for _, cluster := range []*testcore.TestCluster{s.clusterA, s.clusterB} { + s.Eventually(func() bool { + _, err := cluster.HistoryClient().GetReplicationStatus( + ctx, + &historyservice.GetReplicationStatusRequest{}, + ) + return err == nil + }, 5*time.Second, 200*time.Millisecond, "Replication infrastructure not ready") + } + + time.Sleep(1 * time.Second) +} + +func (s *IntraProxyRoutingTestSuite) TestIntraProxyRoutingDistribution() { + s.logger.Info("Testing intra-proxy routing distribution") + + ctx := context.Background() + + s.logger.Info("Triggering replication connections to verify distribution") + + var wg sync.WaitGroup + numConnections := 20 + + for i := 0; i < numConnections; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.clusterA.HistoryClient().GetReplicationStatus( + ctx, + &historyservice.GetReplicationStatusRequest{}, + ) + if err != nil { + s.logger.Warn("GetReplicationStatus failed", tag.Error(err)) + } + }() + } + + wg.Wait() + + time.Sleep(2 * time.Second) + + countA1 := s.connectionCountsA1.Load() + countA2 := s.connectionCountsA2.Load() + countB1 := s.connectionCountsB1.Load() + countB2 := s.connectionCountsB2.Load() + countPA1 := s.connectionCountsPA1.Load() + countPA2 := s.connectionCountsPA2.Load() + + s.logger.Info("Connection distribution results", + tag.NewInt64("loadBalancerA_pa1", countA1), + tag.NewInt64("loadBalancerA_pa2", countA2), + tag.NewInt64("loadBalancerB_pb1_from_pa", countPA1), + tag.NewInt64("loadBalancerB_pb2_from_pa", countPA2), + tag.NewInt64("loadBalancerC_pb1", countB1), + tag.NewInt64("loadBalancerC_pb2", countB2), + ) + + s.Greater(countA1, int64(0), "Load balancer A should route to pa1") + s.Greater(countA2, int64(0), "Load balancer A should route to pa2") + s.Greater(countB1, int64(0), "Load balancer C should route to pb1") + s.Greater(countB2, int64(0), "Load balancer C should route to pb2") + s.Greater(countPA1, int64(0), "Load balancer B should route to pb1 from pa") + s.Greater(countPA2, int64(0), "Load balancer B should route to pb2 from pa") + + totalA := countA1 + countA2 + totalB := countB1 + countB2 + totalPA := countPA1 + countPA2 + + s.logger.Info("Total connections", + tag.NewInt64("totalA", totalA), + tag.NewInt64("totalB", totalB), + tag.NewInt64("totalPA", totalPA), + ) + + s.Greater(totalA, int64(0), "Should have connections through load balancer A") + s.Greater(totalB, int64(0), "Should have connections through load balancer C") + s.Greater(totalPA, int64(0), "Should have connections through load balancer B") +} diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index 692f5864..d1651c67 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -3,7 +3,6 @@ package proxy import ( "context" "fmt" - "net" "sync" "testing" "time" @@ -28,6 +27,7 @@ import ( "github.com/temporalio/s2s-proxy/config" s2sproxy "github.com/temporalio/s2s-proxy/proxy" + "github.com/temporalio/s2s-proxy/testutil" ) type ( @@ -136,19 +136,6 @@ func TestReplicationFailoverTestSuite(t *testing.T) { } } -func getFreePort() int { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - panic(fmt.Sprintf("failed to get free port: %v", err)) - } - defer func() { - if err := l.Close(); err != nil { - fmt.Printf("Failed to close listener: %v\n", err) - } - }() - return l.Addr().(*net.TCPAddr).Port -} - func (s *ReplicationTestSuite) SetupSuite() { s.Assertions = require.New(s.T()) s.logger = log.NewTestLogger() @@ -162,11 +149,11 @@ func (s *ReplicationTestSuite) SetupSuite() { s.clusterA = s.createCluster("cluster-a", s.shardCountA, 1) s.clusterB = s.createCluster("cluster-b", s.shardCountB, 2) - s.proxyAAddress = fmt.Sprintf("localhost:%d", getFreePort()) - proxyAOutbound := fmt.Sprintf("localhost:%d", getFreePort()) - s.proxyBAddress = fmt.Sprintf("localhost:%d", getFreePort()) - proxyBOutbound := fmt.Sprintf("localhost:%d", getFreePort()) - muxServerAddress := fmt.Sprintf("localhost:%d", getFreePort()) + s.proxyAAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyAOutbound := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyBAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyBOutbound := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + muxServerAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) proxyBShardConfig := s.shardCountConfigB if proxyBShardConfig.Mode == config.ShardCountLCM || proxyBShardConfig.Mode == config.ShardCountRouting { diff --git a/testutil/testutil.go b/testutil/testutil.go new file mode 100644 index 00000000..28e3a15f --- /dev/null +++ b/testutil/testutil.go @@ -0,0 +1,21 @@ +package testutil + +import ( + "fmt" + "net" +) + +// GetFreePort returns an available TCP port by listening on localhost:0. +// This is useful for tests that need to allocate ports dynamically. +func GetFreePort() int { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + panic(fmt.Sprintf("failed to get free port: %v", err)) + } + defer func() { + if err := l.Close(); err != nil { + fmt.Printf("Failed to close listener: %v\n", err) + } + }() + return l.Addr().(*net.TCPAddr).Port +} From 682211304d1bebd7dd4c4fa9bde05d1a48df69e9 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Sat, 20 Dec 2025 17:20:18 -0800 Subject: [PATCH 26/38] update tests --- proxy/test/bench_test.go | 97 +++- proxy/test/echo_proxy_test.go | 125 +++--- proxy/test/intra_proxy_routing_test.go | 432 +----------------- proxy/test/replication_failover_test.go | 563 ++++++++++++------------ proxy/test/test_common.go | 476 ++++++++++++++++++++ proxy/test/wiring_test.go | 46 +- 6 files changed, 958 insertions(+), 781 deletions(-) create mode 100644 proxy/test/test_common.go diff --git a/proxy/test/bench_test.go b/proxy/test/bench_test.go index 60247c45..090b2a9e 100644 --- a/proxy/test/bench_test.go +++ b/proxy/test/bench_test.go @@ -1,6 +1,7 @@ package proxy import ( + "fmt" "testing" "go.temporal.io/server/api/adminservice/v1" @@ -10,17 +11,89 @@ import ( "github.com/temporalio/s2s-proxy/config" "github.com/temporalio/s2s-proxy/endtoendtest" + "github.com/temporalio/s2s-proxy/testutil" ) -func benchmarkStreamSendRecvWithoutProxy(b *testing.B, payloadSize int) { +func createEchoServerConfigWithPorts( + echoServerAddress string, + serverProxyInboundAddress string, + serverProxyOutboundAddress string, + opts ...cfgOption, +) *config.S2SProxyConfig { + return createS2SProxyConfig(&config.S2SProxyConfig{ + Inbound: &config.ProxyConfig{ + Name: "proxy1-inbound-server", + Server: config.ProxyServerConfig{ + TCPServerSetting: config.TCPServerSetting{ + ListenAddress: serverProxyInboundAddress, + }, + }, + Client: config.ProxyClientConfig{ + TCPClientSetting: config.TCPClientSetting{ + ServerAddress: echoServerAddress, + }, + }, + }, + Outbound: &config.ProxyConfig{ + Name: "proxy1-outbound-server", + Server: config.ProxyServerConfig{ + TCPServerSetting: config.TCPServerSetting{ + ListenAddress: serverProxyOutboundAddress, + }, + }, + Client: config.ProxyClientConfig{ + TCPClientSetting: config.TCPClientSetting{ + ServerAddress: "to-be-added", + }, + }, + }, + }, opts) +} + +func createEchoClientConfigWithPorts( + echoClientAddress string, + clientProxyInboundAddress string, + clientProxyOutboundAddress string, + opts ...cfgOption, +) *config.S2SProxyConfig { + return createS2SProxyConfig(&config.S2SProxyConfig{ + Inbound: &config.ProxyConfig{ + Name: "proxy2-inbound-server", + Server: config.ProxyServerConfig{ + TCPServerSetting: config.TCPServerSetting{ + ListenAddress: clientProxyInboundAddress, + }, + }, + Client: config.ProxyClientConfig{ + TCPClientSetting: config.TCPClientSetting{ + ServerAddress: echoClientAddress, + }, + }, + }, + Outbound: &config.ProxyConfig{ + Name: "proxy2-outbound-server", + Server: config.ProxyServerConfig{ + TCPServerSetting: config.TCPServerSetting{ + ListenAddress: clientProxyOutboundAddress, + }, + }, + Client: config.ProxyClientConfig{ + TCPClientSetting: config.TCPClientSetting{ + ServerAddress: "to-be-added", + }, + }, + }, + }, opts) +} +func benchmarkStreamSendRecvWithoutProxy(b *testing.B, payloadSize int) { echoServerInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), ClusterShardID: serverClusterShard, } echoClientInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), ClusterShardID: clientClusterShard, } @@ -31,7 +104,18 @@ func benchmarkStreamSendRecvWithMuxProxy(b *testing.B, payloadSize int) { b.Log("Start BenchmarkStreamSendRecv") muxTransportName := "muxed" - echoServerConfig := createEchoServerConfig( + // Allocate ports dynamically + echoServerAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + serverProxyInboundAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + serverProxyOutboundAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + echoClientAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + clientProxyInboundAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + clientProxyOutboundAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + + echoServerConfig := createEchoServerConfigWithPorts( + echoServerAddress, + serverProxyInboundAddress, + serverProxyOutboundAddress, withMux( config.MuxTransportConfig{ Name: muxTransportName, @@ -54,7 +138,10 @@ func benchmarkStreamSendRecvWithMuxProxy(b *testing.B, payloadSize int) { }, false), ) - echoClientConfig := createEchoClientConfig( + echoClientConfig := createEchoClientConfigWithPorts( + echoClientAddress, + clientProxyInboundAddress, + clientProxyOutboundAddress, withMux( config.MuxTransportConfig{ Name: muxTransportName, diff --git a/proxy/test/echo_proxy_test.go b/proxy/test/echo_proxy_test.go index c5ec33e8..3b4850db 100644 --- a/proxy/test/echo_proxy_test.go +++ b/proxy/test/echo_proxy_test.go @@ -1,6 +1,7 @@ package proxy import ( + "fmt" "os" "path/filepath" "testing" @@ -16,6 +17,7 @@ import ( "github.com/temporalio/s2s-proxy/config" "github.com/temporalio/s2s-proxy/endtoendtest" + "github.com/temporalio/s2s-proxy/testutil" "github.com/temporalio/s2s-proxy/transport/mux" ) @@ -25,16 +27,6 @@ func init() { mux.MuxManagerStartDelay = 0 } -const ( - echoServerAddress = "localhost:7266" - serverProxyInboundAddress = "localhost:7366" - serverProxyOutboundAddress = "localhost:7466" - echoClientAddress = "localhost:8266" - clientProxyInboundAddress = "localhost:8366" - clientProxyOutboundAddress = "localhost:8466" - invalidAddress = "" -) - var ( serverClusterShard = history.ClusterShardID{ ClusterID: 1, @@ -56,8 +48,14 @@ var ( type ( proxyTestSuite struct { suite.Suite - originalPath string - developPath string + originalPath string + developPath string + echoServerAddress string + serverProxyInboundAddress string + serverProxyOutboundAddress string + echoClientAddress string + clientProxyInboundAddress string + clientProxyOutboundAddress string } cfgOption func(c *config.S2SProxyConfig) @@ -156,18 +154,18 @@ func createS2SProxyConfig(cfg *config.S2SProxyConfig, opts []cfgOption) *config. return cfg } -func createEchoServerConfig(opts ...cfgOption) *config.S2SProxyConfig { +func (s *proxyTestSuite) createEchoServerConfig(opts ...cfgOption) *config.S2SProxyConfig { return createS2SProxyConfig(&config.S2SProxyConfig{ Inbound: &config.ProxyConfig{ Name: "proxy1-inbound-server", Server: config.ProxyServerConfig{ TCPServerSetting: config.TCPServerSetting{ - ListenAddress: serverProxyInboundAddress, + ListenAddress: s.serverProxyInboundAddress, }, }, Client: config.ProxyClientConfig{ TCPClientSetting: config.TCPClientSetting{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, }, }, }, @@ -175,7 +173,7 @@ func createEchoServerConfig(opts ...cfgOption) *config.S2SProxyConfig { Name: "proxy1-outbound-server", Server: config.ProxyServerConfig{ TCPServerSetting: config.TCPServerSetting{ - ListenAddress: serverProxyOutboundAddress, + ListenAddress: s.serverProxyOutboundAddress, }, }, Client: config.ProxyClientConfig{ @@ -210,18 +208,18 @@ func EchoClientTLSOptions() []cfgOption { } } -func createEchoClientConfig(opts ...cfgOption) *config.S2SProxyConfig { +func (s *proxyTestSuite) createEchoClientConfig(opts ...cfgOption) *config.S2SProxyConfig { return createS2SProxyConfig(&config.S2SProxyConfig{ Inbound: &config.ProxyConfig{ Name: "proxy2-inbound-server", Server: config.ProxyServerConfig{ TCPServerSetting: config.TCPServerSetting{ - ListenAddress: clientProxyInboundAddress, + ListenAddress: s.clientProxyInboundAddress, }, }, Client: config.ProxyClientConfig{ TCPClientSetting: config.TCPClientSetting{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, }, }, }, @@ -229,7 +227,7 @@ func createEchoClientConfig(opts ...cfgOption) *config.S2SProxyConfig { Name: "proxy2-outbound-server", Server: config.ProxyServerConfig{ TCPServerSetting: config.TCPServerSetting{ - ListenAddress: clientProxyOutboundAddress, + ListenAddress: s.clientProxyOutboundAddress, }, }, Client: config.ProxyClientConfig{ @@ -252,6 +250,14 @@ func (s *proxyTestSuite) SetupTest() { s.developPath = filepath.Join("..", "..", "develop") err = os.Chdir(s.developPath) s.NoError(err) + + // Allocate free ports for each test + s.echoServerAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.serverProxyInboundAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.serverProxyOutboundAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.echoClientAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.clientProxyInboundAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.clientProxyOutboundAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) } func (s *proxyTestSuite) TearDownTest() { @@ -300,11 +306,11 @@ func (s *proxyTestSuite) Test_Echo_Success() { // echo_server <- - -> echo_client name: "no-proxy", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, }, }, @@ -312,12 +318,12 @@ func (s *proxyTestSuite) Test_Echo_Success() { // echo_server <-> proxy.inbound <- - -> echo_client name: "server-side-only-proxy", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig(), + S2sProxyConfig: s.createEchoServerConfig(), }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, }, }, @@ -325,49 +331,49 @@ func (s *proxyTestSuite) Test_Echo_Success() { // echo_server <- - -> proxy.outbound <-> echo_client name: "client-side-only-proxy", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), }, }, { // echo_server <-> proxy.inbound <- - -> proxy.outbound <-> echo_client name: "server-and-client-side-proxy", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig(), + S2sProxyConfig: s.createEchoServerConfig(), }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), }, }, { // echo_server <-> proxy.inbound <- mTLS -> proxy.outbound <-> echo_client name: "server-and-client-side-proxy-mTLS", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig(EchoServerTLSOptions()...), + S2sProxyConfig: s.createEchoServerConfig(EchoServerTLSOptions()...), }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(EchoClientTLSOptions()...), + S2sProxyConfig: s.createEchoClientConfig(EchoClientTLSOptions()...), }, }, { name: "server-and-client-side-proxy-ACL", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig(withACLPolicy( + S2sProxyConfig: s.createEchoServerConfig(withACLPolicy( &config.ACLPolicy{ AllowedMethods: config.AllowedMethods{ AdminService: []string{ @@ -383,9 +389,9 @@ func (s *proxyTestSuite) Test_Echo_Success() { )), }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), }, }, } @@ -439,9 +445,9 @@ func (s *proxyTestSuite) Test_Echo_WithNamespaceTranslation() { { name: "server-and-client-side-proxy-namespacetrans", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig(withNamespaceTranslation( + S2sProxyConfig: s.createEchoServerConfig(withNamespaceTranslation( []config.NameMappingConfig{ { LocalName: "local", @@ -452,9 +458,9 @@ func (s *proxyTestSuite) Test_Echo_WithNamespaceTranslation() { )), }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), }, serverNamespace: "local", clientNamespace: "remote", @@ -462,9 +468,9 @@ func (s *proxyTestSuite) Test_Echo_WithNamespaceTranslation() { { name: "server-and-client-side-proxy-namespacetrans-acl", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig( + S2sProxyConfig: s.createEchoServerConfig( withNamespaceTranslation( []config.NameMappingConfig{ { @@ -489,9 +495,9 @@ func (s *proxyTestSuite) Test_Echo_WithNamespaceTranslation() { ), )}, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), }, serverNamespace: "local", clientNamespace: "remote", @@ -533,13 +539,13 @@ func (s *proxyTestSuite) Test_Echo_WithMuxTransport() { // // echoServer proxy1.inbound.Server(muxClient) <- proxy2.outbound.Client(muxServer) echoClient // echoServer proxy1.outbound.Client(muxClient) -> proxy2.inbound.Server(muxServer) echoClient - echoServerConfig := createEchoServerConfig( + echoServerConfig := s.createEchoServerConfig( withMux( config.MuxTransportConfig{ Name: muxTransportName, Mode: config.ClientMode, Client: config.TCPClientSetting{ - ServerAddress: clientProxyInboundAddress, + ServerAddress: s.clientProxyInboundAddress, }, }), withServerConfig( @@ -556,13 +562,13 @@ func (s *proxyTestSuite) Test_Echo_WithMuxTransport() { }, false), ) - echoClientConfig := createEchoClientConfig( + echoClientConfig := s.createEchoClientConfig( withMux( config.MuxTransportConfig{ Name: muxTransportName, Mode: config.ServerMode, Server: config.TCPServerSetting{ - ListenAddress: clientProxyInboundAddress, + ListenAddress: s.clientProxyInboundAddress, }, }), withServerConfig( @@ -580,12 +586,12 @@ func (s *proxyTestSuite) Test_Echo_WithMuxTransport() { ) echoServerInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, S2sProxyConfig: echoServerConfig, } echoClientInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, S2sProxyConfig: echoClientConfig, } @@ -614,14 +620,14 @@ func (s *proxyTestSuite) Test_ForceStopSourceServer() { logger := log.NewTestLogger() echoServerInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, } echoClientInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), } echoServer := endtoendtest.NewEchoServer(echoServerInfo, echoClientInfo, "EchoServer", logger, nil) @@ -629,6 +635,10 @@ func (s *proxyTestSuite) Test_ForceStopSourceServer() { echoServer.Start() echoClient.Start() + defer func() { + echoClient.Stop() + echoServer.Stop() + }() stream, err := echoClient.CreateStreamClient() s.NoError(err) @@ -649,5 +659,4 @@ func (s *proxyTestSuite) Test_ForceStopSourceServer() { s.ErrorContains(err, "EOF") _ = stream.CloseSend() - echoClient.Stop() } diff --git a/proxy/test/intra_proxy_routing_test.go b/proxy/test/intra_proxy_routing_test.go index d4656072..87a132d0 100644 --- a/proxy/test/intra_proxy_routing_test.go +++ b/proxy/test/intra_proxy_routing_test.go @@ -3,8 +3,6 @@ package proxy import ( "context" "fmt" - "io" - "net" "sync" "sync/atomic" "testing" @@ -12,11 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/api/historyservice/v1" - "go.temporal.io/server/common" - "go.temporal.io/server/common/cluster" - "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/tests/testcore" @@ -82,8 +76,8 @@ func (s *IntraProxyRoutingTestSuite) SetupSuite() { s.logger.Info("Setting up intra-proxy routing test suite") - s.clusterA = s.createCluster("cluster-a", 2, 1, 1) - s.clusterB = s.createCluster("cluster-b", 2, 2, 1) + s.clusterA = createCluster(s.logger, s.T(), "cluster-a", 2, 1, 1) + s.clusterB = createCluster(s.logger, s.T(), "cluster-b", 2, 2, 1) s.proxyA1Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) s.proxyA2Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) @@ -120,36 +114,40 @@ func (s *IntraProxyRoutingTestSuite) SetupSuite() { s.proxyB1MemberlistPort = testutil.GetFreePort() s.proxyB2MemberlistPort = testutil.GetFreePort() - s.proxyB1 = s.createProxy("proxy-b-1", proxyB1Address, s.proxyB1Outbound, s.proxyB1Mux, s.clusterB, config.ServerMode, config.ShardCountConfig{}, "proxy-node-b-1", "127.0.0.1", s.proxyB1MemberlistPort, nil, proxyAddressesB) - s.proxyB2 = s.createProxy("proxy-b-2", proxyB2Address, s.proxyB2Outbound, s.proxyB2Mux, s.clusterB, config.ServerMode, config.ShardCountConfig{}, "proxy-node-b-2", "127.0.0.1", s.proxyB2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyB1MemberlistPort)}, proxyAddressesB) + s.proxyB1 = createProxy(s.logger, s.T(), "proxy-b-1", proxyB1Address, s.proxyB1Outbound, s.proxyB1Mux, s.clusterB, config.ServerMode, config.ShardCountConfig{}, "proxy-node-b-1", "127.0.0.1", s.proxyB1MemberlistPort, nil, proxyAddressesB) + s.proxyB2 = createProxy(s.logger, s.T(), "proxy-b-2", proxyB2Address, s.proxyB2Outbound, s.proxyB2Mux, s.clusterB, config.ServerMode, config.ShardCountConfig{}, "proxy-node-b-2", "127.0.0.1", s.proxyB2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyB1MemberlistPort)}, proxyAddressesB) s.logger.Info("Setting up load balancers") - s.loadBalancerA = s.createLoadBalancer(loadBalancerAPort, []string{s.proxyA1Outbound, s.proxyA2Outbound}, &s.connectionCountsA1, &s.connectionCountsA2) - s.loadBalancerB = s.createLoadBalancer(loadBalancerBPort, []string{s.proxyB1Mux, s.proxyB2Mux}, &s.connectionCountsPA1, &s.connectionCountsPA2) - s.loadBalancerC = s.createLoadBalancer(loadBalancerCPort, []string{s.proxyB1Outbound, s.proxyB2Outbound}, &s.connectionCountsB1, &s.connectionCountsB2) + var err error + s.loadBalancerA, err = createLoadBalancer(s.logger, loadBalancerAPort, []string{s.proxyA1Outbound, s.proxyA2Outbound}, &s.connectionCountsA1, &s.connectionCountsA2) + s.NoError(err, "Failed to start load balancer A") + s.loadBalancerB, err = createLoadBalancer(s.logger, loadBalancerBPort, []string{s.proxyB1Mux, s.proxyB2Mux}, &s.connectionCountsPA1, &s.connectionCountsPA2) + s.NoError(err, "Failed to start load balancer B") + s.loadBalancerC, err = createLoadBalancer(s.logger, loadBalancerCPort, []string{s.proxyB1Outbound, s.proxyB2Outbound}, &s.connectionCountsB1, &s.connectionCountsB2) + s.NoError(err, "Failed to start load balancer C") muxLoadBalancerBAddress := fmt.Sprintf("localhost:%s", loadBalancerBPort) - s.proxyA1 = s.createProxy("proxy-a-1", proxyA1Address, s.proxyA1Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-1", "127.0.0.1", s.proxyA1MemberlistPort, nil, proxyAddressesA) - s.proxyA2 = s.createProxy("proxy-a-2", proxyA2Address, s.proxyA2Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-2", "127.0.0.1", s.proxyA2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyA1MemberlistPort)}, proxyAddressesA) + s.proxyA1 = createProxy(s.logger, s.T(), "proxy-a-1", proxyA1Address, s.proxyA1Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-1", "127.0.0.1", s.proxyA1MemberlistPort, nil, proxyAddressesA) + s.proxyA2 = createProxy(s.logger, s.T(), "proxy-a-2", proxyA2Address, s.proxyA2Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-2", "127.0.0.1", s.proxyA2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyA1MemberlistPort)}, proxyAddressesA) s.logger.Info("Waiting for proxies to start and connect") time.Sleep(15 * time.Second) s.logger.Info("Configuring remote clusters") - s.configureRemoteCluster(s.clusterA, s.clusterB.ClusterName(), fmt.Sprintf("localhost:%s", loadBalancerAPort)) - s.configureRemoteCluster(s.clusterB, s.clusterA.ClusterName(), fmt.Sprintf("localhost:%s", loadBalancerCPort)) - s.waitForReplicationReady() + configureRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName(), fmt.Sprintf("localhost:%s", loadBalancerAPort)) + configureRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName(), fmt.Sprintf("localhost:%s", loadBalancerCPort)) + waitForReplicationReady(s.logger, s.T(), s.clusterA, s.clusterB) } func (s *IntraProxyRoutingTestSuite) TearDownSuite() { s.logger.Info("Tearing down intra-proxy routing test suite") if s.clusterA != nil && s.clusterB != nil { s.logger.Info("Removing remote cluster A from cluster B") - s.removeRemoteCluster(s.clusterA, s.clusterB.ClusterName()) + removeRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName()) s.logger.Info("Remote cluster A removed") s.logger.Info("Removing remote cluster B from cluster A") - s.removeRemoteCluster(s.clusterB, s.clusterA.ClusterName()) + removeRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName()) s.logger.Info("Remote cluster B removed") } if s.clusterA != nil { @@ -198,400 +196,6 @@ func (s *IntraProxyRoutingTestSuite) TearDownSuite() { s.logger.Info("Intra-proxy routing test suite torn down") } -func (s *IntraProxyRoutingTestSuite) createCluster( - clusterName string, - numShards int, - initialFailoverVersion int64, - numHistoryHosts int, -) *testcore.TestCluster { - clusterSuffix := common.GenerateRandomString(8) - fullClusterName := fmt.Sprintf("%s-%s", clusterName, clusterSuffix) - - clusterConfig := &testcore.TestClusterConfig{ - ClusterMetadata: cluster.Config{ - EnableGlobalNamespace: true, - FailoverVersionIncrement: 10, - MasterClusterName: fullClusterName, - CurrentClusterName: fullClusterName, - ClusterInformation: map[string]cluster.ClusterInformation{ - fullClusterName: { - Enabled: true, - InitialFailoverVersion: initialFailoverVersion, - }, - }, - }, - HistoryConfig: testcore.HistoryConfig{ - NumHistoryShards: int32(numShards), - NumHistoryHosts: numHistoryHosts, - }, - DynamicConfigOverrides: map[dynamicconfig.Key]interface{}{ - dynamicconfig.NamespaceCacheRefreshInterval.Key(): time.Second, - dynamicconfig.EnableReplicationStream.Key(): true, - dynamicconfig.EnableReplicationTaskBatching.Key(): true, - }, - } - - testClusterFactory := testcore.NewTestClusterFactory() - logger := log.With(s.logger, tag.NewStringTag("clusterName", clusterName)) - cluster, err := testClusterFactory.NewCluster(s.T(), clusterConfig, logger) - s.NoError(err, "Failed to create cluster %s", clusterName) - s.NotNil(cluster) - - return cluster -} - -func (s *IntraProxyRoutingTestSuite) createProxy( - name string, - inboundAddress string, - outboundAddress string, - muxAddress string, - cluster *testcore.TestCluster, - muxMode config.MuxMode, - shardCountConfig config.ShardCountConfig, - nodeName string, - memberlistBindAddr string, - memberlistBindPort int, - memberlistJoinAddrs []string, - proxyAddresses map[string]string, -) *s2sproxy.Proxy { - var muxConnectionType config.ConnectionType - var muxAddressInfo config.TCPTLSInfo - if muxMode == config.ServerMode { - muxConnectionType = config.ConnTypeMuxServer - muxAddressInfo = config.TCPTLSInfo{ - ConnectionString: muxAddress, - } - } else { - muxConnectionType = config.ConnTypeMuxClient - muxAddressInfo = config.TCPTLSInfo{ - ConnectionString: muxAddress, - } - } - - cfg := &config.S2SProxyConfig{ - ClusterConnections: []config.ClusterConnConfig{ - { - Name: name, - LocalServer: config.ClusterDefinition{ - Connection: config.TransportInfo{ - ConnectionType: config.ConnTypeTCP, - TcpClient: config.TCPTLSInfo{ - ConnectionString: cluster.Host().FrontendGRPCAddress(), - }, - TcpServer: config.TCPTLSInfo{ - ConnectionString: outboundAddress, - }, - }, - }, - RemoteServer: config.ClusterDefinition{ - Connection: config.TransportInfo{ - ConnectionType: muxConnectionType, - MuxCount: 1, - MuxAddressInfo: muxAddressInfo, - }, - }, - ShardCountConfig: shardCountConfig, - MemberlistConfig: &config.MemberlistConfig{ - Enabled: true, - NodeName: nodeName, - BindAddr: memberlistBindAddr, - BindPort: memberlistBindPort, - JoinAddrs: memberlistJoinAddrs, - ProxyAddresses: proxyAddresses, - TCPOnly: true, - }, - }, - }, - } - - configProvider := &simpleConfigProvider{cfg: *cfg} - proxy := s2sproxy.NewProxy(configProvider, s.logger) - s.NotNil(proxy) - - err := proxy.Start() - s.NoError(err, "Failed to start proxy %s", name) - - s.logger.Info("Started proxy", tag.NewStringTag("name", name), - tag.NewStringTag("inboundAddress", inboundAddress), - tag.NewStringTag("outboundAddress", outboundAddress), - tag.NewStringTag("muxAddress", muxAddress), - tag.NewStringTag("muxMode", string(muxMode)), - tag.NewStringTag("nodeName", nodeName), - ) - - return proxy -} - -type trackingUpstreamServer struct { - address string - conns atomic.Int64 - count1 *atomic.Int64 - count2 *atomic.Int64 -} - -type trackingUpstream struct { - servers []*trackingUpstreamServer - mu sync.RWMutex -} - -func (u *trackingUpstream) selectLeastConn() *trackingUpstreamServer { - u.mu.RLock() - defer u.mu.RUnlock() - - if len(u.servers) == 0 { - return nil - } - - selected := u.servers[0] - minConns := selected.conns.Load() - - for i := 1; i < len(u.servers); i++ { - conns := u.servers[i].conns.Load() - if conns < minConns { - minConns = conns - selected = u.servers[i] - } - } - - if selected != nil { - if selected == u.servers[0] { - selected.count1.Add(1) - } else if len(u.servers) > 1 && selected == u.servers[1] { - selected.count2.Add(1) - } - } - - return selected -} - -func (u *trackingUpstream) incrementConn(server *trackingUpstreamServer) { - server.conns.Add(1) -} - -func (u *trackingUpstream) decrementConn(server *trackingUpstreamServer) { - server.conns.Add(-1) -} - -type trackingTCPProxy struct { - rules []*trackingProxyRule - logger log.Logger - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup - servers []net.Listener -} - -type trackingProxyRule struct { - ListenPort string - Upstream *trackingUpstream -} - -func (p *trackingTCPProxy) Start() error { - for _, rule := range p.rules { - listener, err := net.Listen("tcp", ":"+rule.ListenPort) - if err != nil { - p.Stop() - return fmt.Errorf("failed to listen on port %s: %w", rule.ListenPort, err) - } - p.servers = append(p.servers, listener) - - p.wg.Add(1) - go p.handleListener(listener, rule) - } - - return nil -} - -func (p *trackingTCPProxy) Stop() { - p.logger.Info("Stopping tracking TCP proxy") - p.cancel() - for _, server := range p.servers { - p.logger.Info("Closing server", tag.NewStringTag("server", server.Addr().String())) - _ = server.Close() - } - p.logger.Info("Waiting for goroutines to finish") - p.wg.Wait() - p.logger.Info("Tracking TCP proxy stopped") -} - -func (p *trackingTCPProxy) handleListener(listener net.Listener, rule *trackingProxyRule) { - defer p.wg.Done() - - for { - select { - case <-p.ctx.Done(): - return - default: - } - - clientConn, err := listener.Accept() - if err != nil { - select { - case <-p.ctx.Done(): - return - default: - p.logger.Warn("failed to accept connection", tag.Error(err)) - continue - } - } - - p.wg.Add(1) - go p.handleConnection(clientConn, rule) - } -} - -func (p *trackingTCPProxy) handleConnection(clientConn net.Conn, rule *trackingProxyRule) { - defer p.wg.Done() - defer func() { _ = clientConn.Close() }() - - // Check if already cancelled - select { - case <-p.ctx.Done(): - return - default: - } - - upstream := rule.Upstream.selectLeastConn() - if upstream == nil { - p.logger.Error("no upstream servers available") - return - } - - rule.Upstream.incrementConn(upstream) - defer rule.Upstream.decrementConn(upstream) - - serverConn, err := net.DialTimeout("tcp", upstream.address, 5*time.Second) - if err != nil { - p.logger.Warn("failed to connect to upstream", tag.NewStringTag("upstream", upstream.address), tag.Error(err)) - return - } - defer func() { _ = serverConn.Close() }() - - // Close connections when context is cancelled to unblock io.Copy - var wg sync.WaitGroup - wg.Add(3) - - go func() { - defer wg.Done() - <-p.ctx.Done() - _ = clientConn.Close() - _ = serverConn.Close() - }() - - go func() { - defer wg.Done() - _, _ = io.Copy(serverConn, clientConn) - _ = serverConn.Close() - }() - - go func() { - defer wg.Done() - _, _ = io.Copy(clientConn, serverConn) - _ = clientConn.Close() - }() - - wg.Wait() -} - -func (s *IntraProxyRoutingTestSuite) createLoadBalancer( - listenPort string, - upstreams []string, - count1 *atomic.Int64, - count2 *atomic.Int64, -) *trackingTCPProxy { - trackingServers := make([]*trackingUpstreamServer, len(upstreams)) - for i, addr := range upstreams { - trackingServers[i] = &trackingUpstreamServer{ - address: addr, - count1: count1, - count2: count2, - } - } - - trackingUpstream := &trackingUpstream{ - servers: trackingServers, - } - - rules := []*trackingProxyRule{ - { - ListenPort: listenPort, - Upstream: trackingUpstream, - }, - } - - ctx, cancel := context.WithCancel(context.Background()) - trackingProxy := &trackingTCPProxy{ - rules: rules, - logger: s.logger, - ctx: ctx, - cancel: cancel, - } - - err := trackingProxy.Start() - s.NoError(err, "Failed to start load balancer on port %s", listenPort) - - return trackingProxy -} - -func (s *IntraProxyRoutingTestSuite) configureRemoteCluster( - cluster *testcore.TestCluster, - remoteClusterName string, - proxyAddress string, -) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - _, err := cluster.AdminClient().AddOrUpdateRemoteCluster( - ctx, - &adminservice.AddOrUpdateRemoteClusterRequest{ - FrontendAddress: proxyAddress, - EnableRemoteClusterConnection: true, - }, - ) - s.NoError(err, "Failed to configure remote cluster %s", remoteClusterName) - s.logger.Info("Configured remote cluster", - tag.NewStringTag("remoteClusterName", remoteClusterName), - tag.NewStringTag("proxyAddress", proxyAddress), - tag.NewStringTag("clusterName", cluster.ClusterName()), - ) -} - -func (s *IntraProxyRoutingTestSuite) removeRemoteCluster( - cluster *testcore.TestCluster, - remoteClusterName string, -) { - _, err := cluster.AdminClient().RemoveRemoteCluster( - context.Background(), - &adminservice.RemoveRemoteClusterRequest{ - ClusterName: remoteClusterName, - }, - ) - s.NoError(err, "Failed to remove remote cluster %s", remoteClusterName) - s.logger.Info("Removed remote cluster", - tag.NewStringTag("remoteClusterName", remoteClusterName), - tag.NewStringTag("clusterName", cluster.ClusterName()), - ) -} - -func (s *IntraProxyRoutingTestSuite) waitForReplicationReady() { - time.Sleep(1 * time.Second) - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - for _, cluster := range []*testcore.TestCluster{s.clusterA, s.clusterB} { - s.Eventually(func() bool { - _, err := cluster.HistoryClient().GetReplicationStatus( - ctx, - &historyservice.GetReplicationStatusRequest{}, - ) - return err == nil - }, 5*time.Second, 200*time.Millisecond, "Replication infrastructure not ready") - } - - time.Sleep(1 * time.Second) -} - func (s *IntraProxyRoutingTestSuite) TestIntraProxyRoutingDistribution() { s.logger.Info("Testing intra-proxy routing distribution") diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index d1651c67..660dba29 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "sync/atomic" "testing" "time" @@ -14,11 +15,8 @@ import ( replicationpb "go.temporal.io/api/replication/v1" taskqueuepb "go.temporal.io/api/taskqueue/v1" "go.temporal.io/api/workflowservice/v1" - "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/common" - "go.temporal.io/server/common/cluster" - "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/namespace" @@ -30,6 +28,13 @@ import ( "github.com/temporalio/s2s-proxy/testutil" ) +type SetupMode string + +const ( + SetupModeSimple SetupMode = "simple" // Case B: Two proxies, direct connection, no load balancer, no memberlist + SetupModeMultiProxy SetupMode = "multiproxy" // Case A: Multi-proxy with load balancers and memberlist +) + type ( // ReplicationTestSuite tests s2s-proxy replication and failover across multiple shard configurations ReplicationTestSuite struct { @@ -41,11 +46,41 @@ type ( clusterA *testcore.TestCluster clusterB *testcore.TestCluster + // Case B: Simple setup proxyA *s2sproxy.Proxy proxyB *s2sproxy.Proxy - proxyAAddress string - proxyBAddress string + proxyAOutbound string + proxyBOutbound string + + // Case A: Multi-proxy setup + proxyA1 *s2sproxy.Proxy + proxyA2 *s2sproxy.Proxy + proxyB1 *s2sproxy.Proxy + proxyB2 *s2sproxy.Proxy + + proxyA1Outbound string + proxyA2Outbound string + proxyB1Outbound string + proxyB2Outbound string + + proxyB1Mux string + proxyB2Mux string + + proxyA1MemberlistPort int + proxyA2MemberlistPort int + proxyB1MemberlistPort int + proxyB2MemberlistPort int + + loadBalancerA *trackingTCPProxy + loadBalancerB *trackingTCPProxy + loadBalancerC *trackingTCPProxy + + loadBalancerAPort string + loadBalancerBPort string + loadBalancerCPort string + + setupMode SetupMode shardCountA int shardCountB int @@ -74,51 +109,108 @@ type ( ShardCountB int WorkflowsPerPair int ShardCountConfigB config.ShardCountConfig + SetupMode SetupMode } ) var testConfigs = []TestConfig{ + // Case B: Simple setup tests { - Name: "SingleShard", + Name: "Simple_SingleShard", ShardCountA: 1, ShardCountB: 1, WorkflowsPerPair: 1, + SetupMode: SetupModeSimple, }, { - Name: "FourShards", + Name: "Simple_FourShards", ShardCountA: 4, ShardCountB: 4, WorkflowsPerPair: 1, + SetupMode: SetupModeSimple, }, { - Name: "AsymmetricShards_4to2", + Name: "Simple_AsymmetricShards_4to2", ShardCountA: 4, ShardCountB: 2, WorkflowsPerPair: 1, + SetupMode: SetupModeSimple, }, { - Name: "AsymmetricShards_2to4", + Name: "Simple_AsymmetricShards_2to4", ShardCountA: 2, ShardCountB: 4, WorkflowsPerPair: 1, + SetupMode: SetupModeSimple, }, { - Name: "ArbitraryShards_2to3_LCM", + Name: "Simple_ArbitraryShards_2to3_LCM", ShardCountA: 2, ShardCountB: 3, WorkflowsPerPair: 1, ShardCountConfigB: config.ShardCountConfig{ Mode: config.ShardCountLCM, }, + SetupMode: SetupModeSimple, }, { - Name: "ArbitraryShards_2to3_Routing", + Name: "Simple_ArbitraryShards_2to3_Routing", ShardCountA: 2, ShardCountB: 3, WorkflowsPerPair: 1, ShardCountConfigB: config.ShardCountConfig{ Mode: config.ShardCountRouting, }, + SetupMode: SetupModeSimple, + }, + // Case A: Multi-proxy setup tests + { + Name: "MultiProxy_SingleShard", + ShardCountA: 1, + ShardCountB: 1, + WorkflowsPerPair: 1, + SetupMode: SetupModeMultiProxy, + }, + { + Name: "MultiProxy_FourShards", + ShardCountA: 4, + ShardCountB: 4, + WorkflowsPerPair: 1, + SetupMode: SetupModeMultiProxy, + }, + { + Name: "MultiProxy_AsymmetricShards_4to2", + ShardCountA: 4, + ShardCountB: 2, + WorkflowsPerPair: 1, + SetupMode: SetupModeMultiProxy, + }, + { + Name: "MultiProxy_AsymmetricShards_2to4", + ShardCountA: 2, + ShardCountB: 4, + WorkflowsPerPair: 1, + SetupMode: SetupModeMultiProxy, + }, + { + Name: "MultiProxy_ArbitraryShards_2to3_LCM", + ShardCountA: 2, + ShardCountB: 3, + WorkflowsPerPair: 1, + ShardCountConfigB: config.ShardCountConfig{ + Mode: config.ShardCountLCM, + }, + SetupMode: SetupModeMultiProxy, + }, + { + Name: "MultiProxy_ArbitraryShards_2to3_Routing", + ShardCountA: 2, + ShardCountB: 3, + WorkflowsPerPair: 1, + ShardCountConfigB: config.ShardCountConfig{ + Mode: config.ShardCountRouting, + }, + SetupMode: SetupModeMultiProxy, }, } @@ -130,6 +222,7 @@ func TestReplicationFailoverTestSuite(t *testing.T) { shardCountB: tc.ShardCountB, shardCountConfigB: tc.ShardCountConfigB, workflowsPerPair: tc.WorkflowsPerPair, + setupMode: tc.SetupMode, } suite.Run(t, s) }) @@ -144,268 +237,172 @@ func (s *ReplicationTestSuite) SetupSuite() { s.logger.Info("Setting up replication test suite", tag.NewInt("shardCountA", s.shardCountA), tag.NewInt("shardCountB", s.shardCountB), + tag.NewStringTag("setupMode", string(s.setupMode)), ) - s.clusterA = s.createCluster("cluster-a", s.shardCountA, 1) - s.clusterB = s.createCluster("cluster-b", s.shardCountB, 2) + s.clusterA = createCluster(s.logger, s.T(), "cluster-a", s.shardCountA, 1, 1) + s.clusterB = createCluster(s.logger, s.T(), "cluster-b", s.shardCountB, 2, 1) - s.proxyAAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - proxyAOutbound := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.proxyBAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - proxyBOutbound := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - muxServerAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - - proxyBShardConfig := s.shardCountConfigB - if proxyBShardConfig.Mode == config.ShardCountLCM || proxyBShardConfig.Mode == config.ShardCountRouting { - proxyBShardConfig.LocalShardCount = int32(s.shardCountB) - proxyBShardConfig.RemoteShardCount = int32(s.shardCountA) + if s.setupMode == SetupModeSimple { + s.setupSimple() + } else { + s.setupMultiProxy() } - s.proxyA = s.createProxy("proxy-a", s.proxyAAddress, proxyAOutbound, muxServerAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}) - s.proxyB = s.createProxy("proxy-b", s.proxyBAddress, proxyBOutbound, muxServerAddress, s.clusterB, config.ServerMode, proxyBShardConfig) - s.logger.Info("Waiting for proxies to start and connect") - time.Sleep(10 * time.Second) // TODO: remove this once we have a better way to wait for proxies to start and connect + time.Sleep(10 * time.Second) s.logger.Info("Configuring remote clusters") - s.configureRemoteCluster(s.clusterA, s.clusterB.ClusterName(), proxyAOutbound) - s.configureRemoteCluster(s.clusterB, s.clusterA.ClusterName(), proxyBOutbound) - s.waitForReplicationReady() + if s.setupMode == SetupModeSimple { + configureRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName(), s.proxyAOutbound) + configureRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName(), s.proxyBOutbound) + } else { + configureRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName(), fmt.Sprintf("localhost:%s", s.loadBalancerAPort)) + configureRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName(), fmt.Sprintf("localhost:%s", s.loadBalancerCPort)) + } + + waitForReplicationReady(s.logger, s.T(), s.clusterA, s.clusterB) s.namespace = s.createGlobalNamespace() s.waitForClusterSynced() } -func (s *ReplicationTestSuite) TearDownSuite() { - if s.namespace != "" && s.clusterA != nil { - s.deglobalizeNamespace(s.namespace) - } +func (s *ReplicationTestSuite) setupSimple() { + s.logger.Info("Setting up simple two-proxy configuration") - if s.clusterA != nil && s.clusterB != nil { - s.removeRemoteCluster(s.clusterA, s.clusterB.ClusterName()) - s.removeRemoteCluster(s.clusterB, s.clusterA.ClusterName()) - } - if s.clusterA != nil { - s.NoError(s.clusterA.TearDownCluster()) - } - if s.clusterB != nil { - s.NoError(s.clusterB.TearDownCluster()) - } - if s.proxyA != nil { - s.proxyA.Stop() - } - if s.proxyB != nil { - s.proxyB.Stop() + proxyAOutbound := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyBOutbound := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + muxServerAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + + s.proxyAOutbound = proxyAOutbound + s.proxyBOutbound = proxyBOutbound + + proxyBShardConfig := s.shardCountConfigB + if proxyBShardConfig.Mode == config.ShardCountLCM || proxyBShardConfig.Mode == config.ShardCountRouting { + proxyBShardConfig.LocalShardCount = int32(s.shardCountB) + proxyBShardConfig.RemoteShardCount = int32(s.shardCountA) } + s.proxyA = createProxy(s.logger, s.T(), "proxy-a", "", proxyAOutbound, muxServerAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "", "", 0, nil, nil) + s.proxyB = createProxy(s.logger, s.T(), "proxy-b", "", proxyBOutbound, muxServerAddress, s.clusterB, config.ServerMode, proxyBShardConfig, "", "", 0, nil, nil) } -func (s *ReplicationTestSuite) SetupTest() { - s.workflows = nil +func (s *ReplicationTestSuite) setupMultiProxy() { + s.logger.Info("Setting up multi-proxy configuration with load balancers") - if s.namespace != "" { - s.ensureNamespaceActive(s.clusterA.ClusterName()) - } -} + s.proxyA1Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyA2Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyB1Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyB2Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) -func (s *ReplicationTestSuite) createCluster( - clusterName string, - numShards int, - initialFailoverVersion int64, -) *testcore.TestCluster { - clusterSuffix := common.GenerateRandomString(8) - fullClusterName := fmt.Sprintf("%s-%s", clusterName, clusterSuffix) - - clusterConfig := &testcore.TestClusterConfig{ - ClusterMetadata: cluster.Config{ - EnableGlobalNamespace: true, - FailoverVersionIncrement: 10, - MasterClusterName: fullClusterName, - CurrentClusterName: fullClusterName, - ClusterInformation: map[string]cluster.ClusterInformation{ - fullClusterName: { - Enabled: true, - InitialFailoverVersion: initialFailoverVersion, - }, - }, - }, - HistoryConfig: testcore.HistoryConfig{ - NumHistoryShards: int32(numShards), - NumHistoryHosts: 1, - }, - DynamicConfigOverrides: map[dynamicconfig.Key]interface{}{ - dynamicconfig.NamespaceCacheRefreshInterval.Key(): time.Second, - dynamicconfig.EnableReplicationStream.Key(): true, - dynamicconfig.EnableReplicationTaskBatching.Key(): true, - }, - } + s.proxyB1Mux = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyB2Mux = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - testClusterFactory := testcore.NewTestClusterFactory() - logger := log.With(s.logger, tag.NewStringTag("clusterName", clusterName)) - cluster, err := testClusterFactory.NewCluster(s.T(), clusterConfig, logger) - s.NoError(err, "Failed to create cluster %s", clusterName) - s.NotNil(cluster) + loadBalancerAPort := fmt.Sprintf("%d", testutil.GetFreePort()) + loadBalancerBPort := fmt.Sprintf("%d", testutil.GetFreePort()) + loadBalancerCPort := fmt.Sprintf("%d", testutil.GetFreePort()) - return cluster -} + s.loadBalancerAPort = loadBalancerAPort + s.loadBalancerBPort = loadBalancerBPort + s.loadBalancerCPort = loadBalancerCPort -func (s *ReplicationTestSuite) createProxy( - name string, - inboundAddress string, - outboundAddress string, - muxAddress string, - cluster *testcore.TestCluster, - muxMode config.MuxMode, - shardCountConfig config.ShardCountConfig, -) *s2sproxy.Proxy { - var muxConnectionType config.ConnectionType - var muxAddressInfo config.TCPTLSInfo - if muxMode == config.ServerMode { - muxConnectionType = config.ConnTypeMuxServer - muxAddressInfo = config.TCPTLSInfo{ - ConnectionString: muxAddress, - } - } else { - muxConnectionType = config.ConnTypeMuxClient - muxAddressInfo = config.TCPTLSInfo{ - ConnectionString: muxAddress, - } - } + proxyA1Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyA2Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyB1Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyB2Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - cfg := &config.S2SProxyConfig{ - ClusterConnections: []config.ClusterConnConfig{ - { - Name: name, - LocalServer: config.ClusterDefinition{ - Connection: config.TransportInfo{ - ConnectionType: config.ConnTypeTCP, - TcpClient: config.TCPTLSInfo{ - ConnectionString: cluster.Host().FrontendGRPCAddress(), - }, - TcpServer: config.TCPTLSInfo{ - ConnectionString: outboundAddress, - }, - }, - }, - RemoteServer: config.ClusterDefinition{ - Connection: config.TransportInfo{ - ConnectionType: muxConnectionType, - MuxCount: 1, - MuxAddressInfo: muxAddressInfo, - }, - }, - ShardCountConfig: shardCountConfig, - }, - }, + proxyAddressesA := map[string]string{ + "proxy-node-a-1": proxyA1Address, + "proxy-node-a-2": proxyA2Address, + } + proxyAddressesB := map[string]string{ + "proxy-node-b-1": proxyB1Address, + "proxy-node-b-2": proxyB2Address, } - configProvider := &simpleConfigProvider{cfg: *cfg} - proxy := s2sproxy.NewProxy(configProvider, s.logger) - s.NotNil(proxy) - - err := proxy.Start() - s.NoError(err, "Failed to start proxy %s", name) + s.proxyA1MemberlistPort = testutil.GetFreePort() + s.proxyA2MemberlistPort = testutil.GetFreePort() + s.proxyB1MemberlistPort = testutil.GetFreePort() + s.proxyB2MemberlistPort = testutil.GetFreePort() - s.logger.Info("Started proxy", tag.NewStringTag("name", name), - tag.NewStringTag("inboundAddress", inboundAddress), - tag.NewStringTag("outboundAddress", outboundAddress), - tag.NewStringTag("muxAddress", muxAddress), - tag.NewStringTag("muxMode", string(muxMode)), - ) + proxyBShardConfig := s.shardCountConfigB + if proxyBShardConfig.Mode == config.ShardCountLCM || proxyBShardConfig.Mode == config.ShardCountRouting { + proxyBShardConfig.LocalShardCount = int32(s.shardCountB) + proxyBShardConfig.RemoteShardCount = int32(s.shardCountA) + } - return proxy -} + s.proxyB1 = createProxy(s.logger, s.T(), "proxy-b-1", proxyB1Address, s.proxyB1Outbound, s.proxyB1Mux, s.clusterB, config.ServerMode, proxyBShardConfig, "proxy-node-b-1", "127.0.0.1", s.proxyB1MemberlistPort, nil, proxyAddressesB) + s.proxyB2 = createProxy(s.logger, s.T(), "proxy-b-2", proxyB2Address, s.proxyB2Outbound, s.proxyB2Mux, s.clusterB, config.ServerMode, proxyBShardConfig, "proxy-node-b-2", "127.0.0.1", s.proxyB2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyB1MemberlistPort)}, proxyAddressesB) -type simpleConfigProvider struct { - cfg config.S2SProxyConfig -} + var countA1, countA2, countB1, countB2, countPA1, countPA2 atomic.Int64 -func (p *simpleConfigProvider) GetS2SProxyConfig() config.S2SProxyConfig { - return p.cfg -} + var err error + s.loadBalancerA, err = createLoadBalancer(s.logger, loadBalancerAPort, []string{s.proxyA1Outbound, s.proxyA2Outbound}, &countA1, &countA2) + s.NoError(err, "Failed to start load balancer A") + s.loadBalancerB, err = createLoadBalancer(s.logger, loadBalancerBPort, []string{s.proxyB1Mux, s.proxyB2Mux}, &countPA1, &countPA2) + s.NoError(err, "Failed to start load balancer B") + s.loadBalancerC, err = createLoadBalancer(s.logger, loadBalancerCPort, []string{s.proxyB1Outbound, s.proxyB2Outbound}, &countB1, &countB2) + s.NoError(err, "Failed to start load balancer C") -func (s *ReplicationTestSuite) configureRemoteCluster( - cluster *testcore.TestCluster, - remoteClusterName string, - proxyAddress string, -) { - _, err := cluster.AdminClient().AddOrUpdateRemoteCluster( - context.Background(), - &adminservice.AddOrUpdateRemoteClusterRequest{ - FrontendAddress: proxyAddress, - EnableRemoteClusterConnection: true, - }, - ) - s.NoError(err, "Failed to configure remote cluster %s", remoteClusterName) - s.logger.Info("Configured remote cluster", - tag.NewStringTag("remoteClusterName", remoteClusterName), - tag.NewStringTag("proxyAddress", proxyAddress), - tag.NewStringTag("clusterName", cluster.ClusterName()), - ) + muxLoadBalancerBAddress := fmt.Sprintf("localhost:%s", loadBalancerBPort) + s.proxyA1 = createProxy(s.logger, s.T(), "proxy-a-1", proxyA1Address, s.proxyA1Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-1", "127.0.0.1", s.proxyA1MemberlistPort, nil, proxyAddressesA) + s.proxyA2 = createProxy(s.logger, s.T(), "proxy-a-2", proxyA2Address, s.proxyA2Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-2", "127.0.0.1", s.proxyA2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyA1MemberlistPort)}, proxyAddressesA) } -func (s *ReplicationTestSuite) deglobalizeNamespace(namespaceName string) { - if s.clusterA == nil { - return +func (s *ReplicationTestSuite) TearDownSuite() { + if s.namespace != "" && s.clusterA != nil { + s.deglobalizeNamespace(s.namespace) } - ctx := context.Background() - updateReq := &workflowservice.UpdateNamespaceRequest{ - Namespace: namespaceName, - ReplicationConfig: &replicationpb.NamespaceReplicationConfig{ - ActiveClusterName: s.clusterA.ClusterName(), - Clusters: []*replicationpb.ClusterReplicationConfig{ - {ClusterName: s.clusterA.ClusterName()}, - }, - }, + if s.clusterA != nil && s.clusterB != nil { + removeRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName()) + removeRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName()) } - - _, err := s.clusterA.FrontendClient().UpdateNamespace(ctx, updateReq) - if err != nil { - s.logger.Warn("Failed to deglobalize namespace", tag.NewStringTag("namespace", namespaceName), tag.Error(err)) - return + if s.clusterA != nil { + s.NoError(s.clusterA.TearDownCluster()) + } + if s.clusterB != nil { + s.NoError(s.clusterB.TearDownCluster()) } - s.Eventually(func() bool { - for _, c := range []*testcore.TestCluster{s.clusterA, s.clusterB} { - if c == nil { - continue - } - descResp, err := c.FrontendClient().DescribeNamespace(ctx, &workflowservice.DescribeNamespaceRequest{ - Namespace: namespaceName, - }) - if err != nil || descResp == nil { - return false - } - clusters := descResp.ReplicationConfig.GetClusters() - if len(clusters) != 1 { - return false - } - if clusters[0].GetClusterName() != s.clusterA.ClusterName() { - return false - } + if s.setupMode == SetupModeSimple { + if s.proxyA != nil { + s.proxyA.Stop() } - return true - }, 10*time.Second, 200*time.Millisecond, "Namespace deglobalization not propagated") - - s.logger.Info("Deglobalized namespace", tag.NewStringTag("namespace", namespaceName)) + if s.proxyB != nil { + s.proxyB.Stop() + } + } else { + if s.loadBalancerA != nil { + s.loadBalancerA.Stop() + } + if s.loadBalancerB != nil { + s.loadBalancerB.Stop() + } + if s.loadBalancerC != nil { + s.loadBalancerC.Stop() + } + if s.proxyA1 != nil { + s.proxyA1.Stop() + } + if s.proxyA2 != nil { + s.proxyA2.Stop() + } + if s.proxyB1 != nil { + s.proxyB1.Stop() + } + if s.proxyB2 != nil { + s.proxyB2.Stop() + } + } } -func (s *ReplicationTestSuite) removeRemoteCluster( - cluster *testcore.TestCluster, - remoteClusterName string, -) { - _, err := cluster.AdminClient().RemoveRemoteCluster( - context.Background(), - &adminservice.RemoveRemoteClusterRequest{ - ClusterName: remoteClusterName, - }, - ) - s.NoError(err, "Failed to remove remote cluster %s", remoteClusterName) - s.logger.Info("Removed remote cluster", - tag.NewStringTag("remoteClusterName", remoteClusterName), - tag.NewStringTag("clusterName", cluster.ClusterName()), - ) +func (s *ReplicationTestSuite) SetupTest() { + s.workflows = nil + + if s.namespace != "" { + s.ensureNamespaceActive(s.clusterA.ClusterName()) + } } func (s *ReplicationTestSuite) createGlobalNamespace() string { @@ -517,31 +514,22 @@ func (s *ReplicationTestSuite) generateWorkflowsWithLoad(workflowsPerPair int) [ // - sourceShard 1 (0-based: 0) can only map to targetShard 1 or 3 (0-based: 0 or 2) // - sourceShard 2 (0-based: 1) can only map to targetShard 2 or 4 (0-based: 1 or 3) func (s *ReplicationTestSuite) isValidShardPair(sourceShard int32, targetShard int32) bool { - // If shard counts are equal, source and target shards must match - // (same hash function with same shard count produces identical shard assignment) if s.shardCountA == s.shardCountB { return sourceShard == targetShard } - // Convert to 0-based for modulo arithmetic sourceShard0 := sourceShard - 1 targetShard0 := targetShard - 1 - // Case 1: targetShardCount divides sourceShardCount (e.g., 4 -> 2) - // Source shard x maps to target shard (x % targetShardCount) if s.shardCountA%s.shardCountB == 0 { expectedTarget := sourceShard0 % int32(s.shardCountB) return targetShard0 == expectedTarget } - // Case 2: sourceShardCount divides targetShardCount (e.g., 2 -> 4) - // Source shard x can map to target shards in set {x, x+sourceShardCount, x+2*sourceShardCount, ...} - // where all values are < targetShardCount if s.shardCountB%s.shardCountA == 0 { return targetShard0%int32(s.shardCountA) == sourceShard0 } - // No divisibility relationship, all pairs are possible (though may be hard to find) return true } @@ -572,25 +560,6 @@ func (s *ReplicationTestSuite) findWorkflowIDForShardPairWithIndex(sourceShard i return "" } -func (s *ReplicationTestSuite) waitForReplicationReady() { - time.Sleep(1 * time.Second) - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - for _, cluster := range []*testcore.TestCluster{s.clusterA, s.clusterB} { - s.Eventually(func() bool { - _, err := cluster.HistoryClient().GetReplicationStatus( - ctx, - &historyservice.GetReplicationStatusRequest{}, - ) - return err == nil - }, 5*time.Second, 200*time.Millisecond, "Replication infrastructure not ready") - } - - time.Sleep(1 * time.Second) -} - func (s *ReplicationTestSuite) waitForClusterSynced() { s.waitForClusterConnected(s.clusterA, s.clusterB.ClusterName()) s.waitForClusterConnected(s.clusterB, s.clusterA.ClusterName()) @@ -614,28 +583,13 @@ func (s *ReplicationTestSuite) waitForClusterConnected( s.logger.Debug("GetReplicationStatus failed", tag.Error(err)) return false } - s.logger.Info("GetReplicationStatus response", - tag.NewStringTag("response", fmt.Sprintf("%+v", resp)), - tag.NewStringTag("source", sourceCluster.ClusterName()), - tag.NewStringTag("target", targetClusterName), - ) if len(resp.Shards) == 0 { return false } for _, shard := range resp.Shards { - s.logger.Info("Replication status", - tag.NewStringTag("shard", fmt.Sprintf("%d", shard.ShardId)), - tag.NewInt64("maxTaskId", shard.MaxReplicationTaskId), - tag.NewStringTag("remoteClusters", fmt.Sprintf("%+v", shard.RemoteClusters)), - ) - if shard.MaxReplicationTaskId <= 0 { - s.logger.Info("Max replication task id is 0", - tag.NewStringTag("shard", fmt.Sprintf("%d", shard.ShardId)), - tag.NewInt64("maxTaskId", shard.MaxReplicationTaskId), - ) continue } @@ -644,19 +598,10 @@ func (s *ReplicationTestSuite) waitForClusterConnected( remoteInfo, ok := shard.RemoteClusters[targetClusterName] if !ok || remoteInfo == nil { - s.logger.Info("Remote cluster not found", - tag.NewStringTag("shard", fmt.Sprintf("%d", shard.ShardId)), - tag.NewStringTag("targetClusterName", targetClusterName), - ) return false } if remoteInfo.AckedTaskId < shard.MaxReplicationTaskId { - s.logger.Debug("Replication not synced", - tag.ShardID(shard.ShardId), - tag.NewInt64("maxTaskId", shard.MaxReplicationTaskId), - tag.NewInt64("ackedTaskId", remoteInfo.AckedTaskId), - ) return false } } @@ -708,8 +653,6 @@ func (s *ReplicationTestSuite) TestReplication() { } } - // TODO: make some progress on the workflows - s.waitForClusterSynced() clientB := s.clusterB.FrontendClient() @@ -720,7 +663,6 @@ func (s *ReplicationTestSuite) TestReplication() { s.failoverNamespace(ctx, s.namespace, s.clusterB.ClusterName()) for _, wf := range s.workflows { - // TODO: continue the workflows instead of just terminating them s.completeWorkflow(ctx, clientB, wf) } @@ -850,6 +792,53 @@ func (s *ReplicationTestSuite) failoverNamespace( s.logger.Info("Namespace failover completed", tag.NewStringTag("namespace", namespaceName), tag.NewStringTag("targetCluster", targetCluster)) } +func (s *ReplicationTestSuite) deglobalizeNamespace(namespaceName string) { + if s.clusterA == nil { + return + } + + ctx := context.Background() + updateReq := &workflowservice.UpdateNamespaceRequest{ + Namespace: namespaceName, + ReplicationConfig: &replicationpb.NamespaceReplicationConfig{ + ActiveClusterName: s.clusterA.ClusterName(), + Clusters: []*replicationpb.ClusterReplicationConfig{ + {ClusterName: s.clusterA.ClusterName()}, + }, + }, + } + + _, err := s.clusterA.FrontendClient().UpdateNamespace(ctx, updateReq) + if err != nil { + s.logger.Warn("Failed to deglobalize namespace", tag.NewStringTag("namespace", namespaceName), tag.Error(err)) + return + } + + s.Eventually(func() bool { + for _, c := range []*testcore.TestCluster{s.clusterA, s.clusterB} { + if c == nil { + continue + } + descResp, err := c.FrontendClient().DescribeNamespace(ctx, &workflowservice.DescribeNamespaceRequest{ + Namespace: namespaceName, + }) + if err != nil || descResp == nil { + return false + } + clusters := descResp.ReplicationConfig.GetClusters() + if len(clusters) != 1 { + return false + } + if clusters[0].GetClusterName() != s.clusterA.ClusterName() { + return false + } + } + return true + }, 10*time.Second, 200*time.Millisecond, "Namespace deglobalization not propagated") + + s.logger.Info("Deglobalized namespace", tag.NewStringTag("namespace", namespaceName)) +} + func (s *ReplicationTestSuite) ensureNamespaceActive(targetCluster string) { descResp, err := s.clusterA.FrontendClient().DescribeNamespace(context.Background(), &workflowservice.DescribeNamespaceRequest{ Namespace: s.namespace, diff --git a/proxy/test/test_common.go b/proxy/test/test_common.go new file mode 100644 index 00000000..de132cc4 --- /dev/null +++ b/proxy/test/test_common.go @@ -0,0 +1,476 @@ +package proxy + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "go.temporal.io/server/api/adminservice/v1" + "go.temporal.io/server/api/historyservice/v1" + "go.temporal.io/server/common" + "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/tests/testcore" + + "github.com/temporalio/s2s-proxy/config" + s2sproxy "github.com/temporalio/s2s-proxy/proxy" +) + +type simpleConfigProvider struct { + cfg config.S2SProxyConfig +} + +func (p *simpleConfigProvider) GetS2SProxyConfig() config.S2SProxyConfig { + return p.cfg +} + +type trackingUpstreamServer struct { + address string + conns atomic.Int64 + count1 *atomic.Int64 + count2 *atomic.Int64 +} + +type trackingUpstream struct { + servers []*trackingUpstreamServer + mu sync.RWMutex +} + +func (u *trackingUpstream) selectLeastConn() *trackingUpstreamServer { + u.mu.RLock() + defer u.mu.RUnlock() + + if len(u.servers) == 0 { + return nil + } + + selected := u.servers[0] + minConns := selected.conns.Load() + + for i := 1; i < len(u.servers); i++ { + conns := u.servers[i].conns.Load() + if conns < minConns { + minConns = conns + selected = u.servers[i] + } + } + + if selected != nil { + if selected == u.servers[0] { + selected.count1.Add(1) + } else if len(u.servers) > 1 && selected == u.servers[1] { + selected.count2.Add(1) + } + } + + return selected +} + +func (u *trackingUpstream) incrementConn(server *trackingUpstreamServer) { + server.conns.Add(1) +} + +func (u *trackingUpstream) decrementConn(server *trackingUpstreamServer) { + server.conns.Add(-1) +} + +type trackingTCPProxy struct { + rules []*trackingProxyRule + logger log.Logger + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + servers []net.Listener +} + +type trackingProxyRule struct { + ListenPort string + Upstream *trackingUpstream +} + +func (p *trackingTCPProxy) Start() error { + for _, rule := range p.rules { + listener, err := net.Listen("tcp", ":"+rule.ListenPort) + if err != nil { + p.Stop() + return fmt.Errorf("failed to listen on port %s: %w", rule.ListenPort, err) + } + p.servers = append(p.servers, listener) + + p.wg.Add(1) + go p.handleListener(listener, rule) + } + + return nil +} + +func (p *trackingTCPProxy) Stop() { + p.logger.Info("Stopping tracking TCP proxy") + p.cancel() + for _, server := range p.servers { + p.logger.Info("Closing server", tag.NewStringTag("server", server.Addr().String())) + _ = server.Close() + } + p.logger.Info("Waiting for goroutines to finish") + p.wg.Wait() + p.logger.Info("Tracking TCP proxy stopped") +} + +func (p *trackingTCPProxy) handleListener(listener net.Listener, rule *trackingProxyRule) { + defer p.wg.Done() + + for { + select { + case <-p.ctx.Done(): + return + default: + } + + clientConn, err := listener.Accept() + if err != nil { + select { + case <-p.ctx.Done(): + return + default: + p.logger.Warn("failed to accept connection", tag.Error(err)) + continue + } + } + + p.wg.Add(1) + go p.handleConnection(clientConn, rule) + } +} + +func (p *trackingTCPProxy) handleConnection(clientConn net.Conn, rule *trackingProxyRule) { + defer p.wg.Done() + defer func() { _ = clientConn.Close() }() + + select { + case <-p.ctx.Done(): + return + default: + } + + upstream := rule.Upstream.selectLeastConn() + if upstream == nil { + p.logger.Error("no upstream servers available") + return + } + + rule.Upstream.incrementConn(upstream) + defer rule.Upstream.decrementConn(upstream) + + serverConn, err := net.DialTimeout("tcp", upstream.address, 5*time.Second) + if err != nil { + p.logger.Warn("failed to connect to upstream", tag.NewStringTag("upstream", upstream.address), tag.Error(err)) + return + } + defer func() { _ = serverConn.Close() }() + + var wg sync.WaitGroup + wg.Add(3) + + go func() { + defer wg.Done() + <-p.ctx.Done() + _ = clientConn.Close() + _ = serverConn.Close() + }() + + go func() { + defer wg.Done() + _, _ = io.Copy(serverConn, clientConn) + _ = serverConn.Close() + }() + + go func() { + defer wg.Done() + _, _ = io.Copy(clientConn, serverConn) + _ = clientConn.Close() + }() + + wg.Wait() +} + +func createLoadBalancer( + logger log.Logger, + listenPort string, + upstreams []string, + count1 *atomic.Int64, + count2 *atomic.Int64, +) (*trackingTCPProxy, error) { + trackingServers := make([]*trackingUpstreamServer, len(upstreams)) + for i, addr := range upstreams { + trackingServers[i] = &trackingUpstreamServer{ + address: addr, + count1: count1, + count2: count2, + } + } + + trackingUpstream := &trackingUpstream{ + servers: trackingServers, + } + + rules := []*trackingProxyRule{ + { + ListenPort: listenPort, + Upstream: trackingUpstream, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + trackingProxy := &trackingTCPProxy{ + rules: rules, + logger: logger, + ctx: ctx, + cancel: cancel, + } + + err := trackingProxy.Start() + if err != nil { + return nil, err + } + + return trackingProxy, nil +} + +func createCluster( + logger log.Logger, + t testingT, + clusterName string, + numShards int, + initialFailoverVersion int64, + numHistoryHosts int, +) *testcore.TestCluster { + clusterSuffix := common.GenerateRandomString(8) + fullClusterName := fmt.Sprintf("%s-%s", clusterName, clusterSuffix) + + clusterConfig := &testcore.TestClusterConfig{ + ClusterMetadata: cluster.Config{ + EnableGlobalNamespace: true, + FailoverVersionIncrement: 10, + MasterClusterName: fullClusterName, + CurrentClusterName: fullClusterName, + ClusterInformation: map[string]cluster.ClusterInformation{ + fullClusterName: { + Enabled: true, + InitialFailoverVersion: initialFailoverVersion, + }, + }, + }, + HistoryConfig: testcore.HistoryConfig{ + NumHistoryShards: int32(numShards), + NumHistoryHosts: numHistoryHosts, + }, + DynamicConfigOverrides: map[dynamicconfig.Key]interface{}{ + dynamicconfig.NamespaceCacheRefreshInterval.Key(): time.Second, + dynamicconfig.EnableReplicationStream.Key(): true, + dynamicconfig.EnableReplicationTaskBatching.Key(): true, + }, + } + + testClusterFactory := testcore.NewTestClusterFactory() + logger = log.With(logger, tag.NewStringTag("clusterName", clusterName)) + + testT := getTestingT(t) + cluster, err := testClusterFactory.NewCluster(testT, clusterConfig, logger) + if err != nil { + t.Fatalf("Failed to create cluster %s: %v", clusterName, err) + } + + return cluster +} + +func createProxy( + logger log.Logger, + t testingT, + name string, + inboundAddress string, + outboundAddress string, + muxAddress string, + cluster *testcore.TestCluster, + muxMode config.MuxMode, + shardCountConfig config.ShardCountConfig, + nodeName string, + memberlistBindAddr string, + memberlistBindPort int, + memberlistJoinAddrs []string, + proxyAddresses map[string]string, +) *s2sproxy.Proxy { + var muxConnectionType config.ConnectionType + var muxAddressInfo config.TCPTLSInfo + if muxMode == config.ServerMode { + muxConnectionType = config.ConnTypeMuxServer + muxAddressInfo = config.TCPTLSInfo{ + ConnectionString: muxAddress, + } + } else { + muxConnectionType = config.ConnTypeMuxClient + muxAddressInfo = config.TCPTLSInfo{ + ConnectionString: muxAddress, + } + } + + cfg := &config.S2SProxyConfig{ + ClusterConnections: []config.ClusterConnConfig{ + { + Name: name, + LocalServer: config.ClusterDefinition{ + Connection: config.TransportInfo{ + ConnectionType: config.ConnTypeTCP, + TcpClient: config.TCPTLSInfo{ + ConnectionString: cluster.Host().FrontendGRPCAddress(), + }, + TcpServer: config.TCPTLSInfo{ + ConnectionString: outboundAddress, + }, + }, + }, + RemoteServer: config.ClusterDefinition{ + Connection: config.TransportInfo{ + ConnectionType: muxConnectionType, + MuxCount: 1, + MuxAddressInfo: muxAddressInfo, + }, + }, + ShardCountConfig: shardCountConfig, + }, + }, + } + + if nodeName != "" && memberlistBindAddr != "" { + cfg.ClusterConnections[0].MemberlistConfig = &config.MemberlistConfig{ + Enabled: true, + NodeName: nodeName, + BindAddr: memberlistBindAddr, + BindPort: memberlistBindPort, + JoinAddrs: memberlistJoinAddrs, + ProxyAddresses: proxyAddresses, + TCPOnly: true, + } + } + + configProvider := &simpleConfigProvider{cfg: *cfg} + proxy := s2sproxy.NewProxy(configProvider, logger) + if proxy == nil { + t.Fatalf("Failed to create proxy %s", name) + } + + err := proxy.Start() + if err != nil { + t.Fatalf("Failed to start proxy %s: %v", name, err) + } + + logger.Info("Started proxy", tag.NewStringTag("name", name), + tag.NewStringTag("inboundAddress", inboundAddress), + tag.NewStringTag("outboundAddress", outboundAddress), + tag.NewStringTag("muxAddress", muxAddress), + tag.NewStringTag("muxMode", string(muxMode)), + tag.NewStringTag("nodeName", nodeName), + ) + + return proxy +} + +func configureRemoteCluster( + logger log.Logger, + t testingT, + cluster *testcore.TestCluster, + remoteClusterName string, + proxyAddress string, +) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + _, err := cluster.AdminClient().AddOrUpdateRemoteCluster( + ctx, + &adminservice.AddOrUpdateRemoteClusterRequest{ + FrontendAddress: proxyAddress, + EnableRemoteClusterConnection: true, + }, + ) + if err != nil { + t.Fatalf("Failed to configure remote cluster %s: %v", remoteClusterName, err) + } + logger.Info("Configured remote cluster", + tag.NewStringTag("remoteClusterName", remoteClusterName), + tag.NewStringTag("proxyAddress", proxyAddress), + tag.NewStringTag("clusterName", cluster.ClusterName()), + ) +} + +func removeRemoteCluster( + logger log.Logger, + t testingT, + cluster *testcore.TestCluster, + remoteClusterName string, +) { + _, err := cluster.AdminClient().RemoveRemoteCluster( + context.Background(), + &adminservice.RemoveRemoteClusterRequest{ + ClusterName: remoteClusterName, + }, + ) + if err != nil { + t.Fatalf("Failed to remove remote cluster %s: %v", remoteClusterName, err) + } + logger.Info("Removed remote cluster", + tag.NewStringTag("remoteClusterName", remoteClusterName), + tag.NewStringTag("clusterName", cluster.ClusterName()), + ) +} + +func waitForReplicationReady( + logger log.Logger, + t testingT, + clusters ...*testcore.TestCluster, +) { + time.Sleep(1 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for _, cluster := range clusters { + ready := false + for i := 0; i < 25; i++ { + _, err := cluster.HistoryClient().GetReplicationStatus( + ctx, + &historyservice.GetReplicationStatusRequest{}, + ) + if err == nil { + ready = true + break + } + time.Sleep(200 * time.Millisecond) + } + if !ready { + t.Fatalf("Replication infrastructure not ready for cluster %s", cluster.ClusterName()) + } + } + + time.Sleep(1 * time.Second) +} + +type testingT interface { + Helper() + Fatalf(format string, args ...interface{}) +} + +func getTestingT(t testingT) *testing.T { + if testT, ok := t.(*testing.T); ok { + return testT + } + if suiteT, ok := t.(interface{ T() *testing.T }); ok { + return suiteT.T() + } + panic("testingT must be *testing.T or have T() method") +} diff --git a/proxy/test/wiring_test.go b/proxy/test/wiring_test.go index 768b3919..b7527378 100644 --- a/proxy/test/wiring_test.go +++ b/proxy/test/wiring_test.go @@ -23,6 +23,7 @@ import ( "github.com/temporalio/s2s-proxy/config" "github.com/temporalio/s2s-proxy/endtoendtest" + "github.com/temporalio/s2s-proxy/testutil" ) func init() { @@ -41,31 +42,36 @@ type ( ) var ( - // Create some believable echo server configs - echoServerInfo = endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + logger log.Logger +) + +func getEchoServerInfo() endtoendtest.ClusterInfo { + return endtoendtest.ClusterInfo{ + ServerAddress: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), ClusterShardID: serverClusterShard, S2sProxyConfig: makeS2SConfig(s2sAddresses{ - echoServer: "localhost:7266", - inbound: "localhost:7366", - outbound: "localhost:7466", - prometheus: "localhost:7468", - healthCheck: "localhost:7479", + echoServer: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + inbound: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + outbound: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + prometheus: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + healthCheck: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), }), } - echoClientInfo = endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, +} + +func getEchoClientInfo() endtoendtest.ClusterInfo { + return endtoendtest.ClusterInfo{ + ServerAddress: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), ClusterShardID: clientClusterShard, S2sProxyConfig: makeS2SConfig(s2sAddresses{ - echoServer: "localhost:8266", - inbound: "localhost:8366", - outbound: "localhost:8466", - prometheus: "localhost:7467", - healthCheck: "localhost:7478", + echoServer: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + inbound: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + outbound: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + prometheus: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + healthCheck: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), }), } - logger log.Logger -) +} type hangupAdminServer struct { adminservice.UnimplementedAdminServiceServer @@ -129,6 +135,12 @@ func TestEOFFromServer(t *testing.T) { } func TestWiringWithEchoService(t *testing.T) { + echoServerInfo := getEchoServerInfo() + echoClientInfo := getEchoClientInfo() + // Update outbound client address to point to the other proxy's inbound + echoServerInfo.S2sProxyConfig.Outbound.Client.ServerAddress = echoClientInfo.S2sProxyConfig.Inbound.Server.ListenAddress + echoClientInfo.S2sProxyConfig.Outbound.Client.ServerAddress = echoServerInfo.S2sProxyConfig.Inbound.Server.ListenAddress + echoServer := endtoendtest.NewEchoServer(echoServerInfo, echoClientInfo, "EchoServer", logger, nil) echoClient := endtoendtest.NewEchoServer(echoClientInfo, echoServerInfo, "EchoClient", logger, nil) echoServer.Start() From 949ec25f44f79e67bbcb48fcae49fae8e069c65e Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Mon, 22 Dec 2025 17:24:45 -0800 Subject: [PATCH 27/38] handle late-registered remote shard --- Makefile | 2 +- proxy/intra_proxy_router.go | 133 +++++++++++++++++++++++- proxy/proxy_streams.go | 39 +++++-- proxy/shard_manager.go | 35 +++++-- proxy/test/replication_failover_test.go | 32 +++++- 5 files changed, 215 insertions(+), 26 deletions(-) diff --git a/Makefile b/Makefile index 4dbc3f41..1d4bdcaf 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ GO_GET_TOOL = go get -tool -modfile=$(TOOLS_MOD_FILE) # Disable cgo by default. CGO_ENABLED ?= 0 -TEST_ARG ?= -race -timeout=5m -tags test_dep +TEST_ARG ?= -race -timeout=15m -tags test_dep BENCH_ARG ?= -benchtime=5000x ALL_SRC := $(shell find . -name "*.go") diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 726432cf..07ddb738 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -16,6 +16,7 @@ import ( "go.temporal.io/server/common/log/tag" "google.golang.org/grpc" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" "github.com/temporalio/s2s-proxy/common" "github.com/temporalio/s2s-proxy/encryption" @@ -88,6 +89,28 @@ func (s *intraProxyStreamSender) Run( s.shardManager.GetIntraProxyManager().RegisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID, s) defer s.shardManager.GetIntraProxyManager().UnregisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID) + // Send pending watermarks to late-registering shards + // When a sender is registered, check if there's an active receiver for the source shard + // that has a pending watermark, and send it immediately to the peer + if receiver, ok := s.shardManager.GetActiveReceiver(s.sourceShardID); ok { + if lastWatermark := receiver.GetLastWatermark(); lastWatermark != nil && lastWatermark.ExclusiveHighWatermark > 0 { + s.logger.Info("Sending pending watermark to peer on sender registration", + tag.NewInt64("exclusive_high", lastWatermark.ExclusiveHighWatermark), + tag.NewStringTag("peer", s.peerNodeName)) + resp := &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: lastWatermark.ExclusiveHighWatermark, + Priority: lastWatermark.Priority, + }, + }, + } + if err := s.sendReplicationMessages(resp); err != nil { + s.logger.Warn("Failed to send pending watermark to peer on sender registration", tag.Error(err)) + } + } + } + // recv ACKs from peer and route to original source shard owner return s.recvAck(shutdownChan) } @@ -175,6 +198,9 @@ type intraProxyStreamReceiver struct { streamID string shutdown channel.ShutdownOnce cancel context.CancelFunc + // lastWatermark tracks the last watermark received from source shard for late-registering target shards + lastWatermarkMu sync.RWMutex + lastWatermark *replicationv1.WorkflowReplicationMessages } // Run opens the client stream with metadata, registers tracking, and starts receiver goroutines. @@ -199,14 +225,17 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, shardManager ShardMa r.cancel = cancel client := adminservice.NewAdminServiceClient(conn) - stream, err := client.StreamWorkflowReplicationMessages(ctx) + streamClient, err := client.StreamWorkflowReplicationMessages(ctx) if err != nil { if r.cancel != nil { r.cancel() } return err } - r.streamClient = stream + r.streamClient = streamClient + + r.shardManager.RegisterActiveReceiver(r.sourceShardID, r) + defer r.shardManager.UnregisterActiveReceiver(r.sourceShardID) // Register client-side intra-proxy stream in tracker st := GetGlobalStreamTracker() @@ -246,6 +275,14 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages() error { st.UpdateStreamReplicationMessages(r.streamID, msgs.Messages.ExclusiveHighWatermark) st.UpdateStream(r.streamID) + // Track last watermark for late-registering shards + r.lastWatermarkMu.Lock() + r.lastWatermark = &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: msgs.Messages.ExclusiveHighWatermark, + Priority: msgs.Messages.Priority, + } + r.lastWatermarkMu.Unlock() + r.logger.Info(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", msgs.Messages.ExclusiveHighWatermark, ids)) msg := RoutedMessage{SourceShard: r.sourceShardID, Resp: resp} @@ -305,6 +342,92 @@ func (r *intraProxyStreamReceiver) sendAck(req *adminservice.StreamWorkflowRepli return nil } +// GetTargetShardID returns the target shard ID for this receiver +func (r *intraProxyStreamReceiver) GetTargetShardID() history.ClusterShardID { + return r.targetShardID +} + +// GetSourceShardID returns the source shard ID for this receiver +func (r *intraProxyStreamReceiver) GetSourceShardID() history.ClusterShardID { + return r.sourceShardID +} + +// GetLastWatermark returns the last watermark received from the source shard +func (r *intraProxyStreamReceiver) GetLastWatermark() *replicationv1.WorkflowReplicationMessages { + r.lastWatermarkMu.RLock() + defer r.lastWatermarkMu.RUnlock() + return r.lastWatermark +} + +// NotifyNewTargetShard notifies the receiver about a newly registered target shard +func (r *intraProxyStreamReceiver) NotifyNewTargetShard(targetShardID history.ClusterShardID) { + r.sendPendingWatermarkToShard(targetShardID) +} + +// sendPendingWatermarkToShard sends the last known watermark to a newly registered target shard +// This ensures late-registering shards receive watermarks that were sent before they registered +func (r *intraProxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history.ClusterShardID) { + r.lastWatermarkMu.RLock() + lastWatermark := r.lastWatermark + r.lastWatermarkMu.RUnlock() + + if lastWatermark == nil || lastWatermark.ExclusiveHighWatermark == 0 { + // No pending watermark to send + return + } + + r.logger.Info("Sending pending watermark to newly registered shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), + tag.NewInt64("exclusive_high", lastWatermark.ExclusiveHighWatermark)) + + msg := RoutedMessage{ + SourceShard: r.sourceShardID, + Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: lastWatermark.ExclusiveHighWatermark, + Priority: lastWatermark.Priority, + }, + }, + }, + } + + // Try to send to local shard first + if sendChan, exists := r.shardManager.GetRemoteSendChan(targetShardID); exists { + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + select { + case sendChan <- clonedMsg: + r.logger.Info("Sent pending watermark to local shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + default: + r.logger.Warn("Failed to send pending watermark to local shard (channel full)", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } + return + } + + // If not local, try to send to remote shard + if r.shardManager != nil { + shutdownChan := channel.NewShutdownOnce() + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + if r.shardManager.DeliverMessagesToShardOwner(targetShardID, &clonedMsg, shutdownChan, r.logger) { + r.logger.Info("Sent pending watermark to remote shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } else { + r.logger.Warn("Failed to send pending watermark to remote shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } + } +} + func (m *intraProxyManager) RegisterSender( peerNodeName string, targetShard history.ClusterShardID, @@ -691,7 +814,7 @@ func (m *intraProxyManager) Notify() { // for a given peer and closes any sender/receiver not in the desired set. // This mirrors the Temporal StreamReceiverMonitor approach. func (m *intraProxyManager) ReconcilePeerStreams(peerNodeName string) { - m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("peerNodeName", peerNodeName)) + m.logger.Info("ReconcilePeerStreams started", tag.NewStringTag("peerNodeName", peerNodeName)) defer m.logger.Info("ReconcilePeerStreams done", tag.NewStringTag("peerNodeName", peerNodeName)) localShards := m.shardManager.GetLocalShards() @@ -700,7 +823,7 @@ func (m *intraProxyManager) ReconcilePeerStreams(peerNodeName string) { m.logger.Error("Failed to get remote shards for peer", tag.Error(err)) return } - m.logger.Info("ReconcilePeerStreams", + m.logger.Info("ReconcilePeerStreams remote and local shards", tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("remoteShards", fmt.Sprintf("%v", remoteShards)), tag.NewStringTag("localShards", fmt.Sprintf("%v", localShards)), @@ -733,7 +856,7 @@ func (m *intraProxyManager) ReconcilePeerStreams(peerNodeName string) { } } - m.logger.Info("ReconcilePeerStreams", tag.NewStringTag("desiredReceivers", fmt.Sprintf("%v", desiredReceivers)), tag.NewStringTag("desiredSenders", fmt.Sprintf("%v", desiredSenders))) + m.logger.Info("ReconcilePeerStreams desired receivers and senders", tag.NewStringTag("desiredReceivers", fmt.Sprintf("%v", desiredReceivers)), tag.NewStringTag("desiredSenders", fmt.Sprintf("%v", desiredSenders))) // Ensure all desired receivers exist for key := range desiredReceivers { diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 8ebbd370..abbb008f 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -726,14 +726,6 @@ func (r *proxyStreamReceiver) recvReplicationMessages( // record last source exclusive high watermark (original id space) r.lastExclusiveHighOriginal = attr.Messages.ExclusiveHighWatermark - // Track last watermark for late-registering shards - r.lastWatermarkMu.Lock() - r.lastWatermark = &replicationv1.WorkflowReplicationMessages{ - ExclusiveHighWatermark: attr.Messages.ExclusiveHighWatermark, - Priority: attr.Messages.Priority, - } - r.lastWatermarkMu.Unlock() - // update tracker for incoming messages if r.streamTracker != nil && r.streamID != "" { r.streamTracker.UpdateStreamLastTaskIDs(r.streamID, ids) @@ -745,6 +737,15 @@ func (r *proxyStreamReceiver) recvReplicationMessages( // If replication tasks are empty, still log the empty batch and send watermark if len(attr.Messages.ReplicationTasks) == 0 { r.logger.Info("Receiver received empty replication batch", tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) + + // Track last watermark for late-registering shards + r.lastWatermarkMu.Lock() + r.lastWatermark = &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: attr.Messages.ExclusiveHighWatermark, + Priority: attr.Messages.Priority, + } + r.lastWatermarkMu.Unlock() + msg := RoutedMessage{ SourceShard: r.sourceShardID, Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ @@ -870,6 +871,28 @@ func (r *proxyStreamReceiver) recvReplicationMessages( return nil } +// GetTargetShardID returns the target shard ID for this receiver +func (r *proxyStreamReceiver) GetTargetShardID() history.ClusterShardID { + return r.targetShardID +} + +// GetSourceShardID returns the source shard ID for this receiver +func (r *proxyStreamReceiver) GetSourceShardID() history.ClusterShardID { + return r.sourceShardID +} + +// GetLastWatermark returns the last watermark received from the source shard +func (r *proxyStreamReceiver) GetLastWatermark() *replicationv1.WorkflowReplicationMessages { + r.lastWatermarkMu.RLock() + defer r.lastWatermarkMu.RUnlock() + return r.lastWatermark +} + +// NotifyNewTargetShard notifies the receiver about a newly registered target shard +func (r *proxyStreamReceiver) NotifyNewTargetShard(targetShardID history.ClusterShardID) { + r.sendPendingWatermarkToShard(targetShardID) +} + // sendPendingWatermarkToShard sends the last known watermark to a newly registered target shard // This ensures late-registering shards receive watermarks that were sent before they registered func (r *proxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history.ClusterShardID) { diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 6219956d..fd6a0d65 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -9,6 +9,7 @@ import ( "time" "github.com/hashicorp/memberlist" + replicationv1 "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/client/history" "go.temporal.io/server/common/channel" "go.temporal.io/server/common/log" @@ -19,6 +20,14 @@ import ( ) type ( + // ActiveReceiver is an interface for receivers that can be notified of new target shards + ActiveReceiver interface { + GetTargetShardID() history.ClusterShardID + GetSourceShardID() history.ClusterShardID + NotifyNewTargetShard(targetShardID history.ClusterShardID) + GetLastWatermark() *replicationv1.WorkflowReplicationMessages + } + // ShardManager manages distributed shard ownership across proxy instances ShardManager interface { // Start initializes the memberlist cluster and starts the manager @@ -68,9 +77,11 @@ type ( // New: notify when remote shard set changes for a peer SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) // RegisterActiveReceiver registers an active receiver for watermark propagation - RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver *proxyStreamReceiver) + RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver ActiveReceiver) // UnregisterActiveReceiver removes an active receiver UnregisterActiveReceiver(sourceShardID history.ClusterShardID) + // GetActiveReceiver returns the active receiver for the given source shard + GetActiveReceiver(sourceShardID history.ClusterShardID) (ActiveReceiver, bool) // SetRemoteSendChan registers a send channel for a specific shard ID SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) // GetRemoteSendChan retrieves the send channel for a specific shard ID @@ -120,8 +131,8 @@ type ( stopJoinRetry chan struct{} joinWg sync.WaitGroup joinLoopRunning bool - // activeReceivers tracks active proxyStreamReceiver instances by source shard for watermark propagation - activeReceivers map[history.ClusterShardID]*proxyStreamReceiver + // activeReceivers tracks active receiver instances by source shard for watermark propagation + activeReceivers map[history.ClusterShardID]ActiveReceiver activeReceiversMu sync.RWMutex // remoteSendChannels maps shard IDs to send channels for replication message routing remoteSendChannels map[history.ClusterShardID]chan RoutedMessage @@ -176,7 +187,7 @@ func NewShardManager(memberlistConfig *config.MemberlistConfig, shardCountConfig intraMgr: nil, intraProxyTLSConfig: intraProxyTLSConfig, stopJoinRetry: make(chan struct{}), - activeReceivers: make(map[history.ClusterShardID]*proxyStreamReceiver), + activeReceivers: make(map[history.ClusterShardID]ActiveReceiver), remoteSendChannels: make(map[history.ClusterShardID]chan RoutedMessage), localAckChannels: make(map[history.ClusterShardID]chan RoutedAck), localReceiverCancelFuncs: make(map[history.ClusterShardID]context.CancelFunc), @@ -1030,7 +1041,7 @@ func (sm *shardManagerImpl) removeLocalShard(shard history.ClusterShardID) { } // RegisterActiveReceiver registers an active receiver for watermark propagation -func (sm *shardManagerImpl) RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver *proxyStreamReceiver) { +func (sm *shardManagerImpl) RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver ActiveReceiver) { sm.activeReceiversMu.Lock() defer sm.activeReceiversMu.Unlock() sm.activeReceivers[sourceShardID] = receiver @@ -1043,11 +1054,19 @@ func (sm *shardManagerImpl) UnregisterActiveReceiver(sourceShardID history.Clust delete(sm.activeReceivers, sourceShardID) } +// GetActiveReceiver returns the active receiver for the given source shard +func (sm *shardManagerImpl) GetActiveReceiver(sourceShardID history.ClusterShardID) (ActiveReceiver, bool) { + sm.activeReceiversMu.RLock() + defer sm.activeReceiversMu.RUnlock() + receiver, ok := sm.activeReceivers[sourceShardID] + return receiver, ok +} + // notifyReceiversOfNewShard notifies all receivers about a newly registered target shard // so they can send pending watermarks if available func (sm *shardManagerImpl) notifyReceiversOfNewShard(targetShardID history.ClusterShardID) { sm.activeReceiversMu.RLock() - receivers := make([]*proxyStreamReceiver, 0, len(sm.activeReceivers)) + receivers := make([]ActiveReceiver, 0, len(sm.activeReceivers)) for _, receiver := range sm.activeReceivers { receivers = append(receivers, receiver) } @@ -1055,8 +1074,8 @@ func (sm *shardManagerImpl) notifyReceiversOfNewShard(targetShardID history.Clus for _, receiver := range receivers { // Only notify receivers that route to the same cluster as the newly registered shard - if receiver.targetShardID.ClusterID == targetShardID.ClusterID { - receiver.sendPendingWatermarkToShard(targetShardID) + if receiver.GetTargetShardID().ClusterID == targetShardID.ClusterID { + receiver.NotifyNewTargetShard(targetShardID) } } } diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index 660dba29..d8580fa8 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -311,13 +311,14 @@ func (s *ReplicationTestSuite) setupMultiProxy() { proxyB1Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) proxyB2Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + // For intra-proxy communication, use outbound addresses where proxies listen proxyAddressesA := map[string]string{ - "proxy-node-a-1": proxyA1Address, - "proxy-node-a-2": proxyA2Address, + "proxy-node-a-1": s.proxyA1Outbound, + "proxy-node-a-2": s.proxyA2Outbound, } proxyAddressesB := map[string]string{ - "proxy-node-b-1": proxyB1Address, - "proxy-node-b-2": proxyB2Address, + "proxy-node-b-1": s.proxyB1Outbound, + "proxy-node-b-2": s.proxyB2Outbound, } s.proxyA1MemberlistPort = testutil.GetFreePort() @@ -513,7 +514,14 @@ func (s *ReplicationTestSuite) generateWorkflowsWithLoad(workflowsPerPair int) [ // Vice versa: if sourceShardCount=2 and targetShardCount=4: // - sourceShard 1 (0-based: 0) can only map to targetShard 1 or 3 (0-based: 0 or 2) // - sourceShard 2 (0-based: 1) can only map to targetShard 2 or 4 (0-based: 1 or 3) +// +// When using routing mode, all pairs are valid because intra-proxy routing can handle arbitrary mappings. func (s *ReplicationTestSuite) isValidShardPair(sourceShard int32, targetShard int32) bool { + // In routing mode, all pairs are valid because intra-proxy routing can handle arbitrary mappings + if s.shardCountConfigB.Mode == config.ShardCountRouting { + return true + } + if s.shardCountA == s.shardCountB { return sourceShard == targetShard } @@ -575,6 +583,11 @@ func (s *ReplicationTestSuite) waitForClusterConnected( ) s.Eventually(func() bool { + s.logger.Info("Checking replication status for clusters to sync", + tag.NewStringTag("source", sourceCluster.ClusterName()), + tag.NewStringTag("target", targetClusterName), + ) + resp, err := sourceCluster.HistoryClient().GetReplicationStatus( context.Background(), &historyservice.GetReplicationStatusRequest{}, @@ -583,12 +596,23 @@ func (s *ReplicationTestSuite) waitForClusterConnected( s.logger.Debug("GetReplicationStatus failed", tag.Error(err)) return false } + s.logger.Info("GetReplicationStatus succeeded", + tag.NewStringTag("source", sourceCluster.ClusterName()), + tag.NewStringTag("target", targetClusterName), + tag.NewStringTag("resp", fmt.Sprintf("%v", resp)), + ) if len(resp.Shards) == 0 { return false } for _, shard := range resp.Shards { + s.logger.Info("Checking shard", + tag.NewInt32("shardId", shard.ShardId), + tag.NewInt64("maxReplicationTaskId", shard.MaxReplicationTaskId), + tag.NewStringTag("shardLocalTime", fmt.Sprintf("%v", shard.ShardLocalTime.AsTime())), + tag.NewStringTag("remoteClusters", fmt.Sprintf("%v", shard.RemoteClusters)), + ) if shard.MaxReplicationTaskId <= 0 { continue } From e1dc33081e321bae5674c56c6d79b9e36d84a81a Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Mon, 22 Dec 2025 17:37:52 -0800 Subject: [PATCH 28/38] use make bins to avoid testcore related error --- .github/workflows/pull-request.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 3302d852..8a81cf33 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -59,7 +59,7 @@ jobs: cache: ${{ github.ref == 'refs/heads/main' }} # only update the cache in main. - name: Run go build - run: go build ./... + run: make bins - name: Run go unittest run: make test From 6bd84496fb6b75851ffae7edef02410025410a51 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Tue, 23 Dec 2025 22:32:17 -0800 Subject: [PATCH 29/38] fix test error; refactor test --- Makefile | 2 +- endtoendtest/tcp_proxy_test.go | 168 ------------------ proxy/proxy_streams.go | 9 +- proxy/shard_manager.go | 202 ++++++++++++++++------ proxy/test/bench_test.go | 18 +- proxy/test/echo_proxy_test.go | 14 +- proxy/test/intra_proxy_routing_test.go | 35 ++-- proxy/test/replication_failover_test.go | 41 +++-- {endtoendtest => proxy/test}/tcp_proxy.go | 2 +- proxy/test/tcp_proxy_test.go | 171 ++++++++++++++++++ proxy/test/test_common.go | 20 +++ proxy/test/wiring_test.go | 38 ++-- testutil/testutil.go | 21 --- 13 files changed, 424 insertions(+), 317 deletions(-) delete mode 100644 endtoendtest/tcp_proxy_test.go rename {endtoendtest => proxy/test}/tcp_proxy.go (99%) create mode 100644 proxy/test/tcp_proxy_test.go delete mode 100644 testutil/testutil.go diff --git a/Makefile b/Makefile index 1d4bdcaf..be35b5a2 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ GO_GET_TOOL = go get -tool -modfile=$(TOOLS_MOD_FILE) # Disable cgo by default. CGO_ENABLED ?= 0 -TEST_ARG ?= -race -timeout=15m -tags test_dep +TEST_ARG ?= -race -timeout=15m -tags test_dep -count=1 BENCH_ARG ?= -benchtime=5000x ALL_SRC := $(shell find . -name "*.go") diff --git a/endtoendtest/tcp_proxy_test.go b/endtoendtest/tcp_proxy_test.go deleted file mode 100644 index aa5ecfeb..00000000 --- a/endtoendtest/tcp_proxy_test.go +++ /dev/null @@ -1,168 +0,0 @@ -package endtoendtest - -import ( - "fmt" - "io" - "net" - "testing" - "time" - - "github.com/stretchr/testify/require" - "go.temporal.io/server/common/log" - - "github.com/temporalio/s2s-proxy/testutil" -) - -func TestTCPProxy(t *testing.T) { - logger := log.NewTestLogger() - - server1Port := testutil.GetFreePort() - server2Port := testutil.GetFreePort() - server3Port := testutil.GetFreePort() - server4Port := testutil.GetFreePort() - server5Port := testutil.GetFreePort() - server6Port := testutil.GetFreePort() - - echoServer1 := startEchoServer(t, fmt.Sprintf("localhost:%d", server1Port)) - echoServer2 := startEchoServer(t, fmt.Sprintf("localhost:%d", server2Port)) - echoServer3 := startEchoServer(t, fmt.Sprintf("localhost:%d", server3Port)) - echoServer4 := startEchoServer(t, fmt.Sprintf("localhost:%d", server4Port)) - echoServer5 := startEchoServer(t, fmt.Sprintf("localhost:%d", server5Port)) - echoServer6 := startEchoServer(t, fmt.Sprintf("localhost:%d", server6Port)) - - defer func() { _ = echoServer1.Close() }() - defer func() { _ = echoServer2.Close() }() - defer func() { _ = echoServer3.Close() }() - defer func() { _ = echoServer4.Close() }() - defer func() { _ = echoServer5.Close() }() - defer func() { _ = echoServer6.Close() }() - - proxyPort1 := testutil.GetFreePort() - proxyPort2 := testutil.GetFreePort() - proxyPort3 := testutil.GetFreePort() - - rules := []*ProxyRule{ - { - ListenPort: fmt.Sprintf("%d", proxyPort1), - Upstream: NewUpstream([]string{fmt.Sprintf("localhost:%d", server1Port), fmt.Sprintf("localhost:%d", server2Port)}), - }, - { - ListenPort: fmt.Sprintf("%d", proxyPort2), - Upstream: NewUpstream([]string{fmt.Sprintf("localhost:%d", server3Port), fmt.Sprintf("localhost:%d", server4Port)}), - }, - { - ListenPort: fmt.Sprintf("%d", proxyPort3), - Upstream: NewUpstream([]string{fmt.Sprintf("localhost:%d", server5Port), fmt.Sprintf("localhost:%d", server6Port)}), - }, - } - - proxy := NewTCPProxy(logger, rules) - err := proxy.Start() - require.NoError(t, err) - defer proxy.Stop() - - // Test proxy on port 1 - testProxyConnection(t, fmt.Sprintf("localhost:%d", proxyPort1), "test message 1") - - // Test proxy on port 2 - testProxyConnection(t, fmt.Sprintf("localhost:%d", proxyPort2), "test message 2") - - // Test proxy on port 3 - testProxyConnection(t, fmt.Sprintf("localhost:%d", proxyPort3), "test message 3") -} - -func testProxyConnection(t *testing.T, proxyAddr, message string) { - conn, err := net.DialTimeout("tcp", proxyAddr, 5*time.Second) - require.NoError(t, err) - defer func() { _ = conn.Close() }() - - _, err = conn.Write([]byte(message)) - require.NoError(t, err) - - buf := make([]byte, len(message)) - _, err = io.ReadFull(conn, buf) - require.NoError(t, err) - require.Equal(t, message, string(buf)) -} - -func startEchoServer(t *testing.T, addr string) net.Listener { - listener, err := net.Listen("tcp", addr) - require.NoError(t, err) - - go func() { - for { - conn, err := listener.Accept() - if err != nil { - return - } - go func(c net.Conn) { - defer func() { _ = c.Close() }() - _, _ = io.Copy(c, c) - }(conn) - } - }() - - return listener -} - -func TestTCPProxyLeastConn(t *testing.T) { - logger := log.NewTestLogger() - - // Create two echo servers - server1Port := testutil.GetFreePort() - server2Port := testutil.GetFreePort() - server1 := startEchoServer(t, fmt.Sprintf("localhost:%d", server1Port)) - server2 := startEchoServer(t, fmt.Sprintf("localhost:%d", server2Port)) - defer func() { _ = server1.Close() }() - defer func() { _ = server2.Close() }() - - // Create proxy with two upstreams - proxyPort := testutil.GetFreePort() - rules := []*ProxyRule{ - { - ListenPort: fmt.Sprintf("%d", proxyPort), - Upstream: NewUpstream([]string{fmt.Sprintf("localhost:%d", server1Port), fmt.Sprintf("localhost:%d", server2Port)}), - }, - } - - proxy := NewTCPProxy(logger, rules) - err := proxy.Start() - require.NoError(t, err) - defer proxy.Stop() - - // Make multiple connections to verify load balancing - for i := 0; i < 10; i++ { - testProxyConnection(t, fmt.Sprintf("localhost:%d", proxyPort), "test") - time.Sleep(10 * time.Millisecond) - } -} - -func TestTCPProxyContextCancellation(t *testing.T) { - logger := log.NewTestLogger() - - serverPort := testutil.GetFreePort() - server := startEchoServer(t, fmt.Sprintf("localhost:%d", serverPort)) - defer func() { _ = server.Close() }() - - proxyPort := testutil.GetFreePort() - rules := []*ProxyRule{ - { - ListenPort: fmt.Sprintf("%d", proxyPort), - Upstream: NewUpstream([]string{fmt.Sprintf("localhost:%d", serverPort)}), - }, - } - - proxy := NewTCPProxy(logger, rules) - err := proxy.Start() - require.NoError(t, err) - - // Verify it's working - testProxyConnection(t, fmt.Sprintf("localhost:%d", proxyPort), "test") - - // Stop the proxy - proxy.Stop() - - // Verify new connections fail - _, err = net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", proxyPort), 100*time.Millisecond) - require.Error(t, err) -} diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index abbb008f..05c9270c 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -724,7 +724,9 @@ func (r *proxyStreamReceiver) recvReplicationMessages( r.logger.Info(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", attr.Messages.ExclusiveHighWatermark, ids)) // record last source exclusive high watermark (original id space) + r.ackMu.Lock() r.lastExclusiveHighOriginal = attr.Messages.ExclusiveHighWatermark + r.ackMu.Unlock() // update tracker for incoming messages if r.streamTracker != nil && r.streamID != "" { @@ -986,14 +988,15 @@ func (r *proxyStreamReceiver) sendAck( } } lastSentMin := r.lastSentMin + lastExclusiveHighOriginal := r.lastExclusiveHighOriginal r.ackMu.Unlock() if !first && min >= lastSentMin { // Clamp ACK to last known exclusive high watermark from source - if r.lastExclusiveHighOriginal > 0 && min > r.lastExclusiveHighOriginal { + if lastExclusiveHighOriginal > 0 && min > lastExclusiveHighOriginal { r.logger.Warn("Aggregated ACK exceeds last source high watermark; clamping", tag.NewInt64("ack_min", min), - tag.NewInt64("source_exclusive_high", r.lastExclusiveHighOriginal)) - min = r.lastExclusiveHighOriginal + tag.NewInt64("source_exclusive_high", lastExclusiveHighOriginal)) + min = lastExclusiveHighOriginal } // Send aggregated minimal ack upstream aggregated := &adminservice.StreamWorkflowReplicationMessagesRequest{ diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index fd6a0d65..d6c27c90 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -116,6 +116,7 @@ type ( ml *memberlist.Memberlist delegate *shardDelegate mutex sync.RWMutex + mlMutex sync.RWMutex // Protects memberlist operations (Members, NumMembers, UpdateNode, etc.) localAddr string started bool onPeerJoin func(nodeName string) @@ -143,6 +144,10 @@ type ( // localReceiverCancelFuncs maps shard IDs to context cancel functions for local receiver termination localReceiverCancelFuncs map[history.ClusterShardID]context.CancelFunc localReceiverCancelFuncsMu sync.RWMutex + // remoteNodeStates stores remote node shard states (from MergeRemoteState) + // keyed by node name, includes the meta (shard state) information + remoteNodeStates map[string]NodeShardState + remoteNodeStatesMu sync.RWMutex } // shardDelegate implements memberlist.Delegate for shard state management @@ -171,8 +176,39 @@ type ( Shards map[string]ShardInfo `json:"shards"` Updated time.Time `json:"updated"` } + + // memberSnapshot is a thread-safe copy of memberlist node data + memberSnapshot struct { + Name string + Meta []byte + } ) +// getMembersSnapshot returns a thread-safe snapshot of remote node states. +// Uses the remoteNodeStates map instead of ml.Members() to avoid data races. +func (sm *shardManagerImpl) getMembersSnapshot() []memberSnapshot { + sm.remoteNodeStatesMu.RLock() + defer sm.remoteNodeStatesMu.RUnlock() + + snapshots := make([]memberSnapshot, 0, len(sm.remoteNodeStates)) + for nodeName, state := range sm.remoteNodeStates { + // Marshal the state to get the meta bytes + metaBytes, err := json.Marshal(state) + if err != nil { + sm.logger.Warn("Failed to marshal node state for snapshot", + tag.NewStringTag("node", nodeName), + tag.Error(err)) + continue + } + snapshot := memberSnapshot{ + Name: nodeName, + Meta: metaBytes, + } + snapshots = append(snapshots, snapshot) + } + return snapshots +} + // NewShardManager creates a new shard manager instance func NewShardManager(memberlistConfig *config.MemberlistConfig, shardCountConfig config.ShardCountConfig, intraProxyTLSConfig encryption.TLSConfig, logger log.Logger) ShardManager { delegate := &shardDelegate{ @@ -191,6 +227,7 @@ func NewShardManager(memberlistConfig *config.MemberlistConfig, shardCountConfig remoteSendChannels: make(map[history.ClusterShardID]chan RoutedMessage), localAckChannels: make(map[history.ClusterShardID]chan RoutedAck), localReceiverCancelFuncs: make(map[history.ClusterShardID]context.CancelFunc), + remoteNodeStates: make(map[string]NodeShardState), } delegate.manager = sm @@ -368,7 +405,11 @@ func (sm *shardManagerImpl) Stop() { } func (sm *shardManagerImpl) shutdownMemberlist() { - if sm.ml == nil { + sm.mutex.RLock() + ml := sm.ml + sm.mutex.RUnlock() + + if ml == nil { return } @@ -377,16 +418,22 @@ func (sm *shardManagerImpl) shutdownMemberlist() { sm.joinWg.Wait() // Leave the cluster gracefully - err := sm.ml.Leave(5 * time.Second) + sm.mlMutex.Lock() + err := ml.Leave(5 * time.Second) if err != nil { sm.logger.Error("Error leaving memberlist cluster", tag.Error(err)) } - err = sm.ml.Shutdown() + err = ml.Shutdown() if err != nil { sm.logger.Error("Error shutting down memberlist", tag.Error(err)) } + sm.mlMutex.Unlock() + + // Clear pointer under main mutex + sm.mutex.Lock() sm.ml = nil + sm.mutex.Unlock() } // startJoinLoop starts the join retry loop if not already running @@ -442,7 +489,10 @@ func (sm *shardManagerImpl) retryJoinCluster() { tag.NewStringTag("attempt", strconv.Itoa(attempt)), tag.NewStringTag("joinAddrs", fmt.Sprintf("%v", joinAddrs))) + // Serialize Join with other memberlist operations + sm.mlMutex.Lock() num, err := ml.Join(joinAddrs) + sm.mlMutex.Unlock() if err != nil { sm.logger.Warn("Failed to join cluster", tag.Error(err)) @@ -472,10 +522,20 @@ func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) sm.broadcastShardChange("register", clientShardID) // Trigger memberlist metadata update to propagate NodeMeta to other nodes - if sm.ml != nil { - if err := sm.ml.UpdateNode(0); err != nil { // 0 timeout means immediate update - sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) - } + // Run asynchronously to avoid blocking callers + sm.mutex.RLock() + ml := sm.ml + sm.mutex.RUnlock() + if ml != nil { + go func() { + // Use mlMutex to serialize with getMembersSnapshot and other memberlist operations + sm.mlMutex.Lock() + err := ml.UpdateNode(0) // 0 timeout means immediate update + sm.mlMutex.Unlock() + if err != nil { + sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) + } + }() } // Notify listeners if sm.onLocalShardChange != nil { @@ -499,10 +559,20 @@ func (sm *shardManagerImpl) UnregisterShard(clientShardID history.ClusterShardID sm.broadcastShardChange("unregister", clientShardID) // Trigger memberlist metadata update to propagate NodeMeta to other nodes - if sm.ml != nil { - if err := sm.ml.UpdateNode(0); err != nil { // 0 timeout means immediate update - sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) - } + // Run asynchronously to avoid blocking callers + sm.mutex.RLock() + ml := sm.ml + sm.mutex.RUnlock() + if ml != nil { + go func() { + // Use mlMutex to serialize with getMembersSnapshot and other memberlist operations + sm.mlMutex.Lock() + err := ml.UpdateNode(0) // 0 timeout means immediate update + sm.mlMutex.Unlock() + if err != nil { + sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) + } + }() } // Notify listeners if sm.onLocalShardChange != nil { @@ -549,14 +619,14 @@ func (sm *shardManagerImpl) GetMemberNodes() []string { } // Use a timeout to prevent deadlocks when memberlist is busy - membersChan := make(chan []*memberlist.Node, 1) + membersChan := make(chan []memberSnapshot, 1) go func() { defer func() { if r := recover(); r != nil { sm.logger.Error("Panic in GetMemberNodes", tag.NewStringTag("error", fmt.Sprintf("%v", r))) } }() - membersChan <- sm.ml.Members() + membersChan <- sm.getMembersSnapshot() }() select { @@ -683,45 +753,21 @@ func (sm *shardManagerImpl) GetShardOwner(shard history.ClusterShardID) (string, } // GetRemoteShardsForPeer returns all shards owned by the specified peer node. -// Non-blocking: uses memberlist metadata and tolerates timeouts by returning a best-effort set. +// Uses the remoteNodeStates map instead of ml.Members() to avoid data races. func (sm *shardManagerImpl) GetRemoteShardsForPeer(peerNodeName string) (map[string]NodeShardState, error) { result := make(map[string]NodeShardState) - if sm.ml == nil { - return result, nil - } - // Read members with a short timeout to avoid blocking debug paths - membersChan := make(chan []*memberlist.Node, 1) - go func() { - defer func() { _ = recover() }() - sm.mutex.RLock() - defer sm.mutex.RUnlock() - membersChan <- sm.ml.Members() - }() + sm.remoteNodeStatesMu.RLock() + defer sm.remoteNodeStatesMu.RUnlock() - var members []*memberlist.Node - select { - case members = <-membersChan: - case <-time.After(100 * time.Millisecond): - sm.logger.Warn("GetRemoteShardsForPeer timeout") - return result, fmt.Errorf("timeout") - } - - for _, member := range members { - if member == nil || len(member.Meta) == 0 { - continue - } - if member.Name == sm.GetNodeName() { - continue - } - if peerNodeName != "" && member.Name != peerNodeName { + for nodeName, state := range sm.remoteNodeStates { + if nodeName == sm.GetNodeName() { continue } - var nodeState NodeShardState - if err := json.Unmarshal(member.Meta, &nodeState); err != nil { + if peerNodeName != "" && nodeName != peerNodeName { continue } - result[member.Name] = nodeState + result[nodeName] = state } return result, nil @@ -915,21 +961,58 @@ func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.C return } - for _, member := range sm.ml.Members() { + // Use remoteNodeStates map to get list of nodes to send to + sm.remoteNodeStatesMu.RLock() + nodeNames := make([]string, 0, len(sm.remoteNodeStates)) + for nodeName := range sm.remoteNodeStates { // Skip sending to self node - if member.Name == sm.GetNodeName() { + if nodeName == sm.GetNodeName() { continue } + nodeNames = append(nodeNames, nodeName) + } + sm.remoteNodeStatesMu.RUnlock() + for _, nodeName := range nodeNames { // Send in goroutine to make it non-blocking - go func(m *memberlist.Node) { - err := sm.ml.SendReliable(m, data) + // Look up fresh node pointer when sending to avoid race with memberlist updates + go func(targetNodeName string) { + sm.mutex.RLock() + ml := sm.ml + sm.mutex.RUnlock() + + if ml == nil { + return + } + + // Find the node by name from current members list + // Serialize memberlist operations to prevent races with UpdateNode + sm.mlMutex.RLock() + var targetNode *memberlist.Node + for _, n := range ml.Members() { + if n != nil && n.Name == targetNodeName { + targetNode = n + break + } + } + if targetNode == nil { + sm.mlMutex.RUnlock() + sm.logger.Warn("Node not found for broadcast", + tag.NewStringTag("target_node", targetNodeName)) + return + } + // SendReliable reads node fields (via FullAddress()), so we must hold the lock + // to prevent races with memberlist's internal updates. + // Note: This is a blocking network call, but we need the lock to prevent + // memberlist from modifying the node while SendReliable reads its fields. + err := ml.SendReliable(targetNode, data) + sm.mlMutex.RUnlock() if err != nil { sm.logger.Error("Failed to broadcast shard change", tag.Error(err), - tag.NewStringTag("target_node", m.Name)) + tag.NewStringTag("target_node", targetNodeName)) } - }(member) + }(nodeName) } } @@ -986,7 +1069,10 @@ func (sd *shardDelegate) NotifyMsg(data []byte) { // if shard is previously registered as local shard, but now is registered as remote shard, // check if the remote shard is newer than the local shard. If so, unregister the local shard. if added { + // Lock when reading localShards to prevent race with concurrent writes + sd.manager.mutex.RLock() localShard, ok := sd.manager.localShards[ClusterShardIDtoShortString(msg.ClientShard)] + sd.manager.mutex.RUnlock() if ok { if localShard.Created.Before(msg.Timestamp) { // Force unregister the local shard by passing its own timestamp @@ -1015,6 +1101,13 @@ func (sd *shardDelegate) MergeRemoteState(buf []byte, join bool) { return } + // Save the remote state to local map + if sd.manager != nil { + sd.manager.remoteNodeStatesMu.Lock() + sd.manager.remoteNodeStates[state.NodeName] = state + sd.manager.remoteNodeStatesMu.Unlock() + } + sd.logger.Info("Merged remote shard state", tag.NewStringTag("node", state.NodeName), tag.NewStringTag("shards", strconv.Itoa(len(state.Shards))), @@ -1225,9 +1318,18 @@ func (sed *shardEventDelegate) NotifyLeave(node *memberlist.Node) { tag.NewStringTag("node", node.Name), tag.NewStringTag("addr", node.Addr.String())) + // Remove the node from remoteNodeStates map + if sed.manager != nil { + sed.manager.remoteNodeStatesMu.Lock() + delete(sed.manager.remoteNodeStates, node.Name) + sed.manager.remoteNodeStatesMu.Unlock() + } + // If we're now isolated and have join addresses configured, restart join loop if sed.manager != nil && sed.manager.ml != nil && sed.manager.memberlistConfig != nil { + sed.manager.mlMutex.RLock() numMembers := sed.manager.ml.NumMembers() + sed.manager.mlMutex.RUnlock() if numMembers == 1 && len(sed.manager.memberlistConfig.JoinAddrs) > 0 { sed.logger.Info("Node is now isolated, restarting join loop", tag.NewStringTag("numMembers", strconv.Itoa(numMembers))) diff --git a/proxy/test/bench_test.go b/proxy/test/bench_test.go index 090b2a9e..6f72943b 100644 --- a/proxy/test/bench_test.go +++ b/proxy/test/bench_test.go @@ -1,7 +1,6 @@ package proxy import ( - "fmt" "testing" "go.temporal.io/server/api/adminservice/v1" @@ -11,7 +10,6 @@ import ( "github.com/temporalio/s2s-proxy/config" "github.com/temporalio/s2s-proxy/endtoendtest" - "github.com/temporalio/s2s-proxy/testutil" ) func createEchoServerConfigWithPorts( @@ -88,12 +86,12 @@ func createEchoClientConfigWithPorts( func benchmarkStreamSendRecvWithoutProxy(b *testing.B, payloadSize int) { echoServerInfo := endtoendtest.ClusterInfo{ - ServerAddress: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + ServerAddress: GetLocalhostAddress(), ClusterShardID: serverClusterShard, } echoClientInfo := endtoendtest.ClusterInfo{ - ServerAddress: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + ServerAddress: GetLocalhostAddress(), ClusterShardID: clientClusterShard, } @@ -105,12 +103,12 @@ func benchmarkStreamSendRecvWithMuxProxy(b *testing.B, payloadSize int) { muxTransportName := "muxed" // Allocate ports dynamically - echoServerAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - serverProxyInboundAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - serverProxyOutboundAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - echoClientAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - clientProxyInboundAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - clientProxyOutboundAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + echoServerAddress := GetLocalhostAddress() + serverProxyInboundAddress := GetLocalhostAddress() + serverProxyOutboundAddress := GetLocalhostAddress() + echoClientAddress := GetLocalhostAddress() + clientProxyInboundAddress := GetLocalhostAddress() + clientProxyOutboundAddress := GetLocalhostAddress() echoServerConfig := createEchoServerConfigWithPorts( echoServerAddress, diff --git a/proxy/test/echo_proxy_test.go b/proxy/test/echo_proxy_test.go index 3b4850db..0788feaa 100644 --- a/proxy/test/echo_proxy_test.go +++ b/proxy/test/echo_proxy_test.go @@ -1,7 +1,6 @@ package proxy import ( - "fmt" "os" "path/filepath" "testing" @@ -17,7 +16,6 @@ import ( "github.com/temporalio/s2s-proxy/config" "github.com/temporalio/s2s-proxy/endtoendtest" - "github.com/temporalio/s2s-proxy/testutil" "github.com/temporalio/s2s-proxy/transport/mux" ) @@ -252,12 +250,12 @@ func (s *proxyTestSuite) SetupTest() { s.NoError(err) // Allocate free ports for each test - s.echoServerAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.serverProxyInboundAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.serverProxyOutboundAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.echoClientAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.clientProxyInboundAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.clientProxyOutboundAddress = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.echoServerAddress = GetLocalhostAddress() + s.serverProxyInboundAddress = GetLocalhostAddress() + s.serverProxyOutboundAddress = GetLocalhostAddress() + s.echoClientAddress = GetLocalhostAddress() + s.clientProxyInboundAddress = GetLocalhostAddress() + s.clientProxyOutboundAddress = GetLocalhostAddress() } func (s *proxyTestSuite) TearDownTest() { diff --git a/proxy/test/intra_proxy_routing_test.go b/proxy/test/intra_proxy_routing_test.go index 87a132d0..538c5284 100644 --- a/proxy/test/intra_proxy_routing_test.go +++ b/proxy/test/intra_proxy_routing_test.go @@ -17,7 +17,6 @@ import ( "github.com/temporalio/s2s-proxy/config" s2sproxy "github.com/temporalio/s2s-proxy/proxy" - "github.com/temporalio/s2s-proxy/testutil" ) type ( @@ -79,26 +78,26 @@ func (s *IntraProxyRoutingTestSuite) SetupSuite() { s.clusterA = createCluster(s.logger, s.T(), "cluster-a", 2, 1, 1) s.clusterB = createCluster(s.logger, s.T(), "cluster-b", 2, 2, 1) - s.proxyA1Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.proxyA2Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.proxyB1Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.proxyB2Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyA1Outbound = GetLocalhostAddress() + s.proxyA2Outbound = GetLocalhostAddress() + s.proxyB1Outbound = GetLocalhostAddress() + s.proxyB2Outbound = GetLocalhostAddress() - s.proxyB1Mux = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.proxyB2Mux = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyB1Mux = GetLocalhostAddress() + s.proxyB2Mux = GetLocalhostAddress() - loadBalancerAPort := fmt.Sprintf("%d", testutil.GetFreePort()) - loadBalancerBPort := fmt.Sprintf("%d", testutil.GetFreePort()) - loadBalancerCPort := fmt.Sprintf("%d", testutil.GetFreePort()) + loadBalancerAPort := fmt.Sprintf("%d", GetFreePort()) + loadBalancerBPort := fmt.Sprintf("%d", GetFreePort()) + loadBalancerCPort := fmt.Sprintf("%d", GetFreePort()) s.loadBalancerAPort = loadBalancerAPort s.loadBalancerBPort = loadBalancerBPort s.loadBalancerCPort = loadBalancerCPort - proxyA1Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - proxyA2Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - proxyB1Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - proxyB2Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyA1Address := GetLocalhostAddress() + proxyA2Address := GetLocalhostAddress() + proxyB1Address := GetLocalhostAddress() + proxyB2Address := GetLocalhostAddress() proxyAddressesA := map[string]string{ "proxy-node-a-1": proxyA1Address, @@ -109,10 +108,10 @@ func (s *IntraProxyRoutingTestSuite) SetupSuite() { "proxy-node-b-2": proxyB2Address, } - s.proxyA1MemberlistPort = testutil.GetFreePort() - s.proxyA2MemberlistPort = testutil.GetFreePort() - s.proxyB1MemberlistPort = testutil.GetFreePort() - s.proxyB2MemberlistPort = testutil.GetFreePort() + s.proxyA1MemberlistPort = GetFreePort() + s.proxyA2MemberlistPort = GetFreePort() + s.proxyB1MemberlistPort = GetFreePort() + s.proxyB2MemberlistPort = GetFreePort() s.proxyB1 = createProxy(s.logger, s.T(), "proxy-b-1", proxyB1Address, s.proxyB1Outbound, s.proxyB1Mux, s.clusterB, config.ServerMode, config.ShardCountConfig{}, "proxy-node-b-1", "127.0.0.1", s.proxyB1MemberlistPort, nil, proxyAddressesB) s.proxyB2 = createProxy(s.logger, s.T(), "proxy-b-2", proxyB2Address, s.proxyB2Outbound, s.proxyB2Mux, s.clusterB, config.ServerMode, config.ShardCountConfig{}, "proxy-node-b-2", "127.0.0.1", s.proxyB2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyB1MemberlistPort)}, proxyAddressesB) diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index d8580fa8..f8229b71 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -25,7 +25,6 @@ import ( "github.com/temporalio/s2s-proxy/config" s2sproxy "github.com/temporalio/s2s-proxy/proxy" - "github.com/temporalio/s2s-proxy/testutil" ) type SetupMode string @@ -270,9 +269,9 @@ func (s *ReplicationTestSuite) SetupSuite() { func (s *ReplicationTestSuite) setupSimple() { s.logger.Info("Setting up simple two-proxy configuration") - proxyAOutbound := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - proxyBOutbound := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - muxServerAddress := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyAOutbound := GetLocalhostAddress() + proxyBOutbound := GetLocalhostAddress() + muxServerAddress := GetLocalhostAddress() s.proxyAOutbound = proxyAOutbound s.proxyBOutbound = proxyBOutbound @@ -290,26 +289,26 @@ func (s *ReplicationTestSuite) setupSimple() { func (s *ReplicationTestSuite) setupMultiProxy() { s.logger.Info("Setting up multi-proxy configuration with load balancers") - s.proxyA1Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.proxyA2Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.proxyB1Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.proxyB2Outbound = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyA1Outbound = GetLocalhostAddress() + s.proxyA2Outbound = GetLocalhostAddress() + s.proxyB1Outbound = GetLocalhostAddress() + s.proxyB2Outbound = GetLocalhostAddress() - s.proxyB1Mux = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - s.proxyB2Mux = fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + s.proxyB1Mux = GetLocalhostAddress() + s.proxyB2Mux = GetLocalhostAddress() - loadBalancerAPort := fmt.Sprintf("%d", testutil.GetFreePort()) - loadBalancerBPort := fmt.Sprintf("%d", testutil.GetFreePort()) - loadBalancerCPort := fmt.Sprintf("%d", testutil.GetFreePort()) + loadBalancerAPort := fmt.Sprintf("%d", GetFreePort()) + loadBalancerBPort := fmt.Sprintf("%d", GetFreePort()) + loadBalancerCPort := fmt.Sprintf("%d", GetFreePort()) s.loadBalancerAPort = loadBalancerAPort s.loadBalancerBPort = loadBalancerBPort s.loadBalancerCPort = loadBalancerCPort - proxyA1Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - proxyA2Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - proxyB1Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) - proxyB2Address := fmt.Sprintf("localhost:%d", testutil.GetFreePort()) + proxyA1Address := GetLocalhostAddress() + proxyA2Address := GetLocalhostAddress() + proxyB1Address := GetLocalhostAddress() + proxyB2Address := GetLocalhostAddress() // For intra-proxy communication, use outbound addresses where proxies listen proxyAddressesA := map[string]string{ @@ -321,10 +320,10 @@ func (s *ReplicationTestSuite) setupMultiProxy() { "proxy-node-b-2": s.proxyB2Outbound, } - s.proxyA1MemberlistPort = testutil.GetFreePort() - s.proxyA2MemberlistPort = testutil.GetFreePort() - s.proxyB1MemberlistPort = testutil.GetFreePort() - s.proxyB2MemberlistPort = testutil.GetFreePort() + s.proxyA1MemberlistPort = GetFreePort() + s.proxyA2MemberlistPort = GetFreePort() + s.proxyB1MemberlistPort = GetFreePort() + s.proxyB2MemberlistPort = GetFreePort() proxyBShardConfig := s.shardCountConfigB if proxyBShardConfig.Mode == config.ShardCountLCM || proxyBShardConfig.Mode == config.ShardCountRouting { diff --git a/endtoendtest/tcp_proxy.go b/proxy/test/tcp_proxy.go similarity index 99% rename from endtoendtest/tcp_proxy.go rename to proxy/test/tcp_proxy.go index 9cacac2c..bbdd811c 100644 --- a/endtoendtest/tcp_proxy.go +++ b/proxy/test/tcp_proxy.go @@ -1,4 +1,4 @@ -package endtoendtest +package proxy import ( "context" diff --git a/proxy/test/tcp_proxy_test.go b/proxy/test/tcp_proxy_test.go new file mode 100644 index 00000000..f0aaff84 --- /dev/null +++ b/proxy/test/tcp_proxy_test.go @@ -0,0 +1,171 @@ +package proxy + +import ( + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.temporal.io/server/common/log" +) + +func TestTCPProxy(t *testing.T) { + logger := log.NewTestLogger() + + server1Addr := GetLocalhostAddress() + server2Addr := GetLocalhostAddress() + server3Addr := GetLocalhostAddress() + server4Addr := GetLocalhostAddress() + server5Addr := GetLocalhostAddress() + server6Addr := GetLocalhostAddress() + + echoServer1 := startEchoServer(t, server1Addr) + echoServer2 := startEchoServer(t, server2Addr) + echoServer3 := startEchoServer(t, server3Addr) + echoServer4 := startEchoServer(t, server4Addr) + echoServer5 := startEchoServer(t, server5Addr) + echoServer6 := startEchoServer(t, server6Addr) + + defer func() { _ = echoServer1.Close() }() + defer func() { _ = echoServer2.Close() }() + defer func() { _ = echoServer3.Close() }() + defer func() { _ = echoServer4.Close() }() + defer func() { _ = echoServer5.Close() }() + defer func() { _ = echoServer6.Close() }() + + proxyAddr1 := GetLocalhostAddress() + proxyAddr2 := GetLocalhostAddress() + proxyAddr3 := GetLocalhostAddress() + + _, proxyPort1, _ := net.SplitHostPort(proxyAddr1) + _, proxyPort2, _ := net.SplitHostPort(proxyAddr2) + _, proxyPort3, _ := net.SplitHostPort(proxyAddr3) + + rules := []*ProxyRule{ + { + ListenPort: proxyPort1, + Upstream: NewUpstream([]string{server1Addr, server2Addr}), + }, + { + ListenPort: proxyPort2, + Upstream: NewUpstream([]string{server3Addr, server4Addr}), + }, + { + ListenPort: proxyPort3, + Upstream: NewUpstream([]string{server5Addr, server6Addr}), + }, + } + + proxy := NewTCPProxy(logger, rules) + err := proxy.Start() + require.NoError(t, err) + defer proxy.Stop() + + // Test proxy on port 1 + testProxyConnection(t, proxyAddr1, "test message 1") + + // Test proxy on port 2 + testProxyConnection(t, proxyAddr2, "test message 2") + + // Test proxy on port 3 + testProxyConnection(t, proxyAddr3, "test message 3") +} + +func testProxyConnection(t *testing.T, proxyAddr, message string) { + conn, err := net.DialTimeout("tcp", proxyAddr, 5*time.Second) + require.NoError(t, err) + defer func() { _ = conn.Close() }() + + _, err = conn.Write([]byte(message)) + require.NoError(t, err) + + buf := make([]byte, len(message)) + _, err = io.ReadFull(conn, buf) + require.NoError(t, err) + require.Equal(t, message, string(buf)) +} + +func startEchoServer(t *testing.T, addr string) net.Listener { + listener, err := net.Listen("tcp", addr) + require.NoError(t, err) + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { _ = c.Close() }() + _, _ = io.Copy(c, c) + }(conn) + } + }() + + return listener +} + +func TestTCPProxyLeastConn(t *testing.T) { + logger := log.NewTestLogger() + + // Create two echo servers + server1Addr := GetLocalhostAddress() + server2Addr := GetLocalhostAddress() + server1 := startEchoServer(t, server1Addr) + server2 := startEchoServer(t, server2Addr) + defer func() { _ = server1.Close() }() + defer func() { _ = server2.Close() }() + + // Create proxy with two upstreams + proxyAddr := GetLocalhostAddress() + _, proxyPort, _ := net.SplitHostPort(proxyAddr) + rules := []*ProxyRule{ + { + ListenPort: proxyPort, + Upstream: NewUpstream([]string{server1Addr, server2Addr}), + }, + } + + proxy := NewTCPProxy(logger, rules) + err := proxy.Start() + require.NoError(t, err) + defer proxy.Stop() + + // Make multiple connections to verify load balancing + for i := 0; i < 10; i++ { + testProxyConnection(t, proxyAddr, "test") + time.Sleep(10 * time.Millisecond) + } +} + +func TestTCPProxyContextCancellation(t *testing.T) { + logger := log.NewTestLogger() + + serverAddr := GetLocalhostAddress() + server := startEchoServer(t, serverAddr) + defer func() { _ = server.Close() }() + + proxyAddr := GetLocalhostAddress() + _, proxyPort, _ := net.SplitHostPort(proxyAddr) + rules := []*ProxyRule{ + { + ListenPort: proxyPort, + Upstream: NewUpstream([]string{serverAddr}), + }, + } + + proxy := NewTCPProxy(logger, rules) + err := proxy.Start() + require.NoError(t, err) + + // Verify it's working + testProxyConnection(t, proxyAddr, "test") + + // Stop the proxy + proxy.Stop() + + // Verify new connections fail + _, err = net.DialTimeout("tcp", proxyAddr, 100*time.Millisecond) + require.Error(t, err) +} diff --git a/proxy/test/test_common.go b/proxy/test/test_common.go index de132cc4..30f8189a 100644 --- a/proxy/test/test_common.go +++ b/proxy/test/test_common.go @@ -474,3 +474,23 @@ func getTestingT(t testingT) *testing.T { } panic("testingT must be *testing.T or have T() method") } + +// GetFreePort returns an available TCP port by listening on localhost:0. +// This is useful for tests that need to allocate ports dynamically. +func GetFreePort() int { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + panic(fmt.Sprintf("failed to get free port: %v", err)) + } + defer func() { + if err := l.Close(); err != nil { + fmt.Printf("Failed to close listener: %v\n", err) + } + }() + return l.Addr().(*net.TCPAddr).Port +} + +// GetLocalhostAddress returns a localhost address with a free port +func GetLocalhostAddress() string { + return fmt.Sprintf("localhost:%d", GetFreePort()) +} diff --git a/proxy/test/wiring_test.go b/proxy/test/wiring_test.go index b7527378..55e58fc5 100644 --- a/proxy/test/wiring_test.go +++ b/proxy/test/wiring_test.go @@ -23,7 +23,6 @@ import ( "github.com/temporalio/s2s-proxy/config" "github.com/temporalio/s2s-proxy/endtoendtest" - "github.com/temporalio/s2s-proxy/testutil" ) func init() { @@ -46,29 +45,39 @@ var ( ) func getEchoServerInfo() endtoendtest.ClusterInfo { + echoServerAddress := GetLocalhostAddress() + serverProxyInboundAddress := GetLocalhostAddress() + serverProxyOutboundAddress := GetLocalhostAddress() + prometheusAddress := GetLocalhostAddress() + healthCheckAddress := GetLocalhostAddress() return endtoendtest.ClusterInfo{ - ServerAddress: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + ServerAddress: echoServerAddress, ClusterShardID: serverClusterShard, S2sProxyConfig: makeS2SConfig(s2sAddresses{ - echoServer: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), - inbound: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), - outbound: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), - prometheus: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), - healthCheck: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + echoServer: echoServerAddress, + inbound: serverProxyInboundAddress, + outbound: serverProxyOutboundAddress, + prometheus: prometheusAddress, + healthCheck: healthCheckAddress, }), } } func getEchoClientInfo() endtoendtest.ClusterInfo { + echoClientAddress := GetLocalhostAddress() + clientProxyInboundAddress := GetLocalhostAddress() + clientProxyOutboundAddress := GetLocalhostAddress() + prometheusAddress := GetLocalhostAddress() + healthCheckAddress := GetLocalhostAddress() return endtoendtest.ClusterInfo{ - ServerAddress: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + ServerAddress: echoClientAddress, ClusterShardID: clientClusterShard, S2sProxyConfig: makeS2SConfig(s2sAddresses{ - echoServer: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), - inbound: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), - outbound: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), - prometheus: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), - healthCheck: fmt.Sprintf("localhost:%d", testutil.GetFreePort()), + echoServer: echoClientAddress, + inbound: clientProxyInboundAddress, + outbound: clientProxyOutboundAddress, + prometheus: prometheusAddress, + healthCheck: healthCheckAddress, }), } } @@ -137,9 +146,6 @@ func TestEOFFromServer(t *testing.T) { func TestWiringWithEchoService(t *testing.T) { echoServerInfo := getEchoServerInfo() echoClientInfo := getEchoClientInfo() - // Update outbound client address to point to the other proxy's inbound - echoServerInfo.S2sProxyConfig.Outbound.Client.ServerAddress = echoClientInfo.S2sProxyConfig.Inbound.Server.ListenAddress - echoClientInfo.S2sProxyConfig.Outbound.Client.ServerAddress = echoServerInfo.S2sProxyConfig.Inbound.Server.ListenAddress echoServer := endtoendtest.NewEchoServer(echoServerInfo, echoClientInfo, "EchoServer", logger, nil) echoClient := endtoendtest.NewEchoServer(echoClientInfo, echoServerInfo, "EchoClient", logger, nil) diff --git a/testutil/testutil.go b/testutil/testutil.go deleted file mode 100644 index 28e3a15f..00000000 --- a/testutil/testutil.go +++ /dev/null @@ -1,21 +0,0 @@ -package testutil - -import ( - "fmt" - "net" -) - -// GetFreePort returns an available TCP port by listening on localhost:0. -// This is useful for tests that need to allocate ports dynamically. -func GetFreePort() int { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - panic(fmt.Sprintf("failed to get free port: %v", err)) - } - defer func() { - if err := l.Close(); err != nil { - fmt.Printf("Failed to close listener: %v\n", err) - } - }() - return l.Addr().(*net.TCPAddr).Port -} From 40b222cf4c979c5291ad4cae9f0da6ae3aebe134 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Wed, 24 Dec 2025 09:43:44 -0800 Subject: [PATCH 30/38] fix test error --- proxy/intra_proxy_router.go | 14 +++++++++----- proxy/proxy_streams.go | 15 +++++++++------ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 07ddb738..24e0655b 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -265,6 +265,10 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages() error { return err } if msgs, ok := resp.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages); ok && msgs.Messages != nil { + // Capture watermark value immediately to avoid data race with sender + exclusiveHighWatermark := msgs.Messages.ExclusiveHighWatermark + priority := msgs.Messages.Priority + // Update client-side intra-proxy tracker for received messages st := GetGlobalStreamTracker() ids := make([]int64, 0, len(msgs.Messages.ReplicationTasks)) @@ -272,18 +276,18 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages() error { ids = append(ids, t.SourceTaskId) } st.UpdateStreamLastTaskIDs(r.streamID, ids) - st.UpdateStreamReplicationMessages(r.streamID, msgs.Messages.ExclusiveHighWatermark) + st.UpdateStreamReplicationMessages(r.streamID, exclusiveHighWatermark) st.UpdateStream(r.streamID) // Track last watermark for late-registering shards r.lastWatermarkMu.Lock() r.lastWatermark = &replicationv1.WorkflowReplicationMessages{ - ExclusiveHighWatermark: msgs.Messages.ExclusiveHighWatermark, - Priority: msgs.Messages.Priority, + ExclusiveHighWatermark: exclusiveHighWatermark, + Priority: priority, } r.lastWatermarkMu.Unlock() - r.logger.Info(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", msgs.Messages.ExclusiveHighWatermark, ids)) + r.logger.Info(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", exclusiveHighWatermark, ids)) msg := RoutedMessage{SourceShard: r.sourceShardID, Resp: resp} sent := false @@ -300,7 +304,7 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages() error { select { case ch <- msg: sent = true - r.logger.Info("Receiver sent ReplicationTasks to local target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID)), tag.NewInt64("exclusive_high", msgs.Messages.ExclusiveHighWatermark)) + r.logger.Info("Receiver sent ReplicationTasks to local target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID)), tag.NewInt64("exclusive_high", exclusiveHighWatermark)) case <-shutdown.Channel(): // Will be handled outside the func } diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 05c9270c..8ae63597 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -474,6 +474,7 @@ func (s *proxyStreamSender) sendReplicationMessages( originalHigh := m.Messages.ExclusiveHighWatermark s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d", m.Messages.ExclusiveHighWatermark, originalHigh)) // Ensure exclusive high watermark is in proxy task ID space + var proxyExclusiveHigh int64 if len(m.Messages.ReplicationTasks) > 0 { for _, t := range m.Messages.ReplicationTasks { // allocate proxy task id @@ -490,7 +491,8 @@ func (s *proxyStreamSender) sendReplicationMessages( originalIDs = append(originalIDs, original) proxyIDs = append(proxyIDs, proxyID) } - m.Messages.ExclusiveHighWatermark = m.Messages.ReplicationTasks[len(m.Messages.ReplicationTasks)-1].SourceTaskId + 1 + proxyExclusiveHigh = m.Messages.ReplicationTasks[len(m.Messages.ReplicationTasks)-1].SourceTaskId + 1 + m.Messages.ExclusiveHighWatermark = proxyExclusiveHigh } else { // No tasks in this batch: allocate a synthetic proxy task id mapping s.nextProxyTaskID++ @@ -498,17 +500,18 @@ func (s *proxyStreamSender) sendReplicationMessages( s.idRing.Append(proxyHigh, routed.SourceShard, originalHigh) originalIDs = append(originalIDs, originalHigh) proxyIDs = append(proxyIDs, proxyHigh) - m.Messages.ExclusiveHighWatermark = proxyHigh - s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d proxy_high=%d original", m.Messages.ExclusiveHighWatermark, originalHigh, proxyHigh)) + proxyExclusiveHigh = proxyHigh + m.Messages.ExclusiveHighWatermark = proxyExclusiveHigh + s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d proxy_high=%d original", proxyExclusiveHigh, originalHigh, proxyHigh)) } s.mu.Unlock() - // Log mapping from original -> proxy IDs - s.logger.Info(fmt.Sprintf("Sender sending ReplicationTasks from shard %s: original=%v proxy=%v", ClusterShardIDtoString(routed.SourceShard), originalIDs, proxyIDs), tag.NewInt64("exclusive_high", m.Messages.ExclusiveHighWatermark)) + // Log mapping from original -> proxy IDs (use captured value to avoid data race) + s.logger.Info(fmt.Sprintf("Sender sending ReplicationTasks from shard %s: original=%v proxy=%v", ClusterShardIDtoString(routed.SourceShard), originalIDs, proxyIDs), tag.NewInt64("exclusive_high", proxyExclusiveHigh)) if err := sourceStreamServer.Send(resp); err != nil { return err } - s.logger.Info("Sender sent ReplicationTasks", tag.NewStringTag("sourceShard", ClusterShardIDtoString(routed.SourceShard)), tag.NewInt64("exclusive_high", m.Messages.ExclusiveHighWatermark)) + s.logger.Info("Sender sent ReplicationTasks", tag.NewStringTag("sourceShard", ClusterShardIDtoString(routed.SourceShard)), tag.NewInt64("exclusive_high", proxyExclusiveHigh)) // Update keepalive state s.mu.Lock() From affc6bba904eea95f7afd7f7e5ec49c7ed2aa96a Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 2 Jan 2026 14:03:06 -0800 Subject: [PATCH 31/38] address comments: update logs level (info->debug); add comments --- proxy/admin_stream_transfer.go | 4 +- proxy/intra_proxy_router.go | 80 ++++++++++++------------- proxy/proxy_streams.go | 64 ++++++++++---------- proxy/shard_manager.go | 10 ++-- proxy/test/replication_failover_test.go | 50 ++++------------ proxy/test/tcp_proxy.go | 2 + 6 files changed, 94 insertions(+), 116 deletions(-) diff --git a/proxy/admin_stream_transfer.go b/proxy/admin_stream_transfer.go index ceae61c6..a48ff021 100644 --- a/proxy/admin_stream_transfer.go +++ b/proxy/admin_stream_transfer.go @@ -190,7 +190,7 @@ func (f *StreamForwarder) forwardReplicationMessages(wg *sync.WaitGroup) { for i, task := range attr.Messages.ReplicationTasks { msg = append(msg, fmt.Sprintf("[%d]: %v", i, task.SourceTaskId)) } - f.logger.Info(fmt.Sprintf("forwarding ReplicationMessages: exclusive %v, tasks: %v", attr.Messages.ExclusiveHighWatermark, strings.Join(msg, ", "))) + f.logger.Debug(fmt.Sprintf("forwarding ReplicationMessages: exclusive %v, tasks: %v", attr.Messages.ExclusiveHighWatermark, strings.Join(msg, ", "))) streamTracker := GetGlobalStreamTracker() streamTracker.UpdateStreamReplicationMessages(f.streamID, attr.Messages.ExclusiveHighWatermark) @@ -264,7 +264,7 @@ func (f *StreamForwarder) forwardAcks(wg *sync.WaitGroup) { switch attr := req.GetAttributes().(type) { case *adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState: - f.logger.Info(fmt.Sprintf("forwarding SyncReplicationState: inclusive %v, attr: %v", attr.SyncReplicationState.InclusiveLowWatermark, attr)) + f.logger.Debug(fmt.Sprintf("forwarding SyncReplicationState: inclusive %v, attr: %v", attr.SyncReplicationState.InclusiveLowWatermark, attr)) var watermarkTime *time.Time if attr.SyncReplicationState.InclusiveLowWatermarkTime != nil { diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 24e0655b..c17416de 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -94,7 +94,7 @@ func (s *intraProxyStreamSender) Run( // that has a pending watermark, and send it immediately to the peer if receiver, ok := s.shardManager.GetActiveReceiver(s.sourceShardID); ok { if lastWatermark := receiver.GetLastWatermark(); lastWatermark != nil && lastWatermark.ExclusiveHighWatermark > 0 { - s.logger.Info("Sending pending watermark to peer on sender registration", + s.logger.Debug("Sending pending watermark to peer on sender registration", tag.NewInt64("exclusive_high", lastWatermark.ExclusiveHighWatermark), tag.NewStringTag("peer", s.peerNodeName)) resp := &adminservice.StreamWorkflowReplicationMessagesResponse{ @@ -117,16 +117,16 @@ func (s *intraProxyStreamSender) Run( // recvAck reads ACKs from the peer and routes them to the source shard owner. func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) error { - s.logger.Info("intraProxyStreamSender recvAck") + s.logger.Debug("intraProxyStreamSender recvAck") defer func() { - s.logger.Info("intraProxyStreamSender recvAck finished") + s.logger.Debug("intraProxyStreamSender recvAck finished") shutdownChan.Shutdown() }() for !shutdownChan.IsShutdown() { req, err := s.sourceStreamServer.Recv() if err == io.EOF { - s.logger.Info("intraProxyStreamSender recvAck encountered EOF") + s.logger.Debug("intraProxyStreamSender recvAck encountered EOF") return nil } if err != nil { @@ -136,7 +136,7 @@ func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) erro if attr, ok := req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { ack := attr.SyncReplicationState.InclusiveLowWatermark - s.logger.Info("Sender received upstream ACK", tag.NewInt64("inclusive_low", ack)) + s.logger.Debug("Sender received upstream ACK", tag.NewInt64("inclusive_low", ack)) // Update server-side intra-proxy stream tracker with sync watermark st := GetGlobalStreamTracker() @@ -152,7 +152,7 @@ func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) erro }, } - s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(s.sourceShardID)), tag.NewInt64("ack", ack)) + s.logger.Debug("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(s.sourceShardID)), tag.NewInt64("ack", ack)) // FIXME: should retry. If not succeed, return and shutdown the stream sent := s.shardManager.DeliverAckToShardOwner(s.sourceShardID, routedAck, shutdownChan, s.logger, ack, false) if !sent { @@ -248,8 +248,8 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, shardManager ShardMa // recvReplicationMessages receives replication messages and forwards to local shard owner. func (r *intraProxyStreamReceiver) recvReplicationMessages() error { - r.logger.Info("intraProxyStreamReceiver recvReplicationMessages started") - defer r.logger.Info("intraProxyStreamReceiver recvReplicationMessages finished") + r.logger.Debug("intraProxyStreamReceiver recvReplicationMessages started") + defer r.logger.Debug("intraProxyStreamReceiver recvReplicationMessages finished") shutdown := r.shutdown defer shutdown.Shutdown() @@ -257,7 +257,7 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages() error { for !shutdown.IsShutdown() { resp, err := r.streamClient.Recv() if err == io.EOF { - r.logger.Info("recvReplicationMessages encountered EOF") + r.logger.Debug("recvReplicationMessages encountered EOF") return nil } if err != nil { @@ -287,7 +287,7 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages() error { } r.lastWatermarkMu.Unlock() - r.logger.Info(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", exclusiveHighWatermark, ids)) + r.logger.Debug(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", exclusiveHighWatermark, ids)) msg := RoutedMessage{SourceShard: r.sourceShardID, Resp: resp} sent := false @@ -304,7 +304,7 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages() error { select { case ch <- msg: sent = true - r.logger.Info("Receiver sent ReplicationTasks to local target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID)), tag.NewInt64("exclusive_high", exclusiveHighWatermark)) + r.logger.Debug("Receiver sent ReplicationTasks to local target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID)), tag.NewInt64("exclusive_high", exclusiveHighWatermark)) case <-shutdown.Channel(): // Will be handled outside the func } @@ -332,8 +332,8 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages() error { // sendAck sends an ACK upstream via the client stream and updates tracker. func (r *intraProxyStreamReceiver) sendAck(req *adminservice.StreamWorkflowReplicationMessagesRequest) error { - r.logger.Info("intraProxyStreamReceiver sendAck started") - defer r.logger.Info("intraProxyStreamReceiver sendAck finished") + r.logger.Debug("intraProxyStreamReceiver sendAck started") + defer r.logger.Debug("intraProxyStreamReceiver sendAck finished") if err := r.streamClient.Send(req); err != nil { return err @@ -380,7 +380,7 @@ func (r *intraProxyStreamReceiver) sendPendingWatermarkToShard(targetShardID his return } - r.logger.Info("Sending pending watermark to newly registered shard", + r.logger.Debug("Sending pending watermark to newly registered shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), tag.NewInt64("exclusive_high", lastWatermark.ExclusiveHighWatermark)) @@ -405,7 +405,7 @@ func (r *intraProxyStreamReceiver) sendPendingWatermarkToShard(targetShardID his } select { case sendChan <- clonedMsg: - r.logger.Info("Sent pending watermark to local shard", + r.logger.Debug("Sent pending watermark to local shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) default: r.logger.Warn("Failed to send pending watermark to local shard (channel full)", @@ -423,7 +423,7 @@ func (r *intraProxyStreamReceiver) sendPendingWatermarkToShard(targetShardID his Resp: clonedResp, } if r.shardManager.DeliverMessagesToShardOwner(targetShardID, &clonedMsg, shutdownChan, r.logger) { - r.logger.Info("Sent pending watermark to remote shard", + r.logger.Debug("Sent pending watermark to remote shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) } else { r.logger.Warn("Failed to send pending watermark to remote shard", @@ -477,7 +477,7 @@ func (m *intraProxyManager) EnsureReceiverForPeerShard(peerNodeName string, targ tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard))) - logger.Info("EnsureReceiverForPeerShard") + logger.Debug("EnsureReceiverForPeerShard") // Cross-cluster only if targetShard.ClusterID == sourceShard.ClusterID { @@ -491,7 +491,7 @@ func (m *intraProxyManager) EnsureReceiverForPeerShard(peerNodeName string, targ isLocalTargetShard := m.shardManager.IsLocalShard(targetShard) isLocalSourceShard := m.shardManager.IsLocalShard(sourceShard) if !isLocalTargetShard && !isLocalSourceShard { - logger.Info("EnsureReceiverForPeerShard skipping because neither shard is local", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewBoolTag("isLocalTargetShard", isLocalTargetShard), tag.NewBoolTag("isLocalSourceShard", isLocalSourceShard)) + logger.Debug("EnsureReceiverForPeerShard skipping because neither shard is local", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewBoolTag("isLocalTargetShard", isLocalTargetShard), tag.NewBoolTag("isLocalSourceShard", isLocalSourceShard)) return } // Consolidated path: ensure stream and background loops @@ -507,24 +507,24 @@ func (m *intraProxyManager) ensurePeer( peerNodeName string, ) (*peerState, error) { logger := log.With(m.logger, tag.NewStringTag("peerNodeName", peerNodeName)) - logger.Info("ensurePeer started") - defer logger.Info("ensurePeer finished") + logger.Debug("ensurePeer started") + defer logger.Debug("ensurePeer finished") m.streamsMu.RLock() if ps, ok := m.peers[peerNodeName]; ok && ps != nil && ps.conn != nil { m.streamsMu.RUnlock() - logger.Info("ensurePeer found existing peer with connection") + logger.Debug("ensurePeer found existing peer with connection") return ps, nil } m.streamsMu.RUnlock() - logger.Info("ensurePeer creating new peer connection") + logger.Debug("ensurePeer creating new peer connection") // Build TLS from this proxy's outbound client TLS config if available tlsCfg := m.shardManager.GetIntraProxyTLSConfig() var parsedTLSCfg *tls.Config if tlsCfg.IsEnabled() { - logger.Info("ensurePeer TLS enabled, building TLS config") + logger.Debug("ensurePeer TLS enabled, building TLS config") var err error parsedTLSCfg, err = encryption.GetClientTLSConfig(tlsCfg) if err != nil { @@ -532,7 +532,7 @@ func (m *intraProxyManager) ensurePeer( return nil, fmt.Errorf("config error when creating tls config: %w", err) } } else { - logger.Info("ensurePeer TLS disabled") + logger.Debug("ensurePeer TLS disabled") } dialOpts := grpcutil.MakeDialOptions(parsedTLSCfg, metrics.GetGRPCClientMetrics("intra_proxy")) @@ -541,27 +541,27 @@ func (m *intraProxyManager) ensurePeer( logger.Error("ensurePeer proxy address not found") return nil, fmt.Errorf("proxy address not found") } - logger.Info("ensurePeer dialing peer", tag.NewStringTag("proxyAddresses", proxyAddresses)) + logger.Debug("ensurePeer dialing peer", tag.NewStringTag("proxyAddresses", proxyAddresses)) cc, err := grpc.NewClient(proxyAddresses, dialOpts...) if err != nil { logger.Error("ensurePeer failed to dial peer", tag.Error(err)) return nil, err } - logger.Info("ensurePeer successfully dialed peer") + logger.Debug("ensurePeer successfully dialed peer") m.streamsMu.Lock() ps := m.peers[peerNodeName] if ps == nil { - logger.Info("ensurePeer creating new peer state") + logger.Debug("ensurePeer creating new peer state") ps = &peerState{conn: cc, receivers: make(map[peerStreamKey]*intraProxyStreamReceiver), senders: make(map[peerStreamKey]*intraProxyStreamSender), recvShutdown: make(map[peerStreamKey]channel.ShutdownOnce)} m.peers[peerNodeName] = ps } else { - logger.Info("ensurePeer updating existing peer state with new connection") + logger.Debug("ensurePeer updating existing peer state with new connection") old := ps.conn ps.conn = cc if old != nil { - logger.Info("ensurePeer closing old connection") + logger.Debug("ensurePeer closing old connection") _ = old.Close() } if ps.receivers == nil { @@ -586,7 +586,7 @@ func (m *intraProxyManager) ensureStream( targetShard history.ClusterShardID, sourceShard history.ClusterShardID, ) error { - logger.Info("ensureStream") + logger.Debug("ensureStream") key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} // Fast path: already exists @@ -594,7 +594,7 @@ func (m *intraProxyManager) ensureStream( if ps, ok := m.peers[peerNodeName]; ok && ps != nil { if r, ok2 := ps.receivers[key]; ok2 && r != nil && r.streamClient != nil { m.streamsMu.RUnlock() - logger.Info("ensureStream reused") + logger.Debug("ensureStream reused") return nil } } @@ -625,7 +625,7 @@ func (m *intraProxyManager) ensureStream( ps.receivers[key] = recv ps.recvShutdown[key] = recv.shutdown m.streamsMu.Unlock() - m.logger.Info("intraProxyStreamReceiver added", tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("key", fmt.Sprintf("%v", key)), tag.NewStringTag("receiver", recv.streamID)) + m.logger.Debug("intraProxyStreamReceiver added", tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("key", fmt.Sprintf("%v", key)), tag.NewStringTag("receiver", recv.streamID)) // Let the receiver open stream, register tracking, and start goroutines go func() { @@ -674,8 +674,8 @@ func (m *intraProxyManager) sendReplicationMessages( ) error { key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} logger := log.With(m.logger, tag.NewStringTag("task-target-shard", ClusterShardIDtoString(targetShard)), tag.NewStringTag("task-source-shard", ClusterShardIDtoString(sourceShard))) - logger.Info("sendReplicationMessages") - defer logger.Info("sendReplicationMessages finished") + logger.Debug("sendReplicationMessages") + defer logger.Debug("sendReplicationMessages finished") // Try server stream first with short retry/backoff to await registration deadline := time.Now().Add(2 * time.Second) @@ -685,13 +685,13 @@ func (m *intraProxyManager) sendReplicationMessages( m.streamsMu.RLock() ps, ok := m.peers[peerNodeName] if ok && ps != nil && ps.senders != nil { - logger.Info("sendReplicationMessages senders for node", tag.NewStringTag("node", peerNodeName), tag.NewStringTag("senders", fmt.Sprintf("%v", ps.senders))) + logger.Debug("sendReplicationMessages senders for node", tag.NewStringTag("node", peerNodeName), tag.NewStringTag("senders", fmt.Sprintf("%v", ps.senders))) if s, ok2 := ps.senders[key]; ok2 && s != nil { sender = s } } m.streamsMu.RUnlock() - logger.Info("sendReplicationMessages sender", tag.NewStringTag("sender", fmt.Sprintf("%v", sender))) + logger.Debug("sendReplicationMessages sender", tag.NewStringTag("sender", fmt.Sprintf("%v", sender))) if sender != nil { if err := sender.sendReplicationMessages(resp); err != nil { @@ -818,8 +818,8 @@ func (m *intraProxyManager) Notify() { // for a given peer and closes any sender/receiver not in the desired set. // This mirrors the Temporal StreamReceiverMonitor approach. func (m *intraProxyManager) ReconcilePeerStreams(peerNodeName string) { - m.logger.Info("ReconcilePeerStreams started", tag.NewStringTag("peerNodeName", peerNodeName)) - defer m.logger.Info("ReconcilePeerStreams done", tag.NewStringTag("peerNodeName", peerNodeName)) + m.logger.Debug("ReconcilePeerStreams started", tag.NewStringTag("peerNodeName", peerNodeName)) + defer m.logger.Debug("ReconcilePeerStreams done", tag.NewStringTag("peerNodeName", peerNodeName)) localShards := m.shardManager.GetLocalShards() remoteShards, err := m.shardManager.GetRemoteShardsForPeer(peerNodeName) @@ -827,7 +827,7 @@ func (m *intraProxyManager) ReconcilePeerStreams(peerNodeName string) { m.logger.Error("Failed to get remote shards for peer", tag.Error(err)) return } - m.logger.Info("ReconcilePeerStreams remote and local shards", + m.logger.Debug("ReconcilePeerStreams remote and local shards", tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("remoteShards", fmt.Sprintf("%v", remoteShards)), tag.NewStringTag("localShards", fmt.Sprintf("%v", localShards)), @@ -860,7 +860,7 @@ func (m *intraProxyManager) ReconcilePeerStreams(peerNodeName string) { } } - m.logger.Info("ReconcilePeerStreams desired receivers and senders", tag.NewStringTag("desiredReceivers", fmt.Sprintf("%v", desiredReceivers)), tag.NewStringTag("desiredSenders", fmt.Sprintf("%v", desiredSenders))) + m.logger.Debug("ReconcilePeerStreams desired receivers and senders", tag.NewStringTag("desiredReceivers", fmt.Sprintf("%v", desiredReceivers)), tag.NewStringTag("desiredSenders", fmt.Sprintf("%v", desiredSenders))) // Ensure all desired receivers exist for key := range desiredReceivers { diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go index 8ae63597..9f8d406b 100644 --- a/proxy/proxy_streams.go +++ b/proxy/proxy_streams.go @@ -275,9 +275,9 @@ func (s *proxyStreamSender) recvAck( sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, shutdownChan channel.ShutdownOnce, ) error { - s.logger.Info("proxyStreamSender recvAck started") + s.logger.Debug("proxyStreamSender recvAck started") defer func() { - s.logger.Info("proxyStreamSender recvAck finished") + s.logger.Debug("proxyStreamSender recvAck finished") shutdownChan.Shutdown() }() for !shutdownChan.IsShutdown() { @@ -301,7 +301,7 @@ func (s *proxyStreamSender) recvAck( shardToAck, pendingDiscard := s.idRing.AggregateUpTo(proxyAckWatermark) s.mu.Unlock() - s.logger.Info("Sender received upstream ACK", tag.NewInt64("inclusive_low", proxyAckWatermark), tag.NewStringTag("shardToAck", fmt.Sprintf("%v", shardToAck)), tag.NewInt("pendingDiscard", pendingDiscard)) + s.logger.Debug("Sender received upstream ACK", tag.NewInt64("inclusive_low", proxyAckWatermark), tag.NewStringTag("shardToAck", fmt.Sprintf("%v", shardToAck)), tag.NewInt("pendingDiscard", pendingDiscard)) if len(shardToAck) > 0 { sent := make(map[history.ClusterShardID]bool, len(shardToAck)) @@ -331,7 +331,7 @@ func (s *proxyStreamSender) recvAck( }, } - s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", originalAck)) + s.logger.Debug("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", originalAck)) if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, shutdownChan, s.logger, originalAck, true) { sent[srcShard] = true @@ -394,7 +394,7 @@ func (s *proxyStreamSender) recvAck( }, } // Log fallback ACK for this source shard - s.logger.Info("Sender forwarding fallback ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", prev)) + s.logger.Debug("Sender forwarding fallback ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", prev)) if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, shutdownChan, s.logger, prev, true) { sent[srcShard] = true numRemaining-- @@ -435,9 +435,9 @@ func (s *proxyStreamSender) sendReplicationMessages( sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, shutdownChan channel.ShutdownOnce, ) error { - s.logger.Info("proxyStreamSender sendReplicationMessages started") + s.logger.Debug("proxyStreamSender sendReplicationMessages started") defer func() { - s.logger.Info("proxyStreamSender sendReplicationMessages finished") + s.logger.Debug("proxyStreamSender sendReplicationMessages finished") shutdownChan.Shutdown() }() @@ -453,7 +453,7 @@ func (s *proxyStreamSender) sendReplicationMessages( if !ok { return nil } - s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: routed.Resp=%p", routed.Resp), tag.NewStringTag("routed", fmt.Sprintf("%v", routed))) + s.logger.Debug(fmt.Sprintf("Sender received ReplicationTasks: routed.Resp=%p", routed.Resp), tag.NewStringTag("routed", fmt.Sprintf("%v", routed))) resp := routed.Resp m, ok := resp.Attributes.(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages) if !ok || m.Messages == nil { @@ -464,7 +464,7 @@ func (s *proxyStreamSender) sendReplicationMessages( for _, t := range m.Messages.ReplicationTasks { sourceTaskIds = append(sourceTaskIds, t.SourceTaskId) } - s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d ids=%v", m.Messages.ExclusiveHighWatermark, sourceTaskIds)) + s.logger.Debug(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d ids=%v", m.Messages.ExclusiveHighWatermark, sourceTaskIds)) // rewrite task ids s.mu.Lock() @@ -472,7 +472,7 @@ func (s *proxyStreamSender) sendReplicationMessages( var proxyIDs []int64 // capture original exclusive high watermark before rewriting originalHigh := m.Messages.ExclusiveHighWatermark - s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d", m.Messages.ExclusiveHighWatermark, originalHigh)) + s.logger.Debug(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d", m.Messages.ExclusiveHighWatermark, originalHigh)) // Ensure exclusive high watermark is in proxy task ID space var proxyExclusiveHigh int64 if len(m.Messages.ReplicationTasks) > 0 { @@ -502,16 +502,16 @@ func (s *proxyStreamSender) sendReplicationMessages( proxyIDs = append(proxyIDs, proxyHigh) proxyExclusiveHigh = proxyHigh m.Messages.ExclusiveHighWatermark = proxyExclusiveHigh - s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d proxy_high=%d original", proxyExclusiveHigh, originalHigh, proxyHigh)) + s.logger.Debug(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d proxy_high=%d original", proxyExclusiveHigh, originalHigh, proxyHigh)) } s.mu.Unlock() // Log mapping from original -> proxy IDs (use captured value to avoid data race) - s.logger.Info(fmt.Sprintf("Sender sending ReplicationTasks from shard %s: original=%v proxy=%v", ClusterShardIDtoString(routed.SourceShard), originalIDs, proxyIDs), tag.NewInt64("exclusive_high", proxyExclusiveHigh)) + s.logger.Debug(fmt.Sprintf("Sender sending ReplicationTasks from shard %s: original=%v proxy=%v", ClusterShardIDtoString(routed.SourceShard), originalIDs, proxyIDs), tag.NewInt64("exclusive_high", proxyExclusiveHigh)) if err := sourceStreamServer.Send(resp); err != nil { return err } - s.logger.Info("Sender sent ReplicationTasks", tag.NewStringTag("sourceShard", ClusterShardIDtoString(routed.SourceShard)), tag.NewInt64("exclusive_high", proxyExclusiveHigh)) + s.logger.Debug("Sender sent ReplicationTasks", tag.NewStringTag("sourceShard", ClusterShardIDtoString(routed.SourceShard)), tag.NewInt64("exclusive_high", proxyExclusiveHigh)) // Update keepalive state s.mu.Lock() @@ -539,7 +539,7 @@ func (s *proxyStreamSender) sendReplicationMessages( }, }, } - s.logger.Info("Sender sending keepalive message", tag.NewInt64("watermark", watermark)) + s.logger.Debug("Sender sending keepalive message", tag.NewInt64("watermark", watermark)) if err := sourceStreamServer.Send(keepaliveResp); err != nil { return err } @@ -696,13 +696,13 @@ func (r *proxyStreamReceiver) recvReplicationMessages( sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient, shutdownChan channel.ShutdownOnce, ) error { - r.logger.Info("proxyStreamReceiver recvReplicationMessages started") - defer r.logger.Info("proxyStreamReceiver recvReplicationMessages finished") + r.logger.Debug("proxyStreamReceiver recvReplicationMessages started") + defer r.logger.Debug("proxyStreamReceiver recvReplicationMessages finished") for !shutdownChan.IsShutdown() { resp, err := sourceStreamClient.Recv() if err == io.EOF { - r.logger.Info("sourceStreamClient.Recv encountered EOF", tag.Error(err)) + r.logger.Debug("sourceStreamClient.Recv encountered EOF", tag.Error(err)) return nil } if err != nil { @@ -724,7 +724,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( } // Log every replication task id received at receiver - r.logger.Info(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", attr.Messages.ExclusiveHighWatermark, ids)) + r.logger.Debug(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", attr.Messages.ExclusiveHighWatermark, ids)) // record last source exclusive high watermark (original id space) r.ackMu.Lock() @@ -741,7 +741,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( // If replication tasks are empty, still log the empty batch and send watermark if len(attr.Messages.ReplicationTasks) == 0 { - r.logger.Info("Receiver received empty replication batch", tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) + r.logger.Debug("Receiver received empty replication batch", tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) // Track last watermark for late-registering shards r.lastWatermarkMu.Lock() @@ -763,7 +763,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( }, } localShardsToSend := r.shardManager.GetRemoteSendChansByCluster(r.targetShardID.ClusterID) - r.logger.Info("Going to broadcast high watermark to local shards", tag.NewStringTag("localShardsToSend", fmt.Sprintf("%v", localShardsToSend))) + r.logger.Debug("Going to broadcast high watermark to local shards", tag.NewStringTag("localShardsToSend", fmt.Sprintf("%v", localShardsToSend))) for targetShardID, sendChan := range localShardsToSend { // Clone the message for each recipient to prevent shared mutation clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) @@ -771,7 +771,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( SourceShard: msg.SourceShard, Resp: clonedResp, } - r.logger.Info(fmt.Sprintf("Sending high watermark to target shard, msg.Resp=%p", clonedMsg.Resp), tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark), tag.NewStringTag("msg", fmt.Sprintf("%v", clonedMsg))) + r.logger.Debug(fmt.Sprintf("Sending high watermark to target shard, msg.Resp=%p", clonedMsg.Resp), tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark), tag.NewStringTag("msg", fmt.Sprintf("%v", clonedMsg))) // Use non-blocking send with recover to handle closed channels func() { defer func() { @@ -799,7 +799,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( r.logger.Error("Failed to get remote shards", tag.Error(err)) return err } - r.logger.Info("Going to broadcast high watermark to remote shards", tag.NewStringTag("remoteShards", fmt.Sprintf("%v", remoteShards))) + r.logger.Debug("Going to broadcast high watermark to remote shards", tag.NewStringTag("remoteShards", fmt.Sprintf("%v", remoteShards))) for _, shards := range remoteShards { for _, shard := range shards.Shards { if shard.ID.ClusterID != r.targetShardID.ClusterID { @@ -825,7 +825,7 @@ func (r *proxyStreamReceiver) recvReplicationMessages( for targetShardID := range tasksByTargetShard { sentByTarget[targetShardID] = false } - r.logger.Info("Going to broadcast ReplicationTasks to target shards", tag.NewStringTag("tasksByTargetShard", fmt.Sprintf("%v", tasksByTargetShard))) + r.logger.Debug("Going to broadcast ReplicationTasks to target shards", tag.NewStringTag("tasksByTargetShard", fmt.Sprintf("%v", tasksByTargetShard))) numRemaining := len(tasksByTargetShard) backoff := 10 * time.Millisecond for numRemaining > 0 { @@ -910,7 +910,7 @@ func (r *proxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history. return } - r.logger.Info("Sending pending watermark to newly registered shard", + r.logger.Debug("Sending pending watermark to newly registered shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), tag.NewInt64("exclusive_high", lastWatermark.ExclusiveHighWatermark)) @@ -935,7 +935,7 @@ func (r *proxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history. } select { case sendChan <- clonedMsg: - r.logger.Info("Sent pending watermark to local shard", + r.logger.Debug("Sent pending watermark to local shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) default: r.logger.Warn("Failed to send pending watermark to local shard (channel full)", @@ -953,7 +953,7 @@ func (r *proxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history. Resp: clonedResp, } if r.shardManager.DeliverMessagesToShardOwner(targetShardID, &clonedMsg, shutdownChan, r.logger) { - r.logger.Info("Sent pending watermark to remote shard", + r.logger.Debug("Sent pending watermark to remote shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) } else { r.logger.Warn("Failed to send pending watermark to remote shard", @@ -967,8 +967,8 @@ func (r *proxyStreamReceiver) sendAck( sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient, shutdownChan channel.ShutdownOnce, ) error { - r.logger.Info("proxyStreamReceiver sendAck started") - defer r.logger.Info("proxyStreamReceiver sendAck finished") + r.logger.Debug("proxyStreamReceiver sendAck started") + defer r.logger.Debug("proxyStreamReceiver sendAck finished") ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() @@ -978,7 +978,7 @@ func (r *proxyStreamReceiver) sendAck( case routed := <-r.ackChan: // Update per-target watermark if attr, ok := routed.Req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { - r.logger.Info("Receiver received upstream ACK", tag.NewInt64("inclusive_low", attr.SyncReplicationState.InclusiveLowWatermark), tag.NewStringTag("targetShard", ClusterShardIDtoString(routed.TargetShard))) + r.logger.Debug("Receiver received upstream ACK", tag.NewInt64("inclusive_low", attr.SyncReplicationState.InclusiveLowWatermark), tag.NewStringTag("targetShard", ClusterShardIDtoString(routed.TargetShard))) r.ackMu.Lock() r.ackByTarget[routed.TargetShard] = attr.SyncReplicationState.InclusiveLowWatermark // Compute minimal watermark across targets @@ -1009,12 +1009,12 @@ func (r *proxyStreamReceiver) sendAck( }, }, } - r.logger.Info("Receiver sending aggregated ACK upstream", tag.NewInt64("inclusive_low", min)) + r.logger.Debug("Receiver sending aggregated ACK upstream", tag.NewInt64("inclusive_low", min)) if err := sourceStreamClient.Send(aggregated); err != nil { if err != io.EOF { r.logger.Error("sourceStreamClient.Send encountered error", tag.Error(err)) } else { - r.logger.Info("sourceStreamClient.Send encountered EOF", tag.Error(err)) + r.logger.Debug("sourceStreamClient.Send encountered EOF", tag.Error(err)) } return err } @@ -1042,7 +1042,7 @@ func (r *proxyStreamReceiver) sendAck( r.ackMu.RUnlock() if shouldSendKeepalive { - r.logger.Info("Receiver sending keepalive ACK") + r.logger.Debug("Receiver sending keepalive ACK") if err := sourceStreamClient.Send(lastAck); err != nil { if err != io.EOF { r.logger.Error("sourceStreamClient.Send keepalive encountered error", tag.Error(err)) diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index d6c27c90..901c394d 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -793,7 +793,7 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( }() select { case ackCh <- *routedAck: - logger.Info("Delivered ACK to local shard owner") + logger.Debug("Delivered ACK to local shard owner") delivered = true case <-shutdownChan.Channel(): // Shutdown signal received @@ -822,7 +822,7 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( logger.Error("Failed to forward ACK to shard owner via intra-proxy", tag.Error(err), tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) return false } - logger.Info("Forwarded ACK to shard owner via intra-proxy", tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) + logger.Debug("Forwarded ACK to shard owner via intra-proxy", tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) return true } logger.Warn("Owner proxy address not found for shard") @@ -855,7 +855,7 @@ func (sm *shardManagerImpl) DeliverMessagesToShardOwner( }() select { case ch <- *routedMsg: - logger.Info("Delivered messages to local shard owner") + logger.Debug("Delivered messages to local shard owner") delivered = true case <-shutdownChan.Channel(): // Shutdown signal received @@ -1057,7 +1057,7 @@ func (sd *shardDelegate) NotifyMsg(data []byte) { return } - sd.logger.Info("Received shard message", + sd.logger.Debug("Received shard message", tag.NewStringTag("type", msg.Type), tag.NewStringTag("node", msg.NodeName), tag.NewStringTag("shard", ClusterShardIDtoString(msg.ClientShard))) @@ -1108,7 +1108,7 @@ func (sd *shardDelegate) MergeRemoteState(buf []byte, join bool) { sd.manager.remoteNodeStatesMu.Unlock() } - sd.logger.Info("Merged remote shard state", + sd.logger.Debug("Merged remote shard state", tag.NewStringTag("node", state.NodeName), tag.NewStringTag("shards", strconv.Itoa(len(state.Shards))), tag.NewStringTag("state", fmt.Sprintf("%+v", state))) diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index f8229b71..64fd47d7 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -354,46 +354,22 @@ func (s *ReplicationTestSuite) TearDownSuite() { s.deglobalizeNamespace(s.namespace) } - if s.clusterA != nil && s.clusterB != nil { - removeRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName()) - removeRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName()) - } - if s.clusterA != nil { - s.NoError(s.clusterA.TearDownCluster()) - } - if s.clusterB != nil { - s.NoError(s.clusterB.TearDownCluster()) - } + removeRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName()) + removeRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName()) + s.NoError(s.clusterA.TearDownCluster()) + s.NoError(s.clusterB.TearDownCluster()) if s.setupMode == SetupModeSimple { - if s.proxyA != nil { - s.proxyA.Stop() - } - if s.proxyB != nil { - s.proxyB.Stop() - } + s.proxyA.Stop() + s.proxyB.Stop() } else { - if s.loadBalancerA != nil { - s.loadBalancerA.Stop() - } - if s.loadBalancerB != nil { - s.loadBalancerB.Stop() - } - if s.loadBalancerC != nil { - s.loadBalancerC.Stop() - } - if s.proxyA1 != nil { - s.proxyA1.Stop() - } - if s.proxyA2 != nil { - s.proxyA2.Stop() - } - if s.proxyB1 != nil { - s.proxyB1.Stop() - } - if s.proxyB2 != nil { - s.proxyB2.Stop() - } + s.loadBalancerA.Stop() + s.loadBalancerB.Stop() + s.loadBalancerC.Stop() + s.proxyA1.Stop() + s.proxyA2.Stop() + s.proxyB1.Stop() + s.proxyB2.Stop() } } diff --git a/proxy/test/tcp_proxy.go b/proxy/test/tcp_proxy.go index bbdd811c..9bb40f7b 100644 --- a/proxy/test/tcp_proxy.go +++ b/proxy/test/tcp_proxy.go @@ -29,6 +29,8 @@ type ( Upstream *Upstream } + // TCPProxy is a simple TCP proxy that distributes connections to the least-used upstream server. + // It is used to test the connections and replication streams across multiple s2s-proxy instances. TCPProxy struct { rules []*ProxyRule logger log.Logger From 9012f6b723e32285f4634beac49222e2d3cde13a Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 2 Jan 2026 14:20:57 -0800 Subject: [PATCH 32/38] refactor and put stream handling into separated file --- proxy/admin_stream_transfer.go | 173 +++++++++++++++++++++++++++++++++ proxy/adminservice.go | 158 ++---------------------------- 2 files changed, 181 insertions(+), 150 deletions(-) diff --git a/proxy/admin_stream_transfer.go b/proxy/admin_stream_transfer.go index a48ff021..e1578265 100644 --- a/proxy/admin_stream_transfer.go +++ b/proxy/admin_stream_transfer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "strconv" "strings" "sync" "time" @@ -11,11 +12,14 @@ import ( "go.temporal.io/api/serviceerror" "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/client/history" + servercommon "go.temporal.io/server/common" "go.temporal.io/server/common/channel" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "google.golang.org/grpc/metadata" + "github.com/temporalio/s2s-proxy/common" + "github.com/temporalio/s2s-proxy/config" "github.com/temporalio/s2s-proxy/metrics" ) @@ -291,3 +295,172 @@ func (f *StreamForwarder) forwardAcks(wg *sync.WaitGroup) { } } } + +// handleStream handles the routing logic for StreamWorkflowReplicationMessages based on shard count mode. +func handleStream( + streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + targetMetadata metadata.MD, + sourceClusterShardID history.ClusterShardID, + targetClusterShardID history.ClusterShardID, + logger log.Logger, + shardCountConfig config.ShardCountConfig, + lcmParameters LCMParameters, + routingParameters RoutingParameters, + adminClient adminservice.AdminServiceClient, + adminClientReverse adminservice.AdminServiceClient, + shardManager ShardManager, + metricLabelValues []string, +) error { + switch shardCountConfig.Mode { + case config.ShardCountLCM: + // Arbitrary shard count support. + // + // Temporal only supports shard counts where one shard count is an even multiple of the other. + // The trick in this mode is the proxy will present the Least Common Multiple of both cluster shard counts. + // Temporal establishes outbound replication streams to the proxy for all unique shard id pairs between + // itself and the proxy's shard count. Then the proxy directly forwards those streams along to the target + // cluster, remapping proxy stream shard ids to the target cluster shard ids. + newTargetShardID := history.ClusterShardID{ + ClusterID: targetClusterShardID.ClusterID, + ShardID: sourceClusterShardID.ShardID, // proxy fake shard id + } + newSourceShardID := history.ClusterShardID{ + ClusterID: sourceClusterShardID.ClusterID, + } + // Remap shard id using the pre-calculated target shard count. + newSourceShardID.ShardID = mapShardIDUnique(lcmParameters.LCM, lcmParameters.TargetShardCount, sourceClusterShardID.ShardID) + + logger = log.With(logger, + tag.NewStringTag("newTarget", ClusterShardIDtoString(newTargetShardID)), + tag.NewStringTag("newSource", ClusterShardIDtoString(newSourceShardID))) + + // Maybe there's a cleaner way. Trying to preserve any other metadata. + targetMetadata.Set(history.MetadataKeyClientClusterID, strconv.Itoa(int(newTargetShardID.ClusterID))) + targetMetadata.Set(history.MetadataKeyClientShardID, strconv.Itoa(int(newTargetShardID.ShardID))) + targetMetadata.Set(history.MetadataKeyServerClusterID, strconv.Itoa(int(newSourceShardID.ClusterID))) + targetMetadata.Set(history.MetadataKeyServerShardID, strconv.Itoa(int(newSourceShardID.ShardID))) + case config.ShardCountRouting: + isIntraProxy := common.IsIntraProxy(streamServer.Context()) + if isIntraProxy { + return streamIntraProxyRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID, shardManager) + } + return streamRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID, shardManager, adminClientReverse, routingParameters) + } + + forwarder := newStreamForwarder( + adminClient, + streamServer, + targetMetadata, + sourceClusterShardID, + targetClusterShardID, + metricLabelValues, + logger, + ) + return forwarder.Run() +} + +func streamIntraProxyRouting( + logger log.Logger, + streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + sourceShardID history.ClusterShardID, + targetShardID history.ClusterShardID, + shardManager ShardManager, +) error { + logger.Info("streamIntraProxyRouting started") + defer logger.Info("streamIntraProxyRouting finished") + + // Determine remote peer identity from intra-proxy headers + peerNodeName := "" + if md, ok := metadata.FromIncomingContext(streamServer.Context()); ok { + vals := md.Get(common.IntraProxyOriginProxyIDHeader) + if len(vals) > 0 { + peerNodeName = vals[0] + } + } + + // Only allow intra-proxy when at least one shard is local to this proxy instance + isLocalSource := shardManager.IsLocalShard(sourceShardID) + isLocalTarget := shardManager.IsLocalShard(targetShardID) + if isLocalTarget || !isLocalSource { + logger.Info("Skipping intra-proxy between two local shards or two remote shards. Client may use outdated shard info.", + tag.NewBoolTag("isLocalSource", isLocalSource), + tag.NewBoolTag("isLocalTarget", isLocalTarget), + ) + return nil + } + + // Sender: handle ACKs coming from peer and forward to original owner + sender := &intraProxyStreamSender{ + logger: logger, + shardManager: shardManager, + peerNodeName: peerNodeName, + sourceShardID: sourceShardID, + targetShardID: targetShardID, + } + + shutdownChan := channel.NewShutdownOnce() + go func() { + if err := sender.Run(streamServer, shutdownChan); err != nil { + logger.Error("intraProxyStreamSender.Run error", tag.Error(err)) + } + }() + <-shutdownChan.Channel() + return nil +} + +func streamRouting( + logger log.Logger, + streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + sourceShardID history.ClusterShardID, + targetShardID history.ClusterShardID, + shardManager ShardManager, + adminClientReverse adminservice.AdminServiceClient, + routingParameters RoutingParameters, +) error { + logger.Info("streamRouting started") + defer logger.Info("streamRouting stopped") + + // client: stream receiver + // server: stream sender + proxyStreamSender := &proxyStreamSender{ + logger: logger, + shardManager: shardManager, + sourceShardID: sourceShardID, + targetShardID: targetShardID, + directionLabel: routingParameters.DirectionLabel, + } + + proxyStreamReceiver := &proxyStreamReceiver{ + logger: logger, + shardManager: shardManager, + adminClient: adminClientReverse, + localShardCount: routingParameters.RoutingLocalShardCount, + sourceShardID: targetShardID, // reverse direction + targetShardID: sourceShardID, // reverse direction + directionLabel: routingParameters.DirectionLabel, + } + + shutdownChan := channel.NewShutdownOnce() + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + proxyStreamSender.Run(streamServer, shutdownChan) + }() + go func() { + defer wg.Done() + proxyStreamReceiver.Run(shutdownChan) + }() + wg.Wait() + + return nil +} + +func mapShardIDUnique(sourceShardCount, targetShardCount, sourceShardID int32) int32 { + targetShardID := servercommon.MapShardID(sourceShardCount, targetShardCount, sourceShardID) + if len(targetShardID) != 1 { + panic(fmt.Sprintf("remapping shard count error: sourceShardCount=%d targetShardCount=%d sourceShardID=%d targetShardID=%v\n", + sourceShardCount, targetShardCount, sourceShardID, targetShardID)) + } + return targetShardID[0] +} diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 6cf491c7..4b6a3f84 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -3,14 +3,10 @@ package proxy import ( "context" "fmt" - "strconv" - "sync" "go.temporal.io/api/serviceerror" "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/client/history" - servercommon "go.temporal.io/server/common" - "go.temporal.io/server/common/channel" "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" @@ -291,9 +287,6 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( return err } - // Detect intra-proxy streams early for logging/behavior toggles - isIntraProxy := common.IsIntraProxy(streamServer.Context()) - logger := log.With(s.logger, tag.NewStringTag("source", ClusterShardIDtoString(sourceClusterShardID)), tag.NewStringTag("target", ClusterShardIDtoString(targetClusterShardID))) @@ -305,53 +298,20 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( streamsActiveGauge.Inc() defer streamsActiveGauge.Dec() - if s.shardCountConfig.Mode == config.ShardCountLCM { - // Arbitrary shard count support. - // - // Temporal only supports shard counts where one shard count is an even multiple of the other. - // The trick in this mode is the proxy will present the Least Common Multiple of both cluster shard counts. - // Temporal establishes outbound replication streams to the proxy for all unique shard id pairs between - // itself and the proxy's shard count. Then the proxy directly forwards those streams along to the target - // cluster, remapping proxy stream shard ids to the target cluster shard ids. - newTargetShardID := history.ClusterShardID{ - ClusterID: targetClusterShardID.ClusterID, - ShardID: sourceClusterShardID.ShardID, // proxy fake shard id - } - newSourceShardID := history.ClusterShardID{ - ClusterID: sourceClusterShardID.ClusterID, - } - // Remap shard id using the pre-calculated target shard count. - newSourceShardID.ShardID = mapShardIDUnique(s.lcmParameters.LCM, s.lcmParameters.TargetShardCount, sourceClusterShardID.ShardID) - - logger = log.With(logger, - tag.NewStringTag("newTarget", ClusterShardIDtoString(newTargetShardID)), - tag.NewStringTag("newSource", ClusterShardIDtoString(newSourceShardID))) - - // Maybe there's a cleaner way. Trying to preserve any other metadata. - targetMetadata.Set(history.MetadataKeyClientClusterID, strconv.Itoa(int(newTargetShardID.ClusterID))) - targetMetadata.Set(history.MetadataKeyClientShardID, strconv.Itoa(int(newTargetShardID.ShardID))) - targetMetadata.Set(history.MetadataKeyServerClusterID, strconv.Itoa(int(newSourceShardID.ClusterID))) - targetMetadata.Set(history.MetadataKeyServerShardID, strconv.Itoa(int(newSourceShardID.ShardID))) - } - - if isIntraProxy { - return s.streamIntraProxyRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID) - } - - if s.shardCountConfig.Mode == config.ShardCountRouting { - return s.streamRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID) - } - - forwarder := newStreamForwarder( - s.adminClient, + err = handleStream( streamServer, targetMetadata, sourceClusterShardID, targetClusterShardID, - s.metricLabelValues, logger, + s.shardCountConfig, + s.lcmParameters, + s.routingParameters, + s.adminClient, + s.adminClientReverse, + s.shardManager, + s.metricLabelValues, ) - err = forwarder.Run() if err != nil { return err } @@ -359,105 +319,3 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( // to the client. return nil } - -func (s *adminServiceProxyServer) streamIntraProxyRouting( - logger log.Logger, - streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, - sourceShardID history.ClusterShardID, - targetShardID history.ClusterShardID, -) error { - logger.Info("streamIntraProxyRouting started") - defer logger.Info("streamIntraProxyRouting finished") - - // Determine remote peer identity from intra-proxy headers - peerNodeName := "" - if md, ok := metadata.FromIncomingContext(streamServer.Context()); ok { - vals := md.Get(common.IntraProxyOriginProxyIDHeader) - if len(vals) > 0 { - peerNodeName = vals[0] - } - } - - // Only allow intra-proxy when at least one shard is local to this proxy instance - isLocalSource := s.shardManager.IsLocalShard(sourceShardID) - isLocalTarget := s.shardManager.IsLocalShard(targetShardID) - if isLocalTarget || !isLocalSource { - logger.Info("Skipping intra-proxy between two local shards or two remote shards. Client may use outdated shard info.", - tag.NewBoolTag("isLocalSource", isLocalSource), - tag.NewBoolTag("isLocalTarget", isLocalTarget), - ) - return nil - } - - // Sender: handle ACKs coming from peer and forward to original owner - sender := &intraProxyStreamSender{ - logger: logger, - shardManager: s.shardManager, - peerNodeName: peerNodeName, - sourceShardID: sourceShardID, - targetShardID: targetShardID, - } - - shutdownChan := channel.NewShutdownOnce() - go func() { - if err := sender.Run(streamServer, shutdownChan); err != nil { - logger.Error("intraProxyStreamSender.Run error", tag.Error(err)) - } - }() - <-shutdownChan.Channel() - return nil -} - -func (s *adminServiceProxyServer) streamRouting( - logger log.Logger, - streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, - sourceShardID history.ClusterShardID, - targetShardID history.ClusterShardID, -) error { - logger.Info("streamRouting started") - defer logger.Info("streamRouting stopped") - - // client: stream receiver - // server: stream sender - proxyStreamSender := &proxyStreamSender{ - logger: logger, - shardManager: s.shardManager, - sourceShardID: sourceShardID, - targetShardID: targetShardID, - directionLabel: s.routingParameters.DirectionLabel, - } - - proxyStreamReceiver := &proxyStreamReceiver{ - logger: s.logger, - shardManager: s.shardManager, - adminClient: s.adminClientReverse, - localShardCount: s.routingParameters.RoutingLocalShardCount, - sourceShardID: targetShardID, // reverse direction - targetShardID: sourceShardID, // reverse direction - directionLabel: s.routingParameters.DirectionLabel, - } - - shutdownChan := channel.NewShutdownOnce() - wg := sync.WaitGroup{} - wg.Add(2) - go func() { - defer wg.Done() - proxyStreamSender.Run(streamServer, shutdownChan) - }() - go func() { - defer wg.Done() - proxyStreamReceiver.Run(shutdownChan) - }() - wg.Wait() - - return nil -} - -func mapShardIDUnique(sourceShardCount, targetShardCount, sourceShardID int32) int32 { - targetShardID := servercommon.MapShardID(sourceShardCount, targetShardCount, sourceShardID) - if len(targetShardID) != 1 { - panic(fmt.Sprintf("remapping shard count error: sourceShardCount=%d targetShardCount=%d sourceShardID=%d targetShardID=%v\n", - sourceShardCount, targetShardCount, sourceShardID, targetShardID)) - } - return targetShardID[0] -} From 7bab5d2da53bc4a7405939add89e084d02c14ec5 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 2 Jan 2026 14:38:45 -0800 Subject: [PATCH 33/38] address comments --- proxy/adminservice_test.go | 5 +++-- proxy/cluster_connection.go | 8 ++++---- proxy/debug.go | 4 ++++ proxy/intra_proxy_router.go | 18 ++++++------------ 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/proxy/adminservice_test.go b/proxy/adminservice_test.go index e964ba95..427c11d1 100644 --- a/proxy/adminservice_test.go +++ b/proxy/adminservice_test.go @@ -10,6 +10,7 @@ import ( "go.temporal.io/server/common/log" gomock "go.uber.org/mock/gomock" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" "github.com/temporalio/s2s-proxy/common" "github.com/temporalio/s2s-proxy/config" @@ -123,7 +124,7 @@ func (s *adminserviceSuite) TestAddOrUpdateRemoteCluster() { s.adminClientMock.EXPECT().AddOrUpdateRemoteCluster(ctx, c.expectedReq).Return(expResp, nil) resp, err := server.AddOrUpdateRemoteCluster(ctx, makeOriginalReq()) s.NoError(err) - s.Equal(expResp, resp) + s.True(proto.Equal(expResp, resp)) s.Equal("[]", observer.PrintActiveStreams()) }) } @@ -235,7 +236,7 @@ func (s *adminserviceSuite) TestAPIOverrides_FailoverVersionIncrement() { s.adminClientMock.EXPECT().DescribeCluster(ctx, gomock.Any()).Return(c.mockResp, nil) resp, err := server.DescribeCluster(ctx, req) s.NoError(err) - s.Equal(c.expResp.FailoverVersionIncrement, resp.FailoverVersionIncrement) + s.True(proto.Equal(c.expResp, resp)) s.Equal("[]", observer.PrintActiveStreams()) }) } diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index 8fe3f9e8..387beb15 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -95,7 +95,7 @@ type ( shardCountConfig config.ShardCountConfig logger log.Logger - clusterConnection *ClusterConnection + shardManager ShardManager lcmParameters LCMParameters routingParameters RoutingParameters } @@ -178,7 +178,7 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon saTranslations: saTranslations.Inverse(), shardCountConfig: connConfig.ShardCountConfig, logger: cc.logger, - clusterConnection: cc, + shardManager: cc.shardManager, lcmParameters: getLCMParameters(connConfig.ShardCountConfig, true), routingParameters: getRoutingParameters(connConfig.ShardCountConfig, true, "inbound"), }) @@ -196,7 +196,7 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon saTranslations: saTranslations, shardCountConfig: connConfig.ShardCountConfig, logger: cc.logger, - clusterConnection: cc, + shardManager: cc.shardManager, lcmParameters: getLCMParameters(connConfig.ShardCountConfig, false), routingParameters: getRoutingParameters(connConfig.ShardCountConfig, false, "outbound"), }) @@ -330,7 +330,7 @@ func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, obs c.lcmParameters, c.routingParameters, c.logger, - c.clusterConnection.shardManager, + c.shardManager, ) var accessControl *auth.AccessControl if c.clusterDefinition.ACLPolicy != nil { diff --git a/proxy/debug.go b/proxy/debug.go index 4af633e1..6f5dc456 100644 --- a/proxy/debug.go +++ b/proxy/debug.go @@ -13,6 +13,10 @@ import ( "github.com/temporalio/s2s-proxy/transport/mux/session" ) +// HandleDebugInfo is the HTTP handler for the proxy debug endpoint. +// It returns JSON-encoded debug information including active streams, shard distribution, +// channel states, and mux connection details. + type ( // ProxyIDEntry is a preview of a ring buffer entry diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index c17416de..08e157ee 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "sync" + "sync/atomic" "time" "go.temporal.io/server/api/adminservice/v1" @@ -199,8 +200,7 @@ type intraProxyStreamReceiver struct { shutdown channel.ShutdownOnce cancel context.CancelFunc // lastWatermark tracks the last watermark received from source shard for late-registering target shards - lastWatermarkMu sync.RWMutex - lastWatermark *replicationv1.WorkflowReplicationMessages + lastWatermark atomic.Pointer[replicationv1.WorkflowReplicationMessages] } // Run opens the client stream with metadata, registers tracking, and starts receiver goroutines. @@ -280,12 +280,10 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages() error { st.UpdateStream(r.streamID) // Track last watermark for late-registering shards - r.lastWatermarkMu.Lock() - r.lastWatermark = &replicationv1.WorkflowReplicationMessages{ + r.lastWatermark.Store(&replicationv1.WorkflowReplicationMessages{ ExclusiveHighWatermark: exclusiveHighWatermark, Priority: priority, - } - r.lastWatermarkMu.Unlock() + }) r.logger.Debug(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", exclusiveHighWatermark, ids)) @@ -358,9 +356,7 @@ func (r *intraProxyStreamReceiver) GetSourceShardID() history.ClusterShardID { // GetLastWatermark returns the last watermark received from the source shard func (r *intraProxyStreamReceiver) GetLastWatermark() *replicationv1.WorkflowReplicationMessages { - r.lastWatermarkMu.RLock() - defer r.lastWatermarkMu.RUnlock() - return r.lastWatermark + return r.lastWatermark.Load() } // NotifyNewTargetShard notifies the receiver about a newly registered target shard @@ -371,9 +367,7 @@ func (r *intraProxyStreamReceiver) NotifyNewTargetShard(targetShardID history.Cl // sendPendingWatermarkToShard sends the last known watermark to a newly registered target shard // This ensures late-registering shards receive watermarks that were sent before they registered func (r *intraProxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history.ClusterShardID) { - r.lastWatermarkMu.RLock() - lastWatermark := r.lastWatermark - r.lastWatermarkMu.RUnlock() + lastWatermark := r.GetLastWatermark() if lastWatermark == nil || lastWatermark.ExclusiveHighWatermark == 0 { // No pending watermark to send From 7290d27167cf2e2192354a4388dc747355d1e72f Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 2 Jan 2026 14:44:43 -0800 Subject: [PATCH 34/38] move structs to better place --- proxy/intra_proxy_router.go | 12 ++++++++++++ proxy/proxy.go | 14 -------------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index 08e157ee..e76c5f94 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -25,6 +25,18 @@ import ( "github.com/temporalio/s2s-proxy/transport/grpcutil" ) +// RoutedAck wraps an ACK with the target shard it originated from +type RoutedAck struct { + TargetShard history.ClusterShardID + Req *adminservice.StreamWorkflowReplicationMessagesRequest +} + +// RoutedMessage wraps a replication response with originating client shard info +type RoutedMessage struct { + SourceShard history.ClusterShardID + Resp *adminservice.StreamWorkflowReplicationMessagesResponse +} + // intraProxyManager maintains long-lived intra-proxy streams to peer proxies and // provides simple send helpers (e.g., forwarding ACKs). type intraProxyManager struct { diff --git a/proxy/proxy.go b/proxy/proxy.go index 8541ef4e..4acfacdb 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -7,8 +7,6 @@ import ( "net/http" "strings" - "go.temporal.io/server/api/adminservice/v1" - "go.temporal.io/server/client/history" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" @@ -23,18 +21,6 @@ type ( //accountId string } - // RoutedAck wraps an ACK with the target shard it originated from - RoutedAck struct { - TargetShard history.ClusterShardID - Req *adminservice.StreamWorkflowReplicationMessagesRequest - } - - // RoutedMessage wraps a replication response with originating client shard info - RoutedMessage struct { - SourceShard history.ClusterShardID - Resp *adminservice.StreamWorkflowReplicationMessagesResponse - } - Proxy struct { lifetime context.Context cancel context.CancelFunc From 2fa9239b22b8710bd8a826957166de5477493d55 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 2 Jan 2026 15:49:26 -0800 Subject: [PATCH 35/38] revert lastWatermark atomic.Pointer change. We can redo the change after oss ReplicationStreamSendEmptyTaskDuration is available. --- proxy/intra_proxy_router.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index e76c5f94..e85ccf6f 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "sync" - "sync/atomic" "time" "go.temporal.io/server/api/adminservice/v1" @@ -212,7 +211,8 @@ type intraProxyStreamReceiver struct { shutdown channel.ShutdownOnce cancel context.CancelFunc // lastWatermark tracks the last watermark received from source shard for late-registering target shards - lastWatermark atomic.Pointer[replicationv1.WorkflowReplicationMessages] + lastWatermarkMu sync.RWMutex + lastWatermark *replicationv1.WorkflowReplicationMessages } // Run opens the client stream with metadata, registers tracking, and starts receiver goroutines. @@ -292,10 +292,12 @@ func (r *intraProxyStreamReceiver) recvReplicationMessages() error { st.UpdateStream(r.streamID) // Track last watermark for late-registering shards - r.lastWatermark.Store(&replicationv1.WorkflowReplicationMessages{ + r.lastWatermarkMu.Lock() + r.lastWatermark = &replicationv1.WorkflowReplicationMessages{ ExclusiveHighWatermark: exclusiveHighWatermark, Priority: priority, - }) + } + r.lastWatermarkMu.Unlock() r.logger.Debug(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", exclusiveHighWatermark, ids)) @@ -368,7 +370,9 @@ func (r *intraProxyStreamReceiver) GetSourceShardID() history.ClusterShardID { // GetLastWatermark returns the last watermark received from the source shard func (r *intraProxyStreamReceiver) GetLastWatermark() *replicationv1.WorkflowReplicationMessages { - return r.lastWatermark.Load() + r.lastWatermarkMu.RLock() + defer r.lastWatermarkMu.RUnlock() + return r.lastWatermark } // NotifyNewTargetShard notifies the receiver about a newly registered target shard @@ -379,7 +383,9 @@ func (r *intraProxyStreamReceiver) NotifyNewTargetShard(targetShardID history.Cl // sendPendingWatermarkToShard sends the last known watermark to a newly registered target shard // This ensures late-registering shards receive watermarks that were sent before they registered func (r *intraProxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history.ClusterShardID) { - lastWatermark := r.GetLastWatermark() + r.lastWatermarkMu.RLock() + lastWatermark := r.lastWatermark + r.lastWatermarkMu.RUnlock() if lastWatermark == nil || lastWatermark.ExclusiveHighWatermark == 0 { // No pending watermark to send From 873b2ce6a2b2824a331ac74e9fc7f07ee1a28dfe Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 2 Jan 2026 16:15:33 -0800 Subject: [PATCH 36/38] close stream when cluster connection is closed --- proxy/admin_stream_transfer.go | 15 +++++++++++++-- proxy/adminservice.go | 4 ++++ proxy/adminservice_test.go | 2 +- proxy/cluster_connection.go | 7 ++++--- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/proxy/admin_stream_transfer.go b/proxy/admin_stream_transfer.go index e1578265..bcc05537 100644 --- a/proxy/admin_stream_transfer.go +++ b/proxy/admin_stream_transfer.go @@ -310,6 +310,7 @@ func handleStream( adminClientReverse adminservice.AdminServiceClient, shardManager ShardManager, metricLabelValues []string, + lifetime context.Context, ) error { switch shardCountConfig.Mode { case config.ShardCountLCM: @@ -342,9 +343,9 @@ func handleStream( case config.ShardCountRouting: isIntraProxy := common.IsIntraProxy(streamServer.Context()) if isIntraProxy { - return streamIntraProxyRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID, shardManager) + return streamIntraProxyRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID, shardManager, lifetime) } - return streamRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID, shardManager, adminClientReverse, routingParameters) + return streamRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID, shardManager, adminClientReverse, routingParameters, lifetime) } forwarder := newStreamForwarder( @@ -365,6 +366,7 @@ func streamIntraProxyRouting( sourceShardID history.ClusterShardID, targetShardID history.ClusterShardID, shardManager ShardManager, + lifetime context.Context, ) error { logger.Info("streamIntraProxyRouting started") defer logger.Info("streamIntraProxyRouting finished") @@ -399,6 +401,10 @@ func streamIntraProxyRouting( } shutdownChan := channel.NewShutdownOnce() + // Wire lifetime context to shutdownChan so cluster connection termination closes the stream + context.AfterFunc(lifetime, func() { + shutdownChan.Shutdown() + }) go func() { if err := sender.Run(streamServer, shutdownChan); err != nil { logger.Error("intraProxyStreamSender.Run error", tag.Error(err)) @@ -416,6 +422,7 @@ func streamRouting( shardManager ShardManager, adminClientReverse adminservice.AdminServiceClient, routingParameters RoutingParameters, + lifetime context.Context, ) error { logger.Info("streamRouting started") defer logger.Info("streamRouting stopped") @@ -441,6 +448,10 @@ func streamRouting( } shutdownChan := channel.NewShutdownOnce() + // Wire lifetime context to shutdownChan so cluster connection termination closes the stream + context.AfterFunc(lifetime, func() { + shutdownChan.Shutdown() + }) wg := sync.WaitGroup{} wg.Add(2) go func() { diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 4b6a3f84..9a7ee52c 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -41,6 +41,7 @@ type ( shardCountConfig config.ShardCountConfig lcmParameters LCMParameters routingParameters RoutingParameters + lifetime context.Context } ) @@ -57,6 +58,7 @@ func NewAdminServiceProxyServer( routingParameters RoutingParameters, logger log.Logger, shardManager ShardManager, + lifetime context.Context, ) adminservice.AdminServiceServer { // The AdminServiceStreams will duplicate the same output for an underlying connection issue hundreds of times. // Limit their output to three times per minute @@ -73,6 +75,7 @@ func NewAdminServiceProxyServer( shardCountConfig: shardCountConfig, lcmParameters: lcmParameters, routingParameters: routingParameters, + lifetime: lifetime, } } @@ -311,6 +314,7 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( s.adminClientReverse, s.shardManager, s.metricLabelValues, + s.lifetime, ) if err != nil { return err diff --git a/proxy/adminservice_test.go b/proxy/adminservice_test.go index 427c11d1..655e9260 100644 --- a/proxy/adminservice_test.go +++ b/proxy/adminservice_test.go @@ -47,7 +47,7 @@ type adminProxyServerInput struct { func (s *adminserviceSuite) newAdminServiceProxyServer(in adminProxyServerInput, observer *ReplicationStreamObserver) adminservice.AdminServiceServer { return NewAdminServiceProxyServer("test-service-name", s.adminClientMock, s.adminClientMock, - in.apiOverrides, in.metricLabels, observer.ReportStreamValue, config.ShardCountConfig{}, LCMParameters{}, RoutingParameters{}, log.NewTestLogger(), nil) + in.apiOverrides, in.metricLabels, observer.ReportStreamValue, config.ShardCountConfig{}, LCMParameters{}, RoutingParameters{}, log.NewTestLogger(), nil, context.Background()) } func (s *adminserviceSuite) TestAddOrUpdateRemoteCluster() { diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index 387beb15..a90f665c 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -227,7 +227,7 @@ func createServer(lifetime context.Context, c serverConfiguration) (contextAware return createTCPServer(lifetime, c) case config.ConnTypeMuxClient, config.ConnTypeMuxServer: observer := NewReplicationStreamObserver(c.logger) - grpcServer, err := buildProxyServer(c, c.clusterDefinition.Connection.MuxAddressInfo.TLSConfig, observer.ReportStreamValue) + grpcServer, err := buildProxyServer(c, c.clusterDefinition.Connection.MuxAddressInfo.TLSConfig, observer.ReportStreamValue, lifetime) if err != nil { return nil, nil, err } @@ -248,7 +248,7 @@ func createTCPServer(lifetime context.Context, c serverConfiguration) (contextAw if err != nil { return nil, nil, fmt.Errorf("invalid configuration for inbound server: %w", err) } - grpcServer, err := buildProxyServer(c, c.clusterDefinition.Connection.TcpServer.TLSConfig, observer.ReportStreamValue) + grpcServer, err := buildProxyServer(c, c.clusterDefinition.Connection.TcpServer.TLSConfig, observer.ReportStreamValue, lifetime) if err != nil { return nil, nil, fmt.Errorf("failed to create inbound server: %w", err) } @@ -312,7 +312,7 @@ func (c *ClusterConnection) AcceptingOutboundTraffic() bool { // buildProxyServer uses the provided grpc.ClientConnInterface and config.ProxyConfig to create a grpc.Server that proxies // the Temporal API across the ClientConnInterface. -func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, observeFn func(int32, int32)) (*grpc.Server, error) { +func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, observeFn func(int32, int32), lifetime context.Context) (*grpc.Server, error) { serverOpts, err := makeServerOptions(c, tlsConfig) if err != nil { return nil, fmt.Errorf("could not parse server options: %w", err) @@ -331,6 +331,7 @@ func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, obs c.routingParameters, c.logger, c.shardManager, + lifetime, ) var accessControl *auth.AccessControl if c.clusterDefinition.ACLPolicy != nil { From a75d610b6aa1140f1a039d0b1e48e055709793e5 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 2 Jan 2026 16:37:35 -0800 Subject: [PATCH 37/38] clean up shardManager interface --- proxy/intra_proxy_router.go | 7 +- proxy/shard_manager.go | 173 +++++++++++------------------------- 2 files changed, 59 insertions(+), 121 deletions(-) diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go index e85ccf6f..be4b1a8a 100644 --- a/proxy/intra_proxy_router.go +++ b/proxy/intra_proxy_router.go @@ -228,8 +228,13 @@ func (r *intraProxyStreamReceiver) Run(ctx context.Context, shardManager ShardMa md.Set(history.MetadataKeyServerClusterID, fmt.Sprintf("%d", r.sourceShardID.ClusterID)) md.Set(history.MetadataKeyServerShardID, fmt.Sprintf("%d", r.sourceShardID.ShardID)) ctx = metadata.NewOutgoingContext(ctx, md) + shardInfos := shardManager.GetShardInfos() + nodeName := "" + if len(shardInfos) > 0 { + nodeName = shardInfos[0].NodeName + } ctx = common.WithIntraProxyHeaders(ctx, map[string]string{ - common.IntraProxyOriginProxyIDHeader: shardManager.GetShardInfo().NodeName, + common.IntraProxyOriginProxyIDHeader: nodeName, }) // Ensure we can cancel Recv() by canceling the context when tearing down diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go index 901c394d..61fc34f0 100644 --- a/proxy/shard_manager.go +++ b/proxy/shard_manager.go @@ -30,64 +30,51 @@ type ( // ShardManager manages distributed shard ownership across proxy instances ShardManager interface { + // Lifecycle // Start initializes the memberlist cluster and starts the manager Start(lifetime context.Context) error // Stop shuts down the manager and leaves the cluster Stop() + + // Shard ownership // RegisterShard registers a clientShardID as owned by this proxy instance and returns the registration timestamp RegisterShard(clientShardID history.ClusterShardID) time.Time // UnregisterShard removes a clientShardID from this proxy's ownership only if the timestamp matches UnregisterShard(clientShardID history.ClusterShardID, expectedRegisteredAt time.Time) - // GetProxyAddress returns the proxy service address for the given node name - GetProxyAddress(nodeName string) (string, bool) // IsLocalShard checks if this proxy instance owns the given shard IsLocalShard(clientShardID history.ClusterShardID) bool + + // Cluster information // GetNodeName returns the name of this proxy instance GetNodeName() string - // GetMemberNodes returns all active proxy nodes in the cluster - GetMemberNodes() []string + // GetProxyAddress returns the proxy service address for the given node name + GetProxyAddress(nodeName string) (string, bool) // GetLocalShards returns all shards currently handled by this proxy instance, keyed by short id GetLocalShards() map[string]history.ClusterShardID // GetRemoteShardsForPeer returns all shards owned by the specified peer node, keyed by short id GetRemoteShardsForPeer(peerNodeName string) (map[string]NodeShardState, error) - // GetShardInfo returns debug information about shard distribution - GetShardInfo() ShardDebugInfo - // GetShardInfos returns debug information about shard distribution as a slice - GetShardInfos() []ShardDebugInfo - // GetChannelInfo returns debug information about active channels - GetChannelInfo() ChannelDebugInfo - // GetShardOwner returns the node name that owns the given shard - GetShardOwner(shard history.ClusterShardID) (string, bool) - // TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed - TerminatePreviousLocalReceiver(shardID history.ClusterShardID, logger log.Logger) - // GetIntraProxyManager returns the intra-proxy manager if it exists - GetIntraProxyManager() *intraProxyManager - // GetIntraProxyTLSConfig returns the TLS config for intra-proxy connections - GetIntraProxyTLSConfig() encryption.TLSConfig + + // Message routing // DeliverAckToShardOwner routes an ACK request to the appropriate shard owner (local or remote) DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool // DeliverMessagesToShardOwner routes replication messages to the appropriate shard owner (local or remote) DeliverMessagesToShardOwner(targetShard history.ClusterShardID, routedMsg *RoutedMessage, shutdownChan channel.ShutdownOnce, logger log.Logger) bool - // SetOnPeerJoin registers a callback invoked when a new peer joins - SetOnPeerJoin(handler func(nodeName string)) - // SetOnPeerLeave registers a callback invoked when a peer leaves. - SetOnPeerLeave(handler func(nodeName string)) - // New: notify when local shard set changes - SetOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) - // New: notify when remote shard set changes for a peer - SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) + + // Active receivers // RegisterActiveReceiver registers an active receiver for watermark propagation RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver ActiveReceiver) // UnregisterActiveReceiver removes an active receiver UnregisterActiveReceiver(sourceShardID history.ClusterShardID) // GetActiveReceiver returns the active receiver for the given source shard GetActiveReceiver(sourceShardID history.ClusterShardID) (ActiveReceiver, bool) + // TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed + TerminatePreviousLocalReceiver(shardID history.ClusterShardID, logger log.Logger) + + // Channel management (used for message routing) // SetRemoteSendChan registers a send channel for a specific shard ID SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) // GetRemoteSendChan retrieves the send channel for a specific shard ID GetRemoteSendChan(shardID history.ClusterShardID) (chan RoutedMessage, bool) - // GetAllRemoteSendChans returns a map of all remote send channels - GetAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage // GetRemoteSendChansByCluster returns a copy of remote send channels filtered by clusterID GetRemoteSendChansByCluster(clusterID int32) map[history.ClusterShardID]chan RoutedMessage // RemoveRemoteSendChan removes the send channel for a specific shard ID only if it matches the provided channel @@ -96,18 +83,26 @@ type ( SetLocalAckChan(shardID history.ClusterShardID, ackChan chan RoutedAck) // GetLocalAckChan retrieves the ack channel for a specific shard ID GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, bool) - // GetAllLocalAckChans returns a map of all local ack channels - GetAllLocalAckChans() map[history.ClusterShardID]chan RoutedAck // RemoveLocalAckChan removes the ack channel for a specific shard ID only if it matches the provided channel RemoveLocalAckChan(shardID history.ClusterShardID, expectedChan chan RoutedAck) - // ForceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID - ForceRemoveLocalAckChan(shardID history.ClusterShardID) // SetLocalReceiverCancelFunc registers a cancel function for a local receiver for a specific shard ID SetLocalReceiverCancelFunc(shardID history.ClusterShardID, cancelFunc context.CancelFunc) // GetLocalReceiverCancelFunc retrieves the cancel function for a local receiver for a specific shard ID GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (context.CancelFunc, bool) // RemoveLocalReceiverCancelFunc unconditionally removes the cancel function for a local receiver for a specific shard ID RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) + + // Intra-proxy + // GetIntraProxyManager returns the intra-proxy manager if it exists + GetIntraProxyManager() *intraProxyManager + // GetIntraProxyTLSConfig returns the TLS config for intra-proxy connections + GetIntraProxyTLSConfig() encryption.TLSConfig + + // Debug + // GetShardInfos returns debug information about shard distribution as a slice + GetShardInfos() []ShardDebugInfo + // GetChannelInfo returns debug information about active channels + GetChannelInfo() ChannelDebugInfo } shardManagerImpl struct { @@ -176,39 +171,8 @@ type ( Shards map[string]ShardInfo `json:"shards"` Updated time.Time `json:"updated"` } - - // memberSnapshot is a thread-safe copy of memberlist node data - memberSnapshot struct { - Name string - Meta []byte - } ) -// getMembersSnapshot returns a thread-safe snapshot of remote node states. -// Uses the remoteNodeStates map instead of ml.Members() to avoid data races. -func (sm *shardManagerImpl) getMembersSnapshot() []memberSnapshot { - sm.remoteNodeStatesMu.RLock() - defer sm.remoteNodeStatesMu.RUnlock() - - snapshots := make([]memberSnapshot, 0, len(sm.remoteNodeStates)) - for nodeName, state := range sm.remoteNodeStates { - // Marshal the state to get the meta bytes - metaBytes, err := json.Marshal(state) - if err != nil { - sm.logger.Warn("Failed to marshal node state for snapshot", - tag.NewStringTag("node", nodeName), - tag.Error(err)) - continue - } - snapshot := memberSnapshot{ - Name: nodeName, - Meta: metaBytes, - } - snapshots = append(snapshots, snapshot) - } - return snapshots -} - // NewShardManager creates a new shard manager instance func NewShardManager(memberlistConfig *config.MemberlistConfig, shardCountConfig config.ShardCountConfig, intraProxyTLSConfig encryption.TLSConfig, logger log.Logger) ShardManager { delegate := &shardDelegate{ @@ -239,29 +203,29 @@ func NewShardManager(memberlistConfig *config.MemberlistConfig, shardCountConfig return sm } -// SetOnPeerJoin registers a callback invoked on new peer joins. -func (sm *shardManagerImpl) SetOnPeerJoin(handler func(nodeName string)) { +// setOnPeerJoin registers a callback invoked on new peer joins. +func (sm *shardManagerImpl) setOnPeerJoin(handler func(nodeName string)) { sm.mutex.Lock() defer sm.mutex.Unlock() sm.onPeerJoin = handler } -// SetOnPeerLeave registers a callback invoked when a peer leaves. -func (sm *shardManagerImpl) SetOnPeerLeave(handler func(nodeName string)) { +// setOnPeerLeave registers a callback invoked when a peer leaves. +func (sm *shardManagerImpl) setOnPeerLeave(handler func(nodeName string)) { sm.mutex.Lock() defer sm.mutex.Unlock() sm.onPeerLeave = handler } -// SetOnLocalShardChange registers local shard change callback. -func (sm *shardManagerImpl) SetOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) { +// setOnLocalShardChange registers local shard change callback. +func (sm *shardManagerImpl) setOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) { sm.mutex.Lock() defer sm.mutex.Unlock() sm.onLocalShardChange = handler } -// SetOnRemoteShardChange registers remote shard change callback. -func (sm *shardManagerImpl) SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) { +// setOnRemoteShardChange registers remote shard change callback. +func (sm *shardManagerImpl) setOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) { sm.mutex.Lock() defer sm.mutex.Unlock() sm.onRemoteShardChange = handler @@ -613,37 +577,6 @@ func (sm *shardManagerImpl) GetNodeName() string { return sm.memberlistConfig.NodeName } -func (sm *shardManagerImpl) GetMemberNodes() []string { - if !sm.started || sm.ml == nil { - return []string{sm.GetNodeName()} - } - - // Use a timeout to prevent deadlocks when memberlist is busy - membersChan := make(chan []memberSnapshot, 1) - go func() { - defer func() { - if r := recover(); r != nil { - sm.logger.Error("Panic in GetMemberNodes", tag.NewStringTag("error", fmt.Sprintf("%v", r))) - } - }() - membersChan <- sm.getMembersSnapshot() - }() - - select { - case members := <-membersChan: - nodes := make([]string, len(members)) - for i, member := range members { - nodes[i] = member.Name - } - return nodes - case <-time.After(100 * time.Millisecond): - // Timeout: return cached node name to prevent hanging - sm.logger.Warn("GetMemberNodes timeout, returning self node", - tag.NewStringTag("node", sm.GetNodeName())) - return []string{sm.GetNodeName()} - } -} - func (sm *shardManagerImpl) GetLocalShards() map[string]history.ClusterShardID { sm.mutex.RLock() defer sm.mutex.RUnlock() @@ -654,7 +587,7 @@ func (sm *shardManagerImpl) GetLocalShards() map[string]history.ClusterShardID { return shards } -func (sm *shardManagerImpl) GetShardInfo() ShardDebugInfo { +func (sm *shardManagerImpl) getShardInfo() ShardDebugInfo { localShardMap := sm.GetLocalShards() remoteShards, err := sm.GetRemoteShardsForPeer("") if err != nil { @@ -687,7 +620,7 @@ func (sm *shardManagerImpl) GetShardInfos() []ShardDebugInfo { if sm.memberlistConfig == nil { return []ShardDebugInfo{} } - return []ShardDebugInfo{sm.GetShardInfo()} + return []ShardDebugInfo{sm.getShardInfo()} } // GetChannelInfo returns debug information about active channels @@ -696,7 +629,7 @@ func (sm *shardManagerImpl) GetChannelInfo() ChannelDebugInfo { var totalSendChannels int // Collect remote send channel info first - allSendChans := sm.GetAllRemoteSendChans() + allSendChans := sm.getAllRemoteSendChans() for shardID, ch := range allSendChans { shardKey := ClusterShardIDtoString(shardID) remoteSendChannels[shardKey] = len(ch) @@ -707,7 +640,7 @@ func (sm *shardManagerImpl) GetChannelInfo() ChannelDebugInfo { var totalAckChannels int // Collect local ack channel info separately - allAckChans := sm.GetAllLocalAckChans() + allAckChans := sm.getAllLocalAckChans() for shardID, ch := range allAckChans { shardKey := ClusterShardIDtoString(shardID) localAckChannels[shardKey] = len(ch) @@ -733,11 +666,11 @@ func (sm *shardManagerImpl) TerminatePreviousLocalReceiver(shardID history.Clust // Force remove the cancel function and ack channel from tracking sm.RemoveLocalReceiverCancelFunc(shardID) - sm.ForceRemoveLocalAckChan(shardID) + sm.forceRemoveLocalAckChan(shardID) } } -func (sm *shardManagerImpl) GetShardOwner(shard history.ClusterShardID) (string, bool) { +func (sm *shardManagerImpl) getShardOwner(shard history.ClusterShardID) (string, bool) { remoteShards, err := sm.GetRemoteShardsForPeer("") if err != nil { sm.logger.Error("Failed to get remote shards", tag.Error(err)) @@ -813,7 +746,7 @@ func (sm *shardManagerImpl) DeliverAckToShardOwner( // Attempt remote delivery via intra-proxy when enabled and shard is remote if sm.memberlistConfig != nil { - if owner, ok := sm.GetShardOwner(sourceShard); ok && owner != sm.GetNodeName() { + if owner, ok := sm.getShardOwner(sourceShard); ok && owner != sm.GetNodeName() { if addr, found := sm.GetProxyAddress(owner); found { clientShard := routedAck.TargetShard serverShard := sourceShard @@ -871,7 +804,7 @@ func (sm *shardManagerImpl) DeliverMessagesToShardOwner( // Attempt remote delivery via intra-proxy when enabled and shard is remote if sm.memberlistConfig != nil { - if owner, ok := sm.GetShardOwner(targetShard); ok && owner != sm.GetNodeName() { + if owner, ok := sm.getShardOwner(targetShard); ok && owner != sm.GetNodeName() { if addr, found := sm.GetProxyAddress(owner); found { if mgr := sm.GetIntraProxyManager(); mgr != nil { resp := routedMsg.Resp @@ -893,7 +826,7 @@ func (sm *shardManagerImpl) DeliverMessagesToShardOwner( func (sm *shardManagerImpl) SetupCallbacks() { // Wire memberlist peer-join callback to reconcile intra-proxy receivers for local/remote pairs - sm.SetOnPeerJoin(func(nodeName string) { + sm.setOnPeerJoin(func(nodeName string) { sm.logger.Info("OnPeerJoin", tag.NewStringTag("nodeName", nodeName)) defer sm.logger.Info("OnPeerJoin done", tag.NewStringTag("nodeName", nodeName)) if sm.intraMgr != nil { @@ -902,7 +835,7 @@ func (sm *shardManagerImpl) SetupCallbacks() { }) // Wire peer-leave to cleanup intra-proxy resources for that peer - sm.SetOnPeerLeave(func(nodeName string) { + sm.setOnPeerLeave(func(nodeName string) { sm.logger.Info("OnPeerLeave", tag.NewStringTag("nodeName", nodeName)) defer sm.logger.Info("OnPeerLeave done", tag.NewStringTag("nodeName", nodeName)) if sm.intraMgr != nil { @@ -911,7 +844,7 @@ func (sm *shardManagerImpl) SetupCallbacks() { }) // Wire local shard changes to reconcile intra-proxy receivers - sm.SetOnLocalShardChange(func(shard history.ClusterShardID, added bool) { + sm.setOnLocalShardChange(func(shard history.ClusterShardID, added bool) { sm.logger.Info("OnLocalShardChange", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) defer sm.logger.Info("OnLocalShardChange done", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) if added { @@ -923,7 +856,7 @@ func (sm *shardManagerImpl) SetupCallbacks() { }) // Wire remote shard changes to reconcile intra-proxy receivers - sm.SetOnRemoteShardChange(func(peer string, shard history.ClusterShardID, added bool) { + sm.setOnRemoteShardChange(func(peer string, shard history.ClusterShardID, added bool) { sm.logger.Info("OnRemoteShardChange", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) defer sm.logger.Info("OnRemoteShardChange done", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) if added { @@ -1189,8 +1122,8 @@ func (sm *shardManagerImpl) GetRemoteSendChan(shardID history.ClusterShardID) (c return ch, exists } -// GetAllRemoteSendChans returns a map of all remote send channels -func (sm *shardManagerImpl) GetAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage { +// getAllRemoteSendChans returns a map of all remote send channels +func (sm *shardManagerImpl) getAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage { sm.remoteSendChannelsMu.RLock() defer sm.remoteSendChannelsMu.RUnlock() @@ -1244,8 +1177,8 @@ func (sm *shardManagerImpl) GetLocalAckChan(shardID history.ClusterShardID) (cha return ch, exists } -// GetAllLocalAckChans returns a map of all local ack channels -func (sm *shardManagerImpl) GetAllLocalAckChans() map[history.ClusterShardID]chan RoutedAck { +// getAllLocalAckChans returns a map of all local ack channels +func (sm *shardManagerImpl) getAllLocalAckChans() map[history.ClusterShardID]chan RoutedAck { sm.localAckChannelsMu.RLock() defer sm.localAckChannelsMu.RUnlock() @@ -1269,8 +1202,8 @@ func (sm *shardManagerImpl) RemoveLocalAckChan(shardID history.ClusterShardID, e } } -// ForceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID -func (sm *shardManagerImpl) ForceRemoveLocalAckChan(shardID history.ClusterShardID) { +// forceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID +func (sm *shardManagerImpl) forceRemoveLocalAckChan(shardID history.ClusterShardID) { sm.logger.Info("Force remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) sm.localAckChannelsMu.Lock() defer sm.localAckChannelsMu.Unlock() From fc1f735510aa24e2288eab599b5692847b1d39b6 Mon Sep 17 00:00:00 2001 From: Hai Zhao Date: Fri, 2 Jan 2026 16:37:58 -0800 Subject: [PATCH 38/38] run test in parallel to reduce time --- Makefile | 2 +- proxy/test/replication_failover_test.go | 1 + proxy/test/test_common.go | 6 ++++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index be35b5a2..b54ea61b 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ GO_GET_TOOL = go get -tool -modfile=$(TOOLS_MOD_FILE) # Disable cgo by default. CGO_ENABLED ?= 0 -TEST_ARG ?= -race -timeout=15m -tags test_dep -count=1 +TEST_ARG ?= -race -timeout=5m -tags test_dep -count=1 BENCH_ARG ?= -benchtime=5000x ALL_SRC := $(shell find . -name "*.go") diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index 64fd47d7..5a26b679 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -216,6 +216,7 @@ var testConfigs = []TestConfig{ func TestReplicationFailoverTestSuite(t *testing.T) { for _, tc := range testConfigs { t.Run(tc.Name, func(t *testing.T) { + t.Parallel() s := &ReplicationTestSuite{ shardCountA: tc.ShardCountA, shardCountB: tc.ShardCountB, diff --git a/proxy/test/test_common.go b/proxy/test/test_common.go index 30f8189a..f321335d 100644 --- a/proxy/test/test_common.go +++ b/proxy/test/test_common.go @@ -243,6 +243,10 @@ func createLoadBalancer( return trackingProxy, nil } +// clusterCreationMu serializes cluster creation to avoid data races in Temporal server's +// test infrastructure which uses global cached values that aren't thread-safe. +var clusterCreationMu sync.Mutex + func createCluster( logger log.Logger, t testingT, @@ -251,6 +255,8 @@ func createCluster( initialFailoverVersion int64, numHistoryHosts int, ) *testcore.TestCluster { + clusterCreationMu.Lock() + defer clusterCreationMu.Unlock() clusterSuffix := common.GenerateRandomString(8) fullClusterName := fmt.Sprintf("%s-%s", clusterName, clusterSuffix)