Skip to content

Commit

Permalink
replace buttons with slider to slice 3d image
Browse files Browse the repository at this point in the history
Signed-off-by: Mingxin Zheng <[email protected]>
  • Loading branch information
mingxin-zheng committed Oct 20, 2024
1 parent 96e0766 commit d285fd0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 61 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
demo_monai_vila2d:
cd thirdparty/VILA; \
./environment_setup.sh
pip install -U nibabel python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage]
pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage]
80 changes: 20 additions & 60 deletions demo/gradio_monai_vila2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,28 +482,28 @@ def input_image(image, sv: SessionVariables):
return image, sv


def update_image_selection(selected_image, sv: SessionVariables, slice_index_html, increment=None):
def update_image_selection(selected_image, sv: SessionVariables, slice_index=None):
"""Update the gradio components based on the selected image"""
logger.debug(f"Updating display image for {selected_image}")
sv.image_url = IMAGES_URLS.get(selected_image, None)
img_file = CACHED_IMAGES.get(sv.image_url, None)

if sv.image_url is None:
return None, sv, slice_index_html, [[""]]
return None, sv, [[""]]

if sv.temp_working_dir is None:
sv.temp_working_dir = tempfile.mkdtemp()

sv.interactive = True
if img_file.endswith(".nii.gz"):
if sv.slice_index is None:
if slice_index is None:
data = nib.load(img_file).get_fdata()
sv.slice_index = data.shape[sv.axis] // 2
sv.idx_range = (0, data.shape[sv.axis] - 1)

if increment is not None:
sv.slice_index += increment
sv.slice_index = max(sv.idx_range[0], min(sv.idx_range[1], sv.slice_index))
else:
# Slice index is updated by the slidebar.
# There is no need to update the idx_range.
sv.slice_index = slice_index

image_filename = get_slice_filenames(img_file, sv.slice_index)[0]
if not os.path.exists(image_filename):
Expand All @@ -518,39 +518,20 @@ def update_image_selection(selected_image, sv: SessionVariables, slice_index_htm
return (
os.path.join(sv.temp_working_dir, image_filename),
sv,
f"Slice Index: {sv.slice_index}",
gr.Slider(sv.idx_range[0], sv.idx_range[1], value=sv.slice_index, visible=True, interactive=True),
gr.Dataset(samples=EXAMPLE_PROMPTS_3D),
)

sv.slice_index = None
sv.idx_range = (None, None)
return (
img_file,
sv,
"Slice Index: N/A for 2D images, clicking prev/next will not change the image.",
gr.Slider(0, 2, 1, 0, visible=False),
gr.Dataset(samples=EXAMPLE_PROMPTS_2D),
)


def update_image_next_10(selected_image, sv, slice_index_html):
"""Update the image to the next 10 slices"""
return update_image_selection(selected_image, sv, slice_index_html, increment=10)


def update_image_next_1(selected_image, sv, slice_index_html):
"""Update the image to the next slice"""
return update_image_selection(selected_image, sv, slice_index_html, increment=1)


def update_image_prev_1(selected_image, sv, slice_index_html):
"""Update the image to the previous slice"""
return update_image_selection(selected_image, sv, slice_index_html, increment=-1)


def update_image_prev_10(selected_image, sv, slice_index_html):
"""Update the image to the previous 10 slices"""
return update_image_selection(selected_image, sv, slice_index_html, increment=-10)


def colorcode_message(text="", data_url=None, show_all=False, role="user"):
"""Color the text based on the role and return the HTML text"""
logger.debug(f"Preparing the HTML text with {show_all} and role: {role}")
Expand Down Expand Up @@ -590,8 +571,8 @@ def clear_one_conv(sv):
d_btn = gr.DownloadButton(label=f"Download {name}", value=filepath, visible=True)
else:
d_btn = gr.DownloadButton(visible=False)
# Order of output: image, image_selector, slice_index_html, temperature_slider, top_p_slider, max_tokens_slider, download_button
return sv, None, None, "Slice Index: N/A", sv.temperature, sv.top_p, sv.max_tokens, d_btn
# Order of output: image, image_selector, temperature_slider, top_p_slider, max_tokens_slider, download_button, image_slider
return sv, None, None, sv.temperature, sv.top_p, sv.max_tokens, d_btn, gr.Slider(0, 2, 1, 0, visible=False)


def clear_all_convs():
Expand Down Expand Up @@ -663,13 +644,7 @@ def create_demo(source, model_path, conv_mode, server_port):
with gr.Column():
image_dropdown = gr.Dropdown(label="Select an image", choices=["Please select .."] + list(IMAGES_URLS.keys()))
image_input = gr.Image(label="Image", sources=[], placeholder="Please select an image from the dropdown list.")
with gr.Accordion("3D image panel", open=False):
slice_index_html = gr.HTML("Slice Index: N/A")
with gr.Row():
prev10_btn = gr.Button("<<")
prev01_btn = gr.Button("<")
next01_btn = gr.Button(">")
next10_btn = gr.Button(">>")
image_slider = gr.Slider(label="Slice Index", minimum=0, maximum=100, step=1, value=0, visible=False)

with gr.Accordion("View Parameters", open=False):
temperature_slider = gr.Slider(
Expand Down Expand Up @@ -724,28 +699,13 @@ def create_demo(source, model_path, conv_mode, server_port):
image_input.input(fn=input_image, inputs=[image_input, sv], outputs=[image_input, sv])
image_dropdown.change(
fn=update_image_selection,
inputs=[image_dropdown, sv, slice_index_html],
outputs=[image_input, sv, slice_index_html, examples.dataset],
)
prev10_btn.click(
fn=update_image_prev_10,
inputs=[image_dropdown, sv, slice_index_html],
outputs=[image_input, sv, slice_index_html],
inputs=[image_dropdown, sv],
outputs=[image_input, sv, image_slider, examples.dataset],
)
prev01_btn.click(
fn=update_image_prev_1,
inputs=[image_dropdown, sv, slice_index_html],
outputs=[image_input, sv, slice_index_html],
)
next01_btn.click(
fn=update_image_next_1,
inputs=[image_dropdown, sv, slice_index_html],
outputs=[image_input, sv, slice_index_html],
)
next10_btn.click(
fn=update_image_next_10,
inputs=[image_dropdown, sv, slice_index_html],
outputs=[image_input, sv, slice_index_html],
image_slider.release(
fn=update_image_selection,
inputs=[image_dropdown, sv, image_slider],
outputs=[image_input, sv, image_slider, examples.dataset],
)
temperature_slider.change(fn=update_temperature, inputs=[temperature_slider, sv], outputs=[sv])
top_p_slider.change(fn=update_top_p, inputs=[top_p_slider, sv], outputs=[sv])
Expand All @@ -767,11 +727,11 @@ def create_demo(source, model_path, conv_mode, server_port):
sv,
image_input,
image_dropdown,
slice_index_html,
temperature_slider,
top_p_slider,
max_tokens_slider,
image_download,
image_slider,
],
)
demo.launch(server_name="0.0.0.0", server_port=server_port)
Expand Down

0 comments on commit d285fd0

Please sign in to comment.