Skip to content

Add support for SageMaker HyperPod nodes by skipping them in cluster autoscaler #8195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions cluster-autoscaler/cloudprovider/aws/aws_cloud_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ func (aws *awsCloudProvider) NodeGroupForNode(node *apiv1.Node) (cloudprovider.N
klog.Warningf("Node %v has no providerId", node.Name)
return nil, nil
}

// Skip SageMaker HyperPod instances
if strings.HasPrefix(node.GetName(), "hyperpod") {
klog.V(4).Infof("Skipping SageMaker HyperPod node %s", node.Name)
return nil, nil
}

ref, err := AwsRefFromProviderId(node.Spec.ProviderID)
if err != nil {
// Dropping this into V as it will be noisy with many Hybrid Nodes
Expand Down Expand Up @@ -143,6 +150,11 @@ func (aws *awsCloudProvider) HasInstance(node *apiv1.Node) (bool, error) {
return true, cloudprovider.ErrNotImplemented
}

// Skip SageMaker HyperPod instances
if strings.HasPrefix(node.GetName(), "hyperpod") {
return true, cloudprovider.ErrNotImplemented
}

// avoid log spam for not autoscaled asgs:
// Nodes that belong to an asg that is not autoscaled will not be found in the asgCache below,
// so do not trigger warning spam by returning an error from being unable to find them.
Expand Down Expand Up @@ -205,10 +217,19 @@ type AwsInstanceRef struct {
}

var validAwsRefIdRegex = regexp.MustCompile(fmt.Sprintf(`^aws\:\/\/\/[-0-9a-z]*\/[-0-9a-z]*(\/[-0-9a-z\.]*)?$|aws\:\/\/\/[-0-9a-z]*\/%s.*$`, placeholderInstanceNamePrefix))
var sageMakerRefIdRegex = regexp.MustCompile(`^aws:///[-0-9a-z]+/sagemaker/.*$`)

// AwsRefFromProviderId creates AwsInstanceRef object from provider id which
// must be in format: aws:///zone/name
func AwsRefFromProviderId(id string) (*AwsInstanceRef, error) {
// Special case for SageMaker format: aws:///<region>/sagemaker/...
if sageMakerRefIdRegex.MatchString(id) {
return &AwsInstanceRef{
ProviderID: id,
Name: "sagemaker-node",
}, nil
}

if validAwsRefIdRegex.FindStringSubmatch(id) == nil {
return nil, fmt.Errorf("wrong id: expected format aws:///<zone>/<name>, got %v", id)
}
Expand Down Expand Up @@ -313,6 +334,11 @@ func (ng *AwsNodeGroup) DecreaseTargetSize(delta int) error {

// Belongs returns true if the given node belongs to the NodeGroup.
func (ng *AwsNodeGroup) Belongs(node *apiv1.Node) (bool, error) {
// Skip SageMaker HyperPod instances
if strings.HasPrefix(node.GetName(), "hyperpod") {
return false, nil
}

ref, err := AwsRefFromProviderId(node.Spec.ProviderID)
if err != nil {
return false, err
Expand Down
23 changes: 18 additions & 5 deletions cluster-autoscaler/cloudprovider/aws/aws_cloud_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,21 +725,34 @@ func TestHasInstance(t *testing.T) {
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
assert.True(t, present)

// Case 3: correct node - not present in AWS
// Case 3: incorrect node - sagemaker hyperpod is unsupported
node3 := &apiv1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "hyperpod-node-1",
},
Spec: apiv1.NodeSpec{
ProviderID: "aws:///use1-az2/sagemaker/cluster/hyperpod-abc123-i-abc123",
},
}
present, err = provider.HasInstance(node3)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
assert.True(t, present)

// Case 4: correct node - not present in AWS
node4 := &apiv1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "node-2",
},
Spec: apiv1.NodeSpec{
ProviderID: "aws:///us-east-1a/test-instance-id-2",
},
}
present, err = provider.HasInstance(node3)
present, err = provider.HasInstance(node4)
assert.ErrorContains(t, err, nodeNotPresentErr)
assert.False(t, present)

// Case 4: correct node - not autoscaled -> not present in AWS -> no warning
node4 := &apiv1.Node{
// Case 5: correct node - not autoscaled -> not present in AWS -> no warning
node5 := &apiv1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "node-2",
Annotations: map[string]string{
Expand All @@ -750,7 +763,7 @@ func TestHasInstance(t *testing.T) {
ProviderID: "aws:///us-east-1a/test-instance-id-2",
},
}
present, err = provider.HasInstance(node4)
present, err = provider.HasInstance(node5)
assert.NoError(t, err)
assert.False(t, present)
}
Expand Down
Loading