diff --git a/agent/doctor/docker_runtime_healthcheck.go b/agent/doctor/docker_runtime_healthcheck.go index 1bb2f9d2b35..4b84a711b79 100644 --- a/agent/doctor/docker_runtime_healthcheck.go +++ b/agent/doctor/docker_runtime_healthcheck.go @@ -19,7 +19,7 @@ import ( "time" "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi" - "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" "github.com/cihub/seelog" ) @@ -28,29 +28,30 @@ const systemPingTimeout = time.Second * 2 var timeNow = time.Now type dockerRuntimeHealthcheck struct { - // HealthcheckType is the reported healthcheck type + // HealthcheckType is the reported healthcheck type. HealthcheckType string `json:"HealthcheckType,omitempty"` - // Status is the container health status - Status doctor.HealthcheckStatus `json:"HealthcheckStatus,omitempty"` - // Timestamp is the timestamp when container health status changed + // Status is the container health status. + Status ecstcs.InstanceHealthCheckStatus `json:"HealthcheckStatus,omitempty"` + // TimeStamp is the timestamp when container health status changed. TimeStamp time.Time `json:"TimeStamp,omitempty"` - // StatusChangeTime is the latest time the health status changed + // StatusChangeTime is the latest time the health status changed. StatusChangeTime time.Time `json:"StatusChangeTime,omitempty"` - // LastStatus is the last container health status - LastStatus doctor.HealthcheckStatus `json:"LastStatus,omitempty"` - // LastTimeStamp is the timestamp of last container health status + // LastStatus is the last container health status. + LastStatus ecstcs.InstanceHealthCheckStatus `json:"LastStatus,omitempty"` + // LastTimeStamp is the timestamp of last container health status. LastTimeStamp time.Time `json:"LastTimeStamp,omitempty"` client dockerapi.DockerClient lock sync.RWMutex } +// NewDockerRuntimeHealthcheck creates a new Docker runtime health check. func NewDockerRuntimeHealthcheck(client dockerapi.DockerClient) *dockerRuntimeHealthcheck { nowTime := timeNow() return &dockerRuntimeHealthcheck{ - HealthcheckType: doctor.HealthcheckTypeContainerRuntime, - Status: doctor.HealthcheckStatusInitializing, + HealthcheckType: ecstcs.InstanceHealthCheckTypeContainerRuntime, + Status: ecstcs.InstanceHealthCheckStatusInitializing, TimeStamp: nowTime, StatusChangeTime: nowTime, LastTimeStamp: nowTime, @@ -58,65 +59,73 @@ func NewDockerRuntimeHealthcheck(client dockerapi.DockerClient) *dockerRuntimeHe } } -func (dhc *dockerRuntimeHealthcheck) RunCheck() doctor.HealthcheckStatus { - // TODO pass in context as an argument +// RunCheck performs a health check by pinging the Docker daemon. +func (dhc *dockerRuntimeHealthcheck) RunCheck() ecstcs.InstanceHealthCheckStatus { + // TODO: Pass in context as an argument. res := dhc.client.SystemPing(context.TODO(), systemPingTimeout) - resultStatus := doctor.HealthcheckStatusOk + resultStatus := ecstcs.InstanceHealthCheckStatusOk if res.Error != nil { seelog.Infof("[DockerRuntimeHealthcheck] Docker Ping failed with error: %v", res.Error) - resultStatus = doctor.HealthcheckStatusImpaired + resultStatus = ecstcs.InstanceHealthCheckStatusImpaired } dhc.SetHealthcheckStatus(resultStatus) return resultStatus } -func (dhc *dockerRuntimeHealthcheck) SetHealthcheckStatus(healthStatus doctor.HealthcheckStatus) { +// SetHealthcheckStatus updates the health check status and timestamps. +func (dhc *dockerRuntimeHealthcheck) SetHealthcheckStatus(healthStatus ecstcs.InstanceHealthCheckStatus) { dhc.lock.Lock() defer dhc.lock.Unlock() nowTime := time.Now() - // if the status has changed, update status change timestamp + // If the status has changed, update status change timestamp. if dhc.Status != healthStatus { dhc.StatusChangeTime = nowTime } - // track previous status + // Track previous status. dhc.LastStatus = dhc.Status dhc.LastTimeStamp = dhc.TimeStamp - // update latest status + // Update latest status. dhc.Status = healthStatus dhc.TimeStamp = nowTime } +// GetHealthcheckType returns the type of this health check. func (dhc *dockerRuntimeHealthcheck) GetHealthcheckType() string { dhc.lock.RLock() defer dhc.lock.RUnlock() return dhc.HealthcheckType } -func (dhc *dockerRuntimeHealthcheck) GetHealthcheckStatus() doctor.HealthcheckStatus { +// GetHealthcheckStatus returns the current health check status. +func (dhc *dockerRuntimeHealthcheck) GetHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { dhc.lock.RLock() defer dhc.lock.RUnlock() return dhc.Status } +// GetHealthcheckTime returns the timestamp of the current health check status. func (dhc *dockerRuntimeHealthcheck) GetHealthcheckTime() time.Time { dhc.lock.RLock() defer dhc.lock.RUnlock() return dhc.TimeStamp } +// GetStatusChangeTime returns the timestamp when the status last changed. func (dhc *dockerRuntimeHealthcheck) GetStatusChangeTime() time.Time { dhc.lock.RLock() defer dhc.lock.RUnlock() return dhc.StatusChangeTime } -func (dhc *dockerRuntimeHealthcheck) GetLastHealthcheckStatus() doctor.HealthcheckStatus { +// GetLastHealthcheckStatus returns the previous health check status. +func (dhc *dockerRuntimeHealthcheck) GetLastHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { dhc.lock.RLock() defer dhc.lock.RUnlock() return dhc.LastStatus } +// GetLastHealthcheckTime returns the timestamp of the previous health check status. func (dhc *dockerRuntimeHealthcheck) GetLastHealthcheckTime() time.Time { dhc.lock.RLock() defer dhc.lock.RUnlock() diff --git a/agent/doctor/docker_runtime_healthcheck_test.go b/agent/doctor/docker_runtime_healthcheck_test.go index 21eec1ae62e..e9456765624 100644 --- a/agent/doctor/docker_runtime_healthcheck_test.go +++ b/agent/doctor/docker_runtime_healthcheck_test.go @@ -9,7 +9,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi" mock_dockerapi "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi/mocks" - "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" "github.com/docker/docker/api/types" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -27,8 +27,8 @@ func TestNewDockerRuntimeHealthCheck(t *testing.T) { defer func() { timeNow = originalTimeNow }() expectedDockerRuntimeHealthcheck := &dockerRuntimeHealthcheck{ - HealthcheckType: doctor.HealthcheckTypeContainerRuntime, - Status: doctor.HealthcheckStatusInitializing, + HealthcheckType: ecstcs.InstanceHealthCheckTypeContainerRuntime, + Status: ecstcs.InstanceHealthCheckStatusInitializing, TimeStamp: mockTime, StatusChangeTime: mockTime, LastTimeStamp: mockTime, @@ -42,8 +42,8 @@ func TestRunCheck(t *testing.T) { testcases := []struct { name string dockerPingResponse *dockerapi.PingResponse - expectedStatus doctor.HealthcheckStatus - expectedLastStatus doctor.HealthcheckStatus + expectedStatus ecstcs.InstanceHealthCheckStatus + expectedLastStatus ecstcs.InstanceHealthCheckStatus }{ { name: "empty checks", @@ -51,8 +51,8 @@ func TestRunCheck(t *testing.T) { Response: &types.Ping{APIVersion: "test_api_version"}, Error: nil, }, - expectedStatus: doctor.HealthcheckStatusOk, - expectedLastStatus: doctor.HealthcheckStatusInitializing, + expectedStatus: ecstcs.InstanceHealthCheckStatusOk, + expectedLastStatus: ecstcs.InstanceHealthCheckStatusInitializing, }, { name: "all true checks", @@ -60,8 +60,8 @@ func TestRunCheck(t *testing.T) { Response: nil, Error: &dockerapi.DockerTimeoutError{}, }, - expectedStatus: doctor.HealthcheckStatusImpaired, - expectedLastStatus: doctor.HealthcheckStatusInitializing, + expectedStatus: ecstcs.InstanceHealthCheckStatusImpaired, + expectedLastStatus: ecstcs.InstanceHealthCheckStatusInitializing, }, } ctrl := gomock.NewController(t) @@ -85,9 +85,9 @@ func TestSetHealthCheckStatus(t *testing.T) { defer ctrl.Finish() dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) dockerRuntimeHealthCheck := NewDockerRuntimeHealthcheck(dockerClient) - healthCheckStatus := doctor.HealthcheckStatusOk + healthCheckStatus := ecstcs.InstanceHealthCheckStatusOk dockerRuntimeHealthCheck.SetHealthcheckStatus(healthCheckStatus) - assert.Equal(t, doctor.HealthcheckStatusOk, dockerRuntimeHealthCheck.Status) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusOk, dockerRuntimeHealthCheck.Status) } func TestSetHealthcheckStatusChange(t *testing.T) { @@ -96,23 +96,23 @@ func TestSetHealthcheckStatusChange(t *testing.T) { dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) dockerRuntimeHealthcheck := NewDockerRuntimeHealthcheck(dockerClient) - // we should start in initializing status - assert.Equal(t, doctor.HealthcheckStatusInitializing, dockerRuntimeHealthcheck.Status) + // We should start in initializing status. + assert.Equal(t, ecstcs.InstanceHealthCheckStatusInitializing, dockerRuntimeHealthcheck.Status) initializationChangeTime := dockerRuntimeHealthcheck.GetStatusChangeTime() - // we update to initializing again; our StatusChangeTime remains the same - dockerRuntimeHealthcheck.SetHealthcheckStatus(doctor.HealthcheckStatusInitializing) + // We update to initializing again; our StatusChangeTime remains the same. + dockerRuntimeHealthcheck.SetHealthcheckStatus(ecstcs.InstanceHealthCheckStatusInitializing) updateChangeTime := dockerRuntimeHealthcheck.GetStatusChangeTime() - assert.Equal(t, doctor.HealthcheckStatusInitializing, dockerRuntimeHealthcheck.Status) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusInitializing, dockerRuntimeHealthcheck.Status) assert.Equal(t, initializationChangeTime, updateChangeTime) - // add a sleep so we know time has elapsed between the initial status and status change time + // Add a sleep so we know time has elapsed between the initial status and status change time. time.Sleep(1 * time.Millisecond) - // change status. This should change the update time too - dockerRuntimeHealthcheck.SetHealthcheckStatus(doctor.HealthcheckStatusOk) - assert.Equal(t, doctor.HealthcheckStatusOk, dockerRuntimeHealthcheck.Status) + // Change status. This should change the update time too. + dockerRuntimeHealthcheck.SetHealthcheckStatus(ecstcs.InstanceHealthCheckStatusOk) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusOk, dockerRuntimeHealthcheck.Status) okChangeTime := dockerRuntimeHealthcheck.GetStatusChangeTime() - // have we updated our change time? + // Have we updated our change time? assert.True(t, okChangeTime.After(initializationChangeTime)) } diff --git a/agent/doctor/ebs_csi_runtime_healthcheck.go b/agent/doctor/ebs_csi_runtime_healthcheck.go index 6a6c1c06c9c..6aa5cd488cd 100644 --- a/agent/doctor/ebs_csi_runtime_healthcheck.go +++ b/agent/doctor/ebs_csi_runtime_healthcheck.go @@ -21,24 +21,25 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" ) const ( - // Default request timeout for EBS CSI Daemon health check requests + // DefaultEBSHealthRequestTimeout is the default request timeout for EBS CSI Daemon health check requests. DefaultEBSHealthRequestTimeout = 2 * time.Second ) -// Health check for EBS CSI Daemon. +// ebsCSIDaemonHealthcheck is a health check for EBS CSI Daemon. type ebsCSIDaemonHealthcheck struct { csiClient csiclient.CSIClient requestTimeout time.Duration *statustracker.HealthCheckStatusTracker } -// Constructor for EBS CSI Daemon Health Check +// NewEBSCSIDaemonHealthCheck is the constructor for EBS CSI Daemon Health Check. func NewEBSCSIDaemonHealthCheck( csiClient csiclient.CSIClient, - requestTimeout time.Duration, // timeout for health check requests + requestTimeout time.Duration, // Timeout for health check requests. ) doctor.Healthcheck { return &ebsCSIDaemonHealthcheck{ csiClient: csiClient, @@ -47,24 +48,25 @@ func NewEBSCSIDaemonHealthCheck( } } -// Performs a health check for EBS CSI Daemon by sending a request to it to get -// node capabilities. If EBS CSI Daemon is not started yet then returns OK trivially. -func (e *ebsCSIDaemonHealthcheck) RunCheck() doctor.HealthcheckStatus { +// RunCheck performs a health check for EBS CSI Daemon by sending a request to it to get node capabilities. +// If EBS CSI Daemon is not started yet then returns OK trivially. +func (e *ebsCSIDaemonHealthcheck) RunCheck() ecstcs.InstanceHealthCheckStatus { ctx, cancel := context.WithTimeout(context.Background(), e.requestTimeout) defer cancel() resp, err := e.csiClient.NodeGetCapabilities(ctx) if err != nil { logger.Error("EBS CSI Daemon health check failed", logger.Fields{field.Error: err}) - e.SetHealthcheckStatus(doctor.HealthcheckStatusImpaired) + e.SetHealthcheckStatus(ecstcs.InstanceHealthCheckStatusImpaired) return e.GetHealthcheckStatus() } logger.Info("EBS CSI Driver is healthy", logger.Fields{"nodeCapabilities": resp}) - e.SetHealthcheckStatus(doctor.HealthcheckStatusOk) + e.SetHealthcheckStatus(ecstcs.InstanceHealthCheckStatusOk) return e.GetHealthcheckStatus() } +// GetHealthcheckType returns the type of this health check. func (e *ebsCSIDaemonHealthcheck) GetHealthcheckType() string { - return doctor.HealthcheckTypeEBSDaemon + return ecstcs.InstanceHealthCheckTypeEBSDaemon } diff --git a/agent/doctor/ebs_csi_runtime_healthcheck_test.go b/agent/doctor/ebs_csi_runtime_healthcheck_test.go index 135bb8459d6..6ed2f7e776d 100644 --- a/agent/doctor/ebs_csi_runtime_healthcheck_test.go +++ b/agent/doctor/ebs_csi_runtime_healthcheck_test.go @@ -20,13 +20,13 @@ import ( "testing" mock_csiclient "github.com/aws/amazon-ecs-agent/ecs-agent/csiclient/mocks" - "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) -// Tests that EBS Daemon Health Check is of the right health check type +// Tests that EBS Daemon Health Check is of the right health check type. func TestEBSGetHealthcheckType(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -34,10 +34,10 @@ func TestEBSGetHealthcheckType(t *testing.T) { csiClient := mock_csiclient.NewMockCSIClient(ctrl) hc := NewEBSCSIDaemonHealthCheck(csiClient, 0) - assert.Equal(t, doctor.HealthcheckTypeEBSDaemon, hc.GetHealthcheckType()) + assert.Equal(t, ecstcs.InstanceHealthCheckTypeEBSDaemon, hc.GetHealthcheckType()) } -// Tests initial health status of EBS Daemon +// Tests initial health status of EBS Daemon. func TestEBSInitialHealth(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -45,15 +45,15 @@ func TestEBSInitialHealth(t *testing.T) { csiClient := mock_csiclient.NewMockCSIClient(ctrl) hc := NewEBSCSIDaemonHealthCheck(csiClient, 0) - assert.Equal(t, doctor.HealthcheckStatusInitializing, hc.GetHealthcheckStatus()) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusInitializing, hc.GetHealthcheckStatus()) } -// Tests RunCheck method of EBS Daemon Health Check +// Tests RunCheck method of EBS Daemon Health Check. func TestEBSRunHealthCheck(t *testing.T) { tcs := []struct { name string setCSIClientExpectations func(csiClient *mock_csiclient.MockCSIClient) - expectedStatus doctor.HealthcheckStatus + expectedStatus ecstcs.InstanceHealthCheckStatus }{ { name: "OK when healthcheck succeeds", @@ -61,14 +61,14 @@ func TestEBSRunHealthCheck(t *testing.T) { csiClient.EXPECT().NodeGetCapabilities(gomock.Any()). Return(&csi.NodeGetCapabilitiesResponse{}, nil) }, - expectedStatus: doctor.HealthcheckStatusOk, + expectedStatus: ecstcs.InstanceHealthCheckStatusOk, }, { name: "IMPAIRED when healthcheck fails", setCSIClientExpectations: func(csiClient *mock_csiclient.MockCSIClient) { csiClient.EXPECT().NodeGetCapabilities(gomock.Any()).Return(nil, errors.New("err")) }, - expectedStatus: doctor.HealthcheckStatusImpaired, + expectedStatus: ecstcs.InstanceHealthCheckStatusImpaired, }, } diff --git a/agent/doctor/statustracker/statustracker.go b/agent/doctor/statustracker/statustracker.go index 55b83c87929..f73d21af0b2 100644 --- a/agent/doctor/statustracker/statustracker.go +++ b/agent/doctor/statustracker/statustracker.go @@ -16,70 +16,76 @@ import ( "sync" "time" - "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" ) -// Helper for keeping track of current and last health check status. +// HealthCheckStatusTracker is a helper for keeping track of current and last health check status. type HealthCheckStatusTracker struct { - status doctor.HealthcheckStatus + status ecstcs.InstanceHealthCheckStatus timeStamp time.Time statusChangeTime time.Time - lastStatus doctor.HealthcheckStatus + lastStatus ecstcs.InstanceHealthCheckStatus lastTimeStamp time.Time - now func() time.Time // function that returns current time (injected for testing) + now func() time.Time // Function that returns current time (injected for testing). lock sync.RWMutex } -func (e *HealthCheckStatusTracker) GetHealthcheckStatus() doctor.HealthcheckStatus { +// GetHealthcheckStatus returns the current health check status. +func (e *HealthCheckStatusTracker) GetHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { e.lock.RLock() defer e.lock.RUnlock() return e.status } +// GetHealthcheckTime returns the timestamp of the current health check status. func (e *HealthCheckStatusTracker) GetHealthcheckTime() time.Time { e.lock.RLock() defer e.lock.RUnlock() return e.timeStamp } +// GetStatusChangeTime returns the timestamp when the status last changed. func (e *HealthCheckStatusTracker) GetStatusChangeTime() time.Time { e.lock.RLock() defer e.lock.RUnlock() return e.statusChangeTime } -func (e *HealthCheckStatusTracker) GetLastHealthcheckStatus() doctor.HealthcheckStatus { +// GetLastHealthcheckStatus returns the previous health check status. +func (e *HealthCheckStatusTracker) GetLastHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { e.lock.RLock() defer e.lock.RUnlock() return e.lastStatus } +// GetLastHealthcheckTime returns the timestamp of the previous health check status. func (e *HealthCheckStatusTracker) GetLastHealthcheckTime() time.Time { e.lock.RLock() defer e.lock.RUnlock() return e.lastTimeStamp } -func (e *HealthCheckStatusTracker) SetHealthcheckStatus(healthStatus doctor.HealthcheckStatus) { +// SetHealthcheckStatus updates the health check status and timestamps. +func (e *HealthCheckStatusTracker) SetHealthcheckStatus(healthStatus ecstcs.InstanceHealthCheckStatus) { e.lock.Lock() defer e.lock.Unlock() nowTime := e.now() - // if the status has changed, update status change timestamp + // If the status has changed, update status change timestamp. if e.status != healthStatus { e.statusChangeTime = nowTime } - // track previous status + // Track previous status. e.lastStatus = e.status e.lastTimeStamp = e.timeStamp - // update latest status + // Update latest status. e.status = healthStatus e.timeStamp = nowTime } -// Returns a new HealthCheckStatusTracker +// NewHealthCheckStatusTracker returns a new HealthCheckStatusTracker. func NewHealthCheckStatusTracker() *HealthCheckStatusTracker { return newHealthCheckStatusTrackerWithTimeFn(time.Now) } @@ -87,7 +93,7 @@ func NewHealthCheckStatusTracker() *HealthCheckStatusTracker { func newHealthCheckStatusTrackerWithTimeFn(timeNow func() time.Time) *HealthCheckStatusTracker { now := timeNow() return &HealthCheckStatusTracker{ - status: doctor.HealthcheckStatusInitializing, + status: ecstcs.InstanceHealthCheckStatusInitializing, timeStamp: now, statusChangeTime: now, now: timeNow, diff --git a/agent/doctor/statustracker/statustracker_test.go b/agent/doctor/statustracker/statustracker_test.go index eb4b74fa106..76e426bfb6a 100644 --- a/agent/doctor/statustracker/statustracker_test.go +++ b/agent/doctor/statustracker/statustracker_test.go @@ -19,7 +19,7 @@ import ( "testing" "time" - "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" "github.com/stretchr/testify/assert" ) @@ -27,40 +27,40 @@ func TestHealthCheckStatusTracker(t *testing.T) { t.Run("initialization", func(t *testing.T) { now := time.Unix(1000, 0) tracker := newHealthCheckStatusTrackerWithTimeFn(func() time.Time { return now }) - assert.Equal(t, doctor.HealthcheckStatusInitializing, tracker.GetHealthcheckStatus()) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusInitializing, tracker.GetHealthcheckStatus()) assert.Equal(t, now, tracker.GetHealthcheckTime()) assert.Equal(t, now, tracker.GetStatusChangeTime()) }) t.Run("last status and timestamp is captured", func(t *testing.T) { tracker := newHealthCheckStatusTrackerWithTimeFn(incrementalTime()) - tracker.SetHealthcheckStatus(doctor.HealthcheckStatusOk) + tracker.SetHealthcheckStatus(ecstcs.InstanceHealthCheckStatusOk) - assert.Equal(t, doctor.HealthcheckStatusOk, tracker.GetHealthcheckStatus()) - assert.Equal(t, doctor.HealthcheckStatusInitializing, tracker.GetLastHealthcheckStatus()) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusOk, tracker.GetHealthcheckStatus()) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusInitializing, tracker.GetLastHealthcheckStatus()) assert.Equal(t, int64(1), tracker.GetLastHealthcheckTime().Unix()) assert.Equal(t, int64(2), tracker.GetHealthcheckTime().Unix()) - assert.Equal(t, int64(2), tracker.GetStatusChangeTime().Unix()) // changed to OK at time 2 + assert.Equal(t, int64(2), tracker.GetStatusChangeTime().Unix()) // Changed to OK at time 2. }) t.Run("status change time is not changed if status hasn't changed", func(t *testing.T) { tracker := newHealthCheckStatusTrackerWithTimeFn(incrementalTime()) - // update (but not change) status a bunch of times + // Update (but not change) status a bunch of times. for i := 0; i < 10; i++ { - tracker.SetHealthcheckStatus(doctor.HealthcheckStatusOk) + tracker.SetHealthcheckStatus(ecstcs.InstanceHealthCheckStatusOk) } - assert.Equal(t, doctor.HealthcheckStatusOk, tracker.GetHealthcheckStatus()) - assert.Equal(t, doctor.HealthcheckStatusOk, tracker.GetLastHealthcheckStatus()) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusOk, tracker.GetHealthcheckStatus()) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusOk, tracker.GetLastHealthcheckStatus()) - // status change time remains at 2 + // Status change time remains at 2. assert.Equal(t, int64(2), tracker.GetStatusChangeTime().Unix()) }) t.Run("multiple updates", func(t *testing.T) { tracker := newHealthCheckStatusTrackerWithTimeFn(incrementalTime()) - tracker.SetHealthcheckStatus(doctor.HealthcheckStatusOk) - tracker.SetHealthcheckStatus(doctor.HealthcheckStatusImpaired) + tracker.SetHealthcheckStatus(ecstcs.InstanceHealthCheckStatusOk) + tracker.SetHealthcheckStatus(ecstcs.InstanceHealthCheckStatusImpaired) - assert.Equal(t, doctor.HealthcheckStatusImpaired, tracker.GetHealthcheckStatus()) - assert.Equal(t, doctor.HealthcheckStatusOk, tracker.GetLastHealthcheckStatus()) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusImpaired, tracker.GetHealthcheckStatus()) + assert.Equal(t, ecstcs.InstanceHealthCheckStatusOk, tracker.GetLastHealthcheckStatus()) assert.Equal(t, int64(2), tracker.GetLastHealthcheckTime().Unix()) assert.Equal(t, int64(3), tracker.GetHealthcheckTime().Unix()) assert.Equal(t, int64(3), tracker.GetStatusChangeTime().Unix()) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/doctor.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/doctor.go index f80d066a9db..2ad8c5848ac 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/doctor.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/doctor.go @@ -19,13 +19,15 @@ import ( "github.com/pkg/errors" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" ) var ( - // EmptyHealthcheckError indicates an error when there are no healthcheck metrics to report + // EmptyHealthcheckError indicates an error when there are no healthcheck metrics to report. EmptyHealthcheckError = errors.New("No instance healthcheck status metrics to report") ) +// Doctor manages and runs health checks for the container instance. type Doctor struct { healthchecks []Healthcheck lock sync.RWMutex @@ -34,6 +36,7 @@ type Doctor struct { statusReported bool } +// NewDoctor creates a new Doctor instance with the provided health checks. func NewDoctor(healthchecks []Healthcheck, cluster string, containerInstanceArn string) (*Doctor, error) { newDoctor := &Doctor{ healthchecks: []Healthcheck{}, @@ -47,8 +50,7 @@ func NewDoctor(healthchecks []Healthcheck, cluster string, containerInstanceArn return newDoctor, nil } -// GetCluster returns the cluster that was provided to the doctor while -// being initialized +// GetCluster returns the cluster that was provided to the doctor while being initialized. func (doc *Doctor) GetCluster() string { doc.lock.RLock() defer doc.lock.RUnlock() @@ -56,8 +58,7 @@ func (doc *Doctor) GetCluster() string { return doc.cluster } -// GetContainerInstanceArn returns the container instance arn that was -// provided to the doctor while being initialized +// GetContainerInstanceArn returns the container instance ARN that was provided to the doctor while being initialized. func (doc *Doctor) GetContainerInstanceArn() string { doc.lock.RLock() defer doc.lock.RUnlock() @@ -65,8 +66,7 @@ func (doc *Doctor) GetContainerInstanceArn() string { return doc.containerInstanceArn } -// SetStatusReported tells the doctor that we have already reported the -// current status of the healthchecks to the backend +// SetStatusReported tells the doctor that we have already reported the current status of the healthchecks to the backend. func (doc *Doctor) SetStatusReported(statusReported bool) { doc.lock.Lock() defer doc.lock.Unlock() @@ -74,8 +74,7 @@ func (doc *Doctor) SetStatusReported(statusReported bool) { doc.statusReported = statusReported } -// HasStatusBeenReported returns whether we have already sent the current -// state of the healthchecks to the backend or not +// HasStatusBeenReported returns whether we have already sent the current state of the healthchecks to the backend or not. func (doc *Doctor) HasStatusBeenReported() bool { doc.lock.RLock() defer doc.lock.RUnlock() @@ -83,20 +82,18 @@ func (doc *Doctor) HasStatusBeenReported() bool { return doc.statusReported } -// AddHealthcheck adds a healthcheck to the list of healthchecks that the -// doctor will run every time doctor.RunHealthchecks() is called +// AddHealthcheck adds a healthcheck to the list of healthchecks that the doctor will run every time doctor.RunHealthchecks() is called. func (doc *Doctor) AddHealthcheck(healthcheck Healthcheck) { doc.lock.Lock() defer doc.lock.Unlock() doc.healthchecks = append(doc.healthchecks, healthcheck) } -// RunHealthchecks runs every healthcheck that the doctor knows about and -// returns a cumulative result; true if they all pass, false otherwise +// RunHealthchecks runs every healthcheck that the doctor knows about and returns a cumulative result; true if they all pass, false otherwise. func (doc *Doctor) RunHealthchecks() bool { doc.lock.Lock() defer doc.lock.Unlock() - allChecksResult := []HealthcheckStatus{} + allChecksResult := []ecstcs.InstanceHealthCheckStatus{} for _, healthcheck := range doc.healthchecks { res := healthcheck.RunCheck() @@ -111,8 +108,7 @@ func (doc *Doctor) RunHealthchecks() bool { return doc.allRight(allChecksResult) } -// GetHealthchecks returns a copy of list of healthchecks that the -// doctor is holding internally. +// GetHealthchecks returns a copy of list of healthchecks that the doctor is holding internally. func (doc *Doctor) GetHealthchecks() *[]Healthcheck { doc.lock.RLock() defer doc.lock.RUnlock() @@ -122,7 +118,7 @@ func (doc *Doctor) GetHealthchecks() *[]Healthcheck { return &healthcheckCopy } -func (doc *Doctor) allRight(checksResult []HealthcheckStatus) bool { +func (doc *Doctor) allRight(checksResult []ecstcs.InstanceHealthCheckStatus) bool { overallResult := true for _, checkResult := range checksResult { overallResult = overallResult && checkResult.Ok() diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/healthcheck.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/healthcheck.go index f185800b895..7a25e3f840c 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/healthcheck.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/healthcheck.go @@ -15,21 +15,18 @@ package doctor import ( "time" -) -const ( - HealthcheckTypeContainerRuntime = "ContainerRuntime" - HealthcheckTypeAgent = "Agent" - HealthcheckTypeEBSDaemon = "EBSDaemon" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" ) +// Healthcheck defines the interface for performing health checks on various components. type Healthcheck interface { GetHealthcheckType() string - GetHealthcheckStatus() HealthcheckStatus + GetHealthcheckStatus() ecstcs.InstanceHealthCheckStatus GetHealthcheckTime() time.Time GetStatusChangeTime() time.Time - GetLastHealthcheckStatus() HealthcheckStatus + GetLastHealthcheckStatus() ecstcs.InstanceHealthCheckStatus GetLastHealthcheckTime() time.Time - RunCheck() HealthcheckStatus - SetHealthcheckStatus(status HealthcheckStatus) + RunCheck() ecstcs.InstanceHealthCheckStatus + SetHealthcheckStatus(status ecstcs.InstanceHealthCheckStatus) } diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/healthcheckstatus.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/healthcheckstatus.go deleted file mode 100644 index 920373ab7be..00000000000 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/doctor/healthcheckstatus.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package doctor - -import ( - "errors" - "strings" -) - -const ( - // HealthcheckStatusInitializing is the zero state of a healthcheck status - HealthcheckStatusInitializing HealthcheckStatus = iota - // HealthcheckStatusOk represents a healthcheck with a true/success result - HealthcheckStatusOk - // HealthcheckStatusImpaired represents a healthcheck with a false/fail result - HealthcheckStatusImpaired -) - -// HealthcheckStatus is an enumeration of possible instance statuses -type HealthcheckStatus int32 - -var healthcheckStatusMap = map[string]HealthcheckStatus{ - "INITIALIZING": HealthcheckStatusInitializing, - "OK": HealthcheckStatusOk, - "IMPAIRED": HealthcheckStatusImpaired, -} - -// String returns a human readable string representation of this object -func (hs HealthcheckStatus) String() string { - for k, v := range healthcheckStatusMap { - if v == hs { - return k - } - } - // we shouldn't see this - return "NONE" -} - -// Ok returns true if the Healthcheck status is OK or INITIALIZING -func (hs HealthcheckStatus) Ok() bool { - return hs == HealthcheckStatusOk || hs == HealthcheckStatusInitializing -} - -// UnmarshalJSON overrides the logic for parsing the JSON-encoded HealthcheckStatus data -func (hs *HealthcheckStatus) UnmarshalJSON(b []byte) error { - if strings.ToLower(string(b)) == "null" { - *hs = HealthcheckStatusInitializing - return nil - } - if b[0] != '"' || b[len(b)-1] != '"' { - *hs = HealthcheckStatusInitializing - return errors.New("healthcheck status unmarshal: status must be a string or null; Got " + string(b)) - } - - stat, ok := healthcheckStatusMap[string(b[1:len(b)-1])] - if !ok { - *hs = HealthcheckStatusInitializing - return errors.New("healthcheck status unmarshal: unrecognized status") - } - *hs = stat - return nil -} - -// MarshalJSON overrides the logic for JSON-encoding the HealthcheckStatus type -func (hs *HealthcheckStatus) MarshalJSON() ([]byte, error) { - if hs == nil { - return nil, nil - } - return []byte(`"` + hs.String() + `"`), nil -} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client/client.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client/client.go index 9512237ab50..4fd80dfde69 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client/client.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/client/client.go @@ -53,20 +53,37 @@ var ( ) // tcsClientServer implements wsclient.ClientServer interface for metrics backend. +// It handles publishing telemetry metrics, health messages, and instance status +// messages to the TCS backend through dedicated channels. type tcsClientServer struct { doctor *doctor.Doctor pullInstanceStatusTicker *time.Ticker disableResourceMetrics bool publishMetricsInterval time.Duration + // metrics is a receive-only channel for telemetry messages containing + // instance and task metrics to be published to the backend. metrics <-chan ecstcs.TelemetryMessage - health <-chan ecstcs.HealthMessage + + // health is a receive-only channel for health messages containing + // task health metrics to be published to the backend. + health <-chan ecstcs.HealthMessage + + // instanceStatus is a receive-only channel for instance status messages + // containing instance health status to be published to the backend. + instanceStatus <-chan ecstcs.InstanceStatusMessage wsclient.ClientServerImpl } // New returns a client/server to bidirectionally communicate with the backend. // The returned struct should have both 'Connect' and 'Serve' called upon it // before being used. +// +// The instanceStatusMessages parameter is optional and can be nil to maintain +// backward compatibility with existing functionality. When provided, it enables +// external components to send instance status updates through a dedicated channel, +// allowing for instance status publishing independent of the doctor module's +// periodic health checks. func New(url string, cfg *wsclient.WSClientMinAgentConfig, doctor *doctor.Doctor, @@ -76,6 +93,7 @@ func New(url string, rwTimeout time.Duration, metricsMessages <-chan ecstcs.TelemetryMessage, healthMessages <-chan ecstcs.HealthMessage, + instanceStatusMessages <-chan ecstcs.InstanceStatusMessage, metricsFactory metrics.EntryFactory, ) wsclient.ClientServer { cs := &tcsClientServer{ @@ -84,6 +102,7 @@ func New(url string, publishMetricsInterval: publishMetricsInterval, metrics: metricsMessages, health: healthMessages, + instanceStatus: instanceStatusMessages, disableResourceMetrics: disableResourceMetrics, ClientServerImpl: wsclient.ClientServerImpl{ URL: url, @@ -122,6 +141,16 @@ func (cs *tcsClientServer) Serve(ctx context.Context) error { return cs.ConsumeMessages(ctx) } +// publishMessages listens for messages on the metrics, health, and instanceStatus +// channels and publishes them to the TCS backend. This method runs in a separate +// goroutine and handles three types of messages concurrently: +// - Telemetry messages containing instance and task metrics +// - Health messages containing task health information +// - Instance status messages containing instance health status information +// +// The method continues processing messages until the context is cancelled. +// Errors during publishing are logged but do not terminate the processing loop, +// ensuring that failures with one message type do not affect others. func (cs *tcsClientServer) publishMessages(ctx context.Context) { for { select { @@ -144,6 +173,14 @@ func (cs *tcsClientServer) publishMessages(ctx context.Context) { field.Error: err, }) } + case instanceStatus := <-cs.instanceStatus: + logger.Debug("received instance status message in instanceStatusChannel") + err := cs.publishInstanceStatusOnce(instanceStatus) + if err != nil { + logger.Warn("Error publishing instance status", logger.Fields{ + field.Error: err, + }) + } } } } @@ -408,7 +445,16 @@ func (cs *tcsClientServer) publishInstanceStatus(ctx context.Context) { select { case <-cs.pullInstanceStatusTicker.C: if !cs.doctor.HasStatusBeenReported() { - err := cs.publishInstanceStatusOnce() + // Create InstanceStatusMessage from doctor data + message, err := cs.createInstanceStatusMessageFromDoctor() + if err != nil { + logger.Warn("Unable to create instance status message from doctor", logger.Fields{ + field.Error: err, + }) + continue + } + + err = cs.publishInstanceStatusOnce(message) if err != nil { logger.Warn("Unable to publish instance status", logger.Fields{ field.Error: err, @@ -425,44 +471,45 @@ func (cs *tcsClientServer) publishInstanceStatus(ctx context.Context) { } } -// publishInstanceStatusOnce gets called on a ticker to pull instance status -// from the doctor instance contained within cs and sned that information to -// the backend -func (cs *tcsClientServer) publishInstanceStatusOnce() error { - // Get the list of health request to send to backend. - request, err := cs.getPublishInstanceStatusRequest() - if err != nil { - return err +// publishInstanceStatusOnce publishes instance status using the provided message +// parameter instead of querying the doctor module. This method accepts an +// InstanceStatusMessage and creates a PublishInstanceStatusRequest from it, +// adding a timestamp and sending it to the TCS backend. +// +// This method enables external components to publish instance status updates +// through the instanceStatus channel, providing an alternative to the doctor +// module's periodic health check publishing mechanism. +func (cs *tcsClientServer) publishInstanceStatusOnce(message ecstcs.InstanceStatusMessage) error { + request := &ecstcs.PublishInstanceStatusRequest{ + Metadata: message.Metadata, + Statuses: message.Statuses, + Timestamp: (*utils.Timestamp)(aws.Time(time.Now())), } - // Make the publish instance status request to the backend. - err = cs.MakeRequest(request) + err := cs.MakeRequest(request) if err != nil { return err } - cs.doctor.SetStatusReported(true) - return nil } -// GetPublishInstanceStatusRequest will get all healthcheck statuses and generate -// a sendable PublishInstanceStatusRequest -func (cs *tcsClientServer) getPublishInstanceStatusRequest() (*ecstcs.PublishInstanceStatusRequest, error) { +// createInstanceStatusMessageFromDoctor creates an InstanceStatusMessage from doctor data +func (cs *tcsClientServer) createInstanceStatusMessageFromDoctor() (ecstcs.InstanceStatusMessage, error) { metadata := &ecstcs.InstanceStatusMetadata{ Cluster: aws.String(cs.doctor.GetCluster()), ContainerInstance: aws.String(cs.doctor.GetContainerInstanceArn()), RequestId: aws.String(uuid.NewRandom().String()), } + instanceStatuses := cs.getInstanceStatuses() if instanceStatuses == nil { - return nil, doctor.EmptyHealthcheckError + return ecstcs.InstanceStatusMessage{}, doctor.EmptyHealthcheckError } - return &ecstcs.PublishInstanceStatusRequest{ - Metadata: metadata, - Statuses: instanceStatuses, - Timestamp: (*utils.Timestamp)(aws.Time(time.Now())), + return ecstcs.InstanceStatusMessage{ + Metadata: metadata, + Statuses: instanceStatuses, }, nil } diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/handler/handler.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/handler/handler.go index 74e16c30564..2bddcdb6219 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/handler/handler.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/handler/handler.go @@ -161,7 +161,7 @@ func (session *telemetrySession) StartTelemetrySession(ctx context.Context) erro tcsEndpointUrl := formatURL(endpoint, session.cluster, session.containerInstanceArn, session.agentVersion, session.agentHash, containerRuntime, session.containerRuntimeVersion) client := tcsclient.New(tcsEndpointUrl, session.cfg, session.doctor, session.disableMetrics, tcsclient.DefaultContainerMetricsPublishInterval, - session.credentialsCache, wsRWTimeout, session.metricsChannel, session.healthChannel, session.metricsFactory) + session.credentialsCache, wsRWTimeout, session.metricsChannel, session.healthChannel, nil, session.metricsFactory) defer client.Close() if session.deregisterInstanceEventStream != nil { diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs/types.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs/types.go index 294ac84d9de..e840dc84b40 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs/types.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs/types.go @@ -14,6 +14,8 @@ package ecstcs import ( + "errors" + "strings" "time" "github.com/aws/amazon-ecs-agent/ecs-agent/utils" @@ -50,3 +52,84 @@ type HealthMessage struct { Metadata *HealthMetadata HealthMetrics []*TaskHealth } + +// InstanceStatusMessage represents a message containing instance health status +// information to be published to the TCS backend. +type InstanceStatusMessage struct { + // Metadata contains identifying information about the container instance + // including cluster name, container instance ARN, and request ID. + Metadata *InstanceStatusMetadata `json:"metadata,omitempty"` + + // Statuses contains a collection of instance status checks that represent + // the health state of various components on the container instance. + Statuses []*InstanceStatus `json:"statuses,omitempty"` +} + +const ( + InstanceHealthCheckTypeContainerRuntime = "ContainerRuntime" + InstanceHealthCheckTypeAgent = "Agent" + InstanceHealthCheckTypeEBSDaemon = "EBSDaemon" + InstanceHealthCheckTypeNvidia = "NvidiaAcceleratedHardware" +) + +const ( + // HealthcheckStatusInitializing is the zero state of a healthcheck status. + InstanceHealthCheckStatusInitializing InstanceHealthCheckStatus = iota + // HealthcheckStatusOk represents a healthcheck with a true/success result. + InstanceHealthCheckStatusOk + // HealthcheckStatusImpaired represents a healthcheck with a false/fail result. + InstanceHealthCheckStatusImpaired +) + +// InstanceHealthCheckStatus is an enumeration of possible instance health check statuses. +type InstanceHealthCheckStatus int32 + +var instanceHealthCheckStatusMap = map[string]InstanceHealthCheckStatus{ + "INITIALIZING": InstanceHealthCheckStatusInitializing, + "OK": InstanceHealthCheckStatusOk, + "IMPAIRED": InstanceHealthCheckStatusImpaired, +} + +// String returns a human readable string representation of this object. +func (hs InstanceHealthCheckStatus) String() string { + for k, v := range instanceHealthCheckStatusMap { + if v == hs { + return k + } + } + // We shouldn't see this. + return "NONE" +} + +// Ok returns true if the instance health check status is OK or INITIALIZING. +func (hs InstanceHealthCheckStatus) Ok() bool { + return hs == InstanceHealthCheckStatusOk || hs == InstanceHealthCheckStatusInitializing +} + +// UnmarshalJSON overrides the logic for parsing the JSON-encoded InstanceHealthCheckStatus data. +func (hs *InstanceHealthCheckStatus) UnmarshalJSON(b []byte) error { + if strings.ToLower(string(b)) == "null" { + *hs = InstanceHealthCheckStatusInitializing + return nil + } + if b[0] != '"' || b[len(b)-1] != '"' { + *hs = InstanceHealthCheckStatusInitializing + return errors.New("instance health check status unmarshal: status must be a string or null; Got " + string(b)) + } + + stat, ok := instanceHealthCheckStatusMap[string(b[1:len(b)-1])] + if !ok { + *hs = InstanceHealthCheckStatusInitializing + return errors.New("instance health check status unmarshal: unrecognized status") + } + *hs = stat + return nil +} + +// MarshalJSON overrides the logic for JSON-encoding the InstanceHealthCheckStatus type. +func (hs *InstanceHealthCheckStatus) MarshalJSON() ([]byte, error) { + if hs == nil { + return nil, nil + } + return []byte(`"` + hs.String() + `"`), nil +} diff --git a/ecs-agent/doctor/doctor.go b/ecs-agent/doctor/doctor.go index f80d066a9db..2ad8c5848ac 100644 --- a/ecs-agent/doctor/doctor.go +++ b/ecs-agent/doctor/doctor.go @@ -19,13 +19,15 @@ import ( "github.com/pkg/errors" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" ) var ( - // EmptyHealthcheckError indicates an error when there are no healthcheck metrics to report + // EmptyHealthcheckError indicates an error when there are no healthcheck metrics to report. EmptyHealthcheckError = errors.New("No instance healthcheck status metrics to report") ) +// Doctor manages and runs health checks for the container instance. type Doctor struct { healthchecks []Healthcheck lock sync.RWMutex @@ -34,6 +36,7 @@ type Doctor struct { statusReported bool } +// NewDoctor creates a new Doctor instance with the provided health checks. func NewDoctor(healthchecks []Healthcheck, cluster string, containerInstanceArn string) (*Doctor, error) { newDoctor := &Doctor{ healthchecks: []Healthcheck{}, @@ -47,8 +50,7 @@ func NewDoctor(healthchecks []Healthcheck, cluster string, containerInstanceArn return newDoctor, nil } -// GetCluster returns the cluster that was provided to the doctor while -// being initialized +// GetCluster returns the cluster that was provided to the doctor while being initialized. func (doc *Doctor) GetCluster() string { doc.lock.RLock() defer doc.lock.RUnlock() @@ -56,8 +58,7 @@ func (doc *Doctor) GetCluster() string { return doc.cluster } -// GetContainerInstanceArn returns the container instance arn that was -// provided to the doctor while being initialized +// GetContainerInstanceArn returns the container instance ARN that was provided to the doctor while being initialized. func (doc *Doctor) GetContainerInstanceArn() string { doc.lock.RLock() defer doc.lock.RUnlock() @@ -65,8 +66,7 @@ func (doc *Doctor) GetContainerInstanceArn() string { return doc.containerInstanceArn } -// SetStatusReported tells the doctor that we have already reported the -// current status of the healthchecks to the backend +// SetStatusReported tells the doctor that we have already reported the current status of the healthchecks to the backend. func (doc *Doctor) SetStatusReported(statusReported bool) { doc.lock.Lock() defer doc.lock.Unlock() @@ -74,8 +74,7 @@ func (doc *Doctor) SetStatusReported(statusReported bool) { doc.statusReported = statusReported } -// HasStatusBeenReported returns whether we have already sent the current -// state of the healthchecks to the backend or not +// HasStatusBeenReported returns whether we have already sent the current state of the healthchecks to the backend or not. func (doc *Doctor) HasStatusBeenReported() bool { doc.lock.RLock() defer doc.lock.RUnlock() @@ -83,20 +82,18 @@ func (doc *Doctor) HasStatusBeenReported() bool { return doc.statusReported } -// AddHealthcheck adds a healthcheck to the list of healthchecks that the -// doctor will run every time doctor.RunHealthchecks() is called +// AddHealthcheck adds a healthcheck to the list of healthchecks that the doctor will run every time doctor.RunHealthchecks() is called. func (doc *Doctor) AddHealthcheck(healthcheck Healthcheck) { doc.lock.Lock() defer doc.lock.Unlock() doc.healthchecks = append(doc.healthchecks, healthcheck) } -// RunHealthchecks runs every healthcheck that the doctor knows about and -// returns a cumulative result; true if they all pass, false otherwise +// RunHealthchecks runs every healthcheck that the doctor knows about and returns a cumulative result; true if they all pass, false otherwise. func (doc *Doctor) RunHealthchecks() bool { doc.lock.Lock() defer doc.lock.Unlock() - allChecksResult := []HealthcheckStatus{} + allChecksResult := []ecstcs.InstanceHealthCheckStatus{} for _, healthcheck := range doc.healthchecks { res := healthcheck.RunCheck() @@ -111,8 +108,7 @@ func (doc *Doctor) RunHealthchecks() bool { return doc.allRight(allChecksResult) } -// GetHealthchecks returns a copy of list of healthchecks that the -// doctor is holding internally. +// GetHealthchecks returns a copy of list of healthchecks that the doctor is holding internally. func (doc *Doctor) GetHealthchecks() *[]Healthcheck { doc.lock.RLock() defer doc.lock.RUnlock() @@ -122,7 +118,7 @@ func (doc *Doctor) GetHealthchecks() *[]Healthcheck { return &healthcheckCopy } -func (doc *Doctor) allRight(checksResult []HealthcheckStatus) bool { +func (doc *Doctor) allRight(checksResult []ecstcs.InstanceHealthCheckStatus) bool { overallResult := true for _, checkResult := range checksResult { overallResult = overallResult && checkResult.Ok() diff --git a/ecs-agent/doctor/doctor_test.go b/ecs-agent/doctor/doctor_test.go index 9234e52c8e6..69218295fce 100644 --- a/ecs-agent/doctor/doctor_test.go +++ b/ecs-agent/doctor/doctor_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" "github.com/stretchr/testify/assert" ) @@ -30,14 +31,18 @@ const ( type trueHealthcheck struct{} -func (tc *trueHealthcheck) RunCheck() HealthcheckStatus { return HealthcheckStatusOk } -func (tc *trueHealthcheck) SetHealthcheckStatus(status HealthcheckStatus) {} -func (tc *trueHealthcheck) GetHealthcheckType() string { return HealthcheckTypeAgent } -func (tc *trueHealthcheck) GetHealthcheckStatus() HealthcheckStatus { - return HealthcheckStatusInitializing +func (tc *trueHealthcheck) RunCheck() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusOk } -func (tc *trueHealthcheck) GetLastHealthcheckStatus() HealthcheckStatus { - return HealthcheckStatusInitializing +func (tc *trueHealthcheck) SetHealthcheckStatus(status ecstcs.InstanceHealthCheckStatus) {} +func (tc *trueHealthcheck) GetHealthcheckType() string { + return ecstcs.InstanceHealthCheckTypeAgent +} +func (tc *trueHealthcheck) GetHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusInitializing +} +func (tc *trueHealthcheck) GetLastHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusInitializing } func (tc *trueHealthcheck) GetHealthcheckTime() time.Time { return time.Date(1974, time.May, 19, 1, 2, 3, 4, time.UTC) @@ -51,14 +56,18 @@ func (tc *trueHealthcheck) GetLastHealthcheckTime() time.Time { type falseHealthcheck struct{} -func (fc *falseHealthcheck) RunCheck() HealthcheckStatus { return HealthcheckStatusImpaired } -func (fc *falseHealthcheck) SetHealthcheckStatus(status HealthcheckStatus) {} -func (fc *falseHealthcheck) GetHealthcheckType() string { return HealthcheckTypeAgent } -func (fc *falseHealthcheck) GetHealthcheckStatus() HealthcheckStatus { - return HealthcheckStatusInitializing +func (fc *falseHealthcheck) RunCheck() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusImpaired +} +func (fc *falseHealthcheck) SetHealthcheckStatus(status ecstcs.InstanceHealthCheckStatus) {} +func (fc *falseHealthcheck) GetHealthcheckType() string { + return ecstcs.InstanceHealthCheckTypeAgent +} +func (fc *falseHealthcheck) GetHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusInitializing } -func (fc *falseHealthcheck) GetLastHealthcheckStatus() HealthcheckStatus { - return HealthcheckStatusInitializing +func (fc *falseHealthcheck) GetLastHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusInitializing } func (fc *falseHealthcheck) GetHealthcheckTime() time.Time { return time.Date(1974, time.May, 19, 1, 2, 3, 4, time.UTC) @@ -161,27 +170,27 @@ func TestGetHealthchecks(t *testing.T) { func TestAllRight(t *testing.T) { testcases := []struct { name string - testChecksResult []HealthcheckStatus + testChecksResult []ecstcs.InstanceHealthCheckStatus expectedResult bool }{ { name: "empty checks", - testChecksResult: []HealthcheckStatus{}, + testChecksResult: []ecstcs.InstanceHealthCheckStatus{}, expectedResult: true, }, { name: "all true checks", - testChecksResult: []HealthcheckStatus{HealthcheckStatusOk, HealthcheckStatusOk}, + testChecksResult: []ecstcs.InstanceHealthCheckStatus{ecstcs.InstanceHealthCheckStatusOk, ecstcs.InstanceHealthCheckStatusOk}, expectedResult: true, }, { name: "all false checks", - testChecksResult: []HealthcheckStatus{HealthcheckStatusImpaired, HealthcheckStatusImpaired}, + testChecksResult: []ecstcs.InstanceHealthCheckStatus{ecstcs.InstanceHealthCheckStatusImpaired, ecstcs.InstanceHealthCheckStatusImpaired}, expectedResult: false, }, { name: "mixed checks", - testChecksResult: []HealthcheckStatus{HealthcheckStatusOk, HealthcheckStatusImpaired}, + testChecksResult: []ecstcs.InstanceHealthCheckStatus{ecstcs.InstanceHealthCheckStatusOk, ecstcs.InstanceHealthCheckStatusImpaired}, expectedResult: false, }, } diff --git a/ecs-agent/doctor/healthcheck.go b/ecs-agent/doctor/healthcheck.go index f185800b895..7a25e3f840c 100644 --- a/ecs-agent/doctor/healthcheck.go +++ b/ecs-agent/doctor/healthcheck.go @@ -15,21 +15,18 @@ package doctor import ( "time" -) -const ( - HealthcheckTypeContainerRuntime = "ContainerRuntime" - HealthcheckTypeAgent = "Agent" - HealthcheckTypeEBSDaemon = "EBSDaemon" + "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" ) +// Healthcheck defines the interface for performing health checks on various components. type Healthcheck interface { GetHealthcheckType() string - GetHealthcheckStatus() HealthcheckStatus + GetHealthcheckStatus() ecstcs.InstanceHealthCheckStatus GetHealthcheckTime() time.Time GetStatusChangeTime() time.Time - GetLastHealthcheckStatus() HealthcheckStatus + GetLastHealthcheckStatus() ecstcs.InstanceHealthCheckStatus GetLastHealthcheckTime() time.Time - RunCheck() HealthcheckStatus - SetHealthcheckStatus(status HealthcheckStatus) + RunCheck() ecstcs.InstanceHealthCheckStatus + SetHealthcheckStatus(status ecstcs.InstanceHealthCheckStatus) } diff --git a/ecs-agent/doctor/healthcheckstatus.go b/ecs-agent/doctor/healthcheckstatus.go deleted file mode 100644 index 920373ab7be..00000000000 --- a/ecs-agent/doctor/healthcheckstatus.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package doctor - -import ( - "errors" - "strings" -) - -const ( - // HealthcheckStatusInitializing is the zero state of a healthcheck status - HealthcheckStatusInitializing HealthcheckStatus = iota - // HealthcheckStatusOk represents a healthcheck with a true/success result - HealthcheckStatusOk - // HealthcheckStatusImpaired represents a healthcheck with a false/fail result - HealthcheckStatusImpaired -) - -// HealthcheckStatus is an enumeration of possible instance statuses -type HealthcheckStatus int32 - -var healthcheckStatusMap = map[string]HealthcheckStatus{ - "INITIALIZING": HealthcheckStatusInitializing, - "OK": HealthcheckStatusOk, - "IMPAIRED": HealthcheckStatusImpaired, -} - -// String returns a human readable string representation of this object -func (hs HealthcheckStatus) String() string { - for k, v := range healthcheckStatusMap { - if v == hs { - return k - } - } - // we shouldn't see this - return "NONE" -} - -// Ok returns true if the Healthcheck status is OK or INITIALIZING -func (hs HealthcheckStatus) Ok() bool { - return hs == HealthcheckStatusOk || hs == HealthcheckStatusInitializing -} - -// UnmarshalJSON overrides the logic for parsing the JSON-encoded HealthcheckStatus data -func (hs *HealthcheckStatus) UnmarshalJSON(b []byte) error { - if strings.ToLower(string(b)) == "null" { - *hs = HealthcheckStatusInitializing - return nil - } - if b[0] != '"' || b[len(b)-1] != '"' { - *hs = HealthcheckStatusInitializing - return errors.New("healthcheck status unmarshal: status must be a string or null; Got " + string(b)) - } - - stat, ok := healthcheckStatusMap[string(b[1:len(b)-1])] - if !ok { - *hs = HealthcheckStatusInitializing - return errors.New("healthcheck status unmarshal: unrecognized status") - } - *hs = stat - return nil -} - -// MarshalJSON overrides the logic for JSON-encoding the HealthcheckStatus type -func (hs *HealthcheckStatus) MarshalJSON() ([]byte, error) { - if hs == nil { - return nil, nil - } - return []byte(`"` + hs.String() + `"`), nil -} diff --git a/ecs-agent/tcs/client/client.go b/ecs-agent/tcs/client/client.go index 9512237ab50..4fd80dfde69 100644 --- a/ecs-agent/tcs/client/client.go +++ b/ecs-agent/tcs/client/client.go @@ -53,20 +53,37 @@ var ( ) // tcsClientServer implements wsclient.ClientServer interface for metrics backend. +// It handles publishing telemetry metrics, health messages, and instance status +// messages to the TCS backend through dedicated channels. type tcsClientServer struct { doctor *doctor.Doctor pullInstanceStatusTicker *time.Ticker disableResourceMetrics bool publishMetricsInterval time.Duration + // metrics is a receive-only channel for telemetry messages containing + // instance and task metrics to be published to the backend. metrics <-chan ecstcs.TelemetryMessage - health <-chan ecstcs.HealthMessage + + // health is a receive-only channel for health messages containing + // task health metrics to be published to the backend. + health <-chan ecstcs.HealthMessage + + // instanceStatus is a receive-only channel for instance status messages + // containing instance health status to be published to the backend. + instanceStatus <-chan ecstcs.InstanceStatusMessage wsclient.ClientServerImpl } // New returns a client/server to bidirectionally communicate with the backend. // The returned struct should have both 'Connect' and 'Serve' called upon it // before being used. +// +// The instanceStatusMessages parameter is optional and can be nil to maintain +// backward compatibility with existing functionality. When provided, it enables +// external components to send instance status updates through a dedicated channel, +// allowing for instance status publishing independent of the doctor module's +// periodic health checks. func New(url string, cfg *wsclient.WSClientMinAgentConfig, doctor *doctor.Doctor, @@ -76,6 +93,7 @@ func New(url string, rwTimeout time.Duration, metricsMessages <-chan ecstcs.TelemetryMessage, healthMessages <-chan ecstcs.HealthMessage, + instanceStatusMessages <-chan ecstcs.InstanceStatusMessage, metricsFactory metrics.EntryFactory, ) wsclient.ClientServer { cs := &tcsClientServer{ @@ -84,6 +102,7 @@ func New(url string, publishMetricsInterval: publishMetricsInterval, metrics: metricsMessages, health: healthMessages, + instanceStatus: instanceStatusMessages, disableResourceMetrics: disableResourceMetrics, ClientServerImpl: wsclient.ClientServerImpl{ URL: url, @@ -122,6 +141,16 @@ func (cs *tcsClientServer) Serve(ctx context.Context) error { return cs.ConsumeMessages(ctx) } +// publishMessages listens for messages on the metrics, health, and instanceStatus +// channels and publishes them to the TCS backend. This method runs in a separate +// goroutine and handles three types of messages concurrently: +// - Telemetry messages containing instance and task metrics +// - Health messages containing task health information +// - Instance status messages containing instance health status information +// +// The method continues processing messages until the context is cancelled. +// Errors during publishing are logged but do not terminate the processing loop, +// ensuring that failures with one message type do not affect others. func (cs *tcsClientServer) publishMessages(ctx context.Context) { for { select { @@ -144,6 +173,14 @@ func (cs *tcsClientServer) publishMessages(ctx context.Context) { field.Error: err, }) } + case instanceStatus := <-cs.instanceStatus: + logger.Debug("received instance status message in instanceStatusChannel") + err := cs.publishInstanceStatusOnce(instanceStatus) + if err != nil { + logger.Warn("Error publishing instance status", logger.Fields{ + field.Error: err, + }) + } } } } @@ -408,7 +445,16 @@ func (cs *tcsClientServer) publishInstanceStatus(ctx context.Context) { select { case <-cs.pullInstanceStatusTicker.C: if !cs.doctor.HasStatusBeenReported() { - err := cs.publishInstanceStatusOnce() + // Create InstanceStatusMessage from doctor data + message, err := cs.createInstanceStatusMessageFromDoctor() + if err != nil { + logger.Warn("Unable to create instance status message from doctor", logger.Fields{ + field.Error: err, + }) + continue + } + + err = cs.publishInstanceStatusOnce(message) if err != nil { logger.Warn("Unable to publish instance status", logger.Fields{ field.Error: err, @@ -425,44 +471,45 @@ func (cs *tcsClientServer) publishInstanceStatus(ctx context.Context) { } } -// publishInstanceStatusOnce gets called on a ticker to pull instance status -// from the doctor instance contained within cs and sned that information to -// the backend -func (cs *tcsClientServer) publishInstanceStatusOnce() error { - // Get the list of health request to send to backend. - request, err := cs.getPublishInstanceStatusRequest() - if err != nil { - return err +// publishInstanceStatusOnce publishes instance status using the provided message +// parameter instead of querying the doctor module. This method accepts an +// InstanceStatusMessage and creates a PublishInstanceStatusRequest from it, +// adding a timestamp and sending it to the TCS backend. +// +// This method enables external components to publish instance status updates +// through the instanceStatus channel, providing an alternative to the doctor +// module's periodic health check publishing mechanism. +func (cs *tcsClientServer) publishInstanceStatusOnce(message ecstcs.InstanceStatusMessage) error { + request := &ecstcs.PublishInstanceStatusRequest{ + Metadata: message.Metadata, + Statuses: message.Statuses, + Timestamp: (*utils.Timestamp)(aws.Time(time.Now())), } - // Make the publish instance status request to the backend. - err = cs.MakeRequest(request) + err := cs.MakeRequest(request) if err != nil { return err } - cs.doctor.SetStatusReported(true) - return nil } -// GetPublishInstanceStatusRequest will get all healthcheck statuses and generate -// a sendable PublishInstanceStatusRequest -func (cs *tcsClientServer) getPublishInstanceStatusRequest() (*ecstcs.PublishInstanceStatusRequest, error) { +// createInstanceStatusMessageFromDoctor creates an InstanceStatusMessage from doctor data +func (cs *tcsClientServer) createInstanceStatusMessageFromDoctor() (ecstcs.InstanceStatusMessage, error) { metadata := &ecstcs.InstanceStatusMetadata{ Cluster: aws.String(cs.doctor.GetCluster()), ContainerInstance: aws.String(cs.doctor.GetContainerInstanceArn()), RequestId: aws.String(uuid.NewRandom().String()), } + instanceStatuses := cs.getInstanceStatuses() if instanceStatuses == nil { - return nil, doctor.EmptyHealthcheckError + return ecstcs.InstanceStatusMessage{}, doctor.EmptyHealthcheckError } - return &ecstcs.PublishInstanceStatusRequest{ - Metadata: metadata, - Statuses: instanceStatuses, - Timestamp: (*utils.Timestamp)(aws.Time(time.Now())), + return ecstcs.InstanceStatusMessage{ + Metadata: metadata, + Statuses: instanceStatuses, }, nil } diff --git a/ecs-agent/tcs/client/client_test.go b/ecs-agent/tcs/client/client_test.go index eb1ec070170..aa61e450ea0 100644 --- a/ecs-agent/tcs/client/client_test.go +++ b/ecs-agent/tcs/client/client_test.go @@ -23,6 +23,7 @@ package tcsclient import ( + "bytes" "context" "fmt" "math/rand" @@ -32,7 +33,6 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" - mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" "github.com/aws/amazon-ecs-agent/ecs-agent/utils" "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" @@ -59,14 +59,18 @@ const ( type trueHealthcheck struct{} -func (tc *trueHealthcheck) RunCheck() doctor.HealthcheckStatus { return doctor.HealthcheckStatusOk } -func (tc *trueHealthcheck) SetHealthcheckStatus(status doctor.HealthcheckStatus) {} -func (tc *trueHealthcheck) GetHealthcheckType() string { return doctor.HealthcheckTypeAgent } -func (tc *trueHealthcheck) GetHealthcheckStatus() doctor.HealthcheckStatus { - return doctor.HealthcheckStatusInitializing +func (tc *trueHealthcheck) RunCheck() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusOk } -func (tc *trueHealthcheck) GetLastHealthcheckStatus() doctor.HealthcheckStatus { - return doctor.HealthcheckStatusInitializing +func (tc *trueHealthcheck) SetHealthcheckStatus(status ecstcs.InstanceHealthCheckStatus) {} +func (tc *trueHealthcheck) GetHealthcheckType() string { + return ecstcs.InstanceHealthCheckTypeAgent +} +func (tc *trueHealthcheck) GetHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusInitializing +} +func (tc *trueHealthcheck) GetLastHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusInitializing } func (tc *trueHealthcheck) GetHealthcheckTime() time.Time { return time.Date(1974, time.May, 19, 1, 2, 3, 4, time.UTC) @@ -80,16 +84,18 @@ func (tc *trueHealthcheck) GetLastHealthcheckTime() time.Time { type falseHealthcheck struct{} -func (fc *falseHealthcheck) RunCheck() doctor.HealthcheckStatus { - return doctor.HealthcheckStatusImpaired +func (fc *falseHealthcheck) RunCheck() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusImpaired } -func (fc *falseHealthcheck) SetHealthcheckStatus(status doctor.HealthcheckStatus) {} -func (fc *falseHealthcheck) GetHealthcheckType() string { return doctor.HealthcheckTypeAgent } -func (fc *falseHealthcheck) GetHealthcheckStatus() doctor.HealthcheckStatus { - return doctor.HealthcheckStatusInitializing +func (fc *falseHealthcheck) SetHealthcheckStatus(status ecstcs.InstanceHealthCheckStatus) {} +func (fc *falseHealthcheck) GetHealthcheckType() string { + return ecstcs.InstanceHealthCheckTypeAgent } -func (fc *falseHealthcheck) GetLastHealthcheckStatus() doctor.HealthcheckStatus { - return doctor.HealthcheckStatusInitializing +func (fc *falseHealthcheck) GetHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusInitializing +} +func (fc *falseHealthcheck) GetLastHealthcheckStatus() ecstcs.InstanceHealthCheckStatus { + return ecstcs.InstanceHealthCheckStatusInitializing } func (fc *falseHealthcheck) GetHealthcheckTime() time.Time { return time.Date(1974, time.May, 19, 1, 2, 3, 4, time.UTC) @@ -649,7 +655,7 @@ func testCS(conn *mock_wsconn.MockWebsocketConn, metricsMessages <-chan ecstcs.T AcceptInsecureCert: true, } cs := New("https://aws.amazon.com/ecs", cfg, emptyDoctor, false, testPublishMetricsInterval, - aws.NewCredentialsCache(testCreds), rwTimeout, metricsMessages, healthMessages, metrics.NewNopEntryFactory()).(*tcsClientServer) + aws.NewCredentialsCache(testCreds), rwTimeout, metricsMessages, healthMessages, nil, metrics.NewNopEntryFactory()).(*tcsClientServer) cs.SetConnection(conn) return cs } @@ -720,7 +726,7 @@ func TestHealthToPublishHealthRequests(t *testing.T) { IsDocker: true, } - cs := New("", cfg, emptyDoctor, true, testPublishMetricsInterval, aws.NewCredentialsCache(testCreds), rwTimeout, nil, nil, metrics.NewNopEntryFactory()) + cs := New("", cfg, emptyDoctor, true, testPublishMetricsInterval, aws.NewCredentialsCache(testCreds), rwTimeout, nil, nil, nil, metrics.NewNopEntryFactory()) cs.SetConnection(conn) testMetadata := &ecstcs.HealthMetadata{ @@ -907,25 +913,21 @@ func TestGetPublishInstanceStatusRequest(t *testing.T) { } cs.doctor.RunHealthchecks() - // note: setting RequestId and Timestamp to nil so I can make the comparison metadata := &ecstcs.InstanceStatusMetadata{ Cluster: aws.String(testCluster), ContainerInstance: aws.String(testContainerInstance), RequestId: nil, } - testResult, err := cs.getPublishInstanceStatusRequest() + testMessage, err := cs.createInstanceStatusMessageFromDoctor() if tc.expectedStatuses != nil { - expectedResult := &ecstcs.PublishInstanceStatusRequest{ - Metadata: metadata, - Statuses: tc.expectedStatuses, - Timestamp: nil, + expectedMessage := ecstcs.InstanceStatusMessage{ + Metadata: metadata, + Statuses: tc.expectedStatuses, } - // note: setting RequestId and Timestamp to nil so I can make the comparison - testResult.Timestamp = nil - testResult.Metadata.RequestId = nil - assert.Equal(t, testResult, expectedResult) + testMessage.Metadata.RequestId = nil + assert.Equal(t, testMessage, expectedMessage) } else { assert.Error(t, err, "Test failed") } @@ -1014,64 +1016,1421 @@ func TestInvalidFormatMessageOnChannel(t *testing.T) { conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Times(0) } -// TestTACSPublishMetricFailureMetric tests that the TACSPublishMetricFailure metric is recorded when there's a metrics publishing error -func TestTACSPublishMetricFailureMetric(t *testing.T) { +// TestNewConstructorWithInstanceStatusChannel tests the constructor with instanceStatus channel parameter. +func TestNewConstructorWithInstanceStatusChannel(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + url string + disableResourceMetrics bool + publishMetricsInterval time.Duration + metricsMessages <-chan ecstcs.TelemetryMessage + healthMessages <-chan ecstcs.HealthMessage + instanceStatusMessages <-chan ecstcs.InstanceStatusMessage + expectedInstanceStatusChan bool + }{ + { + name: "constructor with valid instanceStatus channel", + url: "https://aws.amazon.com/ecs", + disableResourceMetrics: false, + publishMetricsInterval: testPublishMetricsInterval, + metricsMessages: make(chan ecstcs.TelemetryMessage, 1), + healthMessages: make(chan ecstcs.HealthMessage, 1), + instanceStatusMessages: make(chan ecstcs.InstanceStatusMessage, 1), + expectedInstanceStatusChan: true, + }, + { + name: "constructor with nil instanceStatus channel", + url: "https://aws.amazon.com/ecs", + disableResourceMetrics: true, + publishMetricsInterval: testPublishMetricsInterval, + metricsMessages: make(chan ecstcs.TelemetryMessage, 1), + healthMessages: make(chan ecstcs.HealthMessage, 1), + instanceStatusMessages: nil, + expectedInstanceStatusChan: false, + }, + { + name: "constructor with all channels nil", + url: "https://aws.amazon.com/ecs", + disableResourceMetrics: false, + publishMetricsInterval: testPublishMetricsInterval, + metricsMessages: nil, + healthMessages: nil, + instanceStatusMessages: nil, + expectedInstanceStatusChan: false, + }, + { + name: "constructor with different URL and settings", + url: "https://test.example.com", + disableResourceMetrics: true, + publishMetricsInterval: 2 * time.Second, + metricsMessages: make(chan ecstcs.TelemetryMessage, 5), + healthMessages: make(chan ecstcs.HealthMessage, 5), + instanceStatusMessages: make(chan ecstcs.InstanceStatusMessage, 5), + expectedInstanceStatusChan: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + cfg := &wsclient.WSClientMinAgentConfig{ + AWSRegion: "us-east-1", + AcceptInsecureCert: true, + } + + cs := New( + tc.url, + cfg, + emptyDoctor, + tc.disableResourceMetrics, + tc.publishMetricsInterval, + aws.NewCredentialsCache(testCreds), + rwTimeout, + tc.metricsMessages, + tc.healthMessages, + tc.instanceStatusMessages, + metrics.NewNopEntryFactory(), + ).(*tcsClientServer) + + // Verify that the channel is properly stored in the struct + if tc.expectedInstanceStatusChan { + assert.NotNil(t, cs.instanceStatus, "instanceStatus channel should be stored when provided") + assert.Equal(t, tc.instanceStatusMessages, cs.instanceStatus, "instanceStatus channel should match the provided channel") + } else { + assert.Nil(t, cs.instanceStatus, "instanceStatus channel should be nil when not provided") + } + + // Verify other fields are properly set + assert.Equal(t, tc.disableResourceMetrics, cs.disableResourceMetrics, "disableResourceMetrics should match") + assert.Equal(t, tc.publishMetricsInterval, cs.publishMetricsInterval, "publishMetricsInterval should match") + + // Verify channels are set correctly (checking for nil/non-nil rather than exact equality due to type conversion) + if tc.metricsMessages != nil { + assert.NotNil(t, cs.metrics, "metrics channel should be set when provided") + } else { + assert.Nil(t, cs.metrics, "metrics channel should be nil when not provided") + } + + if tc.healthMessages != nil { + assert.NotNil(t, cs.health, "health channel should be set when provided") + } else { + assert.Nil(t, cs.health, "health channel should be nil when not provided") + } + + assert.Equal(t, emptyDoctor, cs.doctor, "doctor should match") + assert.Equal(t, tc.url, cs.URL, "URL should match") + }) + } +} + +// TestNewConstructorBackwardCompatibility tests backward compatibility of the constructor. +func TestNewConstructorBackwardCompatibility(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + instanceStatusMessages <-chan ecstcs.InstanceStatusMessage + description string + }{ + { + name: "nil instanceStatus channel maintains compatibility", + instanceStatusMessages: nil, + description: "Constructor should work with nil instanceStatusMessages parameter", + }, + { + name: "valid instanceStatus channel works correctly", + instanceStatusMessages: make(chan ecstcs.InstanceStatusMessage, 1), + description: "Constructor should work with valid instanceStatusMessages parameter", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + cfg := &wsclient.WSClientMinAgentConfig{ + AWSRegion: "us-east-1", + AcceptInsecureCert: true, + } + + metricsMessages := make(chan ecstcs.TelemetryMessage, 1) + healthMessages := make(chan ecstcs.HealthMessage, 1) + + // Test that constructor works without errors + cs := New( + "https://aws.amazon.com/ecs", + cfg, + emptyDoctor, + false, + testPublishMetricsInterval, + aws.NewCredentialsCache(testCreds), + rwTimeout, + metricsMessages, + healthMessages, + tc.instanceStatusMessages, + metrics.NewNopEntryFactory(), + ) + + // Verify that the client server is created successfully + assert.NotNil(t, cs, "ClientServer should be created successfully") + + // Verify that it implements the expected interface + _, ok := cs.(wsclient.ClientServer) + assert.True(t, ok, "Returned object should implement wsclient.ClientServer interface") + + // Cast to concrete type to verify internal state + tcsCS := cs.(*tcsClientServer) + + // Verify existing functionality is not affected + assert.NotNil(t, tcsCS.metrics, "metrics channel should be set") + assert.NotNil(t, tcsCS.health, "health channel should be set") + assert.Equal(t, emptyDoctor, tcsCS.doctor, "doctor should be properly set") + assert.Equal(t, testPublishMetricsInterval, tcsCS.publishMetricsInterval, "publishMetricsInterval should be properly set") + + // Verify instanceStatus field is handled correctly + if tc.instanceStatusMessages != nil { + assert.NotNil(t, tcsCS.instanceStatus, "instanceStatus channel should be set when provided") + } else { + assert.Nil(t, tcsCS.instanceStatus, "instanceStatus channel should be nil when not provided") + } + + // Verify basic interface compliance without calling Close() which requires a connection + assert.NotNil(t, cs, "ClientServer should implement the interface correctly") + }) + } +} + +// TestPublishMessagesInstanceStatusReception tests instanceStatus message reception and processing. +func TestPublishMessagesInstanceStatusReception(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + instanceStatusMessage ecstcs.InstanceStatusMessage + expectPublishCall bool + mockSetup func(*mock_wsconn.MockWebsocketConn) + expectedError bool + }{ + { + name: "successful instanceStatus message processing", + instanceStatusMessage: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + }, + expectPublishCall: true, + mockSetup: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(nil) + }, + expectedError: false, + }, + { + name: "instanceStatus message with multiple statuses", + instanceStatusMessage: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + { + Status: aws.String("IMPAIRED"), + Type: aws.String("DOCKER"), + }, + }, + }, + expectPublishCall: true, + mockSetup: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(nil) + }, + expectedError: false, + }, + { + name: "instanceStatus message with empty statuses", + instanceStatusMessage: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{}, + }, + expectPublishCall: true, + mockSetup: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(nil) + }, + expectedError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + instanceStatusMessages := make(chan ecstcs.InstanceStatusMessage, 1) + + cs := testCS(conn, nil, nil).(*tcsClientServer) + cs.instanceStatus = instanceStatusMessages + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + if tc.expectPublishCall { + tc.mockSetup(conn) + } + + // Start publishMessages in a goroutine + go cs.publishMessages(ctx) + + // Send the instanceStatus message + instanceStatusMessages <- tc.instanceStatusMessage + + // Give some time for message processing + time.Sleep(100 * time.Millisecond) + + // Cancel context to stop publishMessages + cancel() + + // Verify message was consumed from channel + assert.Len(t, instanceStatusMessages, 0, "instanceStatus message should be consumed from channel") + }) + } +} + +// TestPublishMessagesConcurrentHandling tests concurrent handling of all three message types. +func TestPublishMessagesConcurrentHandling(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) defer ctrl.Finish() - mockMetricsFactory := mock_metrics.NewMockEntryFactory(ctrl) - mockEntry := mock_metrics.NewMockEntry(ctrl) + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + telemetryMessages := make(chan ecstcs.TelemetryMessage, 1) + healthMessages := make(chan ecstcs.HealthMessage, 1) + instanceStatusMessages := make(chan ecstcs.InstanceStatusMessage, 1) + + cs := testCS(conn, telemetryMessages, healthMessages).(*tcsClientServer) + cs.instanceStatus = instanceStatusMessages ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - telemetryMessages := make(chan ecstcs.TelemetryMessage, testTelemetryChannelDefaultBufferSize) - healthMessages := make(chan ecstcs.HealthMessage, testTelemetryChannelDefaultBufferSize) + // Expect three WriteMessage calls for the three different message types. + // Each WriteMessage is preceded by SetWriteDeadline. + // Use AnyTimes() to allow calls in any order. + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - // Create a connection that will fail when writing - conn := mock_wsconn.NewMockWebsocketConn(ctrl) - conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) - conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(fmt.Errorf("connection error")) + // Start publishMessages in a goroutine + go cs.publishMessages(ctx) - cfg := &wsclient.WSClientMinAgentConfig{ - AWSRegion: "us-east-1", - AcceptInsecureCert: true, + // Create test messages + telemetryMessage := ecstcs.TelemetryMessage{ + Metadata: &ecstcs.MetricsMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + Idle: aws.Bool(true), + MessageId: aws.String("test-message"), + }, + TaskMetrics: []*ecstcs.TaskMetric{}, } - cs := New("https://aws.amazon.com/ecs", cfg, emptyDoctor, false, testPublishMetricsInterval, - aws.NewCredentialsCache(testCreds), rwTimeout, telemetryMessages, healthMessages, mockMetricsFactory).(*tcsClientServer) - cs.SetConnection(conn) - // Set expectations for the metrics calls - mockMetricsFactory.EXPECT().New(metrics.TACSPublishMetricFailure).Return(mockEntry).Times(1) - mockEntry.EXPECT().Done(gomock.Any()).Times(1) + healthMessage := ecstcs.HealthMessage{ + Metadata: &ecstcs.HealthMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + MessageId: aws.String("test-message"), + }, + HealthMetrics: []*ecstcs.TaskHealth{}, + } - // Create a valid telemetry message that will trigger publishMetricsOnce - telemetryMessage := ecstcs.TelemetryMessage{ - Metadata: &ecstcs.MetricsMetadata{ - Cluster: aws.String(testCluster), - ContainerInstance: aws.String(testContainerInstance), - Idle: aws.Bool(false), - MessageId: aws.String(testMessageId), + instanceStatusMessage := ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), }, - TaskMetrics: []*ecstcs.TaskMetric{ + Statuses: []*ecstcs.InstanceStatus{ { - TaskArn: aws.String("test-task-arn"), + Status: aws.String("OK"), + Type: aws.String("AGENT"), }, }, } - // Send the message to the channel + // Send all three message types telemetryMessages <- telemetryMessage + healthMessages <- healthMessage + instanceStatusMessages <- instanceStatusMessage + + // Give some time for message processing + time.Sleep(200 * time.Millisecond) + + // Cancel context to stop publishMessages + cancel() + + // Verify all messages were consumed from channels + assert.Len(t, telemetryMessages, 0, "telemetry message should be consumed from channel") + assert.Len(t, healthMessages, 0, "health message should be consumed from channel") + assert.Len(t, instanceStatusMessages, 0, "instanceStatus message should be consumed from channel") +} + +// TestPublishMessagesErrorHandling tests error handling in publishMessages. +func TestPublishMessagesErrorHandling(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + setupMock func(*mock_wsconn.MockWebsocketConn) + sendMessage func(chan ecstcs.InstanceStatusMessage) + expectedErrorLogged bool + }{ + { + name: "publishInstanceStatusOnce fails with connection error", + setupMock: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(fmt.Errorf("connection error")) + }, + sendMessage: func(ch chan ecstcs.InstanceStatusMessage) { + ch <- ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + } + }, + expectedErrorLogged: true, + }, + { + name: "publishInstanceStatusOnce fails with write deadline error", + setupMock: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(fmt.Errorf("write deadline exceeded")) + }, + sendMessage: func(ch chan ecstcs.InstanceStatusMessage) { + ch <- ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("IMPAIRED"), + Type: aws.String("DOCKER"), + }, + }, + } + }, + expectedErrorLogged: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + instanceStatusMessages := make(chan ecstcs.InstanceStatusMessage, 1) + + cs := testCS(conn, nil, nil).(*tcsClientServer) + cs.instanceStatus = instanceStatusMessages + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + tc.setupMock(conn) + + // Start publishMessages in a goroutine + go cs.publishMessages(ctx) + + // Send the message that should cause an error + tc.sendMessage(instanceStatusMessages) + + // Give some time for message processing and error logging + time.Sleep(100 * time.Millisecond) + + // Cancel context to stop publishMessages + cancel() + + // Verify message was consumed from channel even when error occurred + assert.Len(t, instanceStatusMessages, 0, "instanceStatus message should be consumed from channel even on error") + }) + } +} + +// TestPublishMessagesErrorsDoNotAffectOtherMessageTypes tests that errors in instanceStatus processing don't affect other message types. +func TestPublishMessagesErrorsDoNotAffectOtherMessageTypes(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + telemetryMessages := make(chan ecstcs.TelemetryMessage, 1) + healthMessages := make(chan ecstcs.HealthMessage, 1) + instanceStatusMessages := make(chan ecstcs.InstanceStatusMessage, 1) + + cs := testCS(conn, telemetryMessages, healthMessages).(*tcsClientServer) + cs.instanceStatus = instanceStatusMessages + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + // Set up mock expectations: instanceStatus fails, but telemetry and health succeed + // Use AnyTimes() to allow calls in any order since select is non-deterministic. + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).DoAndReturn(func(messageType int, data []byte) error { + // Check if this is an instanceStatus message by looking for "PublishInstanceStatusRequest" in the data + if bytes.Contains(data, []byte("PublishInstanceStatusRequest")) { + return fmt.Errorf("instanceStatus error") + } + return nil + }).AnyTimes() // Start publishMessages in a goroutine go cs.publishMessages(ctx) - // Give some time for the message to be processed - time.Sleep(100 * time.Millisecond) + // Create test messages + instanceStatusMessage := ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + } + + telemetryMessage := ecstcs.TelemetryMessage{ + Metadata: &ecstcs.MetricsMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + Idle: aws.Bool(true), + MessageId: aws.String("test-message"), + }, + TaskMetrics: []*ecstcs.TaskMetric{}, + } + + healthMessage := ecstcs.HealthMessage{ + Metadata: &ecstcs.HealthMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + MessageId: aws.String("test-message"), + }, + HealthMetrics: []*ecstcs.TaskHealth{}, + } + + // Send instanceStatus message first (which will fail) + instanceStatusMessages <- instanceStatusMessage + + // Give some time for the error to be processed + time.Sleep(50 * time.Millisecond) + + // Send telemetry and health messages (which should succeed) + telemetryMessages <- telemetryMessage + healthMessages <- healthMessage + + // Give some time for message processing + time.Sleep(150 * time.Millisecond) + + // Cancel context to stop publishMessages + cancel() + + // Verify all messages were consumed from channels + assert.Len(t, instanceStatusMessages, 0, "instanceStatus message should be consumed from channel") + assert.Len(t, telemetryMessages, 0, "telemetry message should be consumed from channel") + assert.Len(t, healthMessages, 0, "health message should be consumed from channel") +} + +// TestPublishMessagesContextCancellation tests context cancellation behavior. +func TestPublishMessagesContextCancellation(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + instanceStatusMessages := make(chan ecstcs.InstanceStatusMessage, 1) + + cs := testCS(conn, nil, nil).(*tcsClientServer) + cs.instanceStatus = instanceStatusMessages + + ctx, cancel := context.WithCancel(context.TODO()) + + // Start publishMessages in a goroutine + done := make(chan bool) + go func() { + cs.publishMessages(ctx) + done <- true + }() - // Cancel the context to stop the goroutine + // Cancel context immediately cancel() - // Give some time for the goroutine to exit - time.Sleep(100 * time.Millisecond) + // Wait for publishMessages to return + select { + case <-done: + // publishMessages returned as expected + case <-time.After(1 * time.Second): + t.Fatal("publishMessages did not return after context cancellation") + } + + // Verify that any pending messages in channels are not processed after cancellation + instanceStatusMessages <- ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + } + + // Give some time to ensure no processing occurs + time.Sleep(50 * time.Millisecond) + + // Message should still be in channel since publishMessages has stopped + assert.Len(t, instanceStatusMessages, 1, "instanceStatus message should remain in channel after context cancellation") +} + +// TestPublishMessagesWithInstanceStatusChannelSimple tests that publishMessages handles instanceStatus messages correctly. +func TestPublishMessagesWithInstanceStatusChannelSimple(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + + // Create all channels to avoid nil channel blocking + telemetryMessages := make(chan ecstcs.TelemetryMessage, 1) + healthMessages := make(chan ecstcs.HealthMessage, 1) + instanceStatusMessages := make(chan ecstcs.InstanceStatusMessage, 1) + + cs := testCS(conn, telemetryMessages, healthMessages).(*tcsClientServer) + cs.instanceStatus = instanceStatusMessages + + ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second) + defer cancel() + + // Expect SetWriteDeadline and WriteMessage for instanceStatus + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(nil) + + // Start publishMessages in a goroutine + go cs.publishMessages(ctx) + + // Send instanceStatus message + instanceStatusMessage := ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + } + + instanceStatusMessages <- instanceStatusMessage + + // Give time for processing + time.Sleep(200 * time.Millisecond) + + // Verify message was consumed + assert.Len(t, instanceStatusMessages, 0, "instanceStatus message should be consumed from channel") +} + +// TestPublishMessagesInstanceStatusErrorSimple tests error handling for instanceStatus messages. +func TestPublishMessagesInstanceStatusErrorSimple(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + + // Create all channels to avoid nil channel blocking + telemetryMessages := make(chan ecstcs.TelemetryMessage, 1) + healthMessages := make(chan ecstcs.HealthMessage, 1) + instanceStatusMessages := make(chan ecstcs.InstanceStatusMessage, 1) + + cs := testCS(conn, telemetryMessages, healthMessages).(*tcsClientServer) + cs.instanceStatus = instanceStatusMessages + + ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second) + defer cancel() + + // Expect SetWriteDeadline and WriteMessage that fails + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(fmt.Errorf("connection error")) + + // Start publishMessages in a goroutine + go cs.publishMessages(ctx) + + // Send instanceStatus message + instanceStatusMessage := ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("IMPAIRED"), + Type: aws.String("DOCKER"), + }, + }, + } + + instanceStatusMessages <- instanceStatusMessage + + // Give time for processing + time.Sleep(200 * time.Millisecond) + + // Verify message was consumed even with error + assert.Len(t, instanceStatusMessages, 0, "instanceStatus message should be consumed from channel even on error") +} + +// TestPublishMessagesContextCancellationSimple tests context cancellation behavior. +func TestPublishMessagesContextCancellationSimple(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + + // Create all channels to avoid nil channel blocking + telemetryMessages := make(chan ecstcs.TelemetryMessage, 1) + healthMessages := make(chan ecstcs.HealthMessage, 1) + instanceStatusMessages := make(chan ecstcs.InstanceStatusMessage, 1) + + cs := testCS(conn, telemetryMessages, healthMessages).(*tcsClientServer) + cs.instanceStatus = instanceStatusMessages + + ctx, cancel := context.WithCancel(context.TODO()) + + // Start publishMessages in a goroutine + done := make(chan bool) + go func() { + cs.publishMessages(ctx) + done <- true + }() + + // Cancel context immediately + cancel() + + // Wait for publishMessages to return + select { + case <-done: + // publishMessages returned as expected + case <-time.After(1 * time.Second): + t.Fatal("publishMessages did not return after context cancellation") + } +} + +// TestPublishInstanceStatusOnce tests successful instanceStatus publishing. +func TestPublishInstanceStatusOnce(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + message ecstcs.InstanceStatusMessage + expectedError bool + setupMock func(*mock_wsconn.MockWebsocketConn) + }{ + { + name: "successful publish with single status", + message: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + }, + expectedError: false, + setupMock: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(nil) + }, + }, + { + name: "successful publish with multiple statuses", + message: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("production-cluster"), + ContainerInstance: aws.String("i-1234567890abcdef0"), + RequestId: aws.String("req-12345"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + { + Status: aws.String("IMPAIRED"), + Type: aws.String("DOCKER"), + }, + }, + }, + expectedError: false, + setupMock: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(nil) + }, + }, + { + name: "successful publish with empty statuses", + message: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{}, + }, + expectedError: false, + setupMock: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(nil) + }, + }, + { + name: "successful publish with nil metadata fields", + message: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: nil, + ContainerInstance: nil, + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + }, + expectedError: false, + setupMock: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(nil) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + cs := testCS(conn, nil, nil).(*tcsClientServer) + + tc.setupMock(conn) + + err := cs.publishInstanceStatusOnce(tc.message) + + if tc.expectedError { + assert.Error(t, err, "Expected error but got none") + } else { + assert.NoError(t, err, "Expected no error but got: %v", err) + } + }) + } +} + +// TestPublishInstanceStatusOnceErrorHandling tests error handling in publishInstanceStatusOnce. +func TestPublishInstanceStatusOnceErrorHandling(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + message ecstcs.InstanceStatusMessage + setupMock func(*mock_wsconn.MockWebsocketConn) + expectedError string + }{ + { + name: "MakeRequest fails with connection error", + message: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + }, + setupMock: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(fmt.Errorf("connection error")) + }, + expectedError: "connection error", + }, + { + name: "MakeRequest fails with write deadline error", + message: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("IMPAIRED"), + Type: aws.String("DOCKER"), + }, + }, + }, + setupMock: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(fmt.Errorf("write deadline exceeded")) + }, + expectedError: "write deadline exceeded", + }, + { + name: "MakeRequest fails with network timeout", + message: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("production-cluster"), + ContainerInstance: aws.String("i-1234567890abcdef0"), + RequestId: aws.String("req-timeout"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + { + Status: aws.String("OK"), + Type: aws.String("DOCKER"), + }, + }, + }, + setupMock: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(fmt.Errorf("network timeout")) + }, + expectedError: "network timeout", + }, + { + name: "MakeRequest fails with SetWriteDeadline error", + message: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + }, + setupMock: func(mockConn *mock_wsconn.MockWebsocketConn) { + mockConn.EXPECT().SetWriteDeadline(gomock.Any()).Return(fmt.Errorf("deadline error")) + // Even when SetWriteDeadline fails, WriteMessage is still called + mockConn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(fmt.Errorf("deadline error")) + }, + expectedError: "deadline error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + cs := testCS(conn, nil, nil).(*tcsClientServer) + + tc.setupMock(conn) + + err := cs.publishInstanceStatusOnce(tc.message) + + assert.Error(t, err, "Expected error but got none") + assert.Contains(t, err.Error(), tc.expectedError, "Error message should contain expected text") + }) + } +} + +// TestPublishInstanceStatusOnceRequestStructure tests proper PublishInstanceStatusRequest creation. +func TestPublishInstanceStatusOnceRequestStructure(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + message ecstcs.InstanceStatusMessage + }{ + { + name: "request structure with complete metadata", + message: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("test-cluster"), + ContainerInstance: aws.String("test-instance"), + RequestId: aws.String("test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + }, + }, + { + name: "request structure with multiple statuses", + message: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("production-cluster"), + ContainerInstance: aws.String("i-1234567890abcdef0"), + RequestId: aws.String("req-12345"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + { + Status: aws.String("IMPAIRED"), + Type: aws.String("DOCKER"), + }, + { + Status: aws.String("OK"), + Type: aws.String("EBS_CSI"), + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + cs := testCS(conn, nil, nil).(*tcsClientServer) + + // Capture the request structure by examining the WriteMessage call + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).DoAndReturn( + func(messageType int, data []byte) error { + // Verify that the request contains the expected structure + // The data should contain the serialized PublishInstanceStatusRequest + assert.NotEmpty(t, data, "Request data should not be empty") + + // Verify that the data contains expected fields from the message + dataStr := string(data) + if tc.message.Metadata != nil { + if tc.message.Metadata.Cluster != nil { + assert.Contains(t, dataStr, *tc.message.Metadata.Cluster, "Request should contain cluster name") + } + if tc.message.Metadata.ContainerInstance != nil { + assert.Contains(t, dataStr, *tc.message.Metadata.ContainerInstance, "Request should contain container instance") + } + if tc.message.Metadata.RequestId != nil { + assert.Contains(t, dataStr, *tc.message.Metadata.RequestId, "Request should contain request ID") + } + } + + // Verify that status information is included + for _, status := range tc.message.Statuses { + if status.Status != nil { + assert.Contains(t, dataStr, *status.Status, "Request should contain status value") + } + if status.Type != nil { + assert.Contains(t, dataStr, *status.Type, "Request should contain status type") + } + } + + // Verify that timestamp is included (should be present in all requests) + assert.Contains(t, dataStr, "timestamp", "Request should contain timestamp field") + + return nil + }, + ) + + err := cs.publishInstanceStatusOnce(tc.message) + assert.NoError(t, err, "Expected no error but got: %v", err) + }) + } +} + +// testCSIntegration creates a test TCS client for integration tests. +func testCSIntegration(conn *mock_wsconn.MockWebsocketConn, + metricsMessages <-chan ecstcs.TelemetryMessage, + healthMessages <-chan ecstcs.HealthMessage, + instanceStatusMessages <-chan ecstcs.InstanceStatusMessage) wsclient.ClientServer { + cfg := &wsclient.WSClientMinAgentConfig{ + AWSRegion: "us-east-1", + AcceptInsecureCert: true, + } + cs := New("https://aws.amazon.com/ecs", cfg, emptyDoctor, false, testPublishMetricsInterval, + aws.NewCredentialsCache(testCreds), rwTimeout, metricsMessages, healthMessages, + instanceStatusMessages, metrics.NewNopEntryFactory()).(*tcsClientServer) + cs.SetConnection(conn) + return cs +} + +// TestEndToEndInstanceStatusFlow tests the complete flow from channel message to backend request. +func TestEndToEndInstanceStatusFlow(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + instanceStatusMessage ecstcs.InstanceStatusMessage + expectedRequestCount int + description string + }{ + { + name: "complete flow with single status", + instanceStatusMessage: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("integration-test-cluster"), + ContainerInstance: aws.String("integration-test-instance"), + RequestId: aws.String("integration-test-request"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + }, + expectedRequestCount: 1, + description: "Single instanceStatus message should result in one backend request", + }, + { + name: "complete flow with multiple statuses", + instanceStatusMessage: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("integration-test-cluster"), + ContainerInstance: aws.String("integration-test-instance"), + RequestId: aws.String("integration-test-request-multi"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + { + Status: aws.String("IMPAIRED"), + Type: aws.String("DOCKER"), + }, + { + Status: aws.String("OK"), + Type: aws.String("EBS_CSI"), + }, + }, + }, + expectedRequestCount: 1, + description: "Multiple statuses in one message should result in one backend request", + }, + { + name: "complete flow with empty statuses", + instanceStatusMessage: ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("integration-test-cluster"), + ContainerInstance: aws.String("integration-test-instance"), + RequestId: aws.String("integration-test-request-empty"), + }, + Statuses: []*ecstcs.InstanceStatus{}, + }, + expectedRequestCount: 1, + description: "Empty statuses should still result in one backend request", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock websocket connection. + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + + // Create channels for all message types. + instanceStatusMessages := make(chan ecstcs.InstanceStatusMessage, 1) + + // Create TCS client with instanceStatus channel. + cs := testCSIntegration(conn, nil, nil, instanceStatusMessages).(*tcsClientServer) + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + // Set up mock expectations for the backend request. + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).Times(tc.expectedRequestCount) + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).DoAndReturn( + func(messageType int, data []byte) error { + // Verify that the request contains expected data from the message. + dataStr := string(data) + + // Verify metadata fields are present in the request. + if tc.instanceStatusMessage.Metadata != nil { + if tc.instanceStatusMessage.Metadata.Cluster != nil { + assert.Contains(t, dataStr, *tc.instanceStatusMessage.Metadata.Cluster, + "Backend request should contain cluster name") + } + if tc.instanceStatusMessage.Metadata.ContainerInstance != nil { + assert.Contains(t, dataStr, *tc.instanceStatusMessage.Metadata.ContainerInstance, + "Backend request should contain container instance") + } + if tc.instanceStatusMessage.Metadata.RequestId != nil { + assert.Contains(t, dataStr, *tc.instanceStatusMessage.Metadata.RequestId, + "Backend request should contain request ID") + } + } + + // Verify status information is present in the request. + for _, status := range tc.instanceStatusMessage.Statuses { + if status.Status != nil { + assert.Contains(t, dataStr, *status.Status, + "Backend request should contain status value") + } + if status.Type != nil { + assert.Contains(t, dataStr, *status.Type, + "Backend request should contain status type") + } + } + + // Verify timestamp is present (should be in all requests). + assert.Contains(t, dataStr, "timestamp", + "Backend request should contain timestamp field") + + return nil + }, + ).Times(tc.expectedRequestCount) + + // Start publishMessages in a goroutine. + go cs.publishMessages(ctx) + + // Send the instanceStatus message through the channel. + instanceStatusMessages <- tc.instanceStatusMessage + + // Give time for the complete flow to process. + time.Sleep(300 * time.Millisecond) + + // Verify message was consumed from channel. + assert.Len(t, instanceStatusMessages, 0, + "InstanceStatus message should be consumed from channel") + + // Cancel context to stop publishMessages. + cancel() + }) + } +} + +// TestInteractionBetweenMessageTypes tests that instanceStatus messages work correctly alongside metrics and health messages. +func TestInteractionBetweenMessageTypes(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + sendTelemetry bool + sendHealth bool + sendInstanceStatus bool + expectedTotalRequests int + description string + }{ + { + name: "all three message types together", + sendTelemetry: true, + sendHealth: true, + sendInstanceStatus: true, + expectedTotalRequests: 3, + description: "All three message types should be processed independently", + }, + { + name: "instanceStatus with telemetry only", + sendTelemetry: true, + sendHealth: false, + sendInstanceStatus: true, + expectedTotalRequests: 2, + description: "InstanceStatus and telemetry should work together", + }, + { + name: "instanceStatus with health only", + sendTelemetry: false, + sendHealth: true, + sendInstanceStatus: true, + expectedTotalRequests: 2, + description: "InstanceStatus and health should work together", + }, + { + name: "instanceStatus only", + sendTelemetry: false, + sendHealth: false, + sendInstanceStatus: true, + expectedTotalRequests: 1, + description: "InstanceStatus should work independently", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock websocket connection. + conn := mock_wsconn.NewMockWebsocketConn(ctrl) + + // Create channels for all message types. + telemetryMessages := make(chan ecstcs.TelemetryMessage, 1) + healthMessages := make(chan ecstcs.HealthMessage, 1) + instanceStatusMessages := make(chan ecstcs.InstanceStatusMessage, 1) + + // Create TCS client with all channels. + cs := testCSIntegration(conn, telemetryMessages, healthMessages, instanceStatusMessages).(*tcsClientServer) + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + // Set up mock expectations for backend requests. + // Use AnyTimes() to handle variable mock call expectations for different message types. + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + + // Start publishMessages in a goroutine. + go cs.publishMessages(ctx) + + // Send messages based on test case configuration. + if tc.sendTelemetry { + telemetryMessage := ecstcs.TelemetryMessage{ + Metadata: &ecstcs.MetricsMetadata{ + Cluster: aws.String("integration-test-cluster"), + ContainerInstance: aws.String("integration-test-instance"), + Idle: aws.Bool(true), + MessageId: aws.String("integration-test-telemetry"), + }, + TaskMetrics: []*ecstcs.TaskMetric{}, + } + telemetryMessages <- telemetryMessage + } + + if tc.sendHealth { + healthMessage := ecstcs.HealthMessage{ + Metadata: &ecstcs.HealthMetadata{ + Cluster: aws.String("integration-test-cluster"), + ContainerInstance: aws.String("integration-test-instance"), + MessageId: aws.String("integration-test-health"), + }, + HealthMetrics: []*ecstcs.TaskHealth{}, + } + healthMessages <- healthMessage + } + + if tc.sendInstanceStatus { + instanceStatusMessage := ecstcs.InstanceStatusMessage{ + Metadata: &ecstcs.InstanceStatusMetadata{ + Cluster: aws.String("integration-test-cluster"), + ContainerInstance: aws.String("integration-test-instance"), + RequestId: aws.String("integration-test-instance-status"), + }, + Statuses: []*ecstcs.InstanceStatus{ + { + Status: aws.String("OK"), + Type: aws.String("AGENT"), + }, + }, + } + instanceStatusMessages <- instanceStatusMessage + } + + // Give time for all messages to be processed. + time.Sleep(500 * time.Millisecond) + + // Verify all messages were consumed from their respective channels. + if tc.sendTelemetry { + assert.Len(t, telemetryMessages, 0, + "Telemetry message should be consumed from channel") + } + if tc.sendHealth { + assert.Len(t, healthMessages, 0, + "Health message should be consumed from channel") + } + if tc.sendInstanceStatus { + assert.Len(t, instanceStatusMessages, 0, + "InstanceStatus message should be consumed from channel") + } + + // Cancel context to stop publishMessages. + cancel() + }) + } +} + +// containsSubstring is a helper function to check if a string contains a substring. +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false } diff --git a/ecs-agent/tcs/handler/handler.go b/ecs-agent/tcs/handler/handler.go index 74e16c30564..2bddcdb6219 100644 --- a/ecs-agent/tcs/handler/handler.go +++ b/ecs-agent/tcs/handler/handler.go @@ -161,7 +161,7 @@ func (session *telemetrySession) StartTelemetrySession(ctx context.Context) erro tcsEndpointUrl := formatURL(endpoint, session.cluster, session.containerInstanceArn, session.agentVersion, session.agentHash, containerRuntime, session.containerRuntimeVersion) client := tcsclient.New(tcsEndpointUrl, session.cfg, session.doctor, session.disableMetrics, tcsclient.DefaultContainerMetricsPublishInterval, - session.credentialsCache, wsRWTimeout, session.metricsChannel, session.healthChannel, session.metricsFactory) + session.credentialsCache, wsRWTimeout, session.metricsChannel, session.healthChannel, nil, session.metricsFactory) defer client.Close() if session.deregisterInstanceEventStream != nil { diff --git a/ecs-agent/tcs/model/ecstcs/types.go b/ecs-agent/tcs/model/ecstcs/types.go index 294ac84d9de..e840dc84b40 100644 --- a/ecs-agent/tcs/model/ecstcs/types.go +++ b/ecs-agent/tcs/model/ecstcs/types.go @@ -14,6 +14,8 @@ package ecstcs import ( + "errors" + "strings" "time" "github.com/aws/amazon-ecs-agent/ecs-agent/utils" @@ -50,3 +52,84 @@ type HealthMessage struct { Metadata *HealthMetadata HealthMetrics []*TaskHealth } + +// InstanceStatusMessage represents a message containing instance health status +// information to be published to the TCS backend. +type InstanceStatusMessage struct { + // Metadata contains identifying information about the container instance + // including cluster name, container instance ARN, and request ID. + Metadata *InstanceStatusMetadata `json:"metadata,omitempty"` + + // Statuses contains a collection of instance status checks that represent + // the health state of various components on the container instance. + Statuses []*InstanceStatus `json:"statuses,omitempty"` +} + +const ( + InstanceHealthCheckTypeContainerRuntime = "ContainerRuntime" + InstanceHealthCheckTypeAgent = "Agent" + InstanceHealthCheckTypeEBSDaemon = "EBSDaemon" + InstanceHealthCheckTypeNvidia = "NvidiaAcceleratedHardware" +) + +const ( + // HealthcheckStatusInitializing is the zero state of a healthcheck status. + InstanceHealthCheckStatusInitializing InstanceHealthCheckStatus = iota + // HealthcheckStatusOk represents a healthcheck with a true/success result. + InstanceHealthCheckStatusOk + // HealthcheckStatusImpaired represents a healthcheck with a false/fail result. + InstanceHealthCheckStatusImpaired +) + +// InstanceHealthCheckStatus is an enumeration of possible instance health check statuses. +type InstanceHealthCheckStatus int32 + +var instanceHealthCheckStatusMap = map[string]InstanceHealthCheckStatus{ + "INITIALIZING": InstanceHealthCheckStatusInitializing, + "OK": InstanceHealthCheckStatusOk, + "IMPAIRED": InstanceHealthCheckStatusImpaired, +} + +// String returns a human readable string representation of this object. +func (hs InstanceHealthCheckStatus) String() string { + for k, v := range instanceHealthCheckStatusMap { + if v == hs { + return k + } + } + // We shouldn't see this. + return "NONE" +} + +// Ok returns true if the instance health check status is OK or INITIALIZING. +func (hs InstanceHealthCheckStatus) Ok() bool { + return hs == InstanceHealthCheckStatusOk || hs == InstanceHealthCheckStatusInitializing +} + +// UnmarshalJSON overrides the logic for parsing the JSON-encoded InstanceHealthCheckStatus data. +func (hs *InstanceHealthCheckStatus) UnmarshalJSON(b []byte) error { + if strings.ToLower(string(b)) == "null" { + *hs = InstanceHealthCheckStatusInitializing + return nil + } + if b[0] != '"' || b[len(b)-1] != '"' { + *hs = InstanceHealthCheckStatusInitializing + return errors.New("instance health check status unmarshal: status must be a string or null; Got " + string(b)) + } + + stat, ok := instanceHealthCheckStatusMap[string(b[1:len(b)-1])] + if !ok { + *hs = InstanceHealthCheckStatusInitializing + return errors.New("instance health check status unmarshal: unrecognized status") + } + *hs = stat + return nil +} + +// MarshalJSON overrides the logic for JSON-encoding the InstanceHealthCheckStatus type. +func (hs *InstanceHealthCheckStatus) MarshalJSON() ([]byte, error) { + if hs == nil { + return nil, nil + } + return []byte(`"` + hs.String() + `"`), nil +} diff --git a/ecs-agent/doctor/healthcheckstatus_test.go b/ecs-agent/tcs/model/ecstcs/types_test.go similarity index 72% rename from ecs-agent/doctor/healthcheckstatus_test.go rename to ecs-agent/tcs/model/ecstcs/types_test.go index eb6ca1aeecc..0fa2d1bf0a3 100644 --- a/ecs-agent/doctor/healthcheckstatus_test.go +++ b/ecs-agent/tcs/model/ecstcs/types_test.go @@ -14,7 +14,7 @@ // express or implied. See the License for the specific language governing // permissions and limitations under the License. -package doctor +package ecstcs import ( "encoding/json" @@ -25,33 +25,33 @@ import ( ) func TestOk(t *testing.T) { - initializingStatus := HealthcheckStatusInitializing - okStatus := HealthcheckStatusOk - impairedStatus := HealthcheckStatusImpaired + initializingStatus := InstanceHealthCheckStatusInitializing + okStatus := InstanceHealthCheckStatusOk + impairedStatus := InstanceHealthCheckStatusImpaired assert.True(t, initializingStatus.Ok()) assert.True(t, okStatus.Ok()) assert.False(t, impairedStatus.Ok()) } type testHealthcheckStatus struct { - SomeStatus HealthcheckStatus `json:"status"` + SomeStatus InstanceHealthCheckStatus `json:"status"` } func TestUnmarshalHealthcheckStatus(t *testing.T) { - status := HealthcheckStatusInitializing + status := InstanceHealthCheckStatusInitializing initializingStr := "INITIALIZING" err := json.Unmarshal([]byte(fmt.Sprintf(`"%s"`, initializingStr)), &status) assert.NoError(t, err) - // INITIALIZING should unmarshal to INITIALIZING - assert.Equal(t, HealthcheckStatusInitializing, status) + // INITIALIZING should unmarshal to INITIALIZING. + assert.Equal(t, InstanceHealthCheckStatusInitializing, status) assert.Equal(t, initializingStr, status.String()) var test testHealthcheckStatus impairedStr := "IMPAIRED" err = json.Unmarshal([]byte(fmt.Sprintf(`{"status":"%s"}`, impairedStr)), &test) assert.NoError(t, err) - // IMPAIRED should unmarshal to IMPAIRED - assert.Equal(t, HealthcheckStatusImpaired, test.SomeStatus) + // IMPAIRED should unmarshal to IMPAIRED. + assert.Equal(t, InstanceHealthCheckStatusImpaired, test.SomeStatus) assert.Equal(t, impairedStr, test.SomeStatus.String()) }