From f0c0cd333a570771a64f8b70f921fc89b7b2c517 Mon Sep 17 00:00:00 2001 From: Ian Hunt-Isaak Date: Fri, 14 Apr 2023 12:54:45 -0400 Subject: [PATCH] feat: only disallow duplicate names when values don't match (#275) --- mpl_interactions/controller.py | 50 +++++++++++++++++++++++----------- mpl_interactions/generic.py | 1 - mpl_interactions/helpers.py | 44 +++++++++++++++++------------- tests/test_generic.py | 8 ++++++ 4 files changed, 67 insertions(+), 36 deletions(-) diff --git a/mpl_interactions/controller.py b/mpl_interactions/controller.py index f062bec..2a9c6ec 100644 --- a/mpl_interactions/controller.py +++ b/mpl_interactions/controller.py @@ -65,9 +65,10 @@ def __init__( self.indices = defaultdict(lambda: 0) self._update_funcs = defaultdict(list) self._user_callbacks = defaultdict(list) + self._hashes = [] self.add_kwargs(kwargs, slider_formats, play_buttons) - def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_duplicates=False): + def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None): """Add kwargs to the controller. If you pass a redundant kwarg it will just be overwritten @@ -94,23 +95,37 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli self.slider_format_strings[k] = v if self.use_ipywidgets: for k, v in kwargs.items(): - if k in self.params: - if allow_duplicates: - continue - else: - raise ValueError("can't overwrite an existing param in the controller") if isinstance(v, AxesWidget): - self.params[k], self.controls[k], _ = process_mpl_widget( + # TODO: HASHING behavior + param, control, _, hash_ = process_mpl_widget( v, partial(self.slider_updated, key=k) ) + if k in self.params: + if hash_ not in self._hashes: + raise ValueError( + f"kwarg {k} already exists and the new values are incompatible." + ) + # don't need to add it because it already exists + continue + self.params[k], self.controls[k] = param, control + self._hashes.append(hash) else: - self.params[k], control = kwarg_to_ipywidget( + param, control, hash_ = kwarg_to_ipywidget( k, v, partial(self.slider_updated, key=k), self.slider_format_strings[k], play_button=_play_buttons[k], ) + if k in self.params: + if hash_ not in self._hashes: + raise ValueError( + f"kwarg {k} already exists and the new values are incompatible." + ) + # don't need to add it because it already exists + continue + self.params[k] = param + self._hashes.append(hash_) if control: self.controls[k] = control self.vbox.children = [*list(self.vbox.children), control] @@ -123,12 +138,7 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli self.control_figures.append(mpl_layout[0]) widget_y = 0.05 for k, v in kwargs.items(): - if k in self.params: - if allow_duplicates: - continue - else: - raise ValueError("Can't overwrite an existing param in the controller") - self.params[k], control, cb, widget_y = kwarg_to_mpl_widget( + param, control, cb, widget_y, hash_ = kwarg_to_mpl_widget( mpl_layout[0], mpl_layout[1:], widget_y, @@ -137,6 +147,15 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli partial(self.slider_updated, key=k), self.slider_format_strings[k], ) + if k in self.params: + if hash_ not in self._hashes: + raise ValueError( + f"kwarg {k} already exists and the new values are incompatible." + ) + # don't need to add it because it already exists + continue + self.params[k] = param + self._hashes.append(hash_) if control: self.controls[k] = control if k == "vmin_vmax": @@ -390,7 +409,6 @@ def gogogo_controls( slider_formats, play_buttons, extra_controls=None, - allow_dupes=False, ): """ Create a new controls object. @@ -446,7 +464,7 @@ def gogogo_controls( controls.display() else: controls = ctrls.pop() - controls.add_kwargs(kwargs, slider_formats, play_buttons, allow_duplicates=allow_dupes) + controls.add_kwargs(kwargs, slider_formats, play_buttons) params = {k: controls.params[k] for k in keys} return controls, params diff --git a/mpl_interactions/generic.py b/mpl_interactions/generic.py index 14a84df..39e8554 100644 --- a/mpl_interactions/generic.py +++ b/mpl_interactions/generic.py @@ -758,7 +758,6 @@ def hyperslicer( slider_format_strings, play_buttons, extra_ctrls, - allow_dupes=True, ) if vmin_vmax is not None: params.pop("vmin_vmax") diff --git a/mpl_interactions/helpers.py b/mpl_interactions/helpers.py index c169664..4ceb846 100644 --- a/mpl_interactions/helpers.py +++ b/mpl_interactions/helpers.py @@ -204,6 +204,8 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None) The generated widget. This may be the raw widget or a higher level container widget (e.g. HBox) depending on what widget was generated. If a fixed value is returned then control will be *None* + param_hash : + A hash of the possible values, to be used to check duplicates in the future. """ control = None if isinstance(val, set): @@ -214,7 +216,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None) pass else: # fixed parameter - return val, None + return val, None, hash(repr(val)) else: val = list(val) @@ -224,7 +226,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None) else: selector = widgets.Select(options=val) selector.observe(partial(update, values=val), names="index") - return val[0], selector + return val[0], selector, hash(repr(val)) elif isinstance(val, widgets.Widget) or isinstance(val, widgets.fixed): if not hasattr(val, "value"): raise TypeError( @@ -232,7 +234,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None) "But the widget passed for {key} does not have a `.value` attribute" ) if isinstance(val, widgets.fixed): - return val, None + return val, None, hash(repr(val)) elif ( isinstance(val, widgets.Select) or isinstance(val, widgets.SelectionSlider) @@ -242,10 +244,11 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None) # it looks unlikely to change but still would be nice to just check # if its a subclass val.observe(partial(update, values=val.options), names="index") + return val.value, val, hash(repr(val.options)) else: # set values to None and hope for the best val.observe(partial(update, values=None), names="value") - return val.value, val + return val.value, val, hash(repr(val)) # val.observe(partial(update, key=key, label=None), names=["value"]) else: if isinstance(val, tuple) and val[0] in ["r", "range", "rang", "rage"]: @@ -267,7 +270,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None) ) slider.observe(partial(update, values=vals), names="value") controls = widgets.HBox([slider, label]) - return vals[[0, -1]], controls + return vals[[0, -1]], controls, hash("r" + repr(vals)) if isinstance(val, tuple) and len(val) in [2, 3]: # treat as an argument to linspace @@ -279,7 +282,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None) raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar") if len(val) == 1: # don't need to create a slider - return val[0], None + return val[0], None, hash(repr(val)) else: # params[key] = val[0] label = widgets.Label(value=slider_format_string.format(val[0])) @@ -299,7 +302,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None) control = widgets.HBox([play, slider, label]) else: control = widgets.HBox([slider, label]) - return val[0], control + return val[0], control, hash(repr(val)) def extract_num_options(val): @@ -455,17 +458,19 @@ def process_mpl_widget(val, update): # oh boy do I ever not want to val.set_active(0) cb = val.on_clicked(partial(changeify_radio, labels=val.labels, update=update)) - return val.labels[0], val, cb - elif isinstance(val, (mwidgets.Slider, RangeSlider)): + return val.labels[0], val, cb, hash(repr(val.labels)) + elif isinstance(val, (mwidgets.Slider, mwidgets.RangeSlider, RangeSlider)): + # TODO: proper inherit matplotlib rand # potential future improvement: # check if valstep has been set and then try to infer the values # but not now, I'm trying to avoid premature optimization lest this # drag on forever cb = val.on_changed(partial(changeify, update=partial(update, values=None))) - return val.val, val, cb + hash_ = hash(str(val.valmin) + str(val.valmax) + str(val.valstep)) + return val.val, val, cb, hash_ else: cb = val.on_changed(partial(changeify, update=partial(update, values=None))) - return val.val, val, cb + return val.val, val, cb, hash(repr(val)) def kwarg_to_mpl_widget( @@ -512,6 +517,7 @@ def kwarg_to_mpl_widget( the callback id new_y The widget_y to use for the next pass. + hash """ slider_height, radio_height, gap_height = heights @@ -525,7 +531,7 @@ def kwarg_to_mpl_widget( if isinstance(val, tuple): pass else: - return val, None, None, widget_y + return val, None, None, widget_y, hash(repr(val)) else: val = list(val) @@ -537,10 +543,10 @@ def kwarg_to_mpl_widget( widget_y += radio_height * n + gap_height radio_buttons = mwidgets.RadioButtons(radio_ax, val, active=0) cb = radio_buttons.on_clicked(partial(changeify_radio, labels=val, update=update)) - return val[0], radio_buttons, cb, widget_y + return val[0], radio_buttons, cb, widget_y, hash(repr(val)) elif isinstance(val, mwidgets.AxesWidget): - val, widget, cb = process_mpl_widget(val, update) - return val, widget, cb, widget_y + val, widget, cb, hash_ = process_mpl_widget(val, update) + return val, widget, cb, widget_y, hash_ else: slider = None if isinstance(val, tuple) and val[0] in ["r", "range", "rang", "rage"]: @@ -552,7 +558,7 @@ def kwarg_to_mpl_widget( slider = create_mpl_range_selection_slider(slider_ax, key, vals, slider_format_string) cb = slider.on_changed(partial(changeify, update=partial(update, values=vals))) widget_y += slider_height + gap_height - return vals[[0, -1]], slider, cb, widget_y + return vals[[0, -1]], slider, cb, widget_y, hash(repr(vals)) if isinstance(val, tuple): if len(val) == 2: @@ -569,7 +575,7 @@ def update_text(val): slider.on_changed(update_text) cb = slider.on_changed(partial(changeify, update=partial(update, values=None))) widget_y += slider_height + gap_height - return min_, slider, cb, widget_y + return min_, slider, cb, widget_y, hash(repr(val)) elif len(val) == 3: # should warn that that doesn't make sense with matplotlib sliders min_ = val[0] @@ -580,13 +586,13 @@ def update_text(val): raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar") if len(val) == 1: # don't need to create a slider - return val[0], None, None, widget_y + return val[0], None, None, widget_y, hash(repr(val)) else: slider_ax = fig.add_axes([0.2, 0.9 - widget_y - gap_height, 0.65, slider_height]) slider = create_mpl_selection_slider(slider_ax, key, val, slider_format_string) slider.on_changed(partial(changeify, update=partial(update, values=val))) widget_y += slider_height + gap_height - return val[0], slider, None, widget_y + return val[0], slider, None, widget_y, hash(repr(val)) def create_slider_format_dict(slider_format_string): diff --git a/tests/test_generic.py b/tests/test_generic.py index 1bf91ae..188a9a2 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -55,3 +55,11 @@ def test_xr_hyperslicer_extents(): assert axs[1, 0].get_xlim() == axs[1, 1].get_xlim() assert axs[1, 0].get_ylim() == axs[1, 1].get_ylim() + + +def test_duplicate_axis_names(): + plt.subplots() + img_stack = np.random.rand(5, 512, 512) + + with hyperslicer(img_stack): + hyperslicer(img_stack)