Skip to content

Commit

Permalink
Move console accessors to recipes.utils.rich (#952)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Jan 4, 2025
1 parent c1fd7c0 commit 1a7eb11
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 51 deletions.
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from fairseq2.context import get_runtime_context
from fairseq2.logging import get_log_writer
from fairseq2.recipes.cli import Cli, CliCommandHandler
from fairseq2.recipes.console import get_console
from fairseq2.recipes.utils.rich import get_console

log = get_log_writer(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
ClusterResolver,
UnknownClusterError,
)
from fairseq2.recipes.console import get_console, set_console
from fairseq2.recipes.logging import DistributedLoggingInitializer, setup_basic_logging
from fairseq2.recipes.runner import (
ConfigFileNotFoundError,
Expand All @@ -43,6 +42,7 @@
get_sweep_keys,
)
from fairseq2.recipes.utils.argparse import ConfigAction
from fairseq2.recipes.utils.rich import get_console, set_console
from fairseq2.recipes.utils.sweep_tagger import (
NoopSweepTagger,
StandardSweepTagger,
Expand Down
45 changes: 0 additions & 45 deletions src/fairseq2/recipes/console.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/fairseq2/recipes/llama/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from fairseq2.models.llama import load_llama_config
from fairseq2.models.llama.integ import convert_to_reference_checkpoint
from fairseq2.recipes.cli import CliCommandHandler
from fairseq2.recipes.console import get_error_console
from fairseq2.recipes.utils.rich import get_error_console
from fairseq2.utils.file import dump_torch_tensors, load_torch_tensors

log = get_log_writer(__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/lm/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from fairseq2.models.decoder import DecoderModel
from fairseq2.recipes.cli import CliCommandHandler
from fairseq2.recipes.cluster import ClusterError, ClusterHandler, ClusterResolver
from fairseq2.recipes.console import get_console
from fairseq2.recipes.utils.argparse import parse_dtype
from fairseq2.recipes.utils.rich import get_console
from fairseq2.recipes.utils.setup import setup_gangs
from fairseq2.typing import CPU
from fairseq2.utils.rng import RngBag
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from fairseq2.error import SetupError
from fairseq2.gang import get_rank
from fairseq2.recipes.console import get_error_console
from fairseq2.recipes.utils.rich import get_error_console


def setup_basic_logging(*, debug: bool = False, utc_time: bool = False) -> None:
Expand Down
38 changes: 37 additions & 1 deletion src/fairseq2/recipes/utils/rich.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

from rich import get_console as get_rich_console
from rich.console import Console
from rich.progress import (
BarColumn,
Progress,
Expand All @@ -19,7 +21,41 @@
from typing_extensions import override

from fairseq2.gang import get_rank
from fairseq2.recipes.console import get_error_console

_console: Console | None = None


def get_console() -> Console:
global _console

if _console is None:
_console = get_rich_console()

return _console


def set_console(console: Console) -> None:
global _console

_console = console


_error_console: Console | None = None


def get_error_console() -> Console:
global _error_console

if _error_console is None:
_error_console = Console(stderr=True, highlight=False)

return _error_console


def set_error_console(console: Console) -> None:
global _error_console

_error_console = console


def create_rich_progress() -> Progress:
Expand Down

0 comments on commit 1a7eb11

Please sign in to comment.