diff --git a/mpl_interactions/controller.py b/mpl_interactions/controller.py index 8cd299a..f062bec 100644 --- a/mpl_interactions/controller.py +++ b/mpl_interactions/controller.py @@ -532,11 +532,7 @@ def prep_scalars(kwarg_dict, **kwargs): kwargs[name] = _gen_f(arg.keys[0]) extra_ctrls.append(arg) - if len(added_kwargs) == 0: - # shortcircuit options - def param_excluder(params, except_=None): - return params - - else: - param_excluder = _gen_param_excluder(added_kwargs) + # always exclude all these - this will always be matplotlib + # arugments and so should never be passed to user supplied functions. + param_excluder = _gen_param_excluder(list(kwargs.keys())) return kwargs, extra_ctrls, param_excluder diff --git a/mpl_interactions/pyplot.py b/mpl_interactions/pyplot.py index 9fde3bf..4576af2 100644 --- a/mpl_interactions/pyplot.py +++ b/mpl_interactions/pyplot.py @@ -419,6 +419,7 @@ def interactive_scatter( c=None, vmin=None, vmax=None, + vmin_vmax=None, alpha=None, marker=None, edgecolors=None, @@ -452,6 +453,9 @@ def interactive_scatter( or any slider shorthand to control with a slider, or an indexed controls object to use an existing slider, or an arbitrary function of the other parameters. + vmin_vmax : tuple of float + Used to generate a range slider for vmin and vmax. Should be given in range slider + notation: `("r", 0, 1)`. alpha : float or Callable, optional Affects all scatter points. This will compound with any alpha introduced by the ``c`` argument @@ -516,21 +520,38 @@ def interactive_scatter( facecolors = kwargs.pop("facecolor", facecolors) edgecolors = kwargs.pop("edgecolor", edgecolors) - kwargs, collection_kwargs = kwarg_popper(kwargs, collection_kwargs_list) - ipympl = notebook_backend() or force_ipywidgets fig, ax = gogogo_figure(ipympl, ax) slider_formats = create_slider_format_dict(slider_formats) - extra_ctrls = [] - funcs, extra_ctrls, param_excluder = prep_scalars(kwargs, s=s, alpha=alpha, marker=marker) + kwargs, collection_kwargs = kwarg_popper(kwargs, collection_kwargs_list) + funcs, extra_ctrls, param_excluder = prep_scalars( + kwargs, s=s, alpha=alpha, marker=marker, vmin=vmin, vmax=vmax + ) s = funcs["s"] + vmin = funcs["vmin"] + vmax = funcs["vmax"] alpha = funcs["alpha"] marker = funcs["marker"] + if vmin_vmax is not None: + if isinstance(vmin_vmax, tuple) and not isinstance(vmin_vmax[0], str): + vmin_vmax = ("r", *vmin_vmax) + kwargs["vmin_vmax"] = vmin_vmax + controls, params = gogogo_controls( kwargs, controls, display_controls, slider_formats, play_buttons, extra_ctrls ) + if vmin_vmax is not None: + params.pop("vmin_vmax") + params["vmin"] = controls.params["vmin"] + params["vmax"] = controls.params["vmax"] + + def vmin(**kwargs): + return kwargs["vmin"] + + def vmax(**kwargs): + return kwargs["vmax"] def update(params, indices, cache): if parametric: @@ -545,7 +566,7 @@ def update(params, indices, cache): s_ = check_callable_xy(s, x_, y_, param_excluder(params, "s"), cache) ec_ = check_callable_xy(edgecolors, x_, y_, param_excluder(params), cache) fc_ = check_callable_xy(facecolors, x_, y_, param_excluder(params), cache) - a_ = check_callable_alpha(alpha, param_excluder(params, "alpha"), cache) + a_ = (callable_else_value_no_cast(alpha, param_excluder(params, "alpha"), cache),) marker_ = callable_else_value_no_cast(marker, param_excluder(params), cache) if marker_ is not None: @@ -576,6 +597,10 @@ def update(params, indices, cache): scatter.set_sizes(s_) if a_ is not None: scatter.set_alpha(a_) + if isinstance(vmin, Callable): + scatter.norm.vmin = callable_else_value(vmin, param_excluder(params, "vmin"), cache) + if isinstance(vmax, Callable): + scatter.norm.vmax = callable_else_value(vmax, param_excluder(params, "vmax"), cache) update_datalim_from_bbox( ax, scatter.get_datalim(ax.transData), stretch_x=stretch_x, stretch_y=stretch_y @@ -592,14 +617,6 @@ def check_callable_xy(arg, x, y, params, cache): else: return arg - def check_callable_alpha(alpha_, params, cache): - if isinstance(alpha_, Callable): - if alpha_ not in cache: - cache[alpha_] = alpha_(**param_excluder(params, "alpha")) - return cache[alpha_] - else: - return alpha_ - p = param_excluder(params) if parametric: out = callable_else_value_no_cast(x, p) @@ -612,17 +629,16 @@ def check_callable_alpha(alpha_, params, cache): s_ = check_callable_xy(s, x_, y_, param_excluder(params, "s"), {}) ec_ = check_callable_xy(edgecolors, x_, y_, p, {}) fc_ = check_callable_xy(facecolors, x_, y_, p, {}) - a_ = check_callable_alpha(alpha, params, {}) marker_ = callable_else_value_no_cast(marker, p, {}) scatter = ax.scatter( x_, y_, c=c_, s=s_, - vmin=vmin, - vmax=vmax, + alpha=callable_else_value_no_cast(alpha, param_excluder(params, "alpha")), + vmin=callable_else_value_no_cast(vmin, param_excluder(params, "vmin")), + vmax=callable_else_value_no_cast(vmax, param_excluder(params, "vmax")), marker=marker_, - alpha=a_, edgecolors=ec_, facecolors=fc_, label=label,