diff --git a/Makefile b/Makefile index a1a852e..4f8bd07 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ demo_monai_vila2d: cd thirdparty/VILA; \ ./environment_setup.sh - pip install -U nibabel python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage] + pip install -U python-dotenv deepspeed gradio monai[nibabel,pynrrd,skimage] diff --git a/demo/gradio_monai_vila2d.py b/demo/gradio_monai_vila2d.py index 9d71d99..62c77dc 100644 --- a/demo/gradio_monai_vila2d.py +++ b/demo/gradio_monai_vila2d.py @@ -482,28 +482,28 @@ def input_image(image, sv: SessionVariables): return image, sv -def update_image_selection(selected_image, sv: SessionVariables, slice_index_html, increment=None): +def update_image_selection(selected_image, sv: SessionVariables, slice_index=None): """Update the gradio components based on the selected image""" logger.debug(f"Updating display image for {selected_image}") sv.image_url = IMAGES_URLS.get(selected_image, None) img_file = CACHED_IMAGES.get(sv.image_url, None) if sv.image_url is None: - return None, sv, slice_index_html, [[""]] + return None, sv, [[""]] if sv.temp_working_dir is None: sv.temp_working_dir = tempfile.mkdtemp() sv.interactive = True if img_file.endswith(".nii.gz"): - if sv.slice_index is None: + if slice_index is None: data = nib.load(img_file).get_fdata() sv.slice_index = data.shape[sv.axis] // 2 sv.idx_range = (0, data.shape[sv.axis] - 1) - - if increment is not None: - sv.slice_index += increment - sv.slice_index = max(sv.idx_range[0], min(sv.idx_range[1], sv.slice_index)) + else: + # Slice index is updated by the slidebar. + # There is no need to update the idx_range. + sv.slice_index = slice_index image_filename = get_slice_filenames(img_file, sv.slice_index)[0] if not os.path.exists(image_filename): @@ -518,39 +518,20 @@ def update_image_selection(selected_image, sv: SessionVariables, slice_index_htm return ( os.path.join(sv.temp_working_dir, image_filename), sv, - f"Slice Index: {sv.slice_index}", + gr.Slider(sv.idx_range[0], sv.idx_range[1], value=sv.slice_index, visible=True, interactive=True), gr.Dataset(samples=EXAMPLE_PROMPTS_3D), ) sv.slice_index = None + sv.idx_range = (None, None) return ( img_file, sv, - "Slice Index: N/A for 2D images, clicking prev/next will not change the image.", + gr.Slider(0, 2, 1, 0, visible=False), gr.Dataset(samples=EXAMPLE_PROMPTS_2D), ) -def update_image_next_10(selected_image, sv, slice_index_html): - """Update the image to the next 10 slices""" - return update_image_selection(selected_image, sv, slice_index_html, increment=10) - - -def update_image_next_1(selected_image, sv, slice_index_html): - """Update the image to the next slice""" - return update_image_selection(selected_image, sv, slice_index_html, increment=1) - - -def update_image_prev_1(selected_image, sv, slice_index_html): - """Update the image to the previous slice""" - return update_image_selection(selected_image, sv, slice_index_html, increment=-1) - - -def update_image_prev_10(selected_image, sv, slice_index_html): - """Update the image to the previous 10 slices""" - return update_image_selection(selected_image, sv, slice_index_html, increment=-10) - - def colorcode_message(text="", data_url=None, show_all=False, role="user"): """Color the text based on the role and return the HTML text""" logger.debug(f"Preparing the HTML text with {show_all} and role: {role}") @@ -590,8 +571,8 @@ def clear_one_conv(sv): d_btn = gr.DownloadButton(label=f"Download {name}", value=filepath, visible=True) else: d_btn = gr.DownloadButton(visible=False) - # Order of output: image, image_selector, slice_index_html, temperature_slider, top_p_slider, max_tokens_slider, download_button - return sv, None, None, "Slice Index: N/A", sv.temperature, sv.top_p, sv.max_tokens, d_btn + # Order of output: image, image_selector, temperature_slider, top_p_slider, max_tokens_slider, download_button, image_slider + return sv, None, None, sv.temperature, sv.top_p, sv.max_tokens, d_btn, gr.Slider(0, 2, 1, 0, visible=False) def clear_all_convs(): @@ -663,13 +644,7 @@ def create_demo(source, model_path, conv_mode, server_port): with gr.Column(): image_dropdown = gr.Dropdown(label="Select an image", choices=["Please select .."] + list(IMAGES_URLS.keys())) image_input = gr.Image(label="Image", sources=[], placeholder="Please select an image from the dropdown list.") - with gr.Accordion("3D image panel", open=False): - slice_index_html = gr.HTML("Slice Index: N/A") - with gr.Row(): - prev10_btn = gr.Button("<<") - prev01_btn = gr.Button("<") - next01_btn = gr.Button(">") - next10_btn = gr.Button(">>") + image_slider = gr.Slider(label="Slice Index", minimum=0, maximum=100, step=1, value=0, visible=False) with gr.Accordion("View Parameters", open=False): temperature_slider = gr.Slider( @@ -724,28 +699,13 @@ def create_demo(source, model_path, conv_mode, server_port): image_input.input(fn=input_image, inputs=[image_input, sv], outputs=[image_input, sv]) image_dropdown.change( fn=update_image_selection, - inputs=[image_dropdown, sv, slice_index_html], - outputs=[image_input, sv, slice_index_html, examples.dataset], - ) - prev10_btn.click( - fn=update_image_prev_10, - inputs=[image_dropdown, sv, slice_index_html], - outputs=[image_input, sv, slice_index_html], + inputs=[image_dropdown, sv], + outputs=[image_input, sv, image_slider, examples.dataset], ) - prev01_btn.click( - fn=update_image_prev_1, - inputs=[image_dropdown, sv, slice_index_html], - outputs=[image_input, sv, slice_index_html], - ) - next01_btn.click( - fn=update_image_next_1, - inputs=[image_dropdown, sv, slice_index_html], - outputs=[image_input, sv, slice_index_html], - ) - next10_btn.click( - fn=update_image_next_10, - inputs=[image_dropdown, sv, slice_index_html], - outputs=[image_input, sv, slice_index_html], + image_slider.release( + fn=update_image_selection, + inputs=[image_dropdown, sv, image_slider], + outputs=[image_input, sv, image_slider, examples.dataset], ) temperature_slider.change(fn=update_temperature, inputs=[temperature_slider, sv], outputs=[sv]) top_p_slider.change(fn=update_top_p, inputs=[top_p_slider, sv], outputs=[sv]) @@ -767,11 +727,11 @@ def create_demo(source, model_path, conv_mode, server_port): sv, image_input, image_dropdown, - slice_index_html, temperature_slider, top_p_slider, max_tokens_slider, image_download, + image_slider, ], ) demo.launch(server_name="0.0.0.0", server_port=server_port)