diff --git a/monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py b/monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py index 77e52f7..64b2213 100644 --- a/monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py +++ b/monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py @@ -238,14 +238,14 @@ def update_slice( continue inputs = inputs_l[..., start_idx - (n_z_slices // 2) : start_idx + (n_z_slices // 2) + 1].permute(2, 0, 1) - if device and (device == "cuda" or isinstance(device, torch.device) and device.type == "cuda"): + if device and ((isinstance(device, str) and device.startswith('cuda')) or isinstance(device, torch.device) and device.type == "cuda"): inputs = inputs.cuda() data, unique_labels = prepare_sam_val_input( inputs, class_prompts, point_prompts, start_idx, original_affine, device=device ) predictor.eval() - if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"): + if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type == "cuda"): with torch.cuda.amp.autocast(): outputs = predictor(data) logit = outputs[0]["high_res_logits"] @@ -297,14 +297,14 @@ def iterate_all( ) for start_idx in start_range: inputs = inputs_l[..., start_idx - n_z_slices // 2 : start_idx + n_z_slices // 2 + 1].permute(2, 0, 1) - if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"): + if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type == "cuda"): inputs = inputs.cuda() data, unique_labels = prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, device=device) predictor = predictor.eval() with autocast(): if cachedEmbedding: curr_embedding = cachedEmbedding[start_idx] - if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"): + if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type == "cuda"): curr_embedding = curr_embedding.cuda() outputs = predictor.get_mask_prediction(data, curr_embedding) else: diff --git a/monailabel/monaivista/lib/model/vista_point_2pt5/utils/utils.py b/monailabel/monaivista/lib/model/vista_point_2pt5/utils/utils.py index 695aeee..b2344ca 100644 --- a/monailabel/monaivista/lib/model/vista_point_2pt5/utils/utils.py +++ b/monailabel/monaivista/lib/model/vista_point_2pt5/utils/utils.py @@ -88,7 +88,7 @@ def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, origi class_list = [[i + 1] for i in class_prompts] unique_labels = torch.tensor(class_list).long() - if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"): + if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type == "cuda"): unique_labels = unique_labels.cuda() volume_point_coords = [cp for cp in foreground_all] @@ -133,7 +133,7 @@ def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, origi if point_coords: point_coords = torch.tensor(point_coords).long() point_labels = torch.tensor(point_labels).long() - if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"): + if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type == "cuda"): point_coords = point_coords.cuda() point_labels = point_labels.cuda()