Skip to content

Commit

Permalink
Change TorchXrayvision implementation to local (#26)
Browse files Browse the repository at this point in the history
This PR:
- addresses some HTML formats to improve visualization of disclaimer and
chat history
- changes the torchxrayvision running to local
- adds documentation how to implement one's own expert model
(experimental)

---------

Signed-off-by: Mingxin Zheng <[email protected]>
  • Loading branch information
mingxin-zheng authored Oct 22, 2024
1 parent 8e23c4e commit 0d064d7
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 42 deletions.
23 changes: 21 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
demo_monai_vila2d:
demo_monai_vila2d: cxr_download
cd thirdparty/VILA; \
./environment_setup.sh
pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage]
pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage] torchxrayvision

cxr_download:
mkdir -p $(HOME)/.torchxrayvision/models_data/ \
&& wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
-O $(HOME)/.torchxrayvision/models_data/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
&& wget https://github.com/mlmed/torchxrayvision/releases/download/v1/chex-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
-O $(HOME)/.torchxrayvision/models_data/chex-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
&& wget https://github.com/mlmed/torchxrayvision/releases/download/v1/mimic_ch-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
-O $(HOME)/.torchxrayvision/models_data/mimic_ch-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
&& wget https://github.com/mlmed/torchxrayvision/releases/download/v1/mimic_nb-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
-O $(HOME)/.torchxrayvision/models_data/mimic_nb-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
&& wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
-O $(HOME)/.torchxrayvision/models_data/nih-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
&& wget https://github.com/mlmed/torchxrayvision/releases/download/v1/pc-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
-O $(HOME)/.torchxrayvision/models_data/pc-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
&& wget https://github.com/mlmed/torchxrayvision/releases/download/v1/kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
-O $(HOME)/.torchxrayvision/models_data/kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
&& wget https://github.com/mlmed/torchxrayvision/releases/download/v1/pc-nih-rsna-siim-vin-resnet50-test512-e400-state.pt \
-O $(HOME)/.torchxrayvision/models_data/pc-nih-rsna-siim-vin-resnet50-test512-e400-state.pt
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@ For details, see [here](./monai_vila2d/README.md).
To install these, run
```bash
sudo apt-get update
sudo apt-get install -y python3.10 python3.10-venv python3.10-dev git
sudo apt-get install -y wget python3.10 python3.10-venv python3.10-dev git
```
NOTE: The commands are tailored for the Docker image `nvidia/cuda:12.2.2-devel-ubuntu22.04`. If using a different setup, adjust the commands accordingly.

1. **GPU Memory**: Ensure that the GPU has sufficient memory to run the models:
- **VILA-M3**: 8B: ~18GB, 13B: ~30GB
- **CXR**: This expert loads various [TorchXRayVision](https://github.com/mlmed/torchxrayvision) models and performs ensemble predictions. The memory requirement is roughly 1.5GB in total.
- **VISTA3D**: (TBD)
- **BRATS**: (TBD)

#### Setup Environment

Expand All @@ -64,7 +69,6 @@ For details, see [here](./monai_vila2d/README.md).

1. Set the API keys for calling the expert models:
```bash
export api_key=<your nvcf key>
export NIM_API_KEY=<your NIM key>
```

Expand All @@ -77,6 +81,9 @@ For details, see [here](./monai_vila2d/README.md).
--port 7860
```

1. Adding your own model
- This is still a work in progress. Please refer to the [README](./monai_vila2d/demo/experts/README.md) for more details.

## Contributing

To lint the code, please install these packages:
Expand Down
33 changes: 33 additions & 0 deletions monai_vila2d/demo/experts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Adding a new expert model

[BaseExpert](./base_expert.py) is a template for adding a new expert model to the repository. Please follow the instructions below to add a new expert model.

1. Create a new python file in the `monai_vila2d/demo/experts` directory. The file name should be the name of the expert model you are adding. For example, if you are adding a new expert model called `MyExpert`, the file name should be `my_expert.py`.

2. Implement the expert model `mentioned_by` method in the new python file. The method should return a boolean value indicating whether the expert model is mentioned by the given text from the language model. For example:

```python
def mentioned_by(text: str) -> bool:
return "my expert" in text
```

3. Implement the expert model `run` method in the new python file. The method should return the expert model's prediction given the input image. For example:

```python

def run(image_url: str) -> Tuple[str, str, str, str]:
# Load the image
image = load_image(image_url)

# Perform inference
prediction = my_model(image)

# Return the prediction
text_output = "The expert model prediction is a segmentation mask."
# save the prediction to a file
save_image(prediction, "prediction.png")
image_output = "prediction.png"
instruction = "Use this mask to answer: what is the object in the image?"
seg_file = "prediction.png"
return text_output, image_output, instruction, seg_file
```
125 changes: 110 additions & 15 deletions monai_vila2d/demo/experts/expert_torchxrayvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,92 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
"""
The implementation of TorchXRayVision expert model is adapted from the Get-Started example in TorchXRayVision:
https://github.com/mlmed/torchxrayvision
"""

import re

import requests
import skimage.io
import torch
import torchxrayvision as xrv
from experts.base_expert import BaseExpert

MODEL_NAMES = [
"densenet121-res224-all",
"densenet121-res224-chex",
"densenet121-res224-mimic_ch",
"densenet121-res224-mimic_nb",
"densenet121-res224-nih",
"densenet121-res224-pc",
"densenet121-res224-rsna",
"resnet50-res512-all",
]

# Taken from https://github.com/mlmed/torchxrayvision/blob/master/torchxrayvision/models.py
valid_labels_model = {
"densenet121-res224-all": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'],
"densenet121-res224-chex": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'],
"densenet121-res224-mimic_ch": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'],
"densenet121-res224-mimic_nb": ['Atelectasis', 'Consolidation', '', 'Pneumothorax', 'Edema', '', '', 'Effusion', 'Pneumonia', '', 'Cardiomegaly', '', '', '', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'],
"densenet121-res224-nih": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', '', '', ''],
"densenet121-res224-pc": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', '', 'Fracture', '', ''],
"densenet121-res224-rsna": ['', '', '', '', '', '', '', '', 'Pneumonia', '', '', '', '', '', '', '', 'Lung Opacity', ''],
"resnet50-res512-all": ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'Lung Lesion', 'Fracture', 'Lung Opacity', 'Enlarged Cardiomediastinum'],
}


# Copied from https://github.com/Project-MONAI/VLM/blob/b74d7e444c6604ea83846b67bb98b5172a7e5495/monai_vila2d/data_prepare/experts/torchxrayvision/torchxray_cls.py#L47-L54
group_0 = ["resnet50-res512-all"]
group_1 = [
"densenet121-res224-all",
"densenet121-res224-nih",
"densenet121-res224-chex",
"densenet121-res224-mimic_ch",
"densenet121-res224-mimic_nb",
"densenet121-res224-rsna",
"densenet121-res224-pc",
"resnet50-res512-all",
]
group_2 = [
"densenet121-res224-all",
"densenet121-res224-chex",
"densenet121-res224-pc",
"resnet50-res512-all",
]
group_4 = [
"densenet121-res224-all",
"densenet121-res224-nih",
"densenet121-res224-chex",
"resnet50-res512-all",
]

cls_models = {
"Fracture": group_4,
"Pneumothorax": group_0,
"Lung Opacity": group_1,
"Atelectasis": group_2,
"Cardiomegaly": group_2,
"Consolidation": group_2,
"Edema": group_2,
"Effusion": group_2,
}


class ExpertTXRV(BaseExpert):
"""Expert model for the TorchXRayVision model."""

NIM_CXR = "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/5ac6a152-6d3f-4691-aef8-9e656557ee45"

def __init__(self) -> None:
"""Initialize the CXR expert model."""
self.model_name = "CXR"
self.models = {}
for name in MODEL_NAMES:
if "densenet" in name:
self.models[name] = xrv.models.DenseNet(weights=name).to("cuda")
elif "resnet" in name:
self.models[name] = xrv.models.ResNet(weights=name).to("cuda")


def classification_to_string(self, outputs):
"""Format the classification outputs to a string."""
Expand Down Expand Up @@ -64,19 +135,43 @@ def run(self, image_url: str = "", prompt: str = "", **kwargs):
Returns:
tuple: The classification string, file path, and the next step instruction.
"""

api_key = os.getenv("api_key", "Invalid")
if api_key == "Invalid":
raise ValueError("API key not found. Please set the 'api_key' environment variable.")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
"accept": "application/json",

img = skimage.io.imread(image_url)
img = xrv.datasets.normalize(img, 255)

# Check that images are 2D arrays
if len(img.shape) > 2:
img = img[:, :, 0]
elif len(img.shape) < 2:
raise ValueError("error, dimension lower than 2 for image") # FIX TBD

# Add color channel
img = img[None, :, :]
img = xrv.datasets.XRayCenterCrop()(img)

preds_label = {
label: []
for label in xrv.datasets.default_pathologies
}
response = requests.post(self.NIM_CXR, headers=headers, json={"image": image_url})
response.raise_for_status()

with torch.no_grad():
img = torch.from_numpy(img).unsqueeze(0).to("cuda")
for name, model in self.models.items():
preds = model(img).cpu()
for k, v in zip(xrv.datasets.default_pathologies, preds[0].detach().numpy()):
# TODO: Exclude invalid labels, which may be different from training
# if k not in valid_labels_model[name]:
# continue
# skip if k is one of the 8 pre-selected classes but the model isn't in the group
if k in cls_models and name not in cls_models[k]:
continue
preds_label[k].append(float(v))
output = {
k: float(sum(v) / len(v))
for k, v in preds_label.items()
}
return (
self.classification_to_string(response.json()),
self.classification_to_string(output),
None,
"Use this result to respond to this prompt:\n" + prompt,
"", # no file needs to be downloaded
Expand Down
50 changes: 27 additions & 23 deletions monai_vila2d/demo/gradio_monai_vila2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,28 +99,32 @@
CACHED_IMAGES = {}

TITLE = """
<div style="text-align: center; max-width: 800px; margin: 0 auto;">
<p>
<img src="https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/MONAI-logo-color.png" alt="project monai" style="width: 50%; min-width: 500px; max-width: 800px; margin: auto; display: block;">
</p>
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px;">
MONAI Multi-Modal Medical (M3) VLM Demo
</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%">
VILA-M3 is a vision-language model for medical applications that interprets medical images and text prompts to generate relevant responses.
Disclaimer: AI models generate responses and outputs based on complex algorithms and machine learning techniques, and those responses or outputs may be inaccurate, harmful, biased or indecent. By testing this model, you assume the risk of any harm caused by any response or output of the model. This model is for research purposes and not for clinical usage.
</p>
<div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
<p>
<img src="https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/MONAI-logo-color.png" alt="Project MONAI logo"
style="width: 50%; max-width: 600px; min-width: 300px; margin: auto; display: block;">
</p>
<h1 style="font-weight: 900; font-size: 1.5rem; margin-bottom: 10px;">
MONAI Multi-Modal (M3) VLM Demo
</h1>
<div style="font-size: 0.95rem; text-align: left; max-width: 800px; margin: 0 auto;">
<span>
VILA-M3 is a vision-language model for medical applications that interprets medical images and text prompts to generate relevant responses.
</span>
<details style="display: inline; cursor: pointer;">
<summary style="display: inline; font-weight: bold;">
<strong>DISCLAIMER</strong>
</summary>
<span style="font-size: 0.95rem;">
AI models generate responses and outputs based on complex algorithms and machine learning techniques,
and those responses or outputs may be inaccurate, harmful, biased, or indecent. By testing this model, you assume the risk of any
harm caused by any response or output of the model. This model is for research purposes and not for clinical usage.
</span>
</details>
</div>
</div>
"""

CSS_STYLES = (
Expand Down Expand Up @@ -708,9 +712,9 @@ def create_demo(source, model_path, conv_mode, server_port):

with gr.Column():
with gr.Tab("In front of the scene"):
history_text = gr.HTML(HTML_PLACEHOLDER, label="Previous prompts")
history_text = gr.HTML(HTML_PLACEHOLDER, label="Previous prompts", max_height=600)
with gr.Tab("Behind the scene"):
history_text_full = gr.HTML(HTML_PLACEHOLDER, label="Previous prompts full")
history_text_full = gr.HTML(HTML_PLACEHOLDER, label="Previous prompts full", max_height=600)
image_download = gr.DownloadButton("Download the file", visible=False)
clear_btn = gr.Button("Clear Conversation")
with gr.Row(variant="compact"):
Expand Down

0 comments on commit 0d064d7

Please sign in to comment.