From 013d7811872833ca98d5bb961836d2b1e56cdf5f Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Sat, 4 Jan 2025 02:03:24 +0000 Subject: [PATCH] Move console accessors to recipes.utils.rich --- src/fairseq2/recipes/assets.py | 2 +- src/fairseq2/recipes/cli.py | 2 +- src/fairseq2/recipes/console.py | 45 ------------------- .../recipes/llama/convert_checkpoint.py | 2 +- src/fairseq2/recipes/lm/chatbot.py | 2 +- src/fairseq2/recipes/logging.py | 2 +- src/fairseq2/recipes/utils/rich.py | 38 +++++++++++++++- 7 files changed, 42 insertions(+), 51 deletions(-) delete mode 100644 src/fairseq2/recipes/console.py diff --git a/src/fairseq2/recipes/assets.py b/src/fairseq2/recipes/assets.py index 382157986..122dddeb4 100644 --- a/src/fairseq2/recipes/assets.py +++ b/src/fairseq2/recipes/assets.py @@ -18,7 +18,7 @@ from fairseq2.context import get_runtime_context from fairseq2.logging import get_log_writer from fairseq2.recipes.cli import Cli, CliCommandHandler -from fairseq2.recipes.console import get_console +from fairseq2.recipes.utils.rich import get_console log = get_log_writer(__name__) diff --git a/src/fairseq2/recipes/cli.py b/src/fairseq2/recipes/cli.py index dbd5abb39..78cd17b69 100644 --- a/src/fairseq2/recipes/cli.py +++ b/src/fairseq2/recipes/cli.py @@ -27,7 +27,6 @@ ClusterResolver, UnknownClusterError, ) -from fairseq2.recipes.console import get_console, set_console from fairseq2.recipes.logging import DistributedLoggingInitializer, setup_basic_logging from fairseq2.recipes.runner import ( ConfigFileNotFoundError, @@ -43,6 +42,7 @@ get_sweep_keys, ) from fairseq2.recipes.utils.argparse import ConfigAction +from fairseq2.recipes.utils.rich import get_console, set_console from fairseq2.recipes.utils.sweep_tagger import ( NoopSweepTagger, StandardSweepTagger, diff --git a/src/fairseq2/recipes/console.py b/src/fairseq2/recipes/console.py deleted file mode 100644 index bd44f3ba7..000000000 --- a/src/fairseq2/recipes/console.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from rich import get_console as get_rich_console -from rich.console import Console - -_console: Console | None = None - - -def get_console() -> Console: - global _console - - if _console is None: - _console = get_rich_console() - - return _console - - -def set_console(console: Console) -> None: - global _console - - _console = console - - -_error_console: Console | None = None - - -def get_error_console() -> Console: - global _error_console - - if _error_console is None: - _error_console = Console(stderr=True, highlight=False) - - return _error_console - - -def set_error_console(console: Console) -> None: - global _error_console - - _error_console = console diff --git a/src/fairseq2/recipes/llama/convert_checkpoint.py b/src/fairseq2/recipes/llama/convert_checkpoint.py index cd9c92844..919687290 100644 --- a/src/fairseq2/recipes/llama/convert_checkpoint.py +++ b/src/fairseq2/recipes/llama/convert_checkpoint.py @@ -22,7 +22,7 @@ from fairseq2.models.llama import load_llama_config from fairseq2.models.llama.integ import convert_to_reference_checkpoint from fairseq2.recipes.cli import CliCommandHandler -from fairseq2.recipes.console import get_error_console +from fairseq2.recipes.utils.rich import get_error_console from fairseq2.utils.file import dump_torch_tensors, load_torch_tensors log = get_log_writer(__name__) diff --git a/src/fairseq2/recipes/lm/chatbot.py b/src/fairseq2/recipes/lm/chatbot.py index ed33d1c2c..9a49ecb16 100644 --- a/src/fairseq2/recipes/lm/chatbot.py +++ b/src/fairseq2/recipes/lm/chatbot.py @@ -31,8 +31,8 @@ from fairseq2.models.decoder import DecoderModel from fairseq2.recipes.cli import CliCommandHandler from fairseq2.recipes.cluster import ClusterError, ClusterHandler, ClusterResolver -from fairseq2.recipes.console import get_console from fairseq2.recipes.utils.argparse import parse_dtype +from fairseq2.recipes.utils.rich import get_console from fairseq2.recipes.utils.setup import setup_gangs from fairseq2.typing import CPU from fairseq2.utils.rng import RngBag diff --git a/src/fairseq2/recipes/logging.py b/src/fairseq2/recipes/logging.py index b61ae37fa..60faf3301 100644 --- a/src/fairseq2/recipes/logging.py +++ b/src/fairseq2/recipes/logging.py @@ -20,7 +20,7 @@ from fairseq2.error import SetupError from fairseq2.gang import get_rank -from fairseq2.recipes.console import get_error_console +from fairseq2.recipes.utils.rich import get_error_console def setup_basic_logging(*, debug: bool = False, utc_time: bool = False) -> None: diff --git a/src/fairseq2/recipes/utils/rich.py b/src/fairseq2/recipes/utils/rich.py index 6adc80fc5..c1696e2e6 100644 --- a/src/fairseq2/recipes/utils/rich.py +++ b/src/fairseq2/recipes/utils/rich.py @@ -6,6 +6,8 @@ from __future__ import annotations +from rich import get_console as get_rich_console +from rich.console import Console from rich.progress import ( BarColumn, Progress, @@ -19,7 +21,41 @@ from typing_extensions import override from fairseq2.gang import get_rank -from fairseq2.recipes.console import get_error_console + +_console: Console | None = None + + +def get_console() -> Console: + global _console + + if _console is None: + _console = get_rich_console() + + return _console + + +def set_console(console: Console) -> None: + global _console + + _console = console + + +_error_console: Console | None = None + + +def get_error_console() -> Console: + global _error_console + + if _error_console is None: + _error_console = Console(stderr=True, highlight=False) + + return _error_console + + +def set_error_console(console: Console) -> None: + global _error_console + + _error_console = console def create_rich_progress() -> Progress: