diff --git a/pkg/azureclients/armclient/azure_armclient.go b/pkg/azureclients/armclient/azure_armclient.go index 2c6b5c4709..7eb068c9bb 100644 --- a/pkg/azureclients/armclient/azure_armclient.go +++ b/pkg/azureclients/armclient/azure_armclient.go @@ -19,6 +19,7 @@ package armclient import ( "context" "crypto/tls" + "encoding/json" "fmt" "html" "net" @@ -430,6 +431,16 @@ func (c *Client) waitAsync(ctx context.Context, futures map[string]*azure.Future // PutResourcesInBatches is similar with PutResources, but it sends sync request concurrently in batches. func (c *Client) PutResourcesInBatches(ctx context.Context, resources map[string]interface{}, batchSize int) map[string]*PutResourcesResponse { + return c.PutResourcesInBatchesBase(ctx, resources, batchSize, false) +} + +// PutResourcesInBatchesWithEtag is similar with PutResources, but it sends sync request concurrently in batches with Etag header when Etag field is not empty. +func (c *Client) PutResourcesInBatchesWithEtag(ctx context.Context, resources map[string]interface{}, batchSize int) map[string]*PutResourcesResponse { + return c.PutResourcesInBatchesBase(ctx, resources, batchSize, true) +} + +// PutResourcesInBatches is similar with PutResources, but it sends sync request concurrently in batches. +func (c *Client) PutResourcesInBatchesBase(ctx context.Context, resources map[string]interface{}, batchSize int, enableEtag bool) map[string]*PutResourcesResponse { if len(resources) == 0 { return nil } @@ -458,7 +469,23 @@ func (c *Client) PutResourcesInBatches(ctx context.Context, resources map[string go func(resourceID string, parameters interface{}) { defer wg.Done() defer func() { <-rateLimiter }() - future, rerr := c.PutResourceAsync(ctx, resourceID, parameters) + decorators := []autorest.PrepareDecorator{} + if enableEtag { + type etagPlaceholder struct { + Etag *string `json:"etag,omitempty"` + } + + p := &etagPlaceholder{} + b, err := json.Marshal(parameters) + if err == nil { + err = json.Unmarshal(b, &p) + if err == nil && p.Etag != nil { + decorators = append(decorators, autorest.WithHeader("If-Match", autorest.String(*p.Etag))) + } + } + } + + future, rerr := c.PutResourceAsync(ctx, resourceID, parameters, decorators...) if rerr != nil { responseLock.Lock() responses[resourceID] = &PutResourcesResponse{ diff --git a/pkg/azureclients/armclient/interface.go b/pkg/azureclients/armclient/interface.go index e66ba8a18b..f988986f2e 100644 --- a/pkg/azureclients/armclient/interface.go +++ b/pkg/azureclients/armclient/interface.go @@ -71,6 +71,9 @@ type Interface interface { // PutResourcesInBatches is similar with PutResources, but it sends sync request concurrently in batches. PutResourcesInBatches(ctx context.Context, resources map[string]interface{}, batchSize int) map[string]*PutResourcesResponse + // PutResourcesInBatchesWithEtag is similar with PutResources, but it sends sync request concurrently in batches with Etag header when Etag field is not empty. + PutResourcesInBatchesWithEtag(ctx context.Context, resources map[string]interface{}, batchSize int) map[string]*PutResourcesResponse + // PatchResource patches a resource by resource ID PatchResource(ctx context.Context, resourceID string, parameters interface{}, decorators ...autorest.PrepareDecorator) (*http.Response, *retry.Error) diff --git a/pkg/azureclients/armclient/mockarmclient/interface.go b/pkg/azureclients/armclient/mockarmclient/interface.go index 9dc11c6128..72b692cb5c 100644 --- a/pkg/azureclients/armclient/mockarmclient/interface.go +++ b/pkg/azureclients/armclient/mockarmclient/interface.go @@ -401,6 +401,20 @@ func (mr *MockInterfaceMockRecorder) PutResourcesInBatches(ctx, resources, batch return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutResourcesInBatches", reflect.TypeOf((*MockInterface)(nil).PutResourcesInBatches), ctx, resources, batchSize) } +// PutResourcesInBatchesWithEtag mocks base method. +func (m *MockInterface) PutResourcesInBatchesWithEtag(ctx context.Context, resources map[string]any, batchSize int) map[string]*armclient.PutResourcesResponse { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PutResourcesInBatchesWithEtag", ctx, resources, batchSize) + ret0, _ := ret[0].(map[string]*armclient.PutResourcesResponse) + return ret0 +} + +// PutResourcesInBatchesWithEtag indicates an expected call of PutResourcesInBatchesWithEtag. +func (mr *MockInterfaceMockRecorder) PutResourcesInBatchesWithEtag(ctx, resources, batchSize any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutResourcesInBatchesWithEtag", reflect.TypeOf((*MockInterface)(nil).PutResourcesInBatchesWithEtag), ctx, resources, batchSize) +} + // Send mocks base method. func (m *MockInterface) Send(ctx context.Context, request *http.Request, decorators ...autorest.SendDecorator) (*http.Response, *retry.Error) { m.ctrl.T.Helper() diff --git a/pkg/azureclients/vmssclient/azure_vmssclient.go b/pkg/azureclients/vmssclient/azure_vmssclient.go index d58707929c..255059a9fd 100644 --- a/pkg/azureclients/vmssclient/azure_vmssclient.go +++ b/pkg/azureclients/vmssclient/azure_vmssclient.go @@ -87,20 +87,20 @@ func New(config *azclients.ClientConfig) *Client { } // Get gets a VirtualMachineScaleSet. -func (c *Client) Get(ctx context.Context, resourceGroupName string, VMScaleSetName string) (compute.VirtualMachineScaleSet, *retry.Error) { +func (c *Client) Get(ctx context.Context, resourceGroupName string, VMScaleSetName string) (VirtualMachineScaleSet, *retry.Error) { mc := metrics.NewMetricContext("vmss", "get", resourceGroupName, c.subscriptionID, "") // Report errors if the client is rate limited. if !c.rateLimiterReader.TryAccept() { mc.RateLimitedCount() - return compute.VirtualMachineScaleSet{}, retry.GetRateLimitError(false, "VMSSGet") + return VirtualMachineScaleSet{}, retry.GetRateLimitError(false, "VMSSGet") } // Report errors if the client is throttled. if c.RetryAfterReader.After(time.Now()) { mc.ThrottledCount() rerr := retry.GetThrottlingError("VMSSGet", "client throttled", c.RetryAfterReader) - return compute.VirtualMachineScaleSet{}, rerr + return VirtualMachineScaleSet{}, rerr } result, rerr := c.getVMSS(ctx, resourceGroupName, VMScaleSetName) @@ -118,14 +118,14 @@ func (c *Client) Get(ctx context.Context, resourceGroupName string, VMScaleSetNa } // getVMSS gets a VirtualMachineScaleSet. -func (c *Client) getVMSS(ctx context.Context, resourceGroupName string, VMScaleSetName string) (compute.VirtualMachineScaleSet, *retry.Error) { +func (c *Client) getVMSS(ctx context.Context, resourceGroupName string, VMScaleSetName string) (VirtualMachineScaleSet, *retry.Error) { resourceID := armclient.GetResourceID( c.subscriptionID, resourceGroupName, vmssResourceType, VMScaleSetName, ) - result := compute.VirtualMachineScaleSet{} + result := VirtualMachineScaleSet{} response, rerr := c.armClient.GetResource(ctx, resourceID) defer c.armClient.CloseResponse(ctx, response) @@ -148,7 +148,7 @@ func (c *Client) getVMSS(ctx context.Context, resourceGroupName string, VMScaleS } // List gets a list of VirtualMachineScaleSets in the resource group. -func (c *Client) List(ctx context.Context, resourceGroupName string) ([]compute.VirtualMachineScaleSet, *retry.Error) { +func (c *Client) List(ctx context.Context, resourceGroupName string) ([]VirtualMachineScaleSet, *retry.Error) { mc := metrics.NewMetricContext("vmss", "list", resourceGroupName, c.subscriptionID, "") // Report errors if the client is rate limited. @@ -179,13 +179,13 @@ func (c *Client) List(ctx context.Context, resourceGroupName string) ([]compute. } // listVMSS gets a list of VirtualMachineScaleSets in the resource group. -func (c *Client) listVMSS(ctx context.Context, resourceGroupName string) ([]compute.VirtualMachineScaleSet, *retry.Error) { +func (c *Client) listVMSS(ctx context.Context, resourceGroupName string) ([]VirtualMachineScaleSet, *retry.Error) { resourceID := armclient.GetResourceListID( c.subscriptionID, resourceGroupName, vmssResourceType, ) - result := make([]compute.VirtualMachineScaleSet, 0) + result := make([]VirtualMachineScaleSet, 0) page := &VirtualMachineScaleSetListResultPage{} page.fn = c.listNextResults @@ -221,7 +221,7 @@ func (c *Client) listVMSS(ctx context.Context, resourceGroupName string) ([]comp } // CreateOrUpdate creates or updates a VirtualMachineScaleSet. -func (c *Client) CreateOrUpdate(ctx context.Context, resourceGroupName string, VMScaleSetName string, parameters compute.VirtualMachineScaleSet) *retry.Error { +func (c *Client) CreateOrUpdate(ctx context.Context, resourceGroupName string, VMScaleSetName string, parameters VirtualMachineScaleSet, etag string) *retry.Error { mc := metrics.NewMetricContext("vmss", "create_or_update", resourceGroupName, c.subscriptionID, "") // Report errors if the client is rate limited. @@ -237,7 +237,7 @@ func (c *Client) CreateOrUpdate(ctx context.Context, resourceGroupName string, V return rerr } - rerr := c.createOrUpdateVMSS(ctx, resourceGroupName, VMScaleSetName, parameters) + rerr := c.createOrUpdateVMSS(ctx, resourceGroupName, VMScaleSetName, parameters, etag) mc.Observe(rerr) if rerr != nil { if rerr.IsThrottled() { @@ -252,7 +252,7 @@ func (c *Client) CreateOrUpdate(ctx context.Context, resourceGroupName string, V } // CreateOrUpdateAsync sends the request to arm client and DO NOT wait for the response -func (c *Client) CreateOrUpdateAsync(ctx context.Context, resourceGroupName string, VMScaleSetName string, parameters compute.VirtualMachineScaleSet) (*azure.Future, *retry.Error) { +func (c *Client) CreateOrUpdateAsync(ctx context.Context, resourceGroupName string, VMScaleSetName string, parameters VirtualMachineScaleSet, etag string) (*azure.Future, *retry.Error) { mc := metrics.NewMetricContext("vmss", "create_or_update_async", resourceGroupName, c.subscriptionID, "") // Report errors if the client is rate limited. @@ -275,7 +275,12 @@ func (c *Client) CreateOrUpdateAsync(ctx context.Context, resourceGroupName stri VMScaleSetName, ) - future, rerr := c.armClient.PutResourceAsync(ctx, resourceID, parameters) + decorators := []autorest.PrepareDecorator{} + if etag != "" { + decorators = append(decorators, autorest.WithHeader("If-Match", autorest.String(etag))) + } + + future, rerr := c.armClient.PutResourceAsync(ctx, resourceID, parameters, decorators...) mc.Observe(rerr) if rerr != nil { if rerr.IsThrottled() { @@ -318,14 +323,20 @@ func (c *Client) WaitForAsyncOperationResult(ctx context.Context, future *azure. } // createOrUpdateVMSS creates or updates a VirtualMachineScaleSet. -func (c *Client) createOrUpdateVMSS(ctx context.Context, resourceGroupName string, VMScaleSetName string, parameters compute.VirtualMachineScaleSet) *retry.Error { +func (c *Client) createOrUpdateVMSS(ctx context.Context, resourceGroupName string, VMScaleSetName string, parameters VirtualMachineScaleSet, etag string) *retry.Error { resourceID := armclient.GetResourceID( c.subscriptionID, resourceGroupName, vmssResourceType, VMScaleSetName, ) - response, rerr := c.armClient.PutResource(ctx, resourceID, parameters) + + decorators := []autorest.PrepareDecorator{} + if etag != "" { + decorators = append(decorators, autorest.WithHeader("If-Match", autorest.String(etag))) + } + + response, rerr := c.armClient.PutResource(ctx, resourceID, parameters, decorators...) defer c.armClient.CloseResponse(ctx, response) if rerr != nil { klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vmss.put.request", resourceID, rerr.Error()) @@ -353,7 +364,7 @@ func (c *Client) createOrUpdateResponder(resp *http.Response) (*compute.VirtualM return result, retry.GetError(resp, err) } -func (c *Client) listResponder(resp *http.Response) (result compute.VirtualMachineScaleSetListResult, err error) { +func (c *Client) listResponder(resp *http.Response) (result VirtualMachineScaleSetListResult, err error) { err = autorest.Respond( resp, autorest.ByIgnoring(), @@ -365,7 +376,7 @@ func (c *Client) listResponder(resp *http.Response) (result compute.VirtualMachi // virtualMachineScaleSetListResultPreparer prepares a request to retrieve the next set of results. // It returns nil if no more results exist. -func (c *Client) virtualMachineScaleSetListResultPreparer(ctx context.Context, vmsslr compute.VirtualMachineScaleSetListResult) (*http.Request, error) { +func (c *Client) virtualMachineScaleSetListResultPreparer(ctx context.Context, vmsslr VirtualMachineScaleSetListResult) (*http.Request, error) { if vmsslr.NextLink == nil || len(ptr.Deref(vmsslr.NextLink, "")) < 1 { return nil, nil } @@ -377,7 +388,7 @@ func (c *Client) virtualMachineScaleSetListResultPreparer(ctx context.Context, v } // listNextResults retrieves the next set of results, if any. -func (c *Client) listNextResults(ctx context.Context, lastResults compute.VirtualMachineScaleSetListResult) (result compute.VirtualMachineScaleSetListResult, err error) { +func (c *Client) listNextResults(ctx context.Context, lastResults VirtualMachineScaleSetListResult) (result VirtualMachineScaleSetListResult, err error) { req, err := c.virtualMachineScaleSetListResultPreparer(ctx, lastResults) if err != nil { return result, autorest.NewErrorWithError(err, "vmssclient", "listNextResults", nil, "Failure preparing next results request") @@ -403,8 +414,8 @@ func (c *Client) listNextResults(ctx context.Context, lastResults compute.Virtua // VirtualMachineScaleSetListResultPage contains a page of VirtualMachineScaleSet values. type VirtualMachineScaleSetListResultPage struct { - fn func(context.Context, compute.VirtualMachineScaleSetListResult) (compute.VirtualMachineScaleSetListResult, error) - vmsslr compute.VirtualMachineScaleSetListResult + fn func(context.Context, VirtualMachineScaleSetListResult) (VirtualMachineScaleSetListResult, error) + vmsslr VirtualMachineScaleSetListResult } // NextWithContext advances to the next page of values. If there was an error making @@ -431,12 +442,12 @@ func (page VirtualMachineScaleSetListResultPage) NotDone() bool { } // Response returns the raw server response from the last page request. -func (page VirtualMachineScaleSetListResultPage) Response() compute.VirtualMachineScaleSetListResult { +func (page VirtualMachineScaleSetListResultPage) Response() VirtualMachineScaleSetListResult { return page.vmsslr } // Values returns the slice of values for the current page or nil if there are no values. -func (page VirtualMachineScaleSetListResultPage) Values() []compute.VirtualMachineScaleSet { +func (page VirtualMachineScaleSetListResultPage) Values() []VirtualMachineScaleSet { if page.vmsslr.IsEmpty() { return nil } diff --git a/pkg/azureclients/vmssclient/azure_vmssclient_test.go b/pkg/azureclients/vmssclient/azure_vmssclient_test.go index 5129645033..d5e0a7860b 100644 --- a/pkg/azureclients/vmssclient/azure_vmssclient_test.go +++ b/pkg/azureclients/vmssclient/azure_vmssclient_test.go @@ -40,6 +40,7 @@ import ( azclients "sigs.k8s.io/cloud-provider-azure/pkg/azureclients" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/armclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/armclient/mockarmclient" + "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) @@ -103,7 +104,7 @@ func TestGet(t *testing.T) { armClient.EXPECT().GetResource(gomock.Any(), testResourceID).Return(response, nil).Times(1) armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - expected := compute.VirtualMachineScaleSet{Response: autorest.Response{Response: response}} + expected := VirtualMachineScaleSet{VirtualMachineScaleSet: compute.VirtualMachineScaleSet{Response: autorest.Response{Response: response}}} vmssClient := getTestVMSSClient(armClient) result, rerr := vmssClient.Get(context.TODO(), "rg", "vmss1") assert.Equal(t, expected, result) @@ -121,7 +122,7 @@ func TestGetNeverRateLimiter(t *testing.T) { armClient := mockarmclient.NewMockInterface(ctrl) vmssClient := getTestVMSSClientWithNeverRateLimiter(armClient) - expected := compute.VirtualMachineScaleSet{} + expected := VirtualMachineScaleSet{} result, rerr := vmssClient.Get(context.TODO(), "rg", "vmss1") assert.Equal(t, expected, result) assert.Equal(t, vmssGetErr, rerr) @@ -139,7 +140,7 @@ func TestGetRetryAfterReader(t *testing.T) { armClient := mockarmclient.NewMockInterface(ctrl) vmssClient := getTestVMSSClientWithRetryAfterReader(armClient) - expected := compute.VirtualMachineScaleSet{} + expected := VirtualMachineScaleSet{} result, rerr := vmssClient.Get(context.TODO(), "rg", "vmss1") assert.Equal(t, expected, result) assert.Equal(t, vmssGetErr, rerr) @@ -158,7 +159,7 @@ func TestGetNotFound(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssClient := getTestVMSSClient(armClient) - expectedVMSS := compute.VirtualMachineScaleSet{Response: autorest.Response{}} + expectedVMSS := VirtualMachineScaleSet{VirtualMachineScaleSet: compute.VirtualMachineScaleSet{Response: autorest.Response{}}} result, rerr := vmssClient.Get(context.TODO(), "rg", "vmss1") assert.Equal(t, expectedVMSS, result) assert.NotNil(t, rerr) @@ -178,7 +179,7 @@ func TestGetInternalError(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssClient := getTestVMSSClient(armClient) - expectedVMSS := compute.VirtualMachineScaleSet{Response: autorest.Response{}} + expectedVMSS := VirtualMachineScaleSet{VirtualMachineScaleSet: compute.VirtualMachineScaleSet{Response: autorest.Response{}}} result, rerr := vmssClient.Get(context.TODO(), "rg", "vmss1") assert.Equal(t, expectedVMSS, result) assert.NotNil(t, rerr) @@ -214,8 +215,8 @@ func TestList(t *testing.T) { defer ctrl.Finish() armClient := mockarmclient.NewMockInterface(ctrl) - vmssList := []compute.VirtualMachineScaleSet{getTestVMSS("vmss1"), getTestVMSS("vmss2"), getTestVMSS("vmss3")} - responseBody, err := json.Marshal(compute.VirtualMachineScaleSetListResult{Value: &vmssList}) + vmssList := []VirtualMachineScaleSet{getTestVMSS("vmss1"), getTestVMSS("vmss2"), getTestVMSS("vmss3")} + responseBody, err := json.Marshal(VirtualMachineScaleSetListResult{Value: &vmssList}) assert.NoError(t, err) armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return( &http.Response{ @@ -243,7 +244,7 @@ func TestListNotFound(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssClient := getTestVMSSClient(armClient) - expected := []compute.VirtualMachineScaleSet{} + expected := []VirtualMachineScaleSet{} result, rerr := vmssClient.List(context.TODO(), "rg") assert.Equal(t, expected, result) assert.NotNil(t, rerr) @@ -263,7 +264,7 @@ func TestListInternalError(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssClient := getTestVMSSClient(armClient) - expected := []compute.VirtualMachineScaleSet{} + expected := []VirtualMachineScaleSet{} result, rerr := vmssClient.List(context.TODO(), "rg") assert.Equal(t, expected, result) assert.NotNil(t, rerr) @@ -299,8 +300,8 @@ func TestListWithListResponderError(t *testing.T) { defer ctrl.Finish() armClient := mockarmclient.NewMockInterface(ctrl) - vmssList := []compute.VirtualMachineScaleSet{getTestVMSS("vmss1"), getTestVMSS("vmss2"), getTestVMSS("vmss3")} - responseBody, err := json.Marshal(compute.VirtualMachineScaleSetListResult{Value: &vmssList}) + vmssList := []VirtualMachineScaleSet{getTestVMSS("vmss1"), getTestVMSS("vmss2"), getTestVMSS("vmss3")} + responseBody, err := json.Marshal(VirtualMachineScaleSetListResult{Value: &vmssList}) assert.NoError(t, err) armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return( &http.Response{ @@ -319,10 +320,10 @@ func TestListWithNextPage(t *testing.T) { defer ctrl.Finish() armClient := mockarmclient.NewMockInterface(ctrl) - vmssList := []compute.VirtualMachineScaleSet{getTestVMSS("vmss1"), getTestVMSS("vmss2"), getTestVMSS("vmss3")} - partialResponse, err := json.Marshal(compute.VirtualMachineScaleSetListResult{Value: &vmssList, NextLink: ptr.To("nextLink")}) + vmssList := []VirtualMachineScaleSet{getTestVMSS("vmss1"), getTestVMSS("vmss2"), getTestVMSS("vmss3")} + partialResponse, err := json.Marshal(VirtualMachineScaleSetListResult{Value: &vmssList, NextLink: ptr.To("nextLink")}) assert.NoError(t, err) - pagedResponse, err := json.Marshal(compute.VirtualMachineScaleSetListResult{Value: &vmssList}) + pagedResponse, err := json.Marshal(VirtualMachineScaleSetListResult{Value: &vmssList}) assert.NoError(t, err) armClient.EXPECT().PrepareGetRequest(gomock.Any(), gomock.Any()).Return(&http.Request{}, nil) armClient.EXPECT().Send(gomock.Any(), gomock.Any()).Return( @@ -404,7 +405,7 @@ func TestListNextResultsMultiPages(t *testing.T) { }, } - lastResult := compute.VirtualMachineScaleSetListResult{ + lastResult := VirtualMachineScaleSetListResult{ NextLink: ptr.To("next"), } @@ -460,7 +461,7 @@ func TestListNextResultsMultiPagesWithListResponderError(t *testing.T) { }, } - lastResult := compute.VirtualMachineScaleSetListResult{ + lastResult := VirtualMachineScaleSetListResult{ NextLink: ptr.To("next"), } @@ -482,7 +483,7 @@ func TestListNextResultsMultiPagesWithListResponderError(t *testing.T) { StatusCode: http.StatusNotFound, Body: io.NopCloser(bytes.NewBuffer([]byte(`{"foo":"bar"}`))), } - expected := compute.VirtualMachineScaleSetListResult{} + expected := VirtualMachineScaleSetListResult{} expected.Response = autorest.Response{Response: response} vmssClient := getTestVMSSClient(armClient) result, err := vmssClient.listNextResults(context.TODO(), lastResult) @@ -509,7 +510,7 @@ func TestCreateOrUpdate(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssClient := getTestVMSSClient(armClient) - rerr := vmssClient.CreateOrUpdate(context.TODO(), "rg", "vmss1", vmss) + rerr := vmssClient.CreateOrUpdate(context.TODO(), "rg", "vmss1", vmss, "") assert.Nil(t, rerr) } @@ -526,7 +527,7 @@ func TestCreateOrUpdateWithCreateOrUpdateResponderError(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssClient := getTestVMSSClient(armClient) - rerr := vmssClient.CreateOrUpdate(context.TODO(), "rg", "vmss1", vmss) + rerr := vmssClient.CreateOrUpdate(context.TODO(), "rg", "vmss1", vmss, "") assert.NotNil(t, rerr) } @@ -539,7 +540,7 @@ func TestCreateOrUpdateNeverRateLimiter(t *testing.T) { armClient := mockarmclient.NewMockInterface(ctrl) vmssClient := getTestVMSSClientWithNeverRateLimiter(armClient) vmss := getTestVMSS("vmss1") - rerr := vmssClient.CreateOrUpdate(context.TODO(), "rg", "vmss1", vmss) + rerr := vmssClient.CreateOrUpdate(context.TODO(), "rg", "vmss1", vmss, "") assert.NotNil(t, rerr) assert.Equal(t, vmssCreateOrUpdateErr, rerr) } @@ -553,7 +554,7 @@ func TestCreateOrUpdateRetryAfterReader(t *testing.T) { vmss := getTestVMSS("vmss1") armClient := mockarmclient.NewMockInterface(ctrl) vmssClient := getTestVMSSClientWithRetryAfterReader(armClient) - rerr := vmssClient.CreateOrUpdate(context.TODO(), "rg", "vmss1", vmss) + rerr := vmssClient.CreateOrUpdate(context.TODO(), "rg", "vmss1", vmss, "") assert.NotNil(t, rerr) assert.Equal(t, vmssCreateOrUpdateErr, rerr) } @@ -579,7 +580,7 @@ func TestCreateOrUpdateThrottle(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssClient := getTestVMSSClient(armClient) - rerr := vmssClient.CreateOrUpdate(context.TODO(), "rg", "vmss1", vmss) + rerr := vmssClient.CreateOrUpdate(context.TODO(), "rg", "vmss1", vmss, "") assert.NotNil(t, rerr) assert.Equal(t, throttleErr, rerr) } @@ -594,12 +595,12 @@ func TestCreateOrUpdateAsync(t *testing.T) { armClient.EXPECT().PutResourceAsync(gomock.Any(), ptr.Deref(vmss.ID, ""), vmss).Return(future, nil).Times(1) vmssClient := getTestVMSSClient(armClient) - _, rerr := vmssClient.CreateOrUpdateAsync(context.TODO(), "rg", "vmss1", vmss) + _, rerr := vmssClient.CreateOrUpdateAsync(context.TODO(), "rg", "vmss1", vmss, "") assert.Nil(t, rerr) retryErr := &retry.Error{RawError: fmt.Errorf("error")} armClient.EXPECT().PutResourceAsync(gomock.Any(), ptr.Deref(vmss.ID, ""), vmss).Return(future, retryErr).Times(1) - _, rerr = vmssClient.CreateOrUpdateAsync(context.TODO(), "rg", "vmss1", vmss) + _, rerr = vmssClient.CreateOrUpdateAsync(context.TODO(), "rg", "vmss1", vmss, "") assert.Equal(t, retryErr, rerr) } @@ -615,7 +616,7 @@ func TestCreateOrUpdateAsyncNeverRateLimiter(t *testing.T) { armClient := mockarmclient.NewMockInterface(ctrl) vmssClient := getTestVMSSClientWithNeverRateLimiter(armClient) vmss := getTestVMSS("vmss1") - _, rerr := vmssClient.CreateOrUpdateAsync(context.TODO(), "rg", "vmss1", vmss) + _, rerr := vmssClient.CreateOrUpdateAsync(context.TODO(), "rg", "vmss1", vmss, "") assert.NotNil(t, rerr) assert.Equal(t, vmssCreateOrUpdateAsyncErr, rerr) } @@ -633,7 +634,7 @@ func TestCreateOrUpdateAsyncRetryAfterReader(t *testing.T) { vmss := getTestVMSS("vmss1") armClient := mockarmclient.NewMockInterface(ctrl) vmssClient := getTestVMSSClientWithRetryAfterReader(armClient) - _, rerr := vmssClient.CreateOrUpdateAsync(context.TODO(), "rg", "vmss1", vmss) + _, rerr := vmssClient.CreateOrUpdateAsync(context.TODO(), "rg", "vmss1", vmss, "") assert.NotNil(t, rerr) assert.Equal(t, vmssCreateOrUpdateAsyncErr, rerr) } @@ -655,7 +656,7 @@ func TestCreateOrUpdateAsyncThrottle(t *testing.T) { armClient.EXPECT().PutResourceAsync(gomock.Any(), ptr.Deref(vmss.ID, ""), vmss).Return(future, throttleErr).Times(1) vmssClient := getTestVMSSClient(armClient) - _, rerr := vmssClient.CreateOrUpdateAsync(context.TODO(), "rg", "vmss1", vmss) + _, rerr := vmssClient.CreateOrUpdateAsync(context.TODO(), "rg", "vmss1", vmss, "") assert.NotNil(t, rerr) assert.Equal(t, throttleErr, rerr) } @@ -834,14 +835,16 @@ func TestDeleteInstancesAsync(t *testing.T) { assert.Equal(t, retryErr, rerr) } -func getTestVMSS(name string) compute.VirtualMachineScaleSet { - return compute.VirtualMachineScaleSet{ - ID: ptr.To("/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss1"), - Name: ptr.To(name), - Location: ptr.To("eastus"), - Sku: &compute.Sku{ - Name: ptr.To("Standard"), - Capacity: ptr.To(int64(3)), +func getTestVMSS(name string) VirtualMachineScaleSet { + return VirtualMachineScaleSet{ + VirtualMachineScaleSet: compute.VirtualMachineScaleSet{ + ID: ptr.To("/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss1"), + Name: ptr.To(name), + Location: ptr.To("eastus"), + Sku: &compute.Sku{ + Name: ptr.To("Standard"), + Capacity: ptr.To(int64(3)), + }, }, } } @@ -883,3 +886,129 @@ func getTestVMSSClientWithRetryAfterReader(armClient armclient.Interface) *Clien func getFutureTime() time.Time { return time.Unix(3000000000, 0) } + +func getFakeVmssVM() VirtualMachineScaleSet { + testLBBackendpoolID := "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0" + virtualMachineScaleSetNetworkConfiguration := compute.VirtualMachineScaleSetNetworkConfiguration{ + VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + { + VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + LoadBalancerBackendAddressPools: &[]compute.SubResource{{ID: ptr.To(testLBBackendpoolID)}}, + Primary: ptr.To(true), + }, + }, + }, + }, + } + vmssVM := VirtualMachineScaleSet{ + VirtualMachineScaleSet: compute.VirtualMachineScaleSet{ + Location: ptr.To("eastus"), + VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ + NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: &[]compute.VirtualMachineScaleSetNetworkConfiguration{ + virtualMachineScaleSetNetworkConfiguration, + }, + }, + }, + OrchestrationMode: compute.Flexible, + }, + Tags: map[string]*string{ + consts.VMSetCIDRIPV4TagKey: ptr.To("24"), + consts.VMSetCIDRIPV6TagKey: ptr.To("64"), + }, + }, + Etag: ptr.To("\"120\""), + } + return vmssVM +} + +func TestMarshal(t *testing.T) { + fakeVmss := getFakeVmssVM() + fakeVmssWithoutEtag := getFakeVmssVM() + fakeVmssWithoutEtag.Etag = nil + fakeVmssWithoutCompueVMSS := getFakeVmssVM() + fakeVmssWithoutCompueVMSS.VirtualMachineScaleSet = compute.VirtualMachineScaleSet{} + testcases := []struct { + name string + vmss VirtualMachineScaleSet + expectJSON string + }{ + + { + name: "should return empty json when vmss is empty", + vmss: VirtualMachineScaleSet{}, + expectJSON: "{}", + }, + { + name: "should return only VirtualMachineScaleSet when etag is empty", + vmss: fakeVmssWithoutEtag, + expectJSON: `{"location":"eastus","properties":{"orchestrationMode":"Flexible","virtualMachineProfile":{"networkProfile":{"networkInterfaceConfigurations":[{"properties":{"ipConfigurations":[{"properties":{"primary":true,"loadBalancerBackendAddressPools":[{"id":"/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0"}]}}]}}]}}},"tags":{"kubernetesNodeCIDRMaskIPV4":"24","kubernetesNodeCIDRMaskIPV6":"64"}}`, + }, + { + name: "should return only etag json when vmss is empty", + vmss: fakeVmssWithoutCompueVMSS, + expectJSON: `{"etag":"\"120\""}`, + }, + { + name: "should return full json when both VirtualMachineScaleSet and etag are set", + vmss: fakeVmss, + expectJSON: `{"location":"eastus","properties":{"orchestrationMode":"Flexible","virtualMachineProfile":{"networkProfile":{"networkInterfaceConfigurations":[{"properties":{"ipConfigurations":[{"properties":{"primary":true,"loadBalancerBackendAddressPools":[{"id":"/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0"}]}}]}}]}}},"tags":{"kubernetesNodeCIDRMaskIPV4":"24","kubernetesNodeCIDRMaskIPV6":"64"},"etag":"\"120\""}`, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + actualJSON, err := json.Marshal(tt.vmss) + assert.Nil(t, err) + assert.Equal(t, string(actualJSON), tt.expectJSON) + }) + } +} + +func TestUnMarshal(t *testing.T) { + fakeVmss := getFakeVmssVM() + fakeVmssWithoutEtag := getFakeVmssVM() + fakeVmssWithoutEtag.Etag = nil + fakeVmssWithoutCompueVMSS := getFakeVmssVM() + fakeVmssWithoutCompueVMSS.VirtualMachineScaleSet = compute.VirtualMachineScaleSet{} + testcases := []struct { + name string + expectedVmss VirtualMachineScaleSet + inputJSON string + }{ + { + name: "should return empty json when vmss is empty", + expectedVmss: VirtualMachineScaleSet{}, + inputJSON: "{}", + }, + + { + name: "should return only compute.VirtualMachineScaleSetVM when etag is empty", + expectedVmss: fakeVmssWithoutEtag, + inputJSON: `{"location":"eastus","properties":{"orchestrationMode":"Flexible","virtualMachineProfile":{"networkProfile":{"networkInterfaceConfigurations":[{"properties":{"ipConfigurations":[{"properties":{"primary":true,"loadBalancerBackendAddressPools":[{"id":"/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0"}]}}]}}]}}},"tags":{"kubernetesNodeCIDRMaskIPV4":"24","kubernetesNodeCIDRMaskIPV6":"64"}}`, + }, + + { + name: "should return only etag json when vmss is empty", + expectedVmss: fakeVmssWithoutCompueVMSS, + inputJSON: `{"etag":"\"120\""}`, + }, + + { + name: "should return full json when both VirtualMachineScaleSetVM and etag are set", + expectedVmss: fakeVmss, + inputJSON: `{"location":"eastus","properties":{"orchestrationMode":"Flexible","virtualMachineProfile":{"networkProfile":{"networkInterfaceConfigurations":[{"properties":{"ipConfigurations":[{"properties":{"primary":true,"loadBalancerBackendAddressPools":[{"id":"/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0"}]}}]}}]}}},"tags":{"kubernetesNodeCIDRMaskIPV4":"24","kubernetesNodeCIDRMaskIPV6":"64"},"etag":"\"120\""}`, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + var actualVmss VirtualMachineScaleSet + err := json.Unmarshal([]byte(tt.inputJSON), &actualVmss) + assert.Nil(t, err) + assert.Equal(t, actualVmss, tt.expectedVmss) + }) + } +} diff --git a/pkg/azureclients/vmssclient/interface.go b/pkg/azureclients/vmssclient/interface.go index 10c9d04703..65aedbe6c5 100644 --- a/pkg/azureclients/vmssclient/interface.go +++ b/pkg/azureclients/vmssclient/interface.go @@ -28,7 +28,7 @@ import ( const ( // APIVersion is the API version for VMSS. - APIVersion = "2022-03-01" + APIVersion = "2024-03-01" // AzureStackCloudAPIVersion is the API version for Azure Stack AzureStackCloudAPIVersion = "2019-07-01" // AzureStackCloudName is the cloud name of Azure Stack @@ -36,19 +36,20 @@ const ( ) // Interface is the client interface for VirtualMachineScaleSet. +// For backward compatibility, the input // Don't forget to run "hack/update-mock-clients.sh" command to generate the mock client. type Interface interface { // Get gets a VirtualMachineScaleSet. - Get(ctx context.Context, resourceGroupName string, VMScaleSetName string) (result compute.VirtualMachineScaleSet, rerr *retry.Error) + Get(ctx context.Context, resourceGroupName string, VMScaleSetName string) (result VirtualMachineScaleSet, rerr *retry.Error) // List gets a list of VirtualMachineScaleSets in the resource group. - List(ctx context.Context, resourceGroupName string) (result []compute.VirtualMachineScaleSet, rerr *retry.Error) + List(ctx context.Context, resourceGroupName string) (result []VirtualMachineScaleSet, rerr *retry.Error) // CreateOrUpdate creates or updates a VirtualMachineScaleSet. - CreateOrUpdate(ctx context.Context, resourceGroupName string, VMScaleSetName string, parameters compute.VirtualMachineScaleSet) *retry.Error + CreateOrUpdate(ctx context.Context, resourceGroupName string, VMScaleSetName string, parameters VirtualMachineScaleSet, etag string) *retry.Error // CreateOrUpdateSync sends the request to arm client and DO NOT wait for the response - CreateOrUpdateAsync(ctx context.Context, resourceGroupName string, VMScaleSetName string, parameters compute.VirtualMachineScaleSet) (*azure.Future, *retry.Error) + CreateOrUpdateAsync(ctx context.Context, resourceGroupName string, VMScaleSetName string, parameters VirtualMachineScaleSet, etag string) (*azure.Future, *retry.Error) // WaitForAsyncOperationResult waits for the response of the request WaitForAsyncOperationResult(ctx context.Context, future *azure.Future, resourceGroupName, request, asyncOpName string) (*http.Response, error) diff --git a/pkg/azureclients/vmssclient/mockvmssclient/interface.go b/pkg/azureclients/vmssclient/mockvmssclient/interface.go index f3d166012a..d57f303d5d 100644 --- a/pkg/azureclients/vmssclient/mockvmssclient/interface.go +++ b/pkg/azureclients/vmssclient/mockvmssclient/interface.go @@ -34,6 +34,7 @@ import ( compute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" azure "github.com/Azure/go-autorest/autorest/azure" gomock "go.uber.org/mock/gomock" + vmssclient "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient" retry "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) @@ -61,32 +62,32 @@ func (m *MockInterface) EXPECT() *MockInterfaceMockRecorder { } // CreateOrUpdate mocks base method. -func (m *MockInterface) CreateOrUpdate(ctx context.Context, resourceGroupName, VMScaleSetName string, parameters compute.VirtualMachineScaleSet) *retry.Error { +func (m *MockInterface) CreateOrUpdate(ctx context.Context, resourceGroupName, VMScaleSetName string, parameters vmssclient.VirtualMachineScaleSet, etag string) *retry.Error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateOrUpdate", ctx, resourceGroupName, VMScaleSetName, parameters) + ret := m.ctrl.Call(m, "CreateOrUpdate", ctx, resourceGroupName, VMScaleSetName, parameters, etag) ret0, _ := ret[0].(*retry.Error) return ret0 } // CreateOrUpdate indicates an expected call of CreateOrUpdate. -func (mr *MockInterfaceMockRecorder) CreateOrUpdate(ctx, resourceGroupName, VMScaleSetName, parameters any) *gomock.Call { +func (mr *MockInterfaceMockRecorder) CreateOrUpdate(ctx, resourceGroupName, VMScaleSetName, parameters, etag any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdate", reflect.TypeOf((*MockInterface)(nil).CreateOrUpdate), ctx, resourceGroupName, VMScaleSetName, parameters) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdate", reflect.TypeOf((*MockInterface)(nil).CreateOrUpdate), ctx, resourceGroupName, VMScaleSetName, parameters, etag) } // CreateOrUpdateAsync mocks base method. -func (m *MockInterface) CreateOrUpdateAsync(ctx context.Context, resourceGroupName, VMScaleSetName string, parameters compute.VirtualMachineScaleSet) (*azure.Future, *retry.Error) { +func (m *MockInterface) CreateOrUpdateAsync(ctx context.Context, resourceGroupName, VMScaleSetName string, parameters vmssclient.VirtualMachineScaleSet, etag string) (*azure.Future, *retry.Error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateOrUpdateAsync", ctx, resourceGroupName, VMScaleSetName, parameters) + ret := m.ctrl.Call(m, "CreateOrUpdateAsync", ctx, resourceGroupName, VMScaleSetName, parameters, etag) ret0, _ := ret[0].(*azure.Future) ret1, _ := ret[1].(*retry.Error) return ret0, ret1 } // CreateOrUpdateAsync indicates an expected call of CreateOrUpdateAsync. -func (mr *MockInterfaceMockRecorder) CreateOrUpdateAsync(ctx, resourceGroupName, VMScaleSetName, parameters any) *gomock.Call { +func (mr *MockInterfaceMockRecorder) CreateOrUpdateAsync(ctx, resourceGroupName, VMScaleSetName, parameters, etag any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdateAsync", reflect.TypeOf((*MockInterface)(nil).CreateOrUpdateAsync), ctx, resourceGroupName, VMScaleSetName, parameters) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdateAsync", reflect.TypeOf((*MockInterface)(nil).CreateOrUpdateAsync), ctx, resourceGroupName, VMScaleSetName, parameters, etag) } // DeallocateInstancesAsync mocks base method. @@ -134,10 +135,10 @@ func (mr *MockInterfaceMockRecorder) DeleteInstancesAsync(ctx, resourceGroupName } // Get mocks base method. -func (m *MockInterface) Get(ctx context.Context, resourceGroupName, VMScaleSetName string) (compute.VirtualMachineScaleSet, *retry.Error) { +func (m *MockInterface) Get(ctx context.Context, resourceGroupName, VMScaleSetName string) (vmssclient.VirtualMachineScaleSet, *retry.Error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, VMScaleSetName) - ret0, _ := ret[0].(compute.VirtualMachineScaleSet) + ret0, _ := ret[0].(vmssclient.VirtualMachineScaleSet) ret1, _ := ret[1].(*retry.Error) return ret0, ret1 } @@ -149,10 +150,10 @@ func (mr *MockInterfaceMockRecorder) Get(ctx, resourceGroupName, VMScaleSetName } // List mocks base method. -func (m *MockInterface) List(ctx context.Context, resourceGroupName string) ([]compute.VirtualMachineScaleSet, *retry.Error) { +func (m *MockInterface) List(ctx context.Context, resourceGroupName string) ([]vmssclient.VirtualMachineScaleSet, *retry.Error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "List", ctx, resourceGroupName) - ret0, _ := ret[0].([]compute.VirtualMachineScaleSet) + ret0, _ := ret[0].([]vmssclient.VirtualMachineScaleSet) ret1, _ := ret[1].(*retry.Error) return ret0, ret1 } diff --git a/pkg/azureclients/vmssclient/models.go b/pkg/azureclients/vmssclient/models.go new file mode 100644 index 0000000000..94f25b3c8b --- /dev/null +++ b/pkg/azureclients/vmssclient/models.go @@ -0,0 +1,131 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vmssclient + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/to" +) + +// VirtualMachineScaleSet wraps the original VirtualMachineScaleSet struct and adds an Etag field. +type VirtualMachineScaleSet struct { + compute.VirtualMachineScaleSet `json:",inline"` + // READ-ONLY; Etag is property returned in Create/Update/Get response of the VMSS, so that customer can supply it in the header + // to ensure optimistic updates + Etag *string `json:"etag,omitempty"` +} + +// VirtualMachineScaleSetListResult the List Virtual Machine operation response. +type VirtualMachineScaleSetListResult struct { + autorest.Response `json:"-"` + // Value - The list of virtual machine scale sets. + Value *[]VirtualMachineScaleSet `json:"value,omitempty"` + // NextLink - The uri to fetch the next page of Virtual Machine Scale Sets. Call ListNext() with this to fetch the next page of VMSS. + NextLink *string `json:"nextLink,omitempty"` +} + +// IsEmpty returns true if the ListResult contains no values. +func (vmsslr VirtualMachineScaleSetListResult) IsEmpty() bool { + return vmsslr.Value == nil || len(*vmsslr.Value) == 0 +} + +// hasNextLink returns true if the NextLink is not empty. +func (vmsslr VirtualMachineScaleSetListResult) hasNextLink() bool { + return vmsslr.NextLink != nil && len(*vmsslr.NextLink) != 0 +} + +// virtualMachineScaleSetListResultPreparer prepares a request to retrieve the next set of results. +// It returns nil if no more results exist. +func (vmsslr VirtualMachineScaleSetListResult) virtualMachineScaleSetListResultPreparer(ctx context.Context) (*http.Request, error) { + if !vmsslr.hasNextLink() { + return nil, nil + } + return autorest.Prepare((&http.Request{}).WithContext(ctx), + autorest.AsJSON(), + autorest.AsGet(), + autorest.WithBaseURL(to.String(vmsslr.NextLink))) +} + +// UnmarshalJSON is the custom unmarshaler for VirtualMachineScaleSetVM struct. +// compute.VirtualMachineScaleSet implemented `UnmarshalJSON` method, and when the response is unmarshaled into VirtualMachineScaleSet, +// compute.VirtualMachineScaleSet.UnmarshalJSON is called, leading to the loss of the Etag field. +func (vmss *VirtualMachineScaleSet) UnmarshalJSON(data []byte) error { + // Unmarshal Etag first + etagPlaceholder := struct { + Etag *string `json:"etag,omitempty"` + }{} + if err := json.Unmarshal(data, &etagPlaceholder); err != nil { + return err + } + // Unmarshal Nested VirtualMachineScaleSet + nestedVirtualMachineScaleSet := struct { + compute.VirtualMachineScaleSet `json:",inline"` + }{} + // the Nested impl UnmarshalJSON, so it should be unmarshaled alone + if err := json.Unmarshal(data, &nestedVirtualMachineScaleSet); err != nil { + return err + } + (vmss).Etag = etagPlaceholder.Etag + (vmss).VirtualMachineScaleSet = nestedVirtualMachineScaleSet.VirtualMachineScaleSet + return nil +} + +// MarshalJSON is the custom marshaler for VirtualMachineScaleSet. +func (vmss VirtualMachineScaleSet) MarshalJSON() ([]byte, error) { + var err error + var nestedVirtualMachineScaleSetJSON, etagJSON []byte + if nestedVirtualMachineScaleSetJSON, err = vmss.VirtualMachineScaleSet.MarshalJSON(); err != nil { + return nil, err + } + + if vmss.Etag != nil { + if etagJSON, err = json.Marshal(map[string]interface{}{ + "etag": vmss.Etag, + }); err != nil { + return nil, err + } + } + + // empty struct can be Unmarshaled to "{}" + nestedVirtualMachineScaleSetJSONEmpty := true + if string(nestedVirtualMachineScaleSetJSON) != "{}" { + nestedVirtualMachineScaleSetJSONEmpty = false + } + etagJSONEmpty := true + if len(etagJSON) != 0 { + etagJSONEmpty = false + } + + // when both parts not empty, join the two parts with a comma but remove the open brace of nestedVirtualMachineScaleSetVMJson and the close brace of the etagJson + // {"location": "eastus"} + {"etag": "\"120\""} will be merged into {"location": "eastus", "etag": "\"120\""} + if !nestedVirtualMachineScaleSetJSONEmpty && !etagJSONEmpty { + etagJSON[0] = ',' + return append(nestedVirtualMachineScaleSetJSON[:len(nestedVirtualMachineScaleSetJSON)-1], etagJSON...), nil + } + if !nestedVirtualMachineScaleSetJSONEmpty { + return nestedVirtualMachineScaleSetJSON, nil + } + if !etagJSONEmpty { + return etagJSON, nil + } + return []byte("{}"), nil +} diff --git a/pkg/azureclients/vmssvmclient/azure_vmssvmclient.go b/pkg/azureclients/vmssvmclient/azure_vmssvmclient.go index cdc6a169d4..3bed0a90bb 100644 --- a/pkg/azureclients/vmssvmclient/azure_vmssvmclient.go +++ b/pkg/azureclients/vmssvmclient/azure_vmssvmclient.go @@ -92,20 +92,20 @@ func New(config *azclients.ClientConfig) *Client { } // Get gets a VirtualMachineScaleSetVM. -func (c *Client) Get(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, expand compute.InstanceViewTypes) (compute.VirtualMachineScaleSetVM, *retry.Error) { +func (c *Client) Get(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, expand compute.InstanceViewTypes) (VirtualMachineScaleSetVM, *retry.Error) { mc := metrics.NewMetricContext("vmssvm", "get", resourceGroupName, c.subscriptionID, "") // Report errors if the client is rate limited. if !c.rateLimiterReader.TryAccept() { mc.RateLimitedCount() - return compute.VirtualMachineScaleSetVM{}, retry.GetRateLimitError(false, "VMSSVMGet") + return VirtualMachineScaleSetVM{}, retry.GetRateLimitError(false, "VMSSVMGet") } // Report errors if the client is throttled. if c.RetryAfterReader.After(time.Now()) { mc.ThrottledCount() rerr := retry.GetThrottlingError("VMSSVMGet", "client throttled", c.RetryAfterReader) - return compute.VirtualMachineScaleSetVM{}, rerr + return VirtualMachineScaleSetVM{}, rerr } result, rerr := c.getVMSSVM(ctx, resourceGroupName, VMScaleSetName, instanceID, expand) @@ -123,7 +123,7 @@ func (c *Client) Get(ctx context.Context, resourceGroupName string, VMScaleSetNa } // getVMSSVM gets a VirtualMachineScaleSetVM. -func (c *Client) getVMSSVM(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, expand compute.InstanceViewTypes) (compute.VirtualMachineScaleSetVM, *retry.Error) { +func (c *Client) getVMSSVM(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, expand compute.InstanceViewTypes) (VirtualMachineScaleSetVM, *retry.Error) { resourceID := armclient.GetChildResourceID( c.subscriptionID, resourceGroupName, @@ -132,7 +132,7 @@ func (c *Client) getVMSSVM(ctx context.Context, resourceGroupName string, VMScal vmResourceType, instanceID, ) - result := compute.VirtualMachineScaleSetVM{} + result := VirtualMachineScaleSetVM{} response, rerr := c.armClient.GetResourceWithExpandQuery(ctx, resourceID, string(expand)) defer c.armClient.CloseResponse(ctx, response) @@ -155,7 +155,7 @@ func (c *Client) getVMSSVM(ctx context.Context, resourceGroupName string, VMScal } // List gets a list of VirtualMachineScaleSetVMs in the virtualMachineScaleSet. -func (c *Client) List(ctx context.Context, resourceGroupName string, virtualMachineScaleSetName string, expand string) ([]compute.VirtualMachineScaleSetVM, *retry.Error) { +func (c *Client) List(ctx context.Context, resourceGroupName string, virtualMachineScaleSetName string, expand string) ([]VirtualMachineScaleSetVM, *retry.Error) { mc := metrics.NewMetricContext("vmssvm", "list", resourceGroupName, c.subscriptionID, "") // Report errors if the client is rate limited. @@ -186,7 +186,7 @@ func (c *Client) List(ctx context.Context, resourceGroupName string, virtualMach } // listVMSSVM gets a list of VirtualMachineScaleSetVMs in the virtualMachineScaleSet. -func (c *Client) listVMSSVM(ctx context.Context, resourceGroupName string, virtualMachineScaleSetName string, expand string) ([]compute.VirtualMachineScaleSetVM, *retry.Error) { +func (c *Client) listVMSSVM(ctx context.Context, resourceGroupName string, virtualMachineScaleSetName string, expand string) ([]VirtualMachineScaleSetVM, *retry.Error) { resourceID := armclient.GetChildResourcesListID( c.subscriptionID, resourceGroupName, @@ -195,7 +195,7 @@ func (c *Client) listVMSSVM(ctx context.Context, resourceGroupName string, virtu vmResourceType, ) - result := make([]compute.VirtualMachineScaleSetVM, 0) + result := make([]VirtualMachineScaleSetVM, 0) page := &VirtualMachineScaleSetVMListResultPage{} page.fn = c.listNextResults @@ -231,7 +231,7 @@ func (c *Client) listVMSSVM(ctx context.Context, resourceGroupName string, virtu } // Update updates a VirtualMachineScaleSetVM. -func (c *Client) Update(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, parameters compute.VirtualMachineScaleSetVM, source string) (*compute.VirtualMachineScaleSetVM, *retry.Error) { +func (c *Client) Update(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, parameters VirtualMachineScaleSetVM, source, etag string) (*VirtualMachineScaleSetVM, *retry.Error) { mc := metrics.NewMetricContext("vmssvm", "update", resourceGroupName, c.subscriptionID, source) // Report errors if the client is rate limited. @@ -247,7 +247,7 @@ func (c *Client) Update(ctx context.Context, resourceGroupName string, VMScaleSe return nil, rerr } - result, rerr := c.updateVMSSVM(ctx, resourceGroupName, VMScaleSetName, instanceID, parameters) + result, rerr := c.updateVMSSVM(ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, etag) mc.Observe(rerr) if rerr != nil { if rerr.IsThrottled() { @@ -260,7 +260,7 @@ func (c *Client) Update(ctx context.Context, resourceGroupName string, VMScaleSe } // UpdateAsync updates a VirtualMachineScaleSetVM asynchronously -func (c *Client) UpdateAsync(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, parameters compute.VirtualMachineScaleSetVM, source string) (*azure.Future, *retry.Error) { +func (c *Client) UpdateAsync(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, parameters VirtualMachineScaleSetVM, source string, etag string) (*azure.Future, *retry.Error) { mc := metrics.NewMetricContext("vmssvm", "updateasync", resourceGroupName, c.subscriptionID, source) // Report errors if the client is rate limited. @@ -285,7 +285,12 @@ func (c *Client) UpdateAsync(ctx context.Context, resourceGroupName string, VMSc instanceID, ) - future, rerr := c.armClient.PutResourceAsync(ctx, resourceID, parameters) + decorators := []autorest.PrepareDecorator{} + if etag != "" { + decorators = append(decorators, autorest.WithHeader("If-Match", autorest.String(etag))) + } + + future, rerr := c.armClient.PutResourceAsync(ctx, resourceID, parameters, decorators...) mc.Observe(rerr) if rerr != nil { if rerr.IsThrottled() { @@ -300,7 +305,7 @@ func (c *Client) UpdateAsync(ctx context.Context, resourceGroupName string, VMSc } // WaitForUpdateResult waits for the response of the update request -func (c *Client) WaitForUpdateResult(ctx context.Context, future *azure.Future, resourceGroupName, source string) (*compute.VirtualMachineScaleSetVM, *retry.Error) { +func (c *Client) WaitForUpdateResult(ctx context.Context, future *azure.Future, resourceGroupName, source string) (*VirtualMachineScaleSetVM, *retry.Error) { mc := metrics.NewMetricContext("vmss", "wait_for_update_result", resourceGroupName, c.subscriptionID, source) response, err := c.armClient.WaitForAsyncOperationResult(ctx, future, "VMSSWaitForUpdateResult") mc.Observe(retry.NewErrorOrNil(false, err)) @@ -324,13 +329,13 @@ func (c *Client) WaitForUpdateResult(ctx context.Context, future *azure.Future, return result, rerr } - result := &compute.VirtualMachineScaleSetVM{} + result := &VirtualMachineScaleSetVM{} result.Response = autorest.Response{Response: response} return result, nil } // updateVMSSVM updates a VirtualMachineScaleSetVM. -func (c *Client) updateVMSSVM(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, parameters compute.VirtualMachineScaleSetVM) (*compute.VirtualMachineScaleSetVM, *retry.Error) { +func (c *Client) updateVMSSVM(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, parameters VirtualMachineScaleSetVM, etag string) (*VirtualMachineScaleSetVM, *retry.Error) { resourceID := armclient.GetChildResourceID( c.subscriptionID, resourceGroupName, @@ -340,7 +345,12 @@ func (c *Client) updateVMSSVM(ctx context.Context, resourceGroupName string, VMS instanceID, ) - response, rerr := c.armClient.PutResource(ctx, resourceID, parameters) + decorators := []autorest.PrepareDecorator{} + if etag != "" { + decorators = append(decorators, autorest.WithHeader("If-Match", autorest.String(etag))) + } + + response, rerr := c.armClient.PutResource(ctx, resourceID, parameters, decorators...) defer c.armClient.CloseResponse(ctx, response) if rerr != nil { klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vmssvm.put.request", resourceID, rerr.Error()) @@ -355,13 +365,13 @@ func (c *Client) updateVMSSVM(ctx context.Context, resourceGroupName string, VMS return result, rerr } - result := &compute.VirtualMachineScaleSetVM{} + result := &VirtualMachineScaleSetVM{} result.Response = autorest.Response{Response: response} return result, nil } -func (c *Client) updateResponder(resp *http.Response) (*compute.VirtualMachineScaleSetVM, *retry.Error) { - result := &compute.VirtualMachineScaleSetVM{} +func (c *Client) updateResponder(resp *http.Response) (*VirtualMachineScaleSetVM, *retry.Error) { + result := &VirtualMachineScaleSetVM{} err := autorest.Respond( resp, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusCreated), @@ -370,7 +380,7 @@ func (c *Client) updateResponder(resp *http.Response) (*compute.VirtualMachineSc return result, retry.GetError(resp, err) } -func (c *Client) listResponder(resp *http.Response) (result compute.VirtualMachineScaleSetVMListResult, err error) { +func (c *Client) listResponder(resp *http.Response) (result VirtualMachineScaleSetVMListResult, err error) { err = autorest.Respond( resp, autorest.ByIgnoring(), @@ -382,7 +392,7 @@ func (c *Client) listResponder(resp *http.Response) (result compute.VirtualMachi // virtualMachineScaleSetListResultPreparer prepares a request to retrieve the next set of results. // It returns nil if no more results exist. -func (c *Client) virtualMachineScaleSetVMListResultPreparer(ctx context.Context, vmssvmlr compute.VirtualMachineScaleSetVMListResult) (*http.Request, error) { +func (c *Client) virtualMachineScaleSetVMListResultPreparer(ctx context.Context, vmssvmlr VirtualMachineScaleSetVMListResult) (*http.Request, error) { if vmssvmlr.NextLink == nil || len(ptr.Deref(vmssvmlr.NextLink, "")) < 1 { return nil, nil } @@ -394,7 +404,7 @@ func (c *Client) virtualMachineScaleSetVMListResultPreparer(ctx context.Context, } // listNextResults retrieves the next set of results, if any. -func (c *Client) listNextResults(ctx context.Context, lastResults compute.VirtualMachineScaleSetVMListResult) (result compute.VirtualMachineScaleSetVMListResult, err error) { +func (c *Client) listNextResults(ctx context.Context, lastResults VirtualMachineScaleSetVMListResult) (result VirtualMachineScaleSetVMListResult, err error) { req, err := c.virtualMachineScaleSetVMListResultPreparer(ctx, lastResults) if err != nil { return result, autorest.NewErrorWithError(err, "vmssvmclient", "listNextResults", nil, "Failure preparing next results request") @@ -420,8 +430,8 @@ func (c *Client) listNextResults(ctx context.Context, lastResults compute.Virtua // VirtualMachineScaleSetVMListResultPage contains a page of VirtualMachineScaleSetVM values. type VirtualMachineScaleSetVMListResultPage struct { - fn func(context.Context, compute.VirtualMachineScaleSetVMListResult) (compute.VirtualMachineScaleSetVMListResult, error) - vmssvlr compute.VirtualMachineScaleSetVMListResult + fn func(context.Context, VirtualMachineScaleSetVMListResult) (VirtualMachineScaleSetVMListResult, error) + vmssvlr VirtualMachineScaleSetVMListResult } // NextWithContext advances to the next page of values. If there was an error making @@ -448,12 +458,12 @@ func (page VirtualMachineScaleSetVMListResultPage) NotDone() bool { } // Response returns the raw server response from the last page request. -func (page VirtualMachineScaleSetVMListResultPage) Response() compute.VirtualMachineScaleSetVMListResult { +func (page VirtualMachineScaleSetVMListResultPage) Response() VirtualMachineScaleSetVMListResult { return page.vmssvlr } // Values returns the slice of values for the current page or nil if there are no values. -func (page VirtualMachineScaleSetVMListResultPage) Values() []compute.VirtualMachineScaleSetVM { +func (page VirtualMachineScaleSetVMListResultPage) Values() []VirtualMachineScaleSetVM { if page.vmssvlr.IsEmpty() { return nil } @@ -463,7 +473,7 @@ func (page VirtualMachineScaleSetVMListResultPage) Values() []compute.VirtualMac // UpdateVMs updates a list of VirtualMachineScaleSetVM from map[instanceID]compute.VirtualMachineScaleSetVM. // If the batch size > 0, it will send sync requests concurrently in batches, or it will send sync requests in sequence. // No matter what the batch size is, it will process the async requests concurrently in one single batch. -func (c *Client) UpdateVMs(ctx context.Context, resourceGroupName string, VMScaleSetName string, instances map[string]compute.VirtualMachineScaleSetVM, source string, batchSize int) *retry.Error { +func (c *Client) UpdateVMs(ctx context.Context, resourceGroupName string, VMScaleSetName string, instances map[string]VirtualMachineScaleSetVM, source string, batchSize int) *retry.Error { mc := metrics.NewMetricContext("vmssvm", "update_vms", resourceGroupName, c.subscriptionID, source) // Report errors if the client is rate limited. @@ -494,7 +504,7 @@ func (c *Client) UpdateVMs(ctx context.Context, resourceGroupName string, VMScal } // updateVMSSVMs updates a list of VirtualMachineScaleSetVM from map[instanceID]compute.VirtualMachineScaleSetVM. -func (c *Client) updateVMSSVMs(ctx context.Context, resourceGroupName string, VMScaleSetName string, instances map[string]compute.VirtualMachineScaleSetVM, batchSize int) *retry.Error { +func (c *Client) updateVMSSVMs(ctx context.Context, resourceGroupName string, VMScaleSetName string, instances map[string]VirtualMachineScaleSetVM, batchSize int) *retry.Error { resources := make(map[string]interface{}) for instanceID, parameter := range instances { resourceID := armclient.GetChildResourceID( @@ -508,14 +518,14 @@ func (c *Client) updateVMSSVMs(ctx context.Context, resourceGroupName string, VM resources[resourceID] = parameter } - responses := c.armClient.PutResourcesInBatches(ctx, resources, batchSize) + responses := c.armClient.PutResourcesInBatchesWithEtag(ctx, resources, batchSize) errors, retryIDs := c.parseResp(ctx, responses, true) if len(retryIDs) > 0 { retryResources := make(map[string]interface{}) for _, id := range retryIDs { retryResources[id] = resources[id] } - resps := c.armClient.PutResourcesInBatches(ctx, retryResources, batchSize) + resps := c.armClient.PutResourcesInBatchesWithEtag(ctx, retryResources, batchSize) errs, _ := c.parseResp(ctx, resps, false) errors = append(errors, errs...) } diff --git a/pkg/azureclients/vmssvmclient/azure_vmssvmclient_test.go b/pkg/azureclients/vmssvmclient/azure_vmssvmclient_test.go index 269ed208b5..e5fc8a3c38 100644 --- a/pkg/azureclients/vmssvmclient/azure_vmssvmclient_test.go +++ b/pkg/azureclients/vmssvmclient/azure_vmssvmclient_test.go @@ -104,7 +104,7 @@ func TestGet(t *testing.T) { armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourceID, "InstanceView").Return(response, nil).Times(1) armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - expected := compute.VirtualMachineScaleSetVM{Response: autorest.Response{Response: response}} + expected := VirtualMachineScaleSetVM{VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{Response: autorest.Response{Response: response}}} vmssvmClient := getTestVMSSVMClient(armClient) result, rerr := vmssvmClient.Get(context.TODO(), "rg", "vmss1", "0", "InstanceView") assert.Equal(t, expected, result) @@ -122,7 +122,7 @@ func TestGetNeverRateLimiter(t *testing.T) { armClient := mockarmclient.NewMockInterface(ctrl) vmssvmClient := getTestVMSSVMClientWithNeverRateLimiter(armClient) - expected := compute.VirtualMachineScaleSetVM{} + expected := VirtualMachineScaleSetVM{} result, rerr := vmssvmClient.Get(context.TODO(), "rg", "vmss1", "0", "InstanceView") assert.Equal(t, expected, result) assert.Equal(t, vmssvmGetErr, rerr) @@ -140,7 +140,7 @@ func TestGetRetryAfterReader(t *testing.T) { armClient := mockarmclient.NewMockInterface(ctrl) vmssvmClient := getTestVMSSVMClientWithRetryAfterReader(armClient) - expected := compute.VirtualMachineScaleSetVM{} + expected := VirtualMachineScaleSetVM{} result, rerr := vmssvmClient.Get(context.TODO(), "rg", "vmss1", "0", "InstanceView") assert.Equal(t, expected, result) assert.Equal(t, vmssvmGetErr, rerr) @@ -159,7 +159,7 @@ func TestGetNotFound(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssClient := getTestVMSSVMClient(armClient) - expectedVM := compute.VirtualMachineScaleSetVM{Response: autorest.Response{}} + expectedVM := VirtualMachineScaleSetVM{VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{Response: autorest.Response{}}} result, rerr := vmssClient.Get(context.TODO(), "rg", "vmss1", "0", "InstanceView") assert.Equal(t, expectedVM, result) assert.NotNil(t, rerr) @@ -180,7 +180,7 @@ func TestGetInternalError(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssClient := getTestVMSSVMClient(armClient) - expectedVM := compute.VirtualMachineScaleSetVM{Response: autorest.Response{}} + expectedVM := VirtualMachineScaleSetVM{VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{Response: autorest.Response{}}} result, rerr := vmssClient.Get(context.TODO(), "rg", "vmss1", "1", "InstanceView") assert.Equal(t, expectedVM, result) assert.NotNil(t, rerr) @@ -216,8 +216,8 @@ func TestList(t *testing.T) { defer ctrl.Finish() armClient := mockarmclient.NewMockInterface(ctrl) - vmssList := []compute.VirtualMachineScaleSetVM{getTestVMSSVM("vmss1", "1"), getTestVMSSVM("vmss1", "2"), getTestVMSSVM("vmss1", "3")} - responseBody, err := json.Marshal(compute.VirtualMachineScaleSetVMListResult{Value: &vmssList}) + vmssList := []VirtualMachineScaleSetVM{getTestVMSSVM("vmss1", "1"), getTestVMSSVM("vmss1", "2"), getTestVMSSVM("vmss1", "3")} + responseBody, err := json.Marshal(VirtualMachineScaleSetVMListResult{Value: &vmssList}) assert.NoError(t, err) armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourcePrefix, "InstanceView").Return( &http.Response{ @@ -245,7 +245,7 @@ func TestListNotFound(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssvmClient := getTestVMSSVMClient(armClient) - expected := []compute.VirtualMachineScaleSetVM{} + expected := []VirtualMachineScaleSetVM{} result, rerr := vmssvmClient.List(context.TODO(), "rg", "vmss1", "InstanceView") assert.Equal(t, expected, result) assert.NotNil(t, rerr) @@ -265,7 +265,7 @@ func TestListInternalError(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssvmClient := getTestVMSSVMClient(armClient) - expected := []compute.VirtualMachineScaleSetVM{} + expected := []VirtualMachineScaleSetVM{} result, rerr := vmssvmClient.List(context.TODO(), "rg", "vmss1", "InstanceView") assert.Equal(t, expected, result) assert.NotNil(t, rerr) @@ -301,8 +301,8 @@ func TestListWithListResponderError(t *testing.T) { defer ctrl.Finish() armClient := mockarmclient.NewMockInterface(ctrl) - vmssvmList := []compute.VirtualMachineScaleSetVM{getTestVMSSVM("vmss1", "1"), getTestVMSSVM("vmss1", "2"), getTestVMSSVM("vmss1", "3")} - responseBody, err := json.Marshal(compute.VirtualMachineScaleSetVMListResult{Value: &vmssvmList}) + vmssvmList := []VirtualMachineScaleSetVM{getTestVMSSVM("vmss1", "1"), getTestVMSSVM("vmss1", "2"), getTestVMSSVM("vmss1", "3")} + responseBody, err := json.Marshal(VirtualMachineScaleSetVMListResult{Value: &vmssvmList}) assert.NoError(t, err) armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourcePrefix, "InstanceView").Return( &http.Response{ @@ -321,10 +321,10 @@ func TestListWithNextPage(t *testing.T) { defer ctrl.Finish() armClient := mockarmclient.NewMockInterface(ctrl) - vmssvmList := []compute.VirtualMachineScaleSetVM{getTestVMSSVM("vmss1", "1"), getTestVMSSVM("vmss1", "2"), getTestVMSSVM("vmss1", "3")} - partialResponse, err := json.Marshal(compute.VirtualMachineScaleSetVMListResult{Value: &vmssvmList, NextLink: ptr.To("nextLink")}) + vmssvmList := []VirtualMachineScaleSetVM{getTestVMSSVM("vmss1", "1"), getTestVMSSVM("vmss1", "2"), getTestVMSSVM("vmss1", "3")} + partialResponse, err := json.Marshal(VirtualMachineScaleSetVMListResult{Value: &vmssvmList, NextLink: ptr.To("nextLink")}) assert.NoError(t, err) - pagedResponse, err := json.Marshal(compute.VirtualMachineScaleSetVMListResult{Value: &vmssvmList}) + pagedResponse, err := json.Marshal(VirtualMachineScaleSetVMListResult{Value: &vmssvmList}) assert.NoError(t, err) armClient.EXPECT().PrepareGetRequest(gomock.Any(), gomock.Any()).Return(&http.Request{}, nil) armClient.EXPECT().Send(gomock.Any(), gomock.Any()).Return( @@ -406,7 +406,7 @@ func TestListNextResultsMultiPages(t *testing.T) { }, } - lastResult := compute.VirtualMachineScaleSetVMListResult{ + lastResult := VirtualMachineScaleSetVMListResult{ NextLink: ptr.To("next"), } @@ -462,7 +462,7 @@ func TestListNextResultsMultiPagesWithListResponderError(t *testing.T) { }, } - lastResult := compute.VirtualMachineScaleSetVMListResult{ + lastResult := VirtualMachineScaleSetVMListResult{ NextLink: ptr.To("next"), } @@ -510,11 +510,11 @@ func TestUpdate(t *testing.T) { armClient.EXPECT().PutResource(gomock.Any(), ptr.Deref(vmssVM.ID, ""), vmssVM).Return(response, nil).Times(1) armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - expected := &compute.VirtualMachineScaleSetVM{} + expected := &VirtualMachineScaleSetVM{} expected.Response = autorest.Response{Response: response} vmssClient := getTestVMSSVMClient(armClient) - result, rerr := vmssClient.Update(context.TODO(), "rg", "vmss1", "0", vmssVM, "test") + result, rerr := vmssClient.Update(context.TODO(), "rg", "vmss1", "0", vmssVM, "test", "") assert.Nil(t, rerr) assert.Equal(t, expected, result) } @@ -528,7 +528,7 @@ func TestUpdateAsync(t *testing.T) { armClient.EXPECT().PutResourceAsync(gomock.Any(), ptr.Deref(vmssVM.ID, ""), vmssVM).Return(nil, nil).Times(1) vmssClient := getTestVMSSVMClient(armClient) - future, rerr := vmssClient.UpdateAsync(context.TODO(), "rg", "vmss1", "0", vmssVM, "test") + future, rerr := vmssClient.UpdateAsync(context.TODO(), "rg", "vmss1", "0", vmssVM, "test", "") assert.Nil(t, rerr) assert.Nil(t, future) } @@ -583,9 +583,9 @@ func TestWaitForUpdateResult(t *testing.T) { vmssClient := getTestVMSSVMClient(armClient) response, err := vmssClient.WaitForUpdateResult(context.TODO(), &azure.Future{}, "rg", "test") assert.Equal(t, err, test.expectedResult) - var output *compute.VirtualMachineScaleSetVM + var output *VirtualMachineScaleSetVM if err == nil { - output = &compute.VirtualMachineScaleSetVM{} + output = &VirtualMachineScaleSetVM{} output.Response = autorest.Response{Response: test.response} } assert.Equal(t, response, output) @@ -604,11 +604,11 @@ func TestUpdateWithUpdateResponderError(t *testing.T) { } armClient.EXPECT().PutResource(gomock.Any(), ptr.Deref(vmssVM.ID, ""), vmssVM).Return(response, nil).Times(1) armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - expected := &compute.VirtualMachineScaleSetVM{} + expected := &VirtualMachineScaleSetVM{} expected.Response = autorest.Response{Response: response} vmssvmClient := getTestVMSSVMClient(armClient) - result, rerr := vmssvmClient.Update(context.TODO(), "rg", "vmss1", "0", vmssVM, "test") + result, rerr := vmssvmClient.Update(context.TODO(), "rg", "vmss1", "0", vmssVM, "test", "") assert.NotNil(t, rerr) assert.Equal(t, expected, result) } @@ -625,8 +625,8 @@ func TestUpdateNeverRateLimiter(t *testing.T) { armClient := mockarmclient.NewMockInterface(ctrl) vmssvmClient := getTestVMSSVMClientWithNeverRateLimiter(armClient) vmssVM := getTestVMSSVM("vmss1", "0") - var expected *compute.VirtualMachineScaleSetVM - result, rerr := vmssvmClient.Update(context.TODO(), "rg", "vmss1", "0", vmssVM, "test") + var expected *VirtualMachineScaleSetVM + result, rerr := vmssvmClient.Update(context.TODO(), "rg", "vmss1", "0", vmssVM, "test", "") assert.NotNil(t, rerr) assert.Equal(t, vmssvmUpdateErr, rerr) assert.Equal(t, expected, result) @@ -645,8 +645,8 @@ func TestUpdateRetryAfterReader(t *testing.T) { vmssVM := getTestVMSSVM("vmss1", "0") armClient := mockarmclient.NewMockInterface(ctrl) vmClient := getTestVMSSVMClientWithRetryAfterReader(armClient) - var expected *compute.VirtualMachineScaleSetVM - result, rerr := vmClient.Update(context.TODO(), "rg", "vmss1", "0", vmssVM, "test") + var expected *VirtualMachineScaleSetVM + result, rerr := vmClient.Update(context.TODO(), "rg", "vmss1", "0", vmssVM, "test", "") assert.NotNil(t, rerr) assert.Equal(t, vmssvmUpdateErr, rerr) assert.Equal(t, expected, result) @@ -673,8 +673,8 @@ func TestUpdateThrottle(t *testing.T) { armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) vmssvmClient := getTestVMSSVMClient(armClient) - var expected *compute.VirtualMachineScaleSetVM - result, rerr := vmssvmClient.Update(context.TODO(), "rg", "vmss1", "0", vmssVM, "test") + var expected *VirtualMachineScaleSetVM + result, rerr := vmssvmClient.Update(context.TODO(), "rg", "vmss1", "0", vmssVM, "test", "") assert.NotNil(t, rerr) assert.Equal(t, throttleErr, rerr) assert.Equal(t, expected, result) @@ -686,7 +686,7 @@ func TestUpdateVMs(t *testing.T) { vmssVM1 := getTestVMSSVM("vmss1", "1") vmssVM2 := getTestVMSSVM("vmss1", "2") - instances := map[string]compute.VirtualMachineScaleSetVM{ + instances := map[string]VirtualMachineScaleSetVM{ "1": vmssVM1, "2": vmssVM2, } @@ -721,7 +721,7 @@ func TestUpdateVMsWithUpdateVMsResponderError(t *testing.T) { defer ctrl.Finish() vmssVM := getTestVMSSVM("vmss1", "1") - instances := map[string]compute.VirtualMachineScaleSetVM{ + instances := map[string]VirtualMachineScaleSetVM{ "1": vmssVM, } testvmssVMs := map[string]interface{}{ @@ -751,7 +751,7 @@ func TestUpdateVMsPreemptedRetry(t *testing.T) { vmssVM1 := getTestVMSSVM("vmss1", "1") vmssVM2 := getTestVMSSVM("vmss1", "2") - instances := map[string]compute.VirtualMachineScaleSetVM{ + instances := map[string]VirtualMachineScaleSetVM{ "1": vmssVM1, "2": vmssVM2, } @@ -807,7 +807,7 @@ func TestUpdateVMsNeverRateLimiter(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - instances := map[string]compute.VirtualMachineScaleSetVM{} + instances := map[string]VirtualMachineScaleSetVM{} vmssvmUpdateVMsErr := &retry.Error{ RawError: fmt.Errorf("azure cloud provider rate limited(%s) for operation %q", "write", "VMSSVMUpdateVMs"), Retriable: true, @@ -824,7 +824,7 @@ func TestUpdateVMsRetryAfterReader(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - instances := map[string]compute.VirtualMachineScaleSetVM{} + instances := map[string]VirtualMachineScaleSetVM{} vmssvmUpdateVMsErr := &retry.Error{ RawError: fmt.Errorf("azure cloud provider throttled for operation %s with reason %q", "VMSSVMUpdateVMs", "client throttled"), Retriable: true, @@ -843,7 +843,7 @@ func TestUpdateVMsThrottle(t *testing.T) { defer ctrl.Finish() vmssVM := getTestVMSSVM("vmss1", "1") - instances := map[string]compute.VirtualMachineScaleSetVM{ + instances := map[string]VirtualMachineScaleSetVM{ "1": vmssVM, } testvmssVMs := map[string]interface{}{ @@ -883,7 +883,7 @@ func TestUpdateVMsIgnoreError(t *testing.T) { vmssVM2 := getTestVMSSVM("vmss1", "2") vmssVM3 := getTestVMSSVM("vmss1", "3") vmssVM4 := getTestVMSSVM("vmss1", "4") - instances := map[string]compute.VirtualMachineScaleSetVM{ + instances := map[string]VirtualMachineScaleSetVM{ "1": vmssVM, "2": vmssVM2, "3": vmssVM3, @@ -936,13 +936,16 @@ func TestUpdateVMsIgnoreError(t *testing.T) { assert.Equal(t, rerr.Error().Error(), "Retriable: false, RetryAfter: 4s, HTTPStatusCode: 0, RawError: Retriable: true, RetryAfter: 4s, HTTPStatusCode: 0, RawError: The request failed due to conflict with a concurrent request.") } -func getTestVMSSVM(vmssName, instanceID string) compute.VirtualMachineScaleSetVM { +func getTestVMSSVM(vmssName, instanceID string) VirtualMachineScaleSetVM { resourceID := fmt.Sprintf("/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/%s/virtualMachines/%s", vmssName, instanceID) - return compute.VirtualMachineScaleSetVM{ - ID: ptr.To(resourceID), - InstanceID: ptr.To(instanceID), - Location: ptr.To("eastus"), + return VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{ + ID: ptr.To(resourceID), + InstanceID: ptr.To(instanceID), + Location: ptr.To("eastus"), + }, } + } func getTestVMSSVMClient(armClient armclient.Interface) *Client { @@ -982,3 +985,153 @@ func getTestVMSSVMClientWithRetryAfterReader(armClient armclient.Interface) *Cli func getFutureTime() time.Time { return time.Unix(3000000000, 0) } + +func getFakeVmssVM() VirtualMachineScaleSetVM { + index := 0 + scaleSetName := "fakevmss" + interfaceID := fmt.Sprintf("/subscriptions/fakesub/resourceGroups/fakerg/providers/Microsoft.Compute/virtualMachineScaleSets/%s/virtualMachines/%d/networkInterfaces/fakenic", scaleSetName, index) + nodeName := fmt.Sprintf("%s000000", scaleSetName) + testLBBackendpoolID := "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0" + + // set vmss virtual machine. + networkInterfaces := []compute.NetworkInterfaceReference{ + { + ID: &interfaceID, + NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Primary: ptr.To(true), + }, + }, + } + ipConfigurations := []compute.VirtualMachineScaleSetIPConfiguration{ + { + Name: ptr.To("ipconfig1"), + VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Primary: ptr.To(true), + LoadBalancerBackendAddressPools: &[]compute.SubResource{{ID: ptr.To(testLBBackendpoolID)}}, + PrivateIPAddressVersion: compute.IPv4, + }, + }, + } + networkConfigurations := []compute.VirtualMachineScaleSetNetworkConfiguration{ + { + Name: ptr.To("vmss-nic"), + ID: ptr.To("fakeNetworkConfiguration"), + VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: &ipConfigurations, + Primary: ptr.To(true), + }, + }, + } + + vmssVM := VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ + OsProfile: &compute.OSProfile{ + ComputerName: &nodeName, + }, + NetworkProfile: &compute.NetworkProfile{ + NetworkInterfaces: &networkInterfaces, + }, + HardwareProfile: &compute.HardwareProfile{ + VMSize: compute.StandardD2sV3, + }, + NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ + NetworkInterfaceConfigurations: &networkConfigurations, + }, + }, + Location: ptr.To("eastus"), + }, + Etag: ptr.To("\"120\""), + } + return vmssVM +} + +func TestMarshal(t *testing.T) { + fakeVmssVM := getFakeVmssVM() + fakeVmssVMWithoutEtag := getFakeVmssVM() + fakeVmssVMWithoutEtag.Etag = nil + fakeVmssVMWithoutCompueVMSSVM := getFakeVmssVM() + fakeVmssVMWithoutCompueVMSSVM.VirtualMachineScaleSetVM = compute.VirtualMachineScaleSetVM{} + testcases := []struct { + name string + vmss VirtualMachineScaleSetVM + expectJSON string + }{ + + { + name: "should return empty json when vmss is empty", + vmss: VirtualMachineScaleSetVM{}, + expectJSON: "{}", + }, + { + name: "should return only compute.VirtualMachineScaleSetVM when etag is empty", + vmss: fakeVmssVMWithoutEtag, + expectJSON: `{"location":"eastus","properties":{"hardwareProfile":{"vmSize":"Standard_D2s_v3"},"networkProfile":{"networkInterfaces":[{"id":"/subscriptions/fakesub/resourceGroups/fakerg/providers/Microsoft.Compute/virtualMachineScaleSets/fakevmss/virtualMachines/0/networkInterfaces/fakenic","properties":{"primary":true}}]},"networkProfileConfiguration":{"networkInterfaceConfigurations":[{"id":"fakeNetworkConfiguration","name":"vmss-nic","properties":{"primary":true,"ipConfigurations":[{"name":"ipconfig1","properties":{"primary":true,"privateIPAddressVersion":"IPv4","loadBalancerBackendAddressPools":[{"id":"/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0"}]}}]}}]},"osProfile":{"computerName":"fakevmss000000"}}}`, + }, + { + name: "should return only etag json when vmss is empty", + vmss: fakeVmssVMWithoutCompueVMSSVM, + expectJSON: `{"etag":"\"120\""}`, + }, + + { + name: "should return full json when both VirtualMachineScaleSetVM and etag are set", + vmss: fakeVmssVM, + expectJSON: `{"location":"eastus","properties":{"hardwareProfile":{"vmSize":"Standard_D2s_v3"},"networkProfile":{"networkInterfaces":[{"id":"/subscriptions/fakesub/resourceGroups/fakerg/providers/Microsoft.Compute/virtualMachineScaleSets/fakevmss/virtualMachines/0/networkInterfaces/fakenic","properties":{"primary":true}}]},"networkProfileConfiguration":{"networkInterfaceConfigurations":[{"id":"fakeNetworkConfiguration","name":"vmss-nic","properties":{"primary":true,"ipConfigurations":[{"name":"ipconfig1","properties":{"primary":true,"privateIPAddressVersion":"IPv4","loadBalancerBackendAddressPools":[{"id":"/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0"}]}}]}}]},"osProfile":{"computerName":"fakevmss000000"}},"etag":"\"120\""}`, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + actualJSON, err := json.Marshal(tt.vmss) + assert.Nil(t, err) + assert.Equal(t, string(actualJSON), tt.expectJSON) + }) + } +} + +func TestUnMarshal(t *testing.T) { + fakeVmssVM := getFakeVmssVM() + fakeVmssVMWithoutEtag := getFakeVmssVM() + fakeVmssVMWithoutEtag.Etag = nil + fakeVmssVMWithoutCompueVMSSVM := getFakeVmssVM() + fakeVmssVMWithoutCompueVMSSVM.VirtualMachineScaleSetVM = compute.VirtualMachineScaleSetVM{} + testcases := []struct { + name string + expectedVmssVM VirtualMachineScaleSetVM + inputJSON string + }{ + { + name: "should return empty json when vmss is empty", + expectedVmssVM: VirtualMachineScaleSetVM{}, + inputJSON: "{}", + }, + + { + name: "should return only compute.VirtualMachineScaleSetVM when etag is empty", + expectedVmssVM: fakeVmssVMWithoutEtag, + inputJSON: `{"location":"eastus","properties":{"hardwareProfile":{"vmSize":"Standard_D2s_v3"},"networkProfile":{"networkInterfaces":[{"id":"/subscriptions/fakesub/resourceGroups/fakerg/providers/Microsoft.Compute/virtualMachineScaleSets/fakevmss/virtualMachines/0/networkInterfaces/fakenic","properties":{"primary":true}}]},"networkProfileConfiguration":{"networkInterfaceConfigurations":[{"id":"fakeNetworkConfiguration","name":"vmss-nic","properties":{"primary":true,"ipConfigurations":[{"name":"ipconfig1","properties":{"primary":true,"privateIPAddressVersion":"IPv4","loadBalancerBackendAddressPools":[{"id":"/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0"}]}}]}}]},"osProfile":{"computerName":"fakevmss000000"}}}`, + }, + + { + name: "should return only etag json when vmss is empty", + expectedVmssVM: fakeVmssVMWithoutCompueVMSSVM, + inputJSON: `{"etag":"\"120\""}`, + }, + + { + name: "should return full json when both VirtualMachineScaleSetVM and etag are set", + expectedVmssVM: fakeVmssVM, + inputJSON: `{"location":"eastus","properties":{"hardwareProfile":{"vmSize":"Standard_D2s_v3"},"networkProfile":{"networkInterfaces":[{"id":"/subscriptions/fakesub/resourceGroups/fakerg/providers/Microsoft.Compute/virtualMachineScaleSets/fakevmss/virtualMachines/0/networkInterfaces/fakenic","properties":{"primary":true}}]},"networkProfileConfiguration":{"networkInterfaceConfigurations":[{"id":"fakeNetworkConfiguration","name":"vmss-nic","properties":{"primary":true,"ipConfigurations":[{"name":"ipconfig1","properties":{"primary":true,"privateIPAddressVersion":"IPv4","loadBalancerBackendAddressPools":[{"id":"/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0"}]}}]}}]},"osProfile":{"computerName":"fakevmss000000"}},"etag":"\"120\""}`, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + var actualVmssVM VirtualMachineScaleSetVM + err := json.Unmarshal([]byte(tt.inputJSON), &actualVmssVM) + assert.Nil(t, err) + assert.Equal(t, actualVmssVM, tt.expectedVmssVM) + }) + } +} diff --git a/pkg/azureclients/vmssvmclient/interface.go b/pkg/azureclients/vmssvmclient/interface.go index 352c3aa642..c5108985ba 100644 --- a/pkg/azureclients/vmssvmclient/interface.go +++ b/pkg/azureclients/vmssvmclient/interface.go @@ -27,7 +27,7 @@ import ( const ( // APIVersion is the API version for VMSS. - APIVersion = "2022-03-01" + APIVersion = "2024-03-01" // AzureStackCloudAPIVersion is the API version for Azure Stack AzureStackCloudAPIVersion = "2019-07-01" // AzureStackCloudName is the cloud name of Azure Stack @@ -38,20 +38,20 @@ const ( // Don't forget to run "hack/update-mock-clients.sh" command to generate the mock client. type Interface interface { // Get gets a VirtualMachineScaleSetVM. - Get(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, expand compute.InstanceViewTypes) (compute.VirtualMachineScaleSetVM, *retry.Error) + Get(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, expand compute.InstanceViewTypes) (VirtualMachineScaleSetVM, *retry.Error) // List gets a list of VirtualMachineScaleSetVMs in the virtualMachineScaleSet. - List(ctx context.Context, resourceGroupName string, virtualMachineScaleSetName string, expand string) ([]compute.VirtualMachineScaleSetVM, *retry.Error) + List(ctx context.Context, resourceGroupName string, virtualMachineScaleSetName string, expand string) ([]VirtualMachineScaleSetVM, *retry.Error) // Update updates a VirtualMachineScaleSetVM. - Update(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, parameters compute.VirtualMachineScaleSetVM, source string) (*compute.VirtualMachineScaleSetVM, *retry.Error) + Update(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, parameters VirtualMachineScaleSetVM, source string, etag string) (*VirtualMachineScaleSetVM, *retry.Error) // UpdateAsync updates a VirtualMachineScaleSetVM asynchronously - UpdateAsync(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, parameters compute.VirtualMachineScaleSetVM, source string) (*azure.Future, *retry.Error) + UpdateAsync(ctx context.Context, resourceGroupName string, VMScaleSetName string, instanceID string, parameters VirtualMachineScaleSetVM, source string, etag string) (*azure.Future, *retry.Error) // WaitForUpdateResult waits for the response of the update request - WaitForUpdateResult(ctx context.Context, future *azure.Future, resourceGroupName, source string) (*compute.VirtualMachineScaleSetVM, *retry.Error) + WaitForUpdateResult(ctx context.Context, future *azure.Future, resourceGroupName, source string) (*VirtualMachineScaleSetVM, *retry.Error) // UpdateVMs updates a list of VirtualMachineScaleSetVM from map[instanceID]compute.VirtualMachineScaleSetVM. - UpdateVMs(ctx context.Context, resourceGroupName string, VMScaleSetName string, instances map[string]compute.VirtualMachineScaleSetVM, source string, batchSize int) *retry.Error + UpdateVMs(ctx context.Context, resourceGroupName string, VMScaleSetName string, instances map[string]VirtualMachineScaleSetVM, source string, batchSize int) *retry.Error } diff --git a/pkg/azureclients/vmssvmclient/mockvmssvmclient/interface.go b/pkg/azureclients/vmssvmclient/mockvmssvmclient/interface.go index 0bc207ca6a..4683f94ee0 100644 --- a/pkg/azureclients/vmssvmclient/mockvmssvmclient/interface.go +++ b/pkg/azureclients/vmssvmclient/mockvmssvmclient/interface.go @@ -33,6 +33,7 @@ import ( compute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" azure "github.com/Azure/go-autorest/autorest/azure" gomock "go.uber.org/mock/gomock" + vmssvmclient "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" retry "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) @@ -60,10 +61,10 @@ func (m *MockInterface) EXPECT() *MockInterfaceMockRecorder { } // Get mocks base method. -func (m *MockInterface) Get(ctx context.Context, resourceGroupName, VMScaleSetName, instanceID string, expand compute.InstanceViewTypes) (compute.VirtualMachineScaleSetVM, *retry.Error) { +func (m *MockInterface) Get(ctx context.Context, resourceGroupName, VMScaleSetName, instanceID string, expand compute.InstanceViewTypes) (vmssvmclient.VirtualMachineScaleSetVM, *retry.Error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, VMScaleSetName, instanceID, expand) - ret0, _ := ret[0].(compute.VirtualMachineScaleSetVM) + ret0, _ := ret[0].(vmssvmclient.VirtualMachineScaleSetVM) ret1, _ := ret[1].(*retry.Error) return ret0, ret1 } @@ -75,10 +76,10 @@ func (mr *MockInterfaceMockRecorder) Get(ctx, resourceGroupName, VMScaleSetName, } // List mocks base method. -func (m *MockInterface) List(ctx context.Context, resourceGroupName, virtualMachineScaleSetName, expand string) ([]compute.VirtualMachineScaleSetVM, *retry.Error) { +func (m *MockInterface) List(ctx context.Context, resourceGroupName, virtualMachineScaleSetName, expand string) ([]vmssvmclient.VirtualMachineScaleSetVM, *retry.Error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "List", ctx, resourceGroupName, virtualMachineScaleSetName, expand) - ret0, _ := ret[0].([]compute.VirtualMachineScaleSetVM) + ret0, _ := ret[0].([]vmssvmclient.VirtualMachineScaleSetVM) ret1, _ := ret[1].(*retry.Error) return ret0, ret1 } @@ -90,37 +91,37 @@ func (mr *MockInterfaceMockRecorder) List(ctx, resourceGroupName, virtualMachine } // Update mocks base method. -func (m *MockInterface) Update(ctx context.Context, resourceGroupName, VMScaleSetName, instanceID string, parameters compute.VirtualMachineScaleSetVM, source string) (*compute.VirtualMachineScaleSetVM, *retry.Error) { +func (m *MockInterface) Update(ctx context.Context, resourceGroupName, VMScaleSetName, instanceID string, parameters vmssvmclient.VirtualMachineScaleSetVM, source, etag string) (*vmssvmclient.VirtualMachineScaleSetVM, *retry.Error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Update", ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source) - ret0, _ := ret[0].(*compute.VirtualMachineScaleSetVM) + ret := m.ctrl.Call(m, "Update", ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source, etag) + ret0, _ := ret[0].(*vmssvmclient.VirtualMachineScaleSetVM) ret1, _ := ret[1].(*retry.Error) return ret0, ret1 } // Update indicates an expected call of Update. -func (mr *MockInterfaceMockRecorder) Update(ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source any) *gomock.Call { +func (mr *MockInterfaceMockRecorder) Update(ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source, etag any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockInterface)(nil).Update), ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockInterface)(nil).Update), ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source, etag) } // UpdateAsync mocks base method. -func (m *MockInterface) UpdateAsync(ctx context.Context, resourceGroupName, VMScaleSetName, instanceID string, parameters compute.VirtualMachineScaleSetVM, source string) (*azure.Future, *retry.Error) { +func (m *MockInterface) UpdateAsync(ctx context.Context, resourceGroupName, VMScaleSetName, instanceID string, parameters vmssvmclient.VirtualMachineScaleSetVM, source, etag string) (*azure.Future, *retry.Error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAsync", ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source) + ret := m.ctrl.Call(m, "UpdateAsync", ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source, etag) ret0, _ := ret[0].(*azure.Future) ret1, _ := ret[1].(*retry.Error) return ret0, ret1 } // UpdateAsync indicates an expected call of UpdateAsync. -func (mr *MockInterfaceMockRecorder) UpdateAsync(ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source any) *gomock.Call { +func (mr *MockInterfaceMockRecorder) UpdateAsync(ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source, etag any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAsync", reflect.TypeOf((*MockInterface)(nil).UpdateAsync), ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAsync", reflect.TypeOf((*MockInterface)(nil).UpdateAsync), ctx, resourceGroupName, VMScaleSetName, instanceID, parameters, source, etag) } // UpdateVMs mocks base method. -func (m *MockInterface) UpdateVMs(ctx context.Context, resourceGroupName, VMScaleSetName string, instances map[string]compute.VirtualMachineScaleSetVM, source string, batchSize int) *retry.Error { +func (m *MockInterface) UpdateVMs(ctx context.Context, resourceGroupName, VMScaleSetName string, instances map[string]vmssvmclient.VirtualMachineScaleSetVM, source string, batchSize int) *retry.Error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateVMs", ctx, resourceGroupName, VMScaleSetName, instances, source, batchSize) ret0, _ := ret[0].(*retry.Error) @@ -134,10 +135,10 @@ func (mr *MockInterfaceMockRecorder) UpdateVMs(ctx, resourceGroupName, VMScaleSe } // WaitForUpdateResult mocks base method. -func (m *MockInterface) WaitForUpdateResult(ctx context.Context, future *azure.Future, resourceGroupName, source string) (*compute.VirtualMachineScaleSetVM, *retry.Error) { +func (m *MockInterface) WaitForUpdateResult(ctx context.Context, future *azure.Future, resourceGroupName, source string) (*vmssvmclient.VirtualMachineScaleSetVM, *retry.Error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "WaitForUpdateResult", ctx, future, resourceGroupName, source) - ret0, _ := ret[0].(*compute.VirtualMachineScaleSetVM) + ret0, _ := ret[0].(*vmssvmclient.VirtualMachineScaleSetVM) ret1, _ := ret[1].(*retry.Error) return ret0, ret1 } diff --git a/pkg/azureclients/vmssvmclient/models.go b/pkg/azureclients/vmssvmclient/models.go new file mode 100644 index 0000000000..7a7b984c66 --- /dev/null +++ b/pkg/azureclients/vmssvmclient/models.go @@ -0,0 +1,131 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vmssvmclient + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/to" +) + +// VirtualMachineScaleSetVM wraps the original VirtualMachineScaleSetVM struct and adds an Etag field. +type VirtualMachineScaleSetVM struct { + compute.VirtualMachineScaleSetVM `json:",inline"` + // READ-ONLY; Etag is property returned in Update/Get response of the VMSS VM, so that customer can supply it in the header + // to ensure optimistic updates. + Etag *string `json:"etag,omitempty"` +} + +// VirtualMachineScaleSetVMListResult the List Virtual Machine operation response. +type VirtualMachineScaleSetVMListResult struct { + autorest.Response `json:"-"` + // Value - The list of virtual machine scale sets. + Value *[]VirtualMachineScaleSetVM `json:"value,omitempty"` + // NextLink - The uri to fetch the next page of Virtual Machine Scale Sets. Call ListNext() with this to fetch the next page of VMSS. + NextLink *string `json:"nextLink,omitempty"` +} + +// IsEmpty returns true if the ListResult contains no values. +func (vmssvmlr VirtualMachineScaleSetVMListResult) IsEmpty() bool { + return vmssvmlr.Value == nil || len(*vmssvmlr.Value) == 0 +} + +// hasNextLink returns true if the NextLink is not empty. +func (vmssvmlr VirtualMachineScaleSetVMListResult) hasNextLink() bool { + return vmssvmlr.NextLink != nil && len(*vmssvmlr.NextLink) != 0 +} + +// virtualMachineScaleSetListResultPreparer prepares a request to retrieve the next set of results. +// It returns nil if no more results exist. +func (vmssvmlr VirtualMachineScaleSetVMListResult) virtualMachineScaleSetListResultPreparer(ctx context.Context) (*http.Request, error) { + if !vmssvmlr.hasNextLink() { + return nil, nil + } + return autorest.Prepare((&http.Request{}).WithContext(ctx), + autorest.AsJSON(), + autorest.AsGet(), + autorest.WithBaseURL(to.String(vmssvmlr.NextLink))) +} + +// UnmarshalJSON is the custom unmarshaler for VirtualMachineScaleSetVM struct. +// compute.VirtualMachineScaleSetVM implemented `UnmarshalJSON` method, and when the response is unmarshaled into VirtualMachineScaleSetVM, +// compute.VirtualMachineScaleSetVM.UnmarshalJSON is called, leading to the loss of the Etag field. +func (vmssvm *VirtualMachineScaleSetVM) UnmarshalJSON(data []byte) error { + // Unmarshal Etag first + etagPlaceholder := struct { + Etag *string `json:"etag,omitempty"` + }{} + if err := json.Unmarshal(data, &etagPlaceholder); err != nil { + return err + } + // Unmarshal Nested VirtualMachineScaleSetVM + nestedVirtualMachineScaleSetVM := struct { + compute.VirtualMachineScaleSetVM `json:",inline"` + }{} + // the Nested impl UnmarshalJSON, so it should be unmarshaled alone + if err := json.Unmarshal(data, &nestedVirtualMachineScaleSetVM); err != nil { + return err + } + (vmssvm).Etag = etagPlaceholder.Etag + (vmssvm).VirtualMachineScaleSetVM = nestedVirtualMachineScaleSetVM.VirtualMachineScaleSetVM + return nil +} + +// MarshalJSON is the custom marshaler for VirtualMachineScaleSetVM. +func (vmssv VirtualMachineScaleSetVM) MarshalJSON() ([]byte, error) { + var err error + var nestedVirtualMachineScaleSetVMJSON, etagJSON []byte + if nestedVirtualMachineScaleSetVMJSON, err = vmssv.VirtualMachineScaleSetVM.MarshalJSON(); err != nil { + return nil, err + } + + if vmssv.Etag != nil { + if etagJSON, err = json.Marshal(map[string]interface{}{ + "etag": vmssv.Etag, + }); err != nil { + return nil, err + } + } + + // empty struct can be Unmarshaled to "{}" + nestedVirtualMachineScaleSetVMJSONEmpty := true + if string(nestedVirtualMachineScaleSetVMJSON) != "{}" { + nestedVirtualMachineScaleSetVMJSONEmpty = false + } + etagJSONEmpty := true + if len(etagJSON) != 0 { + etagJSONEmpty = false + } + + // when both parts not empty, join the two parts with a comma but remove the open brace of nestedVirtualMachineScaleSetVMJson and the close brace of the etagJSON + // {"location": "eastus"} + {"etag": "\"120\""} will be merged into {"location": "eastus", "etag": "\"120\""} + if !nestedVirtualMachineScaleSetVMJSONEmpty && !etagJSONEmpty { + etagJSON[0] = ',' + return append(nestedVirtualMachineScaleSetVMJSON[:len(nestedVirtualMachineScaleSetVMJSON)-1], etagJSON...), nil + } + if !nestedVirtualMachineScaleSetVMJSONEmpty { + return nestedVirtualMachineScaleSetVMJSON, nil + } + if !etagJSONEmpty { + return etagJSON, nil + } + return []byte("{}"), nil +} diff --git a/pkg/provider/azure_controller_vmss.go b/pkg/provider/azure_controller_vmss.go index 48e4185048..6c61580f8a 100644 --- a/pkg/provider/azure_controller_vmss.go +++ b/pkg/provider/azure_controller_vmss.go @@ -30,6 +30,7 @@ import ( "k8s.io/klog/v2" "k8s.io/utils/ptr" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/retry" @@ -100,23 +101,25 @@ func (ss *ScaleSet) AttachDisk(ctx context.Context, nodeName types.NodeName, dis }) } - newVM := compute.VirtualMachineScaleSetVM{ - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - StorageProfile: &compute.StorageProfile{ - DataDisks: &disks, + newVM := vmssvmclient.VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ + StorageProfile: &compute.StorageProfile{ + DataDisks: &disks, + }, }, }, } klog.V(2).Infof("azureDisk - update: rg(%s) vm(%s) - attach disk list(%+v)", nodeResourceGroup, nodeName, diskMap) - future, rerr := ss.VirtualMachineScaleSetVMsClient.UpdateAsync(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, "attach_disk") + future, rerr := ss.VirtualMachineScaleSetVMsClient.UpdateAsync(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, "attach_disk", "") if rerr != nil { klog.Errorf("azureDisk - attach disk list(%+v) on rg(%s) vm(%s) failed, err: %v", diskMap, nodeResourceGroup, nodeName, rerr) if rerr.HTTPStatusCode == http.StatusNotFound { klog.Errorf("azureDisk - begin to filterNonExistingDisks(%v) on rg(%s) vm(%s)", diskMap, nodeResourceGroup, nodeName) disks := FilterNonExistingDisks(ctx, ss.DisksClient, *newVM.VirtualMachineScaleSetVMProperties.StorageProfile.DataDisks) newVM.VirtualMachineScaleSetVMProperties.StorageProfile.DataDisks = &disks - future, rerr = ss.VirtualMachineScaleSetVMsClient.UpdateAsync(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, "attach_disk") + future, rerr = ss.VirtualMachineScaleSetVMsClient.UpdateAsync(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, "attach_disk", "") } } @@ -217,15 +220,17 @@ func (ss *ScaleSet) DetachDisk(ctx context.Context, nodeName types.NodeName, dis } } - newVM := compute.VirtualMachineScaleSetVM{ - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - StorageProfile: &compute.StorageProfile{ - DataDisks: &disks, + newVM := vmssvmclient.VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ + StorageProfile: &compute.StorageProfile{ + DataDisks: &disks, + }, }, }, } - var result *compute.VirtualMachineScaleSetVM + var result *vmssvmclient.VirtualMachineScaleSetVM var rerr *retry.Error defer func() { @@ -242,14 +247,14 @@ func (ss *ScaleSet) DetachDisk(ctx context.Context, nodeName types.NodeName, dis klog.V(2).Infof("azureDisk - update(%s): vm(%s) - detach disk list(%s)", nodeResourceGroup, nodeName, diskMap) result, rerr = ss.VirtualMachineScaleSetVMsClient.Update(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, - "detach_disk") + "detach_disk", "") if rerr != nil { klog.Errorf("azureDisk - detach disk list(%s) on rg(%s) vm(%s) failed, err: %v", diskMap, nodeResourceGroup, nodeName, rerr) if rerr.HTTPStatusCode == http.StatusNotFound { klog.Errorf("azureDisk - begin to filterNonExistingDisks(%v) on rg(%s) vm(%s)", diskMap, nodeResourceGroup, nodeName) disks := FilterNonExistingDisks(ctx, ss.DisksClient, *newVM.VirtualMachineScaleSetVMProperties.StorageProfile.DataDisks) newVM.VirtualMachineScaleSetVMProperties.StorageProfile.DataDisks = &disks - result, rerr = ss.VirtualMachineScaleSetVMsClient.Update(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, "detach_disk") + result, rerr = ss.VirtualMachineScaleSetVMsClient.Update(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, "detach_disk", "") } } @@ -282,7 +287,7 @@ func (ss *ScaleSet) UpdateVMAsync(ctx context.Context, nodeName types.NodeName) return nil, err } - future, rerr := ss.VirtualMachineScaleSetVMsClient.UpdateAsync(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, compute.VirtualMachineScaleSetVM{}, "update_vmss_instance") + future, rerr := ss.VirtualMachineScaleSetVMsClient.UpdateAsync(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, vmssvmclient.VirtualMachineScaleSetVM{}, "update_vmss_instance", "") if rerr != nil { return future, rerr.Error() } diff --git a/pkg/provider/azure_controller_vmss_test.go b/pkg/provider/azure_controller_vmss_test.go index 5dff138cfb..5e91caddb0 100644 --- a/pkg/provider/azure_controller_vmss_test.go +++ b/pkg/provider/azure_controller_vmss_test.go @@ -37,6 +37,7 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient/mockvmssvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/retry" @@ -104,7 +105,7 @@ func TestAttachDiskWithVMSS(t *testing.T) { mockVMSSClient := testCloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), "").Return(nil).MaxTimes(1) mockVMClient := testCloud.VirtualMachinesClient.(*mockvmclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachine{}, nil).AnyTimes() @@ -133,9 +134,9 @@ func TestAttachDiskWithVMSS(t *testing.T) { mockVMSSVMClient := testCloud.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() if scaleSetName == string(fakeStatusNotFoundVMSSName) { - mockVMSSVMClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMSSVMClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any(), "").Return(nil, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() } else { - mockVMSSVMClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockVMSSVMClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any(), "").Return(nil, nil).AnyTimes() mockVMSSVMClient.EXPECT().WaitForUpdateResult(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() } @@ -244,10 +245,10 @@ func TestDetachDiskWithVMSS(t *testing.T) { mockVMSSClient := testCloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), "").Return(nil).MaxTimes(1) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(testCloud, scaleSetName, "", 0, test.vmssVMList, "succeeded", false) - var updatedVMSSVM *compute.VirtualMachineScaleSetVM + var updatedVMSSVM *vmssvmclient.VirtualMachineScaleSetVM for itr, vmssvm := range expectedVMSSVMs { vmssvm.StorageProfile = &compute.StorageProfile{ OsDisk: &compute.OSDisk{ @@ -282,9 +283,9 @@ func TestDetachDiskWithVMSS(t *testing.T) { mockVMSSVMClient := testCloud.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() if scaleSetName == strings.ToLower(string(fakeStatusNotFoundVMSSName)) { - mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any()).Return(updatedVMSSVM, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any(), "").Return(updatedVMSSVM, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() } else { - mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any()).Return(updatedVMSSVM, nil).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any(), "").Return(updatedVMSSVM, nil).AnyTimes() } mockVMClient := testCloud.VirtualMachinesClient.(*mockvmclient.MockInterface) @@ -374,10 +375,10 @@ func TestUpdateVMWithVMSS(t *testing.T) { mockVMSSClient := testCloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), "").Return(nil).MaxTimes(1) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(testCloud, scaleSetName, "", 0, test.vmssVMList, "succeeded", false) - var updatedVMSSVM *compute.VirtualMachineScaleSetVM + var updatedVMSSVM *vmssvmclient.VirtualMachineScaleSetVM for itr, vmssvm := range expectedVMSSVMs { vmssvm.StorageProfile = &compute.StorageProfile{ @@ -408,7 +409,7 @@ func TestUpdateVMWithVMSS(t *testing.T) { future, err := azure.NewFutureFromResponse(r) - mockVMSSVMClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&future, err).AnyTimes() + mockVMSSVMClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "").Return(&future, err).AnyTimes() if scaleSetName == strings.ToLower(string(fakeStatusNotFoundVMSSName)) { mockVMSSVMClient.EXPECT().WaitForUpdateResult(gomock.Any(), &future, testCloud.ResourceGroup, gomock.Any()).Return(updatedVMSSVM, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() @@ -496,7 +497,7 @@ func TestGetDataDisksWithVMSS(t *testing.T) { mockVMSSClient := testCloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), "").Return(nil).MaxTimes(1) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(testCloud, scaleSetName, "", 0, []string{"vmss00-vm-000000"}, "succeeded", false) if !test.isDataDiskNull { @@ -512,7 +513,7 @@ func TestGetDataDisksWithVMSS(t *testing.T) { updatedVMSSVM := &expectedVMSSVMs[0] mockVMSSVMClient := testCloud.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any()).Return(updatedVMSSVM, nil).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any(), "").Return(updatedVMSSVM, nil).AnyTimes() dataDisks, _, err := ss.GetDataDisks(context.TODO(), test.nodeName, test.crt) assert.Equal(t, test.expectedDataDisks, dataDisks, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, test.expectedErr, err != nil, "TestCase[%d]: %s", i, test.desc) diff --git a/pkg/provider/azure_mock_vmsets.go b/pkg/provider/azure_mock_vmsets.go index 5476b3b6b2..a7a84fe0e4 100644 --- a/pkg/provider/azure_mock_vmsets.go +++ b/pkg/provider/azure_mock_vmsets.go @@ -16,13 +16,12 @@ // // Code generated by MockGen. DO NOT EDIT. -// Source: /Users/niqi/go/src/sigs.k8s.io/cloud-provider-azure/pkg/provider/azure_vmsets.go +// Source: /home/azureuser/go/src/sigs.k8s.io/cloud-provider-azure/pkg/provider/azure_vmsets.go // // Generated by this command: // -// mockgen -destination=/Users/niqi/go/src/sigs.k8s.io/cloud-provider-azure/pkg/provider/azure_mock_vmsets.go -source=/Users/niqi/go/src/sigs.k8s.io/cloud-provider-azure/pkg/provider/azure_vmsets.go -package=provider VMSet +// mockgen -copyright_file=/home/azureuser/go/src/sigs.k8s.io/cloud-provider-azure/hack/boilerplate/boilerplate.generatego.txt -destination=/home/azureuser/go/src/sigs.k8s.io/cloud-provider-azure/pkg/provider/azure_mock_vmsets.go -source=/home/azureuser/go/src/sigs.k8s.io/cloud-provider-azure/pkg/provider/azure_vmsets.go -package=provider VMSet // - // Package provider is a generated GoMock package. package provider @@ -31,12 +30,12 @@ import ( reflect "reflect" armcompute "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - compute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" network "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" gomock "go.uber.org/mock/gomock" v1 "k8s.io/api/core/v1" types "k8s.io/apimachinery/pkg/types" cloudprovider "k8s.io/cloud-provider" + vmssvmclient "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" cache "sigs.k8s.io/cloud-provider-azure/pkg/cache" ) @@ -135,13 +134,13 @@ func (mr *MockVMSetMockRecorder) EnsureBackendPoolDeletedFromVMSets(ctx, vmSetNa } // EnsureHostInPool mocks base method. -func (m *MockVMSet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID, vmSetName string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { +func (m *MockVMSet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID, vmSetName string) (string, string, string, *vmssvmclient.VirtualMachineScaleSetVM, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EnsureHostInPool", ctx, service, nodeName, backendPoolID, vmSetName) ret0, _ := ret[0].(string) ret1, _ := ret[1].(string) ret2, _ := ret[2].(string) - ret3, _ := ret[3].(*compute.VirtualMachineScaleSetVM) + ret3, _ := ret[3].(*vmssvmclient.VirtualMachineScaleSetVM) ret4, _ := ret[4].(error) return ret0, ret1, ret2, ret3, ret4 } diff --git a/pkg/provider/azure_standard.go b/pkg/provider/azure_standard.go index 79b13824c9..8d18172150 100644 --- a/pkg/provider/azure_standard.go +++ b/pkg/provider/azure_standard.go @@ -39,6 +39,7 @@ import ( "k8s.io/klog/v2" "k8s.io/utils/ptr" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/metrics" @@ -859,7 +860,7 @@ func (as *availabilitySet) getPrimaryInterfaceWithVMSet(ctx context.Context, nod // EnsureHostInPool ensures the given VM's Primary NIC's Primary IP Configuration is // participating in the specified LoadBalancer Backend Pool. -func (as *availabilitySet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetName string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { +func (as *availabilitySet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetName string) (string, string, string, *vmssvmclient.VirtualMachineScaleSetVM, error) { vmName := mapNodeNameToVMName(nodeName) serviceName := getServiceName(service) nic, _, err := as.getPrimaryInterfaceWithVMSet(ctx, vmName, vmSetName) diff --git a/pkg/provider/azure_vmsets.go b/pkg/provider/azure_vmsets.go index 5b7c62acd9..247efa052d 100644 --- a/pkg/provider/azure_vmsets.go +++ b/pkg/provider/azure_vmsets.go @@ -27,6 +27,7 @@ import ( "k8s.io/apimachinery/pkg/types" cloudprovider "k8s.io/cloud-provider" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" ) @@ -69,7 +70,7 @@ type VMSet interface { EnsureHostsInPool(ctx context.Context, service *v1.Service, nodes []*v1.Node, backendPoolID string, vmSetName string) error // EnsureHostInPool ensures the given VM's Primary NIC's Primary IP Configuration is // participating in the specified LoadBalancer Backend Pool. - EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetName string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) + EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetName string) (string, string, string, *vmssvmclient.VirtualMachineScaleSetVM, error) // EnsureBackendPoolDeleted ensures the loadBalancer backendAddressPools deleted from the specified nodes. EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools *[]network.BackendAddressPool, deleteFromVMSet bool) (bool, error) // EnsureBackendPoolDeletedFromVMSets ensures the loadBalancer backendAddressPools deleted from the specified VMSS/VMAS diff --git a/pkg/provider/azure_vmss.go b/pkg/provider/azure_vmss.go index 0db63e3e14..54dde58927 100644 --- a/pkg/provider/azure_vmss.go +++ b/pkg/provider/azure_vmss.go @@ -36,10 +36,13 @@ import ( "k8s.io/klog/v2" "k8s.io/utils/ptr" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/metrics" "sigs.k8s.io/cloud-provider-azure/pkg/provider/virtualmachine" + "sigs.k8s.io/cloud-provider-azure/pkg/retry" vmutil "sigs.k8s.io/cloud-provider-azure/pkg/util/vm" ) @@ -164,8 +167,8 @@ func newScaleSet(az *Cloud) (VMSet, error) { return ss, nil } -func (ss *ScaleSet) getVMSS(ctx context.Context, vmssName string, crt azcache.AzureCacheReadType) (*compute.VirtualMachineScaleSet, error) { - getter := func(vmssName string) (*compute.VirtualMachineScaleSet, error) { +func (ss *ScaleSet) getVMSS(ctx context.Context, vmssName string, crt azcache.AzureCacheReadType) (*vmssclient.VirtualMachineScaleSet, error) { + getter := func(vmssName string) (*vmssclient.VirtualMachineScaleSet, error) { cached, err := ss.vmssCache.Get(ctx, consts.VMSSKey, crt) if err != nil { return nil, err @@ -223,6 +226,7 @@ func (ss *ScaleSet) getVmssVMByNodeIdentity(ctx context.Context, node *nodeIdent return nil, true, nil } found = true + return virtualmachine.FromVirtualMachineScaleSetVM(result.VirtualMachine, virtualmachine.ByVMSS(result.VMSSName)), found, nil } @@ -339,8 +343,8 @@ func (ss *ScaleSet) GetProvisioningStateByNodeName(ctx context.Context, name str // getCachedVirtualMachineByInstanceID gets scaleSetVMInfo from cache. // The node must belong to one of scale sets. -func (ss *ScaleSet) getVmssVMByInstanceID(ctx context.Context, resourceGroup, scaleSetName, instanceID string, crt azcache.AzureCacheReadType) (*compute.VirtualMachineScaleSetVM, error) { - getter := func(ctx context.Context, crt azcache.AzureCacheReadType) (vm *compute.VirtualMachineScaleSetVM, found bool, err error) { +func (ss *ScaleSet) getVmssVMByInstanceID(ctx context.Context, resourceGroup, scaleSetName, instanceID string, crt azcache.AzureCacheReadType) (*vmssvmclient.VirtualMachineScaleSetVM, error) { + getter := func(ctx context.Context, crt azcache.AzureCacheReadType) (vm *vmssvmclient.VirtualMachineScaleSetVM, found bool, err error) { virtualMachines, err := ss.getVMSSVMsFromCache(ctx, resourceGroup, scaleSetName, crt) if err != nil { return nil, false, err @@ -816,7 +820,7 @@ func (ss *ScaleSet) getNodeIdentityByNodeName(ctx context.Context, nodeName stri } // listScaleSetVMs lists VMs belonging to the specified scale set. -func (ss *ScaleSet) listScaleSetVMs(scaleSetName, resourceGroup string) ([]compute.VirtualMachineScaleSetVM, error) { +func (ss *ScaleSet) listScaleSetVMs(scaleSetName, resourceGroup string) ([]vmssvmclient.VirtualMachineScaleSetVM, error) { ctx, cancel := getContextWithCancel() defer cancel() @@ -1037,7 +1041,7 @@ func getPrimaryIPConfigFromVMSSNetworkConfig(config *compute.VirtualMachineScale // EnsureHostInPool ensures the given VM's Primary NIC's Primary IP Configuration is // participating in the specified LoadBalancer Backend Pool, which returns (resourceGroup, vmasName, instanceID, vmssVM, error). -func (ss *ScaleSet) EnsureHostInPool(ctx context.Context, _ *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetNameOfLB string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { +func (ss *ScaleSet) EnsureHostInPool(ctx context.Context, _ *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetNameOfLB string) (string, string, string, *vmssvmclient.VirtualMachineScaleSetVM, error) { logger := klog.Background().WithName("EnsureHostInPool"). WithValues("nodeName", nodeName, "backendPoolID", backendPoolID, "vmSetNameOfLB", vmSetNameOfLB) vmName := mapNodeNameToVMName(nodeName) @@ -1142,14 +1146,17 @@ func (ss *ScaleSet) EnsureHostInPool(ctx context.Context, _ *v1.Service, nodeNam ID: ptr.To(backendPoolID), }) primaryIPConfiguration.LoadBalancerBackendAddressPools = &newBackendPools - newVM := &compute.VirtualMachineScaleSetVM{ - Location: &vm.Location, - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - HardwareProfile: vm.VirtualMachineScaleSetVMProperties.HardwareProfile, - NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ - NetworkInterfaceConfigurations: &networkInterfaceConfigurations, + newVM := &vmssvmclient.VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{ + Location: &vm.Location, + VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ + HardwareProfile: vm.VirtualMachineScaleSetVMProperties.HardwareProfile, + NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ + NetworkInterfaceConfigurations: &networkInterfaceConfigurations, + }, }, }, + Etag: vm.Etag, } // Get the node resource group. @@ -1309,21 +1316,30 @@ func (ss *ScaleSet) ensureVMSSInPool(ctx context.Context, _ *v1.Service, nodes [ ID: ptr.To(backendPoolID), }) primaryIPConfig.LoadBalancerBackendAddressPools = &loadBalancerBackendAddressPools - newVMSS := compute.VirtualMachineScaleSet{ - Location: vmss.Location, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: &vmssNIC, + newVMSS := vmssclient.VirtualMachineScaleSet{ + VirtualMachineScaleSet: compute.VirtualMachineScaleSet{ + Location: vmss.Location, + VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ + NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: &vmssNIC, + }, }, }, }, + Etag: vmss.Etag, } klog.V(2).Infof("ensureVMSSInPool begins to update vmss(%s) with new backendPoolID %s", vmssName, backendPoolID) rerr := ss.CreateOrUpdateVMSS(ss.ResourceGroup, vmssName, newVMSS) + // VMSS cache must be refreshed when etagmismatch error happens. + // TODO(mainred): we need to update the cache from the response of a successful request. + if rerr != nil && errors.Is(rerr.Error(), &retry.EtagMismatchError{}) { + klog.V(3).Infof("ensureVMSSInPool invalidate the vmss cache for EtagMismatchError") + _ = ss.vmssCache.Delete(consts.VMSSKey) + } if rerr != nil { - klog.Errorf("ensureVMSSInPool CreateOrUpdateVMSS(%s) with new backendPoolID %s, err: %v", vmssName, backendPoolID, err) + klog.Errorf("ensureVMSSInPool CreateOrUpdateVMSS(%s) with new backendPoolID %s, err: %v", vmssName, backendPoolID, rerr.Error()) return rerr.Error() } } @@ -1331,7 +1347,7 @@ func (ss *ScaleSet) ensureVMSSInPool(ctx context.Context, _ *v1.Service, nodes [ } // isWindows2019 checks if the ImageReference on the VMSS matches a Windows Server 2019 image. -func isWindows2019(vmss *compute.VirtualMachineScaleSet) bool { +func isWindows2019(vmss *vmssclient.VirtualMachineScaleSet) bool { if vmss == nil { return false } @@ -1385,8 +1401,8 @@ func (ss *ScaleSet) ensureHostsInPool(ctx context.Context, service *v1.Service, } hostUpdates := make([]func() error, 0, len(nodes)) - nodeUpdates := make(map[vmssMetaInfo]map[string]compute.VirtualMachineScaleSetVM) - errors := make([]error, 0) + nodeUpdates := make(map[vmssMetaInfo]map[string]vmssvmclient.VirtualMachineScaleSetVM) + errs := make([]error, 0) for _, node := range nodes { localNodeName := node.Name @@ -1408,7 +1424,7 @@ func (ss *ScaleSet) ensureHostsInPool(ctx context.Context, service *v1.Service, nodeResourceGroup, nodeVMSS, nodeInstanceID, nodeVMSSVM, err := ss.EnsureHostInPool(ctx, service, types.NodeName(localNodeName), backendPoolID, vmSetNameOfLB) if err != nil { klog.Errorf("EnsureHostInPool(%s): backendPoolID(%s) - failed to ensure host in pool: %q", getServiceName(service), backendPoolID, err) - errors = append(errors, err) + errs = append(errs, err) continue } @@ -1421,7 +1437,7 @@ func (ss *ScaleSet) ensureHostsInPool(ctx context.Context, service *v1.Service, if v, ok := nodeUpdates[nodeVMSSMetaInfo]; ok { v[nodeInstanceID] = *nodeVMSSVM } else { - nodeUpdates[nodeVMSSMetaInfo] = map[string]compute.VirtualMachineScaleSetVM{ + nodeUpdates[nodeVMSSMetaInfo] = map[string]vmssvmclient.VirtualMachineScaleSetVM{ nodeInstanceID: *nodeVMSSVM, } } @@ -1455,6 +1471,8 @@ func (ss *ScaleSet) ensureHostsInPool(ctx context.Context, service *v1.Service, } klog.V(2).InfoS("Begin to update VMs for VMSS with new backendPoolID", logFields...) + //NOTE(mainred): We don't have to invalidate the cache in case of ETagMismatch error, since the cache is already invalidated anyway + // in the last nodes loop in defer function. rerr := ss.VirtualMachineScaleSetVMsClient.UpdateVMs(ctx, meta.resourceGroup, meta.vmssName, update, "network_update", batchSize) if rerr != nil { klog.ErrorS(err, "Failed to update VMs for VMSS", logFields...) @@ -1464,14 +1482,20 @@ func (ss *ScaleSet) ensureHostsInPool(ctx context.Context, service *v1.Service, return nil }) } - errs := utilerrors.AggregateGoroutines(hostUpdates...) - if errs != nil { - return utilerrors.Flatten(errs) + updateErrors := utilerrors.AggregateGoroutines(hostUpdates...) + if updateErrors != nil { + // TODO(mainred): Update vm cache from response when a sucessful update is done instead of always invalidating the cache for a refresh. + // Invalidates the vm cache only when an etag mismatch error happens to reduce the cache triggered API call. + aggUpdateErrors := utilerrors.Flatten(updateErrors) + if errors.Is(aggUpdateErrors, &retry.EtagMismatchError{}) { + klog.V(3).Info("EnsureHostInPool UpdateVMs failed for EtagMismatchError") + } + return aggUpdateErrors } // Fail if there are other errors. - if len(errors) > 0 { - return utilerrors.Flatten(utilerrors.NewAggregate(errors)) + if len(errs) > 0 { + return utilerrors.Flatten(utilerrors.NewAggregate(errs)) } isOperationSucceeded = true @@ -1557,7 +1581,7 @@ func (ss *ScaleSet) EnsureHostsInPool(ctx context.Context, service *v1.Service, // ensureBackendPoolDeletedFromNode ensures the loadBalancer backendAddressPools deleted // from the specified node, which returns (resourceGroup, vmasName, instanceID, vmssVM, error). -func (ss *ScaleSet) ensureBackendPoolDeletedFromNode(ctx context.Context, nodeName string, backendPoolIDs []string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { +func (ss *ScaleSet) ensureBackendPoolDeletedFromNode(ctx context.Context, nodeName string, backendPoolIDs []string) (string, string, string, *vmssvmclient.VirtualMachineScaleSetVM, error) { logger := klog.Background().WithName("ensureBackendPoolDeletedFromNode").WithValues("nodeName", nodeName, "backendPoolIDs", backendPoolIDs) vm, err := ss.getVmssVM(ctx, nodeName, azcache.CacheReadTypeDefault) if err != nil { @@ -1604,14 +1628,17 @@ func (ss *ScaleSet) ensureBackendPoolDeletedFromNode(ctx context.Context, nodeNa } // Compose a new vmssVM with added backendPoolID. - newVM := &compute.VirtualMachineScaleSetVM{ - Location: &vm.Location, - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - HardwareProfile: vm.VirtualMachineScaleSetVMProperties.HardwareProfile, - NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ - NetworkInterfaceConfigurations: &networkInterfaceConfigurations, + newVM := &vmssvmclient.VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{ + Location: &vm.Location, + VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ + HardwareProfile: vm.VirtualMachineScaleSetVMProperties.HardwareProfile, + NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ + NetworkInterfaceConfigurations: &networkInterfaceConfigurations, + }, }, }, + Etag: vm.Etag, } // Get the node resource group. @@ -1743,10 +1770,10 @@ func (ss *ScaleSet) ensureBackendPoolDeletedFromVmssUniform(ctx context.Context, vmssUniformMap := cachedUniform.(*sync.Map) var errorList []error walk := func(_, value interface{}) bool { - var vmss *compute.VirtualMachineScaleSet + var vmss *vmssclient.VirtualMachineScaleSet if vmssEntry, ok := value.(*VMSSEntry); ok { vmss = vmssEntry.VMSS - } else if v, ok := value.(*compute.VirtualMachineScaleSet); ok { + } else if v, ok := value.(*vmssclient.VirtualMachineScaleSet); ok { vmss = v } klog.V(2).Infof("ensureBackendPoolDeletedFromVmssUniform: vmss %q, backendPoolIDs %q", ptr.Deref(vmss.Name, ""), backendPoolIDs) @@ -1845,7 +1872,7 @@ func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Se // Ensure the backendPoolID is deleted from the VMSS VMs. hostUpdates := make([]func() error, 0, len(ipConfigurationIDs)) - nodeUpdates := make(map[vmssMetaInfo]map[string]compute.VirtualMachineScaleSetVM) + nodeUpdates := make(map[vmssMetaInfo]map[string]vmssvmclient.VirtualMachineScaleSetVM) allErrs := make([]error, 0) visitedIPConfigIDPrefix := map[string]bool{} for i := range ipConfigurationIDs { @@ -1900,7 +1927,7 @@ func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Se if v, ok := nodeUpdates[nodeVMSSMetaInfo]; ok { v[nodeInstanceID] = *nodeVMSSVM } else { - nodeUpdates[nodeVMSSMetaInfo] = map[string]compute.VirtualMachineScaleSetVM{ + nodeUpdates[nodeVMSSMetaInfo] = map[string]vmssvmclient.VirtualMachineScaleSetVM{ nodeInstanceID: *nodeVMSSVM, } } @@ -1932,6 +1959,8 @@ func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Se } klog.V(2).InfoS("Begin to update VMs for VMSS with new backendPoolID", logFields...) + //NOTE(mainred): We don't have to invalidate the cache in case of ETagMismatch error, since the cache is already invalidated anyway + // in the last nodes loop in defer function. rerr := ss.VirtualMachineScaleSetVMsClient.UpdateVMs(ctx, meta.resourceGroup, meta.vmssName, update, "network_update", batchSize) if rerr != nil { klog.ErrorS(err, "Failed to update VMs for VMSS", logFields...) @@ -1942,9 +1971,15 @@ func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Se return nil }) } - errs := utilerrors.AggregateGoroutines(hostUpdates...) - if errs != nil { - return updatedVM.Load(), utilerrors.Flatten(errs) + updateErrors := utilerrors.AggregateGoroutines(hostUpdates...) + if updateErrors != nil { + // TODO(mainred): Update vm cache from response when a sucessful update is done instead of always invalidating the cache for a refresh. + // Invalidates the vm cache only when an etag mismatch error happens to reduce the cache triggered API call. + aggUpdateErrors := utilerrors.Flatten(updateErrors) + if errors.Is(aggUpdateErrors, &retry.EtagMismatchError{}) { + klog.V(3).Info("EnsureBackendPoolDeleted UpdateVMs failed for EtagMismatchError") + } + return updatedVM.Load(), aggUpdateErrors } // Fail if there are other errors. @@ -2146,13 +2181,13 @@ func deleteBackendPoolFromIPConfig(msg, backendPoolID, resource string, primaryN // EnsureBackendPoolDeletedFromVMSets ensures the loadBalancer backendAddressPools deleted from the specified VMSS func (ss *ScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmssNamesMap map[string]bool, backendPoolIDs []string) error { vmssUpdaters := make([]func() error, 0, len(vmssNamesMap)) - errors := make([]error, 0, len(vmssNamesMap)) + errs := make([]error, 0, len(vmssNamesMap)) for vmssName := range vmssNamesMap { vmssName := vmssName vmss, err := ss.getVMSS(ctx, vmssName, azcache.CacheReadTypeDefault) if err != nil { - klog.Errorf("ensureBackendPoolDeletedFromVMSS: failed to get VMSS %s: %v", vmssName, err) - errors = append(errors, err) + klog.Errorf("EnsureBackendPoolDeletedFromVMSets: failed to get VMSS %s: %v", vmssName, err) + errs = append(errs, err) continue } @@ -2170,14 +2205,14 @@ func (ss *ScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmss primaryNIC, err := getPrimaryNetworkInterfaceConfiguration(vmssNIC, vmssName) if err != nil { klog.Errorf("EnsureBackendPoolDeletedFromVMSets: failed to get the primary network interface config of the VMSS %s: %v", vmssName, err) - errors = append(errors, err) + errs = append(errs, err) continue } foundTotal := false for _, backendPoolID := range backendPoolIDs { found, err := deleteBackendPoolFromIPConfig("EnsureBackendPoolDeletedFromVMSets", backendPoolID, vmssName, primaryNIC) if err != nil { - errors = append(errors, err) + errs = append(errs, err) continue } if found { @@ -2190,19 +2225,30 @@ func (ss *ScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmss vmssUpdaters = append(vmssUpdaters, func() error { // Compose a new vmss with added backendPoolID. - newVMSS := compute.VirtualMachineScaleSet{ - Location: vmss.Location, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: &vmssNIC, + newVMSS := vmssclient.VirtualMachineScaleSet{ + VirtualMachineScaleSet: compute.VirtualMachineScaleSet{ + Location: vmss.Location, + VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ + NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: &vmssNIC, + }, }, }, }, + Etag: vmss.Etag, } klog.V(2).Infof("EnsureBackendPoolDeletedFromVMSets begins to update vmss(%s) with backendPoolIDs %q", vmssName, backendPoolIDs) rerr := ss.CreateOrUpdateVMSS(ss.ResourceGroup, vmssName, newVMSS) + + // VMSS cache must be refreshed when etagmismatch error happens. + // TODO(mainred): we need to update the cache from the response of a successful request. + if rerr != nil && errors.Is(rerr.Error(), &retry.EtagMismatchError{}) { + klog.V(3).Infof("EnsureBackendPoolDeletedFromVMSets invalidate the vmss cache for EtagMismatchError") + _ = ss.vmssCache.Delete(consts.VMSSKey) + } + if rerr != nil { klog.Errorf("EnsureBackendPoolDeletedFromVMSets CreateOrUpdateVMSS(%s) with new backendPoolIDs %q, err: %v", vmssName, backendPoolIDs, rerr) return rerr.Error() @@ -2212,13 +2258,13 @@ func (ss *ScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmss }) } - errs := utilerrors.AggregateGoroutines(vmssUpdaters...) + aggregateErrs := utilerrors.AggregateGoroutines(vmssUpdaters...) if errs != nil { - return utilerrors.Flatten(errs) + return utilerrors.Flatten(aggregateErrs) } // Fail if there are other errors. - if len(errors) > 0 { - return utilerrors.Flatten(utilerrors.NewAggregate(errors)) + if len(errs) > 0 { + return utilerrors.Flatten(utilerrors.NewAggregate(errs)) } return nil diff --git a/pkg/provider/azure_vmss_cache.go b/pkg/provider/azure_vmss_cache.go index e011f85e9a..09d23a5f2a 100644 --- a/pkg/provider/azure_vmss_cache.go +++ b/pkg/provider/azure_vmss_cache.go @@ -28,6 +28,8 @@ import ( "k8s.io/klog/v2" "k8s.io/utils/ptr" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" @@ -37,12 +39,12 @@ type VMSSVirtualMachineEntry struct { ResourceGroup string VMSSName string InstanceID string - VirtualMachine *compute.VirtualMachineScaleSetVM + VirtualMachine *vmssvmclient.VirtualMachineScaleSetVM LastUpdate time.Time } type VMSSEntry struct { - VMSS *compute.VirtualMachineScaleSet + VMSS *vmssclient.VirtualMachineScaleSet ResourceGroup string LastUpdate time.Time } @@ -287,7 +289,7 @@ func (ss *ScaleSet) DeleteCacheForNode(ctx context.Context, nodeName string) err return nil } -func (ss *ScaleSet) updateCache(ctx context.Context, nodeName, resourceGroupName, vmssName, instanceID string, updatedVM *compute.VirtualMachineScaleSetVM) error { +func (ss *ScaleSet) updateCache(ctx context.Context, nodeName, resourceGroupName, vmssName, instanceID string, updatedVM *vmssvmclient.VirtualMachineScaleSetVM) error { // lock the VMSS entry to ensure a consistent view of the VM map when there are concurrent updates. cacheKey := getVMSSVMCacheKey(resourceGroupName, vmssName) ss.lockMap.LockEntry(cacheKey) diff --git a/pkg/provider/azure_vmss_repo.go b/pkg/provider/azure_vmss_repo.go index 89b5a900f2..662e4aa604 100644 --- a/pkg/provider/azure_vmss_repo.go +++ b/pkg/provider/azure_vmss_repo.go @@ -19,15 +19,15 @@ package provider import ( "strings" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "k8s.io/klog/v2" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) // CreateOrUpdateVMSS invokes az.VirtualMachineScaleSetsClient.Update(). -func (az *Cloud) CreateOrUpdateVMSS(resourceGroupName string, VMScaleSetName string, parameters compute.VirtualMachineScaleSet) *retry.Error { +func (az *Cloud) CreateOrUpdateVMSS(resourceGroupName string, VMScaleSetName string, parameters vmssclient.VirtualMachineScaleSet) *retry.Error { ctx, cancel := getContextWithCancel() defer cancel() @@ -44,8 +44,13 @@ func (az *Cloud) CreateOrUpdateVMSS(resourceGroupName string, VMScaleSetName str return nil } - rerr = az.VirtualMachineScaleSetsClient.CreateOrUpdate(ctx, resourceGroupName, VMScaleSetName, parameters) - klog.V(10).Infof("UpdateVmssVMWithRetry: VirtualMachineScaleSetsClient.CreateOrUpdate(%s): end", VMScaleSetName) + etag := "" + if parameters.Etag != nil { + etag = *parameters.Etag + } + + rerr = az.VirtualMachineScaleSetsClient.CreateOrUpdate(ctx, resourceGroupName, VMScaleSetName, parameters, etag) + klog.V(10).Infof("CreateOrUpdateVMSS: VirtualMachineScaleSetsClient.CreateOrUpdate(%s): end", VMScaleSetName) if rerr != nil { klog.Errorf("CreateOrUpdateVMSS: error CreateOrUpdate vmss(%s): %v", VMScaleSetName, rerr) return rerr diff --git a/pkg/provider/azure_vmss_repo_test.go b/pkg/provider/azure_vmss_repo_test.go index 4528990c08..266c4ad0c9 100644 --- a/pkg/provider/azure_vmss_repo_test.go +++ b/pkg/provider/azure_vmss_repo_test.go @@ -35,6 +35,7 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" @@ -77,7 +78,7 @@ func TestCreateOrUpdateVMSS(t *testing.T) { mockVMSSClient := az.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, testVMSSName).Return(test.vmss, test.clientErr) - err := az.CreateOrUpdateVMSS(az.ResourceGroup, testVMSSName, compute.VirtualMachineScaleSet{}) + err := az.CreateOrUpdateVMSS(az.ResourceGroup, testVMSSName, vmssclient.VirtualMachineScaleSet{}) assert.Equal(t, test.expectedErr, err) } } diff --git a/pkg/provider/azure_vmss_test.go b/pkg/provider/azure_vmss_test.go index e0d6aad139..45e0185157 100644 --- a/pkg/provider/azure_vmss_test.go +++ b/pkg/provider/azure_vmss_test.go @@ -39,6 +39,7 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient/mockvmssvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" @@ -158,8 +159,8 @@ func buildTestVMSS(name, computerNamePrefix string) compute.VirtualMachineScaleS } } -func buildTestVirtualMachineEnv(ss *Cloud, scaleSetName, zone string, faultDomain int32, vmList []string, state string, isIPv6 bool) ([]compute.VirtualMachineScaleSetVM, network.Interface, network.PublicIPAddress) { - expectedVMSSVMs := make([]compute.VirtualMachineScaleSetVM, 0) +func buildTestVirtualMachineEnv(ss *Cloud, scaleSetName, zone string, faultDomain int32, vmList []string, state string, isIPv6 bool) ([]vmssvmclient.VirtualMachineScaleSetVM, network.Interface, network.PublicIPAddress) { + expectedVMSSVMs := make([]vmssvmclient.VirtualMachineScaleSetVM, 0) expectedInterface := network.Interface{} expectedPIP := network.PublicIPAddress{} @@ -211,30 +212,32 @@ func buildTestVirtualMachineEnv(ss *Cloud, scaleSetName, zone string, faultDomai }, } - vmssVM := compute.VirtualMachineScaleSetVM{ - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - ProvisioningState: ptr.To(state), - OsProfile: &compute.OSProfile{ - ComputerName: &nodeName, - }, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &networkInterfaces, - }, - NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ - NetworkInterfaceConfigurations: &networkConfigurations, - }, - InstanceView: &compute.VirtualMachineScaleSetVMInstanceView{ - PlatformFaultDomain: &faultDomain, - Statuses: &[]compute.InstanceViewStatus{ - {Code: ptr.To(testVMPowerState)}, + vmssVM := vmssvmclient.VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ + ProvisioningState: ptr.To(state), + OsProfile: &compute.OSProfile{ + ComputerName: &nodeName, + }, + NetworkProfile: &compute.NetworkProfile{ + NetworkInterfaces: &networkInterfaces, + }, + NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ + NetworkInterfaceConfigurations: &networkConfigurations, + }, + InstanceView: &compute.VirtualMachineScaleSetVMInstanceView{ + PlatformFaultDomain: &faultDomain, + Statuses: &[]compute.InstanceViewStatus{ + {Code: ptr.To(testVMPowerState)}, + }, }, }, + ID: &ID, + InstanceID: &instanceID, + Name: &vmName, + Location: &ss.Location, + Sku: &compute.Sku{Name: ptr.To("sku")}, }, - ID: &ID, - InstanceID: &instanceID, - Name: &vmName, - Location: &ss.Location, - Sku: &compute.Sku{Name: ptr.To("sku")}, } if zone != "" { zones := []string{zone} @@ -790,7 +793,7 @@ func TestGetVmssVM(t *testing.T) { mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, test.existedVMSSName, "", 0, test.existedNodeNames, "", false) - var expectedVMSSVM compute.VirtualMachineScaleSetVM + var expectedVMSSVM vmssvmclient.VirtualMachineScaleSetVM for _, expected := range expectedVMSSVMs { if strings.EqualFold(*expected.OsProfile.ComputerName, test.nodeName) { expectedVMSSVM = expected @@ -1126,11 +1129,13 @@ func TestGetPrimaryInterfaceID(t *testing.T) { assert.NoError(t, err, "unexpected error when creating test VMSS") existedInterfaces := test.existedInterfaces - vm := compute.VirtualMachineScaleSetVM{ - Name: ptr.To("vm"), - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &existedInterfaces, + vm := vmssvmclient.VirtualMachineScaleSetVM{ + VirtualMachineScaleSetVM: compute.VirtualMachineScaleSetVM{ + Name: ptr.To("vm"), + VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ + NetworkProfile: &compute.NetworkProfile{ + NetworkInterfaces: &existedInterfaces, + }, }, }, } @@ -2567,7 +2572,7 @@ func TestEnsureVMSSInPool(t *testing.T) { vmssPutTimes = 1 mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSS, nil) } - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(nil).Times(vmssPutTimes) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any(), "").Return(nil).Times(vmssPutTimes) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, []string{"vmss-vm-000000"}, "", test.setIPv6Config) mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) @@ -2661,7 +2666,7 @@ func TestEnsureHostsInPool(t *testing.T) { mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any(), "").Return(nil).MaxTimes(1) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, []string{"vmss-vm-000000", "vmss-vm-000001", "vmss-vm-000002"}, "", false) mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) @@ -2869,7 +2874,7 @@ func TestEnsureBackendPoolDeletedFromVMSS(t *testing.T) { vmssPutTimes = 1 mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSS, nil) } - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(test.vmssClientErr).Times(vmssPutTimes) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any(), "").Return(test.vmssClientErr).Times(vmssPutTimes) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, []string{"vmss-vm-000000"}, "", false) mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) @@ -2975,7 +2980,7 @@ func TestEnsureBackendPoolDeleted(t *testing.T) { mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(nil).Times(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any(), "").Return(nil).Times(1) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, []string{"vmss-vm-000000", "vmss-vm-000001", "vmss-vm-000002"}, "", false) mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) @@ -3048,7 +3053,7 @@ func TestEnsureBackendPoolDeletedConcurrently(t *testing.T) { expectedVMSSVMsOfVMSS0, _, _ := buildTestVirtualMachineEnv(ss.Cloud, "vmss-0", "", 0, []string{"vmss-0-vm-000000"}, "succeeded", false) expectedVMSSVMsOfVMSS1, _, _ := buildTestVirtualMachineEnv(ss.Cloud, "vmss-1", "", 0, []string{"vmss-1-vm-000001"}, "succeeded", false) - for _, expectedVMSSVMs := range [][]compute.VirtualMachineScaleSetVM{expectedVMSSVMsOfVMSS0, expectedVMSSVMsOfVMSS1} { + for _, expectedVMSSVMs := range [][]vmssvmclient.VirtualMachineScaleSetVM{expectedVMSSVMsOfVMSS0, expectedVMSSVMsOfVMSS1} { vmssVMNetworkConfigs := expectedVMSSVMs[0].NetworkProfileConfiguration vmssVMIPConfigs := (*vmssVMNetworkConfigs.NetworkInterfaceConfigurations)[0].VirtualMachineScaleSetNetworkConfigurationProperties.IPConfigurations lbBackendpools := (*vmssVMIPConfigs)[0].LoadBalancerBackendAddressPools @@ -3063,7 +3068,7 @@ func TestEnsureBackendPoolDeletedConcurrently(t *testing.T) { mockVMSSClient.EXPECT().List(gomock.Any(), "rg1").Return(nil, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, "vmss-0").Return(vmss0, nil).MaxTimes(2) mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, "vmss-1").Return(vmss1, nil).MaxTimes(2) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil).Times(2) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, gomock.Any(), gomock.Any(), "").Return(nil).Times(2) mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) mockVMSSVMClient.EXPECT().List(gomock.Any(), "rg1", "vmss-0", gomock.Any()).Return(nil, nil).AnyTimes() diff --git a/pkg/provider/azure_vmssflex.go b/pkg/provider/azure_vmssflex.go index 662eded7be..2f5481e81b 100644 --- a/pkg/provider/azure_vmssflex.go +++ b/pkg/provider/azure_vmssflex.go @@ -35,6 +35,8 @@ import ( "k8s.io/klog/v2" "k8s.io/utils/ptr" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/metrics" @@ -444,7 +446,7 @@ func (fs *FlexScaleSet) GetNodeCIDRMasksByProviderID(ctx context.Context, provid // EnsureHostInPool ensures the given VM's Primary NIC's Primary IP Configuration is // participating in the specified LoadBalancer Backend Pool, which returns (resourceGroup, vmasName, instanceID, vmssVM, error). -func (fs *FlexScaleSet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetNameOfLB string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { +func (fs *FlexScaleSet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetNameOfLB string) (string, string, string, *vmssvmclient.VirtualMachineScaleSetVM, error) { serviceName := getServiceName(service) name := mapNodeNameToVMName(nodeName) vmssFlexName, err := fs.getNodeVmssFlexName(ctx, name) @@ -676,13 +678,15 @@ func (fs *FlexScaleSet) ensureVMSSFlexInPool(ctx context.Context, _ *v1.Service, ID: ptr.To(backendPoolID), }) primaryIPConfig.LoadBalancerBackendAddressPools = &loadBalancerBackendAddressPools - newVMSS := compute.VirtualMachineScaleSet{ - Location: vmssFlex.Location, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: &vmssNIC, - NetworkAPIVersion: compute.TwoZeroTwoZeroHyphenMinusOneOneHyphenMinusZeroOne, + newVMSS := vmssclient.VirtualMachineScaleSet{ + VirtualMachineScaleSet: compute.VirtualMachineScaleSet{ + Location: vmssFlex.Location, + VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ + NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: &vmssNIC, + NetworkAPIVersion: compute.TwoZeroTwoZeroHyphenMinusOneOneHyphenMinusZeroOne, + }, }, }, }, @@ -763,7 +767,7 @@ func (fs *FlexScaleSet) ensureBackendPoolDeletedFromVmssFlex(ctx context.Context } vmssFlexes := cached.(*sync.Map) vmssFlexes.Range(func(_, value interface{}) bool { - vmssFlex := value.(*compute.VirtualMachineScaleSet) + vmssFlex := value.(*vmssclient.VirtualMachineScaleSet) vmssNamesMap[ptr.Deref(vmssFlex.Name, "")] = true return true }) @@ -820,13 +824,15 @@ func (fs *FlexScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmssUpdaters = append(vmssUpdaters, func() error { // Compose a new vmss with added backendPoolID. - newVMSS := compute.VirtualMachineScaleSet{ - Location: vmss.Location, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: &vmssNIC, - NetworkAPIVersion: compute.TwoZeroTwoZeroHyphenMinusOneOneHyphenMinusZeroOne, + newVMSS := vmssclient.VirtualMachineScaleSet{ + VirtualMachineScaleSet: compute.VirtualMachineScaleSet{ + Location: vmss.Location, + VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ + NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: &vmssNIC, + NetworkAPIVersion: compute.TwoZeroTwoZeroHyphenMinusOneOneHyphenMinusZeroOne, + }, }, }, }, diff --git a/pkg/provider/azure_vmssflex_cache.go b/pkg/provider/azure_vmssflex_cache.go index 59c423a8ad..3c1146d7c5 100644 --- a/pkg/provider/azure_vmssflex_cache.go +++ b/pkg/provider/azure_vmssflex_cache.go @@ -30,6 +30,7 @@ import ( "k8s.io/klog/v2" "k8s.io/utils/ptr" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" ) @@ -185,7 +186,7 @@ func (fs *FlexScaleSet) getNodeVmssFlexID(ctx context.Context, nodeName string) var vmssFlexIDs []string vmssFlexes.Range(func(key, value interface{}) bool { vmssFlexID := key.(string) - vmssFlex := value.(*compute.VirtualMachineScaleSet) + vmssFlex := value.(*vmssclient.VirtualMachineScaleSet) vmssPrefix := ptr.Deref(vmssFlex.Name, "") if vmssFlex.VirtualMachineProfile != nil && vmssFlex.VirtualMachineProfile.OsProfile != nil && @@ -244,14 +245,14 @@ func (fs *FlexScaleSet) getVmssFlexVM(ctx context.Context, nodeName string, crt return *(cachedVM.(*compute.VirtualMachine)), nil } -func (fs *FlexScaleSet) getVmssFlexByVmssFlexID(ctx context.Context, vmssFlexID string, crt azcache.AzureCacheReadType) (*compute.VirtualMachineScaleSet, error) { +func (fs *FlexScaleSet) getVmssFlexByVmssFlexID(ctx context.Context, vmssFlexID string, crt azcache.AzureCacheReadType) (*vmssclient.VirtualMachineScaleSet, error) { cached, err := fs.vmssFlexCache.Get(ctx, consts.VmssFlexKey, crt) if err != nil { return nil, err } vmssFlexes := cached.(*sync.Map) if vmssFlex, ok := vmssFlexes.Load(vmssFlexID); ok { - result := vmssFlex.(*compute.VirtualMachineScaleSet) + result := vmssFlex.(*vmssclient.VirtualMachineScaleSet) return result, nil } @@ -262,13 +263,13 @@ func (fs *FlexScaleSet) getVmssFlexByVmssFlexID(ctx context.Context, vmssFlexID } vmssFlexes = cached.(*sync.Map) if vmssFlex, ok := vmssFlexes.Load(vmssFlexID); ok { - result := vmssFlex.(*compute.VirtualMachineScaleSet) + result := vmssFlex.(*vmssclient.VirtualMachineScaleSet) return result, nil } return nil, cloudprovider.InstanceNotFound } -func (fs *FlexScaleSet) getVmssFlexByNodeName(ctx context.Context, nodeName string, crt azcache.AzureCacheReadType) (*compute.VirtualMachineScaleSet, error) { +func (fs *FlexScaleSet) getVmssFlexByNodeName(ctx context.Context, nodeName string, crt azcache.AzureCacheReadType) (*vmssclient.VirtualMachineScaleSet, error) { vmssFlexID, err := fs.getNodeVmssFlexID(ctx, nodeName) if err != nil { return nil, err @@ -305,17 +306,17 @@ func (fs *FlexScaleSet) getVmssFlexIDByName(ctx context.Context, vmssFlexName st return "", cloudprovider.InstanceNotFound } -func (fs *FlexScaleSet) getVmssFlexByName(ctx context.Context, vmssFlexName string) (*compute.VirtualMachineScaleSet, error) { +func (fs *FlexScaleSet) getVmssFlexByName(ctx context.Context, vmssFlexName string) (*vmssclient.VirtualMachineScaleSet, error) { cached, err := fs.vmssFlexCache.Get(ctx, consts.VmssFlexKey, azcache.CacheReadTypeDefault) if err != nil { return nil, err } - var targetVmssFlex *compute.VirtualMachineScaleSet + var targetVmssFlex *vmssclient.VirtualMachineScaleSet vmssFlexes := cached.(*sync.Map) vmssFlexes.Range(func(key, value interface{}) bool { vmssFlexID := key.(string) - vmssFlex := value.(*compute.VirtualMachineScaleSet) + vmssFlex := value.(*vmssclient.VirtualMachineScaleSet) name, err := getLastSegment(vmssFlexID, "/") if err != nil { return true diff --git a/pkg/provider/azure_vmssflex_test.go b/pkg/provider/azure_vmssflex_test.go index 0adc72a4ee..01b9008250 100644 --- a/pkg/provider/azure_vmssflex_test.go +++ b/pkg/provider/azure_vmssflex_test.go @@ -1262,7 +1262,7 @@ func TestEnsureVMSSFlexInPool(t *testing.T) { mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(expectedestVmssFlexList, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(testVmssFlex1, nil).AnyTimes() - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.vmssPutErr).AnyTimes() + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "").Return(tc.vmssPutErr).AnyTimes() mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() @@ -1369,7 +1369,7 @@ func TestEnsureHostsInPoolVmssFlex(t *testing.T) { mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{genreteTestVmssFlex("vmssflex1", testVmssFlex1ID)}, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(testVmssFlex1, nil).AnyTimes() - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.vmssPutErr).AnyTimes() + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "").Return(tc.vmssPutErr).AnyTimes() mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() @@ -1519,7 +1519,7 @@ func TestEnsureBackendPoolDeletedFromVMSetsVmssFlex(t *testing.T) { mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(vmssFlexList, nil).Times(tc.vmssListCallingTimes) mockVMSSClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(testVmssFlex1, nil).AnyTimes() - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.vmssPutErr).AnyTimes() + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "").Return(tc.vmssPutErr).AnyTimes() err = fs.EnsureBackendPoolDeletedFromVMSets(context.TODO(), tc.vmssNamesMap, []string{tc.backendPoolID}) _, _ = fs.getVmssFlexByName(context.TODO(), "vmssflex1") @@ -1707,7 +1707,7 @@ func TestEnsureBackendPoolDeletedVmssFlex(t *testing.T) { mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(vmssFlexList, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(testVmssFlex1, nil).AnyTimes() - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.vmssPutErr).AnyTimes() + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), "").Return(tc.vmssPutErr).AnyTimes() mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() diff --git a/pkg/provider/virtualmachine/virtualmachine.go b/pkg/provider/virtualmachine/virtualmachine.go index f9b482298f..8977d7f1c6 100644 --- a/pkg/provider/virtualmachine/virtualmachine.go +++ b/pkg/provider/virtualmachine/virtualmachine.go @@ -21,6 +21,7 @@ import ( "k8s.io/utils/ptr" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" ) @@ -51,7 +52,7 @@ func ByVMSS(vmssName string) ManageOption { type VirtualMachine struct { Variant Variant vm *compute.VirtualMachine - vmssVM *compute.VirtualMachineScaleSetVM + vmssVM *vmssvmclient.VirtualMachineScaleSetVM Manage Manage VMSSName string @@ -67,6 +68,8 @@ type VirtualMachine struct { Plan *compute.Plan Resources *[]compute.VirtualMachineExtension + Etag *string + // fields of VirtualMachine Identity *compute.VirtualMachineIdentity VirtualMachineProperties *compute.VirtualMachineProperties @@ -102,7 +105,7 @@ func FromVirtualMachine(vm *compute.VirtualMachine, opt ...ManageOption) *Virtua return v } -func FromVirtualMachineScaleSetVM(vm *compute.VirtualMachineScaleSetVM, opt ManageOption) *VirtualMachine { +func FromVirtualMachineScaleSetVM(vm *vmssvmclient.VirtualMachineScaleSetVM, opt ManageOption) *VirtualMachine { v := &VirtualMachine{ Variant: VariantVirtualMachineScaleSetVM, vmssVM: vm, @@ -115,6 +118,7 @@ func FromVirtualMachineScaleSetVM(vm *compute.VirtualMachineScaleSetVM, opt Mana Zones: stringSlice(vm.Zones), Plan: vm.Plan, Resources: vm.Resources, + Etag: vm.Etag, SKU: vm.Sku, InstanceID: ptr.Deref(vm.InstanceID, ""), @@ -144,7 +148,7 @@ func (vm *VirtualMachine) AsVirtualMachine() *compute.VirtualMachine { return vm.vm } -func (vm *VirtualMachine) AsVirtualMachineScaleSetVM() *compute.VirtualMachineScaleSetVM { +func (vm *VirtualMachine) AsVirtualMachineScaleSetVM() *vmssvmclient.VirtualMachineScaleSetVM { return vm.vmssVM } diff --git a/pkg/retry/azure_error.go b/pkg/retry/azure_error.go index 923e33de0d..7c6cf32e03 100644 --- a/pkg/retry/azure_error.go +++ b/pkg/retry/azure_error.go @@ -60,7 +60,8 @@ type Error struct { // RetryAfter indicates the time when the request should retry after throttling. // A throttled request is retriable. RetryAfter time.Time - // RetryAfter indicates the raw error from API. + // RawError indicates the raw error from API. + // It's beneficial to errors.Is() or errors.As() to check the real error type. RawError error } @@ -197,6 +198,10 @@ func getRawError(resp *http.Response, err error) error { return fmt.Errorf("HTTP status code (%d)", resp.StatusCode) } + if IsPreconditionFailedEtagMismatch(resp.StatusCode, string(respBody)) { + return NewEtagMismatchError(resp.StatusCode, string(respBody)) + } + // return the raw response body. return fmt.Errorf("%s", string(respBody)) } diff --git a/pkg/retry/azure_error_test.go b/pkg/retry/azure_error_test.go index 9fa91d334a..9b07c643b5 100644 --- a/pkg/retry/azure_error_test.go +++ b/pkg/retry/azure_error_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + "github.com/Azure/go-autorest/autorest/mocks" "github.com/stretchr/testify/assert" ) @@ -159,6 +160,16 @@ func TestGetErrorNil(t *testing.T) { assert.Equal(t, fmt.Errorf("empty HTTP response"), rerr.RawError) } +func TestGetErrorErrorIs(t *testing.T) { + fakeResp := &http.Response{ + StatusCode: http.StatusPreconditionFailed, + Body: mocks.NewBody("Etag provided in if-match header \"771\" does not match etag \"773\" of resource."), + } + err := GetError(fakeResp, nil) + fmt.Println(errors.Is(err.Error(), &EtagMismatchError{})) + assert.True(t, errors.Is(err.Error(), &EtagMismatchError{})) +} + func TestGetStatusNotFoundAndForbiddenIgnoredError(t *testing.T) { now = func() time.Time { return time.Time{} diff --git a/pkg/retry/etagmismatch_error.go b/pkg/retry/etagmismatch_error.go new file mode 100644 index 0000000000..7facfb4253 --- /dev/null +++ b/pkg/retry/etagmismatch_error.go @@ -0,0 +1,98 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retry + +import ( + "fmt" + "net/http" + "regexp" + "strings" +) + +const ( + // TODO(mainred): etag mismatch error message is not consistent across Azure APIs, we need to normalize it or add more error messages. + // Example error message for CRP + // Etag provided in if-match header {0} does not match etag {1} of resource. + // where {0} and {1} are the etags provided in the request and the resource respectively. + EtagMismatchPattern = `Etag provided in if-match header ([^\s]+) does not match etag ([^\s]+) of resource` + + EtagMismatchErrorTag = "EtagMismatchError" +) + +type EtagMismatchError struct { + currentEtag string + latestEtag string +} + +func NewEtagMismatchError(httpStatusCode int, respBody string) *EtagMismatchError { + if httpStatusCode != http.StatusPreconditionFailed { + return nil + } + + currentEtag, latestEtag, match := getMatchedLatestAndCurrentEtags(respBody) + if !match { + return nil + } + + return &EtagMismatchError{ + currentEtag: currentEtag, + latestEtag: latestEtag, + } +} + +func (e *EtagMismatchError) Error() string { + return fmt.Sprintf("%s: etag %s does not match etag %s of resource", EtagMismatchErrorTag, e.currentEtag, e.latestEtag) +} + +func (e *EtagMismatchError) CurrentEtag() string { + return e.currentEtag +} + +func (e *EtagMismatchError) LatestEtag() string { + return e.latestEtag +} + +func (e *EtagMismatchError) Is(target error) bool { + return strings.Contains(target.Error(), EtagMismatchErrorTag) +} + +// isPreconditionFailedEtagMismatch returns true the if the request failed for Etag mismatch +func IsPreconditionFailedEtagMismatch(httpStatusCode int, respBody string) bool { + + if httpStatusCode != http.StatusPreconditionFailed { + return false + } + + _, _, match := getMatchedLatestAndCurrentEtags(respBody) + return match +} + +func getMatchedLatestAndCurrentEtags(respBody string) (string, string, bool) { + + var currentEtag, latestEtag string + re := regexp.MustCompile(EtagMismatchPattern) + matches := re.FindStringSubmatch(respBody) + + if len(matches) != 3 { + return currentEtag, latestEtag, false + } + + currentEtag = matches[1] + latestEtag = matches[2] + + return currentEtag, latestEtag, true +} diff --git a/pkg/retry/etagmismatch_error_test.go b/pkg/retry/etagmismatch_error_test.go new file mode 100644 index 0000000000..832b683ba4 --- /dev/null +++ b/pkg/retry/etagmismatch_error_test.go @@ -0,0 +1,67 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retry + +import ( + "io" + "net/http" + "testing" + + "github.com/Azure/go-autorest/autorest/mocks" +) + +func TestIsPreconditionFailedEtagMismatch(t *testing.T) { + tests := []struct { + desc string + resp *http.Response + expectedMatch bool + }{ + { + desc: "status code and response body both match should return true", + resp: &http.Response{ + StatusCode: http.StatusPreconditionFailed, + Body: mocks.NewBody("Etag provided in if-match header \"771\" does not match etag \"773\" of resource."), + }, + expectedMatch: true, + }, + { + desc: "status code match and response body mismatch should return false", + resp: &http.Response{ + StatusCode: http.StatusPreconditionFailed, + Body: mocks.NewBody(""), + }, + expectedMatch: false, + }, + { + desc: "status code mismatch should return false", + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: mocks.NewBody(""), + }, + expectedMatch: false, + }, + } + + for _, test := range tests { + defer test.resp.Body.Close() + respBody, _ := io.ReadAll(test.resp.Body) + actualMatch := IsPreconditionFailedEtagMismatch(test.resp.StatusCode, string(respBody)) + if actualMatch != test.expectedMatch { + t.Errorf("test [%q] get unexpected result: %v != %v", test.desc, actualMatch, test.expectedMatch) + } + } +}