Skip to content

Commit

Permalink
feat: only disallow duplicate names when values don't match (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianhi authored Apr 14, 2023
1 parent 47c1b38 commit f0c0cd3
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 36 deletions.
50 changes: 34 additions & 16 deletions mpl_interactions/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def __init__(
self.indices = defaultdict(lambda: 0)
self._update_funcs = defaultdict(list)
self._user_callbacks = defaultdict(list)
self._hashes = []
self.add_kwargs(kwargs, slider_formats, play_buttons)

def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_duplicates=False):
def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None):
"""Add kwargs to the controller.
If you pass a redundant kwarg it will just be overwritten
Expand All @@ -94,23 +95,37 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli
self.slider_format_strings[k] = v
if self.use_ipywidgets:
for k, v in kwargs.items():
if k in self.params:
if allow_duplicates:
continue
else:
raise ValueError("can't overwrite an existing param in the controller")
if isinstance(v, AxesWidget):
self.params[k], self.controls[k], _ = process_mpl_widget(
# TODO: HASHING behavior
param, control, _, hash_ = process_mpl_widget(
v, partial(self.slider_updated, key=k)
)
if k in self.params:
if hash_ not in self._hashes:
raise ValueError(
f"kwarg {k} already exists and the new values are incompatible."
)
# don't need to add it because it already exists
continue
self.params[k], self.controls[k] = param, control
self._hashes.append(hash)
else:
self.params[k], control = kwarg_to_ipywidget(
param, control, hash_ = kwarg_to_ipywidget(
k,
v,
partial(self.slider_updated, key=k),
self.slider_format_strings[k],
play_button=_play_buttons[k],
)
if k in self.params:
if hash_ not in self._hashes:
raise ValueError(
f"kwarg {k} already exists and the new values are incompatible."
)
# don't need to add it because it already exists
continue
self.params[k] = param
self._hashes.append(hash_)
if control:
self.controls[k] = control
self.vbox.children = [*list(self.vbox.children), control]
Expand All @@ -123,12 +138,7 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli
self.control_figures.append(mpl_layout[0])
widget_y = 0.05
for k, v in kwargs.items():
if k in self.params:
if allow_duplicates:
continue
else:
raise ValueError("Can't overwrite an existing param in the controller")
self.params[k], control, cb, widget_y = kwarg_to_mpl_widget(
param, control, cb, widget_y, hash_ = kwarg_to_mpl_widget(
mpl_layout[0],
mpl_layout[1:],
widget_y,
Expand All @@ -137,6 +147,15 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli
partial(self.slider_updated, key=k),
self.slider_format_strings[k],
)
if k in self.params:
if hash_ not in self._hashes:
raise ValueError(
f"kwarg {k} already exists and the new values are incompatible."
)
# don't need to add it because it already exists
continue
self.params[k] = param
self._hashes.append(hash_)
if control:
self.controls[k] = control
if k == "vmin_vmax":
Expand Down Expand Up @@ -390,7 +409,6 @@ def gogogo_controls(
slider_formats,
play_buttons,
extra_controls=None,
allow_dupes=False,
):
"""
Create a new controls object.
Expand Down Expand Up @@ -446,7 +464,7 @@ def gogogo_controls(
controls.display()
else:
controls = ctrls.pop()
controls.add_kwargs(kwargs, slider_formats, play_buttons, allow_duplicates=allow_dupes)
controls.add_kwargs(kwargs, slider_formats, play_buttons)
params = {k: controls.params[k] for k in keys}
return controls, params

Expand Down
1 change: 0 additions & 1 deletion mpl_interactions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,6 @@ def hyperslicer(
slider_format_strings,
play_buttons,
extra_ctrls,
allow_dupes=True,
)
if vmin_vmax is not None:
params.pop("vmin_vmax")
Expand Down
44 changes: 25 additions & 19 deletions mpl_interactions/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
The generated widget. This may be the raw widget or a higher level container
widget (e.g. HBox) depending on what widget was generated. If a fixed value is
returned then control will be *None*
param_hash :
A hash of the possible values, to be used to check duplicates in the future.
"""
control = None
if isinstance(val, set):
Expand All @@ -214,7 +216,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
pass
else:
# fixed parameter
return val, None
return val, None, hash(repr(val))
else:
val = list(val)

Expand All @@ -224,15 +226,15 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
else:
selector = widgets.Select(options=val)
selector.observe(partial(update, values=val), names="index")
return val[0], selector
return val[0], selector, hash(repr(val))
elif isinstance(val, widgets.Widget) or isinstance(val, widgets.fixed):
if not hasattr(val, "value"):
raise TypeError(
"widgets passed as parameters must have the `value` trait."
"But the widget passed for {key} does not have a `.value` attribute"
)
if isinstance(val, widgets.fixed):
return val, None
return val, None, hash(repr(val))
elif (
isinstance(val, widgets.Select)
or isinstance(val, widgets.SelectionSlider)
Expand All @@ -242,10 +244,11 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
# it looks unlikely to change but still would be nice to just check
# if its a subclass
val.observe(partial(update, values=val.options), names="index")
return val.value, val, hash(repr(val.options))
else:
# set values to None and hope for the best
val.observe(partial(update, values=None), names="value")
return val.value, val
return val.value, val, hash(repr(val))
# val.observe(partial(update, key=key, label=None), names=["value"])
else:
if isinstance(val, tuple) and val[0] in ["r", "range", "rang", "rage"]:
Expand All @@ -267,7 +270,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
)
slider.observe(partial(update, values=vals), names="value")
controls = widgets.HBox([slider, label])
return vals[[0, -1]], controls
return vals[[0, -1]], controls, hash("r" + repr(vals))

if isinstance(val, tuple) and len(val) in [2, 3]:
# treat as an argument to linspace
Expand All @@ -279,7 +282,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar")
if len(val) == 1:
# don't need to create a slider
return val[0], None
return val[0], None, hash(repr(val))
else:
# params[key] = val[0]
label = widgets.Label(value=slider_format_string.format(val[0]))
Expand All @@ -299,7 +302,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
control = widgets.HBox([play, slider, label])
else:
control = widgets.HBox([slider, label])
return val[0], control
return val[0], control, hash(repr(val))


def extract_num_options(val):
Expand Down Expand Up @@ -455,17 +458,19 @@ def process_mpl_widget(val, update):
# oh boy do I ever not want to
val.set_active(0)
cb = val.on_clicked(partial(changeify_radio, labels=val.labels, update=update))
return val.labels[0], val, cb
elif isinstance(val, (mwidgets.Slider, RangeSlider)):
return val.labels[0], val, cb, hash(repr(val.labels))
elif isinstance(val, (mwidgets.Slider, mwidgets.RangeSlider, RangeSlider)):
# TODO: proper inherit matplotlib rand
# potential future improvement:
# check if valstep has been set and then try to infer the values
# but not now, I'm trying to avoid premature optimization lest this
# drag on forever
cb = val.on_changed(partial(changeify, update=partial(update, values=None)))
return val.val, val, cb
hash_ = hash(str(val.valmin) + str(val.valmax) + str(val.valstep))
return val.val, val, cb, hash_
else:
cb = val.on_changed(partial(changeify, update=partial(update, values=None)))
return val.val, val, cb
return val.val, val, cb, hash(repr(val))


def kwarg_to_mpl_widget(
Expand Down Expand Up @@ -512,6 +517,7 @@ def kwarg_to_mpl_widget(
the callback id
new_y
The widget_y to use for the next pass.
hash
"""
slider_height, radio_height, gap_height = heights

Expand All @@ -525,7 +531,7 @@ def kwarg_to_mpl_widget(
if isinstance(val, tuple):
pass
else:
return val, None, None, widget_y
return val, None, None, widget_y, hash(repr(val))
else:
val = list(val)

Expand All @@ -537,10 +543,10 @@ def kwarg_to_mpl_widget(
widget_y += radio_height * n + gap_height
radio_buttons = mwidgets.RadioButtons(radio_ax, val, active=0)
cb = radio_buttons.on_clicked(partial(changeify_radio, labels=val, update=update))
return val[0], radio_buttons, cb, widget_y
return val[0], radio_buttons, cb, widget_y, hash(repr(val))
elif isinstance(val, mwidgets.AxesWidget):
val, widget, cb = process_mpl_widget(val, update)
return val, widget, cb, widget_y
val, widget, cb, hash_ = process_mpl_widget(val, update)
return val, widget, cb, widget_y, hash_
else:
slider = None
if isinstance(val, tuple) and val[0] in ["r", "range", "rang", "rage"]:
Expand All @@ -552,7 +558,7 @@ def kwarg_to_mpl_widget(
slider = create_mpl_range_selection_slider(slider_ax, key, vals, slider_format_string)
cb = slider.on_changed(partial(changeify, update=partial(update, values=vals)))
widget_y += slider_height + gap_height
return vals[[0, -1]], slider, cb, widget_y
return vals[[0, -1]], slider, cb, widget_y, hash(repr(vals))

if isinstance(val, tuple):
if len(val) == 2:
Expand All @@ -569,7 +575,7 @@ def update_text(val):
slider.on_changed(update_text)
cb = slider.on_changed(partial(changeify, update=partial(update, values=None)))
widget_y += slider_height + gap_height
return min_, slider, cb, widget_y
return min_, slider, cb, widget_y, hash(repr(val))
elif len(val) == 3:
# should warn that that doesn't make sense with matplotlib sliders
min_ = val[0]
Expand All @@ -580,13 +586,13 @@ def update_text(val):
raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar")
if len(val) == 1:
# don't need to create a slider
return val[0], None, None, widget_y
return val[0], None, None, widget_y, hash(repr(val))
else:
slider_ax = fig.add_axes([0.2, 0.9 - widget_y - gap_height, 0.65, slider_height])
slider = create_mpl_selection_slider(slider_ax, key, val, slider_format_string)
slider.on_changed(partial(changeify, update=partial(update, values=val)))
widget_y += slider_height + gap_height
return val[0], slider, None, widget_y
return val[0], slider, None, widget_y, hash(repr(val))


def create_slider_format_dict(slider_format_string):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,11 @@ def test_xr_hyperslicer_extents():

assert axs[1, 0].get_xlim() == axs[1, 1].get_xlim()
assert axs[1, 0].get_ylim() == axs[1, 1].get_ylim()


def test_duplicate_axis_names():
plt.subplots()
img_stack = np.random.rand(5, 512, 512)

with hyperslicer(img_stack):
hyperslicer(img_stack)

0 comments on commit f0c0cd3

Please sign in to comment.