Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
Signed-off-by: Mingxin Zheng <[email protected]>
  • Loading branch information
mingxin-zheng committed Oct 23, 2024
1 parent 82234e4 commit 0a43aa1
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 37 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,6 @@ venv.bak/

# Data files
data/

# Logs
*.log
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
demo_monai_vila2d:
cd thirdparty/VILA; \
./environment_setup.sh
pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage, fire] torchxrayvision
pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage,fire,ignite] torchxrayvision
mkdir -p $(HOME)/.torchxrayvision/models_data/ \
&& wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
-O $(HOME)/.torchxrayvision/models_data/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ For details, see [here](./monai_vila2d/README.md).
1. **GPU Memory**: Ensure that the GPU has sufficient memory to run the models:
- **VILA-M3**: 8B: ~18GB, 13B: ~30GB
- **CXR**: This expert loads various [TorchXRayVision](https://github.com/mlmed/torchxrayvision) models and performs ensemble predictions. The memory requirement is roughly 1.5GB in total.
- **VISTA3D**: This expert model is based on the [VISTA3D](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/monaitoolkit/models/monai_vista3d).
- **VISTA3D**: This expert model is based on the [VISTA3D](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/monaitoolkit/models/monai_vista3d). The memory requirement is roughly 12GB, and peak memory usage can be higher, depending on the input size of the 3D volume.
- **BRATS**: (TBD)

#### Setup Environment
Expand Down
80 changes: 45 additions & 35 deletions monai_vila2d/demo/experts/expert_monai_vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import os
import re
import sys
import tempfile
from pathlib import Path
from shutil import move

import requests
from experts.base_expert import BaseExpert
from experts.utils import get_monai_transforms, get_slice_filenames
from monai.bundle import create_workflow
Expand All @@ -27,24 +29,7 @@ class ExpertVista3D(BaseExpert):
def __init__(self) -> None:
"""Initialize the VISTA-3D expert model."""
self.model_name = "VISTA3D"
self.bundle_root = os.path.expanduser("~/.cache/torch/hub/bundle/vista3d_v0.5.4")
self.result_dir = "inference_results"
sys.path = [self.bundle_root] + sys.path
override = {
"dataset#data": [{}],
"output_dtype": "uint8",
"seperate_folder": False,
"output_ext": ".nrrd",
"output_dir": self.result_dir,
}
self.workflow = create_workflow(
workflow_type="infer",
bundle_root=self.bundle_root,
config_file=os.path.join(self.bundle_root, f"configs/inference.json"),
logging_file=os.path.join(self.bundle_root, "configs/logging.conf"),
meta_file=os.path.join(self.bundle_root, "configs/metadata.json"),
**override,
)
self.bundle_root = os.path.expanduser("~/.cache/torch/hub/bundle/vista3d_v0.5.4/vista3d")

def label_id_to_name(self, label_id: int, label_dict: dict):
"""
Expand Down Expand Up @@ -129,12 +114,26 @@ def mentioned_by(self, input: str):
return False
return self.model_name in str(matches[0])

def download_file(self, url: str, img_file: str):
"""
Download the file from the URL.
Args:
url (str): The URL.
img_file (str): The file path.
"""
parent_dir = os.path.dirname(img_file)
os.makedirs(parent_dir, exist_ok=True)
with open(img_file, "wb") as f:
response = requests.get(url)
f.write(response.content)

def run(
self,
img_file: str = "",
image_url: str = "",
input: str = "",
output_dir: str = "",
img_file: str = "",
slice_index: int = 0,
prompt: str = "",
**kwargs,
Expand All @@ -146,11 +145,16 @@ def run(
image_url (str): The image URL.
input (str): The input text.
output_dir (str): The output directory.
img_file (str): The image file path.
img_file (str): The image file path. If not provided, download from the URL.
slice_index (int): The slice index.
prompt (str): The prompt text from the original request.
**kwargs: Additional keyword arguments.
"""
if not img_file:
# Download from the URL
img_file = os.path.join(output_dir, os.path.basename(image_url))
self.download_file(image_url, img_file)

output_dir = Path(output_dir)
matches = re.findall(r"<(.*?)>", input)
if len(matches) != 1:
Expand Down Expand Up @@ -183,23 +187,29 @@ def run(
vista3d_prompts = {"classes": list(label_groups[arg_matches[0]].keys())}

# Trigger the VISTA-3D model
api_key = os.getenv("NIM_API_KEY", "Invalid")
if api_key == "Invalid":
raise ValueError(f"Expert model API key not found to trigger {self.NIM_VISTA3D}")
input_dict = {"image": img_file}
if vista3d_prompts is not None:
input_dict["prompts"] = vista3d_prompts

sys.path = [self.bundle_root] + sys.path

input_data = {"image": image_url}
if vista3d_prompts is not None:
input_data["prompts"] = vista3d_prompts

self.workflow.parser.ref_resolver.items["dataset"].config["data"][0] = input_data
self.workflow.evaluator.run()
output_root_dir = os.path.join(self.bundle_root, self.result_dir)
if not os.path.exists(output_root_dir):
raise ValueError("VISTA-3D model failed to run.")
output_file = os.listdir(output_root_dir)[0]
seg_file = os.path.join(output_root_dir, "segmentation.nrrd")
move(output_file, seg_file)
with tempfile.TemporaryDirectory() as temp_dir:
workflow = create_workflow(
workflow_type="infer",
bundle_root=self.bundle_root,
config_file=os.path.join(self.bundle_root, f"configs/inference.json"),
logging_file=os.path.join(self.bundle_root, "configs/logging.conf"),
meta_file=os.path.join(self.bundle_root, "configs/metadata.json"),
input_dict=input_dict,
output_dtype="uint8",
separate_folder=False,
output_ext=".nii.gz",
output_dir=temp_dir,
)
workflow.evaluator.run()
output_file = os.path.join(temp_dir, os.listdir(temp_dir)[0])
seg_file = os.path.join(output_dir, "segmentation.nii.gz")
move(output_file, seg_file)

text_output = self.segmentation_to_string(
output_dir,
Expand Down

0 comments on commit 0a43aa1

Please sign in to comment.