Skip to content

Commit

Permalink
Add huggingface checkpoint autodownload and rename sys_msg to model_c…
Browse files Browse the repository at this point in the history
…ards (#47)

This PR adds
- Auto-download of the 13B huggingface checkpoint
- In order to work with the 13B, disable some "metadata" prompts by
default
- Imaging modality will not be auto-determined, unless the user selects
"Auto"
- Model cards will not be added, unless the user checks "Use Model
Cards"

changes
- some "sys_msg" to "model_cards" in variable and function naming.

---------

Signed-off-by: Mingxin Zheng <[email protected]>
  • Loading branch information
mingxin-zheng authored Oct 31, 2024
1 parent 3f65e36 commit a9d281b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
demo_m3:
cd thirdparty/VILA; \
./environment_setup.sh
pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage,fire,ignite] torchxrayvision
pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage,fire,ignite] torchxrayvision huggingface_hub
mkdir -p $(HOME)/.torchxrayvision/models_data/ \
&& wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
-O $(HOME)/.torchxrayvision/models_data/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt \
Expand Down
14 changes: 5 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ The repository provides a collection of vision language models, benchmarks, and

## 💡 News

- [2024/10/31] We released the [VILA-M3-3B](https://huggingface.co/MONAI/Llama3-VILA-M3-3B), [VILA-M3-8B](https://huggingface.co/MONAI/Llama3-VILA-M3-8B), and [VILA-M3-13B](https://huggingface.co/MONAI/Llama3-VILA-M3-13B) checkpoints on [HuggingFace](https://huggingface.co/MONAI).
- [2024/10/24] We presented VILA-M3 and the VLM module in MONAI at MONAI Day ([slides](./m3/docs/materials/VILA-M3_MONAI-Day_2024.pdf), [recording - coming soon!](https://www.youtube.com/c/Project-MONAI))
- [2024/10/24] Interactive [VILA-M3 Demo](https://vila-m3-demo.monai.ngc.nvidia.com/) is available online!
- [Coming soon!] Several fine-tuned healthcare checkpoints are released.

## VILA-M3

Expand Down Expand Up @@ -45,8 +45,7 @@ Please visit the [VILA-M3 Demo](https://vila-m3-demo.monai.ngc.nvidia.com/) to t
If CUDA is not installed, use one of the following methods:
- **Recommended** Use the Docker image: `nvidia/cuda:12.2.2-devel-ubuntu22.04`
```bash
docker run -it --rm --ipc host --gpus all --net host \
-v <ckpts_dir>:/data/checkpoints \
docker run -it --rm --ipc host --gpus all --net host
nvidia/cuda:12.2.2-devel-ubuntu22.04 bash
```
- **Manual Installation (not recommended)** Download the appropiate package from [NVIDIA offical page](https://developer.nvidia.com/cuda-12-2-2-download-archive)
Expand Down Expand Up @@ -86,14 +85,11 @@ Please visit the [VILA-M3 Demo](https://vila-m3-demo.monai.ngc.nvidia.com/) to t
1. Start the Gradio demo:
```bash
python gradio_m3.py \
--modelpath /data/checkpoints/<8B-checkpoint-name> \
--convmode llama_3 \
--port 7860
python gradio_m3.py
```
1. Adding your own expert model
- This is still a work in progress. Please refer to the [README](m3/demo/experts/README.md) for more details.
#### Adding your own expert model
- This is still a work in progress. Please refer to the [README](m3/demo/experts/README.md) for more details.
## Contributing
Expand Down
104 changes: 72 additions & 32 deletions m3/demo/gradio_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
import os
import tempfile
import time
from copy import deepcopy
from glob import glob
from shutil import copyfile, rmtree
Expand All @@ -32,6 +33,7 @@
load_image,
save_image_url_to_file,
)
from huggingface_hub import snapshot_download
from llava.constants import IMAGE_TOKEN_INDEX
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import KeywordsStoppingCriteria, get_model_name_from_path, process_images, tokenizer_image_token
Expand Down Expand Up @@ -71,7 +73,7 @@
"Chest X-ray Sample 3": "https://developer.download.nvidia.com/assets/Clara/monai/samples/cxr_c2af2ab3-6a11cbae-d9fa4d64-21ab221e-cf6f2146_v1.jpg",
}

SYS_MSG = "Here is a list of available expert models:\n<BRATS(args)> Modality: MRI, Task: segmentation, Overview: A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data, Accuracy: Tumor core (TC): 0.8559 - Whole tumor (WT): 0.9026 - Enhancing tumor (ET): 0.7905 - Average: 0.8518, Valid args are: None\n<VISTA3D(args)> Modality: CT, Task: segmentation, Overview: domain-specialized interactive foundation model developed for segmenting and annotating human anatomies with precision, Accuracy: 127 organs: 0.792 Dice on average, Valid args are: 'everything', 'hepatic tumor', 'pancreatic tumor', 'lung tumor', 'bone lesion', 'organs', 'cardiovascular', 'gastrointestinal', 'skeleton', or 'muscles'\n<VISTA2D(args)> Modality: cell imaging, Task: segmentation, Overview: model for cell segmentation, which was trained on a variety of cell imaging outputs, including brightfield, phase-contrast, fluorescence, confocal, or electron microscopy, Accuracy: Good accuracy across several cell imaging datasets, Valid args are: None\n<CXR(args)> Modality: chest x-ray (CXR), Task: classification, Overview: pre-trained model which are trained on large cohorts of data, Accuracy: Good accuracy across several diverse chest x-rays datasets, Valid args are: None\nGive the model <NAME(args)> when selecting a suitable expert model.\n"
MODEL_CARDS = "Here is a list of available expert models:\n<BRATS(args)> Modality: MRI, Task: segmentation, Overview: A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data, Accuracy: Tumor core (TC): 0.8559 - Whole tumor (WT): 0.9026 - Enhancing tumor (ET): 0.7905 - Average: 0.8518, Valid args are: None\n<VISTA3D(args)> Modality: CT, Task: segmentation, Overview: domain-specialized interactive foundation model developed for segmenting and annotating human anatomies with precision, Accuracy: 127 organs: 0.792 Dice on average, Valid args are: 'everything', 'hepatic tumor', 'pancreatic tumor', 'lung tumor', 'bone lesion', 'organs', 'cardiovascular', 'gastrointestinal', 'skeleton', or 'muscles'\n<VISTA2D(args)> Modality: cell imaging, Task: segmentation, Overview: model for cell segmentation, which was trained on a variety of cell imaging outputs, including brightfield, phase-contrast, fluorescence, confocal, or electron microscopy, Accuracy: Good accuracy across several cell imaging datasets, Valid args are: None\n<CXR(args)> Modality: chest x-ray (CXR), Task: classification, Overview: pre-trained model which are trained on large cohorts of data, Accuracy: Good accuracy across several diverse chest x-rays datasets, Valid args are: None\nGive the model <NAME(args)> when selecting a suitable expert model.\n"

SYS_PROMPT = None # set when the script initializes

Expand Down Expand Up @@ -271,7 +273,8 @@ class SessionVariables:
def __init__(self):
"""Initialize the session variables"""
self.sys_prompt = SYS_PROMPT
self.sys_msg = SYS_MSG
self.sys_msg = MODEL_CARDS
self.use_model_cards = True
self.slice_index = None # Slice index for 3D images
self.image_url = None # Image URL to the image on the web
self.axis = 2
Expand Down Expand Up @@ -300,23 +303,31 @@ def new_session_variables(**kwargs):
class M3Generator:
"""Class to generate M3 responses"""

def __init__(self, source="local", model_path="", conv_mode=""):
def __init__(self, source="huggingface", model_path="", conv_mode=""):
"""Initialize the M3 generator"""
global SYS_PROMPT
self.source = source
# TODO: support huggingface models
if source == "local":
# TODO: allow setting the device
disable_torch_init()
self.conv_mode = conv_mode
self.model_name = get_model_name_from_path(model_path)
model_name = get_model_name_from_path(model_path)
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, self.model_name
model_path, model_name
)
logger.info(f"Model {self.model_name} loaded successfully. Context length: {self.context_len}")
logger.info(f"Model {model_name} loaded successfully. Context length: {self.context_len}")
elif source == "huggingface":
pass
model_path = snapshot_download(model_path)
self.conv_mode = conv_mode
model_name = get_model_name_from_path(model_path)
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, model_name
)
logger.info(f"Model {model_name} loaded successfully. Context length: {self.context_len}")
else:
raise NotImplementedError(f"Source {source} is not supported.")
SYS_PROMPT = conv_templates[self.conv_mode].system # only set once

def generate_response_local(
self,
Expand Down Expand Up @@ -368,6 +379,7 @@ def generate_response_local(
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)

start_time = time.time()
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
Expand All @@ -382,6 +394,9 @@ def generate_response_local(
pad_token_id=self.tokenizer.eos_token_id,
min_new_tokens=2,
)
end_time = time.time()
logger.debug(f"Time taken to generate {len(output_ids[0])} tokens: {end_time - start_time:.2f} seconds")
logger.debug(f"Tokens per second: {len(output_ids[0]) / (end_time - start_time):.2f}")

outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
Expand All @@ -394,8 +409,9 @@ def generate_response_local(

def generate_response(self, **kwargs):
"""Generate the response"""
if self.source == "local":
if self.source == "local" or self.source == "huggingface":
return self.generate_response_local(**kwargs)
raise NotImplementedError(f"Source {self.source} is not supported.")

def squash_expert_messages_into_user(self, messages: list):
"""Squash consecutive expert messages into a single user message."""
Expand Down Expand Up @@ -433,15 +449,18 @@ def process_prompt(self, prompt, sv, chat_history):
modality = sv.modality_prompt
mod_msg = f"This is a {modality} image.\n" if modality != "Unknown" else ""

model_cards = sv.sys_msg if sv.use_model_cards else ""

img_file = CACHED_IMAGES.get(sv.image_url, None)

if isinstance(img_file, str):
if "<image>" not in prompt:
_prompt = sv.sys_msg + "<image>" + mod_msg + prompt
sv.sys_msgs_to_hide.append(sv.sys_msg + "<image>" + mod_msg)
_prompt = model_cards + "<image>" + mod_msg + prompt
sv.sys_msgs_to_hide.append(model_cards + "<image>" + mod_msg)
else:
_prompt = sv.sys_msg + mod_msg + prompt
sv.sys_msgs_to_hide.append(sv.sys_msg + mod_msg)
_prompt = model_cards + mod_msg + prompt
if model_cards + mod_msg != "":
sv.sys_msgs_to_hide.append(model_cards + mod_msg)

if img_file.endswith(".nii.gz"): # Take the specific slice from a volume
chat_history.append(
Expand Down Expand Up @@ -511,6 +530,7 @@ def process_prompt(self, prompt, sv, chat_history):
# Keep these parameters accross one conversation
sys_prompt=sv.sys_prompt,
sys_msg=sv.sys_msg,
use_model_cards=sv.use_model_cards,
download_file_path=download_pkg,
temp_working_dir=sv.temp_working_dir,
interactive=True,
Expand Down Expand Up @@ -582,6 +602,7 @@ def colorcode_message(text="", data_url=None, show_all=False, role="user", sys_m
if not show_all and role == "expert":
return ""
if not show_all and sys_msgs_to_hide and isinstance(sys_msgs_to_hide, list):
sys_msgs_to_hide.sort(key=len, reverse=True)
for sys_msg in sys_msgs_to_hide:
text = text.replace(sys_msg, "")

Expand Down Expand Up @@ -629,14 +650,15 @@ def clear_all_convs(sv: SessionVariables):
if sv.temp_working_dir is not None:
rmtree(sv.temp_working_dir)
new_sv = new_session_variables()
# Order of output: prompt_edit, chat_history, history_text, history_text_full, sys_prompt_text, model_cards_text
# Order of output: prompt_edit, chat_history, history_text, history_text_full, sys_prompt_text, model_cards_checkbox, model_cards_text, modality_prompt_dropdown
return (
new_sv,
"Enter your prompt here",
ChatHistory(),
HTML_PLACEHOLDER,
HTML_PLACEHOLDER,
new_sv.sys_prompt,
new_sv.use_model_cards,
new_sv.sys_msg,
new_sv.modality_prompt,
)
Expand Down Expand Up @@ -670,10 +692,17 @@ def update_sys_prompt(sys_prompt, sv):
return sv


def update_sys_message(sys_message, sv):
def update_model_cards_text(model_cards, sv):
"""Update the model cards"""
logger.debug(f"Updating the model cards")
sv.sys_msg = sys_message
logger.debug(f"Updating the model cards contents")
sv.sys_msg = model_cards
return sv


def update_model_cards_checkbox(use_model_cards, sv):
"""Update the model cards checkbox"""
logger.debug(f"Updating the model cards checkbox")
sv.use_model_cards = use_model_cards
return sv


Expand Down Expand Up @@ -720,20 +749,21 @@ def create_demo(source, model_path, conv_mode, server_port):
)

with gr.Accordion("System Prompt and Message", open=False):
modality_prompt_dropdown = gr.Dropdown(
label="Select Modality",
choices=["Unknown", "Auto", "CT", "MRI", "CXR"],
)
model_cards_checkbox = gr.Checkbox(
label="Use Model Cards",
value=sv.value.use_model_cards,
info="Check this to include the model cards of the experts.",
)
model_cards_text = gr.Textbox(label="Model Cards", value=sv.value.sys_msg, lines=8)
sys_prompt_text = gr.Textbox(
label="System Prompt",
value=sv.value.sys_prompt,
lines=4,
)
model_cards_text = gr.Textbox(
label="Model Cards",
value=sv.value.sys_msg,
lines=10,
)
modality_prompt_dropdown = gr.Dropdown(
label="Select Modality",
choices=["Auto", "CT", "MRI", "X-ray", "Unknown"],
)

with gr.Column():
with gr.Tab("In front of the scene"):
Expand Down Expand Up @@ -780,7 +810,8 @@ def create_demo(source, model_path, conv_mode, server_port):
top_p_slider.change(fn=update_top_p, inputs=[top_p_slider, sv], outputs=[sv])
max_tokens_slider.change(fn=update_max_tokens, inputs=[max_tokens_slider, sv], outputs=[sv])
sys_prompt_text.change(fn=update_sys_prompt, inputs=[sys_prompt_text, sv], outputs=[sv])
model_cards_text.change(fn=update_sys_message, inputs=[model_cards_text, sv], outputs=[sv])
model_cards_checkbox.change(fn=update_model_cards_checkbox, inputs=[model_cards_checkbox, sv], outputs=[sv])
model_cards_text.change(fn=update_model_cards_text, inputs=[model_cards_text, sv], outputs=[sv])
modality_prompt_dropdown.change(fn=update_modality_prompt, inputs=[modality_prompt_dropdown, sv], outputs=[sv])
# Reset button
clear_btn.click(
Expand All @@ -793,6 +824,7 @@ def create_demo(source, model_path, conv_mode, server_port):
history_text,
history_text_full,
sys_prompt_text,
model_cards_checkbox,
model_cards_text,
modality_prompt_dropdown,
],
Expand Down Expand Up @@ -820,12 +852,21 @@ def create_demo(source, model_path, conv_mode, server_port):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# TODO: Add the argument to load multiple models from a JSON file
parser.add_argument("--convmode", type=str, default="llama_3", help="The conversation mode to use.")
parser.add_argument(
"--convmode",
type=str,
default="llama_3",
help="The conversation mode to use. For 8B models, use 'llama_3'. For 3B and 13B models, use 'vicuna_v1'.",
)
parser.add_argument(
"--modelpath",
type=str,
default="/data/checkpoints/vila-m3-8b",
help="The path to the model to load.",
default="MONAI/Llama3-VILA-M3-8B",
help=(
"The path to the model to load. "
"If source is 'local', it can be '/data/checkpoints/vila-m3-8b'. If "
"If source is 'huggingface', it can be 'MONAI/Llama3-VILA-M3-8B'."
),
)
parser.add_argument(
"--port",
Expand All @@ -836,11 +877,10 @@ def create_demo(source, model_path, conv_mode, server_port):
parser.add_argument(
"--source",
type=str,
default="local",
help="The source of the model. Option is only 'local'.",
default="huggingface",
help="The source of the model. Option is 'huggingface' or 'local'.",
)
args = parser.parse_args()
SYS_PROMPT = conv_templates[args.convmode].system
cache_images()
create_demo(args.source, args.modelpath, args.convmode, args.port)
cache_cleanup()

0 comments on commit a9d281b

Please sign in to comment.