diff --git a/config/config.go b/config/config.go index d1b28116..3f2b1eb7 100644 --- a/config/config.go +++ b/config/config.go @@ -251,10 +251,10 @@ type ( func FromServerTLSConfig(cfg ServerTLSConfig) encryption.TLSConfig { return encryption.TLSConfig{ - CertificatePath: cfg.CertificatePath, - KeyPath: cfg.KeyPath, - RemoteCAPath: cfg.ClientCAPath, - ValidateClientCA: cfg.RequireClientAuth, + CertificatePath: cfg.CertificatePath, + KeyPath: cfg.KeyPath, + RemoteCAPath: cfg.ClientCAPath, + VerifyCA: cfg.RequireClientAuth, } } func FromClientTLSConfig(cfg ClientTLSConfig) encryption.TLSConfig { diff --git a/config/converter.go b/config/converter.go index 93412caf..fb09ce06 100644 --- a/config/converter.go +++ b/config/converter.go @@ -119,11 +119,11 @@ func translateClientTCPTLSInfo(cfg TCPClientSetting) TCPTLSInfo { return TCPTLSInfo{ ConnectionString: cfg.ServerAddress, TLSConfig: encryption.TLSConfig{ - CertificatePath: cfg.TLS.CertificatePath, - KeyPath: cfg.TLS.KeyPath, - RemoteCAPath: cfg.TLS.ServerCAPath, - CAServerName: cfg.TLS.ServerName, - ValidateClientCA: false, + CertificatePath: cfg.TLS.CertificatePath, + KeyPath: cfg.TLS.KeyPath, + RemoteCAPath: cfg.TLS.ServerCAPath, + CAServerName: cfg.TLS.ServerName, + VerifyCA: cfg.TLS.ServerName != "" && cfg.TLS.ServerCAPath != "", }, } } @@ -131,10 +131,10 @@ func translateServerTCPTLSInfo(cfg TCPServerSetting) TCPTLSInfo { return TCPTLSInfo{ ConnectionString: cfg.ListenAddress, TLSConfig: encryption.TLSConfig{ - CertificatePath: cfg.TLS.CertificatePath, - KeyPath: cfg.TLS.KeyPath, - RemoteCAPath: cfg.TLS.ClientCAPath, - ValidateClientCA: cfg.TLS.RequireClientAuth, + CertificatePath: cfg.TLS.CertificatePath, + KeyPath: cfg.TLS.KeyPath, + RemoteCAPath: cfg.TLS.ClientCAPath, + VerifyCA: cfg.TLS.RequireClientAuth, }, } } diff --git a/config/new_config_test.go b/config/new_config_test.go index e8db47fc..1db59702 100644 --- a/config/new_config_test.go +++ b/config/new_config_test.go @@ -30,7 +30,7 @@ func TestBasic(t *testing.T) { require.Equal(t, "127.0.0.1:9004", proxyConfig.ClusterConnections[0].RemoteServer.Connection.MuxAddressInfo.ConnectionString) require.Equal(t, "", proxyConfig.ClusterConnections[0].RemoteServer.Connection.TcpServer.ConnectionString) require.Equal(t, "", proxyConfig.ClusterConnections[0].RemoteServer.Connection.TcpClient.ConnectionString) - require.Equal(t, false, proxyConfig.ClusterConnections[0].RemoteServer.Connection.MuxAddressInfo.TLSConfig.ValidateClientCA) + require.Equal(t, false, proxyConfig.ClusterConnections[0].RemoteServer.Connection.MuxAddressInfo.TLSConfig.VerifyCA) nsTranslation, err := proxyConfig.ClusterConnections[0].NamespaceTranslation.AsLocalToRemoteBiMap() require.NoError(t, err) require.Equal(t, "remoteName", nsTranslation.Get("localName")) @@ -55,9 +55,26 @@ func TestConversion(t *testing.T) { require.Nil(t, converted.Inbound) require.Nil(t, converted.Outbound) require.Equal(t, ConnTypeTCP, converted.ClusterConnections[0].RemoteServer.Connection.ConnectionType) - require.True(t, converted.ClusterConnections[0].RemoteServer.Connection.TcpServer.TLSConfig.ValidateClientCA) + require.True(t, converted.ClusterConnections[0].RemoteServer.Connection.TcpServer.TLSConfig.VerifyCA) require.Equal(t, ConnTypeTCP, converted.ClusterConnections[0].LocalServer.Connection.ConnectionType) require.Equal(t, "AddOrUpdateRemoteCluster", converted.ClusterConnections[0].LocalServer.ACLPolicy.AllowedMethods.AdminService[0]) require.Equal(t, "namespace1", converted.ClusterConnections[0].LocalServer.ACLPolicy.AllowedNamespaces[0]) require.Equal(t, int64(100), *converted.ClusterConnections[0].LocalServer.APIOverrides.AdminService.DescribeCluster.Response.FailoverVersionIncrement) } + +func TestConversionWithTLS(t *testing.T) { + samplePath := filepath.Join("..", "develop", "old-config-with-TLS.yaml") + + proxyConfig, err := LoadConfig[S2SProxyConfig](samplePath) + require.NoError(t, err) + converted := ToClusterConnConfig(proxyConfig) + require.Equal(t, 1, len(converted.ClusterConnections)) + require.Nil(t, converted.Inbound) + require.Nil(t, converted.Outbound) + require.Equal(t, ConnTypeMuxClient, converted.ClusterConnections[0].RemoteServer.Connection.ConnectionType) + require.False(t, converted.ClusterConnections[0].RemoteServer.Connection.TcpServer.TLSConfig.VerifyCA) + require.Equal(t, ConnTypeTCP, converted.ClusterConnections[0].LocalServer.Connection.ConnectionType) + require.Equal(t, "AddOrUpdateRemoteCluster", converted.ClusterConnections[0].LocalServer.ACLPolicy.AllowedMethods.AdminService[0]) + require.Equal(t, 0, len(converted.ClusterConnections[0].LocalServer.ACLPolicy.AllowedNamespaces)) + require.Nil(t, converted.ClusterConnections[0].LocalServer.APIOverrides) +} diff --git a/develop/old-config-with-TLS.yaml b/develop/old-config-with-TLS.yaml new file mode 100644 index 00000000..1eadb4bb --- /dev/null +++ b/develop/old-config-with-TLS.yaml @@ -0,0 +1,75 @@ +inbound: + name: "inbound-server" + server: + # No TLS here because the mux connection is already mTLS + type: "mux" + mux: "muxed" + client: + tcp: + # Frontend of local cluster + serverAddress: "local-frontend:7233" + tls: + # Certificate that identifies this host to local Temporal frontend + certificatePath: "/tls/internode/tls.crt" + # Key this host will use to encrypt local communication to local Temporal frontend + keyPath: "/tls/internode/tls.key" + # CA for the local Temporal frontend + serverCAPath: "/tls/internode/ca.crt" + # Servername that should be stamped on the local Temporal frontend's cert + serverName: "frontend.temporal.svc.cluster.local" + aclPolicy: + allowedMethods: + adminService: + - AddOrUpdateRemoteCluster + - RemoveRemoteCluster + - DescribeCluster + - DescribeMutableState + - GetNamespaceReplicationMessages + - GetWorkflowExecutionRawHistoryV2 + - ListClusters + - StreamWorkflowReplicationMessages + - ReapplyEvents + - GetNamespace # for EagerGetNamespace +outbound: + name: "outbound-server" + server: + tcp: + listenAddress: "0.0.0.0:9233" + tls: + # Certificate that identifies this host to *local* Temporal cluster + certificatePath: "/tls/internode/tls.crt" + # Key this host will use to encrypt local communication to *local* Temporal cluster + keyPath: "/tls/internode/tls.key" + # CA for *local* Temporal cluster + clientCAPath: "/tls/internode/ca.crt" + # We trust our local cluster, we don't need to verify its cert + requireClientAuth: false + client: + # No TLS here because the mux connection is already mTLS + type: "mux" + mux: "muxed" +mux: + - name: "muxed" + mode: "client" + num_connections: 10 + client: + serverAddress: "remote-s2s-proxy-endpoint:8233" + tls: + # Certificate that identifies this host to the remote s2s proxy + certificatePath: "/s2c-client-tls/tls.crt" + # Key for encrypting the mux tunnel to the remote s2s proxy + keyPath: "/s2c-client-tls/tls.key" + # CA for the remote s2s proxy + serverCAPath: "/s2c-server-tls/tls.crt" +healthCheck: + protocol: "http" + listenAddress: "0.0.0.0:8234" +namespaceNameTranslation: + mappings: + - localName: "mg-hybrid" + remoteName: "mg-hybrid" +metrics: + prometheus: + listenAddress: "0.0.0.0:9090" +# This enables profiling with the default config and port +profiling: {} \ No newline at end of file diff --git a/develop/sample-cluster-conn-config.yaml b/develop/sample-cluster-conn-config.yaml index 3f44a24c..4fc40acd 100644 --- a/develop/sample-cluster-conn-config.yaml +++ b/develop/sample-cluster-conn-config.yaml @@ -10,7 +10,7 @@ clusterConnections: keyPath: "" remoteCAPath: "" caServerName: "" - validateClientCA: false + verifyCA: false tcpServer: address: "127.0.0.1:9002" tls: @@ -18,7 +18,7 @@ clusterConnections: keyPath: "" remoteCAPath: "" caServerName: "" - validateClientCA: false + verifyCA: false clusterInfo: serverVersion: "over 9000" shardCount: 42 @@ -53,7 +53,7 @@ clusterConnections: keyPath: "" remoteCAPath: "" caServerName: "" - validateClientCA: false + verifyCA: false clusterInfo: shardCount: 42 serverVersion: "v1.22" diff --git a/encryption/tls.go b/encryption/tls.go index ce288cd5..e5a2657f 100644 --- a/encryption/tls.go +++ b/encryption/tls.go @@ -17,12 +17,18 @@ import ( ) type ( + // TLSConfig sets TLS options for the proxy's clients and servers TLSConfig struct { - CertificatePath string `yaml:"certificatePath"` - KeyPath string `yaml:"keyPath"` - RemoteCAPath string `yaml:"remoteCAPath"` - CAServerName string `yaml:"caServerName"` - ValidateClientCA bool `yaml:"validateClientCA"` + // CertificatePath is the path to the TLS cert that identifies this host + CertificatePath string `yaml:"certificatePath"` + // KeyPath is the path to the TLS key used to encrypt traffic + KeyPath string `yaml:"keyPath"` + // RemoteCAPath is the path to the TLS CA cert that is used to verify the remote host's certificate + RemoteCAPath string `yaml:"remoteCAPath"` + // CAServerName must match against the remote host's CA cert + CAServerName string `yaml:"caServerName"` + // If set to false, VerifyCA will skip the CA authentication step + VerifyCA bool `yaml:"verifyCA"` } HttpGetter interface { @@ -38,48 +44,52 @@ var netClient HttpGetter = &http.Client{ Timeout: time.Second * 10, } -func GetServerTLSConfig(serverConfig TLSConfig, logger log.Logger) (*tls.Config, error) { - certPath := serverConfig.CertificatePath - keyPath := serverConfig.KeyPath - clientCAPath := serverConfig.RemoteCAPath - +func GetServerTLSConfig(serverConfig TLSConfig, logger log.Logger) (tlsConfig *tls.Config, err error) { if !serverConfig.IsEnabled() { - return nil, nil + return } - var serverCert *tls.Certificate - var clientCAPool *x509.CertPool - - clientAuthType := tls.NoClientCert - if serverConfig.ValidateClientCA { - clientAuthType = tls.RequireAndVerifyClientCert - caCertPool, err := fetchCACert(clientCAPath) + tlsConfig = auth.NewEmptyTLSConfig() + if serverConfig.VerifyCA { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.ClientCAs, err = fetchCACert(serverConfig.RemoteCAPath) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read CACert from %s: %w", serverConfig.RemoteCAPath, err) } - clientCAPool = caCertPool + } else { + tlsConfig.ClientAuth = tls.NoClientCert } - if certPath != "" { - cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if serverConfig.CertificatePath != "" { + serverCert, err := tls.LoadX509KeyPair(serverConfig.CertificatePath, serverConfig.KeyPath) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to load cert-key pair (%s, %s): %w", + serverConfig.CertificatePath, serverConfig.KeyPath, err) } - serverCert = &cert + tlsConfig.Certificates = []tls.Certificate{serverCert} } - c := auth.NewEmptyTLSConfig() - c.ClientAuth = clientAuthType - c.Certificates = []tls.Certificate{*serverCert} - c.ClientCAs = clientCAPool - c.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { - logger.Info("Received TLS handshake", tag.Address(hello.Conn.RemoteAddr().String())) + tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + logger.Info("Received TLS handshake", tag.Address(hello.Conn.RemoteAddr().String()), tag.ServerName(hello.ServerName)) return nil, nil } - c.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + tlsConfig.GetClientCertificate = func(clientInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { + var allCertFailures error + for _, cert := range tlsConfig.Certificates { + certErr := clientInfo.SupportsCertificate(&cert) + if certErr == nil { + return &cert, nil + } + allCertFailures = errors.Join(allCertFailures, certErr) + } + logger.Warn("Could not match cert request. Check cert failures in error tag", tag.Error(allCertFailures)) + return nil, allCertFailures + } + + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { if len(rawCerts) == 0 { - logger.Info("No client certificate provided") + logger.Info("No client certificate provided, so no verification performed") } else { cert, _ := x509.ParseCertificate(rawCerts[0]) logger.Info(fmt.Sprintf("Client certificate subject: %s", cert.Subject)) @@ -87,53 +97,41 @@ func GetServerTLSConfig(serverConfig TLSConfig, logger log.Logger) (*tls.Config, return nil } - return c, nil + return tlsConfig, nil } -func GetClientTLSConfig(clientConfig TLSConfig) (*tls.Config, error) { - certPath := clientConfig.CertificatePath - keyPath := clientConfig.KeyPath - caPath := clientConfig.RemoteCAPath - serverName := clientConfig.CAServerName - +func GetClientTLSConfig(clientConfig TLSConfig) (tlsConfig *tls.Config, err error) { if !clientConfig.IsEnabled() { - return nil, nil + return } - var cert *tls.Certificate - var caPool *x509.CertPool - - if caPath != "" { - caCertPool, err := fetchCACert(caPath) - if err != nil { - return nil, err + tlsConfig = auth.NewEmptyTLSConfig() + if !clientConfig.VerifyCA { + tlsConfig.InsecureSkipVerify = true + } else { + if clientConfig.CAServerName == "" || clientConfig.RemoteCAPath == "" { + return nil, errors.New("CAServerName and RemoteCAPath must be set when VerifyCA is true") } - caPool = caCertPool + tlsConfig.ServerName = clientConfig.CAServerName } - if certPath != "" { - myCert, err := tls.LoadX509KeyPair(certPath, keyPath) + if clientConfig.RemoteCAPath != "" { + caCertPool, err := fetchCACert(clientConfig.RemoteCAPath) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to load CA cert: %w", err) } - cert = &myCert + tlsConfig.RootCAs = caCertPool } - // If we are given arguments to verify either server or client, configure TLS - if caPool != nil || cert != nil || serverName != "" { - enableHostVerification := serverName != "" && caPath != "" - tlsConfig := auth.NewTLSConfigForServer(serverName, enableHostVerification) - if caPool != nil { - tlsConfig.RootCAs = caPool - } - if cert != nil { - tlsConfig.Certificates = []tls.Certificate{*cert} + if clientConfig.CertificatePath != "" { + myCert, err := tls.LoadX509KeyPair(clientConfig.CertificatePath, clientConfig.KeyPath) + if err != nil { + return nil, fmt.Errorf("failed to load cert-key pair from path %s: %w", clientConfig.CertificatePath, err) } - - return tlsConfig, nil + tlsConfig.Certificates = []tls.Certificate{myCert} } - return nil, nil + return } func fetchCACert(pathOrUrl string) (caPool *x509.CertPool, err error) { diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 9a7ee52c..86d29d8d 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -34,7 +34,8 @@ type ( shardManager ShardManager adminClient adminservice.AdminServiceClient adminClientReverse adminservice.AdminServiceClient - logger log.Logger + verboseLogger log.Logger + replicationLogger log.Logger apiOverrides *config.APIOverridesConfig metricLabelValues []string reportStreamValue func(idx int32, value int32) @@ -60,15 +61,19 @@ func NewAdminServiceProxyServer( shardManager ShardManager, lifetime context.Context, ) adminservice.AdminServiceServer { - // The AdminServiceStreams will duplicate the same output for an underlying connection issue hundreds of times. - // Limit their output to three times per minute - logger = log.NewThrottledLogger(log.With(logger, common.ServiceTag(serviceName)), + // Replication streams / APIs will run many hundreds of times per second. Throttle their output + // to 3 / min + replicationLogger := log.NewThrottledLogger(log.With(logger, common.ServiceTag(serviceName)), func() float64 { return 3.0 / 60.0 }) + // For config operations, allow most logs so that we can see all the info without putting disk at risk + verboseLogger := log.NewThrottledLogger(log.With(logger, common.ServiceTag(serviceName)), + func() float64 { return 3.0 }) return &adminServiceProxyServer{ shardManager: shardManager, adminClient: adminClient, adminClientReverse: adminClientReverse, - logger: logger, + verboseLogger: verboseLogger, + replicationLogger: replicationLogger, apiOverrides: apiOverrides, metricLabelValues: metricLabelValues, reportStreamValue: reportStreamValue, @@ -79,7 +84,8 @@ func NewAdminServiceProxyServer( } } -func (s *adminServiceProxyServer) AddOrUpdateRemoteCluster(ctx context.Context, in0 *adminservice.AddOrUpdateRemoteClusterRequest) (*adminservice.AddOrUpdateRemoteClusterResponse, error) { +func (s *adminServiceProxyServer) AddOrUpdateRemoteCluster(ctx context.Context, in0 *adminservice.AddOrUpdateRemoteClusterRequest) (resp *adminservice.AddOrUpdateRemoteClusterResponse, err error) { + s.verboseLogger.Info("Received AddOrUpdateRemoteCluster", tag.Address(in0.FrontendAddress), tag.NewBoolTag("Enabled", in0.GetEnableRemoteClusterConnection()), tag.NewStringsTag("configTags", s.metricLabelValues)) if !common.IsRequestTranslationDisabled(ctx) && s.apiOverrides != nil { reqOverride := s.apiOverrides.AdminService.AddOrUpdateRemoteCluster if reqOverride != nil && len(reqOverride.Request.FrontendAddress) > 0 { @@ -88,9 +94,15 @@ func (s *adminServiceProxyServer) AddOrUpdateRemoteCluster(ctx context.Context, // from the local temporal server, or the proxy may be deployed behind a load balancer. // Only used in single-proxy scenarios, i.e. Temporal <> Proxy <> Temporal in0.FrontendAddress = reqOverride.Request.FrontendAddress + s.verboseLogger.Info("Overwrote outbound address", tag.Address(in0.FrontendAddress), tag.NewStringsTag("configTags", s.metricLabelValues)) } } - return s.adminClient.AddOrUpdateRemoteCluster(ctx, in0) + resp, err = s.adminClient.AddOrUpdateRemoteCluster(ctx, in0) + if err != nil { + s.verboseLogger.Error("Error when adding remote cluster", tag.Error(err), tag.Operation("AddOrUpdateRemoteCluster"), + tag.NewStringTag("FrontendAddress", in0.GetFrontendAddress())) + } + return } func (s *adminServiceProxyServer) AddSearchAttributes(ctx context.Context, in0 *adminservice.AddSearchAttributesRequest) (*adminservice.AddSearchAttributesResponse, error) { @@ -114,8 +126,16 @@ func (s *adminServiceProxyServer) DeleteWorkflowExecution(ctx context.Context, i } func (s *adminServiceProxyServer) DescribeCluster(ctx context.Context, in0 *adminservice.DescribeClusterRequest) (*adminservice.DescribeClusterResponse, error) { + s.verboseLogger.Info("Received DescribeClusterRequest") resp, err := s.adminClient.DescribeCluster(ctx, in0) + if resp != nil { + s.verboseLogger.Info("Raw DescribeClusterResponse", tag.NewStringTag("clusterID", resp.ClusterId), + tag.NewStringTag("clusterName", resp.ClusterName), tag.NewStringTag("version", resp.ServerVersion), + tag.NewInt64("failoverVersionIncrement", resp.FailoverVersionIncrement), tag.NewInt64("initialFailoverVersion", resp.InitialFailoverVersion), + tag.NewBoolTag("isGlobalNamespaceEnabled", resp.IsGlobalNamespaceEnabled), tag.NewStringsTag("configTags", s.metricLabelValues)) + } if common.IsRequestTranslationDisabled(ctx) { + s.verboseLogger.Info("Request translation disabled. Returning as-is") return resp, err } @@ -134,11 +154,20 @@ func (s *adminServiceProxyServer) DescribeCluster(ctx context.Context, in0 *admi responseOverride := s.apiOverrides.AdminService.DescribeCluster.Response if resp != nil && responseOverride.FailoverVersionIncrement != nil { resp.FailoverVersionIncrement = *responseOverride.FailoverVersionIncrement + s.verboseLogger.Info("Overwrite FailoverVersionIncrement", tag.NewInt64("failoverVersionIncrement", resp.FailoverVersionIncrement), + tag.NewStringsTag("configTags", s.metricLabelValues)) } } - s.logger.Info("DescribeCluster response", tag.NewStringTag("response", fmt.Sprintf("%v", resp))) - + if resp != nil { + s.verboseLogger.Info("Translated DescribeClusterResponse", tag.NewStringTag("clusterID", resp.ClusterId), + tag.NewStringTag("clusterName", resp.ClusterName), tag.NewStringTag("version", resp.ServerVersion), + tag.NewInt64("failoverVersionIncrement", resp.FailoverVersionIncrement), tag.NewInt64("initialFailoverVersion", resp.InitialFailoverVersion), + tag.NewBoolTag("isGlobalNamespaceEnabled", resp.IsGlobalNamespaceEnabled), tag.NewStringsTag("configTags", s.metricLabelValues)) + } + if err != nil { + s.verboseLogger.Info("Got error when calling DescribeCluster!", tag.Error(err), tag.NewStringsTag("configTags", s.metricLabelValues)) + } return resp, err } @@ -150,8 +179,17 @@ func (s *adminServiceProxyServer) DescribeHistoryHost(ctx context.Context, in0 * return s.adminClient.DescribeHistoryHost(ctx, in0) } -func (s *adminServiceProxyServer) DescribeMutableState(ctx context.Context, in0 *adminservice.DescribeMutableStateRequest) (*adminservice.DescribeMutableStateResponse, error) { - return s.adminClient.DescribeMutableState(ctx, in0) +func (s *adminServiceProxyServer) DescribeMutableState(ctx context.Context, in0 *adminservice.DescribeMutableStateRequest) (resp *adminservice.DescribeMutableStateResponse, err error) { + resp, err = s.adminClient.DescribeMutableState(ctx, in0) + if err != nil { + // This is a duplicate of the grpc client metrics, but not everyone has metrics set up + s.replicationLogger.Error("Failed to describe workflow", + tag.NewStringTag("WorkflowId", in0.GetExecution().GetWorkflowId()), + tag.NewStringTag("RunId", in0.GetExecution().GetRunId()), + tag.NewStringTag("Namespace", in0.GetNamespace()), + tag.Error(err), tag.Operation("DescribeMutableState")) + } + return } func (s *adminServiceProxyServer) GetDLQMessages(ctx context.Context, in0 *adminservice.GetDLQMessagesRequest) (*adminservice.GetDLQMessagesResponse, error) { @@ -170,12 +208,23 @@ func (s *adminServiceProxyServer) GetNamespace(ctx context.Context, in0 *adminse return s.adminClient.GetNamespace(ctx, in0) } -func (s *adminServiceProxyServer) GetNamespaceReplicationMessages(ctx context.Context, in0 *adminservice.GetNamespaceReplicationMessagesRequest) (*adminservice.GetNamespaceReplicationMessagesResponse, error) { - return s.adminClient.GetNamespaceReplicationMessages(ctx, in0) +func (s *adminServiceProxyServer) GetNamespaceReplicationMessages(ctx context.Context, in0 *adminservice.GetNamespaceReplicationMessagesRequest) (resp *adminservice.GetNamespaceReplicationMessagesResponse, err error) { + resp, err = s.adminClient.GetNamespaceReplicationMessages(ctx, in0) + if err != nil { + // This is a duplicate of the grpc client metrics, but not everyone has metrics set up + s.replicationLogger.Error("Failed to get namespace replication messages", tag.NewStringTag("Cluster", in0.GetClusterName()), + tag.Error(err), tag.Operation("GetNamespaceReplicationMessages")) + } + return } -func (s *adminServiceProxyServer) GetReplicationMessages(ctx context.Context, in0 *adminservice.GetReplicationMessagesRequest) (*adminservice.GetReplicationMessagesResponse, error) { - return s.adminClient.GetReplicationMessages(ctx, in0) +func (s *adminServiceProxyServer) GetReplicationMessages(ctx context.Context, in0 *adminservice.GetReplicationMessagesRequest) (resp *adminservice.GetReplicationMessagesResponse, err error) { + resp, err = s.adminClient.GetReplicationMessages(ctx, in0) + if err != nil { + s.replicationLogger.Error("Failed to get replication messages", tag.NewStringTag("Cluster", in0.GetClusterName()), + tag.Error(err), tag.Operation("GetReplicationMessages")) + } + return } func (s *adminServiceProxyServer) GetSearchAttributes(ctx context.Context, in0 *adminservice.GetSearchAttributesRequest) (*adminservice.GetSearchAttributesResponse, error) { @@ -277,7 +326,7 @@ func ClusterShardIDtoShortString(sd history.ClusterShardID) string { func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, ) (retError error) { - defer log.CapturePanic(s.logger, &retError) + defer log.CapturePanic(s.replicationLogger, &retError) targetMetadata, ok := metadata.FromIncomingContext(streamServer.Context()) if !ok { @@ -290,7 +339,7 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( return err } - logger := log.With(s.logger, + logger := log.With(s.replicationLogger, tag.NewStringTag("source", ClusterShardIDtoString(sourceClusterShardID)), tag.NewStringTag("target", ClusterShardIDtoString(targetClusterShardID))) diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index a90f665c..a4f49f34 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -404,10 +404,11 @@ func (s *simpleGRPCServer) Start() { err := s.server.Serve(s.listener) if s.lifetime.Err() != nil { // Cluster is closing, just exit. - return + break } if err != nil { - s.logger.Warn("GRPC server failed", tag.NewStringTag("direction", "outbound"), tag.Address(s.listener.Addr().String()), tag.Error(err)) + s.logger.Warn("GRPC server failed", tag.NewStringTag("direction", "outbound"), + tag.Address(s.listener.Addr().String()), tag.Error(err)) if err == io.EOF { metrics.GRPCServerStopped.WithLabelValues(s.name, "eof").Inc() } else if !errors.Is(err, grpc.ErrServerStopped) { @@ -416,6 +417,7 @@ func (s *simpleGRPCServer) Start() { } time.Sleep(1 * time.Second) } + s.logger.Info("TCP-TLS gRPC server closed as requested", tag.Name(s.name), tag.Address(s.listener.Addr().String())) }() // The basic net.Listen, grpc.Server, and ClientConn are not context-aware, so make sure they clean up on context close. context.AfterFunc(s.lifetime, func() {