Skip to content

Commit

Permalink
Improve code in utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cnhwl committed Dec 31, 2024
1 parent e6b6854 commit 8468f6d
Showing 1 changed file with 27 additions and 33 deletions.
60 changes: 27 additions & 33 deletions darts/ad/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,23 @@ def show_anomalies_from_scores(
Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores).
Only effective when `pred_scores` is not `None`.
Default: "AUC_ROC".
multivariate_plot
If True, it will separately plot each component in multivariate series.
"""

series = _check_input(
series,
name="series",
num_series_expected=1,
)[0]
series = (
_check_input(
series,
name="series",
check_multivariate=True,
)[0]
if multivariate_plot
else _check_input(
series,
name="series",
num_series_expected=1,
)[0]
)

if title is None and pred_scores is not None:
title = "Anomaly results"
Expand Down Expand Up @@ -424,15 +434,17 @@ def show_anomalies_from_scores(
)

nbr_plots = nbr_plots + len(set(window))

if multivariate_plot:
series = _check_input(
series,
name="series",
check_multivariate=True,
)[0]
series_width = series.n_components
plots_per_ts = nbr_plots * series_width if multivariate_plot else nbr_plots
fig, axs = plt.subplots(
plots_per_ts,
figsize=(8, 4 + 2 * (plots_per_ts - 1)),
sharex=True,
gridspec_kw={"height_ratios": [2] + [1] * (plots_per_ts - 1)},
squeeze=False,
)

if multivariate_plot:
if pred_series is not None:
pred_series = _check_input(
pred_series,
Expand All @@ -446,6 +458,7 @@ def show_anomalies_from_scores(
anomalies,
name="anomalies",
width_expected=series.width,
check_binary=True,
check_multivariate=True,
)[0]

Expand All @@ -458,17 +471,7 @@ def show_anomalies_from_scores(
check_multivariate=True,
)[0]

series_width = series.n_components
fig, axs = plt.subplots(
nbr_plots * series_width,
figsize=(8, 4 + 2 * (nbr_plots * series_width - 1)),
sharex=True,
gridspec_kw={"height_ratios": [2] + [1] * (nbr_plots * series_width - 1)},
squeeze=False,
)

for i in range(series_width):
index_ax = i * nbr_plots
_plot_series_and_anomalies(
series=series[series.components[i]],
anomalies=anomalies[anomalies.components[i]]
Expand All @@ -482,21 +485,12 @@ def show_anomalies_from_scores(
names_of_scorers=names_of_scorers,
metric=metric,
axs=axs,
index_ax=index_ax,
index_ax=i * nbr_plots,
nbr_plots=nbr_plots,
)

fig.suptitle(title)
else:
fig, axs = plt.subplots(
nbr_plots,
figsize=(8, 4 + 2 * (nbr_plots - 1)),
sharex=True,
gridspec_kw={"height_ratios": [2] + [1] * (nbr_plots - 1)},
squeeze=False,
)

index_ax = 0
_plot_series_and_anomalies(
series=series,
anomalies=anomalies,
Expand All @@ -506,7 +500,7 @@ def show_anomalies_from_scores(
names_of_scorers=names_of_scorers,
metric=metric,
axs=axs,
index_ax=index_ax,
index_ax=0,
nbr_plots=nbr_plots,
)

Expand Down

0 comments on commit 8468f6d

Please sign in to comment.