diff --git a/go.mod b/go.mod index f7ba42f3c..dac59789d 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,6 @@ require ( github.com/stretchr/testify v1.9.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 - golang.org/x/crypto v0.23.0 k8s.io/api v0.29.3 k8s.io/apiextensions-apiserver v0.29.3 k8s.io/apimachinery v0.29.3 @@ -126,6 +125,7 @@ require ( go.opentelemetry.io/otel/trace v1.21.0 // indirect go.uber.org/automaxprocs v1.5.3 // indirect go.uber.org/mock v0.4.0 // indirect + golang.org/x/crypto v0.23.0 // indirect golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/net v0.25.0 // indirect golang.org/x/oauth2 v0.16.0 // indirect diff --git a/karpenter-values-template.yaml b/karpenter-values-template.yaml index 5da0107d4..0deb08e90 100644 --- a/karpenter-values-template.yaml +++ b/karpenter-values-template.yaml @@ -34,15 +34,12 @@ controller: value: ${AZURE_SUBSCRIPTION_ID} - name: LOCATION value: ${AZURE_LOCATION} - # settings for managed workload ideneity - - name: ARM_USE_CREDENTIAL_FROM_ENVIRONMENT - value: "true" - - name: ARM_USE_MANAGED_IDENTITY_EXTENSION - value: "false" - - name: ARM_USER_ASSIGNED_IDENTITY_ID + - name: ARM_KUBELET_IDENTITY_CLIENT_ID value: "" - name: AZURE_NODE_RESOURCE_GROUP value: ${AZURE_RESOURCE_GROUP_MC} + - name: ARM_AUTH_METHOD + value: "workload-identity" serviceAccount: name: ${KARPENTER_SERVICE_ACCOUNT_NAME} annotations: diff --git a/pkg/auth/authmanager.go b/pkg/auth/authmanager.go new file mode 100644 index 000000000..a3d002d9c --- /dev/null +++ b/pkg/auth/authmanager.go @@ -0,0 +1,111 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +import ( + "fmt" + + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/adal" + "github.com/Azure/go-autorest/autorest/azure" + "k8s.io/klog/v2" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/jongio/azidext/go/azidext" +) + +const ( + // auth methods + AuthMethodSysMSI = "system-assigned-msi" + AuthMethodWorkloadIdentity = "workload-identity" +) + +// AuthManager manages the authentication logic for Azure clients used by Karpenter to make requests +type AuthManager struct { + authMethod string + location string +} + +func NewAuthManagerWorkloadIdentity(location string) *AuthManager { + return &AuthManager{ + authMethod: AuthMethodWorkloadIdentity, + location: location, + } + +} + +func NewAuthManagerSystemAssignedMSI(location string) *AuthManager { + return &AuthManager{ + authMethod: AuthMethodSysMSI, + location: location, + } +} + +// NewCredential provides a token credential +func (am AuthManager) NewCredential() (azcore.TokenCredential, error) { + if am.authMethod == AuthMethodWorkloadIdentity { + klog.V(2).Infoln("cred: using workload identity for new credential") + return azidentity.NewDefaultAzureCredential(nil) + } + + if am.authMethod == AuthMethodSysMSI { + klog.V(2).Infoln("cred: using system assigned MSI for new credential") + msiCred, err := azidentity.NewManagedIdentityCredential(nil) + if err != nil { + return nil, err + } + return msiCred, nil + } + + return nil, fmt.Errorf("cred: unsupported auth method: %s", am.authMethod) +} + +func (am AuthManager) NewAutorestAuthorizer() (autorest.Authorizer, error) { + // TODO (charliedmcb): need to get track 2 support for the skewer API, and align all auth under workload identity in the same way within cred.go + if am.authMethod == AuthMethodWorkloadIdentity { + klog.V(2).Infoln("auth: using workload identity for new authorizer") + cred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, fmt.Errorf("default cred: %w", err) + } + return azidext.NewTokenCredentialAdapter(cred, []string{azidext.DefaultManagementScope}), nil + } + + if am.authMethod == AuthMethodSysMSI { + klog.V(2).Infoln("auth: using system assigned MSI to retrieve access token") + msiEndpoint, err := adal.GetMSIVMEndpoint() + if err != nil { + return nil, fmt.Errorf("getting the managed service identity endpoint: %w", err) + } + + azureEnvironment, err := azure.EnvironmentFromName(am.location) + if err != nil { + return nil, fmt.Errorf("failed to get AzureEnvironment: %w", err) + } + + token, err := adal.NewServicePrincipalTokenFromMSI( + msiEndpoint, + azureEnvironment.ServiceManagementEndpoint) + if err != nil { + return nil, fmt.Errorf("retrieve service principal token: %w", err) + } + return autorest.NewBearerAuthorizer(token), nil + } + + return nil, fmt.Errorf("auth: unsupported auth method %s", am.authMethod) +} diff --git a/pkg/auth/autorest_auth.go b/pkg/auth/autorest_auth.go deleted file mode 100644 index 133e631a7..000000000 --- a/pkg/auth/autorest_auth.go +++ /dev/null @@ -1,105 +0,0 @@ -/* -Portions Copyright (c) Microsoft Corporation. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package auth - -import ( - "fmt" - "os" - - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/adal" - "github.com/Azure/go-autorest/autorest/azure" - "k8s.io/klog/v2" - - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/jongio/azidext/go/azidext" -) - -func NewAuthorizer(config *Config, env *azure.Environment) (autorest.Authorizer, error) { - // TODO (charliedmcb): need to get track 2 support for the skewer API, and align all auth under workload identity in the same way within cred.go - if config.UseCredentialFromEnvironment { - klog.V(2).Infoln("auth: using workload identity for new authorizer") - cred, err := azidentity.NewDefaultAzureCredential(nil) - if err != nil { - return nil, fmt.Errorf("default cred: %w", err) - } - return azidext.NewTokenCredentialAdapter(cred, []string{azidext.DefaultManagementScope}), nil - } - - token, err := newServicePrincipalTokenFromCredentials(config, env) - if err != nil { - return nil, fmt.Errorf("retrieve service principal token: %w", err) - } - return autorest.NewBearerAuthorizer(token), nil -} - -// newServicePrincipalTokenFromCredentials creates a new ServicePrincipalToken using values of the -// passed credentials map. -func newServicePrincipalTokenFromCredentials(config *Config, env *azure.Environment) (*adal.ServicePrincipalToken, error) { - oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, config.TenantID) - if err != nil { - return nil, fmt.Errorf("creating the OAuth config: %w", err) - } - - if config.UseManagedIdentityExtension { - klog.V(2).Infoln("azure: using managed identity extension to retrieve access token") - msiEndpoint, err := adal.GetMSIVMEndpoint() - if err != nil { - return nil, fmt.Errorf("getting the managed service identity endpoint: %w", err) - } - - if len(config.UserAssignedIdentityID) > 0 { - klog.V(4).Info("azure: using User Assigned MSI ID to retrieve access token") - return adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, - env.ServiceManagementEndpoint, - config.UserAssignedIdentityID) - } - klog.V(4).Info("azure: using System Assigned MSI to retrieve access token") - return adal.NewServicePrincipalTokenFromMSI( - msiEndpoint, - env.ServiceManagementEndpoint) - } - - if len(config.AADClientSecret) > 0 { - klog.V(2).Infoln("azure: using client_id+client_secret to retrieve access token") - return adal.NewServicePrincipalToken( - *oauthConfig, - config.AADClientID, - config.AADClientSecret, - env.ServiceManagementEndpoint) - } - - if len(config.AADClientCertPath) > 0 && len(config.AADClientCertPassword) > 0 { - klog.V(2).Infoln("azure: using jwt client_assertion (client_cert+client_private_key) to retrieve access token") - certData, err := os.ReadFile(config.AADClientCertPath) - if err != nil { - return nil, fmt.Errorf("reading the client certificate from file %s: %w", config.AADClientCertPath, err) - } - certificate, privateKey, err := decodePkcs12(certData, config.AADClientCertPassword) - if err != nil { - return nil, fmt.Errorf("decoding the client certificate: %w", err) - } - return adal.NewServicePrincipalTokenFromCertificate( - *oauthConfig, - config.AADClientID, - certificate, - privateKey, - env.ServiceManagementEndpoint) - } - - return nil, fmt.Errorf("no credentials provided for AAD application %s", config.AADClientID) -} diff --git a/pkg/auth/config.go b/pkg/auth/config.go deleted file mode 100644 index 83c927663..000000000 --- a/pkg/auth/config.go +++ /dev/null @@ -1,208 +0,0 @@ -/* -Portions Copyright (c) Microsoft Corporation. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package auth - -import ( - "fmt" - "os" - "strconv" - "strings" - - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/azure" -) - -const ( - // auth methods - authMethodPrincipal = "principal" - authMethodCLI = "cli" -) - -const ( - // from azure_manager - vmTypeVMSS = "vmss" -) - -type cfgField struct { - val string - name string -} - -// ClientConfig contains all essential information to create an Azure client. -type ClientConfig struct { - CloudName string - Location string - SubscriptionID string - ResourceManagerEndpoint string - Authorizer autorest.Authorizer - UserAgent string -} - -// Config holds the configuration parsed from the --cloud-config flag -type Config struct { - Cloud string `json:"cloud" yaml:"cloud"` - Location string `json:"location" yaml:"location"` - TenantID string `json:"tenantId" yaml:"tenantId"` - SubscriptionID string `json:"subscriptionId" yaml:"subscriptionId"` - ResourceGroup string `json:"resourceGroup" yaml:"resourceGroup"` - VMType string `json:"vmType" yaml:"vmType"` - - // AuthMethod determines how to authorize requests for the Azure - // cloud. Valid options are "principal" (= the traditional - // service principle approach) and "cli" (= load az command line - // config file). The default is "principal". - AuthMethod string `json:"authMethod" yaml:"authMethod"` - - // Settings for a service principal. - AADClientID string `json:"aadClientId" yaml:"aadClientId"` - AADClientSecret string `json:"aadClientSecret" yaml:"aadClientSecret"` - AADClientCertPath string `json:"aadClientCertPath" yaml:"aadClientCertPath"` - AADClientCertPassword string `json:"aadClientCertPassword" yaml:"aadClientCertPassword"` - UseCredentialFromEnvironment bool `json:"useCredentialFromEnvironment" yaml:"useCredentialFromEnvironment"` - UseManagedIdentityExtension bool `json:"useManagedIdentityExtension" yaml:"useManagedIdentityExtension"` - UserAssignedIdentityID string `json:"userAssignedIdentityID" yaml:"userAssignedIdentityID"` - - //Configs only for AKS - ClusterName string `json:"clusterName" yaml:"clusterName"` - //Config only for AKS - NodeResourceGroup string `json:"nodeResourceGroup" yaml:"nodeResourceGroup"` -} - -func (cfg *Config) PrepareConfig() error { - cfg.BaseVars() - err := cfg.prepareID() - if err != nil { - return err - } - return nil -} - -func (cfg *Config) BaseVars() { - cfg.Cloud = os.Getenv("ARM_CLOUD") - cfg.Location = os.Getenv("LOCATION") - cfg.ResourceGroup = os.Getenv("ARM_RESOURCE_GROUP") - cfg.TenantID = os.Getenv("ARM_TENANT_ID") - cfg.SubscriptionID = os.Getenv("ARM_SUBSCRIPTION_ID") - cfg.AADClientID = os.Getenv("ARM_CLIENT_ID") - cfg.AADClientSecret = os.Getenv("ARM_CLIENT_SECRET") - cfg.VMType = strings.ToLower(os.Getenv("ARM_VM_TYPE")) - cfg.AADClientCertPath = os.Getenv("ARM_CLIENT_CERT_PATH") - cfg.AADClientCertPassword = os.Getenv("ARM_CLIENT_CERT_PASSWORD") - cfg.ClusterName = os.Getenv("AZURE_CLUSTER_NAME") - cfg.NodeResourceGroup = os.Getenv("AZURE_NODE_RESOURCE_GROUP") -} - -func (cfg *Config) prepareID() error { - useCredentialFromEnvironmentFromEnv := os.Getenv("ARM_USE_CREDENTIAL_FROM_ENVIRONMENT") - if len(useCredentialFromEnvironmentFromEnv) > 0 { - shouldUse, err := strconv.ParseBool(useCredentialFromEnvironmentFromEnv) - if err != nil { - return err - } - cfg.UseCredentialFromEnvironment = shouldUse - } - useManagedIdentityExtensionFromEnv := os.Getenv("ARM_USE_MANAGED_IDENTITY_EXTENSION") - if len(useManagedIdentityExtensionFromEnv) > 0 { - shouldUse, err := strconv.ParseBool(useManagedIdentityExtensionFromEnv) - if err != nil { - return err - } - cfg.UseManagedIdentityExtension = shouldUse - } - userAssignedIdentityIDFromEnv := os.Getenv("ARM_USER_ASSIGNED_IDENTITY_ID") - if userAssignedIdentityIDFromEnv != "" { - cfg.UserAssignedIdentityID = userAssignedIdentityIDFromEnv - } - return nil -} - -// BuildAzureConfig returns a Config object for the Azure clients -func BuildAzureConfig() (*Config, error) { - var err error - cfg := &Config{} - err = cfg.PrepareConfig() - if err != nil { - return nil, err - } - cfg.TrimSpace() - setVMType(cfg) - - if err := cfg.validate(); err != nil { - return nil, err - } - return cfg, nil -} - -func setVMType(cfg *Config) { - // Defaulting vmType to vmss. - if cfg.VMType == "" { - cfg.VMType = vmTypeVMSS - } -} - -func (cfg *Config) GetAzureClientConfig(authorizer autorest.Authorizer, env *azure.Environment) *ClientConfig { - azClientConfig := &ClientConfig{ - Location: cfg.Location, - SubscriptionID: cfg.SubscriptionID, - ResourceManagerEndpoint: env.ResourceManagerEndpoint, - Authorizer: authorizer, - } - - return azClientConfig -} - -// TrimSpace removes all leading and trailing white spaces. -func (cfg *Config) TrimSpace() { - cfg.Cloud = strings.TrimSpace(cfg.Cloud) - cfg.TenantID = strings.TrimSpace(cfg.TenantID) - cfg.SubscriptionID = strings.TrimSpace(cfg.SubscriptionID) - cfg.ResourceGroup = strings.TrimSpace(cfg.ResourceGroup) - cfg.VMType = strings.TrimSpace(cfg.VMType) - cfg.AADClientID = strings.TrimSpace(cfg.AADClientID) - cfg.AADClientSecret = strings.TrimSpace(cfg.AADClientSecret) - cfg.AADClientCertPath = strings.TrimSpace(cfg.AADClientCertPath) - cfg.AADClientCertPassword = strings.TrimSpace(cfg.AADClientCertPassword) - cfg.ClusterName = strings.TrimSpace(cfg.ClusterName) - cfg.NodeResourceGroup = strings.TrimSpace(cfg.NodeResourceGroup) -} - -func (cfg *Config) validate() error { - // Setup fields and validate all of them are not empty - fields := []cfgField{ - {cfg.SubscriptionID, "subscription ID"}, - {cfg.NodeResourceGroup, "node resource group"}, - {cfg.VMType, "VM type"}, - // Even though the config doesnt use some of these, - // its good to validate they were set in the environment - } - - for _, field := range fields { - if field.val == "" { - return fmt.Errorf("%s not set", field.name) - } - } - - if cfg.UseManagedIdentityExtension { - return nil - } - - if cfg.AuthMethod != "" && cfg.AuthMethod != authMethodPrincipal && cfg.AuthMethod != authMethodCLI { - return fmt.Errorf("unsupported authorization method: %s", cfg.AuthMethod) - } - - return nil -} diff --git a/pkg/auth/config_test.go b/pkg/auth/config_test.go index 01231bd1b..071a3fb8c 100644 --- a/pkg/auth/config_test.go +++ b/pkg/auth/config_test.go @@ -41,6 +41,7 @@ func TestBuildAzureConfig(t *testing.T) { ResourceGroup: "my-rg", NodeResourceGroup: "my-node-rg", VMType: "vmss", + ArmAuthMethod: "workload-identity", }, wantErr: false, env: map[string]string{ @@ -59,6 +60,7 @@ func TestBuildAzureConfig(t *testing.T) { ResourceGroup: "my-rg", NodeResourceGroup: "my-node-rg", VMType: "vm", + ArmAuthMethod: "workload-identity", }, wantErr: false, env: map[string]string{ @@ -72,39 +74,79 @@ func TestBuildAzureConfig(t *testing.T) { }, }, { - name: "bogus ARM_USE_MANAGED_IDENTITY_EXTENSION", + name: "bogus ARM_AUTH_METHOD", expected: nil, wantErr: true, env: map[string]string{ - "ARM_RESOURCE_GROUP": "my-rg", - "ARM_SUBSCRIPTION_ID": "12345", - "AZURE_NODE_RESOURCE_GROUP": "my-node-rg", - "AZURE_SUBNET_ID": "12345", - "AZURE_SUBNET_NAME": "my-subnet", - "AZURE_VNET_NAME": "my-vnet", - "ARM_USE_MANAGED_IDENTITY_EXTENSION": "foo", // this is not a supported value + "ARM_RESOURCE_GROUP": "my-rg", + "ARM_SUBSCRIPTION_ID": "12345", + "AZURE_NODE_RESOURCE_GROUP": "my-node-rg", + "AZURE_SUBNET_ID": "12345", + "AZURE_SUBNET_NAME": "my-subnet", + "AZURE_VNET_NAME": "my-vnet", + "ARM_AUTH_METHOD": "foo", // this is not a supported value + }, + }, + { + name: "auth method msi", + expected: &Config{ + SubscriptionID: "12345", + ResourceGroup: "my-rg", + NodeResourceGroup: "my-node-rg", + VMType: "vmss", + ArmAuthMethod: "system-assigned-msi", + }, + wantErr: false, + env: map[string]string{ + "ARM_RESOURCE_GROUP": "my-rg", + "ARM_SUBSCRIPTION_ID": "12345", + "AZURE_NODE_RESOURCE_GROUP": "my-node-rg", + "AZURE_SUBNET_ID": "12345", + "AZURE_SUBNET_NAME": "my-subnet", + "AZURE_VNET_NAME": "my-vnet", + "ARM_AUTH_METHOD": "system-assigned-msi", + }, + }, + { + name: "auth method workload identity", + expected: &Config{ + SubscriptionID: "12345", + ResourceGroup: "my-rg", + NodeResourceGroup: "my-node-rg", + VMType: "vmss", + ArmAuthMethod: "workload-identity", + }, + wantErr: false, + env: map[string]string{ + "ARM_RESOURCE_GROUP": "my-rg", + "ARM_SUBSCRIPTION_ID": "12345", + "AZURE_NODE_RESOURCE_GROUP": "my-node-rg", + "AZURE_SUBNET_ID": "12345", + "AZURE_SUBNET_NAME": "my-subnet", + "AZURE_VNET_NAME": "my-vnet", + "ARM_AUTH_METHOD": "workload-identity", }, }, { - name: "valid msi", + name: "valid kubelet identity", expected: &Config{ - SubscriptionID: "12345", - ResourceGroup: "my-rg", - NodeResourceGroup: "my-node-rg", - VMType: "vmss", - UseManagedIdentityExtension: true, - UserAssignedIdentityID: "12345", + SubscriptionID: "12345", + ResourceGroup: "my-rg", + NodeResourceGroup: "my-node-rg", + VMType: "vmss", + ArmAuthMethod: "system-assigned-msi", + KubeletIdentityClientID: "11111111-2222-3333-4444-555555555555", }, wantErr: false, env: map[string]string{ - "ARM_RESOURCE_GROUP": "my-rg", - "ARM_SUBSCRIPTION_ID": "12345", - "AZURE_NODE_RESOURCE_GROUP": "my-node-rg", - "AZURE_SUBNET_ID": "12345", - "AZURE_SUBNET_NAME": "my-subnet", - "AZURE_VNET_NAME": "my-vnet", - "ARM_USE_MANAGED_IDENTITY_EXTENSION": "true", - "ARM_USER_ASSIGNED_IDENTITY_ID": "12345", + "ARM_RESOURCE_GROUP": "my-rg", + "ARM_SUBSCRIPTION_ID": "12345", + "AZURE_NODE_RESOURCE_GROUP": "my-node-rg", + "AZURE_SUBNET_ID": "12345", + "AZURE_SUBNET_NAME": "my-subnet", + "AZURE_VNET_NAME": "my-vnet", + "ARM_AUTH_METHOD": "system-assigned-msi", + "ARM_KUBELET_IDENTITY_CLIENT_ID": "11111111-2222-3333-4444-555555555555", }, }, } diff --git a/pkg/auth/cred.go b/pkg/auth/cred.go deleted file mode 100644 index bd1df61fd..000000000 --- a/pkg/auth/cred.go +++ /dev/null @@ -1,55 +0,0 @@ -/* -Portions Copyright (c) Microsoft Corporation. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package auth - -import ( - "fmt" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "k8s.io/klog/v2" -) - -// NewCredential provides a token credential for msi and service principal auth -func NewCredential(cfg *Config) (azcore.TokenCredential, error) { - if cfg == nil { - return nil, fmt.Errorf("failed to create credential, nil config provided") - } - - if cfg.UseCredentialFromEnvironment { - klog.V(2).Infoln("cred: using workload identity for new credential") - return azidentity.NewDefaultAzureCredential(nil) - } - - if cfg.UseManagedIdentityExtension || cfg.AADClientID == "msi" { - klog.V(2).Infoln("cred: using msi for new credential") - msiCred, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ - ID: azidentity.ClientID(cfg.UserAssignedIdentityID), - }) - if err != nil { - return nil, err - } - return msiCred, nil - } - // service principal case - klog.V(2).Infoln("cred: using sp for new credential") - cred, err := azidentity.NewClientSecretCredential(cfg.TenantID, cfg.AADClientID, cfg.AADClientSecret, nil) - if err != nil { - return nil, err - } - return cred, nil -} diff --git a/pkg/auth/cred_test.go b/pkg/auth/cred_test.go index 22d48ab93..8235b0a5c 100644 --- a/pkg/auth/cred_test.go +++ b/pkg/auth/cred_test.go @@ -39,45 +39,35 @@ func TestNewCredential(t *testing.T) { wantErrStr: "failed to create credential, nil config provided", }, { - name: "AAD client ID is MSI", + name: "unsupported auth method", cfg: &Config{ - AADClientID: "msi", - TenantID: "00000000-0000-0000-0000-000000000000", - UserAssignedIdentityID: "12345678-1234-1234-1234-123456789012", + ArmAuthMethod: "unsupported", }, - want: reflect.TypeOf(&azidentity.ManagedIdentityCredential{}), - wantErr: false, + want: nil, + wantErr: true, + wantErrStr: "cred: unsupported auth method: unsupported", }, { - name: "AAD client ID is using MSI extension", - cfg: &Config{ - UseManagedIdentityExtension: true, - AADClientID: "msi", - TenantID: "00000000-0000-0000-0000-000000000000", - UserAssignedIdentityID: "12345678-1234-1234-1234-123456789012", - }, - want: reflect.TypeOf(&azidentity.ManagedIdentityCredential{}), - wantErr: false, + name: "empty auth method", + cfg: &Config{}, + want: nil, + wantErr: true, + wantErrStr: "cred: unsupported auth method: ", }, { - name: "AADClientID is not MSI", + name: "auth method system-assigned-msi", cfg: &Config{ - AADClientID: "test-client-id", - AADClientSecret: "test-client-secret", - TenantID: "00000000-0000-0000-0000-000000000000", + ArmAuthMethod: authMethodSysMSI, }, - want: reflect.TypeOf(&azidentity.ClientSecretCredential{}), + want: reflect.TypeOf(&azidentity.ManagedIdentityCredential{}), wantErr: false, }, { - name: "AADClientID is not MSI and UserAssignedIdentityID is set", + name: "auth method workload-identity", cfg: &Config{ - AADClientID: "test-client-id", - AADClientSecret: "test-client-secret", - TenantID: "00000000-0000-0000-0000-000000000000", - UserAssignedIdentityID: "12345678-1234-1234-1234-123456789012", + ArmAuthMethod: authMethodWorkloadIdentity, }, - want: reflect.TypeOf(&azidentity.ClientSecretCredential{}), + want: reflect.TypeOf(&azidentity.DefaultAzureCredential{}), wantErr: false, }, } diff --git a/pkg/auth/util.go b/pkg/auth/util.go index e3a5189d8..2bcd4af1f 100644 --- a/pkg/auth/util.go +++ b/pkg/auth/util.go @@ -17,30 +17,11 @@ limitations under the License. package auth import ( - "crypto/rsa" - "crypto/x509" "fmt" - "golang.org/x/crypto/pkcs12" - "github.com/Azure/karpenter-provider-azure/pkg/utils/project" ) -// decodePkcs12 decodes a PKCS#12 client certificate by extracting the public certificate and -// the private RSA key -func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) { - privateKey, certificate, err := pkcs12.Decode(pkcs, password) - if err != nil { - return nil, nil, fmt.Errorf("decoding the PKCS#12 client certificate: %w", err) - } - rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey) - if !isRsaKey { - return nil, nil, fmt.Errorf("PKCS#12 certificate must contain a RSA private key") - } - - return certificate, rsaPrivateKey, nil -} - func GetUserAgentExtension() string { return fmt.Sprintf("karpenter-aks/v%s", project.Version) } diff --git a/pkg/controllers/controllers.go b/pkg/controllers/controllers.go index 8da248639..38ecabcb6 100644 --- a/pkg/controllers/controllers.go +++ b/pkg/controllers/controllers.go @@ -26,15 +26,18 @@ import ( "github.com/Azure/karpenter-provider-azure/pkg/cloudprovider" nodeclaimgarbagecollection "github.com/Azure/karpenter-provider-azure/pkg/controllers/nodeclaim/garbagecollection" "github.com/Azure/karpenter-provider-azure/pkg/controllers/nodeclaim/inplaceupdate" + "github.com/Azure/karpenter-provider-azure/pkg/operator/options" "github.com/Azure/karpenter-provider-azure/pkg/providers/instance" "github.com/Azure/karpenter-provider-azure/pkg/utils/project" ) func NewControllers(ctx context.Context, kubeClient client.Client, cloudProvider *cloudprovider.CloudProvider, instanceProvider *instance.Provider) []controller.Controller { logging.FromContext(ctx).With("version", project.Version).Debugf("discovered version") + opts := options.FromContext(ctx) + controllers := []controller.Controller{ nodeclaimgarbagecollection.NewController(kubeClient, cloudProvider), - inplaceupdate.NewController(kubeClient, instanceProvider), + inplaceupdate.NewController(kubeClient, instanceProvider, opts), } return controllers } diff --git a/pkg/controllers/nodeclaim/inplaceupdate/controller.go b/pkg/controllers/nodeclaim/inplaceupdate/controller.go index 557a4ab14..592aee016 100644 --- a/pkg/controllers/nodeclaim/inplaceupdate/controller.go +++ b/pkg/controllers/nodeclaim/inplaceupdate/controller.go @@ -44,6 +44,9 @@ import ( type Controller struct { kubeClient client.Client instanceProvider *instance.Provider + + // Will be used to calculate the goal state + opts *options.Options } var _ corecontroller.TypedController[*v1beta1.NodeClaim] = &Controller{} @@ -51,6 +54,7 @@ var _ corecontroller.TypedController[*v1beta1.NodeClaim] = &Controller{} func NewController( kubeClient client.Client, instanceProvider *instance.Provider, + opts *options.Options, ) corecontroller.Controller { controller := &Controller{ kubeClient: kubeClient, @@ -81,8 +85,7 @@ func (c *Controller) Reconcile(ctx context.Context, nodeClaim *v1beta1.NodeClaim // TODO: To look it up and use that as input to calculate the goal state as well // Compare the expected hash with the actual hash - options := options.FromContext(ctx) - goalHash, err := HashFromNodeClaim(options, nodeClaim) + goalHash, err := HashFromNodeClaim(c.opts, nodeClaim) if err != nil { return reconcile.Result{}, err } @@ -105,7 +108,7 @@ func (c *Controller) Reconcile(ctx context.Context, nodeClaim *v1beta1.NodeClaim return reconcile.Result{}, fmt.Errorf("getting azure VM for machine, %w", err) } - update := calculateVMPatch(options, vm) + update := calculateVMPatch(c.opts, vm) // This is safe only as long as we're not updating fields which we consider secret. // If we do/are, we need to redact them. logVMPatch(ctx, update) @@ -133,12 +136,12 @@ func (c *Controller) Reconcile(ctx context.Context, nodeClaim *v1beta1.NodeClaim } func calculateVMPatch( - options *options.Options, + opts *options.Options, // TODO: Can pass and consider NodeClaim and/or NodePool here if we need to in the future currentVM *armcompute.VirtualMachine, ) *armcompute.VirtualMachineUpdate { // Determine the differences between the current state and the goal state - expectedIdentities := options.NodeIdentities + expectedIdentities := opts.NodeIdentities var currentIdentities []string if currentVM.Identity != nil { currentIdentities = lo.Keys(currentVM.Identity.UserAssignedIdentities) diff --git a/pkg/controllers/nodeclaim/inplaceupdate/utils.go b/pkg/controllers/nodeclaim/inplaceupdate/utils.go index b64551e09..1ef787270 100644 --- a/pkg/controllers/nodeclaim/inplaceupdate/utils.go +++ b/pkg/controllers/nodeclaim/inplaceupdate/utils.go @@ -64,9 +64,9 @@ func HashFromVM(vm *armcompute.VirtualMachine) (string, error) { } // HashFromNodeClaim calculates an inplace update hash from the specified machine and options -func HashFromNodeClaim(options *options.Options, _ *v1beta1.NodeClaim) (string, error) { +func HashFromNodeClaim(opts *options.Options, _ *v1beta1.NodeClaim) (string, error) { hashStruct := &inPlaceUpdateFields{ - Identities: sets.New(options.NodeIdentities...), + Identities: sets.New(opts.NodeIdentities...), } return hashStruct.CalculateHash() diff --git a/pkg/operator/operator.go b/pkg/operator/operator.go index 4f7e70276..d5615497d 100644 --- a/pkg/operator/operator.go +++ b/pkg/operator/operator.go @@ -30,8 +30,6 @@ import ( corev1beta1 "sigs.k8s.io/karpenter/pkg/apis/v1beta1" "sigs.k8s.io/karpenter/pkg/operator/scheme" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" - "github.com/Azure/karpenter-provider-azure/pkg/apis" "github.com/Azure/karpenter-provider-azure/pkg/auth" azurecache "github.com/Azure/karpenter-provider-azure/pkg/cache" @@ -42,8 +40,6 @@ import ( "github.com/Azure/karpenter-provider-azure/pkg/providers/launchtemplate" "github.com/Azure/karpenter-provider-azure/pkg/providers/loadbalancer" "github.com/Azure/karpenter-provider-azure/pkg/providers/pricing" - "github.com/Azure/karpenter-provider-azure/pkg/utils" - armopts "github.com/Azure/karpenter-provider-azure/pkg/utils/opts" "sigs.k8s.io/karpenter/pkg/operator" ) @@ -67,28 +63,31 @@ type Operator struct { } func NewOperator(ctx context.Context, operator *operator.Operator) (context.Context, *Operator) { - azConfig, err := GetAZConfig() - lo.Must0(err, "creating Azure config") // NOTE: we prefer this over the cleaner azConfig := lo.Must(GetAzConfig()), as when initializing the client there are helpful error messages in initializing clients and the azure config + opts := options.FromContext(ctx) - azClient, err := instance.CreateAZClient(ctx, azConfig) - lo.Must0(err, "creating Azure client") + var authManager *auth.AuthManager + if opts.ArmAuthMethod == auth.AuthMethodWorkloadIdentity { + authManager = auth.NewAuthManagerWorkloadIdentity(opts.Location) + } else if opts.ArmAuthMethod == auth.AuthMethodSysMSI { + authManager = auth.NewAuthManagerSystemAssignedMSI(opts.Location) + } - vnetGUID, err := getVnetGUID(azConfig, options.FromContext(ctx).SubnetID) - lo.Must0(err, "getting VNET GUID") + azClient, err := instance.CreateAZClient(ctx, opts.SubscriptionID, authManager) + lo.Must0(err, "creating Azure client") unavailableOfferingsCache := azurecache.NewUnavailableOfferings() pricingProvider := pricing.NewProvider( ctx, + opts, pricing.NewAPI(), - azConfig.Location, operator.Elected(), ) imageProvider := imagefamily.NewProvider( + opts, operator.KubernetesInterface, cache.New(azurecache.KubernetesVersionTTL, azurecache.DefaultCleanupInterval), azClient.ImageVersionsClient, - azConfig.Location, ) imageResolver := imagefamily.New( operator.GetClient(), @@ -96,39 +95,30 @@ func NewOperator(ctx context.Context, operator *operator.Operator) (context.Cont ) launchTemplateProvider := launchtemplate.NewProvider( ctx, + opts, imageResolver, imageProvider, lo.Must(getCABundle(operator.GetConfig())), - options.FromContext(ctx).ClusterEndpoint, - azConfig.TenantID, - azConfig.SubscriptionID, - azConfig.UserAssignedIdentityID, - azConfig.NodeResourceGroup, - azConfig.Location, - vnetGUID, ) instanceTypeProvider := instancetype.NewProvider( - azConfig.Location, + opts, cache.New(instancetype.InstanceTypesCacheTTL, azurecache.DefaultCleanupInterval), azClient.SKUClient, pricingProvider, unavailableOfferingsCache, ) loadBalancerProvider := loadbalancer.NewProvider( + opts, azClient.LoadBalancersClient, cache.New(loadbalancer.LoadBalancersCacheTTL, azurecache.DefaultCleanupInterval), - azConfig.NodeResourceGroup, ) instanceProvider := instance.NewProvider( + opts, azClient, instanceTypeProvider, launchTemplateProvider, loadBalancerProvider, unavailableOfferingsCache, - azConfig.Location, - azConfig.NodeResourceGroup, - options.FromContext(ctx).SubnetID, - azConfig.SubscriptionID, ) return ctx, &Operator{ @@ -144,14 +134,6 @@ func NewOperator(ctx context.Context, operator *operator.Operator) (context.Cont } } -func GetAZConfig() (*auth.Config, error) { - cfg, err := auth.BuildAzureConfig() - if err != nil { - return nil, err - } - return cfg, nil -} - func getCABundle(restConfig *rest.Config) (*string, error) { // Discover CA Bundle from the REST client. We could alternatively // have used the simpler client-go InClusterConfig() method. @@ -167,28 +149,3 @@ func getCABundle(restConfig *rest.Config) (*string, error) { } return ptr.String(base64.StdEncoding.EncodeToString(transportConfig.TLS.CAData)), nil } - -func getVnetGUID(cfg *auth.Config, subnetID string) (string, error) { - creds, err := auth.NewCredential(cfg) - if err != nil { - return "", err - } - opts := armopts.DefaultArmOpts() - vnetClient, err := armnetwork.NewVirtualNetworksClient(cfg.SubscriptionID, creds, opts) - if err != nil { - return "", err - } - - subnetParts, err := utils.GetVnetSubnetIDComponents(subnetID) - if err != nil { - return "", err - } - vnet, err := vnetClient.Get(context.Background(), subnetParts.ResourceGroupName, subnetParts.VNetName, nil) - if err != nil { - return "", err - } - if vnet.Properties == nil || vnet.Properties.ResourceGUID == nil { - return "", fmt.Errorf("vnet %s does not have a resource GUID", subnetParts.VNetName) - } - return *vnet.Properties.ResourceGUID, nil -} diff --git a/pkg/operator/options/options.go b/pkg/operator/options/options.go index 5ede66943..d1202d21e 100644 --- a/pkg/operator/options/options.go +++ b/pkg/operator/options/options.go @@ -21,13 +21,9 @@ import ( "errors" "flag" "fmt" - "hash/fnv" - "math/rand" - "net/url" "os" "strings" - "k8s.io/apimachinery/pkg/util/sets" coreoptions "sigs.k8s.io/karpenter/pkg/operator/options" "sigs.k8s.io/karpenter/pkg/utils/env" ) @@ -57,40 +53,67 @@ func (s *nodeIdentitiesValue) String() string { return strings.Join(*s, ",") } type optionsKey struct{} +// Options contains the configuration provided by the user. +// Currently we do not support changing these after initialization in any way. +// Even if we always get the updated value from the context/pointer, their copies that have been passed into external/vendored functions will not be updated. +// And some of their defaults might even depends on neighbouring fields (e.g., APIServerName). +// So, if one day we want to support dynamic configuration updates, consider creating the a new instance of Options and reinitialize the operator/controllers. +// +// If some fields need to be updated/refreshed and have their own (valid) cascading update procedure, consider moving it away to prevent confusion with the above assumption. +// At that point, we should consider having a more clear distinction between user options and global variables. type Options struct { - ClusterName string - ClusterEndpoint string // => APIServerName in bootstrap, except needs to be w/o https/port - VMMemoryOverheadPercent float64 - ClusterID string + // Target cluster information; might be use for both bootstrapping and ARM authentications + Cloud string + Location string + TenantID string + SubscriptionID string + ResourceGroup string + ClusterName string + ClusterEndpoint string // => APIServerName in bootstrap, except needs to be w/o https/port + ClusterID string + APIServerName string + + // Node parameters + NodeResourceGroup string + KubeletIdentityClientID string KubeletClientTLSBootstrapToken string // => TLSBootstrapToken in bootstrap (may need to be per node/nodepool) + NodeIdentities []string // => Applied onto each VM SSHPublicKey string // ssh.publicKeys.keyData => VM SSH public key // TODO: move to v1alpha2.AKSNodeClass? NetworkPlugin string // => NetworkPlugin in bootstrap NetworkPolicy string // => NetworkPolicy in bootstrap - NodeIdentities []string // => Applied onto each VM - - SubnetID string // => VnetSubnetID to use (for nodes in Azure CNI Overlay and Azure CNI + pod subnet; for for nodes and pods in Azure CNI), unless overridden via AKSNodeClass + SubnetID string // => VnetSubnetID to use (for nodes in Azure CNI Overlay and Azure CNI + pod subnet; for for nodes and pods in Azure CNI), unless overridden via AKSNodeClass + VnetGUID string - setFlags map[string]bool + // Behavioral configuration + ArmAuthMethod string + VMMemoryOverheadPercent float64 } func (o *Options) AddFlags(fs *coreoptions.FlagSet) { + fs.StringVar(&o.Cloud, "cloud", env.WithDefaultString("ARM_CLOUD", "AZUREPUBLICCLOUD"), "The cloud environment to use. Currently only supports 'AZUREPUBLICCLOUD'.") + fs.StringVar(&o.Location, "location", env.WithDefaultString("LOCATION", ""), "[REQUIRED] The location of the cluster.") + fs.StringVar(&o.TenantID, "tenant-id", env.WithDefaultString("ARM_TENANT_ID", ""), "The tenant ID of the cluster.") + fs.StringVar(&o.SubscriptionID, "subscription-id", env.WithDefaultString("ARM_SUBSCRIPTION_ID", ""), "[REQUIRED] The subscription ID of the cluster.") + fs.StringVar(&o.ResourceGroup, "resource-group", env.WithDefaultString("ARM_RESOURCE_GROUP", ""), "The resource group of the cluster.") fs.StringVar(&o.ClusterName, "cluster-name", env.WithDefaultString("CLUSTER_NAME", ""), "[REQUIRED] The kubernetes cluster name for resource tags.") fs.StringVar(&o.ClusterEndpoint, "cluster-endpoint", env.WithDefaultString("CLUSTER_ENDPOINT", ""), "[REQUIRED] The external kubernetes cluster endpoint for new nodes to connect with.") - fs.Float64Var(&o.VMMemoryOverheadPercent, "vm-memory-overhead-percent", env.WithDefaultFloat64("VM_MEMORY_OVERHEAD_PERCENT", 0.075), "The VM memory overhead as a percent that will be subtracted from the total memory for all instance types.") + + fs.StringVar(&o.NodeResourceGroup, "node-resource-group", env.WithDefaultString("AZURE_NODE_RESOURCE_GROUP", ""), "[REQUIRED] The resource group of the nodes.") fs.StringVar(&o.KubeletClientTLSBootstrapToken, "kubelet-bootstrap-token", env.WithDefaultString("KUBELET_BOOTSTRAP_TOKEN", ""), "[REQUIRED] The bootstrap token for new nodes to join the cluster.") + fs.StringVar(&o.KubeletIdentityClientID, "kubelet-identity-client-id", env.WithDefaultString("ARM_KUBELET_IDENTITY_CLIENT_ID", ""), "[REQUIRED] The client ID of the user assigned identity for kubelet.") + fs.Var(newNodeIdentitiesValue(env.WithDefaultString("NODE_IDENTITIES", ""), &o.NodeIdentities), "node-identities", "Additional identities to be assigned to the provisioned VMs. Allow support for AKS features like Addons.") fs.StringVar(&o.SSHPublicKey, "ssh-public-key", env.WithDefaultString("SSH_PUBLIC_KEY", ""), "[REQUIRED] VM SSH public key.") fs.StringVar(&o.NetworkPlugin, "network-plugin", env.WithDefaultString("NETWORK_PLUGIN", "azure"), "The network plugin used by the cluster.") fs.StringVar(&o.NetworkPolicy, "network-policy", env.WithDefaultString("NETWORK_POLICY", ""), "The network policy used by the cluster.") - fs.StringVar(&o.SubnetID, "vnet-subnet-id", env.WithDefaultString("VNET_SUBNET_ID", ""), "The default subnet ID to use for new nodes. This must be a valid ARM resource ID for subnet that does not overlap with the service CIDR or the pod CIDR") - fs.Var(newNodeIdentitiesValue(env.WithDefaultString("NODE_IDENTITIES", ""), &o.NodeIdentities), "node-identities", "User assigned identities for nodes.") -} + fs.StringVar(&o.SubnetID, "vnet-subnet-id", env.WithDefaultString("VNET_SUBNET_ID", ""), "[REQUIRED] The default subnet ID to use for new nodes. This must be a valid ARM resource ID for subnet that does not overlap with the service CIDR or the pod CIDR") -func (o Options) GetAPIServerName() string { - endpoint, _ := url.Parse(o.ClusterEndpoint) // assume to already validated - return endpoint.Hostname() + fs.StringVar(&o.ArmAuthMethod, "auth-method", env.WithDefaultString("ARM_AUTH_METHOD", "workload-identity"), "The authentication method to use.") + fs.Float64Var(&o.VMMemoryOverheadPercent, "vm-memory-overhead-percent", env.WithDefaultFloat64("VM_MEMORY_OVERHEAD_PERCENT", 0.075), "The VM memory overhead as a percent that will be subtracted from the total memory for all instance types.") } func (o *Options) Parse(fs *coreoptions.FlagSet, args ...string) error { + ctx := context.Background() + if err := fs.Parse(args); err != nil { if errors.Is(err, flag.ErrHelp) { os.Exit(0) @@ -98,25 +121,13 @@ func (o *Options) Parse(fs *coreoptions.FlagSet, args ...string) error { return fmt.Errorf("parsing flags, %w", err) } - // Check if each option has been set. This is a little brute force and better options might exist, - // but this only needs to be here for one version - o.setFlags = map[string]bool{} - cliFlags := sets.New[string]() - fs.Visit(func(f *flag.Flag) { - cliFlags.Insert(f.Name) - }) - fs.VisitAll(func(f *flag.Flag) { - envName := strings.ReplaceAll(strings.ToUpper(f.Name), "-", "_") - _, ok := os.LookupEnv(envName) - o.setFlags[f.Name] = ok || cliFlags.Has(f.Name) - }) - if err := o.Validate(); err != nil { return fmt.Errorf("validating options, %w", err) } - // ClusterID is generated from cluster endpoint - o.ClusterID = getAKSClusterID(o.GetAPIServerName()) + if err := o.Default(ctx); err != nil { + return fmt.Errorf("defaulting options, %w", err) + } return nil } @@ -136,14 +147,3 @@ func FromContext(ctx context.Context) *Options { } return retval.(*Options) } - -// getAKSClusterID returns cluster ID based on the DNS prefix of the cluster. -// The logic comes from AgentBaker and other places, originally from aks-engine -// with the additional assumption of DNS prefix being the first 33 chars of FQDN -func getAKSClusterID(apiServerFQDN string) string { - dnsPrefix := apiServerFQDN[:33] - h := fnv.New64a() - h.Write([]byte(dnsPrefix)) - r := rand.New(rand.NewSource(int64(h.Sum64()))) //nolint:gosec - return fmt.Sprintf("%08d", r.Uint32())[:8] -} diff --git a/pkg/operator/options/options_defaulting.go b/pkg/operator/options/options_defaulting.go new file mode 100644 index 000000000..c1d1c9f56 --- /dev/null +++ b/pkg/operator/options/options_defaulting.go @@ -0,0 +1,108 @@ +/* +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package options + +import ( + "context" + "fmt" + "hash/fnv" + "math/rand" + "net/url" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/Azure/karpenter-provider-azure/pkg/auth" + "github.com/Azure/karpenter-provider-azure/pkg/utils" + + armopts "github.com/Azure/karpenter-provider-azure/pkg/utils/opts" +) + +// Default sets the default values for the options, but only for those too complicated to be in env default (e.g., depends on other envs) +func (o *Options) Default(ctx context.Context) error { + var err error + + var authManager *auth.AuthManager + if o.ArmAuthMethod == auth.AuthMethodWorkloadIdentity { + authManager = auth.NewAuthManagerWorkloadIdentity(o.Location) + } else if o.ArmAuthMethod == auth.AuthMethodSysMSI { + authManager = auth.NewAuthManagerSystemAssignedMSI(o.Location) + } + + if o.APIServerName, err = getAPIServerName(o.ClusterEndpoint); err != nil { + return fmt.Errorf("failed to get APIServerName: %w", err) + } + + if o.ClusterID, err = getAKSClusterID(o.APIServerName); err != nil { + return fmt.Errorf("failed to get ClusterID: %w", err) + } + + if o.VnetGUID, err = getVnetGUID(ctx, o.SubscriptionID, o.SubnetID, authManager); err != nil { + return fmt.Errorf("failed to get VnetGUID: %w", err) + + } + + return nil +} + +func getAPIServerName(clusterEndpoint string) (string, error) { + endpoint, err := url.Parse(clusterEndpoint) // assume to already validated + return endpoint.Hostname(), err +} + +// getAKSClusterID returns cluster ID based on the DNS prefix of the cluster. +// The logic comes from AgentBaker and other places, originally from aks-engine +// with the additional assumption of DNS prefix being the first 33 chars of FQDN +func getAKSClusterID(apiServerFQDN string) (string, error) { + dnsPrefix := apiServerFQDN[:33] + h := fnv.New64a() + h.Write([]byte(dnsPrefix)) + r := rand.New(rand.NewSource(int64(h.Sum64()))) //nolint:gosec + return fmt.Sprintf("%08d", r.Uint32())[:8], nil +} + +func getVnetGUID(ctx context.Context, subscriptionID string, VnetSubnetID string, authManager *auth.AuthManager) (string, error) { + creds, err := authManager.NewCredential() + if err != nil { + return "", err + } + armOpts := armopts.DefaultArmOpts() + vnetClient, err := armnetwork.NewVirtualNetworksClient(subscriptionID, creds, armOpts) + if err != nil { + return "", err + } + + subnetParts, err := utils.GetVnetSubnetIDComponents(VnetSubnetID) + if err != nil { + return "", err + } + vnet, err := vnetClient.Get(ctx, subnetParts.ResourceGroupName, subnetParts.VNetName, nil) + if err != nil { + return "", err + } + if vnet.Properties == nil || vnet.Properties.ResourceGUID == nil { + return "", fmt.Errorf("vnet %s does not have a resource GUID", subnetParts.VNetName) + } + return *vnet.Properties.ResourceGUID, nil +} + +func contains(slice []string, target string) bool { + for _, element := range slice { + if target == element { + return true + } + } + return false +} diff --git a/pkg/operator/options/options_validation.go b/pkg/operator/options/options_validation.go index 9ed1bb21c..e03984c78 100644 --- a/pkg/operator/options/options_validation.go +++ b/pkg/operator/options/options_validation.go @@ -29,22 +29,45 @@ func (o Options) Validate() error { validate := validator.New() return multierr.Combine( o.validateRequiredFields(), - o.validateEndpoint(), + o.validateClusterEndpoint(), + o.validateArmAuthMethod(), o.validateVMMemoryOverheadPercent(), o.validateVnetSubnetID(), validate.Struct(o), ) } -func (o Options) validateVnetSubnetID() error { - _, err := utils.GetVnetSubnetIDComponents(o.SubnetID) - if err != nil { - return fmt.Errorf("vnet-subnet-id is invalid: %w", err) +func (o Options) validateRequiredFields() error { + if o.Location == "" { + return fmt.Errorf("missing field, location") + } + if o.SubscriptionID == "" { + return fmt.Errorf("missing field, subscription-id") + } + if o.ClusterEndpoint == "" { + return fmt.Errorf("missing field, cluster-endpoint") + } + if o.ClusterName == "" { + return fmt.Errorf("missing field, cluster-name") } + + if o.NodeResourceGroup == "" { + return fmt.Errorf("missing field, node-resource-group") + } + if o.KubeletClientTLSBootstrapToken == "" { + return fmt.Errorf("missing field, kubelet-bootstrap-token") + } + if o.SSHPublicKey == "" { + return fmt.Errorf("missing field, ssh-public-key") + } + if o.SubnetID == "" { + return fmt.Errorf("missing field, vnet-subnet-id") + } + return nil } -func (o Options) validateEndpoint() error { +func (o Options) validateClusterEndpoint() error { if o.ClusterEndpoint == "" { return nil } @@ -57,6 +80,13 @@ func (o Options) validateEndpoint() error { return nil } +func (o Options) validateArmAuthMethod() error { + if o.ArmAuthMethod != "system-assigned-msi" && o.ArmAuthMethod != "workload-identity" { + return fmt.Errorf("unsupported authorization method: %s", o.ArmAuthMethod) + } + return nil +} + func (o Options) validateVMMemoryOverheadPercent() error { if o.VMMemoryOverheadPercent < 0 { return fmt.Errorf("vm-memory-overhead-percent cannot be negative") @@ -64,21 +94,10 @@ func (o Options) validateVMMemoryOverheadPercent() error { return nil } -func (o Options) validateRequiredFields() error { - if o.ClusterEndpoint == "" { - return fmt.Errorf("missing field, cluster-endpoint") - } - if o.ClusterName == "" { - return fmt.Errorf("missing field, cluster-name") - } - if o.KubeletClientTLSBootstrapToken == "" { - return fmt.Errorf("missing field, kubelet-bootstrap-token") - } - if o.SSHPublicKey == "" { - return fmt.Errorf("missing field, ssh-public-key") - } - if o.SubnetID == "" { - return fmt.Errorf("missing field, vnet-subnet-id") +func (o Options) validateVnetSubnetID() error { + _, err := utils.GetVnetSubnetIDComponents(o.SubnetID) + if err != nil { + return fmt.Errorf("vnet-subnet-id is invalid: %w", err) } return nil } diff --git a/pkg/providers/imagefamily/azlinux.go b/pkg/providers/imagefamily/azlinux.go index 3d36d8cc2..f4f4beff4 100644 --- a/pkg/providers/imagefamily/azlinux.go +++ b/pkg/providers/imagefamily/azlinux.go @@ -92,7 +92,7 @@ func (u AzureLinux) UserData(kubeletConfig *corev1beta1.KubeletConfiguration, ta TenantID: u.Options.TenantID, SubscriptionID: u.Options.SubscriptionID, Location: u.Options.Location, - UserAssignedIdentityID: u.Options.UserAssignedIdentityID, + KubeletIdentityClientID: u.Options.KubeletIdentityClientID, ResourceGroup: u.Options.ResourceGroup, ClusterID: u.Options.ClusterID, APIServerName: u.Options.APIServerName, diff --git a/pkg/providers/imagefamily/bootstrap/aksbootstrap.go b/pkg/providers/imagefamily/bootstrap/aksbootstrap.go index f4e9b2260..83ad73a2c 100644 --- a/pkg/providers/imagefamily/bootstrap/aksbootstrap.go +++ b/pkg/providers/imagefamily/bootstrap/aksbootstrap.go @@ -40,7 +40,7 @@ type AKS struct { Arch string TenantID string SubscriptionID string - UserAssignedIdentityID string + KubeletIdentityClientID string Location string ResourceGroup string ClusterID string @@ -429,7 +429,7 @@ func (a AKS) applyOptions(nbv *NodeBootstrapVariables) { nbv.SubscriptionID = a.SubscriptionID nbv.Location = a.Location nbv.ResourceGroup = a.ResourceGroup - nbv.UserAssignedIdentityID = a.UserAssignedIdentityID + nbv.UserAssignedIdentityID = a.KubeletIdentityClientID nbv.NetworkPlugin = a.NetworkPlugin nbv.NetworkPolicy = a.NetworkPolicy diff --git a/pkg/providers/imagefamily/image.go b/pkg/providers/imagefamily/image.go index ff5a7e70c..4b2a6f419 100644 --- a/pkg/providers/imagefamily/image.go +++ b/pkg/providers/imagefamily/image.go @@ -25,6 +25,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v5" "github.com/Azure/karpenter-provider-azure/pkg/apis/v1alpha2" + "github.com/Azure/karpenter-provider-azure/pkg/operator/options" "github.com/patrickmn/go-cache" "github.com/samber/lo" "k8s.io/client-go/kubernetes" @@ -51,11 +52,11 @@ const ( imageIDFormat = "/CommunityGalleries/%s/images/%s/versions/%s" ) -func NewProvider(kubernetesInterface kubernetes.Interface, kubernetesVersionCache *cache.Cache, versionsClient CommunityGalleryImageVersionsAPI, location string) *Provider { +func NewProvider(opts *options.Options, kubernetesInterface kubernetes.Interface, kubernetesVersionCache *cache.Cache, versionsClient CommunityGalleryImageVersionsAPI) *Provider { return &Provider{ kubernetesVersionCache: kubernetesVersionCache, imageCache: cache.New(imageExpirationInterval, imageCacheCleaningInterval), - location: location, + location: opts.Location, imageVersionsClient: versionsClient, cm: pretty.NewChangeMonitor(), kubernetesInterface: kubernetesInterface, diff --git a/pkg/providers/imagefamily/ubuntu_2204.go b/pkg/providers/imagefamily/ubuntu_2204.go index a7c3b8ee2..4542fbef7 100644 --- a/pkg/providers/imagefamily/ubuntu_2204.go +++ b/pkg/providers/imagefamily/ubuntu_2204.go @@ -91,7 +91,7 @@ func (u Ubuntu2204) UserData(kubeletConfig *corev1beta1.KubeletConfiguration, ta TenantID: u.Options.TenantID, SubscriptionID: u.Options.SubscriptionID, Location: u.Options.Location, - UserAssignedIdentityID: u.Options.UserAssignedIdentityID, + KubeletIdentityClientID: u.Options.KubeletIdentityClientID, ResourceGroup: u.Options.ResourceGroup, ClusterID: u.Options.ClusterID, APIServerName: u.Options.APIServerName, diff --git a/pkg/providers/instance/azure_client.go b/pkg/providers/instance/azure_client.go index 025e02caa..f00b18a0a 100644 --- a/pkg/providers/instance/azure_client.go +++ b/pkg/providers/instance/azure_client.go @@ -24,7 +24,6 @@ import ( armcomputev5 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v5" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resourcegraph/armresourcegraph" - "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/karpenter-provider-azure/pkg/auth" "github.com/Azure/karpenter-provider-azure/pkg/providers/imagefamily" "github.com/Azure/karpenter-provider-azure/pkg/providers/instance/skuclient" @@ -88,61 +87,44 @@ func NewAZClientFromAPI( } } -func CreateAZClient(ctx context.Context, cfg *auth.Config) (*AZClient, error) { - // Defaulting env to Azure Public Cloud. - env := azure.PublicCloud - var err error - if cfg.Cloud != "" { - env, err = azure.EnvironmentFromName(cfg.Cloud) - if err != nil { - return nil, err - } - } - - azClient, err := NewAZClient(ctx, cfg, &env) - if err != nil { - return nil, err - } - - return azClient, nil -} - -func NewAZClient(ctx context.Context, cfg *auth.Config, env *azure.Environment) (*AZClient, error) { - cred, err := auth.NewCredential(cfg) +func CreateAZClient(ctx context.Context, subscriptionID string, authManager *auth.AuthManager) (*AZClient, error) { + cred, err := authManager.NewCredential() if err != nil { return nil, err } + armOpts := armopts.DefaultArmOpts() - opts := armopts.DefaultArmOpts() - extensionsClient, err := armcompute.NewVirtualMachineExtensionsClient(cfg.SubscriptionID, cred, opts) + extensionsClient, err := armcompute.NewVirtualMachineExtensionsClient(subscriptionID, cred, armOpts) if err != nil { return nil, err } + klog.V(5).Infof("Created virtual machine extensions client %v, using a token credential", extensionsClient) - interfacesClient, err := armnetwork.NewInterfacesClient(cfg.SubscriptionID, cred, opts) + interfacesClient, err := armnetwork.NewInterfacesClient(subscriptionID, cred, armOpts) if err != nil { return nil, err } klog.V(5).Infof("Created network interface client %v using token credential", interfacesClient) - virtualMachinesClient, err := armcompute.NewVirtualMachinesClient(cfg.SubscriptionID, cred, opts) + virtualMachinesClient, err := armcompute.NewVirtualMachinesClient(subscriptionID, cred, armOpts) if err != nil { return nil, err } klog.V(5).Infof("Created virtual machines client %v, using a token credential", virtualMachinesClient) - azureResourceGraphClient, err := armresourcegraph.NewClient(cred, opts) + + azureResourceGraphClient, err := armresourcegraph.NewClient(cred, armOpts) if err != nil { return nil, err } klog.V(5).Infof("Created azure resource graph client %v, using a token credential", azureResourceGraphClient) - imageVersionsClient, err := armcomputev5.NewCommunityGalleryImageVersionsClient(cfg.SubscriptionID, cred, opts) + imageVersionsClient, err := armcomputev5.NewCommunityGalleryImageVersionsClient(subscriptionID, cred, armOpts) if err != nil { return nil, err } klog.V(5).Infof("Created image versions client %v, using a token credential", imageVersionsClient) - loadBalancersClient, err := armnetwork.NewLoadBalancersClient(cfg.SubscriptionID, cred, opts) + loadBalancersClient, err := armnetwork.NewLoadBalancersClient(subscriptionID, cred, armOpts) if err != nil { return nil, err } @@ -150,7 +132,11 @@ func NewAZClient(ctx context.Context, cfg *auth.Config, env *azure.Environment) // TODO: this one is not enabled for rate limiting / throttling ... // TODO Move this over to track 2 when skewer is migrated - skuClient := skuclient.NewSkuClient(ctx, cfg, env) + skuClient, err := skuclient.NewSkuClient(ctx, subscriptionID, authManager) + if err != nil { + return nil, err + } + klog.V(5).Infof("Created sku client %v, using a token credential", skuClient) return NewAZClientFromAPI(virtualMachinesClient, azureResourceGraphClient, diff --git a/pkg/providers/instance/instance.go b/pkg/providers/instance/instance.go index dd1fdc74d..39775e73f 100644 --- a/pkg/providers/instance/instance.go +++ b/pkg/providers/instance/instance.go @@ -89,30 +89,31 @@ type Provider struct { subnetID string subscriptionID string unavailableOfferings *cache.UnavailableOfferings + sshPublicKey string + nodeIdentities []string } func NewProvider( + opts *options.Options, azClient *AZClient, instanceTypeProvider *instancetype.Provider, launchTemplateProvider *launchtemplate.Provider, loadBalancerProvider *loadbalancer.Provider, offeringsCache *cache.UnavailableOfferings, - location string, - resourceGroup string, - subnetID string, - subscriptionID string, ) *Provider { - listQuery = GetListQueryBuilder(resourceGroup).String() + listQuery = GetListQueryBuilder(opts.NodeResourceGroup).String() return &Provider{ azClient: azClient, instanceTypeProvider: instanceTypeProvider, launchTemplateProvider: launchTemplateProvider, loadBalancerProvider: loadBalancerProvider, - location: location, - resourceGroup: resourceGroup, - subnetID: subnetID, - subscriptionID: subscriptionID, + location: opts.Location, + resourceGroup: opts.NodeResourceGroup, + subnetID: opts.SubnetID, + subscriptionID: opts.SubscriptionID, unavailableOfferings: offeringsCache, + sshPublicKey: opts.SSHPublicKey, + nodeIdentities: opts.NodeIdentities, } } @@ -393,9 +394,7 @@ func (p *Provider) launchInstance( return nil, nil, err } - sshPublicKey := options.FromContext(ctx).SSHPublicKey - nodeIdentityIDs := options.FromContext(ctx).NodeIdentities - vm := newVMObject(resourceName, nicReference, zone, capacityType, p.location, sshPublicKey, nodeIdentityIDs, nodeClass, launchTemplate, instanceType) + vm := newVMObject(resourceName, nicReference, zone, capacityType, p.location, p.sshPublicKey, p.nodeIdentities, nodeClass, launchTemplate, instanceType) logging.FromContext(ctx).Debugf("Creating virtual machine %s (%s)", resourceName, instanceType.Name) // Uses AZ Client to create a new virtual machine using the vm object we prepared earlier diff --git a/pkg/providers/instance/skuclient/skuclient.go b/pkg/providers/instance/skuclient/skuclient.go index c69e304d0..5db105fbf 100644 --- a/pkg/providers/instance/skuclient/skuclient.go +++ b/pkg/providers/instance/skuclient/skuclient.go @@ -22,7 +22,6 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute" - "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/karpenter-provider-azure/pkg/auth" "github.com/Azure/skewer" klog "k8s.io/klog/v2" @@ -37,40 +36,20 @@ type SkuClient interface { } type skuClient struct { - cfg *auth.Config - env *azure.Environment + subscriptionID string + authManager *auth.AuthManager mu sync.RWMutex instance compute.ResourceSkusClient } -func (sc *skuClient) updateInstance() { - sc.mu.RLock() - defer sc.mu.RUnlock() - - authorizer, err := auth.NewAuthorizer(sc.cfg, sc.env) - if err != nil { - klog.V(5).Infof("Error creating authorizer for sku client: %s", err) - return - } - - azClientConfig := sc.cfg.GetAzureClientConfig(authorizer, sc.env) - azClientConfig.UserAgent = auth.GetUserAgentExtension() - - skuClient := compute.NewResourceSkusClient(sc.cfg.SubscriptionID) - skuClient.Authorizer = azClientConfig.Authorizer - klog.V(5).Infof("Created sku client with authorizer: %v", skuClient) - - sc.instance = skuClient -} - -func NewSkuClient(ctx context.Context, cfg *auth.Config, env *azure.Environment) SkuClient { +func NewSkuClient(ctx context.Context, subscriptionID string, authManager *auth.AuthManager) (SkuClient, error) { sc := &skuClient{ - cfg: cfg, - env: env, + subscriptionID: subscriptionID, + authManager: authManager, } - sc.updateInstance() + sc.updateInstance() go func() { for { select { @@ -81,7 +60,25 @@ func NewSkuClient(ctx context.Context, cfg *auth.Config, env *azure.Environment) } } }() - return sc + + return sc, nil +} + +func (sc *skuClient) updateInstance() { + sc.mu.RLock() + defer sc.mu.RUnlock() + + updatedSkuClient := compute.NewResourceSkusClient(sc.subscriptionID) + + authorizer, err := sc.authManager.NewAutorestAuthorizer() + if err != nil { + klog.V(5).Infof("Error creating authorizer for sku client: %s", err) + return + } + updatedSkuClient.Authorizer = authorizer + + klog.V(5).Infof("Created sku client with authorizer: %v", updatedSkuClient) + sc.instance = updatedSkuClient } func (sc *skuClient) GetInstance() skewer.ResourceClient { diff --git a/pkg/providers/instancetype/instancetype.go b/pkg/providers/instancetype/instancetype.go index abc17cdea..26295a3a7 100644 --- a/pkg/providers/instancetype/instancetype.go +++ b/pkg/providers/instancetype/instancetype.go @@ -17,7 +17,6 @@ limitations under the License. package instancetype import ( - "context" "fmt" "math" @@ -33,8 +32,6 @@ import ( "sigs.k8s.io/karpenter/pkg/cloudprovider" "sigs.k8s.io/karpenter/pkg/scheduling" - "github.com/Azure/karpenter-provider-azure/pkg/operator/options" - "sigs.k8s.io/karpenter/pkg/utils/resources" ) @@ -118,13 +115,13 @@ func (t TaxBrackets) Calculate(amount float64) float64 { return tax } -func NewInstanceType(ctx context.Context, sku *skewer.SKU, vmsize *skewer.VMSizeType, kc *corev1beta1.KubeletConfiguration, region string, - offerings cloudprovider.Offerings, nodeClass *v1alpha2.AKSNodeClass, architecture string) *cloudprovider.InstanceType { +func NewInstanceType(sku *skewer.SKU, vmsize *skewer.VMSizeType, kc *corev1beta1.KubeletConfiguration, region string, + offerings cloudprovider.Offerings, nodeClass *v1alpha2.AKSNodeClass, architecture string, vmMemoryOverheadPercent float64) *cloudprovider.InstanceType { return &cloudprovider.InstanceType{ Name: sku.GetName(), Requirements: computeRequirements(sku, vmsize, architecture, offerings, region), Offerings: offerings, - Capacity: computeCapacity(ctx, sku, kc, nodeClass), + Capacity: computeCapacity(sku, kc, nodeClass, vmMemoryOverheadPercent), Overhead: &cloudprovider.InstanceTypeOverhead{ KubeReserved: KubeReservedResources(lo.Must(sku.VCPU()), lo.Must(sku.Memory())), SystemReserved: SystemReservedResources(), @@ -263,10 +260,10 @@ func getArchitecture(architecture string) string { return architecture // unrecognized } -func computeCapacity(ctx context.Context, sku *skewer.SKU, kc *corev1beta1.KubeletConfiguration, nodeClass *v1alpha2.AKSNodeClass) v1.ResourceList { +func computeCapacity(sku *skewer.SKU, kc *corev1beta1.KubeletConfiguration, nodeClass *v1alpha2.AKSNodeClass, vmMemoryOverheadPercent float64) v1.ResourceList { return v1.ResourceList{ v1.ResourceCPU: *cpu(sku), - v1.ResourceMemory: *memory(ctx, sku), + v1.ResourceMemory: *memory(sku, vmMemoryOverheadPercent), v1.ResourceEphemeralStorage: *ephemeralStorage(nodeClass), v1.ResourcePods: *pods(sku, kc), v1.ResourceName("nvidia.com/gpu"): *gpuNvidiaCount(sku), @@ -298,11 +295,11 @@ func memoryMiB(sku *skewer.SKU) int64 { return int64(memoryGiB(sku) * 1024) } -func memory(ctx context.Context, sku *skewer.SKU) *resource.Quantity { +func memory(sku *skewer.SKU, vmMemoryOverheadPercent float64) *resource.Quantity { memory := resources.Quantity(fmt.Sprintf("%dGi", int64(memoryGiB(sku)))) // Account for VM overhead in calculation memory.Sub(resource.MustParse(fmt.Sprintf("%dMi", int64(math.Ceil( - float64(memory.Value())*options.FromContext(ctx).VMMemoryOverheadPercent/1024/1024))))) + float64(memory.Value())*vmMemoryOverheadPercent/1024/1024))))) return memory } diff --git a/pkg/providers/instancetype/instancetypes.go b/pkg/providers/instancetype/instancetypes.go index a8442b2a7..4907722df 100644 --- a/pkg/providers/instancetype/instancetypes.go +++ b/pkg/providers/instancetype/instancetypes.go @@ -32,6 +32,7 @@ import ( "github.com/Azure/go-autorest/autorest/to" "github.com/Azure/karpenter-provider-azure/pkg/apis/v1alpha2" kcache "github.com/Azure/karpenter-provider-azure/pkg/cache" + "github.com/Azure/karpenter-provider-azure/pkg/operator/options" "github.com/Azure/karpenter-provider-azure/pkg/utils" "github.com/patrickmn/go-cache" "k8s.io/apimachinery/pkg/util/sets" @@ -68,18 +69,21 @@ type Provider struct { cm *pretty.ChangeMonitor // instanceTypesSeqNum is a monotonically increasing change counter used to avoid the expensive hashing operation on instance types instanceTypesSeqNum uint64 + + vmMemoryOverheadPercent float64 } -func NewProvider(region string, cache *cache.Cache, skuClient skuclient.SkuClient, pricingProvider *pricing.Provider, offeringsCache *kcache.UnavailableOfferings) *Provider { +func NewProvider(opts *options.Options, cache *cache.Cache, skuClient skuclient.SkuClient, pricingProvider *pricing.Provider, offeringsCache *kcache.UnavailableOfferings) *Provider { return &Provider{ // TODO: skewer api, subnetprovider, pricing provider, unavailable offerings, ... - region: region, - skuClient: skuClient, - pricingProvider: pricingProvider, - unavailableOfferings: offeringsCache, - cache: cache, - cm: pretty.NewChangeMonitor(), - instanceTypesSeqNum: 0, + region: opts.Location, + skuClient: skuClient, + pricingProvider: pricingProvider, + unavailableOfferings: offeringsCache, + cache: cache, + cm: pretty.NewChangeMonitor(), + instanceTypesSeqNum: 0, + vmMemoryOverheadPercent: opts.VMMemoryOverheadPercent, } } @@ -120,7 +124,7 @@ func (p *Provider) List( continue } instanceTypeZones := instanceTypeZones(sku, p.region) - instanceType := NewInstanceType(ctx, sku, vmsize, kc, p.region, p.createOfferings(sku, instanceTypeZones), nodeClass, architecture) + instanceType := NewInstanceType(sku, vmsize, kc, p.region, p.createOfferings(sku, instanceTypeZones), nodeClass, architecture, p.vmMemoryOverheadPercent) if len(instanceType.Offerings) == 0 { continue } diff --git a/pkg/providers/launchtemplate/launchtemplate.go b/pkg/providers/launchtemplate/launchtemplate.go index 3c5e5ac95..1506374d8 100644 --- a/pkg/providers/launchtemplate/launchtemplate.go +++ b/pkg/providers/launchtemplate/launchtemplate.go @@ -53,34 +53,20 @@ type Template struct { } type Provider struct { - imageFamily *imagefamily.Resolver - imageProvider *imagefamily.Provider - caBundle *string - clusterEndpoint string - tenantID string - subscriptionID string - userAssignedIdentityID string - resourceGroup string - location string - vnetGUID string + opts *options.Options + imageFamily *imagefamily.Resolver + imageProvider *imagefamily.Provider + caBundle *string } // TODO: add caching of launch templates -func NewProvider(_ context.Context, imageFamily *imagefamily.Resolver, imageProvider *imagefamily.Provider, caBundle *string, clusterEndpoint string, - tenantID, subscriptionID, userAssignedIdentityID, resourceGroup, location, vnetGUID string, -) *Provider { +func NewProvider(_ context.Context, opts *options.Options, imageFamily *imagefamily.Resolver, imageProvider *imagefamily.Provider, caBundle *string) *Provider { return &Provider{ - imageFamily: imageFamily, - imageProvider: imageProvider, - caBundle: caBundle, - clusterEndpoint: clusterEndpoint, - tenantID: tenantID, - subscriptionID: subscriptionID, - userAssignedIdentityID: userAssignedIdentityID, - resourceGroup: resourceGroup, - location: location, - vnetGUID: vnetGUID, + opts: opts, + imageFamily: imageFamily, + imageProvider: imageProvider, + caBundle: caBundle, } } @@ -108,13 +94,13 @@ func (p *Provider) GetTemplate(ctx context.Context, nodeClass *v1alpha2.AKSNodeC return launchTemplate, nil } -func (p *Provider) getStaticParameters(ctx context.Context, instanceType *cloudprovider.InstanceType, nodeClass *v1alpha2.AKSNodeClass, labels map[string]string) (*parameters.StaticParameters, error) { +func (p *Provider) getStaticParameters(_ context.Context, instanceType *cloudprovider.InstanceType, nodeClass *v1alpha2.AKSNodeClass, labels map[string]string) (*parameters.StaticParameters, error) { var arch string = corev1beta1.ArchitectureAmd64 if err := instanceType.Requirements.Compatible(scheduling.NewRequirements(scheduling.NewRequirement(v1.LabelArchStable, v1.NodeSelectorOpIn, corev1beta1.ArchitectureArm64))); err == nil { arch = corev1beta1.ArchitectureArm64 } // TODO: make conditional on either Azure CNI Overlay or pod subnet - vnetLabels, err := p.getVnetInfoLabels(ctx, nodeClass) + vnetLabels, err := p.getVnetInfoLabels(nodeClass) if err != nil { return nil, err } @@ -130,8 +116,8 @@ func (p *Provider) getStaticParameters(ctx context.Context, instanceType *cloudp labels[vnetDataPlaneLabel] = networkDataplaneCilium return ¶meters.StaticParameters{ - ClusterName: options.FromContext(ctx).ClusterName, - ClusterEndpoint: p.clusterEndpoint, + ClusterName: p.opts.ClusterName, + ClusterEndpoint: p.opts.ClusterEndpoint, Tags: nodeClass.Spec.Tags, Labels: labels, CABundle: p.caBundle, @@ -139,17 +125,17 @@ func (p *Provider) getStaticParameters(ctx context.Context, instanceType *cloudp GPUNode: utils.IsNvidiaEnabledSKU(instanceType.Name), GPUDriverVersion: utils.GetGPUDriverVersion(instanceType.Name), GPUImageSHA: utils.GetAKSGPUImageSHA(instanceType.Name), - TenantID: p.tenantID, - SubscriptionID: p.subscriptionID, - UserAssignedIdentityID: p.userAssignedIdentityID, - ResourceGroup: p.resourceGroup, - Location: p.location, - ClusterID: options.FromContext(ctx).ClusterID, - APIServerName: options.FromContext(ctx).GetAPIServerName(), - KubeletClientTLSBootstrapToken: options.FromContext(ctx).KubeletClientTLSBootstrapToken, - NetworkPlugin: options.FromContext(ctx).NetworkPlugin, - NetworkPolicy: options.FromContext(ctx).NetworkPolicy, - SubnetID: options.FromContext(ctx).SubnetID, + TenantID: p.opts.TenantID, + SubscriptionID: p.opts.SubscriptionID, + KubeletIdentityClientID: p.opts.KubeletIdentityClientID, + ResourceGroup: p.opts.ResourceGroup, + Location: p.opts.Location, + ClusterID: p.opts.ClusterID, + APIServerName: p.opts.APIServerName, + KubeletClientTLSBootstrapToken: p.opts.KubeletClientTLSBootstrapToken, + NetworkPlugin: p.opts.NetworkPlugin, + NetworkPolicy: p.opts.NetworkPolicy, + SubnetID: p.opts.SubnetID, }, nil } @@ -178,15 +164,15 @@ func mergeTags(tags ...map[string]string) (result map[string]*string) { }) } -func (p *Provider) getVnetInfoLabels(ctx context.Context, _ *v1alpha2.AKSNodeClass) (map[string]string, error) { - // TODO(bsoghigian): this should be refactored to lo.Ternary(nodeClass.Spec.VnetSubnetID != nil, lo.FromPtr(nodeClass.Spec.VnetSubnetID), os.Getenv("AZURE_SUBNET_ID")) when we add VnetSubnetID to the nodeclass - vnetSubnetComponents, err := utils.GetVnetSubnetIDComponents(options.FromContext(ctx).SubnetID) +func (p *Provider) getVnetInfoLabels(_ *v1alpha2.AKSNodeClass) (map[string]string, error) { + // TODO(bsoghigian): this should be refactored to lo.Ternary(nodeClass.Spec.SubnetID != nil, lo.FromPtr(nodeClass.Spec.SubnetID), p.opts.SubnetID) when we add VnetSubnetID to the nodeclass + vnetSubnetComponents, err := utils.GetVnetSubnetIDComponents(p.opts.SubnetID) if err != nil { return nil, err } vnetLabels := map[string]string{ vnetSubnetNameLabel: vnetSubnetComponents.SubnetName, - vnetGUIDLabel: p.vnetGUID, + vnetGUIDLabel: p.opts.VnetGUID, vnetPodNetworkTypeLabel: networkModeOverlay, } return vnetLabels, nil diff --git a/pkg/providers/launchtemplate/parameters/types.go b/pkg/providers/launchtemplate/parameters/types.go index 238ce0710..5d2033e24 100644 --- a/pkg/providers/launchtemplate/parameters/types.go +++ b/pkg/providers/launchtemplate/parameters/types.go @@ -31,7 +31,7 @@ type StaticParameters struct { GPUImageSHA string TenantID string SubscriptionID string - UserAssignedIdentityID string + KubeletIdentityClientID string Location string ResourceGroup string ClusterID string diff --git a/pkg/providers/loadbalancer/loadbalancer.go b/pkg/providers/loadbalancer/loadbalancer.go index d85a707cc..fdb5ecd8f 100644 --- a/pkg/providers/loadbalancer/loadbalancer.go +++ b/pkg/providers/loadbalancer/loadbalancer.go @@ -24,6 +24,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/Azure/karpenter-provider-azure/pkg/operator/options" "github.com/patrickmn/go-cache" "github.com/samber/lo" "knative.dev/pkg/logging" @@ -69,11 +70,11 @@ type BackendAddressPools struct { } // NewProvider creates a new LoadBalancer provider -func NewProvider(loadBalancersAPI LoadBalancersAPI, cache *cache.Cache, resourceGroup string) *Provider { +func NewProvider(opts *options.Options, loadBalancersAPI LoadBalancersAPI, cache *cache.Cache) *Provider { return &Provider{ loadBalancersAPI: loadBalancersAPI, cache: cache, - resourceGroup: resourceGroup, + resourceGroup: opts.NodeResourceGroup, } } diff --git a/pkg/providers/pricing/pricing.go b/pkg/providers/pricing/pricing.go index 8e05a7472..19aba465c 100644 --- a/pkg/providers/pricing/pricing.go +++ b/pkg/providers/pricing/pricing.go @@ -24,6 +24,7 @@ import ( "sync" "time" + "github.com/Azure/karpenter-provider-azure/pkg/operator/options" "github.com/Azure/karpenter-provider-azure/pkg/providers/pricing/client" "github.com/samber/lo" "knative.dev/pkg/logging" @@ -64,7 +65,9 @@ func NewAPI() client.PricingAPI { return client.New() } -func NewProvider(ctx context.Context, pricing client.PricingAPI, region string, startAsync <-chan struct{}) *Provider { +func NewProvider(ctx context.Context, opts *options.Options, pricing client.PricingAPI, startAsync <-chan struct{}) *Provider { + region := opts.Location + // see if we've got region specific pricing data staticPricing, ok := initialOnDemandPrices[region] if !ok {