From 33006d44e3967f9d2e59e61d7a8367af45dde9ff Mon Sep 17 00:00:00 2001 From: he weilin Date: Tue, 31 Dec 2024 10:19:18 +0800 Subject: [PATCH] refactor the show_anomalies_from_scores() --- darts/ad/utils.py | 347 +++++++++++++++++++++------------------------- 1 file changed, 155 insertions(+), 192 deletions(-) diff --git a/darts/ad/utils.py b/darts/ad/utils.py index 510ef3f924..6f2821a656 100644 --- a/darts/ad/utils.py +++ b/darts/ad/utils.py @@ -469,112 +469,23 @@ def show_anomalies_from_scores( for i in range(series_width): index_ax = i * nbr_plots - - _plot_series( + _plot_series_and_anomalies( series=series[series.components[i]], - ax_id=axs[index_ax][0], - linewidth=0.5, - label_name="", - ) - - if pred_series[pred_series.components[i]] is not None: - _plot_series( - series=pred_series[pred_series.components[i]], - ax_id=axs[index_ax][0], - linewidth=0.5, - label_name=pred_series.components[i] + " model_output", - ) - - axs[index_ax][0].set_title("") - - if anomalies is not None or pred_scores is not None: - axs[index_ax][0].set_xlabel("") - - axs[index_ax][0].legend( - loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=2 + anomalies=anomalies[anomalies.components[i]] + if anomalies is not None + else None, + pred_series=pred_series[pred_series.components[i]] + if pred_series is not None + else None, + pred_scores=pred_scores, + window=window, + names_of_scorers=names_of_scorers, + metric=metric, + axs=axs, + index_ax=index_ax, + nbr_plots=nbr_plots, ) - if pred_scores is not None: - dict_input = {} - - for idx, (score, w) in enumerate(zip(pred_scores, window)): - dict_input[idx] = { - "series_score": score, - "window": w, - "name_id": idx, - } - - for index, elem in enumerate( - sorted(dict_input.items(), key=lambda x: x[1]["window"]) - ): - if index == 0: - current_window = elem[1]["window"] - index_ax = index_ax + 1 - - idx = elem[1]["name_id"] - w = elem[1]["window"] - - if w != current_window: - current_window = w - index_ax = index_ax + 1 - - if metric is not None: - value = round( - eval_metric_from_scores( - anomalies=anomalies[anomalies.components[i]], - pred_scores=pred_scores[idx][ - pred_scores[idx].components[i] - ], - window=w, - metric=metric, - ), - 3, - ) - else: - value = None - - if names_of_scorers is not None: - label = ( - names_of_scorers[idx] + [f" ({value})", ""][value is None] - ) - else: - label = f"score_{str(idx)}" + [f" ({value})", ""][value is None] - - _plot_series( - series=elem[1]["series_score"][ - elem[1]["series_score"].components[i] - ], - ax_id=axs[index_ax][0], - linewidth=0.5, - label_name=label, - ) - - axs[index_ax][0].legend( - loc="upper center", bbox_to_anchor=(0.5, 1.19), ncol=2 - ) - axs[index_ax][0].set_title(f"Window: {str(w)}", loc="left") - axs[index_ax][0].set_title("") - axs[index_ax][0].set_xlabel("") - - if anomalies is not None: - _plot_series( - series=anomalies[anomalies.components[i]], - ax_id=axs[index_ax + 1][0], - linewidth=1, - label_name=anomalies.components[i], - color="red", - ) - - axs[index_ax + 1][0].set_title("") - axs[index_ax + 1][0].set_ylim([-0.1, 1.1]) - axs[index_ax + 1][0].set_yticks([0, 1]) - axs[index_ax + 1][0].set_yticklabels(["no", "yes"]) - axs[index_ax + 1][0].legend( - loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=2 - ) - else: - axs[index_ax][0].set_xlabel("timestamp") - fig.suptitle(title) else: fig, axs = plt.subplots( @@ -586,97 +497,19 @@ def show_anomalies_from_scores( ) index_ax = 0 - - _plot_series( - series=series, ax_id=axs[index_ax][0], linewidth=0.5, label_name="" + _plot_series_and_anomalies( + series=series, + anomalies=anomalies, + pred_series=pred_series, + pred_scores=pred_scores, + window=window, + names_of_scorers=names_of_scorers, + metric=metric, + axs=axs, + index_ax=index_ax, + nbr_plots=nbr_plots, ) - if pred_series is not None: - _plot_series( - series=pred_series, - ax_id=axs[index_ax][0], - linewidth=0.5, - label_name="model output", - ) - - axs[index_ax][0].set_title("") - - if anomalies is not None or pred_scores is not None: - axs[index_ax][0].set_xlabel("") - - axs[index_ax][0].legend(loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=2) - - if pred_scores is not None: - dict_input = {} - - for idx, (score, w) in enumerate(zip(pred_scores, window)): - dict_input[idx] = {"series_score": score, "window": w, "name_id": idx} - - for index, elem in enumerate( - sorted(dict_input.items(), key=lambda x: x[1]["window"]) - ): - if index == 0: - current_window = elem[1]["window"] - index_ax = index_ax + 1 - - idx = elem[1]["name_id"] - w = elem[1]["window"] - - if w != current_window: - current_window = w - index_ax = index_ax + 1 - - if metric is not None: - value = round( - eval_metric_from_scores( - anomalies=anomalies, - pred_scores=pred_scores[idx], - window=w, - metric=metric, - ), - 3, - ) - else: - value = None - - if names_of_scorers is not None: - label = names_of_scorers[idx] + [f" ({value})", ""][value is None] - else: - label = f"score_{str(idx)}" + [f" ({value})", ""][value is None] - - _plot_series( - series=elem[1]["series_score"], - ax_id=axs[index_ax][0], - linewidth=0.5, - label_name=label, - ) - - axs[index_ax][0].legend( - loc="upper center", bbox_to_anchor=(0.5, 1.19), ncol=2 - ) - axs[index_ax][0].set_title(f"Window: {str(w)}", loc="left") - axs[index_ax][0].set_title("") - axs[index_ax][0].set_xlabel("") - - if anomalies is not None: - _plot_series( - series=anomalies, - ax_id=axs[index_ax + 1][0], - linewidth=1, - label_name="anomalies", - color="red", - ) - - axs[index_ax + 1][0].set_title("") - axs[index_ax + 1][0].set_ylim([-0.1, 1.1]) - axs[index_ax + 1][0].set_yticks([0, 1]) - axs[index_ax + 1][0].set_yticklabels(["no", "yes"]) - axs[index_ax + 1][0].legend( - loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=2 - ) - else: - axs[index_ax][0].set_xlabel("timestamp") - fig.suptitle(title) @@ -937,3 +770,133 @@ def _assert_fit_called(fit_called: bool, name: str): ), logger=logger, ) + + +def _plot_series_and_anomalies( + series: TimeSeries, + anomalies: TimeSeries, + pred_series: TimeSeries, + pred_scores: Sequence[TimeSeries], + window: Sequence[int], + names_of_scorers: Sequence[str], + metric: str, + axs: plt.Axes, + index_ax: int, + nbr_plots: int, +): + """Helper function to plot series and anomalies. + + Parameters + ---------- + series + The actual series to visualize anomalies from. + anomalies + The ground truth of the anomalies (1 if it is an anomaly and 0 if not). + pred_series + Output of the model given as input the `series` (can be stochastic). + pred_scores + Output of the scorers given the output of the model and `series`. + window + Window parameter for each anomaly scores. + names_of_scorers + Name of the scores. + metric + The name of the metric function to use. + axs + The axes to plot on. + index_ax + The index of the current axis. + nbr_plots + The number of plots. + """ + _plot_series(series=series, ax_id=axs[index_ax][0], linewidth=0.5, label_name="") + + if pred_series is not None: + _plot_series( + series=pred_series, + ax_id=axs[index_ax][0], + linewidth=0.5, + label_name="model output", + ) + + axs[index_ax][0].set_title("") + + if anomalies is not None or pred_scores is not None: + axs[index_ax][0].set_xlabel("") + + axs[index_ax][0].legend(loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=2) + + if pred_scores is not None: + dict_input = {} + + for idx, (score, w) in enumerate(zip(pred_scores, window)): + dict_input[idx] = {"series_score": score, "window": w, "name_id": idx} + + for index, elem in enumerate( + sorted(dict_input.items(), key=lambda x: x[1]["window"]) + ): + if index == 0: + current_window = elem[1]["window"] + index_ax = index_ax + 1 + + idx = elem[1]["name_id"] + w = elem[1]["window"] + + if w != current_window: + current_window = w + index_ax = index_ax + 1 + + if metric is not None: + value = round( + eval_metric_from_scores( + anomalies=anomalies, + pred_scores=pred_scores[idx][ + pred_scores[idx].components[index_ax // nbr_plots] + ], + window=w, + metric=metric, + ), + 3, + ) + else: + value = None + + if names_of_scorers is not None: + label = names_of_scorers[idx] + [f" ({value})", ""][value is None] + else: + label = f"score_{str(idx)}" + [f" ({value})", ""][value is None] + + _plot_series( + series=elem[1]["series_score"][ + elem[1]["series_score"].components[index_ax // nbr_plots] + ], + ax_id=axs[index_ax][0], + linewidth=0.5, + label_name=label, + ) + + axs[index_ax][0].legend( + loc="upper center", bbox_to_anchor=(0.5, 1.19), ncol=2 + ) + axs[index_ax][0].set_title(f"Window: {str(w)}", loc="left") + axs[index_ax][0].set_title("") + axs[index_ax][0].set_xlabel("") + + if anomalies is not None: + _plot_series( + series=anomalies, + ax_id=axs[index_ax + 1][0], + linewidth=1, + label_name="anomalies", + color="red", + ) + + axs[index_ax + 1][0].set_title("") + axs[index_ax + 1][0].set_ylim([-0.1, 1.1]) + axs[index_ax + 1][0].set_yticks([0, 1]) + axs[index_ax + 1][0].set_yticklabels(["no", "yes"]) + axs[index_ax + 1][0].legend( + loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=2 + ) + else: + axs[index_ax][0].set_xlabel("timestamp")