Skip to content

Commit

Permalink
feat: Specify subnet id list and allow selection of availability zone (
Browse files Browse the repository at this point in the history
…#39)

* feat: Specify subnet id list and allow selection of availability zone

* fix: redundant fetch of subnet ids from env variables
  • Loading branch information
roehrijn authored Nov 29, 2024
1 parent 566b812 commit 7169a2c
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 60 deletions.
6 changes: 5 additions & 1 deletion hack/provider/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ optionGroups:
- AWS_KMS_KEY_ARN_FOR_SESSION_MANAGER
- AWS_USE_ROUTE53
- AWS_ROUTE53_ZONE_NAME
- AWS_AVAILABILITY_ZONE
name: "AWS options"
defaultVisible: false
- options:
Expand Down Expand Up @@ -83,11 +84,14 @@ options:
description: The vpc id to use.
default: ""
AWS_SUBNET_ID:
description: The subnet id to use.
description: The subnet id to use. Can also be multiple once separated by a comma. By default the one with the most available IPs is chosen. Can be overridden by AWS_AVAILABILITY_ZONE.
default: ""
AWS_SECURITY_GROUP_ID:
description: The security group id to use. Multiple can be specified by separating with a comma.
default: ""
AWS_AVAILABILITY_ZONE:
description: The name of the AWS availability zone can be specified to choose a subnet out of the desired zone.
default: ""
AWS_AMI:
description: The disk image to use.
default: ""
Expand Down
136 changes: 79 additions & 57 deletions pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,65 +124,98 @@ type AwsProvider struct {
WorkingDirectory string
}

func GetSubnetID(ctx context.Context, provider *AwsProvider) (string, error) {
svc := ec2.NewFromConfig(provider.AwsConfig)

// first search for a default devpod specific subnet, if it fails
// we search the subnet with most free IPs that can do also public-ipv4
input := &ec2.DescribeSubnetsInput{
Filters: []types.Filter{
{
Name: aws.String("tag:devpod"),
Values: []string{
"devpod",
},
},
},
func GetSubnet(ctx context.Context, provider *AwsProvider) (string, error) {
// in case a single subnet ID is specified, use it without further checks
if len(provider.Config.SubnetIDs) == 1 {
return provider.Config.SubnetIDs[0], nil
}

result, err := svc.DescribeSubnets(ctx, input)
if err != nil {
return "", err
}
svc := ec2.NewFromConfig(provider.AwsConfig)
// in case multiple subnet IDs are specified, we return the one with most free IPs
if len(provider.Config.SubnetIDs) > 1 {
subnets, err := svc.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{
SubnetIds: provider.Config.SubnetIDs,
})
if err != nil {
return "", fmt.Errorf("list specified subnets %q: %w", provider.Config.SubnetIDs, err)
}
var maxIPCount int32
var subnet *types.Subnet
for _, s := range subnets.Subnets {
if provider.Config.AvailabilityZone != "" && *s.AvailabilityZone != provider.Config.AvailabilityZone {
continue
}
if *s.AvailableIpAddressCount > maxIPCount {
maxIPCount = *s.AvailableIpAddressCount
subnet = &s
}
}

if subnet == nil {
if provider.Config.AvailabilityZone == "" {
return "", fmt.Errorf("no subnets found with IDs %q", provider.Config.SubnetIDs)
} else {
return "", fmt.Errorf("no subnets found with IDs %q in availability zone %q", provider.Config.SubnetIDs, provider.Config.AvailabilityZone)
}
}

if len(result.Subnets) > 0 {
return *result.Subnets[0].SubnetId, nil
return *subnet.SubnetId, nil
}

input = &ec2.DescribeSubnetsInput{
Filters: []types.Filter{
{
Name: aws.String("vpc-id"),
Values: []string{
provider.Config.VpcID,
},
},
// retrieve and index all visible subnets
input := &ec2.DescribeSubnetsInput{}
if provider.Config.AvailabilityZone != "" {
input.Filters = []types.Filter{
{
Name: aws.String("map-public-ip-on-launch"),
Name: aws.String("availability-zone"),
Values: []string{
"true",
provider.Config.AvailabilityZone,
},
},
},
}
}
p := ec2.NewDescribeSubnetsPaginator(svc, input)
var taggedSubnetMaxIPCount, vpcedSubnetMaxIPCount int32
var taggedSubnet, vpcedSubnet *types.Subnet
for p.HasMorePages() {
page, err := p.NextPage(ctx)
if err != nil {
return "", fmt.Errorf("list all subnets: %w", err)
}

result, err = svc.DescribeSubnets(ctx, input)
if err != nil {
return "", err
for _, s := range page.Subnets {
for _, tag := range s.Tags {
if *tag.Key == "devpod" && *tag.Value == "devpod" {
if *s.AvailableIpAddressCount > taggedSubnetMaxIPCount {
taggedSubnetMaxIPCount = *s.AvailableIpAddressCount
taggedSubnet = &s
}
}
}
if provider.Config.VpcID != "" && *s.VpcId == provider.Config.VpcID &&
*s.AvailableIpAddressCount > vpcedSubnetMaxIPCount &&
*s.MapPublicIpOnLaunch {
vpcedSubnetMaxIPCount = *s.AvailableIpAddressCount
vpcedSubnet = &s
}
}
}

var maxIPCount int32
// if we found tagged subnets, we return the one with the most free IPs
if taggedSubnet != nil {
return *taggedSubnet.SubnetId, nil
}

subnetID := ""
// we found no tagged subnet so far. If a VPC is specified, we search for a subnet with the most free IPs that can do also public-ipv4
if vpcedSubnet != nil {
return *vpcedSubnet.SubnetId, nil
}

for _, v := range result.Subnets {
if *v.AvailableIpAddressCount > maxIPCount {
maxIPCount = *v.AvailableIpAddressCount
subnetID = *v.SubnetId
}
if provider.Config.VpcID == "" {
return "", errors.New("could not find a suitable subnet. Please either specify a subnet ID or VPC ID, or tag the desired subnets with devpod:devpod")
}

return subnetID, nil
return "", nil
}

func GetDevpodVPC(ctx context.Context, provider *AwsProvider) (string, error) {
Expand Down Expand Up @@ -802,22 +835,11 @@ func Create(
}
}

if providerAws.Config.VpcID != "" && providerAws.Config.SubnetID == "" {
subnetID, err := GetSubnetID(ctx, providerAws)
if err != nil {
return Machine{}, err
}

if subnetID == "" {
return Machine{}, fmt.Errorf("could not find a matching SubnetID in VPC %s, please specify one", providerAws.Config.VpcID)
}

instance.SubnetId = &subnetID
}

if providerAws.Config.SubnetID != "" {
instance.SubnetId = &providerAws.Config.SubnetID
subnetID, err := GetSubnet(ctx, providerAws)
if err != nil {
return Machine{}, fmt.Errorf("determine subnet ID: %w", err)
}
instance.SubnetId = &subnetID

result, err := svc.RunInstances(ctx, instance)
if err != nil {
Expand Down
14 changes: 12 additions & 2 deletions pkg/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"
"strconv"
"strings"
)

var (
Expand All @@ -15,6 +16,7 @@ var (
AWS_SECURITY_GROUP_ID = "AWS_SECURITY_GROUP_ID"
AWS_SUBNET_ID = "AWS_SUBNET_ID"
AWS_VPC_ID = "AWS_VPC_ID"
AWS_AVAILABILITY_ZONE = "AWS_AVAILABILITY_ZONE"
AWS_INSTANCE_TAGS = "AWS_INSTANCE_TAGS"
AWS_INSTANCE_PROFILE_ARN = "AWS_INSTANCE_PROFILE_ARN"
AWS_USE_INSTANCE_CONNECT_ENDPOINT = "AWS_USE_INSTANCE_CONNECT_ENDPOINT"
Expand All @@ -35,7 +37,8 @@ type Options struct {
MachineID string
MachineType string
VpcID string
SubnetID string
SubnetIDs []string
AvailabilityZone string
SecurityGroupID string
InstanceProfileArn string
InstanceTags string
Expand Down Expand Up @@ -74,8 +77,8 @@ func FromEnv(init bool) (*Options, error) {
retOptions.DiskImage = os.Getenv(AWS_AMI)
retOptions.RootDevice = os.Getenv(AWS_ROOT_DEVICE)
retOptions.SecurityGroupID = os.Getenv(AWS_SECURITY_GROUP_ID)
retOptions.SubnetID = os.Getenv(AWS_SUBNET_ID)
retOptions.VpcID = os.Getenv(AWS_VPC_ID)
retOptions.AvailabilityZone = os.Getenv(AWS_AVAILABILITY_ZONE)
retOptions.InstanceTags = os.Getenv(AWS_INSTANCE_TAGS)
retOptions.InstanceProfileArn = os.Getenv(AWS_INSTANCE_PROFILE_ARN)
retOptions.Zone = os.Getenv(AWS_REGION)
Expand All @@ -87,6 +90,13 @@ func FromEnv(init bool) (*Options, error) {
retOptions.UseRoute53Hostnames = os.Getenv(AWS_USE_ROUTE53) == "true"
retOptions.Route53ZoneName = os.Getenv(AWS_ROUTE53_ZONE_NAME)

subnetIDs := os.Getenv(AWS_SUBNET_ID)
if subnetIDs != "" {
for _, subnetID := range strings.Split(subnetIDs, ",") {
retOptions.SubnetIDs = append(retOptions.SubnetIDs, strings.TrimSpace(subnetID))
}
}

// Return early if we're just doing init
if init {
return retOptions, nil
Expand Down

0 comments on commit 7169a2c

Please sign in to comment.