From 86c349ce10a455f28aa00c73f5ebb2dda49ce40c Mon Sep 17 00:00:00 2001 From: mabingqi Date: Thu, 26 Sep 2024 19:46:17 +0800 Subject: [PATCH] model loading revision --- IVM.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/IVM.py b/IVM.py index 5829660..496c253 100644 --- a/IVM.py +++ b/IVM.py @@ -27,7 +27,7 @@ def load(ckpt_path, low_gpu_memory = False): url = "https://drive.google.com/uc?export=download&id=1OyVci6rAwnb2sJPxhObgK7AvlLYDLLHw" sam_ckpt = _download(url, "sam_vit_h_4b8939.pth", os.path.expanduser(f"~/.cache/IVM/Sam")) ckpt = torch.load(ckpt_path, map_location="cpu") - model = IVM(sam_model=sam_ckpt) + model = IVM(sam_model=sam_ckpt).eval() model.load_state_dict(ckpt, strict=False) if low_gpu_memory: return accelerate.cpu_offload(model, "cuda") else: return model.cuda()