From 1b32a28185c5d97d7a39353b32ed832cb250345c Mon Sep 17 00:00:00 2001 From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Date: Thu, 17 Oct 2024 23:12:53 +0800 Subject: [PATCH] Handle expert errors better and disable text as the first prompt (#20) This PR includes the following changes: - Handle expert errors as a message from the `expert`, and display the error image in the conversation. - Forbid the user to send the first prompt without image - Set the max_tokens default value to 1024 - Create `source` argument to allow future integration of huggingface/nim/ollama --------- Signed-off-by: Mingxin Zheng --- demo/experts/expert_monai_vista3d.py | 10 +-- demo/experts/expert_torchxrayvision.py | 7 +- demo/experts/label_groups_dict.json | 4 ++ demo/gradio_monai_vila2d.py | 95 +++++++++++++++++--------- 4 files changed, 79 insertions(+), 37 deletions(-) diff --git a/demo/experts/expert_monai_vista3d.py b/demo/experts/expert_monai_vista3d.py index 6578a97..8081a69 100644 --- a/demo/experts/expert_monai_vista3d.py +++ b/demo/experts/expert_monai_vista3d.py @@ -141,7 +141,10 @@ def mentioned_by(self, input: str): Returns: bool: True if the VISTA-3D model is mentioned, False otherwise. """ - return self.model_name in str(input) + matches = re.findall(r"<(.*?)>", str(input)) + if len(matches) != 1: + return False + return self.model_name in str(matches[0]) def run( self, @@ -181,7 +184,7 @@ def run( arg_matches = ["everything"] if len(arg_matches) > 1: raise ValueError( - "Multiple expert model arguments are provided in the same prompt" + "Multiple expert model arguments are provided in the same prompt, " "which is not supported in this version." ) @@ -191,7 +194,7 @@ def run( label_groups = json.load(f) if arg_matches[0] not in label_groups: - raise ValueError(f"Label group {arg_matches[0]} not found in {label_groups}") + raise ValueError(f"Label group {arg_matches[0]} is not accepted by the VISTA-3D model.") if arg_matches[0] != "everything": vista3d_prompts = {"classes": list(label_groups[arg_matches[0]].keys())} @@ -213,7 +216,6 @@ def run( response = requests.post(self.NIM_VISTA3D, headers=headers, json=payload) if response.status_code != 200: - payload.pop("image") # hide the image URL in the error message raise requests.exceptions.HTTPError( f"Error triggering POST to {self.NIM_VISTA3D} with Payload {payload}: {response.status_code}" ) diff --git a/demo/experts/expert_torchxrayvision.py b/demo/experts/expert_torchxrayvision.py index c727d84..b545c03 100644 --- a/demo/experts/expert_torchxrayvision.py +++ b/demo/experts/expert_torchxrayvision.py @@ -10,6 +10,7 @@ # limitations under the License. import os +import re import requests from experts.base_expert import BaseExpert @@ -47,8 +48,10 @@ def mentioned_by(self, input: str): Returns: bool: True if the CXR model is mentioned, False otherwise. """ - - return self.model_name in str(input) + matches = re.findall(r"<(.*?)>", str(input)) + if len(matches) != 1: + return False + return self.model_name in str(matches[0]) def run(self, image_url: str = "", prompt: str = "", **kwargs): """ diff --git a/demo/experts/label_groups_dict.json b/demo/experts/label_groups_dict.json index 84a94d7..e03cf0f 100644 --- a/demo/experts/label_groups_dict.json +++ b/demo/experts/label_groups_dict.json @@ -4,6 +4,10 @@ "liver": 1, "hepatic tumor": 26 }, + "hepatoma": { + "liver": 1, + "hepatic tumor": 26 + }, "pancreatic tumor": { "pancreas": 4, "pancreatic tumor": 24 diff --git a/demo/gradio_monai_vila2d.py b/demo/gradio_monai_vila2d.py index 4a107fc..e235fb6 100644 --- a/demo/gradio_monai_vila2d.py +++ b/demo/gradio_monai_vila2d.py @@ -240,10 +240,11 @@ def __init__(self): self.axis = 2 self.top_p = 0.9 self.temperature = 0.0 - self.max_tokens = 300 + self.max_tokens = 1024 self.download_file_path = "" # Path to the downloaded file self.temp_working_dir = None self.idx_range = (None, None) + self.interactive = False def new_session_variables(**kwargs): @@ -260,23 +261,28 @@ def new_session_variables(**kwargs): class M3Generator: """Class to generate M3 responses""" - def __init__(self, model_path, conv_mode): + def __init__(self, source="local", model_path="", conv_mode=""): """Initialize the M3 generator""" - # TODO: allow setting the device - disable_torch_init() - self.conv_mode = conv_mode - self.model_name = get_model_name_from_path(model_path) - self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( - model_path, self.model_name - ) - logger.info(f"Model {self.model_name} loaded successfully. Context length: {self.context_len}") + self.source = source + # TODO: support huggingface models + if source == "local": + # TODO: allow setting the device + disable_torch_init() + self.conv_mode = conv_mode + self.model_name = get_model_name_from_path(model_path) + self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( + model_path, self.model_name + ) + logger.info(f"Model {self.model_name} loaded successfully. Context length: {self.context_len}") + else: + raise NotImplementedError(f"Source {source} is not supported.") - def generate_response( + def generate_response_local( self, - messages: list, - max_tokens: int, - temperature: float, - top_p: float, + messages: list = [], + max_tokens: int = 1024, + temperature: float = 0.0, + top_p: float = 0.9, system_prompt: str | None = None, ): """Generate the response""" @@ -345,6 +351,11 @@ def generate_response( return outputs + def generate_response(self, **kwargs): + """Generate the response""" + if self.source == "local": + return self.generate_response_local(**kwargs) + def squash_expert_messages_into_user(self, messages: list): """Squash consecutive expert messages into a single user message.""" logger.debug("Squashing expert messages into user messages") @@ -368,6 +379,10 @@ def process_prompt(self, prompt, sv, chat_history): """Process the prompt and return the result. Inputs/outputs are the gradio components.""" logger.debug(f"Process the image and return the result") + if not sv.interactive: + # Do not process the prompt if the image is not provided + return None, sv, chat_history, "Please select an image", "Please select an image" + if sv.temp_working_dir is None: sv.temp_working_dir = tempfile.mkdtemp() @@ -415,15 +430,24 @@ def process_prompt(self, prompt, sv, chat_history): break if expert: - logger.debug(f"Expert model {expert.__class__.__name__} is being called.") - text_output, seg_file, instruction, download_pkg = expert.run( - image_url=sv.image_url, - input=outputs, - output_dir=sv.temp_working_dir, - img_file=img_file, - slice_index=sv.slice_index, - prompt=prompt, - ) + logger.debug(f"Expert model {expert.__class__.__name__} is being called to process {sv.image_url}.") + try: + if sv.image_url is None: + raise ValueError(f"No image is provided with {outputs}.") + text_output, seg_file, instruction, download_pkg = expert.run( + image_url=sv.image_url, + input=outputs, + output_dir=sv.temp_working_dir, + img_file=img_file, + slice_index=sv.slice_index, + prompt=prompt, + ) + except Exception as e: + text_output = f"Sorry I met an error: {e}" + seg_file = None + instruction = "" + download_pkg = "" + chat_history.append(text_output, image_path=seg_file, role="expert") if instruction: chat_history.append(instruction, role="expert") @@ -441,6 +465,7 @@ def process_prompt(self, prompt, sv, chat_history): sys_prompt=sv.sys_prompt, sys_msg=sv.sys_msg, download_file_path=download_pkg, + interactive=True, ) return None, new_sv, chat_history, chat_history.get_html(show_all=False), chat_history.get_html(show_all=True) @@ -450,6 +475,7 @@ def input_image(image, sv: SessionVariables): logger.debug(f"Received user input image") # TODO: support user uploaded images sv.image_url = image_to_data_url(image) + sv.interactive = True return image, sv @@ -465,6 +491,7 @@ def update_image_selection(selected_image, sv: SessionVariables, slice_index_htm if sv.temp_working_dir is None: sv.temp_working_dir = tempfile.mkdtemp() + sv.interactive = True if img_file.endswith(".nii.gz"): if sv.slice_index is None: data = nib.load(img_file).get_fdata() @@ -555,7 +582,7 @@ def clear_one_conv(sv): else: d_btn = gr.DownloadButton(visible=False) # Order of output: image, image_selector, slice_index_html, temperature_slider, top_p_slider, max_tokens_slider, download_button - return sv, None, None, "Slice Index: N/A", 0.0, 0.9, 300, d_btn + return sv, None, None, "Slice Index: N/A", sv.temperature, sv.top_p, sv.max_tokens, d_btn def clear_all_convs(): @@ -614,9 +641,9 @@ def download_file(): return [gr.DownloadButton(visible=False)] -def create_demo(model_path, conv_mode, server_port): +def create_demo(source, model_path, conv_mode, server_port): """Main function to create the Gradio interface""" - generator = M3Generator(model_path, conv_mode) + generator = M3Generator(source=source, model_path=model_path, conv_mode=conv_mode) with gr.Blocks(css=CSS_STYLES) as demo: gr.HTML(TITLE, label="Title") @@ -625,7 +652,7 @@ def create_demo(model_path, conv_mode, server_port): with gr.Row(): with gr.Column(): - image_input = gr.Image(label="Image", placeholder="Please select an image from the dropdown list.") + image_input = gr.Image(label="Image", sources=[], placeholder="Please select an image from the dropdown list.") image_dropdown = gr.Dropdown(label="Select an image", choices=list(IMAGES_URLS.keys())) with gr.Accordion("View Parameters", open=False): temperature_slider = gr.Slider( @@ -635,7 +662,7 @@ def create_demo(model_path, conv_mode, server_port): label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.9, interactive=True ) max_tokens_slider = gr.Slider( - label="Max Tokens", minimum=1, maximum=1024, step=1, value=300, interactive=True + label="Max Tokens", minimum=1, maximum=1024, step=1, value=1024, interactive=True ) with gr.Accordion("3D image panel", open=False): @@ -748,7 +775,7 @@ def create_demo(model_path, conv_mode, server_port): parser.add_argument( "--modelpath", type=str, - default="/workspace/nvidia/medical-service-nims/vila/checkpoints/baseline/checkpoint-3500", + default="/data/checkpoints/vila-m3-8b", help="The path to the model to load.", ) parser.add_argument( @@ -757,8 +784,14 @@ def create_demo(model_path, conv_mode, server_port): default=7860, help="The port to run the Gradio server on.", ) + parser.add_argument( + "--source", + type=str, + default="local", + help="The source of the model. Option is only 'local'.", + ) args = parser.parse_args() SYS_PROMPT = conv_templates[args.convmode].system cache_images() - create_demo(args.modelpath, args.convmode, args.port) + create_demo(args.source, args.modelpath, args.convmode, args.port) cache_cleanup()