Skip to content

Commit

Permalink
fix: ensure vmin_vmax doesn't get passed to user fxns + add vmin_vmax…
Browse files Browse the repository at this point in the history
… to scatter (#274)
  • Loading branch information
ianhi authored Apr 14, 2023
1 parent b1bbdf1 commit 47c1b38
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 24 deletions.
10 changes: 3 additions & 7 deletions mpl_interactions/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,7 @@ def prep_scalars(kwarg_dict, **kwargs):
kwargs[name] = _gen_f(arg.keys[0])
extra_ctrls.append(arg)

if len(added_kwargs) == 0:
# shortcircuit options
def param_excluder(params, except_=None):
return params

else:
param_excluder = _gen_param_excluder(added_kwargs)
# always exclude all these - this will always be matplotlib
# arugments and so should never be passed to user supplied functions.
param_excluder = _gen_param_excluder(list(kwargs.keys()))
return kwargs, extra_ctrls, param_excluder
50 changes: 33 additions & 17 deletions mpl_interactions/pyplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def interactive_scatter(
c=None,
vmin=None,
vmax=None,
vmin_vmax=None,
alpha=None,
marker=None,
edgecolors=None,
Expand Down Expand Up @@ -452,6 +453,9 @@ def interactive_scatter(
or any slider shorthand to control with a slider, or an indexed controls
object to use an existing slider, or an arbitrary function of the other
parameters.
vmin_vmax : tuple of float
Used to generate a range slider for vmin and vmax. Should be given in range slider
notation: `("r", 0, 1)`.
alpha : float or Callable, optional
Affects all scatter points. This will compound with any alpha introduced by
the ``c`` argument
Expand Down Expand Up @@ -516,21 +520,38 @@ def interactive_scatter(
facecolors = kwargs.pop("facecolor", facecolors)
edgecolors = kwargs.pop("edgecolor", edgecolors)

kwargs, collection_kwargs = kwarg_popper(kwargs, collection_kwargs_list)

ipympl = notebook_backend() or force_ipywidgets
fig, ax = gogogo_figure(ipympl, ax)
slider_formats = create_slider_format_dict(slider_formats)

extra_ctrls = []
funcs, extra_ctrls, param_excluder = prep_scalars(kwargs, s=s, alpha=alpha, marker=marker)
kwargs, collection_kwargs = kwarg_popper(kwargs, collection_kwargs_list)
funcs, extra_ctrls, param_excluder = prep_scalars(
kwargs, s=s, alpha=alpha, marker=marker, vmin=vmin, vmax=vmax
)
s = funcs["s"]
vmin = funcs["vmin"]
vmax = funcs["vmax"]
alpha = funcs["alpha"]
marker = funcs["marker"]

if vmin_vmax is not None:
if isinstance(vmin_vmax, tuple) and not isinstance(vmin_vmax[0], str):
vmin_vmax = ("r", *vmin_vmax)
kwargs["vmin_vmax"] = vmin_vmax

controls, params = gogogo_controls(
kwargs, controls, display_controls, slider_formats, play_buttons, extra_ctrls
)
if vmin_vmax is not None:
params.pop("vmin_vmax")
params["vmin"] = controls.params["vmin"]
params["vmax"] = controls.params["vmax"]

def vmin(**kwargs):
return kwargs["vmin"]

def vmax(**kwargs):
return kwargs["vmax"]

def update(params, indices, cache):
if parametric:
Expand All @@ -545,7 +566,7 @@ def update(params, indices, cache):
s_ = check_callable_xy(s, x_, y_, param_excluder(params, "s"), cache)
ec_ = check_callable_xy(edgecolors, x_, y_, param_excluder(params), cache)
fc_ = check_callable_xy(facecolors, x_, y_, param_excluder(params), cache)
a_ = check_callable_alpha(alpha, param_excluder(params, "alpha"), cache)
a_ = (callable_else_value_no_cast(alpha, param_excluder(params, "alpha"), cache),)
marker_ = callable_else_value_no_cast(marker, param_excluder(params), cache)

if marker_ is not None:
Expand Down Expand Up @@ -576,6 +597,10 @@ def update(params, indices, cache):
scatter.set_sizes(s_)
if a_ is not None:
scatter.set_alpha(a_)
if isinstance(vmin, Callable):
scatter.norm.vmin = callable_else_value(vmin, param_excluder(params, "vmin"), cache)
if isinstance(vmax, Callable):
scatter.norm.vmax = callable_else_value(vmax, param_excluder(params, "vmax"), cache)

update_datalim_from_bbox(
ax, scatter.get_datalim(ax.transData), stretch_x=stretch_x, stretch_y=stretch_y
Expand All @@ -592,14 +617,6 @@ def check_callable_xy(arg, x, y, params, cache):
else:
return arg

def check_callable_alpha(alpha_, params, cache):
if isinstance(alpha_, Callable):
if alpha_ not in cache:
cache[alpha_] = alpha_(**param_excluder(params, "alpha"))
return cache[alpha_]
else:
return alpha_

p = param_excluder(params)
if parametric:
out = callable_else_value_no_cast(x, p)
Expand All @@ -612,17 +629,16 @@ def check_callable_alpha(alpha_, params, cache):
s_ = check_callable_xy(s, x_, y_, param_excluder(params, "s"), {})
ec_ = check_callable_xy(edgecolors, x_, y_, p, {})
fc_ = check_callable_xy(facecolors, x_, y_, p, {})
a_ = check_callable_alpha(alpha, params, {})
marker_ = callable_else_value_no_cast(marker, p, {})
scatter = ax.scatter(
x_,
y_,
c=c_,
s=s_,
vmin=vmin,
vmax=vmax,
alpha=callable_else_value_no_cast(alpha, param_excluder(params, "alpha")),
vmin=callable_else_value_no_cast(vmin, param_excluder(params, "vmin")),
vmax=callable_else_value_no_cast(vmax, param_excluder(params, "vmax")),
marker=marker_,
alpha=a_,
edgecolors=ec_,
facecolors=fc_,
label=label,
Expand Down

0 comments on commit 47c1b38

Please sign in to comment.