From 0d064d7ce7ab5964ab4eb67a678a3abf8b503444 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Date: Tue, 22 Oct 2024 22:06:53 +0800 Subject: [PATCH] Change TorchXrayvision implementation to local (#26) This PR: - addresses some HTML formats to improve visualization of disclaimer and chat history - changes the torchxrayvision running to local - adds documentation how to implement one's own expert model (experimental) --------- Signed-off-by: Mingxin Zheng --- Makefile | 23 +++- README.md | 11 +- monai_vila2d/demo/experts/README.md | 33 +++++ .../demo/experts/expert_torchxrayvision.py | 125 +++++++++++++++--- monai_vila2d/demo/gradio_monai_vila2d.py | 50 +++---- 5 files changed, 200 insertions(+), 42 deletions(-) create mode 100644 monai_vila2d/demo/experts/README.md diff --git a/Makefile b/Makefile index 4f8bd07..7069344 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,23 @@ -demo_monai_vila2d: +demo_monai_vila2d: cxr_download cd thirdparty/VILA; \ ./environment_setup.sh - pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage] + pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage] torchxrayvision + +cxr_download: + mkdir -p $(HOME)/.torchxrayvision/models_data/ \ + && wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + -O $(HOME)/.torchxrayvision/models_data/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + && wget https://github.com/mlmed/torchxrayvision/releases/download/v1/chex-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + -O $(HOME)/.torchxrayvision/models_data/chex-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + && wget https://github.com/mlmed/torchxrayvision/releases/download/v1/mimic_ch-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + -O $(HOME)/.torchxrayvision/models_data/mimic_ch-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + && wget https://github.com/mlmed/torchxrayvision/releases/download/v1/mimic_nb-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + -O $(HOME)/.torchxrayvision/models_data/mimic_nb-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + && wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + -O $(HOME)/.torchxrayvision/models_data/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + && wget https://github.com/mlmed/torchxrayvision/releases/download/v1/pc-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + -O $(HOME)/.torchxrayvision/models_data/pc-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + && wget https://github.com/mlmed/torchxrayvision/releases/download/v1/kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + -O $(HOME)/.torchxrayvision/models_data/kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \ + && wget https://github.com/mlmed/torchxrayvision/releases/download/v1/pc-nih-rsna-siim-vin-resnet50-test512-e400-state.pt \ + -O $(HOME)/.torchxrayvision/models_data/pc-nih-rsna-siim-vin-resnet50-test512-e400-state.pt diff --git a/README.md b/README.md index b3bedb9..464f41c 100644 --- a/README.md +++ b/README.md @@ -39,10 +39,15 @@ For details, see [here](./monai_vila2d/README.md). To install these, run ```bash sudo apt-get update - sudo apt-get install -y python3.10 python3.10-venv python3.10-dev git + sudo apt-get install -y wget python3.10 python3.10-venv python3.10-dev git ``` NOTE: The commands are tailored for the Docker image `nvidia/cuda:12.2.2-devel-ubuntu22.04`. If using a different setup, adjust the commands accordingly. +1. **GPU Memory**: Ensure that the GPU has sufficient memory to run the models: + - **VILA-M3**: 8B: ~18GB, 13B: ~30GB + - **CXR**: This expert loads various [TorchXRayVision](https://github.com/mlmed/torchxrayvision) models and performs ensemble predictions. The memory requirement is roughly 1.5GB in total. + - **VISTA3D**: (TBD) + - **BRATS**: (TBD) #### Setup Environment @@ -64,7 +69,6 @@ For details, see [here](./monai_vila2d/README.md). 1. Set the API keys for calling the expert models: ```bash - export api_key= export NIM_API_KEY= ``` @@ -77,6 +81,9 @@ For details, see [here](./monai_vila2d/README.md). --port 7860 ``` +1. Adding your own model + - This is still a work in progress. Please refer to the [README](./monai_vila2d/demo/experts/README.md) for more details. + ## Contributing To lint the code, please install these packages: diff --git a/monai_vila2d/demo/experts/README.md b/monai_vila2d/demo/experts/README.md new file mode 100644 index 0000000..be91693 --- /dev/null +++ b/monai_vila2d/demo/experts/README.md @@ -0,0 +1,33 @@ +# Adding a new expert model + +[BaseExpert](./base_expert.py) is a template for adding a new expert model to the repository. Please follow the instructions below to add a new expert model. + +1. Create a new python file in the `monai_vila2d/demo/experts` directory. The file name should be the name of the expert model you are adding. For example, if you are adding a new expert model called `MyExpert`, the file name should be `my_expert.py`. + +2. Implement the expert model `mentioned_by` method in the new python file. The method should return a boolean value indicating whether the expert model is mentioned by the given text from the language model. For example: + +```python +def mentioned_by(text: str) -> bool: + return "my expert" in text +``` + +3. Implement the expert model `run` method in the new python file. The method should return the expert model's prediction given the input image. For example: + +```python + +def run(image_url: str) -> Tuple[str, str, str, str]: + # Load the image + image = load_image(image_url) + + # Perform inference + prediction = my_model(image) + + # Return the prediction + text_output = "The expert model prediction is a segmentation mask." + # save the prediction to a file + save_image(prediction, "prediction.png") + image_output = "prediction.png" + instruction = "Use this mask to answer: what is the object in the image?" + seg_file = "prediction.png" + return text_output, image_output, instruction, seg_file +``` diff --git a/monai_vila2d/demo/experts/expert_torchxrayvision.py b/monai_vila2d/demo/experts/expert_torchxrayvision.py index b545c03..808be67 100644 --- a/monai_vila2d/demo/experts/expert_torchxrayvision.py +++ b/monai_vila2d/demo/experts/expert_torchxrayvision.py @@ -9,21 +9,92 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +""" +The implementation of TorchXRayVision expert model is adapted from the Get-Started example in TorchXRayVision: +https://github.com/mlmed/torchxrayvision +""" + import re -import requests +import skimage.io +import torch +import torchxrayvision as xrv from experts.base_expert import BaseExpert +MODEL_NAMES = [ + "densenet121-res224-all", + "densenet121-res224-chex", + "densenet121-res224-mimic_ch", + "densenet121-res224-mimic_nb", + "densenet121-res224-nih", + "densenet121-res224-pc", + "densenet121-res224-rsna", + "resnet50-res512-all", +] + +# Taken from https://github.com/mlmed/torchxrayvision/blob/master/torchxrayvision/models.py +valid_labels_model = { + "densenet121-res224-all": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'], + "densenet121-res224-chex": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'], + "densenet121-res224-mimic_ch": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'], + "densenet121-res224-mimic_nb": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'], + "densenet121-res224-nih": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', '', '', ''], + "densenet121-res224-pc": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', 'Fracture', '', ''], + "densenet121-res224-rsna": ['', '', '', '', '', '', '', '', 'Pneumonia', '', '', '', '', '', '', '', 'Lung Opacity', ''], + "resnet50-res512-all": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'], +} + + +# Copied from https://github.com/Project-MONAI/VLM/blob/b74d7e444c6604ea83846b67bb98b5172a7e5495/monai_vila2d/data_prepare/experts/torchxrayvision/torchxray_cls.py#L47-L54 +group_0 = ["resnet50-res512-all"] +group_1 = [ + "densenet121-res224-all", + "densenet121-res224-nih", + "densenet121-res224-chex", + "densenet121-res224-mimic_ch", + "densenet121-res224-mimic_nb", + "densenet121-res224-rsna", + "densenet121-res224-pc", + "resnet50-res512-all", +] +group_2 = [ + "densenet121-res224-all", + "densenet121-res224-chex", + "densenet121-res224-pc", + "resnet50-res512-all", +] +group_4 = [ + "densenet121-res224-all", + "densenet121-res224-nih", + "densenet121-res224-chex", + "resnet50-res512-all", +] + +cls_models = { + "Fracture": group_4, + "Pneumothorax": group_0, + "Lung Opacity": group_1, + "Atelectasis": group_2, + "Cardiomegaly": group_2, + "Consolidation": group_2, + "Edema": group_2, + "Effusion": group_2, +} + class ExpertTXRV(BaseExpert): """Expert model for the TorchXRayVision model.""" - NIM_CXR = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/5ac6a152-6d3f-4691-aef8-9e656557ee45" - def __init__(self) -> None: """Initialize the CXR expert model.""" self.model_name = "CXR" + self.models = {} + for name in MODEL_NAMES: + if "densenet" in name: + self.models[name] = xrv.models.DenseNet(weights=name).to("cuda") + elif "resnet" in name: + self.models[name] = xrv.models.ResNet(weights=name).to("cuda") + def classification_to_string(self, outputs): """Format the classification outputs to a string.""" @@ -64,19 +135,43 @@ def run(self, image_url: str = "", prompt: str = "", **kwargs): Returns: tuple: The classification string, file path, and the next step instruction. """ - - api_key = os.getenv("api_key", "Invalid") - if api_key == "Invalid": - raise ValueError("API key not found. Please set the 'api_key' environment variable.") - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - "accept": "application/json", + + img = skimage.io.imread(image_url) + img = xrv.datasets.normalize(img, 255) + + # Check that images are 2D arrays + if len(img.shape) > 2: + img = img[:, :, 0] + elif len(img.shape) < 2: + raise ValueError("error, dimension lower than 2 for image") # FIX TBD + + # Add color channel + img = img[None, :, :] + img = xrv.datasets.XRayCenterCrop()(img) + + preds_label = { + label: [] + for label in xrv.datasets.default_pathologies } - response = requests.post(self.NIM_CXR, headers=headers, json={"image": image_url}) - response.raise_for_status() + + with torch.no_grad(): + img = torch.from_numpy(img).unsqueeze(0).to("cuda") + for name, model in self.models.items(): + preds = model(img).cpu() + for k, v in zip(xrv.datasets.default_pathologies, preds[0].detach().numpy()): + # TODO: Exclude invalid labels, which may be different from training + # if k not in valid_labels_model[name]: + # continue + # skip if k is one of the 8 pre-selected classes but the model isn't in the group + if k in cls_models and name not in cls_models[k]: + continue + preds_label[k].append(float(v)) + output = { + k: float(sum(v) / len(v)) + for k, v in preds_label.items() + } return ( - self.classification_to_string(response.json()), + self.classification_to_string(output), None, "Use this result to respond to this prompt:\n" + prompt, "", # no file needs to be downloaded diff --git a/monai_vila2d/demo/gradio_monai_vila2d.py b/monai_vila2d/demo/gradio_monai_vila2d.py index 3da03d4..64a9456 100644 --- a/monai_vila2d/demo/gradio_monai_vila2d.py +++ b/monai_vila2d/demo/gradio_monai_vila2d.py @@ -99,28 +99,32 @@ CACHED_IMAGES = {} TITLE = """ -
-

- project monai -

-
-

- MONAI Multi-Modal Medical (M3) VLM Demo -

-
-

- VILA-M3 is a vision-language model for medical applications that interprets medical images and text prompts to generate relevant responses. - Disclaimer: AI models generate responses and outputs based on complex algorithms and machine learning techniques, and those responses or outputs may be inaccurate, harmful, biased or indecent. By testing this model, you assume the risk of any harm caused by any response or output of the model. This model is for research purposes and not for clinical usage. -

- +
+

+ Project MONAI logo +

+ +

+ MONAI Multi-Modal (M3) VLM Demo +

+ +
+ + VILA-M3 is a vision-language model for medical applications that interprets medical images and text prompts to generate relevant responses. + +
+ + DISCLAIMER + + + AI models generate responses and outputs based on complex algorithms and machine learning techniques, + and those responses or outputs may be inaccurate, harmful, biased, or indecent. By testing this model, you assume the risk of any + harm caused by any response or output of the model. This model is for research purposes and not for clinical usage. + +
+
""" CSS_STYLES = ( @@ -708,9 +712,9 @@ def create_demo(source, model_path, conv_mode, server_port): with gr.Column(): with gr.Tab("In front of the scene"): - history_text = gr.HTML(HTML_PLACEHOLDER, label="Previous prompts") + history_text = gr.HTML(HTML_PLACEHOLDER, label="Previous prompts", max_height=600) with gr.Tab("Behind the scene"): - history_text_full = gr.HTML(HTML_PLACEHOLDER, label="Previous prompts full") + history_text_full = gr.HTML(HTML_PLACEHOLDER, label="Previous prompts full", max_height=600) image_download = gr.DownloadButton("Download the file", visible=False) clear_btn = gr.Button("Clear Conversation") with gr.Row(variant="compact"):