Skip to content

Commit

Permalink
Move EarlyStopper to recipes (#951)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Jan 4, 2025
1 parent 612a2a4 commit c1fd7c0
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 27 deletions.
24 changes: 0 additions & 24 deletions src/fairseq2/early_stopper.py

This file was deleted.

32 changes: 32 additions & 0 deletions src/fairseq2/recipes/early_stopper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import final

from typing_extensions import override


class EarlyStopper(ABC):
"""Stops training when an implementation-specific condition is not met."""

@abstractmethod
def should_stop(self, step_nr: int, score: float) -> bool:
"""
:param step_nr: The number of the current training step.
:para score: The validation score of the current training step.
:returns: ``True`` if the training should be stopped; otherwise, ``False``.
"""


@final
class NoopEarlyStopper(EarlyStopper):
@override
def should_stop(self, step_nr: int, score: float) -> bool:
return False
8 changes: 5 additions & 3 deletions src/fairseq2/recipes/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from fairseq2.checkpoint import CheckpointManager, CheckpointNotFoundError
from fairseq2.datasets import DataReader
from fairseq2.early_stopper import EarlyStopper
from fairseq2.error import ContractError, InternalError, InvalidOperationError
from fairseq2.gang import FakeGang, Gang, broadcast_flag
from fairseq2.logging import log
Expand All @@ -51,6 +50,7 @@
from fairseq2.optim import DynamicLossScaler
from fairseq2.optim.lr_scheduler import LRScheduler, NoopLR, get_effective_lr
from fairseq2.recipes.common_metrics import extend_batch_metrics
from fairseq2.recipes.early_stopper import EarlyStopper, NoopEarlyStopper
from fairseq2.recipes.evaluator import EvalUnit
from fairseq2.recipes.utils.rich import create_rich_progress
from fairseq2.typing import CPU, DataType
Expand Down Expand Up @@ -403,7 +403,7 @@ def __init__(
)

if root_gang.rank != 0:
early_stopper = lambda step_nr, score: False
early_stopper = NoopEarlyStopper()

self._early_stopper = early_stopper
else:
Expand Down Expand Up @@ -1075,7 +1075,9 @@ def _maybe_request_early_stop(self) -> None:
if self._valid_score is None:
raise InternalError("Early stopping, but `_valid_score` is `None`.")

should_stop = self._early_stopper(self._step_nr, self._valid_score)
should_stop = self._early_stopper.should_stop(
self._step_nr, self._valid_score
)
else:
should_stop = False

Expand Down

0 comments on commit c1fd7c0

Please sign in to comment.