Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(creation): remove polling from creation #249

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
12 changes: 9 additions & 3 deletions pkg/fake/azureresourcegraphapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,15 @@ func TestAzureResourceGraphAPI_Resources_VM(t *testing.T) {
for _, c := range cases {
t.Run(c.testName, func(t *testing.T) {
for _, name := range c.vmNames {
_, err := instance.CreateVirtualMachine(context.Background(), virtualMachinesAPI, resourceGroup, name, armcompute.VirtualMachine{Tags: c.tags})
if err != nil {
t.Errorf("Unexpected error %v", err)
ctx := context.Background()
_, errRetriever := instance.CreateVirtualMachine(ctx, virtualMachinesAPI, resourceGroup, name, armcompute.VirtualMachine{Tags: c.tags})
if errRetriever.GetFrontendErr() != nil {
t.Errorf("Unexpected frontend error %v", errRetriever.GetFrontendErr())
return
}
asyncErr := errRetriever.WaitForAsyncErr(ctx)
if asyncErr != nil {
t.Errorf("Unexpected async error %v", asyncErr)
return
}
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/fake/virtualmachineextensionsapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ func (c *VirtualMachineExtensionsAPI) BeginCreateOrUpdate(_ context.Context, res
})
}

func (c *VirtualMachineExtensionsAPI) Get(ctx context.Context, resourceGroupName string, vmName string, vmExtensionName string, options *armcompute.VirtualMachineExtensionsClientGetOptions) (armcompute.VirtualMachineExtensionsClientGetResponse, error) {
return armcompute.VirtualMachineExtensionsClientGetResponse{}, nil
}

func mkVMExtensionID(resourceGroupName, vmName, extensionName string) string {
const idFormat = "/subscriptions/subscriptionID/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s/extensions/%s"
return fmt.Sprintf(idFormat, resourceGroupName, vmName, extensionName)
Expand Down
12 changes: 9 additions & 3 deletions pkg/fake/virtualmachinesapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@ func TestComputeAPI_BeginCreateOrUpdate(t *testing.T) {
// return nil, nil
//})
// test
vm, err := instance.CreateVirtualMachine(context.Background(), computeAPI, "resourceGroupName", "vmName", armcompute.VirtualMachine{})
ctx := context.Background()
vm, errRetriever := instance.CreateVirtualMachine(ctx, computeAPI, "resourceGroupName", "vmName", armcompute.VirtualMachine{})
// verify
if err != nil {
t.Errorf("Unexpected error %v", err)
if errRetriever.GetFrontendErr() != nil {
t.Errorf("Unexpected frontend error %v", errRetriever.GetFrontendErr())
return
}
asyncErr := errRetriever.WaitForAsyncErr(ctx)
if asyncErr != nil {
t.Errorf("Unexpected async error %v", asyncErr)
return
}
if vm == nil {
Expand Down
64 changes: 48 additions & 16 deletions pkg/providers/instance/armutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,53 @@ import (
"context"

sdkerrors "github.com/Azure/azure-sdk-for-go-extensions/pkg/errors"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
)

func CreateVirtualMachine(ctx context.Context, client VirtualMachinesAPI, rg, vmName string, vm armcompute.VirtualMachine) (*armcompute.VirtualMachine, error) {
type ErrorRetriever interface {
GetFrontendErr() error
WaitForAsyncErr(ctx context.Context) error
}

type errorRetriever struct {
frontendErr error
asyncErrPoller func(ctx context.Context) error
}

// T is the type of the arm response object
func NewErrorRetriever[T any](frontendErr error, asyncPoller *runtime.Poller[T]) ErrorRetriever {
return &errorRetriever{
frontendErr: frontendErr,
asyncErrPoller: func(ctx context.Context) error {
if asyncPoller == nil {
return nil
}
_, err := asyncPoller.PollUntilDone(ctx, nil)
return err
},
}
}

func (er *errorRetriever) GetFrontendErr() error {
return er.frontendErr
}

func (er *errorRetriever) WaitForAsyncErr(ctx context.Context) error {
return er.asyncErrPoller(ctx)
}

func CreateVirtualMachine(ctx context.Context, client VirtualMachinesAPI, rg, vmName string, vm armcompute.VirtualMachine) (*armcompute.VirtualMachine, ErrorRetriever) {
poller, err := client.BeginCreateOrUpdate(ctx, rg, vmName, vm, nil)
if err != nil {
return nil, err
return nil, NewErrorRetriever[armcompute.VirtualMachinesClientCreateOrUpdateResponse](err, poller)
}
res, err := poller.PollUntilDone(ctx, nil)
vmget, err := client.Get(ctx, rg, vmName, nil)
if err != nil {
return nil, err
return nil, NewErrorRetriever[armcompute.VirtualMachinesClientCreateOrUpdateResponse](err, poller)
}
return &res.VirtualMachine, nil
return &vmget.VirtualMachine, NewErrorRetriever[armcompute.VirtualMachinesClientCreateOrUpdateResponse](nil, poller)
}

func UpdateVirtualMachine(ctx context.Context, client VirtualMachinesAPI, rg, vmName string, updates armcompute.VirtualMachineUpdate) error {
Expand Down Expand Up @@ -63,29 +96,28 @@ func deleteVirtualMachine(ctx context.Context, client VirtualMachinesAPI, rg, vm
return nil
}

func createVirtualMachineExtension(ctx context.Context, client VirtualMachineExtensionsAPI, rg, vmName, extensionName string, vmExt armcompute.VirtualMachineExtension) (*armcompute.VirtualMachineExtension, error) {
func createVirtualMachineExtension(ctx context.Context, client VirtualMachineExtensionsAPI, rg, vmName, extensionName string, vmExt armcompute.VirtualMachineExtension) (*armcompute.VirtualMachineExtension, ErrorRetriever) {
poller, err := client.BeginCreateOrUpdate(ctx, rg, vmName, extensionName, vmExt, nil)
if err != nil {
return nil, err
return nil, NewErrorRetriever[armcompute.VirtualMachineExtensionsClientCreateOrUpdateResponse](err, poller)
}
res, err := poller.PollUntilDone(ctx, nil)
getExt, err := client.Get(ctx, rg, vmName, extensionName, nil)
if err != nil {
return nil, err
return nil, NewErrorRetriever[armcompute.VirtualMachineExtensionsClientCreateOrUpdateResponse](err, poller)
}
return &res.VirtualMachineExtension, nil
return &getExt.VirtualMachineExtension, NewErrorRetriever[armcompute.VirtualMachineExtensionsClientCreateOrUpdateResponse](nil, poller)
}

func createNic(ctx context.Context, client NetworkInterfacesAPI, rg, nicName string, nic armnetwork.Interface) (*armnetwork.Interface, error) {
func createNic(ctx context.Context, client NetworkInterfacesAPI, rg, nicName string, nic armnetwork.Interface) (*armnetwork.Interface, ErrorRetriever) {
poller, err := client.BeginCreateOrUpdate(ctx, rg, nicName, nic, nil)
if err != nil {
return nil, err
return nil, NewErrorRetriever[armnetwork.InterfacesClientCreateOrUpdateResponse](err, poller)
}
res, err := poller.PollUntilDone(ctx, nil)

getNic, err := client.Get(ctx, rg, nicName, nil)
if err != nil {
return nil, err
return nil, NewErrorRetriever[armnetwork.InterfacesClientCreateOrUpdateResponse](err, poller)
}
return &res.Interface, nil
return &getNic.Interface, NewErrorRetriever[armnetwork.InterfacesClientCreateOrUpdateResponse](nil, poller)
}

func deleteNic(ctx context.Context, client NetworkInterfacesAPI, rg, nicName string) error {
Expand Down
1 change: 1 addition & 0 deletions pkg/providers/instance/azure_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type AzureResourceGraphAPI interface {

type VirtualMachineExtensionsAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, vmExtensionName string, extensionParameters armcompute.VirtualMachineExtension, options *armcompute.VirtualMachineExtensionsClientBeginCreateOrUpdateOptions) (*runtime.Poller[armcompute.VirtualMachineExtensionsClientCreateOrUpdateResponse], error)
Get(ctx context.Context, resourceGroupName string, vmName string, vmExtensionName string, options *armcompute.VirtualMachineExtensionsClientGetOptions) (armcompute.VirtualMachineExtensionsClientGetResponse, error)
}

type NetworkInterfacesAPI interface {
Expand Down
46 changes: 26 additions & 20 deletions pkg/providers/instance/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,12 @@ func (p *Provider) createAKSIdentifyingExtension(ctx context.Context, vmName str
vmExt := p.getAKSIdentifyingExtension()
vmExtName := *vmExt.Name
logging.FromContext(ctx).Debugf("Creating virtual machine AKS identifying extension for %s", vmName)
v, err := createVirtualMachineExtension(ctx, p.azClient.virtualMachinesExtensionClient, p.resourceGroup, vmName, vmExtName, *vmExt)
if err != nil {
logging.FromContext(ctx).Errorf("Creating VM AKS identifying extension for VM %q failed, %w", vmName, err)
return fmt.Errorf("creating VM AKS identifying extension for VM %q, %w failed", vmName, err)
v, errRetriever := createVirtualMachineExtension(ctx, p.azClient.virtualMachinesExtensionClient, p.resourceGroup, vmName, vmExtName, *vmExt)
if errRetriever.GetFrontendErr() != nil {
logging.FromContext(ctx).Errorf("Creating VM AKS identifying extension for VM %q failed on frontend request, %w", vmName, errRetriever.GetFrontendErr())
return fmt.Errorf("creating VM AKS identifying extension for VM %q, %w failed on frontend request", vmName, errRetriever.GetFrontendErr())
}
logging.FromContext(ctx).Debugf("Created virtual machine AKS identifying extension for %s, with an id of %s", vmName, *v.ID)
logging.FromContext(ctx).Debugf("Created virtual machine AKS identifying extension for %s, with an id of %s", vmName, *v.ID)
return nil
}

Expand Down Expand Up @@ -272,9 +272,9 @@ func (p *Provider) createNetworkInterface(ctx context.Context, opts *createNICOp
nic := p.newNetworkInterfaceForVM(opts)
p.applyTemplateToNic(&nic, opts.LaunchTemplate)
logging.FromContext(ctx).Debugf("Creating network interface %s", opts.NICName)
res, err := createNic(ctx, p.azClient.networkInterfacesClient, p.resourceGroup, opts.NICName, nic)
if err != nil {
return "", err
res, errRetriever := createNic(ctx, p.azClient.networkInterfacesClient, p.resourceGroup, opts.NICName, nic)
if errRetriever.GetFrontendErr() != nil {
return "", errRetriever.GetFrontendErr()
}
logging.FromContext(ctx).Debugf("Successfully created network interface: %v", *res.ID)
return *res.ID, nil
Expand Down Expand Up @@ -380,19 +380,18 @@ func setVMPropertiesBillingProfile(vmProperties *armcompute.VirtualMachineProper

// setNodePoolNameTag sets "karpenter.sh/nodepool" tag
func setNodePoolNameTag(tags map[string]*string, nodeClaim *corev1beta1.NodeClaim) {
if val, ok := nodeClaim.Labels[corev1beta1.NodePoolLabelKey]; ok {
tags[NodePoolTagKey] = &val
}
nodePoolLabel := nodeClaim.Labels[corev1beta1.NodePoolLabelKey]
tags[NodePoolTagKey] = &nodePoolLabel
Comment on lines -383 to +384
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we change this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good catch. This is unrelated to the change.

Pretty sure its a bug I found months back. Will separate it out into a new PR.

}

func (p *Provider) createVirtualMachine(ctx context.Context, vm armcompute.VirtualMachine, vmName string) (*armcompute.VirtualMachine, error) {
result, err := CreateVirtualMachine(ctx, p.azClient.virtualMachinesClient, p.resourceGroup, vmName, vm)
if err != nil {
logging.FromContext(ctx).Errorf("Creating virtual machine %q failed: %v", vmName, err)
return nil, fmt.Errorf("virtualMachine.BeginCreateOrUpdate for VM %q failed: %w", vmName, err)
func (p *Provider) createVirtualMachine(ctx context.Context, vm armcompute.VirtualMachine, vmName string) (*armcompute.VirtualMachine, ErrorRetriever) {
result, errRetriever := CreateVirtualMachine(ctx, p.azClient.virtualMachinesClient, p.resourceGroup, vmName, vm)
if errRetriever.GetFrontendErr() != nil {
logging.FromContext(ctx).Errorf("Creating virtual machine %q failed on frontend request: %v", vmName, errRetriever.GetFrontendErr())
return result, errRetriever
}
logging.FromContext(ctx).Debugf("Created virtual machine %s", *result.ID)
return result, nil
return result, errRetriever
}

func (p *Provider) launchInstance(
Expand Down Expand Up @@ -437,11 +436,18 @@ func (p *Provider) launchInstance(

logging.FromContext(ctx).Debugf("Creating virtual machine %s (%s)", resourceName, instanceType.Name)
// Uses AZ Client to create a new virtual machine using the vm object we prepared earlier
resp, err := p.createVirtualMachine(ctx, vm, resourceName)
if err != nil {
azErr := p.handleResponseErrors(ctx, instanceType, zone, capacityType, err)
resp, errRetriever := p.createVirtualMachine(ctx, vm, resourceName)
if errRetriever.GetFrontendErr() != nil {
azErr := p.handleResponseErrors(ctx, instanceType, zone, capacityType, errRetriever.GetFrontendErr())
return nil, nil, azErr
}
go func() {
asyncRrr := errRetriever.WaitForAsyncErr(ctx)
if asyncRrr != nil {
azErr := p.handleResponseErrors(ctx, instanceType, zone, capacityType, asyncRrr)
logging.FromContext(ctx).Errorf("Creating virtual machine %q had async failure: %v", resourceName, azErr)
}
}()

err = p.createAKSIdentifyingExtension(ctx, resourceName)
if err != nil {
Expand Down
Loading