diff --git a/api/pod.go b/api/pod.go index e35e9c4..9a39579 100644 --- a/api/pod.go +++ b/api/pod.go @@ -355,11 +355,11 @@ func StartOnDemandPod(id string) (pod map[string]interface{}, err error) { return } -func StartSpotPod(id string, bidPerGpu float32) (podBidResume map[string]interface{}, err error) { +func StartSpotPod(id string, bidPerGpu float32, gpuCount int) (podBidResume map[string]interface{}, err error) { input := Input{ Query: ` mutation Mutation($podId: String!, $bidPerGpu: Float!) { - podBidResume(input: {podId: $podId, bidPerGpu: $bidPerGpu}) { + podBidResume(input: {podId: $podId, bidPerGpu: $bidPerGpu, gpuCount: $gpuCount}) { id costPerHr desiredStatus @@ -367,7 +367,7 @@ func StartSpotPod(id string, bidPerGpu float32) (podBidResume map[string]interfa } } `, - Variables: map[string]interface{}{"podId": id, "bidPerGpu": bidPerGpu}, + Variables: map[string]interface{}{"podId": id, "bidPerGpu": bidPerGpu, "gpuCount": gpuCount}, } res, err := Query(input) if err != nil { diff --git a/cmd/pod/startPod.go b/cmd/pod/startPod.go index 0e4a273..b908350 100644 --- a/cmd/pod/startPod.go +++ b/cmd/pod/startPod.go @@ -9,6 +9,7 @@ import ( ) var bidPerGpu float32 +var gpuCount int var StartPodCmd = &cobra.Command{ Use: "pod [podId]", @@ -19,7 +20,7 @@ var StartPodCmd = &cobra.Command{ var err error var pod map[string]interface{} if bidPerGpu > 0 { - pod, err = api.StartSpotPod(args[0], bidPerGpu) + pod, err = api.StartSpotPod(args[0], bidPerGpu, gpuCount) } else { pod, err = api.StartOnDemandPod(args[0]) } @@ -36,4 +37,5 @@ var StartPodCmd = &cobra.Command{ func init() { StartPodCmd.Flags().Float32Var(&bidPerGpu, "bid", 0, "bid per gpu for spot price") + StartPodCmd.Flags().IntVar(&gpuCount, "gpuCount", 1, "number of GPUs to request") }