Skip to content

Commit

Permalink
Handle expert errors better and disable text as the first prompt (#20)
Browse files Browse the repository at this point in the history
This PR includes the following changes:
- Handle expert errors as a message from the `expert`, and display the
error image in the conversation.
- Forbid the user to send the first prompt without image
- Set the max_tokens default value to 1024
- Create `source` argument to allow future integration of
huggingface/nim/ollama

---------

Signed-off-by: Mingxin Zheng <[email protected]>
  • Loading branch information
mingxin-zheng authored Oct 17, 2024
1 parent bc31eec commit 1b32a28
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 37 deletions.
10 changes: 6 additions & 4 deletions demo/experts/expert_monai_vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def mentioned_by(self, input: str):
Returns:
bool: True if the VISTA-3D model is mentioned, False otherwise.
"""
return self.model_name in str(input)
matches = re.findall(r"<(.*?)>", str(input))
if len(matches) != 1:
return False
return self.model_name in str(matches[0])

def run(
self,
Expand Down Expand Up @@ -181,7 +184,7 @@ def run(
arg_matches = ["everything"]
if len(arg_matches) > 1:
raise ValueError(
"Multiple expert model arguments are provided in the same prompt"
"Multiple expert model arguments are provided in the same prompt, "
"which is not supported in this version."
)

Expand All @@ -191,7 +194,7 @@ def run(
label_groups = json.load(f)

if arg_matches[0] not in label_groups:
raise ValueError(f"Label group {arg_matches[0]} not found in {label_groups}")
raise ValueError(f"Label group {arg_matches[0]} is not accepted by the VISTA-3D model.")

if arg_matches[0] != "everything":
vista3d_prompts = {"classes": list(label_groups[arg_matches[0]].keys())}
Expand All @@ -213,7 +216,6 @@ def run(

response = requests.post(self.NIM_VISTA3D, headers=headers, json=payload)
if response.status_code != 200:
payload.pop("image") # hide the image URL in the error message
raise requests.exceptions.HTTPError(
f"Error triggering POST to {self.NIM_VISTA3D} with Payload {payload}: {response.status_code}"
)
Expand Down
7 changes: 5 additions & 2 deletions demo/experts/expert_torchxrayvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import os
import re

import requests
from experts.base_expert import BaseExpert
Expand Down Expand Up @@ -47,8 +48,10 @@ def mentioned_by(self, input: str):
Returns:
bool: True if the CXR model is mentioned, False otherwise.
"""

return self.model_name in str(input)
matches = re.findall(r"<(.*?)>", str(input))
if len(matches) != 1:
return False
return self.model_name in str(matches[0])

def run(self, image_url: str = "", prompt: str = "", **kwargs):
"""
Expand Down
4 changes: 4 additions & 0 deletions demo/experts/label_groups_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
"liver": 1,
"hepatic tumor": 26
},
"hepatoma": {
"liver": 1,
"hepatic tumor": 26
},
"pancreatic tumor": {
"pancreas": 4,
"pancreatic tumor": 24
Expand Down
95 changes: 64 additions & 31 deletions demo/gradio_monai_vila2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,11 @@ def __init__(self):
self.axis = 2
self.top_p = 0.9
self.temperature = 0.0
self.max_tokens = 300
self.max_tokens = 1024
self.download_file_path = "" # Path to the downloaded file
self.temp_working_dir = None
self.idx_range = (None, None)
self.interactive = False


def new_session_variables(**kwargs):
Expand All @@ -260,23 +261,28 @@ def new_session_variables(**kwargs):
class M3Generator:
"""Class to generate M3 responses"""

def __init__(self, model_path, conv_mode):
def __init__(self, source="local", model_path="", conv_mode=""):
"""Initialize the M3 generator"""
# TODO: allow setting the device
disable_torch_init()
self.conv_mode = conv_mode
self.model_name = get_model_name_from_path(model_path)
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, self.model_name
)
logger.info(f"Model {self.model_name} loaded successfully. Context length: {self.context_len}")
self.source = source
# TODO: support huggingface models
if source == "local":
# TODO: allow setting the device
disable_torch_init()
self.conv_mode = conv_mode
self.model_name = get_model_name_from_path(model_path)
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, self.model_name
)
logger.info(f"Model {self.model_name} loaded successfully. Context length: {self.context_len}")
else:
raise NotImplementedError(f"Source {source} is not supported.")

def generate_response(
def generate_response_local(
self,
messages: list,
max_tokens: int,
temperature: float,
top_p: float,
messages: list = [],
max_tokens: int = 1024,
temperature: float = 0.0,
top_p: float = 0.9,
system_prompt: str | None = None,
):
"""Generate the response"""
Expand Down Expand Up @@ -345,6 +351,11 @@ def generate_response(

return outputs

def generate_response(self, **kwargs):
"""Generate the response"""
if self.source == "local":
return self.generate_response_local(**kwargs)

def squash_expert_messages_into_user(self, messages: list):
"""Squash consecutive expert messages into a single user message."""
logger.debug("Squashing expert messages into user messages")
Expand All @@ -368,6 +379,10 @@ def process_prompt(self, prompt, sv, chat_history):
"""Process the prompt and return the result. Inputs/outputs are the gradio components."""
logger.debug(f"Process the image and return the result")

if not sv.interactive:
# Do not process the prompt if the image is not provided
return None, sv, chat_history, "Please select an image", "Please select an image"

if sv.temp_working_dir is None:
sv.temp_working_dir = tempfile.mkdtemp()

Expand Down Expand Up @@ -415,15 +430,24 @@ def process_prompt(self, prompt, sv, chat_history):
break

if expert:
logger.debug(f"Expert model {expert.__class__.__name__} is being called.")
text_output, seg_file, instruction, download_pkg = expert.run(
image_url=sv.image_url,
input=outputs,
output_dir=sv.temp_working_dir,
img_file=img_file,
slice_index=sv.slice_index,
prompt=prompt,
)
logger.debug(f"Expert model {expert.__class__.__name__} is being called to process {sv.image_url}.")
try:
if sv.image_url is None:
raise ValueError(f"No image is provided with {outputs}.")
text_output, seg_file, instruction, download_pkg = expert.run(
image_url=sv.image_url,
input=outputs,
output_dir=sv.temp_working_dir,
img_file=img_file,
slice_index=sv.slice_index,
prompt=prompt,
)
except Exception as e:
text_output = f"Sorry I met an error: {e}"
seg_file = None
instruction = ""
download_pkg = ""

chat_history.append(text_output, image_path=seg_file, role="expert")
if instruction:
chat_history.append(instruction, role="expert")
Expand All @@ -441,6 +465,7 @@ def process_prompt(self, prompt, sv, chat_history):
sys_prompt=sv.sys_prompt,
sys_msg=sv.sys_msg,
download_file_path=download_pkg,
interactive=True,
)
return None, new_sv, chat_history, chat_history.get_html(show_all=False), chat_history.get_html(show_all=True)

Expand All @@ -450,6 +475,7 @@ def input_image(image, sv: SessionVariables):
logger.debug(f"Received user input image")
# TODO: support user uploaded images
sv.image_url = image_to_data_url(image)
sv.interactive = True
return image, sv


Expand All @@ -465,6 +491,7 @@ def update_image_selection(selected_image, sv: SessionVariables, slice_index_htm
if sv.temp_working_dir is None:
sv.temp_working_dir = tempfile.mkdtemp()

sv.interactive = True
if img_file.endswith(".nii.gz"):
if sv.slice_index is None:
data = nib.load(img_file).get_fdata()
Expand Down Expand Up @@ -555,7 +582,7 @@ def clear_one_conv(sv):
else:
d_btn = gr.DownloadButton(visible=False)
# Order of output: image, image_selector, slice_index_html, temperature_slider, top_p_slider, max_tokens_slider, download_button
return sv, None, None, "Slice Index: N/A", 0.0, 0.9, 300, d_btn
return sv, None, None, "Slice Index: N/A", sv.temperature, sv.top_p, sv.max_tokens, d_btn


def clear_all_convs():
Expand Down Expand Up @@ -614,9 +641,9 @@ def download_file():
return [gr.DownloadButton(visible=False)]


def create_demo(model_path, conv_mode, server_port):
def create_demo(source, model_path, conv_mode, server_port):
"""Main function to create the Gradio interface"""
generator = M3Generator(model_path, conv_mode)
generator = M3Generator(source=source, model_path=model_path, conv_mode=conv_mode)

with gr.Blocks(css=CSS_STYLES) as demo:
gr.HTML(TITLE, label="Title")
Expand All @@ -625,7 +652,7 @@ def create_demo(model_path, conv_mode, server_port):

with gr.Row():
with gr.Column():
image_input = gr.Image(label="Image", placeholder="Please select an image from the dropdown list.")
image_input = gr.Image(label="Image", sources=[], placeholder="Please select an image from the dropdown list.")
image_dropdown = gr.Dropdown(label="Select an image", choices=list(IMAGES_URLS.keys()))
with gr.Accordion("View Parameters", open=False):
temperature_slider = gr.Slider(
Expand All @@ -635,7 +662,7 @@ def create_demo(model_path, conv_mode, server_port):
label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.9, interactive=True
)
max_tokens_slider = gr.Slider(
label="Max Tokens", minimum=1, maximum=1024, step=1, value=300, interactive=True
label="Max Tokens", minimum=1, maximum=1024, step=1, value=1024, interactive=True
)

with gr.Accordion("3D image panel", open=False):
Expand Down Expand Up @@ -748,7 +775,7 @@ def create_demo(model_path, conv_mode, server_port):
parser.add_argument(
"--modelpath",
type=str,
default="/workspace/nvidia/medical-service-nims/vila/checkpoints/baseline/checkpoint-3500",
default="/data/checkpoints/vila-m3-8b",
help="The path to the model to load.",
)
parser.add_argument(
Expand All @@ -757,8 +784,14 @@ def create_demo(model_path, conv_mode, server_port):
default=7860,
help="The port to run the Gradio server on.",
)
parser.add_argument(
"--source",
type=str,
default="local",
help="The source of the model. Option is only 'local'.",
)
args = parser.parse_args()
SYS_PROMPT = conv_templates[args.convmode].system
cache_images()
create_demo(args.modelpath, args.convmode, args.port)
create_demo(args.source, args.modelpath, args.convmode, args.port)
cache_cleanup()

0 comments on commit 1b32a28

Please sign in to comment.