Skip to content

Commit

Permalink
Merge pull request #267 from DHI/Make-consistent-plotting-API
Browse files Browse the repository at this point in the history
Make plotting api more consistent
  • Loading branch information
jsmariegaard authored Oct 12, 2023
2 parents 6373be4 + 97cb41e commit ebc26ad
Show file tree
Hide file tree
Showing 11 changed files with 445 additions and 109 deletions.
49 changes: 33 additions & 16 deletions modelskill/comparison/_collection_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@
from typing import Any, List, Union, Optional, Tuple, Sequence
from matplotlib.axes import Axes # type: ignore

import matplotlib.pyplot as plt # type: ignore
import pandas as pd

from .. import metrics as mtr
from ..utils import _get_idx
from ..plotting import taylor_diagram, scatter, TaylorPoint
from ..plotting._misc import (
_xtick_directional,
_ytick_directional,
)
from ..plotting._misc import _xtick_directional, _ytick_directional, _get_fig_ax


class ComparerCollectionPlotter:
Expand Down Expand Up @@ -41,6 +37,7 @@ def scatter(
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
skill_table: Optional[Union[str, List[str], bool]] = None,
ax: Optional[Axes] = None,
**kwargs,
):
"""Scatter plot showing compared data: observation vs modelled
Expand Down Expand Up @@ -95,6 +92,8 @@ def scatter(
list of modelskill.metrics or boolean, if True then by default modelskill.options.metrics.list.
This kword adds a box at the right of the scatter plot,
by default False
ax : matplotlib axes, optional
axes to plot on, by default None
kwargs
Examples
Expand Down Expand Up @@ -167,6 +166,7 @@ def scatter(
ylabel=ylabel,
skill_df=skill_df,
units=units,
ax=ax,
**kwargs,
)

Expand All @@ -176,13 +176,17 @@ def scatter(

return ax

def kde(self, ax=None, **kwargs) -> Axes:
def kde(self, ax=None, figsize=None, title=None, **kwargs) -> Axes:
"""Plot kernel density estimate of observation and model data.
Parameters
----------
ax : Axes, optional
matplotlib axes, by default None
figsize : tuple, optional
width and height of the figure, by default None
title : str, optional
plot title, by default None
**kwargs
passed to pandas.DataFrame.plot.kde()
Expand All @@ -198,8 +202,7 @@ def kde(self, ax=None, **kwargs) -> Axes:
>>> cc.plot.kde(bw_method='silverman')
"""
if ax is None:
ax = plt.gca()
_, ax = _get_fig_ax(ax, figsize)

df = self.cc.to_dataframe()
ax = df.obs_val.plot.kde(
Expand All @@ -213,9 +216,10 @@ def kde(self, ax=None, **kwargs) -> Axes:
# TODO use unit_text from the first comparer
# TODO make sure they are conistent
# then it should be a property of the collection, not only the comparer
plt.xlabel(f"{self.cc[0].unit_text}")
ax.set_xlabel(f"{self.cc[0].unit_text}")

# TODO title?
default_title = f"KDE plot for {', '.join(cmp.name for cmp in self.cc)}"
ax.set_title(title or default_title)
ax.legend()

# remove y-axis, ticks and label
Expand All @@ -240,6 +244,8 @@ def hist(
title: Optional[str] = None,
density=True,
alpha: float = 0.5,
ax=None,
figsize: Optional[Tuple[float, float]] = None,
**kwargs,
):
"""Plot histogram of specific model and all observations.
Expand All @@ -258,6 +264,10 @@ def hist(
If True, draw and return a probability density, by default True
alpha : float, optional
alpha transparency fraction, by default 0.5
ax : matplotlib axes, optional
axes to plot on, by default None
figsize : tuple, optional
width and height of the figure, by default None
kwargs : other keyword arguments to df.hist()
Returns
Expand All @@ -276,6 +286,8 @@ def hist(
"""
from ._comparison import MOD_COLORS

_, ax = _get_fig_ax(ax, figsize)

mod_id = _get_idx(model, self.cc.mod_names)
mod_name = self.cc.mod_names[mod_id]

Expand All @@ -285,21 +297,22 @@ def hist(
df = cmp.to_dataframe()
kwargs["alpha"] = alpha
kwargs["density"] = density
ax = df.mod_val.hist(bins=bins, color=MOD_COLORS[mod_id], **kwargs)
df.mod_val.hist(bins=bins, color=MOD_COLORS[mod_id], ax=ax, **kwargs)
df.obs_val.hist(
bins=bins,
color=self.cc[0].data["Observation"].attrs["color"],
ax=ax,
**kwargs,
)

ax.legend([mod_name, "observations"])
plt.title(title)
plt.xlabel(f"{self.cc[df.observation[0]].unit_text}")
ax.set_title(title)
ax.set_xlabel(f"{self.cc[df.observation.iloc[0]].unit_text}")

if density:
plt.ylabel("density")
ax.set_ylabel("density")
else:
plt.ylabel("count")
ax.set_ylabel("count")

if self.is_directional:
_xtick_directional(ax)
Expand Down Expand Up @@ -348,6 +361,10 @@ def taylor(
title : str, optional
title of the plot, by default "Taylor diagram"
Returns
-------
matplotlib.figure.Figure
Examples
------
>>> cc.plot.taylor()
Expand Down Expand Up @@ -387,7 +404,7 @@ def taylor(
for r in df.itertuples()
]

taylor_diagram(
return taylor_diagram(
obs_std=ref_std,
points=pts,
figsize=figsize,
Expand Down
Loading

0 comments on commit ebc26ad

Please sign in to comment.