Skip to content

Commit

Permalink
SAMLoader - ESAM supports
Browse files Browse the repository at this point in the history
  • Loading branch information
ltdrdata committed Apr 6, 2024
1 parent b63b689 commit cc0afe7
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 28 deletions.
2 changes: 1 addition & 1 deletion modules/impact/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os


version_code = [4, 86, 2]
version_code = [4, 87]
version = f"V{version_code[0]}.{version_code[1]}" + (f'.{version_code[2]}' if len(version_code) > 2 else '')

dependency_version = 20
Expand Down
78 changes: 56 additions & 22 deletions modules/impact/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,16 +471,57 @@ def sam_predict(predictor, points, plabs, bbox, threshold):
return total_masks


def make_sam_mask(sam_model, segs, image, detection_hint, dilation,
class SAMWrapper:
def __init__(self, model, is_auto_mode, safe_to_gpu=None):
self.model = model
self.safe_to_gpu = safe_to_gpu if safe_to_gpu is not None else SafeToGPU_stub()
self.is_auto_mode = is_auto_mode

def prepare_device(self):
if self.is_auto_mode:
device = comfy.model_management.get_torch_device()
self.safe_to_gpu.to_device(self.model, device=device)

def release_device(self):
if self.is_auto_mode:
self.model.to(device="cpu")

def predict(self, image, points, plabs, bbox, threshold):
predictor = SamPredictor(self.model)
predictor.set_image(image, "RGB")

return sam_predict(predictor, points, plabs, bbox, threshold)


class ESAMWrapper:
def __init__(self, model, device):
self.model = model
self.func_inference = nodes.NODE_CLASS_MAPPINGS['Yoloworld_ESAM_Zho']
self.device = device

def prepare_device(self):
pass

def release_device(self):
pass

def predict(self, image, points, plabs, bbox, threshold):
if self.device == 'CPU':
self.device = 'cpu'
else:
self.device = 'cuda'

detected_masks = self.func_inference.inference_sam_with_boxes(image=image, xyxy=[bbox], model=self.model, device=self.device)
return [detected_masks.squeeze(0)]


def make_sam_mask(sam_obj, segs, image, detection_hint, dilation,
threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative):
if sam_model.is_auto_mode:
device = comfy.model_management.get_torch_device()
sam_model.safe_to.to_device(sam_model, device=device)

sam_obj.prepare_device()

try:
predictor = SamPredictor(sam_model)
image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
predictor.set_image(image, "RGB")

total_masks = []

Expand All @@ -503,7 +544,7 @@ def make_sam_mask(sam_model, segs, image, detection_hint, dilation,
else:
plabs.append(1)

detected_masks = sam_predict(predictor, points, plabs, None, threshold)
detected_masks = sam_obj.predict(image, points, plabs, None, threshold)
total_masks += detected_masks

else:
Expand Down Expand Up @@ -572,15 +613,14 @@ def make_sam_mask(sam_model, segs, image, detection_hint, dilation,
points += npoints
plabs += nplabs

detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold)
detected_masks = sam_obj.predict(image, points, plabs, dilated_bbox, threshold)
total_masks += detected_masks

# merge every collected masks
mask = combine_masks2(total_masks)

finally:
if sam_model.is_auto_mode:
sam_model.to(device="cpu")
sam_obj.release_device()

if mask is not None:
mask = mask.float()
Expand Down Expand Up @@ -735,16 +775,13 @@ def every_three_pick_last(stacked_masks):
return selected_masks


def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation,
def make_sam_mask_segmented(sam_obj, segs, image, detection_hint, dilation,
threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative):
if sam_model.is_auto_mode:
device = comfy.model_management.get_torch_device()
sam_model.safe_to.to_device(sam_model, device=device)

sam_obj.prepare_device()

try:
predictor = SamPredictor(sam_model)
image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
predictor.set_image(image, "RGB")

total_masks = []

Expand All @@ -767,7 +804,7 @@ def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation,
else:
plabs.append(1)

detected_masks = sam_predict(predictor, points, plabs, None, threshold)
detected_masks = sam_obj.predict(image, points, plabs, None, threshold)
total_masks += detected_masks

else:
Expand All @@ -785,18 +822,15 @@ def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation,
mask_hint_threshold, use_small_negative,
mask_hint_use_negative)

detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold)
detected_masks = sam_obj.predict(image, points, plabs, dilated_bbox, threshold)

total_masks += detected_masks

# merge every collected masks
mask = combine_masks2(total_masks)

finally:
if sam_model.is_auto_mode:
sam_model.cpu()

pass
sam_obj.release_device()

mask_working_device = torch.device("cpu")

Expand Down
30 changes: 25 additions & 5 deletions modules/impact/impact_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def INPUT_TYPES(cls):
models = [x for x in folder_paths.get_filename_list("sams") if 'hq' not in x]
return {
"required": {
"model_name": (models, ),
"model_name": (models + ['ESAM'], ),
"device_mode": (["AUTO", "Prefer GPU", "CPU"],),
}
}
Expand All @@ -96,6 +96,24 @@ def INPUT_TYPES(cls):
CATEGORY = "ImpactPack"

def load_model(self, model_name, device_mode="auto"):
if model_name == 'ESAM':
if 'ESAM_ModelLoader_Zho' not in nodes.NODE_CLASS_MAPPINGS:
try_install_custom_node('https://github.com/ZHO-ZHO-ZHO/ComfyUI-YoloWorld-EfficientSAM',
"To use 'ESAM' model, 'ComfyUI-YoloWorld-EfficientSAM' extension is required.")
raise Exception("'ComfyUI-YoloWorld-EfficientSAM' node isn't installed.")

esam_loader = nodes.NODE_CLASS_MAPPINGS['ESAM_ModelLoader_Zho']()

if device_mode == 'CPU':
esam = esam_loader.load_esam_model('CPU')[0]
else:
device_mode = 'CUDA'
esam = esam_loader.load_esam_model('CUDA')[0]

sam_obj = core.ESAMWrapper(esam, device_mode)
print(f"Loads EfficientSAM model: (device:{device_mode})")
return (sam_obj, )

modelname = folder_paths.get_full_path("sams", model_name)

if 'vit_h' in model_name:
Expand All @@ -107,18 +125,20 @@ def load_model(self, model_name, device_mode="auto"):

sam = sam_model_registry[model_kind](checkpoint=modelname)
size = os.path.getsize(modelname)
sam.safe_to = core.SafeToGPU(size)
safe_to = core.SafeToGPU(size)

# Unless user explicitly wants to use CPU, we use GPU
device = comfy.model_management.get_torch_device() if device_mode == "Prefer GPU" else "CPU"

if device_mode == "Prefer GPU":
sam.safe_to.to_device(sam, device)
safe_to.to_device(sam, device)

is_auto_mode = device_mode == "AUTO"

sam.is_auto_mode = device_mode == "AUTO"
sam_obj = core.SAMWrapper(sam, is_auto_mode=is_auto_mode, safe_to_gpu=safe_to)

print(f"Loads SAM model: {modelname} (device:{device_mode})")
return (sam, )
return (sam_obj, )


class ONNXDetectorForEach:
Expand Down

0 comments on commit cc0afe7

Please sign in to comment.