From c1fd7c0036bed81b475a59f159b2ebfbaf6f5155 Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Fri, 3 Jan 2025 20:56:43 -0500 Subject: [PATCH] Move EarlyStopper to recipes (#951) --- src/fairseq2/early_stopper.py | 24 -------------------- src/fairseq2/recipes/early_stopper.py | 32 +++++++++++++++++++++++++++ src/fairseq2/recipes/trainer.py | 8 ++++--- 3 files changed, 37 insertions(+), 27 deletions(-) delete mode 100644 src/fairseq2/early_stopper.py create mode 100644 src/fairseq2/recipes/early_stopper.py diff --git a/src/fairseq2/early_stopper.py b/src/fairseq2/early_stopper.py deleted file mode 100644 index 41f3e39c9..000000000 --- a/src/fairseq2/early_stopper.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from typing import Protocol - - -class EarlyStopper(Protocol): - """Stops training when an implementation-specific condition is not met.""" - - def __call__(self, step_nr: int, score: float) -> bool: - """ - :param step_nr: - The number of the current training step. - :para score: - The validation score of the current training step. - - :returns: - ``True`` if the training should be stopped; otherwise, ``False``. - """ diff --git a/src/fairseq2/recipes/early_stopper.py b/src/fairseq2/recipes/early_stopper.py new file mode 100644 index 000000000..64612813a --- /dev/null +++ b/src/fairseq2/recipes/early_stopper.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import final + +from typing_extensions import override + + +class EarlyStopper(ABC): + """Stops training when an implementation-specific condition is not met.""" + + @abstractmethod + def should_stop(self, step_nr: int, score: float) -> bool: + """ + :param step_nr: The number of the current training step. + :para score: The validation score of the current training step. + + :returns: ``True`` if the training should be stopped; otherwise, ``False``. + """ + + +@final +class NoopEarlyStopper(EarlyStopper): + @override + def should_stop(self, step_nr: int, score: float) -> bool: + return False diff --git a/src/fairseq2/recipes/trainer.py b/src/fairseq2/recipes/trainer.py index 4e8ceec78..85adc9d3e 100644 --- a/src/fairseq2/recipes/trainer.py +++ b/src/fairseq2/recipes/trainer.py @@ -28,7 +28,6 @@ from fairseq2.checkpoint import CheckpointManager, CheckpointNotFoundError from fairseq2.datasets import DataReader -from fairseq2.early_stopper import EarlyStopper from fairseq2.error import ContractError, InternalError, InvalidOperationError from fairseq2.gang import FakeGang, Gang, broadcast_flag from fairseq2.logging import log @@ -51,6 +50,7 @@ from fairseq2.optim import DynamicLossScaler from fairseq2.optim.lr_scheduler import LRScheduler, NoopLR, get_effective_lr from fairseq2.recipes.common_metrics import extend_batch_metrics +from fairseq2.recipes.early_stopper import EarlyStopper, NoopEarlyStopper from fairseq2.recipes.evaluator import EvalUnit from fairseq2.recipes.utils.rich import create_rich_progress from fairseq2.typing import CPU, DataType @@ -403,7 +403,7 @@ def __init__( ) if root_gang.rank != 0: - early_stopper = lambda step_nr, score: False + early_stopper = NoopEarlyStopper() self._early_stopper = early_stopper else: @@ -1075,7 +1075,9 @@ def _maybe_request_early_stop(self) -> None: if self._valid_score is None: raise InternalError("Early stopping, but `_valid_score` is `None`.") - should_stop = self._early_stopper(self._step_nr, self._valid_score) + should_stop = self._early_stopper.should_stop( + self._step_nr, self._valid_score + ) else: should_stop = False