diff --git a/Makefile b/Makefile index a1a852e..4f8bd07 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ demo_monai_vila2d: cd thirdparty/VILA; \ ./environment_setup.sh - pip install -U nibabel python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage] + pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage] diff --git a/README.md b/README.md index 27e5283..b3bedb9 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ For details, see [here](./monai_vila2d/README.md). ```bash git clone https://github.com/Project-MONAI/VLM --recursive cd VLM - python -m venv .venv + python3.10 -m venv .venv source .venv/bin/activate make demo_monai_vila2d ``` @@ -70,9 +70,10 @@ For details, see [here](./monai_vila2d/README.md). 1. Start the Gradio demo: ```bash - python demo/gradio_monai_vila2d.py \ - --modelpath /data/checkpoints/ \ - --convmode \ + cd monai_vila2d/demo + python gradio_monai_vila2d.py \ + --modelpath /data/checkpoints/<8B-checkpoint-name> \ + --convmode llama_3 \ --port 7860 ``` diff --git a/demo/experts/__init__.py b/monai_vila2d/demo/experts/__init__.py similarity index 100% rename from demo/experts/__init__.py rename to monai_vila2d/demo/experts/__init__.py diff --git a/demo/experts/base_expert.py b/monai_vila2d/demo/experts/base_expert.py similarity index 100% rename from demo/experts/base_expert.py rename to monai_vila2d/demo/experts/base_expert.py diff --git a/demo/experts/expert_monai_vista3d.py b/monai_vila2d/demo/experts/expert_monai_vista3d.py similarity index 100% rename from demo/experts/expert_monai_vista3d.py rename to monai_vila2d/demo/experts/expert_monai_vista3d.py diff --git a/demo/experts/expert_torchxrayvision.py b/monai_vila2d/demo/experts/expert_torchxrayvision.py similarity index 100% rename from demo/experts/expert_torchxrayvision.py rename to monai_vila2d/demo/experts/expert_torchxrayvision.py diff --git a/demo/experts/label_groups_dict.json b/monai_vila2d/demo/experts/label_groups_dict.json similarity index 100% rename from demo/experts/label_groups_dict.json rename to monai_vila2d/demo/experts/label_groups_dict.json diff --git a/demo/experts/utils.py b/monai_vila2d/demo/experts/utils.py similarity index 98% rename from demo/experts/utils.py rename to monai_vila2d/demo/experts/utils.py index 3301db8..868c531 100644 --- a/demo/experts/utils.py +++ b/monai_vila2d/demo/experts/utils.py @@ -16,8 +16,6 @@ import re from io import BytesIO from pathlib import Path -from typing import List -from urllib.parse import urlparse import numpy as np import requests @@ -144,11 +142,11 @@ def get_filename_from_cd(url, cd): return fname[0].strip('"').strip("'") -def get_slice_filenames(image_file, slice_index): +def get_slice_filenames(image_file: str, slice_index: int, ext: str = "jpg"): """Small helper function to get the slice filenames""" base_name = os.path.basename(image_file) - image_filename = base_name.replace(".nii.gz", f"_slice{slice_index}.jpg") - seg_filename = base_name.replace(".nii.gz", f"_slice{slice_index}_seg.jpg") + image_filename = base_name.replace(".nii.gz", f"_slice{slice_index}_img.{ext}") + seg_filename = base_name.replace(".nii.gz", f"_slice{slice_index}_seg.{ext}") return image_filename, seg_filename diff --git a/demo/gradio_monai_vila2d.py b/monai_vila2d/demo/gradio_monai_vila2d.py similarity index 84% rename from demo/gradio_monai_vila2d.py rename to monai_vila2d/demo/gradio_monai_vila2d.py index 9d71d99..3da03d4 100644 --- a/demo/gradio_monai_vila2d.py +++ b/monai_vila2d/demo/gradio_monai_vila2d.py @@ -15,6 +15,8 @@ import os import tempfile from copy import deepcopy +from glob import glob +from shutil import copyfile, rmtree import gradio as gr import nibabel as nib @@ -35,6 +37,7 @@ from llava.mm_utils import KeywordsStoppingCriteria, get_model_name_from_path, process_images, tokenizer_image_token from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init +from tqdm import tqdm load_dotenv() @@ -59,8 +62,8 @@ # Suppress logging from dependent libraries logging.getLogger("gradio").setLevel(logging.WARNING) -# Sample images dictionary -IMAGES_URLS = { +# Sample images dictionary. It accepts either a URL or a local path. +IMG_URLS_OR_PATHS = { "CT Sample 1": "https://developer.download.nvidia.com/assets/Clara/monai/samples/liver_0.nii.gz", "CT Sample 2": "https://developer.download.nvidia.com/assets/Clara/monai/samples/ct_sample.nii.gz", "Chest X-ray Sample 1": "https://developer.download.nvidia.com/assets/Clara/monai/samples/cxr_ce3d3d98-bf5170fa-8e962da1-97422442-6653c48a_v1.jpg", @@ -134,18 +137,36 @@ def cache_images(): """Cache the image and return the file path""" - logger.debug(f"Caching the image") - for _, image_url in IMAGES_URLS.items(): - CACHED_IMAGES[image_url] = save_image_url_to_file(image_url, CACHED_DIR) + logger.debug(f"Caching the image to {CACHED_DIR}") + for _, item in IMG_URLS_OR_PATHS.items(): + if item.startswith("http"): + CACHED_IMAGES[item] = save_image_url_to_file(item, CACHED_DIR) + elif os.path.exists(item): + # move the file to the cache directory + file_name = os.path.basename(item) + CACHED_IMAGES[item] = os.path.join(CACHED_DIR, file_name) + if not os.path.isfile(CACHED_IMAGES[item]): + copyfile(item, CACHED_IMAGES[item]) + + if CACHED_IMAGES[item].endswith(".nii.gz"): + data = nib.load(CACHED_IMAGES[item]).get_fdata() + for slice_index in tqdm(range(data.shape[2])): + image_filename = get_slice_filenames(CACHED_IMAGES[item], slice_index)[0] + if not os.path.exists(os.path.join(CACHED_DIR, image_filename)): + compose = get_monai_transforms( + ["image"], + CACHED_DIR, + modality="CT", # TODO: Get the modality from the image/prompt/metadata + slice_index=slice_index, + image_filename=image_filename, + ) + compose({"image": CACHED_IMAGES[item]}) def cache_cleanup(): """Clean up the cache""" logger.debug(f"Cleaning up the cache") - for _, cache_file_name in CACHED_IMAGES.items(): - if os.path.exists(cache_file_name): - os.remove(cache_file_name) - print(f"Cache file {cache_file_name} cleaned up") + rmtree(CACHED_DIR) class ChatHistory: @@ -212,7 +233,7 @@ def append(self, prompt_or_answer, image_path=None, role="user"): self.messages.append({"role": role, "content": new_contents}) - def get_html(self, show_all=False): + def get_html(self, show_all: bool = False, sys_msgs_to_hide: list | None = None): """Returns the chat history as an HTML string to display""" history = [] @@ -222,7 +243,9 @@ def get_html(self, show_all=False): history_text_html = "" for content in contents: if content["type"] == "text": - history_text_html += colorcode_message(text=content["text"], show_all=show_all, role=role) + history_text_html += colorcode_message( + text=content["text"], show_all=show_all, role=role, sys_msgs_to_hide=sys_msgs_to_hide + ) else: history_text_html += colorcode_message( data_url=image_to_data_url(content["image_path"], max_size=(300, 300)), show_all=True, role=role @@ -386,7 +409,7 @@ def process_prompt(self, prompt, sv, chat_history): if not sv.interactive: # Do not process the prompt if the image is not provided - return None, sv, chat_history, "Please select an image", "Please select an image" + return None, sv, chat_history, "Please select an image", "Please select an image" if sv.temp_working_dir is None: sv.temp_working_dir = tempfile.mkdtemp() @@ -395,16 +418,19 @@ def process_prompt(self, prompt, sv, chat_history): mod_msg = f"This is a {modality} image.\n" if modality != "Unknown" else "" img_file = CACHED_IMAGES.get(sv.image_url, None) + sys_msgs_to_hide = [] if isinstance(img_file, str): if "" not in prompt: _prompt = sv.sys_msg + "" + mod_msg + prompt + sys_msgs_to_hide.append(sv.sys_msg + "" + mod_msg) else: _prompt = sv.sys_msg + mod_msg + prompt + sys_msgs_to_hide.append(sv.sys_msg + mod_msg) if img_file.endswith(".nii.gz"): # Take the specific slice from a volume chat_history.append( _prompt, - image_path=os.path.join(sv.temp_working_dir, get_slice_filenames(img_file, sv.slice_index)[0]), + image_path=os.path.join(CACHED_DIR, get_slice_filenames(img_file, sv.slice_index)[0]), ) else: chat_history.append(_prompt, image_path=img_file) @@ -470,9 +496,16 @@ def process_prompt(self, prompt, sv, chat_history): sys_prompt=sv.sys_prompt, sys_msg=sv.sys_msg, download_file_path=download_pkg, + temp_working_dir=sv.temp_working_dir, interactive=True, ) - return None, new_sv, chat_history, chat_history.get_html(show_all=False), chat_history.get_html(show_all=True) + return ( + None, + new_sv, + chat_history, + chat_history.get_html(show_all=False, sys_msgs_to_hide=sys_msgs_to_hide), + chat_history.get_html(show_all=True), + ) def input_image(image, sv: SessionVariables): @@ -482,83 +515,62 @@ def input_image(image, sv: SessionVariables): return image, sv -def update_image_selection(selected_image, sv: SessionVariables, slice_index_html, increment=None): +def update_image_selection(selected_image, sv: SessionVariables, slice_index=None): """Update the gradio components based on the selected image""" logger.debug(f"Updating display image for {selected_image}") - sv.image_url = IMAGES_URLS.get(selected_image, None) + sv.image_url = IMG_URLS_OR_PATHS.get(selected_image, None) img_file = CACHED_IMAGES.get(sv.image_url, None) - if sv.image_url is None: - return None, sv, slice_index_html, [[""]] - - if sv.temp_working_dir is None: - sv.temp_working_dir = tempfile.mkdtemp() + if sv.image_url is None or img_file is None: + return None, sv, gr.Slider(0, 2, 1, 0, visible=False), [[""]] sv.interactive = True if img_file.endswith(".nii.gz"): - if sv.slice_index is None: - data = nib.load(img_file).get_fdata() - sv.slice_index = data.shape[sv.axis] // 2 - sv.idx_range = (0, data.shape[sv.axis] - 1) - - if increment is not None: - sv.slice_index += increment - sv.slice_index = max(sv.idx_range[0], min(sv.idx_range[1], sv.slice_index)) + if slice_index is None: + slice_file_pttn = img_file.replace(".nii.gz", "*_img.jpg") + # glob the image files + slice_files = glob(slice_file_pttn) + sv.slice_index = len(slice_files) // 2 + sv.idx_range = (0, len(slice_files) - 1) + else: + # Slice index is updated by the slidebar. + # There is no need to update the idx_range. + sv.slice_index = slice_index image_filename = get_slice_filenames(img_file, sv.slice_index)[0] - if not os.path.exists(image_filename): - compose = get_monai_transforms( - ["image"], - sv.temp_working_dir, - modality="CT", # TODO: Get the modality from the image/prompt/metadata - slice_index=sv.slice_index, - image_filename=image_filename, - ) - compose({"image": img_file}) + if not os.path.exists(os.path.join(CACHED_DIR, image_filename)): + raise ValueError(f"Image file {image_filename} does not exist.") return ( - os.path.join(sv.temp_working_dir, image_filename), + os.path.join(CACHED_DIR, image_filename), sv, - f"Slice Index: {sv.slice_index}", + gr.Slider(sv.idx_range[0], sv.idx_range[1], value=sv.slice_index, step=1, visible=True, interactive=True), gr.Dataset(samples=EXAMPLE_PROMPTS_3D), ) sv.slice_index = None + sv.idx_range = (None, None) return ( img_file, sv, - "Slice Index: N/A for 2D images, clicking prev/next will not change the image.", + gr.Slider(0, 2, 1, 0, visible=False), gr.Dataset(samples=EXAMPLE_PROMPTS_2D), ) -def update_image_next_10(selected_image, sv, slice_index_html): - """Update the image to the next 10 slices""" - return update_image_selection(selected_image, sv, slice_index_html, increment=10) - - -def update_image_next_1(selected_image, sv, slice_index_html): - """Update the image to the next slice""" - return update_image_selection(selected_image, sv, slice_index_html, increment=1) - - -def update_image_prev_1(selected_image, sv, slice_index_html): - """Update the image to the previous slice""" - return update_image_selection(selected_image, sv, slice_index_html, increment=-1) - - -def update_image_prev_10(selected_image, sv, slice_index_html): - """Update the image to the previous 10 slices""" - return update_image_selection(selected_image, sv, slice_index_html, increment=-10) - - -def colorcode_message(text="", data_url=None, show_all=False, role="user"): +def colorcode_message(text="", data_url=None, show_all=False, role="user", sys_msgs_to_hide: list = None): """Color the text based on the role and return the HTML text""" logger.debug(f"Preparing the HTML text with {show_all} and role: {role}") # if content is not a data URL, escape the text if not show_all and role == "expert": return "" + if not show_all and sys_msgs_to_hide and isinstance(sys_msgs_to_hide, list): + for sys_msg in sys_msgs_to_hide: + text = text.replace(sys_msg, "") + escaped_text = html.escape(text) + # replace newlines with
tags + escaped_text = escaped_text.replace("\n", "
") if data_url is not None: escaped_text += f'' if role == "user": @@ -570,7 +582,7 @@ def colorcode_message(text="", data_url=None, show_all=False, role="user"): raise ValueError(f"Invalid role: {role}") -def clear_one_conv(sv): +def clear_one_conv(sv: SessionVariables): """ Post-event hook indicating the session ended.It's called when `new_session_variables` finishes. Particularly, it resets the non-text parameters. So it excludes: @@ -590,13 +602,15 @@ def clear_one_conv(sv): d_btn = gr.DownloadButton(label=f"Download {name}", value=filepath, visible=True) else: d_btn = gr.DownloadButton(visible=False) - # Order of output: image, image_selector, slice_index_html, temperature_slider, top_p_slider, max_tokens_slider, download_button - return sv, None, None, "Slice Index: N/A", sv.temperature, sv.top_p, sv.max_tokens, d_btn + # Order of output: image, image_selector, temperature_slider, top_p_slider, max_tokens_slider, download_button, image_slider + return sv, None, None, sv.temperature, sv.top_p, sv.max_tokens, d_btn, gr.Slider(0, 2, 1, 0, visible=False) -def clear_all_convs(): +def clear_all_convs(sv: SessionVariables): """Clear and reset everything, Inputs/outputs are the gradio components.""" logger.debug(f"Clearing all conversations") + if sv.temp_working_dir is not None: + rmtree(sv.temp_working_dir) new_sv = new_session_variables() # Order of output: prompt_edit, chat_history, history_text, history_text_full, sys_prompt_text, sys_message_text return ( @@ -661,15 +675,13 @@ def create_demo(source, model_path, conv_mode, server_port): with gr.Row(): with gr.Column(): - image_dropdown = gr.Dropdown(label="Select an image", choices=["Please select .."] + list(IMAGES_URLS.keys())) - image_input = gr.Image(label="Image", sources=[], placeholder="Please select an image from the dropdown list.") - with gr.Accordion("3D image panel", open=False): - slice_index_html = gr.HTML("Slice Index: N/A") - with gr.Row(): - prev10_btn = gr.Button("<<") - prev01_btn = gr.Button("<") - next01_btn = gr.Button(">") - next10_btn = gr.Button(">>") + image_dropdown = gr.Dropdown( + label="Select an image", choices=["Please select .."] + list(IMG_URLS_OR_PATHS.keys()) + ) + image_input = gr.Image( + label="Image", sources=[], placeholder="Please select an image from the dropdown list." + ) + image_slider = gr.Slider(0, 2, 1, 0, visible=False) with gr.Accordion("View Parameters", open=False): temperature_slider = gr.Slider( @@ -703,7 +715,10 @@ def create_demo(source, model_path, conv_mode, server_port): clear_btn = gr.Button("Clear Conversation") with gr.Row(variant="compact"): prompt_edit = gr.Textbox( - label="TextPrompt", container=False, placeholder="Please ask a question about the current image or 2D slice", scale=2 + label="TextPrompt", + container=False, + placeholder="Please ask a question about the current image or 2D slice", + scale=2, ) submit_btn = gr.Button("Submit", scale=0) examples = gr.Examples([[""]], prompt_edit) @@ -724,28 +739,13 @@ def create_demo(source, model_path, conv_mode, server_port): image_input.input(fn=input_image, inputs=[image_input, sv], outputs=[image_input, sv]) image_dropdown.change( fn=update_image_selection, - inputs=[image_dropdown, sv, slice_index_html], - outputs=[image_input, sv, slice_index_html, examples.dataset], - ) - prev10_btn.click( - fn=update_image_prev_10, - inputs=[image_dropdown, sv, slice_index_html], - outputs=[image_input, sv, slice_index_html], - ) - prev01_btn.click( - fn=update_image_prev_1, - inputs=[image_dropdown, sv, slice_index_html], - outputs=[image_input, sv, slice_index_html], + inputs=[image_dropdown, sv], + outputs=[image_input, sv, image_slider, examples.dataset], ) - next01_btn.click( - fn=update_image_next_1, - inputs=[image_dropdown, sv, slice_index_html], - outputs=[image_input, sv, slice_index_html], - ) - next10_btn.click( - fn=update_image_next_10, - inputs=[image_dropdown, sv, slice_index_html], - outputs=[image_input, sv, slice_index_html], + image_slider.release( + fn=update_image_selection, + inputs=[image_dropdown, sv, image_slider], + outputs=[image_input, sv, image_slider, examples.dataset], ) temperature_slider.change(fn=update_temperature, inputs=[temperature_slider, sv], outputs=[sv]) top_p_slider.change(fn=update_top_p, inputs=[top_p_slider, sv], outputs=[sv]) @@ -755,7 +755,7 @@ def create_demo(source, model_path, conv_mode, server_port): # Reset button clear_btn.click( fn=clear_all_convs, - inputs=[], + inputs=[sv], outputs=[sv, prompt_edit, chat_history, history_text, history_text_full, sys_prompt_text, sys_message_text], ) @@ -767,13 +767,14 @@ def create_demo(source, model_path, conv_mode, server_port): sv, image_input, image_dropdown, - slice_index_html, temperature_slider, top_p_slider, max_tokens_slider, image_download, + image_slider, ], ) + demo.queue() demo.launch(server_name="0.0.0.0", server_port=server_port) diff --git a/tests/test_demo_expert_model.py b/monai_vila2d/tests/test_demo_expert_model.py similarity index 100% rename from tests/test_demo_expert_model.py rename to monai_vila2d/tests/test_demo_expert_model.py